import os
import json
import logging
from enum import Enum
from typing import Dict, List, Optional, Any

logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
logger = logging.getLogger("assessment_report_processor")

# System startup confirmation (Fixed NameError: os not defined issue)
logger.info("Initializing Agentic Analytics Processor with os module ready.")

OLLAMA_MODEL = "none"
OLLAMA_BASE_URL = ""

DATA_ROOT = os.environ.get("TRAINED_DATA_DIR", os.path.join(os.path.dirname(os.path.abspath(__file__)), "trained_data"))
os.makedirs(DATA_ROOT, exist_ok=True)

class TrainingStatus(str, Enum):
    QUEUED = "queued"
    PROCESSING = "processing"
    COMPLETED = "completed"
    FAILED = "failed"

def _client_dir(client_id: int) -> str:
    return os.path.join(DATA_ROOT, str(client_id))

def _ensure_client_dir(client_id: int) -> str:
    path = _client_dir(client_id)
    os.makedirs(path, exist_ok=True)
    return path

def _classify_question(q: dict) -> dict:
    qid = str(q.get("id", ""))
    cat = str(q.get("categoryName") or "General Knowledge").strip()
    
    # Extract the root domain (e.g., "Physics > Class 11" -> "Physics")
    root_domain = cat.split(">")[0].split("-")[0].strip()
    
    # Clean up formatting to make it look highly professional
    # e.g. "english grammar and speaking" -> "English Grammar And Speaking"
    if not root_domain or root_domain.lower() in ["input type", "test", "uncategorized", "default"]:
        skill = "General Assessment"
    else:
        # Title case but fix small words if necessary, title() is usually fine for domains
        skill = root_domain.title()
        
        # Specific overrides for professional feel based on their payload
        if skill.lower() == "gk":
            skill = "General Knowledge"
            
    return {"question_id": qid, "categoryName": cat, "skill": skill}

def start_training_meta(client_id, question_bank, assessment_summary, completed_users_data, user_attempt_details, manager_mapping=None):
    path = _ensure_client_dir(client_id)
    meta = {"status": TrainingStatus.QUEUED.value}
    with open(os.path.join(path, "status.json"), "w") as f:
        json.dump(meta, f)
    return meta

def get_training_status(client_id):
    path = os.path.join(_client_dir(client_id), "status.json")
    if not os.path.exists(path): return None
    with open(path, "r") as f:
        return json.load(f)

def get_analysis_result(client_id):
    path = os.path.join(_client_dir(client_id), "result.json")
    if not os.path.exists(path): return None
    with open(path, "r") as f:
        return json.load(f)

def list_all_clients(): return []
def delete_client_data(client_id): pass
def check_ollama_health(): return True

def build_skill_chart(skill_stats: dict, title: str):
    labels = list(skill_stats.keys())
    data_accuracy = []
    
    for sk in labels:
        s_data = skill_stats[sk]
        attempted = s_data.get("questions_attempted", 0)
        correct = s_data.get("questions_correct", 0)
        
        if attempted > 0:
            acc = (correct / attempted) * 100
            data_accuracy.append(round(acc, 2))
        else:
            data_accuracy.append(0.0)

    return {
        "labels": labels,
        "datasets": [
            {
                "label": "Accuracy %",
                "data": data_accuracy
            }
        ],
        "type": "bar",
        "title": title
    }

def _evaluate_correctness(q_data: dict) -> bool:
    obtained = float(q_data.get("obtained_marks", 0))
    if obtained > 0:
        return True
        
    expected = str(q_data.get("correct_answer") or "").strip().lower()
    actual = str(q_data.get("user_answer") or "").strip().lower()
    
    if not expected:
        return False
        
    if expected == actual:
        return True
        
    # Handle comma-separated list sorting matching
    if ',' in expected and ',' in actual:
        exp_list = sorted([x.strip() for x in expected.split(",")])
        act_list = sorted([x.strip() for x in actual.split(",")])
        if exp_list == act_list:
            return True

    # Handle multiple choice arrays if available
    q_options = q_data.get("question_option", [])
    if q_options and isinstance(q_options, list):
        true_corrects = 0
        user_hits = 0
        user_misses = 0
        
        for opt in q_options:
            is_truth_corr = bool(opt.get("correct_answer_option", False))
            is_user_sel = bool(opt.get("user_answer", False))
            
            if is_truth_corr:
                true_corrects += 1
            if is_user_sel and is_truth_corr:
                user_hits += 1
            if is_user_sel and not is_truth_corr:
                user_misses += 1
                
        if true_corrects > 0 and user_hits == true_corrects and user_misses == 0:
            return True
            
    return False

def run_training_pipeline(client_id: int, question_bank: List[dict], assessment_summary: List[dict], completed_users_data: List[dict], user_attempt_details: List[dict], manager_mapping: Optional[Dict[str, List[int]]] = None):
    try:
        path = _ensure_client_dir(client_id)
        meta = {"status": TrainingStatus.PROCESSING.value}
        with open(os.path.join(path, "status.json"), "w") as f:
            json.dump(meta, f)
            
        logger.info(f"Processing client {client_id}")

        cat_counts = {}
        qid_to_skill = {}
        for q in question_bank:
            cls = _classify_question(q)
            cat = cls["categoryName"]
            skill = cls["skill"]
            qid = cls["question_id"]
            
            qid_to_skill[qid] = skill
            cat_counts[skill] = cat_counts.get(skill, 0) + 1

        chart1_admin = {
            "labels": list(cat_counts.keys()),
            "data": list(cat_counts.values()),
            "type": "doughnut",
            "title": "Question Skill Mapping"
        }

        # Initialize admin_skill_stats with ALL skills from the bank so charts have perfect parity
        admin_skill_stats = {}
        for s in cat_counts.keys():
            admin_skill_stats[s] = {"questions_correct": 0, "questions_attempted": 0}

        manager_stats = {} # manager_id -> {skill: stats}
        manager_unique_q = {} # manager_id -> {skill: set(qid)}
        user_stats = {}    # user_id -> {skills: stats, ...}
        manager_users = {} # manager_id -> set of user_id
        
        manager_names = {}

        global_attempted = 0
        global_correct = 0
        
        total_q_bank = len(question_bank)
        all_user_ids = set()

        for attempt in user_attempt_details:
            user_name = str(attempt.get("userName", "Unknown_User"))
            user_id = str(attempt.get("userId") or attempt.get("user_id") or user_name)
            manager_id = str(attempt.get("managerId") or attempt.get("manager_id") or "Unassigned")
            manager_name = str(attempt.get("managerName", manager_id))
            
            manager_names[manager_id] = manager_name
            all_user_ids.add(user_id)
            
            if manager_id not in manager_stats:
                manager_stats[manager_id] = {}
                manager_users[manager_id] = set()
                manager_unique_q[manager_id] = {}
            manager_users[manager_id].add(user_id)
                
            if user_id not in user_stats:
                user_stats[user_id] = {"user_name": user_name, "skills": {}, "attempted": 0, "correct": 0, "unique_q": {}}

            for q_data in attempt.get("data", []):
                if not q_data.get("is_attempted", True):
                    continue
                qid = str(q_data.get("question_id"))
                skill = qid_to_skill.get(qid, "Uncategorized")
                
                # Check accuracy based on intelligent matching
                is_correct = _evaluate_correctness(q_data)
                corr_val = 1 if is_correct else 0
                
                global_attempted += 1
                global_correct += corr_val
                
                # Admin aggregation (should already be initialized, but safe check)
                if skill not in admin_skill_stats:
                    admin_skill_stats[skill] = {"questions_correct": 0, "questions_attempted": 0}
                admin_skill_stats[skill]["questions_correct"] += corr_val
                admin_skill_stats[skill]["questions_attempted"] += 1
                
                # Manager aggregation
                if skill not in manager_stats[manager_id]:
                    manager_stats[manager_id][skill] = {"questions_correct": 0, "questions_attempted": 0}
                manager_stats[manager_id][skill]["questions_correct"] += corr_val
                manager_stats[manager_id][skill]["questions_attempted"] += 1
                
                if skill not in manager_unique_q[manager_id]:
                    manager_unique_q[manager_id][skill] = set()
                manager_unique_q[manager_id][skill].add(qid)
                
                # User aggregation
                if skill not in user_stats[user_id]["skills"]:
                    user_stats[user_id]["skills"][skill] = {"questions_correct": 0, "questions_attempted": 0}
                user_stats[user_id]["skills"][skill]["questions_correct"] += corr_val
                user_stats[user_id]["skills"][skill]["questions_attempted"] += 1
                user_stats[user_id]["attempted"] += 1
                user_stats[user_id]["correct"] += corr_val
                
                if skill not in user_stats[user_id]["unique_q"]:
                    user_stats[user_id]["unique_q"][skill] = set()
                user_stats[user_id]["unique_q"][skill].add(qid)
                
        chart2_admin = build_skill_chart(admin_skill_stats, "Overall Skill Distribution & Accuracy")

        global_acc = round(((global_correct / global_attempted) * 100), 2) if global_attempted > 0 else 0.0

        admin_summary = {
            "total_questions_in_bank": total_q_bank,
            "total_users": len(all_user_ids),
            "average_accuracy": global_acc,
            "total_managers": len(manager_stats.keys())
        }

        manager_reports = []
        for mid, stats in manager_stats.items():
            mgr_attempted = sum(s.get("questions_attempted", 0) for s in stats.values())
            mgr_correct = sum(s.get("questions_correct", 0) for s in stats.values())
            mgr_acc = round(((mgr_correct / mgr_attempted) * 100), 2) if mgr_attempted > 0 else 0.0
            
            mgr_cat_counts = {sk: len(qids) for sk, qids in manager_unique_q.get(mid, {}).items()}
            mgr_total_questions = sum(mgr_cat_counts.values())
            
            chart1_manager = {
                "labels": list(mgr_cat_counts.keys()),
                "data": list(mgr_cat_counts.values()),
                "type": "doughnut",
                "title": "Team Question Skill Mapping"
            }
            
            manager_reports.append({
                "manager_id": mid,
                "manager_name": manager_names.get(mid, mid),
                "summary": {
                    "total_questions": mgr_total_questions,
                    "total_users": len(manager_users.get(mid, [])),
                    "average_accuracy": mgr_acc
                },
                "chart_data": {
                    "question_skill_mapping_chart": chart1_manager,
                    "skill_distribution_accuracy_chart": build_skill_chart(stats, "Team Skill Distribution & Accuracy")
                }
            })
            
        user_reports = []
        for uid, udata in user_stats.items():
            uname = udata["user_name"]
            stats = udata["skills"]
            u_attempted = udata["attempted"]
            u_correct = udata["correct"]
            u_acc = round(((u_correct / u_attempted) * 100), 2) if u_attempted > 0 else 0.0
            
            usr_cat_counts = {sk: len(qids) for sk, qids in udata.get("unique_q", {}).items()}
            usr_total_questions = sum(usr_cat_counts.values())
            
            chart1_user = {
                "labels": list(usr_cat_counts.keys()),
                "data": list(usr_cat_counts.values()),
                "type": "doughnut",
                "title": "User Question Skill Mapping"
            }
            
            user_reports.append({
                "user_id": uid,
                "user_name": uname,
                "summary": {
                    "total_questions": usr_total_questions,
                    "average_accuracy": u_acc
                },
                "chart_data": {
                    "question_skill_mapping_chart": chart1_user,
                    "skill_distribution_accuracy_chart": build_skill_chart(stats, f"User Skill Distribution & Accuracy ({uname})")
                }
            })

        result = {
            "admin_report": {
                "summary": admin_summary,
                "chart_data": {
                    "question_skill_mapping_chart": chart1_admin,
                    "skill_distribution_accuracy_chart": chart2_admin
                }
            },
            "manager_reports": manager_reports,
            "user_reports": user_reports
        }
        
        with open(os.path.join(path, "result.json"), "w") as f:
            json.dump(result, f, indent=4)
            
        meta = {"status": TrainingStatus.COMPLETED.value}
        with open(os.path.join(path, "status.json"), "w") as f:
            json.dump(meta, f)

        logger.info(f"Successfully processed client {client_id}")

    except Exception as e:
        logger.error(f"Error processing client {client_id}: {e}")
        meta = {"status": TrainingStatus.FAILED.value, "error": str(e)}
        path = _client_dir(client_id)
        if os.path.exists(path):
            with open(os.path.join(path, "status.json"), "w") as f:
                json.dump(meta, f)