import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import soundfile as sf
import numpy as np
from typing import List
import re

class SimpleHindiTTS:
    """
    Simple Hindi Text-to-Speech using Indic Parler-TTS
    NO reference audio needed - just text and voice description!
    """
    
    def __init__(self):
        """Initialize the Hindi TTS system"""
        print("Loading Indic Parler-TTS model...")
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        
        # Load model and tokenizers
        self.model = ParlerTTSForConditionalGeneration.from_pretrained(
            "ai4bharat/indic-parler-tts"
        ).to(self.device)
        
        self.tokenizer = AutoTokenizer.from_pretrained("ai4bharat/indic-parler-tts")
        self.description_tokenizer = AutoTokenizer.from_pretrained(
            self.model.config.text_encoder._name_or_path
        )
        
        print(f"✅ Model loaded successfully on {self.device}")
    
    def split_text(self, text: str, max_length: int = 100) -> List[str]:
        """Split long text into smaller, more manageable chunks"""
        # Split on Hindi sentence endings
        sentences = re.split(r'[।!?]+', text)
        
        chunks = []
        current_chunk = ""
        
        for sentence in sentences:
            sentence = sentence.strip()
            if not sentence:
                continue
            
            # Make chunks smaller to avoid tensor size issues
            if len(current_chunk) + len(sentence) > max_length and current_chunk:
                chunks.append(current_chunk.strip())
                current_chunk = sentence
            else:
                current_chunk = current_chunk + " " + sentence if current_chunk else sentence
        
        # Add the last chunk
        if current_chunk.strip():
            chunks.append(current_chunk.strip())
        
        # Further split any chunks that are still too long
        final_chunks = []
        for chunk in chunks:
            if len(chunk) <= max_length:
                final_chunks.append(chunk)
            else:
                # Split by commas or spaces if still too long
                words = chunk.split()
                current_word_chunk = ""
                for word in words:
                    if len(current_word_chunk) + len(word) > max_length and current_word_chunk:
                        final_chunks.append(current_word_chunk.strip())
                        current_word_chunk = word
                    else:
                        current_word_chunk = current_word_chunk + " " + word if current_word_chunk else word
                if current_word_chunk.strip():
                    final_chunks.append(current_word_chunk.strip())
        
        return [chunk for chunk in final_chunks if chunk.strip()]
    
    def generate_speech(self, text: str, voice_description: str) -> np.ndarray:
        """Generate speech for given text and voice description"""
        try:
            # Clean and prepare text
            text = text.strip()
            if not text:
                return np.array([])
            
            # Tokenize inputs with proper padding and truncation
            description_input_ids = self.description_tokenizer(
                voice_description, 
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512
            ).to(self.device)
            
            prompt_input_ids = self.tokenizer(
                text, 
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512
            ).to(self.device)
            
            # Generate audio with proper parameters
            with torch.no_grad():
                generation = self.model.generate(
                    input_ids=description_input_ids.input_ids,
                    attention_mask=description_input_ids.attention_mask,
                    prompt_input_ids=prompt_input_ids.input_ids,
                    prompt_attention_mask=prompt_input_ids.attention_mask,
                    do_sample=True,
                    temperature=0.7,
                    max_new_tokens=1000,
                    pad_token_id=self.tokenizer.eos_token_id
                )
            
            # Convert to audio array and fix sampling rate
            audio_arr = generation.cpu().numpy().squeeze()
            
            # Ensure we have valid audio
            if len(audio_arr) == 0:
                print("Warning: Generated empty audio, creating silence")
                return np.zeros(int(2 * 16000))  # 2 seconds of silence
            
            return audio_arr
            
        except Exception as e:
            print(f"Error generating speech: {e}")
            print(f"Text length: {len(text)}, Description length: {len(voice_description)}")
            # Return 2 seconds of silence instead of empty array
            return np.zeros(int(2 * 16000), dtype=np.float32)
    
    def text_to_audio(self, 
                     hindi_text: str, 
                     voice_style: str = "female",
                     quality: str = "high",
                     speed: str = "moderate",
                     output_file: str = "hindi_output.wav") -> str:
        """
        Convert Hindi text to audio
        
        Args:
            hindi_text: Text to convert
            voice_style: "male", "female", "young", "old"
            quality: "high", "medium", "clear"
            speed: "slow", "moderate", "fast"
            output_file: Output filename
        
        Returns:
            Path to generated audio file
        """
        
        # Create voice description based on preferences
        voice_descriptions = {
            "male": f"A male speaker delivers speech with {speed} speed. The recording is of {quality} quality with clear voice.",
            "female": f"A female speaker delivers speech with {speed} speed. The recording is of {quality} quality with clear voice.",
            "young": f"A young speaker delivers energetic speech with {speed} speed. The recording is of {quality} quality.",
            "old": f"An elderly speaker delivers calm speech with {speed} speed. The recording is of {quality} quality.",
        }
        
        description = voice_descriptions.get(voice_style, voice_descriptions["female"])
        
        print(f"Converting text to audio...")
        print(f"Voice style: {voice_style}")
        print(f"Text length: {len(hindi_text)} characters")
        
        # Split text if too long
        chunks = self.split_text(hindi_text)
        print(f"Processing {len(chunks)} chunks...")
        
        # Generate audio for each chunk
        audio_parts = []
        for i, chunk in enumerate(chunks):
            print(f"Processing chunk {i+1}/{len(chunks)}: {chunk[:50]}...")
            
            audio = self.generate_speech(chunk, description)
            if len(audio) > 0:
                audio_parts.append(audio)
        
        if not audio_parts:
            print("⚠️  Warning: No audio generated, creating placeholder")
            # Create a simple beep as placeholder
            sample_rate = 16000
            duration = 1  # 1 second
            t = np.linspace(0, duration, int(sample_rate * duration))
            beep = 0.3 * np.sin(2 * np.pi * 440 * t)  # 440Hz beep
            sf.write(output_file, beep, sample_rate)
            return output_file
        
        # Combine all audio parts
        print("Combining audio chunks...")
        
        # Add small pauses between chunks
        silence = np.zeros(int(0.5 * 16000))  # 0.5 second silence at 16kHz
        
        final_audio = audio_parts[0]
        for audio_part in audio_parts[1:]:
            final_audio = np.concatenate([final_audio, silence, audio_part])
        
        # Save the audio with correct sample rate
        # Parler-TTS uses 16kHz, not 24kHz
        sample_rate = 16000
        sf.write(output_file, final_audio, sample_rate)
        
        duration = len(final_audio) / sample_rate
        print(f"✅ Audio generated successfully!")
        print(f"📁 Saved as: {output_file}")
        print(f"⏱️  Duration: {duration:.2f} seconds")
        
        return output_file

# Predefined voice styles for easy use
VOICE_STYLES = {
    "female_calm": "A female speaker delivers calm and clear speech with moderate speed and pitch. The recording is of very high quality with no background noise.",
    
    "male_energetic": "A male speaker delivers energetic and expressive speech with moderate speed. The recording is of very high quality with clear voice.",
    
    "female_professional": "A female speaker delivers professional and confident speech with moderate speed. The recording is of very high quality.",
    
    "male_calm": "A male speaker delivers calm and soothing speech with slow to moderate speed. The recording is of very high quality.",
    
    "young_female": "A young female speaker delivers lively and animated speech with moderate to fast speed. The recording is of very high quality.",
    
    "narrator": "A clear narrator voice delivers storytelling speech with moderate speed and good expression. The recording is of very high quality."
}

def quick_hindi_tts():
    """Quick and easy Hindi TTS function"""
    
    # Initialize TTS
    tts = SimpleHindiTTS()
    
    # Sample Hindi text
    sample_text = """
    नमस्ते! आज मैं आपको कृत्रिम बुद्धिमत्ता के बारे में बताने जा रहा हूँ। 
    यह एक अत्यंत रोचक और तेज़ी से विकसित होने वाला क्षेत्र है।
    
    आजकल AI का उपयोग हर जगह हो रहा है - मोबाइल फोन से लेकर स्मार्ट कारों तक। 
    यह तकनीक हमारे जीवन को आसान और बेहतर बना रही है।
    
    भारत में भी इस क्षेत्र में बहुत तेज़ी से काम हो रहा है। 
    हमारे इंजीनियर और वैज्ञानिक दुनिया भर में अपना नाम बना रहे हैं।
    
    आने वाले समय में यह तकनीक और भी जबरदस्त काम करेगी। 
    हमें इसका सदुपयोग करना चाहिए और इसके फायदे उठाने चाहिए।
    """
    
    # Convert to audio with different voices
    print("🎤 Generating with female calm voice...")
    tts.text_to_audio(
        hindi_text=sample_text,
        voice_style="female",
        quality="high",
        speed="moderate",
        output_file="hindi_female.wav"
    )
    
    print("\n🎤 Generating with male energetic voice...")
    tts.text_to_audio(
        hindi_text=sample_text,
        voice_style="male",
        quality="high", 
        speed="moderate",
        output_file="hindi_male.wav"
    )
    
    print("\n🎉 Done! Check the generated audio files.")

# Advanced usage with custom voice descriptions
def advanced_hindi_tts():
    """Advanced usage with custom voice descriptions"""
    
    tts = SimpleHindiTTS()
    
    # Custom voice description
    custom_voice = "एक महिला वक्ता धीमी और स्पष्ट आवाज़ में बोलती है। रिकॉर्डिंग बहुत उच्च गुणवत्ता की है।"
    
    text = "यह एक टेस्ट है।"
    
    # Generate with custom description
    audio = tts.generate_speech(text, custom_voice)
    sf.write("custom_voice.wav", audio, 24000)
    print("Custom voice audio saved!")

# Simple test function to debug the issue
def test_simple_tts():
    """Test with very simple input to identify the issue"""
    print("🧪 Testing simple TTS...")
    
    try:
        tts = SimpleHindiTTS()
        
        # Very simple test
        simple_text = "नमस्ते"
        simple_description = "A female speaker."
        
        print(f"Testing with: '{simple_text}'")
        print(f"Voice description: '{simple_description}'")
        
        audio = tts.generate_speech(simple_text, simple_description)
        
        if len(audio) > 0:
            sf.write("test_simple.wav", audio, 16000)
            print("✅ Simple test successful!")
            print("Now trying with full text...")
            
            # Try with full pipeline
            tts.text_to_audio(
                hindi_text="नमस्ते! यह एक परीक्षण है।",
                voice_style="female",
                output_file="test_full.wav"
            )
        else:
            print("❌ Simple test failed")
            print("The model might have compatibility issues")
            
    except Exception as e:
        print(f"❌ Test failed with error: {e}")
        print("Suggestion: Try installing specific versions:")
        print("pip install parler-tts==0.1.* transformers==4.* torch")

if __name__ == "__main__":
    # Run simple test first
    test_simple_tts()