import os
import time
import random
import re
import json
import heapq
from jsonschema import validate, ValidationError
from langchain_core.messages import AIMessage
import requests
import numpy as np
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import ChatPromptTemplate
from crawl4ai import AsyncWebCrawler
import asyncio
from langchain.output_parsers import OutputFixingParser
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser, PydanticOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader, UnstructuredPowerPointLoader
from PyPDF2 import PdfReader
from langchain_experimental.text_splitter import SemanticChunker
from pydantic import BaseModel, Field
from typing import List, Optional, Dict
from urllib.parse import urlparse
from typing import Union, Dict
from langchain_core.runnables import RunnableLambda
import shutil
from pptx import Presentation
from langchain_ollama import OllamaLLM
from langchain_ollama import ChatOllama
from marker.converters.pdf import PdfConverter
from marker.models import create_model_dict
from marker.output import text_from_rendered
from pathlib import Path
from langchain_text_splitters import MarkdownHeaderTextSplitter
from langchain.output_parsers import OutputFixingParser
from youtube_transcript_api import YouTubeTranscriptApi

# Import the image generator function
from storigo_image_generator import fetch_image_for_slide

OLLAMA_MODEL = "nomic-embed-text"

class SlideContent(BaseModel):
    type: str = Field("flash")
    subheading: Optional[str] = Field(None, description="An optional subheading for the slide")
    paragraphs: List[str] = Field(..., description="List of paragraphs for the slide content")
    visualization_suggestion: str = Field(..., description="A detailed suggestion for a relevant visualization or image (5-8 words with specific elements)")
    image: Optional[str] = Field(None, description="URL of the image for the slide")

class MCQContent(BaseModel):
    type: str = Field("Question")
    question: str = Field(..., description="The multiple-choice question")
    options: List[str] = Field(..., description="A list of 4 answer options")
    correct_answer: str = Field(..., description="The correct answer (e.g., 'a', 'b', 'c', or 'd')")

class StorigoContent(BaseModel):
    slides: Dict[str, SlideContent] = Field(..., description="Dictionary of slide contents with slide numbers as keys")
    token_count: int = Field(..., description="Total token count for all the generated content")

class StorigoContentMCQ(BaseModel):
    slides: Dict[str, SlideContent] = Field(..., description="Dictionary of slide contents with slide numbers as keys")
    mcqs: Dict[str, MCQContent] = Field(..., description="Dictionary of MCQs with identifiers like 'mcq_1' as keys")
    token_count: int = Field(..., description="Total token count for all the generated content")

class StorigoContentMCQMid(BaseModel):
    slides: Dict[str, Union[SlideContent, MCQContent]] = Field(..., description="Dictionary of slide contents with slide numbers as keys and MCQs with MCQ numbers as keys")
    token_count: int = Field(..., description="Total token count for all the generated content")

def count_tokens(text):
    tokens = re.findall(r'\w+', text)
    return len(tokens)

async def crawlerrr(file):
    async with AsyncWebCrawler(verbose=True) as crawler:
        result = await crawler.arun(url=file)
        parsed_url = urlparse(file)
        filename = parsed_url.netloc + parsed_url.path.replace('/', '_')
        if not filename.endswith('.txt'):
            filename += '.txt'
        with open(filename, 'w') as txt_file:
            txt_file.write(result.markdown)
        return filename

def read_file_url(input):
    if not input or not os.path.isfile(input):
        return ""
    with open(input, 'r', encoding='utf-8') as file:
        content = file.read()
    return content if content else ""

def transcribe(youtube_video_url):
    # Extract video ID from various YouTube URL formats
    if "youtu.be/" in youtube_video_url:
        video_id = youtube_video_url.split("youtu.be/")[1].split("?")[0]
    elif "youtube.com/watch?v=" in youtube_video_url:
        video_id = youtube_video_url.split("v=")[1].split("&")[0]
    else:
        raise ValueError("Invalid YouTube URL format")

    try:
        transcript_text = YouTubeTranscriptApi.get_transcript(video_id)
        transcript = ""
        for i in transcript_text:
            transcript += " " + i["text"]
        with open(video_id, "w", encoding="utf-8") as f:
            f.write(transcript)
        return video_id
    except Exception as e:
        raise Exception(f"Unable to retrieve transcript for YouTube video {video_id}. The video may not have subtitles enabled or may be unavailable. Error: {str(e)}")

def load_txt(file_path):
    loader = TextLoader(file_path)
    return loader.load()

def split_text_with_semantic_chunker(docs, embeddings):
    text_splitter = SemanticChunker(
        embeddings, breakpoint_threshold_type="percentile"
    )
    documents = text_splitter.create_documents([doc.page_content for doc in docs])
    return documents

def clean_using_llm(content):
    prompt_template = """
    Extract only the meaningful content from the text below. Focus on descriptions, value propositions, mission statements,
    features, and anything that provides valuable information about the company, products, or services. Ignore any URLs,
    navigation links, contact forms, or irrelevant sections.

    Here is the content to process:

    {context}
    """

    llm = ChatOllama(
        base_url='http://127.0.0.1:11434',
        model="gemma3:12b"
    )
    prompt = PromptTemplate(input_variables=["context"], template=prompt_template)
    runnable = prompt | llm
    filtered_content = runnable.invoke({"context": content})
    filtered_content = filtered_content.content
    return filtered_content

def split_text_with_semantic_chunker_for_url(docs, embeddings):
    text_splitter = SemanticChunker(
        embeddings, breakpoint_threshold_type="percentile"
    )
    if isinstance(docs, str):
        docs = [docs]
    if isinstance(docs[0], str):
        docs = [{'page_content': doc} for doc in docs]
    if not all(isinstance(doc, dict) and 'page_content' in doc for doc in docs):
        return []
    documents = text_splitter.create_documents([doc['page_content'] for doc in docs])
    return documents

def create_and_save_embeddings(split_documents, client_id):
    client_id = str(client_id)
    embedding_folder_base = os.path.join("my_embeddings", client_id)
    os.makedirs(embedding_folder_base, exist_ok=True)
    embeddings = OllamaEmbeddings(model='nomic-embed-text')
    for idx, doc in enumerate(split_documents, start=1):
        embedding_folder = os.path.join(embedding_folder_base)
        os.makedirs(embedding_folder, exist_ok=True)
        temp_db = FAISS.from_documents([doc], embedding=embeddings)
        embedding_file_path = os.path.join(embedding_folder, f"faiss_index{idx}")
        temp_db.save_local(embedding_file_path)

def create_embeddings(split_documents, client_id):
    embeddings = OllamaEmbeddings(model='nomic-embed-text')
    vectorstore = FAISS.from_documents(split_documents, embedding=embeddings)
    faiss_index_path = "faiss_supplier_index"
    vectorstore.save_local(faiss_index_path)
    return vectorstore

def merge_all_faiss(client_id, base_path='my_embeddings'):
    embeddings = OllamaEmbeddings(model="nomic-embed-text")
    merged_faiss = None
    folder_path = f'{base_path}/{client_id}'
    faiss_files = [
        folder for folder in os.listdir(folder_path)
        if folder.startswith('faiss_index') and folder[len('faiss_index'):].isdigit()
    ]
    sorted_files = sorted(faiss_files, key=lambda x: int(x.replace('faiss_index', '')))
    for file in sorted_files:
        faiss_path = os.path.join(folder_path, file)
        current_faiss = FAISS.load_local(faiss_path, embeddings, allow_dangerous_deserialization=True)
        current_texts = [current_faiss.docstore.search(doc_id).page_content
                         for doc_id in current_faiss.index_to_docstore_id.values()]
        if merged_faiss is None:
            merged_faiss = current_faiss
        else:
            merged_faiss.add_texts(current_texts)
    if merged_faiss is not None:
        merged_faiss.save_local(f'{folder_path}/merged_faiss')
    for file in sorted_files:
        faiss_path = os.path.join(folder_path, file)
        try:
            shutil.rmtree(faiss_path)
        except FileNotFoundError:
            pass
        except OSError as e:
            pass
    return merged_faiss

async def generate_slide_content_youtube(document_url, client_id, num_slides, num_mcqs, is_image, is_question, question_position, GPU):
    try:
        try:
            # Try to transcribe YouTube video
            file_path = transcribe(document_url)
            embeddings = OllamaEmbeddings(model='nomic-embed-text')
            text = load_txt(file_path)
            split_documents1 = split_text_with_semantic_chunker(text, embeddings)
            meaningful_content = clean_using_llm(split_documents1)
            split_documents = split_text_with_semantic_chunker_for_url(meaningful_content, embeddings)
        except Exception as transcript_error:
            print(f"Transcript not available for {document_url}, falling back to page content extraction. Error: {str(transcript_error)}")
            # Fallback: Crawl the YouTube page and extract content
            filename = await crawlerrr(document_url)
            raw_content = read_file_url(filename)
            if not raw_content:
                raise Exception("Failed to retrieve content from YouTube URL")

            # Clean and extract meaningful content using LLM
            meaningful_content = clean_using_llm(raw_content)
            embeddings = OllamaEmbeddings(model='nomic-embed-text')
            split_documents = split_text_with_semantic_chunker_for_url(meaningful_content, embeddings)

        # Save embeddings
        create_and_save_embeddings(split_documents, client_id)
        merge_embeddings = merge_all_faiss(client_id)

        # Generate slide content
        llm = ChatOllama(
            base_url='http://127.0.0.1:11434',
            model="gemma3:12b",
            temperature=0.7
        )

        all_chunks = list(merge_embeddings.docstore._dict.values())
        total_chunks = len(all_chunks)

        def estimate_slides(text, words_per_slide=50):
            return max(1, round(len(text.split()) / words_per_slide))

        total_possible = sum(estimate_slides(doc.page_content) for doc in all_chunks)

        selected_chunks = []
        if num_slides <= total_chunks:
            indices = np.linspace(0, total_chunks - 1, num=num_slides, dtype=int)
            selected_chunks = [all_chunks[i] for i in indices]
        else:
            word_counts = [len(c.page_content.split()) for c in all_chunks]
            total_words = sum(word_counts)
            raw_alloc = [w / total_words * num_slides for w in word_counts]
            alloc = [max(1, round(x)) for x in raw_alloc]
            diff = num_slides - sum(alloc)
            i = 0
            while diff != 0:
                if diff > 0:
                    alloc[i] += 1
                    diff -= 1
                elif diff < 0 and alloc[i] > 1:
                    alloc[i] -= 1
                    diff += 1
                i = (i + 1) % total_chunks
            for chunk, count in zip(all_chunks, alloc):
                selected_chunks.extend([chunk] * count)

        context_text = "\n\n".join(doc.page_content for doc in selected_chunks[:num_slides])

        slide_content_template = """
Based on the following context, generate professional and engaging content for exactly {num_slides} slides in a Storigos presentation.

STRICT RULES:
- ❗ ONLY use the content provided in the 'context' below.
- ❌ DO NOT introduce any external knowledge, definitions, or examples not present in the context.
- ⚠️ Do not assume common sense or use general facts. Stick to the exact information given.
- ⚠️ Avoid generic phrases like "as we know", "in general", or "in this video".

Each slide must include:
- A clear and concise **sub-heading**
- **Exactly 2–4 concise paragraphs** derived solely from the context
- A **detailed visualization suggestion** (5-8 words, specific to the content with concrete elements)
- Include specific details: people, objects, actions, settings to make it highly unique
- CRITICAL: Each slide's visualization suggestion MUST BE COMPLETELY UNIQUE across all slides - no overlapping concepts, objects, or scenes
- If slides are related, vary all elements (people, objects, actions, settings) significantly to ensure completely different images

Important: Only output the final JSON object. No additional text, markdown, or explanation should be included.

Context:
{context}

{format_instructions}

The final output must be a valid JSON object where each slide is represented as "slide_1", "slide_2", ..., up to "slide_{num_slides}".
Each slide must contain:
- "subheading"
- "paragraphs"
- "visualization_suggestion"
"""
        raw_parser = PydanticOutputParser(pydantic_object=StorigoContent)
        parser = OutputFixingParser.from_llm(parser=raw_parser, llm=llm)
        slide_content_prompt = ChatPromptTemplate.from_template(slide_content_template)

        slide_content_chain = (
            {
                "context": lambda x: context_text,
                "num_slides": lambda x: x["num_slides"],
                "format_instructions": lambda x: parser.get_format_instructions()
            }
            | slide_content_prompt
            | llm
            | parser
        )

        result = slide_content_chain.invoke({"query": "", "num_slides": num_slides})

        # Handle if slides is returned as list instead of dict
        if hasattr(result, 'slides'):
            slides_data = result.slides
        else:
            raw = result.model_dump() if hasattr(result, 'model_dump') else result
            slides_data = raw.get("slides", {})

        if isinstance(slides_data, list):
            # Convert list to dict with slide_1, slide_2, etc.
            slides_dict = {}
            for i, slide in enumerate(slides_data, start=1):
                slides_dict[f"slide_{i}"] = slide
            slides_data = slides_dict

        def custom_slide_sort_key(item_key: str):
            prefix, num_str = item_key.split('_', 1)
            if prefix == "slide":
                group = 0
            elif prefix == "mcq":
                group = 1
            else:
                group = 999
            number = int(num_str) if num_str.isdigit() else 999
            return (group, number)

        ordered_slides = dict(sorted(slides_data.items(), key=lambda kv: custom_slide_sort_key(kv[0])))

        if isinstance(is_image, int):
            is_image_bool = is_image == 1
        elif isinstance(is_image, str):
            is_image_bool = is_image.strip() in ["1", "true", "yes", "0"]
        else:
            is_image_bool = False

        if is_image_bool:
            for slide_key, slide_content in ordered_slides.items():
                if slide_content.visualization_suggestion:
                    image_path = fetch_image_for_slide(
                        slide_key,
                        slide_content.visualization_suggestion
                    )
                    if image_path:
                        slide_content.image = image_path
                    else:
                        slide_content.image = None
                else:
                    slide_content.image = None
        else:
            for slide_content in ordered_slides.values():
                slide_content.image = None

        token_count = 0
        for slide in ordered_slides.values():
            text_content = f"{slide.subheading} {' '.join(slide.paragraphs)} {slide.visualization_suggestion}"
            token_count += count_tokens(text_content)

        if is_question:
            mcq_template = """
Based on the following context from the last two slides, generate one multiple-choice question (MCQ). The question should be relevant to the content and designed to test comprehension.

**Context**: {context}

The MCQ must include:
- A **question** related to the context
- Exactly **4 answer options**
- A clear indication of the **correct answer** as a single letter: 'a', 'b', 'c', or 'd'

⚠️ **Critical Requirements**:
- ✅ Return **only valid JSON** — no explanations, headers, or extra text.
- ✅ Ensure all fields and options are enclosed in **double quotes (`"`)**.
- ✅ Do **not** use letters like "A.", "B." in the options — just the plain text.
- Dont give Here is the MCQ while generating MCQ
- Directly follow the format given below

The final output **must strictly follow** this format:
```json
{{
    "question": "<The MCQ question>",
    "options": [
        "<Option 1>",
        "<Option 2>",
        "<Option 3>",
        "<Option 4>"
    ],
    "correct_answer": "<Correct option (e.g., 'a', 'b', 'c', or 'd')>"

    Always give "question","options","correct_answer" these labels in double quotes only
}}
"""
            slide_keys = list(ordered_slides.keys())
            async def generate_mcqs_async():
                mcq_prompt = ChatPromptTemplate.from_template(mcq_template)
                mcq_llm = ChatOllama(
                    base_url='http://127.0.0.1:11434',
                    model="gemma3:12b",
                    temperature=0.7
                )

                mcqs = {}
                tasks = []

                for i in range(0, len(slide_keys), 2):
                    if len(mcqs) >= num_mcqs:
                        break
                    context_slides = []
                    for j in range(i, min(i + 2, len(slide_keys))):
                        slide = ordered_slides[slide_keys[j]]
                        context_slides.append(f"{slide.subheading}: {' '.join(slide.paragraphs)}")

                    context_text = "\n".join(context_slides)

                    async def generate_single_mcq(ctx_text, idx):
                        try:
                            mcq_result = await asyncio.get_event_loop().run_in_executor(
                                None,
                                lambda: (
                                    RunnableLambda(lambda x: {"context": ctx_text})
                                    | mcq_prompt
                                    | mcq_llm
                                ).invoke({})
                            )

                            if hasattr(mcq_result, 'content'):
                                content = mcq_result.content
                            else:
                                content = str(mcq_result)

                            start = content.find('{')
                            if start != -1:
                                open_braces = 0
                                end = -1
                                for idx_char, char in enumerate(content[start:], start=start):
                                    if char == '{':
                                        open_braces += 1
                                    elif char == '}':
                                        open_braces -= 1
                                        if open_braces == 0:
                                            end = idx_char
                                            break

                                if end != -1:
                                    content_only = content[start:end+1]
                                    json_object = json.loads(content_only)
                                else:
                                    return None
                            else:
                                return None

                            if not content_only or not content_only.strip():
                                return None

                            formatted_mcq = {
                                "question": json_object.get("question", ""),
                                "options": json_object.get("options", []),
                                "correct_answer": json_object.get("correct_answer", "")
                            }

                            if (formatted_mcq["question"] and
                                len(formatted_mcq["options"]) == 4 and
                                formatted_mcq["correct_answer"]):
                                return formatted_mcq
                            return None
                        except Exception as e:
                            return None

                    if len(tasks) < num_mcqs:
                        task = generate_single_mcq(context_text, i)
                        tasks.append(task)

                results = await asyncio.gather(*tasks, return_exceptions=True)
                for result in results:
                    if isinstance(result, dict) and result:
                        mcq_key = f"mcq_{len(mcqs) + 1}"
                        mcqs[mcq_key] = result

                return mcqs

            mcqs = await generate_mcqs_async()

            for mcq in mcqs.values():
                text_content = f"{mcq['question']} {' '.join(mcq['options'])} {mcq['correct_answer']}"
                token_count += count_tokens(text_content)

            interleaved_content = {}
            mcq_counter = 0
            total_slides = num_slides
            total_mcqs = num_mcqs

            if question_position == 1:
                interval = total_slides // total_mcqs if total_mcqs > 0 else total_slides

                for idx, slide_key in enumerate(slide_keys):
                    interleaved_content[slide_key] = ordered_slides[slide_key]

                    if (idx + 1) % interval == 0 and mcq_counter < total_mcqs:
                        mcq_key = f"mcq_{mcq_counter + 1}"
                        interleaved_content[mcq_key] = mcqs[mcq_key]
                        mcq_counter += 1

                storigo_content = StorigoContentMCQMid(slides=interleaved_content, token_count=token_count)
                return storigo_content

            else:
                for slide_key in slide_keys:
                    interleaved_content[slide_key] = ordered_slides[slide_key]

                for mcq_counter in range(total_mcqs):
                    mcq_key = f"mcq_{mcq_counter + 1}"
                    interleaved_content[mcq_key] = mcqs[mcq_key]

                storigo_content = StorigoContentMCQMid(slides=interleaved_content, token_count=token_count)
                return storigo_content
        else:
            storigo_content_without_mc = StorigoContent(slides=ordered_slides, token_count=token_count)
            return storigo_content_without_mc

    except Exception as e:
        raise Exception(f"Error generating slide content: {str(e)}")