import json
import logging
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from difflib import SequenceMatcher
import re
from langchain_ollama import ChatOllama
from langchain.prompts import PromptTemplate
from langchain.schema import HumanMessage

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class Department:
    """Data class for department information"""
    id: int
    name: str

@dataclass
class TNARequest:
    """Data class for TNA request"""
    user_prompt: str
    organization_departments: List[Department]

@dataclass
class TNAResponse:
    """Data class for TNA response"""
    department_id: int
    department_name: str
    inferred_department: str
    similarity_score: float
    matched: bool

class DepartmentInferenceEngine:
    """Handles department inference using LLM"""
    
    def __init__(self, model_name: str = "gemma3:12b"):
        """Initialize the LLM model"""
        try:
            self.llm = ChatOllama(model=model_name)
            logger.info(f"Initialized LLM with model: {model_name}")
        except Exception as e:
            logger.error(f"Failed to initialize LLM: {e}")
            raise
    
    def infer_department(self, user_prompt: str) -> str:
        """
        Infer the most relevant department for the given user prompt
        
        Args:
            user_prompt: User's business goal or challenge
            
        Returns:
            Inferred department name
        """
        prompt_template = PromptTemplate(
            input_variables=["user_prompt"],
            template="""
            You are an expert business analyst specializing in organizational structure and training needs.
            
            Based on the following business goal or challenge, identify the SINGLE most relevant department.
            
            User Goal/Challenge: {user_prompt}
            
            Common departments include:
            - Sales
            - Marketing
            - HR (Human Resources)
            - Operations
            - Finance
            - Product Management
            - Development/Engineering
            - Customer Service
            - Quality Assurance
            - Research and Development
            
            Instructions:
            1. Analyze the user's goal/challenge carefully
            2. Identify which department would be primarily responsible for addressing this goal
            3. Return ONLY the department name (e.g., "Sales", "Marketing", "HR")
            4. Use general department names, not specific subdivisions
            5. If the goal spans multiple departments, choose the PRIMARY one
            
            Department:
            """
        )
        
        try:
            formatted_prompt = prompt_template.format(user_prompt=user_prompt)
            response = self.llm.invoke([HumanMessage(content=formatted_prompt)])
            
            # Extract department name from response
            inferred_dept = response.content.strip()
            
            # Clean up the response (remove any extra text)
            inferred_dept = self._clean_department_name(inferred_dept)
            
            logger.info(f"Inferred department: {inferred_dept}")
            return inferred_dept
            
        except Exception as e:
            logger.error(f"Error in department inference: {e}")
            return "Unknown"
    
    def _clean_department_name(self, dept_name: str) -> str:
        """Clean and normalize department name"""
        # Remove common prefixes/suffixes
        dept_name = dept_name.replace("Department:", "").replace("Department", "")
        dept_name = dept_name.replace("Team:", "").replace("Team", "")
        
        # Extract first line if multiple lines
        dept_name = dept_name.split('\n')[0]
        
        # Remove quotes and extra spaces
        dept_name = dept_name.strip(' "\'')
        
        return dept_name

class SimilarityMatcher:
    """Handles string similarity matching between inferred and organization departments"""
    
    @staticmethod
    def calculate_similarity(str1: str, str2: str) -> float:
        """
        Calculate similarity between two strings using multiple methods
        
        Args:
            str1: First string
            str2: Second string
            
        Returns:
            Similarity score between 0 and 1
        """
        # Normalize strings
        str1_norm = SimilarityMatcher._normalize_string(str1)
        str2_norm = SimilarityMatcher._normalize_string(str2)
        
        # Method 1: Sequence Matcher (difflib)
        seq_similarity = SequenceMatcher(None, str1_norm, str2_norm).ratio()
        
        # Method 2: Exact match after normalization
        exact_match = 1.0 if str1_norm == str2_norm else 0.0
        
        # Method 3: Substring match
        substring_match = 0.0
        if str1_norm in str2_norm or str2_norm in str1_norm:
            substring_match = 0.9
        
        # Method 4: Word-level similarity
        words1 = set(str1_norm.split())
        words2 = set(str2_norm.split())
        if words1 and words2:
            word_similarity = len(words1.intersection(words2)) / len(words1.union(words2))
        else:
            word_similarity = 0.0
        
        # Take the maximum similarity score
        final_score = max(seq_similarity, exact_match, substring_match, word_similarity)
        
        return final_score
    
    @staticmethod
    def _normalize_string(s: str) -> str:
        """Normalize string for comparison"""
        # Convert to lowercase
        s = s.lower()
        
        # Remove special characters and extra spaces
        s = re.sub(r'[^a-zA-Z0-9\s]', '', s)
        s = re.sub(r'\s+', ' ', s)
        
        # Common abbreviations and variations
        replacements = {
            'human resources': 'hr',
            'human resource': 'hr',
            'information technology': 'it',
            'research and development': 'rd',
            'research & development': 'rd',
            'customer service': 'customer support',
            'customer care': 'customer support',
            'product management': 'product',
            'product manager': 'product',
            'software development': 'development',
            'software engineering': 'development',
            'engineering': 'development',
            'programmer': 'developer',
            'programming': 'development'
        }
        
        for old, new in replacements.items():
            s = s.replace(old, new)
        
        return s.strip()

class TNASystem:
    """Main TNA System class that orchestrates the entire process"""
    
    def __init__(self, model_name: str = "gemma3:12b", similarity_threshold: float = 0.96):
        """
        Initialize TNA System
        
        Args:
            model_name: LLM model name
            similarity_threshold: Minimum similarity score for department matching
        """
        self.inference_engine = DepartmentInferenceEngine(model_name)
        self.similarity_matcher = SimilarityMatcher()
        self.similarity_threshold = similarity_threshold
        
        logger.info(f"TNA System initialized with similarity threshold: {similarity_threshold}")
    
    def process_request(self, request_data: Dict) -> TNAResponse:
        """
        Process TNA request and return department mapping result
        
        Args:
            request_data: Dictionary containing user_prompt and organization_departments
            
        Returns:
            TNAResponse with department mapping result
        """
        try:
            # Parse request
            tna_request = self._parse_request(request_data)
            
            # Step 1: Infer department using LLM
            inferred_department = self.inference_engine.infer_department(tna_request.user_prompt)
            
            # Step 2: Find best matching department
            best_match = self._find_best_match(inferred_department, tna_request.organization_departments)
            
            # Step 3: Apply similarity threshold
            if best_match['similarity_score'] >= self.similarity_threshold:
                response = TNAResponse(
                    department_id=best_match['department'].id,
                    department_name=best_match['department'].name,
                    inferred_department=inferred_department,
                    similarity_score=best_match['similarity_score'],
                    matched=True
                )
            else:
                response = TNAResponse(
                    department_id=0,
                    department_name="Null",
                    inferred_department=inferred_department,
                    similarity_score=best_match['similarity_score'],
                    matched=False
                )
            
            logger.info(f"TNA processing completed: {response}")
            return response
            
        except Exception as e:
            logger.error(f"Error processing TNA request: {e}")
            return TNAResponse(
                department_id=0,
                department_name="Null",
                inferred_department="Error",
                similarity_score=0.0,
                matched=False
            )
    
    def _parse_request(self, request_data: Dict) -> TNARequest:
        """Parse request data into TNARequest object"""
        departments = [
            Department(id=dept['id'], name=dept['name'])
            for dept in request_data['organization_departments']
        ]
        
        return TNARequest(
            user_prompt=request_data['user_prompt'],
            organization_departments=departments
        )
    
    def _find_best_match(self, inferred_dept: str, org_departments: List[Department]) -> Dict:
        """Find the best matching department from organization's departments"""
        best_match = {
            'department': None,
            'similarity_score': 0.0
        }
        
        for dept in org_departments:
            similarity_score = self.similarity_matcher.calculate_similarity(
                inferred_dept, dept.name
            )
            
            if similarity_score > best_match['similarity_score']:
                best_match['department'] = dept
                best_match['similarity_score'] = similarity_score
        
        logger.info(f"Best match for '{inferred_dept}': {best_match['department'].name if best_match['department'] else 'None'} (Score: {best_match['similarity_score']:.4f})")
        
        return best_match
    
    def process_json_request(self, json_data: str) -> Dict:
        """
        Process JSON request and return JSON response
        
        Args:
            json_data: JSON string with request data
            
        Returns:
            Dictionary with response data
        """
        try:
            request_data = json.loads(json_data)
            response = self.process_request(request_data)
            
            return {
                'success': True,
                'department_id': response.department_id,
                'department_name': response.department_name,
                'inferred_department': response.inferred_department,
                'similarity_score': round(response.similarity_score, 4),
                'matched': response.matched,
                'threshold_used': self.similarity_threshold
            }
            
        except json.JSONDecodeError as e:
            logger.error(f"JSON decode error: {e}")
            return {
                'success': False,
                'error': f'Invalid JSON format: {e}',
                'department_id': 0,
                'department_name': 'Null'
            }
        except Exception as e:
            logger.error(f"Processing error: {e}")
            return {
                'success': False,
                'error': str(e),
                'department_id': 0,
                'department_name': 'Null'
            }

# Example usage and testing
def main():
    """Example usage of the TNA System"""
    
    # Initialize TNA System
    tna_system = TNASystem(similarity_threshold=0.96)
    
    # Sample request data
    sample_request = {
        "user_prompt": "I Want to improve my technical skills",
        "organization_departments": [
            {"id": 160, "name": "Operations"},
            {"id": 161, "name": "HR"},
            {"id": 162, "name": "Finance"},
            {"id": 170, "name": "Product Management"},
            {"id": 171, "name": "sales"},
            {"id": 172, "name": "Developer"}
        ]
    }
    
    # Process request
    print("Processing TNA request...")
    result = tna_system.process_json_request(json.dumps(sample_request))
    
    # Display results
    print("\nTNA Results:")
    print(f"Success: {result['success']}")
    print(f"Department ID: {result['department_id']}")
    print(f"Department Name: {result['department_name']}")
    print(f"Inferred Department: {result['inferred_department']}")
    print(f"Similarity Score: {result['similarity_score']}")
    print(f"Matched: {result['matched']}")
    print(f"Threshold Used: {result['threshold_used']}")

if __name__ == "__main__":
    main()