import subprocess
import json
import tempfile
import os
import logging
from typing import Dict, List, Any, Optional

from fastapi import FastAPI, HTTPException, BackgroundTasks, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, validator

# -------------------------------------------------------------------
# Logging setup
# -------------------------------------------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("recommendation_api")

# -------------------------------------------------------------------
# FastAPI app
# -------------------------------------------------------------------
app = FastAPI(
    title="Recommendation API",
    description="API for course recommendations using Ollama + GPU",
    version="1.0.2"
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# -------------------------------------------------------------------
# Models
# -------------------------------------------------------------------

class RecSkill(BaseModel):
    skill_id: str
    skill_type: str
    skill_name: str


class RecJobProfile(BaseModel):
    job_profile_name: Optional[str] = ""
    job_profile_skills: Optional[List[RecSkill]] = []


class RecUserInfo(BaseModel):
    user_name: str
    managerId: Optional[str] = ""
    designation: Optional[str] = ""
    jobProfile: Optional[RecJobProfile] = None
    assignedCourses: Optional[List[str]] = []
    completedCourses: Optional[List[str]] = []
    skills: Optional[List[RecSkill]] = []


class RecCourse(BaseModel):
    courseId: str
    name: str
    short_description: Optional[str] = ""
    description: Optional[str] = ""
    skills: Optional[List[RecSkill]] = []


class UsersUnderManager(BaseModel):
    users: Dict[str, RecUserInfo] = {}


class RecommendationRequest(BaseModel):
    client_id: str
    user_data: Dict[str, RecUserInfo]
    users_under_manager_data: Optional[UsersUnderManager] = None
    client_all_courses_data: Dict[str, List[RecCourse]]
    limit: int = Field(..., ge=1)

    @validator("limit", pre=True)
    def validate_limit(cls, v):
        try:
            v = int(v)
        except Exception:
            raise ValueError("limit must be an integer")
        if v < 1:
            raise ValueError("limit must be >= 1")
        return v


class SingleRecommendation(BaseModel):
    courseId: str
    courseName: str
    score: float
    reason: str
    scenario_relevance: str


class RecommendationsResponse(BaseModel):
    recommendations: List[SingleRecommendation]
    token_count: int


# -------------------------------------------------------------------
# Utils
# -------------------------------------------------------------------

def cleanup_temp_file(file_path: str):
    try:
        if os.path.exists(file_path):
            os.unlink(file_path)
    except Exception:
        logger.warning("Failed to cleanup temp file: %s", file_path)


# -------------------------------------------------------------------
# Endpoint
# -------------------------------------------------------------------

@app.post("/recommendations", response_model=RecommendationsResponse)
async def get_recommendations(
    request: RecommendationRequest,
    background_tasks: BackgroundTasks
):
    try:
        payload = request.dict()
        limit = payload["limit"]

        logger.info("Request received | client_id=%s | limit=%d",
                    payload.get("client_id"), limit)

        # Write payload to temp file
        with tempfile.NamedTemporaryFile(
            mode="w+",
            suffix=".json",
            delete=False
        ) as tmp:
            json.dump(payload, tmp, ensure_ascii=False)
            tmp.flush()
            os.fsync(tmp.fileno())
            temp_path = tmp.name

        # Run processor
        result = subprocess.run(
            ["python3", "recommendations_processor.py", temp_path],
            capture_output=True,
            text=True
        )

        background_tasks.add_task(cleanup_temp_file, temp_path)

        if result.returncode != 0:
            logger.error(result.stderr)
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail="Recommendation processor failed"
            )

        output = result.stdout.strip()
        if not output:
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail="Empty response from recommendation processor"
            )

        data = json.loads(output)

        recommendations = data.get("recommendations", [])
        token_count = data.get("token_count", 0)

        # 🔒 HARD LIMIT ENFORCEMENT (NON-NEGOTIABLE)
        recommendations = recommendations[:limit]

        logger.info("Returning %d recommendations (limit=%d)",
                    len(recommendations), limit)

        return RecommendationsResponse(
            recommendations=recommendations,
            token_count=token_count
        )

    except HTTPException:
        raise
    except Exception as e:
        logger.exception("Unhandled error")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=str(e)
        )
