# roleplay_fast_api.py
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional
from datetime import datetime
import os
import base64
import requests
from dotenv import load_dotenv

# Image processing
from PIL import Image
import io

# Load environment variables
load_dotenv()

# ----- Your existing services/models -----
from services.scenario_generator import ScenarioGenerator
from services.roleplay_engine import RoleplayEngine
from services.groq_service import GroqService
from services.ollama_service import OllamaService
from services.skill_analyzer import SkillAnalyzer
from models.scenario import RoleplayScenario

app = FastAPI(
    title="AI Roleplay Service",
    description="AI service for roleplay scenario preview and conversations",
    version="2.0.0",
)

# CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # tighten for production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize services
groq_service = GroqService()
ollama_service = OllamaService()
scenario_generator = ScenarioGenerator(groq_service, ollama_service)
roleplay_engine = RoleplayEngine(groq_service, ollama_service)
skill_analyzer = SkillAnalyzer(groq_service, ollama_service)

# ---------- Pydantic Models ----------
class SkillData(BaseModel):
    skill_id: str
    skill_name: str

class RoleplayData(BaseModel):
    category: str
    objective: str
    learner_role: str = Field(..., alias="Learner role")
    additional_info: str
    company_policies: str = Field("", alias="Constraints/Policies")
    skills_for_roleplay: List[SkillData]
    difficulty_level: str = "Easy"
    isAdmin: int
    groqRoleplay: int

class TokenCount(BaseModel):
    input: int
    output: int
    total: int

class TokenCounts(BaseModel):
    preview: TokenCount
    conversation: TokenCount
    assessment: TokenCount
    service_used: str  # "groq" or "ollama"

class RoleplayRequest(BaseModel):
    client_id: str
    session_id: str
    roleplay_data: RoleplayData
    query: str

class RoleplayResponse(BaseModel):
    session_id: str
    response: str
    token_counts: TokenCounts

class CharacterDetails(BaseModel):
    name: str
    personality: str
    goals: str
    background: str
    emotional_state: str

class ScenarioSetup(BaseModel):
    context: str
    environment: str
    constraints: str

class PreviewResponse(BaseModel):
    scenario_id: str
    category: str
    objective: str
    learner_role: str
    ai_role: str
    skills_to_assess: List[str]
    scenario_setup: ScenarioSetup
    character_details: CharacterDetails
    scenario_intro: str
    conversation_starter: str
    success_criteria: Dict[str, str]
    difficulty_level: str
    background_info: str

class Slide(BaseModel):
    heading: str
    content: str
    goals: Optional[List[str]] = None

class ScenarioPreviewResponse(BaseModel):
    slides: List[Slide]
    scenario: str
    token_counts: TokenCounts
    # Ultra HD Base64 image string
    roleplay_image_base64: Optional[str] = None

class EndSessionResponse(BaseModel):
    message: str
    session_id: str
    assessment: Optional[Dict[str, Any]] = None
    token_counts: TokenCounts

class EndSessionRequest(BaseModel):
    session_id: str
    roleplay_data: RoleplayData

class CleanupRequest(BaseModel):
    session_id: str
    is_admin: int

class CleanupResponse(BaseModel):
    message: str
    session_id: str
    cleanup_status: Dict[str, bool]


# ---------- Health ----------
@app.get("/")
async def root():
    return {"message": "AI Roleplay Service", "status": "running"}

@app.get("/health")
async def health_check():
    try:
        health_status = {"timestamp": datetime.now().isoformat()}
        api_key = os.getenv("GROQ_API_KEY")
        health_status["groq"] = {"status": "healthy"} if api_key else {"status": "error", "message": "GROQ_API_KEY not configured"}
        ollama_available = ollama_service.health_check()
        health_status["ollama"] = {"status": "healthy" if ollama_available else "unavailable"}
        health_status["overall"] = "healthy" if (api_key or ollama_available) else "error"
        return health_status
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Health check failed: {str(e)}")


# ---------- Roleplay endpoints ----------
@app.post("/roleplay", response_model=RoleplayResponse)
async def handle_roleplay(request: RoleplayRequest):
    try:
        session_id = request.session_id
        use_groq = request.roleplay_data.groqRoleplay == 1
        scenario = load_scenario_for_session(session_id, request.roleplay_data.isAdmin)

        if not scenario:
            scenario = create_scenario_from_request(request, use_groq)
            if not scenario:
                raise HTTPException(status_code=500, detail="Failed to create scenario")
            roleplay_engine.start_session(scenario, request.roleplay_data.isAdmin, force_restart=True)
            token_counts = get_token_counts_response(use_groq)
            return RoleplayResponse(session_id=session_id, response=scenario.conversation_starter, token_counts=token_counts)

        conversation_history = roleplay_engine.get_conversation_history(session_id, request.roleplay_data.isAdmin)
        if conversation_history is None:
            roleplay_engine.start_session(scenario, request.roleplay_data.isAdmin, force_restart=True)
            token_counts = get_token_counts_response(use_groq)
            return RoleplayResponse(session_id=session_id, response=scenario.conversation_starter, token_counts=token_counts)

        ai_response = roleplay_engine.add_learner_response(
            session_id, scenario, request.query, request.roleplay_data.isAdmin, use_groq
        )
        if not ai_response:
            raise HTTPException(status_code=500, detail="Failed to generate AI response")
        token_counts = get_token_counts_response(use_groq)
        return RoleplayResponse(session_id=session_id, response=ai_response, token_counts=token_counts)

    except HTTPException:
        raise
    except Exception as e:
        import traceback; traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")


# ---------- Preview ----------
@app.post("/roleplay_scenario", response_model=ScenarioPreviewResponse)
async def get_roleplay_scenario_preview(request: EndSessionRequest):
    try:
        use_groq = request.roleplay_data.groqRoleplay == 1
        existing_scenario = load_scenario_for_session(request.session_id, request.roleplay_data.isAdmin)
        scenario = existing_scenario or create_scenario_from_request(request, use_groq)
        if not scenario:
            raise HTTPException(status_code=500, detail="Failed to create scenario for preview")

        scenario_text = build_professional_scenario_text(scenario)
        goals_list = build_short_goals(
            learner_role=scenario.learner_role,
            objective=scenario.objective,
            skills=scenario.skills_to_assess or []
        )
        base_slides = format_scenario_as_slides(scenario)
        slides_as_models: List[Slide] = [Slide(heading=s["heading"], content=s["content"]) for s in base_slides]
        slides_as_models.append(
            Slide(heading="Goals", content="Focus on these outcomes during the roleplay.", goals=goals_list)
        )
        token_counts = get_token_counts_response(use_groq)

        # --- UPDATED: Ultra HD Pollinations image generation (Base64, no saving) ---
        roleplay_image_b64 = None
        try:
            # Build an enhanced ultra HD prompt for landscape 16:9
            poll_prompt = build_pollinations_ultra_hd_prompt(
                category=request.roleplay_data.category,
                learner_role=request.roleplay_data.learner_role,
                objective=request.roleplay_data.objective,
                additional_info=request.roleplay_data.additional_info,
                company_policies=request.roleplay_data.company_policies
            )
            roleplay_image_b64 = pollinations_generate_ultra_hd_base64(poll_prompt)
        except Exception as e:
            roleplay_image_b64 = None
            print(f"WARNING: Pollinations ultra HD image generation failed: {e}")

        return ScenarioPreviewResponse(
            slides=slides_as_models,
            scenario=scenario_text,
            token_counts=token_counts,
            roleplay_image_base64=roleplay_image_b64
        )

    except HTTPException:
        raise
    except Exception as e:
        import traceback; traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"Error generating preview: {str(e)}")


# ---------- End session ----------
@app.post("/end_session", response_model=EndSessionResponse)
async def end_session(request: RoleplayRequest):
    try:
        use_groq = request.roleplay_data.groqRoleplay == 1
        scenario = load_scenario_for_session(request.session_id, request.roleplay_data.isAdmin)
        if not scenario:
            raise HTTPException(status_code=404, detail="Session/Scenario not found")

        conversation_turns = roleplay_engine.get_conversation_turns_for_assessment(
            request.session_id, request.roleplay_data.isAdmin
        )
        if not conversation_turns:
            raise HTTPException(status_code=404, detail="No conversation history found")

        success = roleplay_engine.end_session(request.session_id, request.roleplay_data.isAdmin)
        if not success:
            raise HTTPException(status_code=500, detail="Failed to end session")

        assessment = skill_analyzer.analyze_session(
            request.session_id, scenario, conversation_turns, request.roleplay_data.isAdmin, use_groq
        )

        final_token_counts = get_token_counts_response(use_groq)
        ai_service = groq_service if use_groq else ollama_service
        ai_service.reset_token_counts()

        if assessment:
            conversation_deleted = roleplay_engine.json_handler.delete_conversation(
                request.session_id, request.roleplay_data.isAdmin
            )
            print(f"DEBUG: Conversation deleted: {conversation_deleted}")

        response_data: Dict[str, Any] = {
            "message": "Session ended successfully",
            "session_id": request.session_id,
            "token_counts": final_token_counts,
        }
        if assessment:
            response_data["assessment"] = assessment.to_dict()

        return EndSessionResponse(**response_data)

    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error ending session: {str(e)}")


# ---------- Cleanup (scenario/assessment/conversation JSON only) ----------
@app.post("/cleanup", response_model=CleanupResponse)
async def cleanup_session_data(request: CleanupRequest):
    try:
        cleanup_status = {"scenario_deleted": False, "assessment_deleted": False, "conversation_deleted": False}
        try:
            scenario_deleted = scenario_generator.json_handler.delete_scenario(request.session_id, request.is_admin)
            cleanup_status["scenario_deleted"] = scenario_deleted
        except Exception as e:
            print(f"Error deleting scenario: {e}")
        try:
            assessment_deleted = skill_analyzer.json_handler.delete_assessment(request.session_id, request.is_admin)
            cleanup_status["assessment_deleted"] = assessment_deleted
        except Exception as e:
            print(f"Error deleting assessment: {e}")
        try:
            conversation_deleted = roleplay_engine.json_handler.delete_conversation(request.session_id, request.is_admin)
            cleanup_status["conversation_deleted"] = conversation_deleted
        except Exception as e:
            print(f"Error deleting conversation: {e}")
        return CleanupResponse(message="Session data cleaned up successfully", session_id=request.session_id, cleanup_status=cleanup_status)
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error cleaning up session data: {str(e)}")


# ---------- Helpers ----------
def create_scenario_from_request(request: RoleplayRequest, use_groq: bool = True) -> Optional[RoleplayScenario]:
    try:
        roleplay_data = request.roleplay_data
        skills_to_assess = [s.skill_name for s in roleplay_data.skills_for_roleplay]
        details = {
            "background": roleplay_data.additional_info,
            "constraints": roleplay_data.company_policies,
            "environment": "Roleplay conversation",
            "difficulty_level": roleplay_data.difficulty_level,
        }
        ai_role = generate_ai_role(roleplay_data.category, roleplay_data.additional_info)
        scenario = scenario_generator.create_scenario(
            session_id=request.session_id,
            category=roleplay_data.category,
            objective=roleplay_data.objective,
            details=details,
            ai_role=ai_role,
            learner_role=roleplay_data.learner_role,
            skills_to_assess=skills_to_assess,
            is_admin=roleplay_data.isAdmin,
            use_groq=use_groq,
        )
        return scenario
    except Exception as e:
        print(f"Error creating scenario: {e}")
        return None

def generate_ai_role(category: str, additional_info: str) -> str:
    category_roles = {
        "sales": "Potential Customer",
        "customer service": "Customer with Issue",
        "leadership": "Team Member",
        "negotiation": "Negotiation Partner",
        "technical support": "User with Technical Problem",
    }
    base_role = category_roles.get(category.lower(), "Conversation Partner")
    low = additional_info.lower()
    if "enterprise" in low:
        return f"Enterprise {base_role}"
    if "budget" in low or "price" in low:
        return f"Budget-Conscious {base_role}"
    if "frustrated" in low or "complaint" in low:
        return f"Frustrated {base_role}"
    return base_role

def load_scenario_for_session(session_id: str, is_admin: int) -> Optional[RoleplayScenario]:
    try:
        return scenario_generator.load_scenario(session_id, is_admin)
    except Exception as e:
        print(f"Error loading scenario for session {session_id}: {e}")
        return None

def format_scenario_as_slides(scenario: RoleplayScenario) -> List[Dict[str, str]]:
    try:
        slides: List[Dict[str, str]] = []
        context = (scenario.scenario_setup.get("context") or "").strip()
        if context:
            slides.append({"heading": "Context", "content": context})
        constraints = (scenario.scenario_setup.get("constraints") or "").strip()
        if constraints and constraints.lower() not in ["", "not specified", "none"]:
            slides.append({"heading": "Constraints/Policies", "content": constraints})
        return slides
    except Exception as e:
        print(f"Error formatting scenario as slides: {e}")
        return [{"heading": "Context", "content": f"This is a {scenario.category} roleplay scenario where you will practice your skills."}]

def get_token_counts_response(use_groq: bool) -> TokenCounts:
    ai_service = groq_service if use_groq else ollama_service
    td = ai_service.get_token_counts()
    return TokenCounts(
        preview=TokenCount(**td["preview"]),
        conversation=TokenCount(**td["conversation"]),
        assessment=TokenCount(**td["assessment"]),
        service_used="groq" if use_groq else "ollama",
    )

def build_professional_scenario_text(scn: RoleplayScenario) -> str:
    category = scn.category
    learner = scn.learner_role
    ai_role = scn.ai_role
    objective = scn.objective
    context = scn.scenario_setup.get("context") or ""
    parts = [
        f"In this {category.lower()} roleplay, you act as a {learner} collaborating with an {ai_role.lower()} to achieve the objective: {objective.strip()}."
    ]
    if context:
        # Shorten context to first sentence or limit to 100 characters
        short_context = context.strip().split('.')[0] + '.' if '.' in context else (context.strip()[:100] + '...' if len(context) > 100 else context.strip())
        parts.append(short_context)
    return " ".join(parts)

def build_short_goals(learner_role: str, objective: str, skills: List[str]) -> List[str]:
    base = [
        "Clarify the user's intent and missing details",
        "Structure effective prompts with clear instructions and examples",
        "Validate responses for accuracy, tone, and relevance",
        "Iterate based on feedback to improve the next output",
    ]
    text = (learner_role + " " + objective + " " + " ".join(skills)).lower()
    custom: List[str] = []
    if "prompt" in text:
        custom.append("Apply prompt patterns (role, task, context, constraints)")
    if "customer" in text or "support" in text:
        custom.append("Use empathetic, concise language for end users")
    if "technical" in text or "engineering" in text:
        custom.append("Ground outputs in accurate, testable details")
    if "sales" in text or "negotiation" in text:
        custom.append("Surface value, objections, and next steps")
    for c in custom[:2]:
        if len(base) < 4:
            base.append(c)
    return base[:4]


def build_pollinations_ultra_hd_prompt(
    category: str,
    learner_role: str,
    objective: str,
    additional_info: str,
    company_policies: str
) -> str:
    """
    Build a concise, focused ultra HD prompt for Pollinations.
    Shorter, more focused prompts often yield better results.
    """
    # Core scene description
    scene_parts = [
        f"professional {category} scene",
        f"{learner_role} in modern office environment",
        "photorealistic",
        "8K ultra HD",
        "sharp focus",
        "professional lighting",
        "cinematic composition",
        "high detail"
    ]

    # Add context if meaningful
    if objective and len(objective.strip()) > 10:
        scene_parts.append(objective.strip()[:100])  # Limit length

    # Keep prompt concise but descriptive
    return ", ".join(scene_parts[:12])  # Limit total elements


def pollinations_generate_ultra_hd_base64(prompt: str) -> Optional[str]:
    """
    Enhanced version with multiple strategies for HD image generation.
    Uses higher base resolution and better upscaling techniques.
    """
    import urllib.parse as up
    from PIL import Image, ImageEnhance, ImageFilter

    TARGET_WIDTH = 1920
    TARGET_HEIGHT = 1080

    # Strategy 1: Request LARGER than target, then downscale for better quality
    REQUEST_WIDTH = 2560   # Request higher resolution
    REQUEST_HEIGHT = 1440  # 16:9 ratio maintained

    encoded_prompt = up.quote(prompt, safe="")

    # Enhanced URL parameters
    url = (
        f"https://image.pollinations.ai/prompt/{encoded_prompt}"
        f"?width={REQUEST_WIDTH}"
        f"&height={REQUEST_HEIGHT}"
        f"?nologo=true"
        f"&private=true"
        f"&enhance=true"
        f"&model=flux"  # Flux model generally produces better quality
        f"&seed={abs(hash(prompt)) % 1000000}"
    )

    headers = {
        "Accept": "image/jpeg,image/png,image/webp,image/*;q=0.9,*/*;q=0.8",
        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
        "Accept-Encoding": "gzip, deflate, br",
        "Cache-Control": "no-cache"
    }

    try:
        print(f"INFO: Requesting HD image from Pollinations ({REQUEST_WIDTH}x{REQUEST_HEIGHT})...")
        resp = requests.get(url, headers=headers, timeout=120, stream=True)

        if resp.status_code != 200:
            print(f"ERROR: Pollinations returned status {resp.status_code}")
            return None

        content = resp.content
        if not content or len(content) < 1000:  # Sanity check
            print("ERROR: Received invalid/empty image data")
            return None

        # Load and process image
        img = Image.open(io.BytesIO(content))
        img = img.convert("RGB")

        orig_w, orig_h = img.size
        print(f"INFO: Received image: {orig_w}x{orig_h} ({len(content)} bytes)")

        # Apply image enhancement for better quality
        img = enhance_image_quality(img)

        # Resize to target dimensions using high-quality resampling
        img_final = resize_and_crop_hq(img, TARGET_WIDTH, TARGET_HEIGHT)

        # Apply slight sharpening after resize
        img_final = img_final.filter(ImageFilter.UnsharpMask(radius=1, percent=120, threshold=3))

        # Encode with maximum quality
        output_bytes = io.BytesIO()
        img_final.save(
            output_bytes,
            format="JPEG",
            quality=98,  # Higher quality
            optimize=True,
            progressive=True,  # Progressive JPEG for better quality perception
            subsampling=0  # No chroma subsampling for maximum quality
        )

        b64_result = base64.b64encode(output_bytes.getvalue()).decode("utf-8")
        final_size = len(output_bytes.getvalue())

        print(f"SUCCESS: Generated {TARGET_WIDTH}x{TARGET_HEIGHT} HD image ({final_size} bytes)")
        return b64_result

    except requests.Timeout:
        print("ERROR: Timeout while fetching image")
        return None
    except Exception as e:
        print(f"ERROR: Image generation failed: {str(e)}")
        import traceback
        traceback.print_exc()
        return None


def enhance_image_quality(img: Image.Image) -> Image.Image:
    """
    Apply quality enhancements to the image.
    """
    # Slight contrast enhancement
    enhancer = ImageEnhance.Contrast(img)
    img = enhancer.enhance(1.1)

    # Slight sharpness enhancement
    enhancer = ImageEnhance.Sharpness(img)
    img = enhancer.enhance(1.15)

    # Slight color enhancement
    enhancer = ImageEnhance.Color(img)
    img = enhancer.enhance(1.05)

    return img


def resize_and_crop_hq(img: Image.Image, target_w: int, target_h: int) -> Image.Image:
    """
    High-quality resize and crop to exact dimensions.
    Uses Lanczos resampling for best quality.
    """
    orig_w, orig_h = img.size
    target_ratio = target_w / target_h
    orig_ratio = orig_w / orig_h

    # Calculate dimensions to cover target area
    if orig_ratio > target_ratio:
        # Image is wider - fit to height
        new_h = target_h
        new_w = int(orig_w * (target_h / orig_h))
    else:
        # Image is taller - fit to width
        new_w = target_w
        new_h = int(orig_h * (target_w / orig_w))

    # High-quality resize
    img_resized = img.resize((new_w, new_h), resample=Image.LANCZOS)

    # Center crop to exact target dimensions
    left = (new_w - target_w) // 2
    top = (new_h - target_h) // 2
    right = left + target_w
    bottom = top + target_h

    img_cropped = img_resized.crop((left, top, right, bottom))

    return img_cropped


# Alternative: Try multiple image generation services as fallback
def generate_hd_image_with_fallback(prompt: str) -> Optional[str]:
    """
    Try multiple strategies/services for HD image generation.
    """
    # Strategy 1: Pollinations with enhanced settings
    result = pollinations_generate_ultra_hd_base64(prompt)
    if result:
        return result

    print("INFO: Pollinations failed, trying alternative approach...")

    # Strategy 2: Try with different model parameter
    result = try_alternative_pollinations(prompt, model="turbo")
    if result:
        return result

    # Strategy 3: Could add other image generation APIs here
    # (Stability AI, DALL-E, etc.)

    return None


def try_alternative_pollinations(prompt: str, model: str = "turbo") -> Optional[str]:
    """
    Try Pollinations with alternative model settings.
    """
    import urllib.parse as up

    encoded_prompt = up.quote(prompt, safe="")
    url = (
        f"https://image.pollinations.ai/prompt/{encoded_prompt}"
        f"?width=1920"
        f"&height=1080"
        f"&model={model}"
        f"&enhance=true"
        f"&nologo=true"
    )

    try:
        resp = requests.get(url, timeout=120)
        if resp.status_code == 200 and len(resp.content) > 1000:
            img = Image.open(io.BytesIO(resp.content))
            img = img.convert("RGB")
            img = enhance_image_quality(img)
            img = resize_and_crop_hq(img, 1920, 1080)

            output = io.BytesIO()
            img.save(output, format="JPEG", quality=98, optimize=True)
            return base64.b64encode(output.getvalue()).decode("utf-8")
    except Exception as e:
        print(f"Alternative strategy failed: {e}")

    return None
