"""
LOCAL GPU IMAGE GENERATOR (Optimized for SD15)
==============================================
VRAM-efficient Stable Diffusion v1.5 image generator.
"""

import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from PIL import Image
import base64
from io import BytesIO
from datetime import datetime
import gc
from typing import Optional, Dict, List, Tuple

# =====================================================================
# GPU DETECTION
# =====================================================================
def detect_device() -> str:
    if not torch.cuda.is_available():
        print("❌ CUDA not available! Install PyTorch with GPU support.")
        return "cpu"

    print(f"✅ Using GPU: {torch.cuda.get_device_name(0)}")
    total_vram = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"   Total VRAM: {total_vram:.2f} GB")
    return "cuda"


DEVICE = detect_device()
if DEVICE == "cpu":
    raise RuntimeError("🚨 GPU not detected! Please enable CUDA.")

# =====================================================================
# MODEL CONFIG (SD15 ONLY)
# =====================================================================
MODEL_CONFIGS = {
    "sd15": {
        "repo_id": "runwayml/stable-diffusion-v1-5",
        "dtype": torch.float16,
        "size_multiple": 64,
        "pipeline_class": StableDiffusionPipeline,
    },
}

# =====================================================================
# CATEGORY PROMPTS
# =====================================================================
CATEGORY_STYLES = {
    "realistic": {
        "positive": "photorealistic, 8k uhd, professional photography, detailed lighting, ultra sharp",
        "negative": "cartoon, painting, sketch, low quality, blurry, deformed",
    },
    "anime": {
        "positive": "anime style, studio quality, vibrant colors, smooth shading",
        "negative": "realistic photo, ugly, deformed, blurry",
    },
    "3d": {
        "positive": "3d render, octane render, unreal engine, ray tracing",
        "negative": "2d, cartoon, low quality",
    },
    "cartoonistic": {
        "positive": "cartoon, colorful, animated style, clean outlines",
        "negative": "realistic, photo, blurry",
    },
    "comic": {
        "positive": "comic book style, inking, halftone, bold outlines",
        "negative": "photo, realistic, blurry",
    },
    "cinematic": {
        "positive": "cinematic lighting, professional composition, dramatic depth of field",
        "negative": "flat, low quality, dull",
    },
    "fantasy": {
        "positive": "fantasy concept art, magic, epic, detailed, trending on artstation",
        "negative": "modern, realistic, low quality",
    },
    "cyberpunk": {
        "positive": "cyberpunk, neon lights, futuristic, sci-fi, detailed cityscape",
        "negative": "fantasy, medieval, low quality",
    },
}

# =====================================================================
# MAIN GENERATOR CLASS
# =====================================================================
class LocalImageGenerator:
    def __init__(self, preload_models: List[str] = None):
        self.pipelines: Dict[str, any] = {}
        self.preload_models = preload_models or ["sd15"]
        print("\n🚀 Initializing SD15 image generator...")
        self._preload_all_models()
        print("✅ SD15 ready for generation!\n")

    def _preload_all_models(self):
        for model_name in self.preload_models:
            self._load_model_to_gpu(model_name)
        self._print_vram_status()

    def _load_model_to_gpu(self, model_name: str) -> bool:
        if model_name in self.pipelines:
            print(f"ℹ️  {model_name} already loaded")
            return True

        cfg = MODEL_CONFIGS[model_name]
        print(f"⏳ Loading SD15 model...")

        try:
            pipeline = cfg["pipeline_class"].from_pretrained(
                cfg["repo_id"],
                torch_dtype=cfg["dtype"],
                use_safetensors=True,
            )
            pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
            pipeline = pipeline.to(DEVICE)
            pipeline.enable_attention_slicing()

            try:
                pipeline.enable_xformers_memory_efficient_attention()
                print("   ✅ xformers enabled")
            except Exception:
                print("   ℹ️ xformers not available")

            self.pipelines[model_name] = pipeline
            print("✅ SD15 loaded to GPU")
            return True
        except Exception as e:
            print(f"❌ Failed to load model: {e}")
            return False

    def _print_vram_status(self):
        if torch.cuda.is_available():
            total = torch.cuda.get_device_properties(0).total_memory / 1024**3
            reserved = torch.cuda.memory_reserved(0) / 1024**3
            allocated = torch.cuda.memory_allocated(0) / 1024**3
            print(f"\n📊 VRAM: Allocated={allocated:.2f} GB | Reserved={reserved:.2f} GB | Total={total:.2f} GB")

    def _enhance_prompt(self, prompt: str, category: str) -> Tuple[str, str]:
        style = CATEGORY_STYLES.get(category.lower(), CATEGORY_STYLES["realistic"])
        return f"{prompt}, {style['positive']}", style["negative"]

    def _adjust_dimensions(self, width: int, height: int, model_name: str) -> Tuple[int, int]:
        multiple = MODEL_CONFIGS[model_name]["size_multiple"]
        width = ((width + multiple - 1) // multiple) * multiple
        height = ((height + multiple - 1) // multiple) * multiple
        return width, height

    def generate_single_image(self, prompt, category, model_name, width, height) -> Optional[bytes]:
        if model_name not in self.pipelines:
            print("❌ Model not loaded")
            return None

        pipeline = self.pipelines[model_name]
        enhanced, negative = self._enhance_prompt(prompt, category)
        width, height = self._adjust_dimensions(width, height, model_name)

        torch.cuda.empty_cache()
        print(f"\n🎨 Generating {width}x{height} image...")

        try:
            with torch.inference_mode():
                result = pipeline(
                    prompt=enhanced,
                    negative_prompt=negative,
                    num_inference_steps=25,
                    guidance_scale=7.5,
                    width=width,
                    height=height,
                )

            img = result.images[0]
            buf = BytesIO()
            img.save(buf, format="PNG")

            del result
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            gc.collect()

            print("✅ Generation complete and VRAM cleared.")
            return buf.getvalue()
        except Exception as e:
            print(f"❌ Error: {e}")
            torch.cuda.empty_cache()
            gc.collect()
            return None

    def generate_multiple_images(self, prompt, category, model_name, width, height, image_count) -> List[Dict[str, str]]:
        imgs = []
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        print(f"\n🚀 Starting generation batch of {image_count} images...")
        for i in range(image_count):
            print(f"📸 Generating image {i + 1}/{image_count}")
            img_bytes = self.generate_single_image(prompt, category, model_name, width, height)
            if img_bytes:
                imgs.append({
                    "name": f"{category}_{model_name}_{timestamp}_{i+1:03d}",
                    "base64": base64.b64encode(img_bytes).decode("utf-8"),
                })
        print(f"✅ Batch complete: {len(imgs)}/{image_count} successful")
        return imgs

    def check_health(self) -> Dict:
        gpu = {
            "gpu_name": torch.cuda.get_device_name(0),
            "vram_total_gb": round(torch.cuda.get_device_properties(0).total_memory / 1024**3, 2),
            "vram_allocated_gb": round(torch.cuda.memory_allocated(0) / 1024**3, 2),
            "vram_reserved_gb": round(torch.cuda.memory_reserved(0) / 1024**3, 2),
        }
        return {
            "status": "healthy",
            "device": DEVICE,
            "gpu_available": torch.cuda.is_available(),
            "models_loaded": list(self.pipelines.keys()),
            "pytorch_version": torch.__version__,
            "cuda_version": torch.version.cuda,
            **gpu,
        }

    def unload_model(self, model_name: str) -> bool:
        if model_name not in self.pipelines:
            print(f"⚠️ Model '{model_name}' not loaded")
            return False
        print(f"🧹 Unloading {model_name.upper()} from GPU...")
        del self.pipelines[model_name]
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        print(f"✅ {model_name.upper()} unloaded and VRAM cleared.")
        self._print_vram_status()
        return True
