from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
import os
import sys
import uuid
from pydantic import BaseModel

# Your environment path
ENV_PATH = "/var/www/eduai.edurigo.com/doc_train/edurigo_ai/Puru/ai4bharat/ai4bharat_env"
print("SS")
def activate_environment():
    """Activate the virtual environment"""
    site_packages = os.path.join(ENV_PATH, "lib", "python3.10", "site-packages")
    if site_packages not in sys.path:
        sys.path.insert(0, site_packages)
    os.environ['VIRTUAL_ENV'] = ENV_PATH
    os.environ['PATH'] = os.path.join(ENV_PATH, 'bin') + os.pathsep + os.environ.get('PATH', '')
print(os.environ)
# Activate environment at import time
activate_environment()
print("ssaaa")
print(os.environ)

from transformers import VitsModel, AutoTokenizer
import torch
import scipy 
import soundfile as sf
import numpy as np

class TTSRequest(BaseModel):
    text: str

app = FastAPI()

@app.post("/generate-tts-simple-stream")
async def generate_tts_simple_stream(request: TTSRequest):
    """Generate TTS using your exact code format"""
    
    if not request.text.strip():
        raise HTTPException(status_code=400, detail="Text cannot be empty")
    
    try:
        # Your exact TTS code
        model = VitsModel.from_pretrained("facebook/mms-tts-hin")
        tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-hin")

        text = request.text

        inputs = tokenizer(text, return_tensors="pt")

        with torch.no_grad():
            output = model(**inputs).waveform

        # Convert tensor to numpy float32
        output_np = output.cpu().numpy().astype(np.float32)

        # If output is shape (1, samples), squeeze to (samples,)
        if output_np.ndim == 2 and output_np.shape[0] == 1:
            output_np = output_np.squeeze(0)

        # Generate unique filename
        filename = f"techno_{str(uuid.uuid4())[:8]}.wav"

        # Save the wav file with the correct sampling rate 16kHz
        sf.write(filename, output_np, samplerate=16000)
        
        return {
            "success": True,
            "message": "TTS generated successfully",
            "filename": filename,
            "download_url": f"/download/{filename}"
        }
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}")

@app.get("/download/{filename}")
async def download_file(filename: str):
    """Download generated audio file"""
    if not os.path.exists(filename):
        raise HTTPException(status_code=404, detail="File not found")
    return FileResponse(filename, media_type="audio/wav", filename=filename)

@app.get("/")
async def root():
    return {"message": "TTS API Ready", "environment": ENV_PATH}