import torch
import torch.nn.functional as F
from transformers import VitsModel, AutoTokenizer
import soundfile as sf
import numpy as np
import time
from pathlib import Path
import gc
from typing import Optional, Union, List
import threading
from concurrent.futures import ThreadPoolExecutor
import queue
import io
import base64


class HindiTTSProcessor:
    """
    Professional Hindi Text-to-Speech Processing System
    Optimized for speed, accuracy, and production deployment
    """
    
    def __init__(self, 
                 model_name: str = "facebook/mms-tts-hin",
                 device: Optional[str] = None,
                 enable_cuda_optimization: bool = True,
                 cache_size: int = 100,
                 use_half_precision: bool = True):
        """
        Initialize the TTS system with performance optimizations
        
        Args:
            model_name: HuggingFace model identifier
            device: Target device ('cuda', 'cpu', or None for auto-detection)
            enable_cuda_optimization: Enable CUDA-specific optimizations
            cache_size: Size of the audio cache for repeated texts
            use_half_precision: Use FP16 for faster inference (if supported)
        """
        self.model_name = model_name
        self.cache_size = cache_size
        self.audio_cache = {}
        self.cache_queue = queue.Queue(maxsize=cache_size)
        
        # Device optimization
        self.device = self._setup_device(device)
        self.use_half_precision = use_half_precision and self.device.type == 'cuda'
        
        # Load model and tokenizer with optimizations
        self._load_models()
        
        # Performance optimizations
        if enable_cuda_optimization and torch.cuda.is_available():
            self._apply_cuda_optimizations()
        
        # Pre-compile for faster first inference
        self._warmup_model()
    
    def _setup_device(self, device: Optional[str]) -> torch.device:
        """Setup optimal device configuration"""
        if device is None:
            if torch.cuda.is_available():
                device = 'cuda'
            else:
                device = 'cpu'
        
        device_obj = torch.device(device)
        
        # CUDA memory optimization
        if device_obj.type == 'cuda':
            torch.cuda.empty_cache()
            torch.backends.cudnn.benchmark = True
            torch.backends.cudnn.deterministic = False
            
        return device_obj
    
    def _load_models(self):
        """Load models with optimization flags"""
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_name,
            use_fast=True,
            cache_dir=".cache/tokenizers"
        )
        
        # Load model with optimizations
        self.model = VitsModel.from_pretrained(
            self.model_name,
            torch_dtype=torch.float16 if self.use_half_precision else torch.float32,
            low_cpu_mem_usage=True,
            cache_dir=".cache/models"
        )
        
        # Move to device and set eval mode
        self.model = self.model.to(self.device)
        self.model.eval()
        
        # Memory optimization
        if hasattr(self.model, 'config'):
            self.model.config.use_cache = True
    
    def _apply_cuda_optimizations(self):
        """Apply CUDA-specific optimizations"""
        if torch.cuda.is_available():
            # Compile model for faster inference (PyTorch 2.0+)
            if hasattr(torch, 'compile'):
                try:
                    self.model = torch.compile(self.model, mode='max-autotune')
                except Exception as e:
                    pass
            
            # Enable memory format optimization
            if hasattr(self.model, 'to'):
                try:
                    self.model = self.model.to(memory_format=torch.channels_last)
                except:
                    pass
    
    def _warmup_model(self):
        """Warmup model for consistent performance"""
        warmup_text = "नमस्ते"
        
        with torch.no_grad():
            inputs = self.tokenizer(warmup_text, return_tensors="pt").to(self.device)
            if self.use_half_precision:
                inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
            
            # Run inference
            _ = self.model(**inputs)
            
        # Clear cache after warmup
        if self.device.type == 'cuda':
            torch.cuda.empty_cache()
    
    def _get_cache_key(self, text: str) -> str:
        """Generate cache key for text"""
        return f"tts_{hash(text.strip().lower())}"
    
    def _manage_cache(self, key: str, audio_data: np.ndarray):
        """Manage audio cache with LRU-like behavior"""
        if len(self.audio_cache) >= self.cache_size:
            # Remove oldest entry
            try:
                oldest_key = self.cache_queue.get_nowait()
                if oldest_key in self.audio_cache:
                    del self.audio_cache[oldest_key]
            except queue.Empty:
                pass
        
        self.audio_cache[key] = audio_data
        try:
            self.cache_queue.put_nowait(key)
        except queue.Full:
            pass
    
    def synthesize(self, 
                   text: str, 
                   use_cache: bool = True,
                   normalize_audio: bool = True) -> np.ndarray:
        """
        Synthesize speech from Hindi text with maximum performance
        
        Args:
            text: Hindi text to synthesize
            use_cache: Use audio cache for repeated texts
            normalize_audio: Normalize output audio
            
        Returns:
            Audio waveform as numpy array
        """
        if not text or not text.strip():
            raise ValueError("Text cannot be empty")
        
        text = text.strip()
        cache_key = self._get_cache_key(text) if use_cache else None
        
        # Check cache first
        if use_cache and cache_key in self.audio_cache:
            return self.audio_cache[cache_key]
        
        # Tokenize with optimizations
        with torch.no_grad():
            inputs = self.tokenizer(
                text, 
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512
            ).to(self.device)
            
            # Apply half precision if enabled
            if self.use_half_precision:
                inputs = {k: v.half() if v.dtype == torch.float32 else v 
                         for k, v in inputs.items()}
            
            # Generate speech with optimizations
            if self.use_half_precision and torch.cuda.is_available():
                with torch.amp.autocast('cuda', enabled=True):
                    output = self.model(**inputs)
            else:
                output = self.model(**inputs)
            
            # Extract waveform
            if hasattr(output, 'waveform'):
                waveform = output.waveform
            elif hasattr(output, 'audio'):
                waveform = output.audio
            else:
                # Fallback - check all tensor attributes
                for attr_name in dir(output):
                    attr_value = getattr(output, attr_name)
                    if isinstance(attr_value, torch.Tensor) and attr_value.ndim >= 2:
                        waveform = attr_value
                        break
                else:
                    raise RuntimeError("Could not find audio output in model response")
            
            # Process output
            audio_np = waveform.squeeze().cpu().float().numpy()
            
            # Normalize audio if requested
            if normalize_audio:
                audio_np = self._normalize_audio(audio_np)
        
        # Cache the result
        if use_cache and cache_key:
            self._manage_cache(cache_key, audio_np)
        
        return audio_np
    
    def _normalize_audio(self, audio: np.ndarray) -> np.ndarray:
        """Normalize audio to prevent clipping and ensure consistent volume"""
        # Remove DC offset
        audio = audio - np.mean(audio)
        
        # Normalize to [-0.9, 0.9] to prevent clipping
        max_val = np.max(np.abs(audio))
        if max_val > 0:
            audio = audio * (0.9 / max_val)
        
        return audio
    
    def synthesize_to_bytes(self, text: str, sample_rate: int = 16000) -> bytes:
        """
        Synthesize speech and return as bytes for API response
        
        Args:
            text: Hindi text to synthesize
            sample_rate: Audio sample rate
            
        Returns:
            Audio data as bytes
        """
        try:
            audio_np = self.synthesize(text)
            
            # Ensure audio is in correct format
            if audio_np.dtype != np.float32:
                audio_np = audio_np.astype(np.float32)
            
            # Convert to bytes using io.BytesIO
            buffer = io.BytesIO()
            sf.write(buffer, audio_np, sample_rate, format='WAV')
            buffer.seek(0)
            
            return buffer.getvalue()
        except Exception as e:
            print(f"Error in synthesize_to_bytes: {type(e).__name__}: {str(e)}")
            raise
    
    def synthesize_to_base64(self, text: str, sample_rate: int = 16000) -> str:
        """
        Synthesize speech and return as base64 encoded string
        
        Args:
            text: Hindi text to synthesize
            sample_rate: Audio sample rate
            
        Returns:
            Base64 encoded audio data
        """
        audio_bytes = self.synthesize_to_bytes(text, sample_rate)
        return base64.b64encode(audio_bytes).decode('utf-8')
    
    def clear_cache(self):
        """Clear audio cache to free memory"""
        self.audio_cache.clear()
        while not self.cache_queue.empty():
            try:
                self.cache_queue.get_nowait()
            except queue.Empty:
                break
    
    def get_stats(self) -> dict:
        """Get performance statistics"""
        return {
            'device': str(self.device),
            'half_precision': self.use_half_precision,
            'cache_size': len(self.audio_cache),
            'max_cache_size': self.cache_size,
            'model_name': self.model_name,
            'cuda_available': torch.cuda.is_available(),
            'cuda_memory_allocated': torch.cuda.memory_allocated() if torch.cuda.is_available() else 0,
            'cuda_memory_reserved': torch.cuda.memory_reserved() if torch.cuda.is_available() else 0
        }


# Global TTS processor instance
_tts_processor = None

def get_tts_processor() -> HindiTTSProcessor:
    """Get or create the global TTS processor instance"""
    global _tts_processor
    if _tts_processor is None:
        _tts_processor = HindiTTSProcessor(
            enable_cuda_optimization=True,
            use_half_precision=True,
            cache_size=50
        )
    return _tts_processor

def process_text_to_speech(text: str, output_format: str = "bytes") -> Union[bytes, str]:
    """
    Main processing function for text to speech conversion
    
    Args:
        text: Hindi text to convert to speech
        output_format: Output format - 'bytes' or 'base64'
        
    Returns:
        Audio data in requested format
    """
    processor = get_tts_processor()
    
    if output_format == "base64":
        return processor.synthesize_to_base64(text)
    else:
        return processor.synthesize_to_bytes(text)