#!/usr/bin/env python3
"""
Course Recommendation Processor - PRODUCTION VERSION
Handles limit parameter correctly from FastAPI
"""
import sys
import json
import subprocess
from typing import Dict, List, Any, Tuple, Optional

OLLAMA_MODEL = "llama3.1:8b"

def safe_print_json(obj: Dict[str, Any]):
    """Print JSON to stdout."""
    print(json.dumps(obj, ensure_ascii=False))
    sys.stdout.flush()


def load_payload(path: str) -> Dict[str, Any]:
    """Load JSON payload from file."""
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def extract_limit(payload: Dict[str, Any]) -> int:
    """
    Extract and validate limit from payload.
    Returns the limit value or raises ValueError.
    """
    print(f"[extract_limit] Checking for limit in payload...", file=sys.stderr)
    print(f"[extract_limit] Payload keys: {list(payload.keys())}", file=sys.stderr)
    
    if "limit" not in payload:
        print(f"[extract_limit] ERROR: 'limit' key not found", file=sys.stderr)
        raise ValueError("'limit' parameter is required in payload")
    
    raw_limit = payload["limit"]
    print(f"[extract_limit] Found limit: '{raw_limit}' (type: {type(raw_limit).__name__})", file=sys.stderr)
    
    # Handle None/null
    if raw_limit is None:
        print(f"[extract_limit] ERROR: limit is null", file=sys.stderr)
        raise ValueError("'limit' parameter cannot be null")
    
    # Convert to integer
    try:
        if isinstance(raw_limit, bool):
            raise ValueError("Boolean not allowed")
        
        if isinstance(raw_limit, str):
            limit = int(raw_limit.strip())
        else:
            limit = int(raw_limit)
        
        # Validate range
        if limit < 1:
            raise ValueError(f"limit must be >= 1, got {limit}")
        if limit > 100:
            raise ValueError(f"limit must be <= 100, got {limit}")
        
        print(f"[extract_limit] ✓ Valid limit: {limit}", file=sys.stderr)
        return limit
        
    except (ValueError, TypeError, AttributeError) as e:
        print(f"[extract_limit] ERROR: {e}", file=sys.stderr)
        raise ValueError(f"Invalid limit value '{raw_limit}': {str(e)}")


def get_logged_in_user(user_data: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
    """Get the first user from user_data."""
    if not user_data:
        raise ValueError("user_data is empty")
    user_id = list(user_data.keys())[0]
    return user_id, user_data[user_id]


def flatten_courses(client_all_courses_data: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
    """Flatten all courses into a single list."""
    all_courses = []
    for _, courses in client_all_courses_data.items():
        all_courses.extend(courses)
    return all_courses


def map_skill_level(level_raw: Optional[str]) -> str:
    """Map numeric skill_type to readable level."""
    if level_raw is None:
        return "beginner"
    try:
        level = int(level_raw)
    except Exception:
        level = 0
    
    if level <= 1:
        return "beginner"
    elif level == 2:
        return "intermediate"
    elif level == 3:
        return "advanced"
    else:
        return "expert"


def build_skill_sets(user: Dict[str, Any]) -> Dict[str, Any]:
    """Extract user skills from profile and job profile."""
    skills = user.get("skills", []) or []
    skill_ids, skill_names = set(), set()
    
    for s in skills:
        sid = str(s.get("skill_id", "")).strip()
        sname = str(s.get("skill_name", "")).strip().lower()
        if sid:
            skill_ids.add(sid)
        if sname:
            skill_names.add(sname)
    
    job_profile = user.get("jobProfile") or {}
    jp_skills = job_profile.get("job_profile_skills") or []
    for s in jp_skills:
        sid = str(s.get("skill_id", "")).strip()
        sname = str(s.get("skill_name", "")).strip().lower()
        if sid:
            skill_ids.add(sid)
        if sname:
            skill_names.add(sname)
    
    return {"skill_ids": skill_ids, "skill_names": skill_names}


def extract_course_skills(course: Dict[str, Any]) -> Dict[str, Any]:
    """Extract skill IDs and names from course."""
    skills = course.get("skills", []) or []
    skill_ids, skill_names = set(), set()
    
    for s in skills:
        sid = str(s.get("skill_id", "")).strip()
        sname = str(s.get("skill_name", "")).strip().lower()
        if sid:
            skill_ids.add(sid)
        if sname:
            skill_names.add(sname)
    
    return {"skill_ids": skill_ids, "skill_names": skill_names}


def simple_text_skill_overlap(course: Dict[str, Any], user_skill_names: set) -> int:
    """Count skill name mentions in course text."""
    text = (
        (course.get("name") or "") + " " +
        (course.get("short_description") or "") + " " +
        (course.get("description") or "")
    ).lower()
    return sum(1 for s in user_skill_names if s and s in text)


def build_peer_course_stats(users_under_manager: Dict[str, Any], logged_user_id: str) -> Dict[str, int]:
    """Build completion statistics from peer users."""
    if not users_under_manager:
        return {}
    
    users = users_under_manager.get("users", {}) or {}
    stats = {}
    
    for uid, u in users.items():
        if str(uid) == str(logged_user_id):
            continue
        completed = u.get("completedCourses", []) or []
        for cid in completed:
            cid = str(cid)
            stats[cid] = stats.get(cid, 0) + 1
    
    return stats


def score_course(
    course: Dict[str, Any],
    user_features: Dict[str, Any],
    peer_stats: Dict[str, int],
    logged_user: Dict[str, Any],
    total_peers: int
) -> float:
    """Score a course (0-100) based on relevance to user."""
    user_skill_ids = user_features["skill_ids"]
    user_skill_names = user_features["skill_names"]
    
    completed = set(str(c) for c in (logged_user.get("completedCourses") or []))
    assigned = set(str(c) for c in (logged_user.get("assignedCourses") or []))
    
    cid = str(course.get("courseId"))
    
    # Skip completed courses
    if cid in completed:
        return 0.0
    
    cskills = extract_course_skills(course)
    c_skill_ids = cskills["skill_ids"]
    
    # Skill overlap (Jaccard similarity)
    if user_skill_ids or c_skill_ids:
        inter = len(user_skill_ids & c_skill_ids)
        uni = len(user_skill_ids | c_skill_ids)
        jaccard = inter / uni if uni > 0 else 0.0
    else:
        jaccard = 0.0
    
    # Text-based skill detection
    text_overlap = simple_text_skill_overlap(course, user_skill_names)
    text_overlap_norm = min(text_overlap / 3.0, 1.0)
    
    # Weighted skill score (0-70)
    skill_score = (jaccard * 0.6 + text_overlap_norm * 0.4) * 70.0
    
    # Peer completion signal (0-25)
    peer_completed_count = peer_stats.get(cid, 0)
    peer_ratio = peer_completed_count / total_peers if total_peers > 0 else 0.0
    peer_score = peer_ratio * 25.0
    
    # Description quality (0-5)
    desc_len = len((course.get("description") or "").strip())
    richness = min(desc_len / 2000.0, 1.0)
    misc_score = richness * 5.0
    
    total = skill_score + peer_score + misc_score
    
    # Boost assigned courses
    if cid in assigned and cid not in completed:
        total += 5.0
    
    return max(0.0, min(100.0, total))


def select_top_courses(payload: Dict[str, Any], limit: int) -> Tuple[str, Dict[str, Any], List[Dict[str, Any]]]:
    """
    Score all courses and return top N.
    Returns EXACTLY 'limit' courses.
    """
    print(f"\n{'='*70}", file=sys.stderr)
    print(f"SELECTING TOP {limit} COURSES", file=sys.stderr)
    print(f"{'='*70}", file=sys.stderr)
    
    user_data = payload.get("user_data", {}) or {}
    logged_user_id, logged_user = get_logged_in_user(user_data)
    user_features = build_skill_sets(logged_user)
    
    users_under_manager = payload.get("users_under_manager_data") or {}
    peer_stats = build_peer_course_stats(users_under_manager, logged_user_id)
    total_peers = len((users_under_manager.get("users") or {}).keys())
    
    all_courses = flatten_courses(payload.get("client_all_courses_data", {}))
    print(f"Total courses available: {len(all_courses)}", file=sys.stderr)
    
    # Score all courses
    scored = []
    for c in all_courses:
        score = score_course(c, user_features, peer_stats, logged_user, total_peers)
        if score <= 0:
            continue
        cc = dict(c)
        cc["_numeric_score"] = score
        scored.append(cc)
    
    # Sort by score descending
    scored.sort(key=lambda x: x["_numeric_score"], reverse=True)
    print(f"Courses with positive scores: {len(scored)}", file=sys.stderr)
    
    # Take top N
    selected = scored[:limit]
    print(f"✓ Selected {len(selected)} courses (limit={limit})", file=sys.stderr)
    
    return logged_user_id, logged_user, selected


def call_ollama(system_prompt: str, user_prompt: str) -> str:
    """Call Ollama LLM."""
    full_prompt = f"{system_prompt}\n\n{user_prompt}"
    
    result = subprocess.run(
        ["ollama", "run", OLLAMA_MODEL],
        input=full_prompt,
        capture_output=True,
        text=True,
        timeout=120
    )
    
    if result.returncode != 0:
        raise RuntimeError(f"Ollama error: {result.stderr}")
    
    return result.stdout.strip()


def build_llm_context(logged_user: Dict[str, Any], courses: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Build context for LLM."""
    user_skills = logged_user.get("skills", []) or []
    user_skill_summary = []
    for s in user_skills:
        name = s.get("skill_name", "")
        level = map_skill_level(s.get("skill_type"))
        user_skill_summary.append({"name": name, "current_level": level})
    
    jp = logged_user.get("jobProfile") or {}
    jp_name = jp.get("job_profile_name") or ""
    jp_skills = jp.get("job_profile_skills") or []
    jp_skill_summary = []
    for s in jp_skills:
        name = s.get("skill_name", "")
        level = map_skill_level(s.get("skill_type"))
        jp_skill_summary.append({"name": name, "required_level": level})
    
    courses_info = []
    for idx, c in enumerate(courses, 1):
        cskills = c.get("skills", []) or []
        course_skill_list = []
        for s in cskills:
            name = s.get("skill_name", "")
            level = map_skill_level(s.get("skill_type"))
            course_skill_list.append({"name": name, "target_level": level})
        
        courses_info.append({
            "rank": idx,
            "courseId": str(c.get("courseId")),
            "courseName": c.get("name", ""),
            "short_description": c.get("short_description", ""),
            "description": (c.get("description") or "")[:600],
            "numeric_score": round(c.get("_numeric_score", 0.0), 2),
            "skills": course_skill_list
        })
    
    return {
        "user": {
            "name": logged_user.get("user_name", "the user"),
            "current_skills": user_skill_summary,
            "job_profile": {"name": jp_name, "required_skills": jp_skill_summary}
        },
        "courses": courses_info
    }


def generate_with_llm(logged_user: Dict[str, Any], courses: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """Generate recommendations using LLM."""
    num_courses = len(courses)
    print(f"\nLLM: Generating {num_courses} recommendations...", file=sys.stderr)
    
    context = build_llm_context(logged_user, courses)
    
    system_prompt = """You are an expert career development advisor.
Generate unique, personalized course recommendations.
Respond with ONLY valid JSON, no other text."""
    
    user_prompt = f"""Context:

{json.dumps(context, indent=2, ensure_ascii=False)}

Generate exactly {num_courses} recommendations.

For each course provide:
- courseId (exact from context)
- courseName (exact from context)
- score (0-100)
- reason (2-3 sentences about skill progression)
- scenario_relevance (1 sentence about work application)

Return ONLY this JSON structure:
{{
  "recommendations": [
    {{"courseId": "...", "courseName": "...", "score": 85, "reason": "...", "scenario_relevance": "..."}}
  ]
}}"""
    
    try:
        raw_output = call_ollama(system_prompt, user_prompt)
        
        # Extract JSON
        json_start = raw_output.find("{")
        json_end = raw_output.rfind("}") + 1
        
        if json_start < 0 or json_end <= json_start:
            raise ValueError("No JSON in LLM output")
        
        json_str = raw_output[json_start:json_end]
        data = json.loads(json_str)
        
        recommendations = data.get("recommendations", [])
        
        # Normalize fields
        for rec in recommendations:
            try:
                rec["score"] = max(0.0, min(100.0, float(rec.get("score", 0))))
            except Exception:
                rec["score"] = 50.0
            
            rec["courseId"] = str(rec.get("courseId", ""))
            rec["courseName"] = str(rec.get("courseName", ""))
            rec["reason"] = str(rec.get("reason", "")).strip()
            rec["scenario_relevance"] = str(rec.get("scenario_relevance", "")).strip()
        
        print(f"✓ LLM generated {len(recommendations)} recommendations", file=sys.stderr)
        return recommendations
        
    except Exception as e:
        print(f"✗ LLM failed: {e}", file=sys.stderr)
        raise


def find_skill_context(logged_user: Dict[str, Any], course: Dict[str, Any]) -> Dict[str, Any]:
    """Find relevant skill context between user and course."""
    user_skills = logged_user.get("skills", []) or []
    course_skills = course.get("skills", []) or []
    
    user_skill_map = {
        s.get("skill_name", "").lower(): {
            "name": s.get("skill_name", ""),
            "level": map_skill_level(s.get("skill_type"))
        }
        for s in user_skills if s.get("skill_name")
    }
    
    course_skill_map = {
        s.get("skill_name", "").lower(): {
            "name": s.get("skill_name", ""),
            "level": map_skill_level(s.get("skill_type"))
        }
        for s in course_skills if s.get("skill_name")
    }
    
    common = set(user_skill_map.keys()) & set(course_skill_map.keys())
    
    if common:
        key = list(common)[0]
        return {
            "skill_name": user_skill_map[key]["name"],
            "user_level": user_skill_map[key]["level"],
            "course_level": course_skill_map[key]["level"],
            "has_overlap": True
        }
    
    if user_skills:
        us = user_skills[0]
        user_level = map_skill_level(us.get("skill_type"))
        next_level = {
            "beginner": "intermediate",
            "intermediate": "advanced",
            "advanced": "expert"
        }.get(user_level, "advanced")
        
        return {
            "skill_name": us.get("skill_name", "professional skills"),
            "user_level": user_level,
            "course_level": next_level,
            "has_overlap": False
        }
    
    return {
        "skill_name": "professional skills",
        "user_level": "current",
        "course_level": "advanced",
        "has_overlap": False
    }


def generate_fallback(logged_user: Dict[str, Any], courses: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """Generate fallback recommendations without LLM."""
    num_courses = len(courses)
    print(f"\nFallback: Generating {num_courses} recommendations...", file=sys.stderr)
    
    recommendations = []
    scenarios = [
        "when working on {skill} in daily projects",
        "during {skill} technical discussions",
        "when mentoring team members on {skill}",
        "while solving {skill} challenges",
        "when collaborating on {skill} solutions"
    ]
    
    for idx, course in enumerate(courses):
        ctx = find_skill_context(logged_user, course)
        skill = ctx["skill_name"]
        ulvl = ctx["user_level"]
        clvl = ctx["course_level"]
        
        if ctx["has_overlap"]:
            reason = f"This course advances your {skill} skills from {ulvl} to {clvl} level through structured, hands-on learning that builds on your existing knowledge."
        else:
            reason = f"This course develops essential {skill} capabilities, progressing from {ulvl} fundamentals to {clvl} proficiency through practical exercises and real-world examples."
        
        scenario = scenarios[idx % len(scenarios)].format(skill=skill)
        
        recommendations.append({
            "courseId": str(course.get("courseId")),
            "courseName": course.get("name", ""),
            "score": round(course.get("_numeric_score", 0.0), 2),
            "reason": reason,
            "scenario_relevance": scenario
        })
    
    print(f"✓ Fallback generated {len(recommendations)} recommendations", file=sys.stderr)
    return recommendations


def estimate_tokens(text: str) -> int:
    """Rough token count estimation."""
    return int(len(text.split()) * 1.3)


def main():
    """Main entry point."""
    try:
        if len(sys.argv) < 2:
            safe_print_json({
                "recommendations": [],
                "token_count": 0,
                "error": "No input file provided"
            })
            return
        
        path = sys.argv[1]
        payload = load_payload(path)
        
        print(f"\n{'#'*70}", file=sys.stderr)
        print(f"# RECOMMENDATION PROCESSOR", file=sys.stderr)
        print(f"{'#'*70}\n", file=sys.stderr)
        
        # Extract and validate limit
        limit = extract_limit(payload)
        
        print(f"\n{'*'*70}", file=sys.stderr)
        print(f"* LIMIT: {limit} recommendations requested", file=sys.stderr)
        print(f"{'*'*70}\n", file=sys.stderr)
        
        # Select top courses
        logged_user_id, logged_user, selected = select_top_courses(payload, limit)
        
        if not selected:
            print(f"No suitable courses found", file=sys.stderr)
            safe_print_json({"recommendations": [], "token_count": 0})
            return
        
        # Generate recommendations
        try:
            recommendations = generate_with_llm(logged_user, selected)
            if not recommendations:
                raise ValueError("LLM returned empty list")
        except Exception as e:
            print(f"LLM failed, using fallback: {e}", file=sys.stderr)
            recommendations = generate_fallback(logged_user, selected)
        
        # Final enforcement
        recommendations = recommendations[:limit]
        actual = len(recommendations)
        
        print(f"\n{'='*70}", file=sys.stderr)
        print(f"FINAL OUTPUT", file=sys.stderr)
        print(f"Requested: {limit} | Generated: {actual} | Match: {'✓' if actual == limit else '✗'}", file=sys.stderr)
        print(f"{'='*70}\n", file=sys.stderr)
        
        token_count = estimate_tokens(json.dumps(payload)[:5000])
        
        safe_print_json({
            "recommendations": recommendations,
            "token_count": token_count
        })
        
    except ValueError as ve:
        print(f"\n✗ Validation Error: {ve}\n", file=sys.stderr)
        safe_print_json({
            "recommendations": [],
            "token_count": 0,
            "error": str(ve)
        })
    except Exception as e:
        print(f"\n✗ Fatal Error: {e}\n", file=sys.stderr)
        import traceback
        traceback.print_exc(file=sys.stderr)
        safe_print_json({
            "recommendations": [],
            "token_count": 0,
            "error": f"Fatal error: {str(e)}"
        })


if __name__ == "__main__":
    main()