import os
import time
import random
import re
import json
import heapq
from jsonschema import validate, ValidationError
from langchain_core.messages import AIMessage
import requests
import numpy as np
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import ChatPromptTemplate
from crawl4ai import AsyncWebCrawler
import asyncio
from langchain.output_parsers import OutputFixingParser
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser, PydanticOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader, UnstructuredPowerPointLoader
from PyPDF2 import PdfReader
from langchain_experimental.text_splitter import SemanticChunker
from pydantic import BaseModel, Field
from typing import List, Optional, Dict
from urllib.parse import urlparse
from typing import Union, Dict
from langchain_core.runnables import RunnableLambda
import shutil
from pptx import Presentation
from langchain_ollama import OllamaLLM
from langchain_ollama import ChatOllama
from marker.converters.pdf import PdfConverter
from marker.models import create_model_dict
from marker.output import text_from_rendered
from pathlib import Path
from langchain_text_splitters import MarkdownHeaderTextSplitter
from langchain.output_parsers import OutputFixingParser

# Import the image generator function
from storigo_image_generator import fetch_image_for_slide

OLLAMA_MODEL = "nomic-embed-text"

class SlideContent(BaseModel):
    type: str = Field("flash")
    subheading: Optional[str] = Field(None, description="An optional subheading for the slide")
    paragraphs: List[str] = Field(..., description="List of paragraphs for the slide content")
    visualization_suggestion: str = Field(..., description="A detailed suggestion for a relevant visualization or image (5-8 words with specific elements)")
    image: Optional[str] = Field(None, description="URL of the image for the slide")

class MCQContent(BaseModel):
    type: str = Field("Question")
    question: str = Field(..., description="The multiple-choice question")
    options: List[str] = Field(..., description="A list of 4 answer options")
    correct_answer: str = Field(..., description="The correct answer (e.g., 'a', 'b', 'c', or 'd')")

class StorigoContent(BaseModel):
    slides: Dict[str, SlideContent] = Field(..., description="Dictionary of slide contents with slide numbers as keys")
    token_count: int = Field(..., description="Total token count for all the generated content")

class StorigoContentMCQ(BaseModel):
    slides: Dict[str, SlideContent] = Field(..., description="Dictionary of slide contents with slide numbers as keys")
    mcqs: Dict[str, MCQContent] = Field(..., description="Dictionary of MCQs with identifiers like 'mcq_1' as keys")
    token_count: int = Field(..., description="Total token count for all the generated content")

class StorigoContentMCQMid(BaseModel):
    slides: Dict[str, Union[SlideContent, MCQContent]] = Field(..., description="Dictionary of slide contents with slide numbers as keys and MCQs with MCQ numbers as keys")
    token_count: int = Field(..., description="Total token count for all the generated content")

def count_tokens(text):
    tokens = re.findall(r'\w+', text)
    return len(tokens)

def parsing(input):
    converter = PdfConverter(artifact_dict=create_model_dict())
    rendered = converter(input)
    text, _, images = text_from_rendered(rendered)
    return text

def marks_splitter(headers_to_split_on, content):
    markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on)
    md_header_splits = markdown_splitter.split_text(content)
    return md_header_splits

def allocate_slides(chunks, total_slides, min_chars=60):
    extracted = []
    for i, chunk in enumerate(chunks):
        if hasattr(chunk, "page_content"):
            text = chunk.page_content
        elif hasattr(chunk, "content"):
            text = chunk.content
        elif isinstance(chunk, str):
            text = chunk
        else:
            text = str(chunk)
        extracted.append((i, text))

    valid = [(i, txt) for i, txt in extracted if len(txt) >= min_chars]
    if not valid:
        return {}

    total_chars = sum(len(txt) for _, txt in valid)

    allocations = {}
    remaining = total_slides
    for idx, txt in valid:
        prop = len(txt) / total_chars
        cnt = max(1, round(prop * total_slides))
        allocations[idx] = cnt
        remaining -= cnt

    while remaining < 0:
        max_idx = max(allocations, key=allocations.get)
        if allocations[max_idx] > 1:
            allocations[max_idx] -= 1
            remaining += 1
        else:
            break

    while remaining > 0:
        sorted_by_size = sorted(valid, key=lambda x: len(x[1]), reverse=True)
        for i in range(min(remaining, len(sorted_by_size))):
            idx = sorted_by_size[i][0]
            allocations[idx] = allocations.get(idx, 0) + 1
            remaining -= 1

    return allocations

class SlideCollection:
    def __init__(self):
        self.slides = {}

    def add_slide(self, key, content):
        self.slides[key] = content

    def keys(self):
        return self.slides.keys()

    def __iter__(self):
        return iter(self.slides)

    def __getitem__(self, key):
        return self.slides[key]

    def items(self):
        return self.slides.items()

    def values(self):
        return self.slides.values()

    def __repr__(self):
        return repr(self.slides)

def quick_json_fix(ai_message) -> str:
    if hasattr(ai_message, 'content'):
        text = ai_message.content
    else:
        text = str(ai_message)

    text = text.strip()

    if "Here's another attempt" in text or "I apologize" in text:
        start = text.find('{')
        end = text.rfind('}') + 1
        if start != -1 and end != 0:
            text = text[start:end]

    text = re.sub(r"'(\w+)\":", r'"\1":', text)

    try:
        parsed = json.loads(text)
        if "properties" in parsed:
            return json.dumps(parsed["properties"])
    except:
        pass

    return text

async def generate_slide_content_document(temp_file_path, client_id, num_slides, num_mcqs, is_image, is_question, question_position, GPU):
    try:
        # Run parsing in executor since it's sync
        loop = asyncio.get_event_loop()
        parse_data = await loop.run_in_executor(None, parsing, temp_file_path)

        with open("parse_data.md", "w", encoding="utf-8") as f:
            f.write(parse_data)
        headers_to_split_on = [
            ("#", "Header 1"),
            ("##", "Header 2")
        ]
        with open("parse_data.md", "r", encoding="utf-8") as f:
            content = f.read()
        chunks = marks_splitter(headers_to_split_on, content)
        allocation = allocate_slides(chunks, num_slides, min_chars=100)

        slide_content_template = """
Based on the following context, generate professional and engaging content for exactly {num_slides} slides in a Storigos presentation.

Each slide must include:
- A clear and concise **sub-heading**
- **Paragraphs** that effectively communicate the key ideas and insights
- A specific, concise **visualization suggestion**

**Context**: {query}

Focus on creating content that is both informative and engaging. Ensure each slide:
- Has a well-structured sub-heading that captures the main point
- Uses clear and concise paragraphs to communicate important information

Use a professional and creative tone throughout. Each slide should incorporate the following elements where appropriate:
- **Thought-provoking questions** to encourage reflection
- **Relevant statistics** or data points that add credibility
- **Industry insights** or emerging trends to demonstrate expertise
- **Practical examples** or case studies to illustrate key concepts
- **Calls to action** to guide the audience toward specific actions or takeaways

For the visualization suggestion:
- Provide a clear and specific description of an image that would be relevant to the slide content.
- Keep it very concise, using a maximum of 5 words.
- Focus on concrete objects, scenes, or concepts that can be easily visualized.
- Avoid abstract or overly complex ideas.
- Include the context of the topic (e.g., "Python programming logo" instead of just "Python logo").

Make sure all content is drawn exclusively from the provided context or embedded data. Avoid introducing external information not found in the source material.

{format_instructions}

CRITICAL: The output must be a valid JSON object with this EXACT structure:
{{
  "slides": {{
    "slide_1": {{
      "type": "flash",
      "subheading": "...",
      "paragraphs": ["...", "..."],
      "visualization_suggestion": "...",
      "image": null
    }},
    "slide_2": {{ ... }}
  }},
  "token_count": 0
}}

DO NOT put "token_count" inside the "slides" object. It must be at the root level.
DO NOT include any explanations or additional text - only the JSON object.
The final output must be in strict sequential order: "slide_1", "slide_2", ..., up to "slide_{num_slides}".
"""

        parser = PydanticOutputParser(pydantic_object=StorigoContent)
        slide_content_prompt = ChatPromptTemplate.from_template(slide_content_template)

        llm = ChatOllama(
            base_url='http://127.0.0.1:11434',
            model="gemma3:12b"
        )

        slide_content_chain = (
            {
                "query": lambda x: x["query"],
                "num_slides": lambda x: x["num_slides"],
                "format_instructions": lambda x: parser.get_format_instructions()
            }
            | slide_content_prompt
            | llm
            | parser
        )

        all_slides = SlideCollection()
        counter = 1

        for chunk_idx in sorted(allocation):
            n = allocation[chunk_idx]
            chunk = chunks[chunk_idx]
            query = getattr(chunk, "page_content", getattr(chunk, "content", str(chunk)))

            result = slide_content_chain.invoke({"query": query, "num_slides": n})

            if hasattr(result, 'slides'):
                slide_items = result.slides
            else:
                raw = result.model_dump() if hasattr(result, 'model_dump') else result
                slide_items = raw.get("slides", {})

            if isinstance(slide_items, dict) and "token_count" in slide_items:
                slide_items.pop("token_count")

            for slide_key in sorted(slide_items.keys(), key=lambda k: int(k.split("_")[1])):
                all_slides.add_slide(f"slide_{counter}", slide_items[slide_key])
                counter += 1

        def custom_slide_sort_key(item_key: str):
            prefix, num_str = item_key.split('_', 1)
            if prefix == "slide":
                group = 0
            elif prefix == "mcq":
                group = 1
            else:
                group = 999
            number = int(num_str) if num_str.isdigit() else 999
            return (group, number)

        ordered_slides = dict(sorted(all_slides.slides.items(), key=lambda kv: custom_slide_sort_key(kv[0])))
        all_slides.slides = ordered_slides

        if is_image:
            for slide_key, slide_content in all_slides.slides.items():
                if isinstance(slide_content, dict):
                    slide_obj = SlideContent(**slide_content)
                else:
                    slide_obj = slide_content

                if slide_obj.visualization_suggestion:
                    image_path = fetch_image_for_slide(
                        slide_key,
                        slide_obj.visualization_suggestion
                    )
                    if image_path:
                        if isinstance(slide_content, dict):
                            all_slides.slides[slide_key]['image'] = image_path
                        else:
                            slide_content.image = image_path
                    else:
                        if isinstance(slide_content, dict):
                            all_slides.slides[slide_key]['image'] = None
                        else:
                            slide_content.image = None
                else:
                    if isinstance(slide_content, dict):
                        all_slides.slides[slide_key]['image'] = None
                    else:
                        slide_content.image = None
        else:
            for slide_key in all_slides.slides:
                if isinstance(all_slides.slides[slide_key], dict):
                    all_slides.slides[slide_key]['image'] = None
                else:
                    all_slides.slides[slide_key].image = None

        if is_question:
            mcq_template = """
Based on the following context from the last two slides, generate one multiple-choice question (MCQ). The question should be relevant to the content and designed to test comprehension.

**Context**: {context}

The MCQ must include:
- A **question** related to the context
- Exactly **4 answer options**
- A clear indication of the **correct answer** as a single letter: 'a', 'b', 'c', or 'd'

⚠️ **Critical Requirements**:
- ✅ Return **only valid JSON** — no explanations, headers, or extra text.
- ✅ Ensure all fields and options are enclosed in **double quotes (`"`)**.
- ✅ Do **not** use letters like "A.", "B." in the options — just the plain text.
- Directly follow the format given below

The final output **must strictly follow** this format:
```json
{{
    "question": "<The MCQ question>",
    "options": [
        "<Option 1>",
        "<Option 2>",
        "<Option 3>",
        "<Option 4>"
    ],
    "correct_answer": "<Correct option (e.g., 'a', 'b', 'c', or 'd')>"
}}

{format_instructions}
"""
            mcq_parser = PydanticOutputParser(pydantic_object=MCQContent)
            mcq_prompt = ChatPromptTemplate.from_template(mcq_template)

            json_fixer = RunnableLambda(quick_json_fix)
            output_fixing_parser = OutputFixingParser.from_llm(llm=llm, parser=mcq_parser)

            mcqs = {}
            slide_keys = list(all_slides.slides.keys())

            for i in range(0, len(slide_keys), 2):
                if len(mcqs) < num_mcqs:
                    context_slides = []
                    for j in range(i, min(i + 2, len(slide_keys))):
                        key = slide_keys[j]
                        slide = all_slides.slides[key]

                        if isinstance(slide, dict):
                            title = slide.get("subheading", "")
                            paras = slide.get("paragraphs", [])
                        else:
                            title = slide.subheading
                            paras = slide.paragraphs

                        context_slides.append(f"{title}: {' '.join(paras)}")

                    context_text = "\n".join(context_slides)
                    try:
                        mcq_result = (
                            mcq_prompt
                            | llm
                            | json_fixer
                            | output_fixing_parser
                        ).invoke({
                            "context": context_text,
                            "format_instructions": mcq_parser.get_format_instructions()
                        })

                        mcqs[f"mcq_{len(mcqs) + 1}"] = mcq_result
                    except Exception as e:
                        continue

            interleaved_content = {}
            mcq_counter = 0

            if question_position == 1:
                interval = num_slides // num_mcqs if num_mcqs > 0 else num_slides

                for idx, slide_key in enumerate(slide_keys):
                    interleaved_content[slide_key] = all_slides.slides[slide_key]

                    if (idx + 1) % interval == 0 and mcq_counter < num_mcqs:
                        mcq_key = f"mcq_{mcq_counter + 1}"
                        if mcq_key in mcqs:
                            interleaved_content[mcq_key] = mcqs[mcq_key]
                        mcq_counter += 1
            else:
                for slide_key in slide_keys:
                    interleaved_content[slide_key] = all_slides.slides[slide_key]

                for mcq_counter in range(num_mcqs):
                    mcq_key = f"mcq_{mcq_counter + 1}"
                    if mcq_key in mcqs:
                        interleaved_content[mcq_key] = mcqs[mcq_key]

            return StorigoContentMCQMid(slides=interleaved_content, token_count=0)
        else:
            return StorigoContent(slides=all_slides.slides, token_count=0)

    except Exception as e:
        raise Exception(f"Error generating slide content: {str(e)}")