#!/usr/bin/env python3
"""
Video Language Changer - Step 3: Text-to-Speech
Generates Hindi audio from translated text using AI4Bharat's IndicTTS
"""

import os
import json
import subprocess
import numpy as np
import wave
from pathlib import Path
from typing import Dict, List, Any, Optional, Tuple
import tempfile
import shutil

def install_requirements():
    """Install required packages for Sarvam AI TTS"""
    packages = [
        "requests",
        "librosa",
        "soundfile", 
        "scipy"
    ]
    
    for package in packages:
        try:
            if package == "requests":
                import requests
            elif package == "librosa":
                import librosa
            elif package == "soundfile":
                import soundfile as sf
            elif package == "scipy":
                import scipy
        except ImportError:
            print(f"Installing {package}...")
            subprocess.run(["pip", "install", package], check=True)

def check_espeak():
    """Check if espeak is installed (needed for some TTS engines)"""
    try:
        subprocess.run(["espeak", "--version"], capture_output=True, check=True)
        print("✓ espeak is available")
        return True
    except (subprocess.CalledProcessError, FileNotFoundError):
        print("ℹ️  espeak not found (optional for some TTS engines)")
        print("Install with: sudo apt install espeak espeak-data (Ubuntu/Debian)")
        return False

def load_translated_data(translated_path: str) -> Dict[str, Any]:
    """Load translated JSON file"""
    if not os.path.exists(translated_path):
        raise FileNotFoundError(f"Translated file not found: {translated_path}")
    
    with open(translated_path, 'r', encoding='utf-8') as f:
        translated_data = json.load(f)
    
    print(f"✓ Loaded translated data from: {translated_path}")
    return translated_data

def setup_indic_tts():
    """Setup Sarvam AI TTS"""
    try:
        print("Setting up Sarvam AI TTS...")
        
        # Get API key from user
        api_key = input("Enter your Sarvam AI API key: ").strip()
        
        if not api_key or api_key == "*****":
            print("❌ Please provide a valid Sarvam AI API key")
            print("Get your API key from: https://www.sarvam.ai/")
            raise ValueError("No valid API key provided")
        
        # Test the API key
        import requests
        
        test_response = requests.post(
            "https://api.sarvam.ai/text-to-speech",
            headers={
                "api-subscription-key": api_key
            },
            json={
                "text": "टेस्ट",
                "target_language_code": "hi-IN"
            }
        )
        
        if test_response.status_code == 200:
            print("✓ Sarvam AI API key validated successfully")
            return "sarvam_ai", {"api_key": api_key}
        else:
            print(f"❌ API key validation failed: {test_response.status_code}")
            print(f"Response: {test_response.text}")
            raise ValueError("Invalid API key")
        
    except Exception as e:
        print(f"❌ Failed to setup Sarvam AI: {e}")
        raise

def calculate_speech_timing(text: str, target_duration: float) -> float:
    """
    Calculate appropriate speech rate for given text and duration
    
    Args:
        text: Text to be spoken
        target_duration: Target duration in seconds
    
    Returns:
        Speech rate factor
    """
    # Average speaking rate: ~150 words per minute = 2.5 words per second
    # Average word length: ~5 characters
    # So roughly 12.5 characters per second
    
    char_count = len(text)
    natural_duration = char_count / 12.5  # Rough estimate
    
    if target_duration <= 0:
        return 1.0
    
    # Calculate rate factor
    rate_factor = natural_duration / target_duration
    
    # Clamp to reasonable bounds (0.5x to 2.0x speed)
    rate_factor = max(0.5, min(2.0, rate_factor))
    
    return rate_factor

def synthesize_with_sarvam_ai(text: str, model_data: Dict, target_duration: float = None) -> np.ndarray:
    """
    Synthesize speech using Sarvam AI TTS
    
    Args:
        text: Hindi text to synthesize
        model_data: Dictionary containing API key
        target_duration: Target duration in seconds
    
    Returns:
        Audio array (16kHz, mono)
    """
    import requests
    import librosa
    import tempfile
    import base64
    from scipy.signal import resample
    
    if not text.strip():
        duration = target_duration or 1.0
        return np.zeros(int(16000 * duration))
    
    try:
        api_key = model_data["api_key"]
        
        print(f"Synthesizing with Sarvam AI: {text[:50]}...")
        
        # Make API request
        response = requests.post(
            "https://api.sarvam.ai/text-to-speech",
            headers={
                "api-subscription-key": api_key
            },
            json={
                "text": text,
                "target_language_code": "hi-IN"
            }
        )
        
        if response.status_code != 200:
            print(f"❌ Sarvam AI API error: {response.status_code}")
            print(f"Response: {response.text}")
            # Return silence as fallback
            duration = target_duration or 1.0
            return np.zeros(int(16000 * duration))
        
        response_data = response.json()
        
        # Check if audio is in the response
        if "audio" not in response_data:
            print("❌ No audio in Sarvam AI response")
            duration = target_duration or 1.0
            return np.zeros(int(16000 * duration))
        
        # Decode base64 audio
        audio_base64 = response_data["audio"]
        audio_bytes = base64.b64decode(audio_base64)
        
        # Save to temporary file and load
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
            temp_file.write(audio_bytes)
            temp_path = temp_file.name
        
        # Load audio with librosa
        audio, sr = librosa.load(temp_path, sr=16000, mono=True)
        
        # Clean up temp file
        import os
        os.unlink(temp_path)
        
        # Ensure audio is float64 for compatibility
        audio = audio.astype(np.float64)
        
        current_duration = len(audio) / 16000
        print(f"Generated audio duration: {current_duration:.2f}s")
        
        # Smart timing adjustment (keeping the same logic as before)
        if target_duration and target_duration > 0:
            speed_ratio = current_duration / target_duration
            
            # Define reasonable speed limits (80% to 120% of natural speed)
            min_speed_factor = 0.8  
            max_speed_factor = 1.2  
            
            print(f"Target duration: {target_duration:.2f}s, Speed ratio: {speed_ratio:.2f}")
            
            if speed_ratio > max_speed_factor:
                # Audio is too long, speed it up but not too much
                stretch_factor = max_speed_factor
                target_length = int(current_duration * 16000 / stretch_factor)
                audio = resample(audio, target_length)
                new_duration = len(audio) / 16000
                print(f"⚡ Sarvam AI: Sped up audio to {new_duration:.2f}s (max {max_speed_factor}x speed)")
                
            elif speed_ratio < min_speed_factor:
                # Audio is too short, slow it down but not too much
                stretch_factor = min_speed_factor
                target_length = int(current_duration * 16000 / stretch_factor)
                audio = resample(audio, target_length)
                new_duration = len(audio) / 16000
                print(f"🐌 Sarvam AI: Slowed down audio to {new_duration:.2f}s (min {min_speed_factor}x speed)")
                
            else:
                # Speed ratio is reasonable, apply gentle adjustment
                if abs(speed_ratio - 1.0) > 0.1:  # Only adjust if >10% difference
                    target_length = int(target_duration * 16000)
                    audio = resample(audio, target_length)
                    new_duration = len(audio) / 16000
                    print(f"🎯 Sarvam AI: Adjusted audio to {new_duration:.2f}s (gentle timing sync)")
                else:
                    print(f"✅ Sarvam AI: Audio duration is good, no adjustment needed")
        
        return audio
        
    except Exception as e:
        print(f"❌ Sarvam AI synthesis failed for text: {text[:50]}... Error: {e}")
        # Return silence as fallback
        duration = target_duration or 1.0
        return np.zeros(int(16000 * duration))

def synthesize_with_coqui_indic(text: str, model_data: Dict, target_duration: float = None) -> np.ndarray:
    """
    Synthesize speech using Coqui TTS with Indic models
    
    Args:
        text: Hindi text to synthesize  
        model_data: Dictionary containing TTS object
        target_duration: Target duration in seconds
    
    Returns:
        Audio array (16kHz, mono)
    """
    import librosa
    import tempfile
    
    if not text.strip():
        duration = target_duration or 1.0
        return np.zeros(int(16000 * duration))
    
    try:
        tts = model_data["tts"]
        
        # Generate speech to temporary file
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
            temp_path = temp_file.name
        
        # Synthesize
        tts.tts_to_file(text=text, file_path=temp_path)
        
        # Load audio
        audio, sr = librosa.load(temp_path, sr=16000, mono=True)
        
        # Clean up
        os.unlink(temp_path)
        
        # Adjust duration if needed
        if target_duration and target_duration > 0:
            current_duration = len(audio) / 16000
            if abs(current_duration - target_duration) > 0.1:
                stretch_factor = current_duration / target_duration
                audio = librosa.effects.time_stretch(audio, rate=stretch_factor)
        
        return audio
        
    except Exception as e:
        print(f"Coqui synthesis failed for text: {text[:50]}... Error: {e}")
        duration = target_duration or 1.0
        return np.zeros(int(16000 * duration))
def synthesize_with_gtts(text: str, target_duration: float = None, lang: str = "hi") -> np.ndarray:
    """
    Synthesize speech using gTTS (fallback option)
    
    Args:
        text: Hindi text to synthesize
        target_duration: Target duration in seconds
        lang: Language code
    
    Returns:
        Audio array (16kHz, mono)
    """
    import tempfile
    from gtts import gTTS
    import librosa
    from pydub import AudioSegment
    
    if not text.strip():
        # Return silence for empty text
        duration = target_duration or 1.0
        return np.zeros(int(16000 * duration))
    
    try:
        # Create TTS
        tts = gTTS(text=text, lang=lang, slow=False)
        
        # Save to temporary file
        with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_file:
            tts.save(temp_file.name)
            temp_path = temp_file.name
        
        # Load audio
        audio, sr = librosa.load(temp_path, sr=16000, mono=True)
        
        # Clean up temp file
        os.unlink(temp_path)
        
        # Adjust duration if needed
        if target_duration and target_duration > 0:
            current_duration = len(audio) / 16000
            if abs(current_duration - target_duration) > 0.1:  # If difference > 100ms
                # Time-stretch to match target duration
                stretch_factor = current_duration / target_duration
                audio = librosa.effects.time_stretch(audio, rate=stretch_factor)
        
        return audio
        
    except Exception as e:
        print(f"gTTS synthesis failed for text: {text[:50]}... Error: {e}")
        # Return silence as fallback
        duration = target_duration or 1.0
        return np.zeros(int(16000 * duration))

def create_silence(duration: float) -> np.ndarray:
    """Create silence of specified duration"""
    return np.zeros(int(16000 * duration))

def synthesize_segments(translated_data: Dict[str, Any], tts_engine: str, model_data: Dict = None) -> List[Tuple[np.ndarray, float, float]]:
    """
    Synthesize audio for all segments using Sarvam AI
    
    Args:
        translated_data: Translated transcription data
        tts_engine: TTS engine to use
        model_data: Model data for Sarvam AI (API key)
    
    Returns:
        List of (audio_array, start_time, end_time) tuples
    """
    print(f"Synthesizing Hindi audio segments using {tts_engine}...")
    
    segments = translated_data.get("segments", [])
    if not segments:
        print("No segments found, synthesizing full text...")
        # Fallback to full text
        hindi_text = translated_data.get("text_hindi", translated_data.get("text", ""))
        if hindi_text:
            audio = synthesize_with_sarvam_ai(hindi_text, model_data)
            return [(audio, 0.0, len(audio) / 16000)]
        else:
            return []
    
    synthesized_segments = []
    
    for i, segment in enumerate(segments):
        print(f"Synthesizing segment {i+1}/{len(segments)}")
        
        hindi_text = segment.get("text_hindi", segment.get("text", "")).strip()
        start_time = segment.get("start", 0.0)
        end_time = segment.get("end", start_time + 2.0)
        target_duration = end_time - start_time
        
        if not hindi_text:
            # Create silence for empty segments
            audio = create_silence(target_duration)
        else:
            # Synthesize speech with Sarvam AI
            audio = synthesize_with_sarvam_ai(hindi_text, model_data, target_duration)
        
        synthesized_segments.append((audio, start_time, end_time))
        
        # Print progress
        if i % 5 == 0 or i == len(segments) - 1:
            print(f"  Progress: {i+1}/{len(segments)} segments completed")
    
    print("✓ All segments synthesized with Sarvam AI")
    return synthesized_segments

def combine_audio_segments(synthesized_segments: List[Tuple[np.ndarray, float, float]], 
                          total_duration: float = None) -> np.ndarray:
    """
    Combine audio segments into a single audio track with proper timing
    
    Args:
        synthesized_segments: List of (audio, start_time, end_time) tuples
        total_duration: Total duration of the final audio
    
    Returns:
        Combined audio array
    """
    print("Combining audio segments...")
    
    if not synthesized_segments:
        return np.array([])
    
    # Calculate total duration if not provided
    if total_duration is None:
        total_duration = max(end_time for _, _, end_time in synthesized_segments)
    
    # Create output audio buffer
    sample_rate = 16000
    total_samples = int(total_duration * sample_rate)
    combined_audio = np.zeros(total_samples)
    
    # Place each segment at its correct position
    for audio, start_time, end_time in synthesized_segments:
        start_sample = int(start_time * sample_rate)
        end_sample = min(start_sample + len(audio), total_samples)
        
        # Ensure we don't exceed bounds
        audio_length = end_sample - start_sample
        if audio_length > 0:
            # Trim audio if needed
            audio_segment = audio[:audio_length]
            combined_audio[start_sample:end_sample] = audio_segment
    
    print("✓ Audio segments combined")
    return combined_audio

def save_audio(audio: np.ndarray, output_path: str, sample_rate: int = 16000):
    """Save audio array to WAV file"""
    import soundfile as sf
    
    # Normalize audio to prevent clipping
    if len(audio) > 0:
        max_val = np.max(np.abs(audio))
        if max_val > 0:
            audio = audio / max_val * 0.95
    
    # Save as WAV file
    sf.write(output_path, audio, sample_rate)
    print(f"✓ Audio saved to: {output_path}")

def create_audio_info(translated_data: Dict[str, Any], audio_duration: float, output_path: str) -> Dict[str, Any]:
    """Create metadata about the generated audio"""
    return {
        "audio_file": output_path,
        "duration": audio_duration,
        "sample_rate": 16000,
        "channels": 1,
        "language": "hi",
        "tts_engine": "Sarvam AI",
        "segments_count": len(translated_data.get("segments", [])),
        "original_transcription": translated_data.get("text_english", ""),
        "hindi_text": translated_data.get("text_hindi", "")
    }

def save_audio_info(audio_info: Dict[str, Any], output_path: str):
    """Save audio metadata to JSON file"""
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(audio_info, f, indent=2, ensure_ascii=False)
    
    print(f"✓ Audio metadata saved to: {output_path}")

def print_tts_summary(audio_info: Dict[str, Any]):
    """Print a summary of the TTS process"""
    print("\n" + "="*50)
    print("TEXT-TO-SPEECH SUMMARY")
    print("="*50)
    
    print(f"Audio file: {audio_info['audio_file']}")
    print(f"Duration: {audio_info['duration']:.2f} seconds")
    print(f"Sample rate: {audio_info['sample_rate']} Hz")
    print(f"Channels: {audio_info['channels']}")
    print(f"TTS engine: {audio_info['tts_engine']}")
    print(f"Segments processed: {audio_info['segments_count']}")
    print(f"\nOriginal text: {audio_info['original_transcription'][:100]}...")
    print(f"Hindi text: {audio_info['hindi_text'][:100]}...")

def main():
    """Main function to generate Hindi TTS"""
    print("Video Language Changer - Step 3: Text-to-Speech")
    print("="*55)
    
    # Install requirements
    install_requirements()
    check_espeak()
    
    # Get translated file
    translated_path = input("Enter path to translated JSON file: ").strip().strip('"\'')
    
    if not os.path.exists(translated_path):
        print(f"❌ Translated file not found: {translated_path}")
        return
    
    try:
        # Step 1: Load translated data
        translated_data = load_translated_data(translated_path)
        
        # Step 2: Setup TTS
        tts_engine, tts_model_data = setup_indic_tts()
        
        # Step 3: Synthesize segments
        synthesized_segments = synthesize_segments(translated_data, tts_engine, tts_model_data)
        
        if not synthesized_segments:
            print("❌ No audio segments were generated")
            return
        
        # Step 4: Combine audio
        combined_audio = combine_audio_segments(synthesized_segments)
        
        # Step 5: Save audio
        base_name = Path(translated_path).stem.replace("_translated", "")
        audio_output_path = f"{base_name}_hindi_audio.wav"
        save_audio(combined_audio, audio_output_path)
        
        # Step 6: Save metadata
        audio_duration = len(combined_audio) / 16000
        audio_info = create_audio_info(translated_data, audio_duration, audio_output_path)
        
        info_output_path = f"{base_name}_audio_info.json"
        save_audio_info(audio_info, info_output_path)
        
        # Step 7: Display summary
        print_tts_summary(audio_info)
        
        print(f"\n✓ Step 3 completed successfully!")
        print(f"Hindi audio: {audio_output_path}")
        print(f"Audio info: {info_output_path}")
        print("\nReady for Step 4: Video reconstruction")
        
    except Exception as e:
        print(f"❌ Error: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()  