"""
PROFESSIONAL IMAGE GENERATION API (Optimized for SD15 Only)
===========================================================
FastAPI server using Stable Diffusion v1.5 with VRAM-safe memory handling.
"""

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, validator
from typing import List, Optional
import logging
from contextlib import asynccontextmanager
import gc
import torch

from local_image_generator import LocalImageGenerator

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)

# =====================================================================
# GLOBAL GENERATOR INSTANCE
# =====================================================================
generator: Optional[LocalImageGenerator] = None


@asynccontextmanager
async def lifespan(app: FastAPI):
    """
    Lifespan context manager for initializing and cleaning up models.
    Only SD15 is preloaded to prevent GPU VRAM accumulation.
    """
    global generator

    logger.info("🚀 Starting Image Generation API (SD15 only)...")

    try:
        # Initialize generator (SD15 only)
        generator = LocalImageGenerator(preload_models=["sd15"])
        logger.info("✅ SD15 model preloaded successfully!")
    except Exception as e:
        logger.error(f"❌ Failed to initialize generator: {e}")
        raise RuntimeError(f"Model initialization failed: {e}")

    yield  # API runs here

    logger.info("🛑 Shutting down API and releasing GPU memory...")
    if generator:
        generator.unload_model("sd15")
        del generator
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    logger.info("✅ Clean shutdown complete!")


# =====================================================================
# FASTAPI APP INITIALIZATION
# =====================================================================
app = FastAPI(
    title="Professional Image Generation API (SD15)",
    description="🚀 GPU-based Self-Hosted Image Generator optimized for VRAM efficiency (Stable Diffusion 1.5)",
    version="7.0.0",
    lifespan=lifespan,
)

# Enable CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# =====================================================================
# REQUEST/RESPONSE MODELS
# =====================================================================
class ImageRequest(BaseModel):
    prompt: str = Field(..., description="Text prompt", min_length=3, example="a futuristic city skyline at dusk")
    category: str = Field(..., description="Visual category (e.g., realistic, anime)", example="realistic")
    width: int = Field(512, description="Width in pixels", ge=256, le=1024)
    height: int = Field(512, description="Height in pixels", ge=256, le=1024)
    image_count: int = Field(1, description="Number of images", ge=1, le=3)

    @validator("category")
    def validate_category(cls, value):
        allowed = ["realistic", "3d", "cartoonistic", "comic", "anime", "cinematic", "fantasy", "cyberpunk"]
        if value.lower() not in allowed:
            logger.warning(f"⚠️ Unknown category '{value}', defaulting to 'realistic'")
            return "realistic"
        return value.lower()


class ImageData(BaseModel):
    name: str
    base64: str


class ImageResponse(BaseModel):
    status: str
    message: str
    data: dict


class HealthResponse(BaseModel):
    status: str
    data: dict


# =====================================================================
# ROUTES
# =====================================================================
@app.get("/", tags=["Info"])
async def root():
    return {
        "service": "Professional Image Generation API",
        "version": "7.0.0",
        "model": "sd15",
        "status": "operational",
        "endpoints": {
            "generate": "/generate-image",
            "health": "/health",
            "models": "/models",
            "docs": "/docs",
        },
    }


@app.post("/generate-image", response_model=ImageResponse, tags=["Generation"])
async def generate_image(request: ImageRequest):
    if generator is None:
        raise HTTPException(status_code=503, detail="Generator not initialized. Restart server.")

    logger.info(f"📥 Generating {request.image_count} image(s) with SD15 | {request.width}x{request.height}")
    logger.info(f"   Prompt: {request.prompt[:100]}...")

    try:
        images = generator.generate_multiple_images(
            prompt=request.prompt,
            category=request.category,
            model_name="sd15",
            width=request.width,
            height=request.height,
            image_count=request.image_count,
        )

        if not images:
            raise HTTPException(status_code=500, detail="No images generated.")

        logger.info(f"✅ Successfully generated {len(images)} image(s) using SD15")
        return {
            "status": "success",
            "message": f"Generated {len(images)} image(s) successfully",
            "data": {"images": images, "model_used": "sd15"},
        }

    except Exception as e:
        logger.error(f"❌ Generation error: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Error generating image: {str(e)}")
    finally:
        torch.cuda.empty_cache()
        gc.collect()


@app.get("/health", response_model=HealthResponse, tags=["System"])
async def health_check():
    if generator is None:
        raise HTTPException(status_code=503, detail="Generator not initialized")
    try:
        return {"status": "success", "data": generator.check_health()}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/models", tags=["System"])
async def list_models():
    return {
        "status": "success",
        "data": {
            "available_models": ["sd15"],
            "preloaded_models": ["sd15"],
            "supported_categories": [
                "realistic",
                "3d",
                "cartoonistic",
                "comic",
                "anime",
                "cinematic",
                "fantasy",
                "cyberpunk",
            ],
        },
    }


# =====================================================================
# ERROR HANDLER
# =====================================================================
@app.exception_handler(ValueError)
async def value_error_handler(request, exc):
    return {"status": "error", "message": str(exc), "type": "validation_error"}
