"""
================================================================================
 Storigo Content Generator - Professional Edition v7.0
 - Synchronous image generation during slide creation
 - Direct image URL assignment (no null values)
 - One slide processed completely before moving to next
 - Professional error handling and logging
 - GUARANTEED: Every slide with is_image=1 gets an image URL or None
================================================================================
"""

import os
import time
import random
import re
import json
import requests
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser, PydanticOutputParser
from pydantic import BaseModel, Field, validator
from typing import List, Optional, Dict, Union
from collections import deque
from threading import Lock
import logging

# Import the professional synchronous image generator
from storigo_image_generator import fetch_image_for_slide

# --------------------------------------------------------------------------
# Configuration
# --------------------------------------------------------------------------
GROQ_API_KEY = "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"

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - [%(levelname)s] - (StorigoGen) - %(message)s'
)
logger = logging.getLogger(__name__)

# --------------------------------------------------------------------------
# Rate Limiter Class
# --------------------------------------------------------------------------
class SmartRateLimiter:
    """Intelligent rate limiter for Groq API"""
    def __init__(self, max_requests_per_minute=25):
        self.max_requests = max_requests_per_minute
        self.request_times = deque()
        self.lock = Lock()
        self.total_requests = 0
        self.total_waits = 0
    
    def wait_if_needed(self):
        """Wait if approaching rate limits"""
        with self.lock:
            now = time.time()
            
            # Remove old requests
            while self.request_times and now - self.request_times[0] > 60:
                self.request_times.popleft()
            
            # Check if we need to wait
            if len(self.request_times) >= self.max_requests:
                sleep_time = 60 - (now - self.request_times[0]) + 1
                logger.warning(f"⏳ Rate limit protection: waiting {sleep_time:.1f}s")
                self.total_waits += 1
                time.sleep(sleep_time)
                
                # Clean up again after waiting
                while self.request_times and time.time() - self.request_times[0] > 60:
                    self.request_times.popleft()
            
            self.request_times.append(time.time())
            self.total_requests += 1

rate_limiter = SmartRateLimiter(max_requests_per_minute=25)

# --------------------------------------------------------------------------
# Pydantic Models
# --------------------------------------------------------------------------
class SlideContent(BaseModel):
    type: str = Field("flash")
    subheading: Optional[str] = Field(None)
    paragraphs: List[str] = Field(...)
    visualization_suggestion: str = Field(...)
    image: Optional[str] = Field(default=None)
    
    @validator('image', pre=True, always=True)
    def validate_image_path(cls, v):
        """Ensure image is either None or a valid string URL"""
        if v is None:
            return None
        if isinstance(v, str):
            stripped = v.strip()
            return stripped if stripped and stripped != "None" else None
        return None
    
    class Config:
        arbitrary_types_allowed = True
        validate_assignment = True

class MCQContent(BaseModel):
    type: str = Field("Question")
    question: str = Field(..., description="The multiple-choice question")
    options: List[str] = Field(..., description="A list of 4 answer options")
    correct_answer: str = Field(..., description="The correct answer")

class StorigoContent(BaseModel):
    slides: Dict[str, SlideContent] = Field(...)
    token_count: int = 0

class StorigoContentMCQMid(BaseModel):
    slides: Dict[str, Union[SlideContent, MCQContent]] = Field(...)
    token_count: int = 0

# --------------------------------------------------------------------------
# Utility Functions
# --------------------------------------------------------------------------
def count_tokens(text: str) -> int:
    """Simple token counter based on word count"""
    tokens = re.findall(r'\w+', text)
    return len(tokens)

def quick_json_fix(ai_message) -> str:
    """Fix common JSON formatting issues from LLM responses"""
    if hasattr(ai_message, 'content'):
        text = ai_message.content
    else:
        text = str(ai_message)

    text = text.strip()

    # Remove markdown code blocks
    if text.startswith("```"):
        text = re.sub(r'^```json?\s*', '', text)
        text = re.sub(r'\s*```', '', text)

    # Remove explanatory text before JSON
    if "Here's" in text or "I apologize" in text:
        start = text.find('{')
        end = text.rfind('}') + 1
        if start != -1 and end != 0:
            text = text[start:end]

    # Try to extract JSON
    json_match = re.search(r'\{.*\}', text, re.DOTALL)
    if json_match:
        return json_match.group(0)

    return text

# --------------------------------------------------------------------------
# Main Generation Function - PROFESSIONAL VERSION
# --------------------------------------------------------------------------
def generate_slide_content_from_prompt(
    prompt: str, 
    num_slides: int, 
    num_mcqs: int, 
    is_image: bool, 
    is_question: str, 
    question_position: str, 
    GPU: int
) -> Union[StorigoContent, StorigoContentMCQMid]:
    """
    Generate professional slide content with INLINE image generation
    
    PROFESSIONAL APPROACH:
    1. Generate slide content
    2. For EACH slide, immediately fetch and assign image
    3. Process one slide completely before moving to next
    4. Return final result with all images assigned
    
    Args:
        prompt: Main topic prompt
        num_slides: Number of slides to generate
        num_mcqs: Number of MCQs
        is_image: Whether to generate images
        is_question: Whether to generate MCQs
        question_position: Position of questions ("end" or "1")
        GPU: Legacy parameter
        
    Returns:
        StorigoContent or StorigoContentMCQMid with images properly assigned
    """
    generation_start = time.time()
    
    try:
        logger.info("=" * 80)
        logger.info("🚀 STORIGO PROFESSIONAL CONTENT GENERATOR v7.0")
        logger.info("=" * 80)
        logger.info(f"📋 Configuration:")
        logger.info(f"   - Slides: {num_slides}")
        logger.info(f"   - MCQs: {num_mcqs}")
        logger.info(f"   - Images: {is_image}")
        logger.info(f"   - Prompt: {prompt[:60]}...")
        logger.info("=" * 80)
        
        # Initialize Groq API
        logger.info("🔧 Initializing Groq API...")
        llm = ChatGroq(
            model_name='llama-3.1-8b-instant',
            groq_api_key=GROQ_API_KEY,
            temperature=0.7,
            max_tokens=4096,
            request_timeout=60,
            max_retries=3
        )
        logger.info("✅ Groq API ready")

        # ======================================================================
        # STEP 1: Generate Slide Content
        # ======================================================================
        logger.info("\n" + "=" * 80)
        logger.info("📝 STEP 1: Generating Slide Content")
        logger.info("=" * 80)
        
        slide_content_template = """
Based on the following prompt, generate professional content for exactly {num_slides} slides.

Each slide MUST include:
- A clear, descriptive subheading
- 2-3 informative paragraphs (well-written and engaging)
- A specific visualization suggestion (3-5 words describing a concrete, searchable image)

CRITICAL for visualization_suggestion:
- Use ONLY concrete, visual, searchable terms
- Think: "What would I search on a stock photo website?"
- Good examples: "team meeting conference room", "data analytics dashboard", "kitchen staff commercial"
- Bad examples: "concept of leadership", "abstract representation"

Prompt: {prompt}

Make content professional with:
- Thought-provoking insights
- Relevant examples or statistics
- Industry trends
- Clear, actionable takeaways

Return ONLY a valid JSON object with this exact structure (no explanations, no markdown, no schema definitions):
{{
  "slides": {{
    "slide_1": {{
      "subheading": "Introduction to Kitchen Safety",
      "paragraphs": [
        "Kitchen safety is paramount in any food service environment.",
        "Proper protocols protect staff and ensure food quality."
      ],
      "visualization_suggestion": "kitchen staff commercial",
      "image": null
    }},
    "slide_2": {{
      "subheading": "Food Handling Best Practices",
      "paragraphs": [
        "Proper food handling prevents contamination and foodborne illness.",
        "Key practices include temperature control and cross-contamination prevention."
      ],
      "visualization_suggestion": "food preparation hygiene",
      "image": null
    }}
  }},
  "token_count": 0
}}

Generate exactly {num_slides} slides numbered slide_1 through slide_{num_slides}.
"""

        slide_prompt = ChatPromptTemplate.from_template(slide_content_template)

        slide_chain = (
            {
                "prompt": lambda x: x["prompt"],
                "num_slides": lambda x: x["num_slides"]
            }
            | slide_prompt
            | llm
            | quick_json_fix
        )

        # Generate with rate limiting
        rate_limiter.wait_if_needed()

        slide_gen_start = time.time()
        raw_result = slide_chain.invoke({"prompt": prompt, "num_slides": num_slides})
        slide_gen_time = time.time() - slide_gen_start

        # Parse the JSON response
        try:
            result_data = json.loads(raw_result)
            result = StorigoContent(**result_data)
            logger.info(f"✅ Generated {len(result.slides)} slides in {slide_gen_time:.2f}s")
        except json.JSONDecodeError as e:
            logger.error(f"❌ JSON parsing failed: {str(e)}")
            logger.error(f"Raw response: {raw_result[:500]}...")
            raise Exception(f"Failed to parse slide content JSON: {str(e)}")

        # Sort slides by slide number (numerical order)
        ordered_slides = dict(sorted(result.slides.items(), key=lambda x: int(x[0].split('_')[1])))
        
        # Validate content
        for slide_key, slide_content in ordered_slides.items():
            # Ensure minimum paragraphs
            while len(slide_content.paragraphs) < 2:
                slide_content.paragraphs.append("Additional information for this slide.")
            
            # Ensure visualization suggestion exists
            if not slide_content.visualization_suggestion or not slide_content.visualization_suggestion.strip():
                slide_content.visualization_suggestion = "professional business concept"
                logger.warning(f"⚠️ {slide_key}: Empty visualization, using default")

        # ======================================================================
        # STEP 2: Generate Images INLINE (One by One)
        # ======================================================================
        if is_image:
            logger.info("\n" + "=" * 80)
            logger.info("🎨 STEP 2: Generating Images (INLINE - One by One)")
            logger.info("=" * 80)
            
            success_count = 0
            failed_count = 0
            
            for slide_key in ordered_slides.keys():
                slide_content = ordered_slides[slide_key]
                
                try:
                    viz_prompt = slide_content.visualization_suggestion
                    logger.info(f"\n📸 Processing {slide_key}...")
                    logger.info(f"   Visualization: '{viz_prompt}'")
                    
                    # CRITICAL: Fetch image synchronously
                    image_start = time.time()
                    image_url = fetch_image_for_slide(slide_key, viz_prompt)
                    image_time = time.time() - image_start
                    
                    if image_url:
                        # ASSIGN IMAGE URL
                        ordered_slides[slide_key].image = image_url
                        success_count += 1
                        logger.info(f"   ✅ Image assigned: {image_url} ({image_time:.2f}s)")
                    else:
                        # NO IMAGE FOUND
                        ordered_slides[slide_key].image = None
                        failed_count += 1
                        logger.warning(f"   ⚠️ No image found ({image_time:.2f}s)")
                    
                except Exception as e:
                    logger.error(f"   ❌ Error: {str(e)}")
                    ordered_slides[slide_key].image = None
                    failed_count += 1
            
            logger.info("\n" + "=" * 80)
            logger.info(f"🎨 Image Generation Summary:")
            logger.info(f"   - Success: {success_count}/{len(ordered_slides)}")
            logger.info(f"   - Failed: {failed_count}/{len(ordered_slides)}")
            logger.info("=" * 80)
        else:
            # No images requested - set all to None
            logger.info("\n📷 Image generation disabled (is_image=False)")
            for slide_key in ordered_slides.keys():
                ordered_slides[slide_key].image = None

        # Update result with ordered slides
        result.slides = ordered_slides

        # ======================================================================
        # STEP 3: Calculate Token Count
        # ======================================================================
        token_count = 0
        for slide_content in result.slides.values():
            text = f"{slide_content.subheading} {' '.join(slide_content.paragraphs)} {slide_content.visualization_suggestion}"
            token_count += count_tokens(text)

        # ======================================================================
        # STEP 4: Generate MCQs (if requested)
        # ======================================================================
        if is_question and is_question not in ["0", "false", "False"]:
            logger.info("\n" + "=" * 80)
            logger.info(f"❓ STEP 4: Generating {num_mcqs} MCQs (Batch Mode)")
            logger.info("=" * 80)
            
            time.sleep(1.5)  # Rate limiting
            
            # Build context from slides
            all_context = []
            for key, slide in result.slides.items():
                title = slide.subheading
                paras = ' '.join(slide.paragraphs)
                all_context.append(f"**{title}**\n{paras}")
            
            full_context = "\n\n".join(all_context)
            
            mcq_template = """
You are an expert educational content creator. Generate exactly {num_mcqs} high-quality MCQs.

**Requirements:**
1. Test comprehension of key concepts
2. Distribute across different topics
3. Each question has exactly 4 options
4. Only one clearly correct answer
5. Avoid trivial questions

**Content:**
{context}

**Output Format:**
Return ONLY a valid JSON array (no markdown, no explanations):

[
    {{
        "type": "Question",
        "question": "Clear question text?",
        "options": ["Option 1", "Option 2", "Option 3", "Option 4"],
        "correct_answer": "Option 1"
    }}
]

Generate exactly {num_mcqs} questions now:
"""
            
            try:
                mcq_prompt = ChatPromptTemplate.from_template(mcq_template)
                
                rate_limiter.wait_if_needed()
                
                mcq_start = time.time()
                mcq_result = (mcq_prompt | llm).invoke({
                    "context": full_context,
                    "num_mcqs": num_mcqs
                })
                mcq_time = time.time() - mcq_start
                
                # Parse MCQs
                content = mcq_result.content.strip()
                
                # Remove markdown
                if content.startswith("```"):
                    content = re.sub(r'^```json?\s*', '', content)
                    content = re.sub(r'\s*```', '', content)
                
                # Extract JSON array
                array_match = re.search(r'\[\s*\{.*\}\s*\]', content, re.DOTALL)
                if array_match:
                    content = array_match.group(0)
                
                mcq_list = json.loads(content)
                
                mcqs = {}
                for idx, mcq_data in enumerate(mcq_list[:num_mcqs]):
                    try:
                        if 'type' not in mcq_data:
                            mcq_data['type'] = 'Question'
                        
                        mcqs[f"mcq_{idx + 1}"] = MCQContent(**mcq_data)
                        logger.info(f"✅ MCQ {idx + 1}: {mcq_data['question'][:50]}...")
                    except Exception as e:
                        logger.warning(f"⚠️ Failed to parse MCQ {idx + 1}: {e}")
                
                logger.info(f"✅ Generated {len(mcqs)}/{num_mcqs} MCQs in {mcq_time:.2f}s")
                
                # Calculate MCQ token count
                for mcq in mcqs.values():
                    text = f"{mcq.question} {' '.join(mcq.options)} {mcq.correct_answer}"
                    token_count += count_tokens(text)
                
                # Combine slides and MCQs
                final_content = {}
                
                if question_position == "1" or question_position == 1:
                    # Distribute MCQs
                    logger.info("📍 Distributing MCQs throughout slides")
                    slide_keys = list(result.slides.keys())
                    interval = len(slide_keys) // len(mcqs) if mcqs else 0
                    mcq_counter = 0
                    
                    for idx, key in enumerate(slide_keys):
                        final_content[key] = result.slides[key]
                        
                        if interval > 0 and (idx + 1) % interval == 0 and mcq_counter < len(mcqs):
                            mcq_key = f"mcq_{mcq_counter + 1}"
                            if mcq_key in mcqs:
                                final_content[mcq_key] = mcqs[mcq_key]
                                mcq_counter += 1
                else:
                    # All MCQs at end
                    logger.info("📍 Placing all MCQs at end")
                    for key in result.slides.keys():
                        final_content[key] = result.slides[key]
                    
                    for mcq_key, mcq_content in mcqs.items():
                        final_content[mcq_key] = mcq_content
                
                generation_time = time.time() - generation_start
                
                logger.info("\n" + "=" * 80)
                logger.info("🎉 GENERATION COMPLETE")
                logger.info("=" * 80)
                logger.info(f"📊 Final Statistics:")
                logger.info(f"   - Total time: {generation_time:.2f}s")
                logger.info(f"   - Slides: {len(result.slides)}")
                logger.info(f"   - MCQs: {len(mcqs)}")
                logger.info(f"   - Total items: {len(final_content)}")
                logger.info(f"   - Token count: {token_count}")
                if is_image:
                    logger.info(f"   - Images assigned: {success_count}/{len(result.slides)}")
                logger.info("=" * 80)
                
                return StorigoContentMCQMid(slides=final_content, token_count=token_count)
                
            except Exception as e:
                logger.error(f"❌ MCQ generation failed: {str(e)}")
                # Return slides only
                return StorigoContent(slides=result.slides, token_count=token_count)
        
        else:
            # No MCQs requested
            generation_time = time.time() - generation_start
            
            logger.info("\n" + "=" * 80)
            logger.info("🎉 GENERATION COMPLETE")
            logger.info("=" * 80)
            logger.info(f"📊 Final Statistics:")
            logger.info(f"   - Total time: {generation_time:.2f}s")
            logger.info(f"   - Slides: {len(result.slides)}")
            logger.info(f"   - Token count: {token_count}")
            if is_image:
                logger.info(f"   - Images assigned: {success_count}/{len(result.slides)}")
            logger.info("=" * 80)
            
            return StorigoContent(slides=result.slides, token_count=token_count)

    except Exception as e:
        logger.error("=" * 80)
        logger.error(f"❌ FATAL ERROR")
        logger.error(f"Error: {str(e)}")
        logger.error("=" * 80)
        import traceback
        logger.error(traceback.format_exc())
        raise Exception(f"Content generation failed: {str(e)}")


# --------------------------------------------------------------------------
# Testing
# --------------------------------------------------------------------------
if __name__ == "__main__":
    print("\n" + "=" * 80)
    print("🧪 TESTING STORIGO CONTENT GENERATOR v7.0")
    print("=" * 80 + "\n")
    
    test_result = generate_slide_content_from_prompt(
        prompt="Kitchen safety and food handling best practices",
        num_slides=3,
        num_mcqs=2,
        is_image=True,
        is_question="1",
        question_position="end",
        GPU=1
    )
    
    print("\n" + "=" * 80)
    print("📋 TEST RESULTS")
    print("=" * 80)
    
    slide_count = sum(1 for k in test_result.slides.keys() if k.startswith('slide_'))
    mcq_count = sum(1 for k in test_result.slides.keys() if k.startswith('mcq_'))
    images_with_url = sum(1 for k, v in test_result.slides.items() 
                          if k.startswith('slide_') and hasattr(v, 'image') and v.image)
    
    print(f"Total items: {len(test_result.slides)}")
    print(f"  - Slides: {slide_count}")
    print(f"  - MCQs: {mcq_count}")
    print(f"  - Images assigned: {images_with_url}/{slide_count}")
    print(f"Token count: {test_result.token_count}")
    
    print("\n📸 Image URLs:")
    for key, content in test_result.slides.items():
        if key.startswith('slide_'):
            img = getattr(content, 'image', None)
            status = "✅" if img else "⚠️"
            print(f"  {status} {key}: {img if img else 'None'}")
    
    print("\n" + "=" * 80)
    print("✅ Test completed successfully!")
    print("=" * 80)