#!/usr/bin/env python3
"""
Video Language Changer - Step 2: Translation
Translates English transcription to Hindi using AI4Bharat's IndicTrans2
"""

import os
import json
import subprocess
from pathlib import Path
from typing import Dict, List, Any, Optional
import re

def install_requirements():
    """Install required packages for translation"""
    packages = [
        "torch",
        "transformers>=4.33.0",
        "sentencepiece",
        "sacremoses",
        "IndicTransToolkit"
    ]
    
    for package in packages:
        try:
            if package.startswith("transformers"):
                import transformers
                # Check version
                version = transformers.__version__
                required_version = package.split(">=")[1]
                print(f"Transformers version: {version}")
            elif package == "torch":
                import torch
            elif package == "sentencepiece":
                import sentencepiece
            elif package == "sacremoses":
                import sacremoses
            elif package == "IndicTransToolkit":
                from IndicTransToolkit import IndicProcessor
        except ImportError:
            print(f"Installing {package}...")
            if package == "IndicTransToolkit":
                subprocess.run(["pip", "install", "git+https://github.com/VarunGumma/IndicTransToolkit"], check=True)
            else:
                subprocess.run(["pip", "install", package], check=True)

def load_transcription(transcription_path: str) -> Dict[str, Any]:
    """Load transcription JSON file"""
    if not os.path.exists(transcription_path):
        raise FileNotFoundError(f"Transcription file not found: {transcription_path}")
    
    with open(transcription_path, 'r', encoding='utf-8') as f:
        transcription_data = json.load(f)
    
    print(f"✓ Loaded transcription from: {transcription_path}")
    return transcription_data

def setup_indictrans2():
    """Setup IndicTrans2 model for English to Hindi translation"""
    try:
        from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
        from IndicTransToolkit import IndicProcessor
        import torch
        
        print("Loading IndicTrans2 model...")
        
        # Model name for English-Hindi translation
        model_name = "ai4bharat/indictrans2-en-indic-1B"
        
        # Load tokenizer and model
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
        
        # Initialize IndicProcessor for preprocessing
        ip = IndicProcessor(inference=True)
        
        # Move to GPU if available
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = model.to(device)
        
        print(f"✓ IndicTrans2 model loaded on {device}")
        return tokenizer, model, ip, device
        
    except Exception as e:
        print(f"❌ Failed to load IndicTrans2: {e}")
        print("Trying alternative approach...")
        
        # Fallback to smaller model or different approach
        try:
            from transformers import MarianMTModel, MarianTokenizer
            
            model_name = "Helsinki-NLP/opus-mt-en-hi"
            tokenizer = MarianTokenizer.from_pretrained(model_name)
            model = MarianMTModel.from_pretrained(model_name)
            
            device = "cuda" if torch.cuda.is_available() else "cpu"
            model = model.to(device)
            
            print(f"✓ Fallback model (Helsinki NLP) loaded on {device}")
            return tokenizer, model, None, device
            
        except Exception as e2:
            print(f"❌ Fallback also failed: {e2}")
            raise

def preprocess_text_for_translation(text: str) -> str:
    """Clean and preprocess text for better translation"""
    # Remove excessive whitespace
    text = re.sub(r'\s+', ' ', text)
    
    # Remove special characters that might confuse translation
    text = re.sub(r'[^\w\s.,!?-]', '', text)
    
    # Strip leading/trailing whitespace
    text = text.strip()
    
    return text

def translate_with_indictrans2(text: str, tokenizer, model, ip, device) -> str:
    """
    Translate text using IndicTrans2
    
    Args:
        text: English text to translate
        tokenizer: IndicTrans2 tokenizer
        model: IndicTrans2 model
        ip: IndicProcessor
        device: Device (cuda/cpu)
    
    Returns:
        Hindi translated text
    """
    import torch
    
    # Preprocess text
    if ip:
        # Use IndicProcessor for preprocessing
        processed_text = ip.preprocess_batch([text], src_lang="eng_Latn", tgt_lang="hin_Deva")
    else:
        processed_text = [text]
    
    # Tokenize
    inputs = tokenizer(
        processed_text,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512
    ).to(device)
    
    # Generate translation
    with torch.no_grad():
        generated_tokens = model.generate(
            **inputs,
            use_cache=True,
            min_length=1,
            max_length=512,
            num_beams=5,
            num_return_sequences=1
        )
    
    # Decode translation
    translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
    
    # Postprocess if IndicProcessor is available
    if ip:
        translated_text = ip.postprocess_batch([translated_text], lang="hin_Deva")[0]
    
    return translated_text

def translate_with_fallback(text: str, tokenizer, model, device) -> str:
    """
    Translate text using fallback model (Helsinki NLP)
    
    Args:
        text: English text to translate
        tokenizer: Marian tokenizer
        model: Marian model
        device: Device (cuda/cpu)
    
    Returns:
        Hindi translated text
    """
    import torch
    
    # Tokenize
    inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
    
    # Generate translation
    with torch.no_grad():
        generated_tokens = model.generate(**inputs, num_beams=5, max_length=512)
    
    # Decode translation
    translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
    
    return translated_text

def translate_segments(transcription_data: Dict[str, Any], tokenizer, model, ip, device) -> Dict[str, Any]:
    """
    Translate all segments in transcription data
    
    Args:
        transcription_data: Original transcription with timestamps
        tokenizer: Translation tokenizer
        model: Translation model
        ip: IndicProcessor (None for fallback)
        device: Device
    
    Returns:
        Translated transcription data with preserved timestamps
    """
    print("Starting translation process...")
    
    translated_data = transcription_data.copy()
    
    # Translate full text
    full_text = transcription_data.get("text", "")
    if full_text:
        print("Translating full text...")
        cleaned_text = preprocess_text_for_translation(full_text)
        
        if ip is not None:
            translated_full_text = translate_with_indictrans2(cleaned_text, tokenizer, model, ip, device)
        else:
            translated_full_text = translate_with_fallback(cleaned_text, tokenizer, model, device)
        
        translated_data["text"] = translated_full_text
        translated_data["text_hindi"] = translated_full_text
        translated_data["text_english"] = full_text
    
    # Translate individual segments
    if "segments" in transcription_data and transcription_data["segments"]:
        print(f"Translating {len(transcription_data['segments'])} segments...")
        
        translated_segments = []
        
        for i, segment in enumerate(transcription_data["segments"]):
            print(f"Translating segment {i+1}/{len(transcription_data['segments'])}")
            
            segment_text = segment.get("text", "").strip()
            if not segment_text:
                continue
            
            # Clean text
            cleaned_segment_text = preprocess_text_for_translation(segment_text)
            
            # Translate segment
            if ip is not None:
                translated_segment_text = translate_with_indictrans2(cleaned_segment_text, tokenizer, model, ip, device)
            else:
                translated_segment_text = translate_with_fallback(cleaned_segment_text, tokenizer, model, device)
            
            # Create translated segment
            translated_segment = segment.copy()
            translated_segment["text"] = translated_segment_text
            translated_segment["text_hindi"] = translated_segment_text
            translated_segment["text_english"] = segment_text
            
            # Translate word-level if available
            if "words" in segment and segment["words"]:
                translated_words = []
                for word_data in segment["words"]:
                    word_text = word_data.get("word", "").strip()
                    if not word_text:
                        continue
                    
                    # For individual words, we'll keep them for timing but note they need context
                    translated_word_data = word_data.copy()
                    translated_word_data["word_english"] = word_text
                    translated_word_data["word"] = word_text  # Keep original for timing
                    translated_words.append(translated_word_data)
                
                translated_segment["words"] = translated_words
            
            translated_segments.append(translated_segment)
        
        translated_data["segments"] = translated_segments
    
    # Add metadata
    translated_data["translation_info"] = {
        "source_language": "en",
        "target_language": "hi",
        "model_used": "ai4bharat/indictrans2-en-indic-1B" if ip else "Helsinki-NLP/opus-mt-en-hi",
        "translation_method": "IndicTrans2" if ip else "MarianMT"
    }
    
    print("✓ Translation completed")
    return translated_data

def save_translated_data(translated_data: Dict[str, Any], output_path: str):
    """Save translated data to JSON file"""
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(translated_data, f, indent=2, ensure_ascii=False)
    
    print(f"✓ Translated data saved to: {output_path}")

def print_translation_summary(translated_data: Dict[str, Any]):
    """Print a summary of the translation"""
    print("\n" + "="*50)
    print("TRANSLATION SUMMARY")
    print("="*50)
    
    # Translation info
    trans_info = translated_data.get("translation_info", {})
    print(f"Translation method: {trans_info.get('translation_method', 'Unknown')}")
    print(f"Model: {trans_info.get('model_used', 'Unknown')}")
    print(f"Source: {trans_info.get('source_language', 'en')} → Target: {trans_info.get('target_language', 'hi')}")
    
    # Sample translations
    print(f"\nOriginal text: {translated_data.get('text_english', '')[:100]}...")
    print(f"Hindi text: {translated_data.get('text_hindi', '')[:100]}...")
    
    if translated_data.get("segments"):
        print(f"\nTranslated segments: {len(translated_data['segments'])}")
        print("\nSample segment translations:")
        
        for i, segment in enumerate(translated_data["segments"][:3]):
            start = segment.get("start", 0)
            end = segment.get("end", 0)
            eng_text = segment.get("text_english", "")
            hin_text = segment.get("text_hindi", "")
            
            print(f"\n  Segment {i+1} [{start:.2f}s - {end:.2f}s]:")
            print(f"    English: {eng_text}")
            print(f"    Hindi: {hin_text}")

def main():
    """Main function to translate transcription"""
    print("Video Language Changer - Step 2: Translation")
    print("="*50)
    
    # Install requirements
    # install_requirements()
    
    # Get transcription file
    transcription_path = input("Enter path to transcription JSON file: ").strip().strip('"\'')
    
    if not os.path.exists(transcription_path):
        print(f"❌ Transcription file not found: {transcription_path}")
        return
    
    try:
        # Step 1: Load transcription
        transcription_data = load_transcription(transcription_path)
        
        # Step 2: Setup translation model
        tokenizer, model, ip, device = setup_indictrans2()
        
        # Step 3: Translate
        translated_data = translate_segments(transcription_data, tokenizer, model, ip, device)
        
        # Step 4: Save translated data
        base_name = Path(transcription_path).stem.replace("_transcription", "")
        output_path = f"{base_name}_translated.json"
        save_translated_data(translated_data, output_path)
        
        # Step 5: Display summary
        print_translation_summary(translated_data)
        
        print(f"\n✓ Step 2 completed successfully!")
        print(f"Translated file: {output_path}")
        print("\nReady for Step 3: Text-to-Speech")
        
    except Exception as e:
        print(f"❌ Error: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()