import torch
import torch.nn.functional as F
from transformers import VitsModel, AutoTokenizer
import soundfile as sf
import numpy as np
import time
import logging
from pathlib import Path
import gc
from typing import Optional, Union, List
import threading
from concurrent.futures import ThreadPoolExecutor
import queue

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class ProfessionalHindiTTS:
    """
    Professional High-Performance Hindi Text-to-Speech System
    Optimized for speed, accuracy, and production deployment
    """
    
    def __init__(self, 
                 model_name: str = "facebook/mms-tts-hin",
                 device: Optional[str] = None,
                 enable_cuda_optimization: bool = True,
                 cache_size: int = 100,
                 use_half_precision: bool = True):
        """
        Initialize the TTS system with performance optimizations
        
        Args:
            model_name: HuggingFace model identifier
            device: Target device ('cuda', 'cpu', or None for auto-detection)
            enable_cuda_optimization: Enable CUDA-specific optimizations
            cache_size: Size of the audio cache for repeated texts
            use_half_precision: Use FP16 for faster inference (if supported)
        """
        self.model_name = model_name
        self.cache_size = cache_size
        self.audio_cache = {}
        self.cache_queue = queue.Queue(maxsize=cache_size)
        
        # Device optimization
        self.device = self._setup_device(device)
        self.use_half_precision = use_half_precision and self.device.type == 'cuda'
        
        # Load model and tokenizer with optimizations
        self._load_models()
        
        # Performance optimizations
        if enable_cuda_optimization and torch.cuda.is_available():
            self._apply_cuda_optimizations()
        
        # Pre-compile for faster first inference
        self._warmup_model()
        
        logger.info(f"✅ Professional Hindi TTS initialized on {self.device}")
        logger.info(f"🚀 Half precision: {self.use_half_precision}")
        logger.info(f"📦 Cache enabled: {self.cache_size} slots")
    
    def _setup_device(self, device: Optional[str]) -> torch.device:
        """Setup optimal device configuration"""
        if device is None:
            if torch.cuda.is_available():
                device = 'cuda'
                logger.info(f"🎯 Auto-selected CUDA GPU: {torch.cuda.get_device_name()}")
            else:
                device = 'cpu'
                logger.info("🎯 Using CPU (CUDA not available)")
        
        device_obj = torch.device(device)
        
        # CUDA memory optimization
        if device_obj.type == 'cuda':
            torch.cuda.empty_cache()
            torch.backends.cudnn.benchmark = True  # Optimize for consistent input sizes
            torch.backends.cudnn.deterministic = False  # Allow non-deterministic for speed
            
        return device_obj
    
    def _load_models(self):
        """Load models with optimization flags"""
        logger.info("📥 Loading tokenizer and model...")
        start_time = time.time()
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_name,
            use_fast=True,  # Use fast tokenizer if available
            cache_dir=".cache/tokenizers"
        )
        
        # Load model with optimizations
        self.model = VitsModel.from_pretrained(
            self.model_name,
            torch_dtype=torch.float16 if self.use_half_precision else torch.float32,
            low_cpu_mem_usage=True,
            cache_dir=".cache/models"
        )
        
        # Move to device and set eval mode
        self.model = self.model.to(self.device)
        self.model.eval()
        
        # Memory optimization
        if hasattr(self.model, 'config'):
            self.model.config.use_cache = True
        
        load_time = time.time() - start_time
        logger.info(f"✅ Models loaded in {load_time:.2f}s")
    
    def _apply_cuda_optimizations(self):
        """Apply CUDA-specific optimizations"""
        if torch.cuda.is_available():
            # Compile model for faster inference (PyTorch 2.0+)
            if hasattr(torch, 'compile'):
                try:
                    self.model = torch.compile(self.model, mode='max-autotune')
                    logger.info("🔥 Model compiled with torch.compile")
                except Exception as e:
                    logger.warning(f"⚠️  torch.compile failed: {e}")
            
            # Enable memory format optimization
            if hasattr(self.model, 'to'):
                try:
                    self.model = self.model.to(memory_format=torch.channels_last)
                    logger.info("🧠 Memory format optimized")
                except:
                    pass
    
    def _warmup_model(self):
        """Warmup model for consistent performance"""
        logger.info("🔥 Warming up model...")
        warmup_text = "नमस्ते"
        
        with torch.no_grad():
            inputs = self.tokenizer(warmup_text, return_tensors="pt").to(self.device)
            if self.use_half_precision:
                inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
            
            # Run inference
            _ = self.model(**inputs)
            
        # Clear cache after warmup
        if self.device.type == 'cuda':
            torch.cuda.empty_cache()
        
        logger.info("✅ Model warmup completed")
    
    def _get_cache_key(self, text: str) -> str:
        """Generate cache key for text"""
        return f"tts_{hash(text.strip().lower())}"
    
    def _manage_cache(self, key: str, audio_data: np.ndarray):
        """Manage audio cache with LRU-like behavior"""
        if len(self.audio_cache) >= self.cache_size:
            # Remove oldest entry
            try:
                oldest_key = self.cache_queue.get_nowait()
                if oldest_key in self.audio_cache:
                    del self.audio_cache[oldest_key]
            except queue.Empty:
                pass
        
        self.audio_cache[key] = audio_data
        try:
            self.cache_queue.put_nowait(key)
        except queue.Full:
            pass
    
    def synthesize(self, 
                   text: str, 
                   use_cache: bool = True,
                   normalize_audio: bool = True) -> np.ndarray:
        """
        Synthesize speech from Hindi text with maximum performance
        
        Args:
            text: Hindi text to synthesize
            use_cache: Use audio cache for repeated texts
            normalize_audio: Normalize output audio
            
        Returns:
            Audio waveform as numpy array
        """
        if not text or not text.strip():
            raise ValueError("Text cannot be empty")
        
        text = text.strip()
        cache_key = self._get_cache_key(text) if use_cache else None
        
        # Check cache first
        if use_cache and cache_key in self.audio_cache:
            logger.info("🎯 Cache hit - returning cached audio")
            return self.audio_cache[cache_key]
        
        start_time = time.time()
        
        # Tokenize with optimizations
        with torch.no_grad():
            inputs = self.tokenizer(
                text, 
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512  # Reasonable limit for TTS
            ).to(self.device)
            
            # Apply half precision if enabled
            if self.use_half_precision:
                inputs = {k: v.half() if v.dtype == torch.float32 else v 
                         for k, v in inputs.items()}
            
            # Generate speech with optimizations
            with torch.cuda.amp.autocast(enabled=self.use_half_precision):
                output = self.model(**inputs)
            
            # Extract waveform
            if hasattr(output, 'waveform'):
                waveform = output.waveform
            elif hasattr(output, 'audio'):
                waveform = output.audio
            else:
                # Fallback - check all tensor attributes
                for attr_name in dir(output):
                    attr_value = getattr(output, attr_name)
                    if isinstance(attr_value, torch.Tensor) and attr_value.ndim >= 2:
                        waveform = attr_value
                        break
                else:
                    raise RuntimeError("Could not find audio output in model response")
            
            # Process output
            audio_np = waveform.squeeze().cpu().float().numpy()
            
            # Normalize audio if requested
            if normalize_audio:
                audio_np = self._normalize_audio(audio_np)
        
        # Cache the result
        if use_cache and cache_key:
            self._manage_cache(cache_key, audio_np)
        
        inference_time = time.time() - start_time
        audio_duration = len(audio_np) / 16000  # Assuming 16kHz
        rtf = inference_time / audio_duration  # Real-time factor
        
        logger.info(f"⚡ Synthesis: {inference_time:.3f}s | Duration: {audio_duration:.2f}s | RTF: {rtf:.3f}")
        
        return audio_np
    
    def _normalize_audio(self, audio: np.ndarray) -> np.ndarray:
        """Normalize audio to prevent clipping and ensure consistent volume"""
        # Remove DC offset
        audio = audio - np.mean(audio)
        
        # Normalize to [-0.9, 0.9] to prevent clipping
        max_val = np.max(np.abs(audio))
        if max_val > 0:
            audio = audio * (0.9 / max_val)
        
        return audio
    
    def save_audio(self, 
                   audio: np.ndarray, 
                   filepath: Union[str, Path], 
                   sample_rate: int = 16000,
                   format: str = 'wav') -> None:
        """Save audio with optimized I/O"""
        filepath = Path(filepath)
        filepath.parent.mkdir(parents=True, exist_ok=True)
        
        # Ensure audio is in correct format
        if audio.dtype != np.float32:
            audio = audio.astype(np.float32)
        
        # Save with optimal settings
        sf.write(
            str(filepath), 
            audio, 
            sample_rate,
            format=format.upper(),
            subtype='PCM_16' if format.lower() == 'wav' else None
        )
        
        file_size = filepath.stat().st_size / 1024  # KB
        logger.info(f"💾 Saved: {filepath} ({file_size:.1f}KB)")
    
    def synthesize_batch(self, 
                        texts: List[str], 
                        max_workers: int = 4) -> List[np.ndarray]:
        """
        Synthesize multiple texts in parallel for maximum throughput
        
        Args:
            texts: List of Hindi texts to synthesize
            max_workers: Maximum number of parallel workers
            
        Returns:
            List of audio waveforms
        """
        if not texts:
            return []
        
        start_time = time.time()
        
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = [executor.submit(self.synthesize, text) for text in texts]
            results = [future.result() for future in futures]
        
        total_time = time.time() - start_time
        logger.info(f"🚀 Batch synthesis: {len(texts)} texts in {total_time:.2f}s")
        
        return results
    
    def clear_cache(self):
        """Clear audio cache to free memory"""
        self.audio_cache.clear()
        while not self.cache_queue.empty():
            try:
                self.cache_queue.get_nowait()
            except queue.Empty:
                break
        logger.info("🧹 Cache cleared")
    
    def get_stats(self) -> dict:
        """Get performance statistics"""
        return {
            'device': str(self.device),
            'half_precision': self.use_half_precision,
            'cache_size': len(self.audio_cache),
            'max_cache_size': self.cache_size,
            'model_name': self.model_name,
            'cuda_available': torch.cuda.is_available(),
            'cuda_memory_allocated': torch.cuda.memory_allocated() if torch.cuda.is_available() else 0,
            'cuda_memory_reserved': torch.cuda.memory_reserved() if torch.cuda.is_available() else 0
        }


def main():
    """Professional usage example"""
    # Initialize with maximum performance settings
    tts = ProfessionalHindiTTS(
        enable_cuda_optimization=True,
        use_half_precision=True,
        cache_size=50
    )
    
    # Test texts
    test_texts = [
        "भारत एक महान देश है जिसकी संस्कृति और परंपराएं बहुत समृद्ध हैं।",
        "नमस्ते, आप कैसे हैं? मुझे हिंदी बोलना बहुत अच्छा लगता है।",
        "तकनीक के क्षेत्र में भारत तेजी से आगे बढ़ रहा है।"
    ]
    
    # Single synthesis
    logger.info("🎯 Single synthesis test...")
    audio = tts.synthesize(test_texts[0])
    tts.save_audio(audio, "output/hindi_speech_optimized.wav")
    
    # Batch synthesis for maximum throughput
    logger.info("🚀 Batch synthesis test...")
    batch_results = tts.synthesize_batch(test_texts, max_workers=2)
    
    # Save batch results
    for i, audio in enumerate(batch_results):
        tts.save_audio(audio, f"output/batch_speech_{i+1}.wav")
    
    # Performance stats
    stats = tts.get_stats()
    logger.info("📊 Performance Statistics:")
    for key, value in stats.items():
        logger.info(f"   {key}: {value}")
    
    logger.info("✅ Professional Hindi TTS demonstration completed!")


if __name__ == "__main__":
    main()