from typing import Dict, List, Any, Optional
from models.scenario import RoleplayScenario
from models.assessment import SkillAssessment, SkillScore, ConversationTurn
from services.groq_service import GroqService
from utils.json_handler import JSONHandler

class SkillAnalyzer:
    def __init__(self):
        self.groq_service = GroqService()
        self.json_handler = JSONHandler()
    
    def analyze_session(self, session_id: str, scenario: RoleplayScenario, 
                       conversation_turns: List[ConversationTurn]) -> Optional[SkillAssessment]:
        """Perform comprehensive skill analysis on a completed session"""
        
        # Convert conversation turns to format expected by Groq service
        conversation_data = [
            {
                'speaker': turn.speaker,
                'message': turn.message
            }
            for turn in conversation_turns
        ]
        
        # Get AI analysis
        analysis_result = self.groq_service.analyze_skills(scenario.to_dict(), conversation_data)
        
        if not analysis_result:
            return None
        
        try:
            # Extract skill scores
            skill_scores = []
            for skill_name in scenario.skills_to_assess:
                if skill_name in analysis_result['skill_analysis']:
                    skill_data = analysis_result['skill_analysis'][skill_name]
                    
                    skill_score = SkillScore(
                        skill_name=skill_name,
                        score=int(skill_data.get('score', 0)),
                        evidence=skill_data.get('evidence', []),
                        strengths=skill_data.get('strengths', []),
                        improvement_areas=skill_data.get('improvement_areas', [])
                    )
                    skill_scores.append(skill_score)
            
            # Calculate overall score (average of skill scores)
            if skill_scores:
                overall_score = sum(score.score for score in skill_scores) / len(skill_scores)
            else:
                overall_score = 0.0
            
            # Determine performance level
            performance_level = self._determine_performance_level(overall_score)
            
            # Create assessment
            assessment = SkillAssessment.create_new(
                scenario_id=scenario.id,
                overall_score=round(overall_score, 1),
                performance_level=performance_level,
                skill_scores=skill_scores,
                conversation_turns=conversation_turns,
                conversation_analysis=analysis_result.get('conversation_analysis', {}),
                recommendations=analysis_result.get('recommendations', {})
            )
            
            # Save assessment
            if self.json_handler.save_assessment(assessment.to_dict()):
                return assessment
            else:
                print("Failed to save assessment")
                return None
                
        except Exception as e:
            print(f"Error processing analysis result: {e}")
            return None
    
    def _determine_performance_level(self, overall_score: float) -> str:
        """Determine performance level based on overall score"""
        if overall_score >= 9.0:
            return "Expert"
        elif overall_score >= 7.0:
            return "Advanced"
        elif overall_score >= 5.0:
            return "Intermediate"
        else:
            return "Beginner"
    
    def load_assessment(self, session_id: str) -> Optional[SkillAssessment]:
        """Load an existing assessment"""
        assessment_data = self.json_handler.load_assessment(session_id)
        if assessment_data:
            return SkillAssessment.from_dict(assessment_data)
        return None
    
    def list_assessments(self) -> List[Dict[str, Any]]:
        """List all assessments"""
        return self.json_handler.list_assessments()
    
    def get_skill_summary(self, assessments: List[SkillAssessment]) -> Dict[str, Any]:
        """Generate summary statistics across multiple assessments"""
        if not assessments:
            return {}
        
        # Collect all skill scores
        skill_data = {}
        overall_scores = []
        
        for assessment in assessments:
            overall_scores.append(assessment.overall_score)
            
            for skill_score in assessment.skill_scores:
                skill_name = skill_score.skill_name
                if skill_name not in skill_data:
                    skill_data[skill_name] = []
                skill_data[skill_name].append(skill_score.score)
        
        # Calculate averages
        skill_averages = {}
        for skill_name, scores in skill_data.items():
            skill_averages[skill_name] = {
                'average_score': round(sum(scores) / len(scores), 1),
                'best_score': max(scores),
                'latest_score': scores[-1],
                'improvement': scores[-1] - scores[0] if len(scores) > 1 else 0
            }
        
        return {
            'total_sessions': len(assessments),
            'overall_average': round(sum(overall_scores) / len(overall_scores), 1),
            'best_overall': max(overall_scores),
            'latest_overall': overall_scores[-1],
            'skill_breakdown': skill_averages,
            'performance_trend': self._calculate_trend(overall_scores)
        }
    
    def _calculate_trend(self, scores: List[float]) -> str:
        """Calculate performance trend"""
        if len(scores) < 2:
            return "Insufficient data"
        
        recent_avg = sum(scores[-3:]) / len(scores[-3:])  # Last 3 scores
        early_avg = sum(scores[:3]) / len(scores[:3])     # First 3 scores
        
        improvement = recent_avg - early_avg
        
        if improvement > 1.0:
            return "Improving"
        elif improvement < -1.0:
            return "Declining"
        else:
            return "Stable"