# Multi-Agent Roleplay System with LangGraph & Streamlit
# Using Latest LangGraph 0.5+ patterns

import streamlit as st
from typing import Dict, List, Optional, TypedDict, Annotated
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from langchain_openai import ChatOpenAI
import operator
import json

# ============================================================================
# STATE DEFINITIONS (Latest LangGraph 0.5+ Pattern)
# ============================================================================

class AgentState(TypedDict):
    """Main state for the roleplay system"""
    messages: Annotated[list[BaseMessage], add_messages]
    knowledge_base: str
    current_question: Optional[str]
    user_answer: Optional[str] 
    is_correct: Optional[bool]
    feedback: Optional[str]
    session_progress: Dict[str, int]
    topics_covered: List[str]
    should_continue: bool

# ============================================================================
# AGENT DEFINITIONS
# ============================================================================

class ScenarioManagerAgent:
    def __init__(self, llm):
        self.llm = llm
        
    def __call__(self, state: AgentState) -> AgentState:
        """Manages the overall scenario flow"""
        
        # Check if we're starting
        if not state.get("current_question"):
            greeting_prompt = f"""
            You are a friendly training scenario manager. Start with a warm greeting like:
            "Hi! I hope you are having a nice day. Welcome to the customer support training!"
            
            Then briefly introduce what we'll be doing based on this knowledge:
            {state['knowledge_base'][:200]}...
            
            Keep it conversational and encouraging.
            """
            
            response = self.llm.invoke([HumanMessage(content=greeting_prompt)])
            
            return {
                **state,
                "messages": [AIMessage(content=response.content)],
                "should_continue": True
            }
        
        # Check if we should continue or end
        progress = state.get("session_progress", {})
        if progress.get("correct_answers", 0) >= 3:  # End after 3 correct answers
            end_prompt = """
            Congratulations! You've completed this training session successfully. 
            Great job on learning the customer support guidelines!
            """
            return {
                **state,
                "messages": state["messages"] + [AIMessage(content=end_prompt)],
                "should_continue": False
            }
        
        return state

class QuestionGeneratorAgent:
    def __init__(self, llm):
        self.llm = llm
        
    def __call__(self, state: AgentState) -> AgentState:
        """Generates questions based on knowledge base"""
        
        topics_covered = state.get("topics_covered", [])
        
        prompt = f"""
        Based on this knowledge base:
        {state['knowledge_base']}
        
        Topics already covered: {topics_covered}
        
        Generate ONE specific question to test the learner's understanding.
        Focus on practical scenarios they might encounter.
        
        Make it clear and concise.
        """
        
        response = self.llm.invoke([HumanMessage(content=prompt)])
        
        return {
            **state,
            "current_question": response.content.strip(),
            "messages": state["messages"] + [AIMessage(content=response.content)]
        }

class KnowledgeValidatorAgent:
    def __init__(self, llm):
        self.llm = llm
        
    def __call__(self, state: AgentState) -> AgentState:
        """Validates user answers against knowledge base"""
        
        if not state.get("user_answer"):
            return state
            
        prompt = f"""
        Knowledge Base:
        {state['knowledge_base']}
        
        Question: {state['current_question']}
        User Answer: {state['user_answer']}
        
        Is the user's answer correct according to the knowledge base?
        Respond with only "CORRECT" or "INCORRECT" followed by a brief explanation.
        """
        
        response = self.llm.invoke([HumanMessage(content=prompt)])
        is_correct = "CORRECT" in response.content.upper()
        
        return {
            **state,
            "is_correct": is_correct,
            "validation_result": response.content
        }

class FeedbackAgent:
    def __init__(self, llm):
        self.llm = llm
        
    def __call__(self, state: AgentState) -> AgentState:
        """Provides feedback and coaching"""
        
        if state.get("is_correct"):
            feedback = f"✅ Correct! {state.get('validation_result', '').replace('CORRECT', '').strip()}"
        else:
            feedback = f"❌ {state.get('validation_result', '').replace('INCORRECT', '').strip()}"
            
        return {
            **state,
            "feedback": feedback,
            "messages": state["messages"] + [AIMessage(content=feedback)]
        }

class ProgressTrackerAgent:
    def __init__(self):
        pass
        
    def __call__(self, state: AgentState) -> AgentState:
        """Tracks learner progress"""
        
        progress = state.get("session_progress", {"correct_answers": 0, "total_questions": 0})
        topics_covered = state.get("topics_covered", [])
        
        if state.get("user_answer"):
            progress["total_questions"] += 1
            
            if state.get("is_correct"):
                progress["correct_answers"] += 1
                
            # Extract topic from current question (simplified)
            if state.get("current_question"):
                if "priority" in state["current_question"].lower():
                    topics_covered.append("priority_levels")
                elif "greeting" in state["current_question"].lower():
                    topics_covered.append("customer_greetings")
                elif "troubleshoot" in state["current_question"].lower():
                    topics_covered.append("troubleshooting")
        
        return {
            **state,
            "session_progress": progress,
            "topics_covered": list(set(topics_covered)),  # Remove duplicates
            "user_answer": None,  # Reset for next question
            "current_question": None  # Reset for next question
        }

# ============================================================================
# GRAPH BUILDER (Latest LangGraph 0.5+ Pattern)
# ============================================================================

def create_roleplay_graph(openai_api_key: str):
    """Create the multi-agent roleplay graph"""
    
    llm = ChatOpenAI(
        model="gpt-4o-mini",
        api_key=openai_api_key,
        temperature=0.7
    )
    
    # Initialize agents
    scenario_manager = ScenarioManagerAgent(llm)
    question_generator = QuestionGeneratorAgent(llm)
    knowledge_validator = KnowledgeValidatorAgent(llm)
    feedback_agent = FeedbackAgent(llm)
    progress_tracker = ProgressTrackerAgent()
    
    # Create graph with explicit state schema
    graph = StateGraph(AgentState)
    
    # Add nodes
    graph.add_node("scenario_manager", scenario_manager)
    graph.add_node("question_generator", question_generator)
    graph.add_node("knowledge_validator", knowledge_validator)
    graph.add_node("feedback_agent", feedback_agent)
    graph.add_node("progress_tracker", progress_tracker)
    
    # Define conditional routing
    def should_generate_question(state: AgentState) -> str:
        """Route to question generator or end"""
        if not state.get("should_continue", True):
            return END
        if not state.get("current_question"):
            return "question_generator"
        return "knowledge_validator"
    
    def should_continue_session(state: AgentState) -> str:
        """Check if session should continue"""
        progress = state.get("session_progress", {})
        if progress.get("correct_answers", 0) >= 3:
            return END
        return "scenario_manager"
    
    # Add edges
    graph.add_edge(START, "scenario_manager")
    graph.add_conditional_edges("scenario_manager", should_generate_question)
    graph.add_edge("question_generator", END)  # Wait for user input
    graph.add_edge("knowledge_validator", "feedback_agent")
    graph.add_edge("feedback_agent", "progress_tracker")
    graph.add_conditional_edges("progress_tracker", should_continue_session)
    
    return graph.compile()

# ============================================================================
# STREAMLIT UI
# ============================================================================

def main():
    st.set_page_config(page_title="AI Roleplay Training System", page_icon="🤖", layout="wide")
    
    st.title("🤖 Multi-Agent Roleplay Training System")
    st.markdown("*Powered by LangGraph & Multiple AI Agents*")
    
    # Sidebar for configuration
    with st.sidebar:
        st.header("⚙️ Configuration")
        
        # API Key input
        openai_api_key = st.text_input(
            "OpenAI API Key",
            type="password",
            help="Enter your OpenAI API key"
        )
        
        # Knowledge base input
        st.subheader("📚 Knowledge Base")
        knowledge_base = st.text_area(
            "Enter your training content:",
            value="""As part of our customer support process, here are some key guidelines you should follow:
When handling support tickets, prioritize them according to their impact. High priority issues include system outages, payment failures, or security concerns — these must be addressed immediately. Medium priority covers functionality bugs that affect many users but do not completely block service. Low priority items are minor issues such as small UI glitches or feature requests, which can be scheduled for later review.
For interacting with customers, always begin with a professional and friendly greeting. Recommended phrases include: "Hello! How can I assist you today?", "Thank you for contacting our support team!", or "We're glad to help you out." These set the right tone for the conversation.
When troubleshooting internet-related issues, follow a step-by-step approach. First, ask the customer if they have tried restarting their router, as this resolves many common problems. If that doesn't help, check whether the issue is affecting other users or is specific to that individual. If the problem remains unresolved after these checks, escalate the issue to the appropriate technical team for further investigation.""",
            height=200
        )
        
        if st.button("🚀 Start New Training Session"):
            if openai_api_key and knowledge_base:
                st.session_state.graph = create_roleplay_graph(openai_api_key)
                st.session_state.state = {
                    "messages": [],
                    "knowledge_base": knowledge_base,
                    "current_question": None,
                    "user_answer": None,
                    "is_correct": None,
                    "feedback": None,
                    "session_progress": {"correct_answers": 0, "total_questions": 0},
                    "topics_covered": [],
                    "should_continue": True
                }
                st.success("Training session started!")
                st.rerun()
            else:
                st.error("Please provide OpenAI API Key and Knowledge Base")
    
    # Main interface
    if "graph" in st.session_state and "state" in st.session_state:
        
        # Progress display
        col1, col2, col3 = st.columns(3)
        with col1:
            progress = st.session_state.state.get("session_progress", {})
            st.metric("Correct Answers", progress.get("correct_answers", 0))
        with col2:
            st.metric("Total Questions", progress.get("total_questions", 0))
        with col3:
            topics = st.session_state.state.get("topics_covered", [])
            st.metric("Topics Covered", len(topics))
        
        st.divider()
        
        # Messages display
        st.subheader("💬 Training Conversation")
        messages_container = st.container()
        
        with messages_container:
            for message in st.session_state.state.get("messages", []):
                if isinstance(message, AIMessage):
                    with st.chat_message("assistant"):
                        st.write(message.content)
                elif isinstance(message, HumanMessage):
                    with st.chat_message("user"):
                        st.write(message.content)
        
        # User input
        if st.session_state.state.get("should_continue", True):
            if not st.session_state.state.get("current_question"):
                # Generate greeting/question
                if st.button("🎯 Start Training", key="start_btn"):
                    result = st.session_state.graph.invoke(st.session_state.state)
                    st.session_state.state = result
                    st.rerun()
            else:
                # Waiting for user answer
                user_input = st.chat_input("Type your answer here...")
                
                if user_input:
                    # Add user message
                    st.session_state.state["messages"].append(HumanMessage(content=user_input))
                    st.session_state.state["user_answer"] = user_input
                    
                    # Process through validation and feedback
                    result = st.session_state.graph.invoke(st.session_state.state)
                    st.session_state.state = result
                    st.rerun()
        else:
            st.success("🎉 Training session completed!")
            if st.button("🔄 Start New Session"):
                del st.session_state.graph
                del st.session_state.state
                st.rerun()
    
    else:
        st.info("👈 Configure your settings in the sidebar and start a training session!")
        
        # Example showcase
        st.subheader("🌟 How it works")
        st.markdown("""
        This system uses **5 specialized AI agents** working together:
        
        1. **🎭 Scenario Manager** - Orchestrates the training flow
        2. **❓ Question Generator** - Creates questions from knowledge base  
        3. **✅ Knowledge Validator** - Checks answers against knowledge
        4. **💬 Feedback Agent** - Provides coaching and explanations
        5. **📊 Progress Tracker** - Monitors learning progress
        
        **Example Flow:**
        - System greets learner warmly
        - Generates contextual questions from your knowledge base
        - Validates answers in real-time
        - Provides instant feedback with explanations
        - Tracks progress across topics
        """)

if __name__ == "__main__":
    main()