import os
import time
import random
import re
import json
import heapq
from jsonschema import validate, ValidationError
from langchain_core.messages import AIMessage
import requests
import numpy as np
#from langchain_community.embeddings import OllamaEmbeddings
from langchain_ollama import OllamaEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import ChatPromptTemplate
from crawl4ai import AsyncWebCrawler
import asyncio
from langchain.output_parsers import OutputFixingParser
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser, PydanticOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader, UnstructuredPowerPointLoader
from PyPDF2 import PdfReader
from langchain_experimental.text_splitter import SemanticChunker
from pydantic import BaseModel, Field
from typing import List, Optional, Dict
from urllib.parse import urlparse
from typing import Union, Dict
from langchain_core.runnables import RunnableLambda
import shutil
from pptx import Presentation
from langchain_ollama import OllamaLLM
from langchain_ollama import ChatOllama
from marker.converters.pdf import PdfConverter
from marker.models import create_model_dict
from marker.output import text_from_rendered
from pathlib import Path
from langchain_text_splitters import MarkdownHeaderTextSplitter
from langchain.output_parsers import OutputFixingParser

# Import the image generator function
from storigo_image_generator import fetch_image_for_slide

OLLAMA_MODEL = "nomic-embed-text"



class SlideContent(BaseModel):
    #heading: str = Field(..., description="The main heading of the slide")
    type: str = Field("flash")
    subheading: Optional[str] = Field(None, description="An optional subheading for the slide")
    paragraphs: List[str] = Field(..., description="List of paragraphs for the slide content")
    visualization_suggestion: str = Field(..., description="A specific and concise suggestion for a relevant visualization or image (max 5 words)")
    image: Optional[str] = Field(None, description="URL of the image for the slide")
    #subheading = heading

class MCQContent(BaseModel):
    type: str = Field("Question")
    question: str = Field(..., description="The multiple-choice question")
    options: List[str] = Field(..., description="A list of 4 answer options")
    correct_answer: str = Field(..., description="The correct answer (e.g., 'a', 'b', 'c', or 'd')")

class StorigoContent(BaseModel):
    slides: Dict[str, SlideContent] = Field(..., description="Dictionary of slide contents with slide numbers as keys")
    #mcqs: Dict[str, MCQContent] = Field(..., description="Dictionary of MCQs with identifiers like 'mcq_1' as keys")
    token_count: int = Field(..., description="Total token count for all the generated content")
    #token_count: int = 0


class StorigoContentMCQ(BaseModel):
    slides: Dict[str, SlideContent] = Field(..., description="Dictionary of slide contents with slide numbers as keys")
    mcqs: Dict[str, MCQContent] = Field(..., description="Dictionary of MCQs with identifiers like 'mcq_1' as keys")
    token_count: int = Field(..., description="Total token count for all the generated content")



class StorigoContentMCQMid(BaseModel):
    slides: Dict[str, Union[SlideContent, MCQContent]] = Field(..., description="Dictionary of slide contents with slide numbers as keys and MCQs with MCQ numbers as keys")
    token_count: int = Field(..., description="Total token count for all the generated content")

class CustomMCQParser(PydanticOutputParser):
    def parse_result(self, result):
        # Step 1: Ensure that result is a string
        if isinstance(result, list):
            # If it's a list, join it into a single string
            result = " ".join(result)

        # Step 2: Convert single quotes to double quotes
        result = result.replace("'", "\"")

        # Step 3: Optionally, remove any unwanted text (e.g., "Here is the MCQ:")
        result = result.replace("Here is the MCQ:", "").strip()

        # Step 4: Try parsing the cleaned output as JSON
        try:
            # Parse the cleaned string into a JSON object
            json_object = json.loads(result)

            # Step 5: Use the Pydantic model to validate the JSON
            return self.pydantic_object.model_validate(json_object)
        except json.JSONDecodeError as e:
            raise Exception(f"Error decoding JSON: {str(e)}")
        except Exception as e:
            raise Exception(f"Error parsing result: {str(e)}")


def extract_text_from_pdf(pdf_path):
    try:
        reader = PdfReader(pdf_path)
        text = ""
        for page in reader.pages:
            text += page.extract_text()
        return text
    except Exception as e:
        raise Exception(f"Error extracting text from PDF: {str(e)}")


def count_tokens(text):
    tokens = re.findall(r'\w+', text)
    return len(tokens)




def generate_slide_content(vectors,client_id, num_slides,num_mcqs, is_image, is_question, question_position, GPU):
    try:
        # Always use Ollama for inference
        llm = ChatOllama(
            base_url='http://127.0.0.1:11434',
            model="gemma3:12b",
            temperature=0.7
        )

        # Load all chunks from vector store
        all_chunks = list(vectors.docstore._dict.values())
        total_chunks = len(all_chunks)

        # Step 1: Estimate total possible slides
        def estimate_slides(text, words_per_slide=50):
            return max(1, round(len(text.split()) / words_per_slide))

        total_possible = sum(estimate_slides(doc.page_content) for doc in all_chunks)

        # Step 2: Select chunks for context
        selected_chunks = []

        if num_slides <= total_chunks:
            indices = np.linspace(0, total_chunks - 1, num=num_slides, dtype=int)
            selected_chunks = [all_chunks[i] for i in indices]
        else:
            word_counts = [len(c.page_content.split()) for c in all_chunks]
            total_words = sum(word_counts)
            raw_alloc = [w / total_words * num_slides for w in word_counts]
            alloc = [max(1, round(x)) for x in raw_alloc]

            # Adjust total to exactly num_slides
            diff = num_slides - sum(alloc)
            i = 0
            while diff != 0:
                if diff > 0:
                    alloc[i] += 1
                    diff -= 1
                elif diff < 0 and alloc[i] > 1:
                    alloc[i] -= 1
                    diff += 1
                i = (i + 1) % total_chunks

            # Duplicate chunks proportionally
            for chunk, count in zip(all_chunks, alloc):
                selected_chunks.extend([chunk] * count)

        context_text = "\n\n".join(doc.page_content for doc in selected_chunks[:num_slides])
        # Prompt template
        slide_content_template = """
Based on the following context, generate professional and engaging content for exactly {num_slides} slides in a Storigos presentation.

STRICT RULES:
- ❗ ONLY use the content provided in the 'context' below.
- ❌ DO NOT introduce any external knowledge, definitions, or examples not present in the context.
- ⚠️ Do not assume common sense or use general facts. Stick to the exact information given.
- ⚠️ Avoid generic phrases like “as we know”, “in general”, or “in this video”.

Each slide must include:
- A clear and concise **sub-heading**
- **Exactly 2–4 concise paragraphs** derived solely from the context
- A **visualization suggestion** (max 5 words, specific to the content)

Important: Only output the final JSON object. No additional text, markdown, or explanation should be included.

Context:
{context}

{format_instructions}

The final output must be a valid JSON object where each slide is represented as "slide_1", "slide_2", ..., up to "slide_{num_slides}".
Each slide must contain:
- "subheading"
- "paragraphs"
- "visualization_suggestion"
"""
        raw_parser = PydanticOutputParser(pydantic_object=StorigoContent)
        parser = OutputFixingParser.from_llm(parser=raw_parser, llm=llm)
        slide_content_prompt = ChatPromptTemplate.from_template(slide_content_template)

        # Creating the chain
        slide_content_chain = (
            {
                "context": lambda x: context_text,
                "num_slides": lambda x: x["num_slides"],
                "format_instructions": lambda x: parser.get_format_instructions()
            }
            | slide_content_prompt
            | llm

            | parser
        )

        # Invoke the chain with the query and number of slides
        result = slide_content_chain.invoke({"query": "", "num_slides": num_slides})

        def custom_slide_sort_key(item_key: str):
            prefix, num_str = item_key.split('_', 1)
            if prefix == "slide":
                group = 0
            elif prefix == "mcq":
                group = 1
            else:
                group = 999
            number = int(num_str) if num_str.isdigit() else 999
            return (group, number)

        ordered_slides = dict(sorted(result.slides.items(), key=lambda kv: custom_slide_sort_key(kv[0])))

        if isinstance(is_image, int):
            is_image_bool = is_image == 1
        elif isinstance(is_image, str):
            is_image_bool = is_image.strip() in ["1", "true", "yes", "0"]
        else:
            is_image_bool = False

        if is_image_bool:
            for slide_key, slide_content in ordered_slides.items():
                if slide_content.visualization_suggestion:
                    image_path = fetch_image_for_slide(
                        slide_key,
                        slide_content.visualization_suggestion
                    )
                    if image_path:
                        slide_content.image = image_path
                    else:
                        print(f"Warning: No suitable image generated for slide {slide_key}.")
                        slide_content.image = None
                else:
                    print(f"Warning: No visualization suggestion for slide {slide_key}.")
                    slide_content.image = None
        else:
            # If not an image slide, ensure the 'image' field is None
            for slide_content in ordered_slides.values():
                slide_content.image = None

        token_count = 0
        for slide in ordered_slides.values():
            text_content = f"{slide.subheading} {' '.join(slide.paragraphs)} {slide.visualization_suggestion}"
            token_count += count_tokens(text_content)


        # Calculate the total token count
        #mcqs = {}
        if is_question:
            mcq_template = """
Based on the following context from the last two slides, generate one multiple-choice question (MCQ). The question should be relevant to the content and designed to test comprehension.

**Context**: {context}

The MCQ must include:
- A **question** related to the context
- Exactly **4 answer options**
- A clear indication of the **correct answer** as a single letter: 'a', 'b', 'c', or 'd'

⚠️ **Critical Requirements**:
- ✅ Return **only valid JSON** — no explanations, headers, or extra text.
- ✅ Ensure all fields and options are enclosed in **double quotes (`"`)**.
- ✅ Do **not** use letters like "A.", "B." in the options — just the plain text.
- Dont give Here is the MCQ while generating MCQ
- Directly follow the format given below

The final output **must strictly follow** this format:
```json
{{
    "question": "<The MCQ question>",
    "options": [
        "<Option 1>",
        "<Option 2>",
        "<Option 3>",
        "<Option 4>"
    ],
    "correct_answer": "<Correct option (e.g., 'a', 'b', 'c', or 'd')>"

    Always give "question","options","correct_answer" these labels in double quotes only
}}

"""
            def is_valid_json(response):
                try:
                    json.loads(response)
                    return True
                except json.JSONDecodeError:
                    return False

            import asyncio

            async def generate_mcqs_async():
                mcq_prompt = ChatPromptTemplate.from_template(mcq_template)
                mcq_llm = ChatOllama(
                    base_url='http://127.0.0.1:11434',
                    model="gemma3:12b",
                    temperature=0.7
                )

                mcqs = {}
                slide_keys = list(ordered_slides.keys())
                tasks = []

                for i in range(0, len(slide_keys), 2):
                    if len(mcqs) >= num_mcqs:
                        break
                    context_slides = []
                    for j in range(i, min(i + 2, len(slide_keys))):
                        slide = ordered_slides[slide_keys[j]]
                        context_slides.append(f"{slide.subheading}: {' '.join(slide.paragraphs)}")

                    context_text = "\n".join(context_slides)

                    async def generate_single_mcq(ctx_text, idx):
                        try:
                            mcq_result = await asyncio.get_event_loop().run_in_executor(
                                None,
                                lambda: (
                                    RunnableLambda(lambda x: {"context": ctx_text})
                                    | mcq_prompt
                                    | mcq_llm
                                ).invoke({})
                            )

                            if hasattr(mcq_result, 'content'):
                                content = mcq_result.content
                            else:
                                content = str(mcq_result)

                            start = content.find('{')
                            if start != -1:
                                open_braces = 0
                                end = -1
                                for idx_char, char in enumerate(content[start:], start=start):
                                    if char == '{':
                                        open_braces += 1
                                    elif char == '}':
                                        open_braces -= 1
                                        if open_braces == 0:
                                            end = idx_char
                                            break

                                if end != -1:
                                    content_only = content[start:end+1]
                                    json_object = json.loads(content_only)
                                else:
                                    return None
                            else:
                                return None

                            if not content_only or not content_only.strip():
                                return None

                            formatted_mcq = {
                                "question": json_object.get("question", ""),
                                "options": json_object.get("options", []),
                                "correct_answer": json_object.get("correct_answer", "")
                            }

                            if (formatted_mcq["question"] and
                                len(formatted_mcq["options"]) == 4 and
                                formatted_mcq["correct_answer"]):
                                return formatted_mcq
                            return None
                        except Exception as e:
                            return None

                    if len(tasks) < num_mcqs:
                        task = generate_single_mcq(context_text, i)
                        tasks.append(task)

                results = await asyncio.gather(*tasks, return_exceptions=True)
                for result in results:
                    if isinstance(result, dict) and result:
                        mcq_key = f"mcq_{len(mcqs) + 1}"
                        mcqs[mcq_key] = result

                return mcqs

            mcqs = asyncio.run(generate_mcqs_async())

            #Calculate token count
            #token_count=0
            for mcq in mcqs.values():
                text_content = f"{mcq['question']} {' '.join(mcq['options'])} {mcq['correct_answer']}"
                token_count += count_tokens(text_content)

            interleaved_content = {}
            mcq_counter = 0
            total_slides = num_slides
            total_mcqs = num_mcqs

            if question_position == 1:
                interval = total_slides // total_mcqs if total_mcqs > 0 else total_slides

                for idx, slide_key in enumerate(slide_keys):
                    interleaved_content[slide_key] = ordered_slides[slide_key]

                    if (idx + 1) % interval == 0 and mcq_counter < total_mcqs:
                        mcq_key = f"mcq_{mcq_counter + 1}"
                        print(mcq_key)
                        interleaved_content[mcq_key] = mcqs[mcq_key]
                        mcq_counter += 1

                storigo_content = StorigoContentMCQMid(slides=interleaved_content, token_count=token_count)
                return storigo_content

            else:
                for slide_key in slide_keys:
                    interleaved_content[slide_key] = ordered_slides[slide_key]

                for mcq_counter in range(total_mcqs):
                    mcq_key = f"mcq_{mcq_counter + 1}"
                    interleaved_content[mcq_key] = mcqs[mcq_key]

                storigo_content = StorigoContentMCQMid(slides=interleaved_content, token_count=token_count)
                return storigo_content
        else:
            storigo_content_without_mc = StorigoContent(slides=ordered_slides, token_count=token_count)
            return storigo_content_without_mc

    except Exception as e:
        raise Exception(f"Error generating slide content: {str(e)}")

def load_pdf(file_path):
    """Load PDF and return documents."""
    #loader = PyPDFLoader(f"{file_path}.pdf")
    loader = PyPDFLoader(file_path)
    return loader.load()

def count_total_words(docs):
    """Count the total number of words in the documents."""
    return sum(len(doc['page_content'].split()) for doc in docs)

def split_text_with_semantic_chunker(docs, embeddings):
    """Splits the text into semantic chunks using the given embeddings."""
    text_splitter = SemanticChunker(
        embeddings, breakpoint_threshold_type="percentile"  # Can be changed to "standard_deviation", "interquartile"
    )
    documents = text_splitter.create_documents([doc.page_content for doc in docs])
    #documents = text_splitter.create_documents([doc['page_content'] for doc in docs])
    print("Documents split into semantic chunks.")
    return documents

async def crawlerrr(file):
    # Create an instance of AsyncWebCrawler
    async with AsyncWebCrawler(verbose=True) as crawler:
        # Run the crawler on a URL
        result = await crawler.arun(url=file)

        # Print the extracted content
        print(result.markdown)

        # Extract a safe filename from the URL
        parsed_url = urlparse(file)
        filename = parsed_url.netloc + parsed_url.path.replace('/', '_')
        if not filename.endswith('.txt'):
            filename += '.txt'

        # Save the result to the text file using the `filename`
        with open(filename, 'w') as txt_file:
            txt_file.write(result.markdown)
        print("HELLO jII")
        print(filename)
        return filename

def read_file_url(input):
    try:
        # Check if input path is a valid file
        print("Started")
        if not input or not os.path.isfile(input):
            print(f"Error: The file '{input}' does not exist or the path is incorrect.")
            return ""

        # Open the file in read mode ('r')
        with open(input, 'r', encoding='utf-8') as file:
            # Read the entire content of the file
            content = file.read()

        # Return the content of the file (if content is empty, return empty string)
        return content if content else ""

    except Exception as e:
        # Catch any other unexpected errors
        print(f"Error reading the file: {e}")
        return ""



def clean_using_llm(content):
    # Define the prompt template for meaningful content extraction
    prompt_template = """
    Extract only the meaningful content from the text below. Focus on descriptions, value propositions, mission statements,
    features, and anything that provides valuable information about the company, products, or services. Ignore any URLs,
    navigation links, contact forms, or irrelevant sections.

    Here is the content to process:

    {context}
    """

    llm = ChatOllama(
    base_url = 'http://127.0.0.1:11434',
    #model = "llama3:8b"
    model = "gemma3:12b"
    #model = "deepseek-r1:8b"
    )
    # Create the PromptTemplate object
    prompt = PromptTemplate(input_variables=["context"], template=prompt_template)

    # Create the LLMChain to pass the prompt and run the model
    runnable = prompt | llm

    # Run the sequence to get the filtered content
    filtered_content = runnable.invoke({"context": content})
    print(filtered_content)
    filtered_content = filtered_content.content
    print(type(filtered_content))
    return filtered_content

def split_text_with_semantic_chunker_for_url(docs, embeddings):
    """Splits the text into semantic chunks using the given embeddings."""
    text_splitter = SemanticChunker(
        embeddings, breakpoint_threshold_type="percentile"
    )

    # Check if docs is a string instead of a list
    if isinstance(docs, str):
        # Convert the string to a list with one item
        docs = [docs]

    # Debugging: Print the type of items in docs
    print(f"Type of docs after conversion: {type(docs)}")
    print(f"First item in docs: {docs[0] if docs else 'Empty list'}")

    # Convert strings to dictionaries with 'page_content' if needed
    if isinstance(docs[0], str):
        docs = [{'page_content': doc} for doc in docs]

    # Ensure all docs have the correct structure
    if not all(isinstance(doc, dict) and 'page_content' in doc for doc in docs):
        print("Error: Invalid document structure.")
        return []

    # Create semantic chunks
    documents = text_splitter.create_documents([doc['page_content'] for doc in docs])
    print("Documents split into semantic chunks.")
    print(documents)
    return documents

def save_documents_to_txt(documents, output_dir):
    """Saves each document in the documents list as a separate .txt file."""
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)  # Create the output directory if it doesn't exist

    for i, document in enumerate(documents):
        file_name = f"document_part_{i+1}.txt"
        file_path = os.path.join(output_dir, file_name)

        with open(file_path, 'w', encoding='utf-8') as file:
            file.write(document.page_content)  # Assuming each document object has a 'page_content' attribute

        print(f"Saved: {file_path}")


def create_and_save_embeddings(split_documents, client_id):

    client_id = str(client_id)
    #eference_id = str(reference_id)
    # Base folder structure: my_embeddings/{client_id}/{reference_id}
    embedding_folder_base = os.path.join("my_embeddings", client_id)

    # Make sure the base embedding folder exists
    os.makedirs(embedding_folder_base, exist_ok=True)

    # Initialize the embedding model
    embeddings = OllamaEmbeddings(model='nomic-embed-text')

    # Iterate over each document chunk and generate embeddings
    for idx, doc in enumerate(split_documents, start=1):
        # Create a unique folder for each document's embeddings directly inside embedding_folder_base
        embedding_folder = os.path.join(embedding_folder_base)

        # Ensure each document's folder is created fresh inside the base folder, without nesting
        os.makedirs(embedding_folder, exist_ok=True)

        # Create a FAISS index for this chunk
        temp_db = FAISS.from_documents([doc], embedding=embeddings)

        # Save the FAISS index for this chunk with an incremental filename
        embedding_file_path = os.path.join(embedding_folder, f"faiss_index{idx}")
        temp_db.save_local(embedding_file_path)

        print(f"Saved FAISS embedding for document part {idx} as faiss_index{idx} in {embedding_folder}")

def create_embeddings(split_documents, client_id):
    client_id = str(client_id)

    # Initialize the embedding model
    embeddings = OllamaEmbeddings(model='nomic-embed-text')

    # Create a FAISS index from all chunks (in memory only)
    vectorstore = FAISS.from_documents(split_documents, embedding=embeddings)
    faiss_index_path = "faiss_supplier_index"
    vectorstore.save_local(faiss_index_path)

    print(f"✅ Created FAISS vectorstore in memory for client {client_id}")

    return vectorstore
import pickle

def save_faiss_per_chunk(documents, base_path="faiss_chunks", embedding_model=None, api_key=None):
    """
    Save each chunk in its own FAISS vector store directory.

    Args:
        documents (List[Document]): List of LangChain Document objects.
        base_path (str): Base directory to store all FAISS chunks.
        embedding_model: Optional embedding model instance.
        api_key (str): Required if embedding_model is not passed.

    Returns:
        List[str]: List of FAISS chunk folder paths.
    """
    if embedding_model is None:
        embedding_model = OllamaEmbeddings(model='nomic-embed-text')

    os.makedirs(base_path, exist_ok=True)
    chunk_paths = []

    for i, doc in enumerate(documents):
        chunk_dir = os.path.join(base_path, f"chunk_{i}")
        os.makedirs(chunk_dir, exist_ok=True)

        # Each FAISS requires at least 1 document
        vector_store = FAISS.from_documents([doc], embedding_model)

        # Save the FAISS index
        vector_store.save_local(chunk_dir)
        chunk_paths.append(chunk_dir)

        # Save the original document for reference
        with open(os.path.join(chunk_dir, "doc_metadata.pkl"), "wb") as f:
            pickle.dump(doc, f)

    print(f"✅ Saved {len(documents)} chunks as individual FAISS indexes in '{base_path}'")
    return chunk_paths

def create_and_save_embeddings_new(split_documents, client_id):

    client_id = str(client_id)
    #eference_id = str(reference_id)
    # Base folder structure: my_embeddings/{client_id}/{reference_id}
    embedding_folder_base = os.path.join("my_embeddings", client_id)

    # Make sure the base embedding folder exists
    os.makedirs(embedding_folder_base, exist_ok=True)

    # Initialize the embedding model
    embeddings = OllamaEmbeddings(model='nomic-embed-text')

    # Iterate over each document chunk and generate embeddings
    for idx, doc in enumerate(split_documents, start=1):
        # Create a unique folder for each document's embeddings directly inside embedding_folder_base
        embedding_folder = os.path.join(embedding_folder_base)

        # Ensure each document's folder is created fresh inside the base folder, without nesting
        os.makedirs(embedding_folder, exist_ok=True)

        # Create a FAISS index for this chunk
        temp_db = FAISS.from_documents([doc], embedding=embeddings)

        # Save the FAISS index for this chunk with an incremental filename
        embedding_file_path = os.path.join(embedding_folder, f"faiss_index{idx}")
        temp_db.save_local(embedding_file_path)

        print(f"Saved FAISS embedding for document part {idx} as faiss_index{idx} in {embedding_folder}")



def merge_all_faiss1(client_id, base_path='my_embeddings'):
    embeddings=OllamaEmbeddings(model="nomic-embed-text")
    # Initialize an empty FAISS vectorstore for merging
    merged_faiss = None

    # Construct the base folder path
    folder_path = f'{base_path}/{client_id}'

    # List all folders that match the pattern 'faiss_index{i}'
    faiss_folders = [
        folder for folder in os.listdir(folder_path)
        if folder.startswith('faiss_index') and folder[len('faiss_index'):].isdigit()
    ]

    # Sort folders by the index number extracted from 'faiss_index{i}'
    sorted_folders = sorted(faiss_folders, key=lambda x: int(x.replace('faiss_index', '')))

    # Loop through the sorted folders and merge FAISS stores
    for folder in sorted_folders:
        faiss_path = os.path.join(folder_path, folder)
        print(f"Loading FAISS index from: {faiss_path}")  # Debugging: See the order of loading
        current_faiss = FAISS.load_local(faiss_path, embeddings, allow_dangerous_deserialization=True)

        # If it's the first FAISS store, initialize merged_faiss
        if merged_faiss is None:
            merged_faiss = current_faiss
        else:
            # Merge current FAISS store into merged_faiss
            merged_faiss.merge_from(current_faiss)

    # Optionally, save the merged FAISS index to a new folder
    if merged_faiss is not None:
        merged_faiss.save_local(f'{folder_path}/merged_faiss')
    print(merged_faiss)

    # Delete individual FAISS index folders, except for the 'merged_faiss'
    for folder in sorted_folders:
        faiss_path = os.path.join(folder_path, folder)
        try:
            # Delete the entire directory for the FAISS index (e.g., faiss_index3)
            shutil.rmtree(faiss_path)
            print(f"Deleted FAISS index folder: {faiss_path}")
        except FileNotFoundError:
            print(f"Folder not found: {faiss_path}")
        except OSError as e:
            print(f"Error deleting {faiss_path}: {e}")
    return merged_faiss

def merge_all_faiss(client_id, base_path='my_embeddings'):
    embeddings = OllamaEmbeddings(model="nomic-embed-text")
    merged_faiss = None

    folder_path = f'{base_path}/{client_id}'
    faiss_files = [
        folder for folder in os.listdir(folder_path)
        if folder.startswith('faiss_index') and folder[len('faiss_index'):].isdigit()
    ]

    sorted_files = sorted(faiss_files, key=lambda x: int(x.replace('faiss_index', '')))

    for file in sorted_files:
        faiss_path = os.path.join(folder_path, file)
        print(f"Loading FAISS index from: {faiss_path}")
        current_faiss = FAISS.load_local(faiss_path, embeddings, allow_dangerous_deserialization=True)

        # Extract document content
        current_texts = [current_faiss.docstore.search(doc_id).page_content
                         for doc_id in current_faiss.index_to_docstore_id.values()]

        if merged_faiss is None:
            merged_faiss = current_faiss
        else:
            # Only add texts — this adds both vectors and metadata
            merged_faiss.add_texts(current_texts)

    if merged_faiss is not None:
        merged_faiss.save_local(f'{folder_path}/merged_faiss')
        print(f"Merged FAISS index saved as merged_faiss")

    # Clean up individual indexes
    for file in sorted_files:
        faiss_path = os.path.join(folder_path, file)
        try:
            shutil.rmtree(faiss_path)
            print(f"Deleted FAISS index folder: {faiss_path}")
        except FileNotFoundError:
            print(f"Folder not found: {faiss_path}")
        except OSError as e:
            print(f"Error deleting {faiss_path}: {e}")

    return merged_faiss




# You Tube
from youtube_transcript_api import YouTubeTranscriptApi

def transcribe(youtube_video_url):
    video_id = youtube_video_url.split("=")[1]
    transcript_text = YouTubeTranscriptApi.get_transcript(video_id)
    print(transcript_text)
    transcript = ""

    for  i in transcript_text:
        transcript += " " + i["text"]

    with open(video_id, "w", encoding="utf-8") as f:
        f.write(transcript)

    print(f"Transcript saved to {video_id}")
    return video_id

def load_txt(file_path):
    """Load PDF and return documents."""
    #loader = PyPDFLoader(f"{file_path}.pdf")
    loader = TextLoader(file_path)
    return loader.load()

def split_text_with_semantic_chunker(docs, embeddings):
    """Splits the text into semantic chunks using the given embeddings."""
    text_splitter = SemanticChunker(
        embeddings, breakpoint_threshold_type="percentile"  # Can be changed to "standard_deviation", "interquartile"
    )
    documents = text_splitter.create_documents([doc.page_content for doc in docs])
    #documents = text_splitter.create_documents([doc['page_content'] for doc in docs])
    #documents = text_splitter.create_documents(docs)
    print("Documents split into semantic chunks.")
    return documents

def parsing(input):
    converter = PdfConverter(artifact_dict=create_model_dict())
    rendered = converter(input)
    text, _, images = text_from_rendered(rendered)
    return text

def marks_splitter(headers_to_split_on, content):
    markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on)
    md_header_splits = markdown_splitter.split_text(content)
    print("CCC")
    print(md_header_splits)
    return md_header_splits

def allocate_slides(chunks, total_slides, min_chars=60):
    """
    Allocate slides based on chunk size (works on Document-like objects).

    Args:
        chunks: List of Document (or string) objects
        total_slides: Total number of slides to generate
        min_chars: Minimum character count to consider a chunk valid

    Returns:
        Dict mapping original chunk index to number of slides to generate
    """
    # Step 0: Extract text from each chunk
    extracted = []
    for i, chunk in enumerate(chunks):
        if hasattr(chunk, "page_content"):
            text = chunk.page_content
        elif hasattr(chunk, "content"):
            text = chunk.content
        elif isinstance(chunk, str):
            text = chunk
        else:
            text = str(chunk)
        extracted.append((i, text))

    # Step 1: Filter out too‑short chunks
    valid = [(i, txt) for i, txt in extracted if len(txt) >= min_chars]
    if not valid:
        return {}

    # Step 2: Total characters across valid chunks
    total_chars = sum(len(txt) for _, txt in valid)

    # Step 3: Initial proportional allocation
    allocations = {}
    remaining = total_slides
    for idx, txt in valid:
        prop = len(txt) / total_chars
        cnt = max(1, round(prop * total_slides))
        allocations[idx] = cnt
        remaining -= cnt

    # Step 4a: If we’ve overshot, remove from the largest buckets
    while remaining < 0:
        # pick the chunk with max slides
        max_idx = max(allocations, key=allocations.get)
        if allocations[max_idx] > 1:
            allocations[max_idx] -= 1
            remaining += 1
        else:
            break

    # Step 4b: If we’ve undershot, add to the largest chunks by text size
    while remaining > 0:
        # sort valid chunks by length descending
        sorted_by_size = sorted(valid, key=lambda x: len(x[1]), reverse=True)
        for i in range(min(remaining, len(sorted_by_size))):
            idx = sorted_by_size[i][0]
            allocations[idx] = allocations.get(idx, 0) + 1
            remaining -= 1

    return allocations


  # adjust as needed

class SlideCollection:
    """Container class for slides with dict-like behavior"""
    def __init__(self):
        self.slides = {}

    def add_slide(self, key, content):
        self.slides[key] = content

    # ← Add these:
    def keys(self):
        return self.slides.keys()

    def __iter__(self):
        return iter(self.slides)

    def __getitem__(self, key):
        return self.slides[key]

    def items(self):
        return self.slides.items()

    def values(self):
        return self.slides.values()

    def __repr__(self):
        return repr(self.slides)

def quick_json_fix(ai_message) -> str:
    # Extract text content first
    if hasattr(ai_message, 'content'):
        text = ai_message.content
    else:
        text = str(ai_message)

    text = text.strip()

    # Remove explanatory text
    if "Here's another attempt" in text or "I apologize" in text:
        start = text.find('{')
        end = text.rfind('}') + 1
        if start != -1 and end != 0:
            text = text[start:end]

    # Fix quote issues
    text = re.sub(r"'(\w+)\":", r'"\1":', text)

    # Remove "properties" wrapper
    try:
        parsed = json.loads(text)
        if "properties" in parsed:
            return json.dumps(parsed["properties"])
    except:
        pass

    return text

def generate_slide_content_alloc1(chunks, allocations, num_slides, num_mcqs, is_image, is_question, question_position, GPU):
    try:
        slide_content_template = """
Based on the following context, generate professional and engaging content for exactly {num_slides} slides in a Storigos presentation.

Each slide must include:
- A clear and concise **sub-heading**
- **Paragraphs** that effectively communicate the key ideas and insights
- A specific, concise **visualization suggestion**

**Context**: {query}

Focus on creating content that is both informative and engaging. Ensure each slide:
- Has a well-structured sub-heading that captures the main point
- Uses clear and concise paragraphs to communicate important information

Use a professional and creative tone throughout. Each slide should incorporate the following elements where appropriate:
- **Thought-provoking questions** to encourage reflection
- **Relevant statistics** or data points that add credibility
- **Industry insights** or emerging trends to demonstrate expertise
- **Practical examples** or case studies to illustrate key concepts
- **Calls to action** to guide the audience toward specific actions or takeaways

For the visualization suggestion:
- Provide a clear and specific description of an image that would be relevant to the slide content.
- Keep it very concise, using a maximum of 5 words.
- Focus on concrete objects, scenes, or concepts that can be easily visualized.
- Avoid abstract or overly complex ideas.
- Include the context of the topic (e.g., "Python programming logo" instead of just "Python logo").

Make sure all content is drawn exclusively from the provided context or embedded data. Avoid introducing external information not found in the source material.

{format_instructions}

CRITICAL: The output must be a valid JSON object with this EXACT structure:
{{
  "slides": {{
    "slide_1": {{
      "type": "flash",
      "subheading": "...",
      "paragraphs": ["...", "..."],
      "visualization_suggestion": "...",
      "image": null
    }},
    "slide_2": {{ ... }}
  }},
  "token_count": 0
}}

DO NOT put "token_count" inside the "slides" object. It must be at the root level.
DO NOT include any explanations or additional text - only the JSON object.
The final output must be in strict sequential order: "slide_1", "slide_2", ..., up to "slide_{num_slides}".
"""

        parser = PydanticOutputParser(pydantic_object=StorigoContent)
        slide_content_prompt = ChatPromptTemplate.from_template(slide_content_template)

        llm = ChatOllama(
            base_url='http://127.0.0.1:11434',
            model="gemma3:12b"
        )

        slide_content_chain = (
            {
                "query": lambda x: x["query"],
                "num_slides": lambda x: x["num_slides"],
                "format_instructions": lambda x: parser.get_format_instructions()
            }
            | slide_content_prompt
            | llm
            | parser
        )

        all_slides = SlideCollection()
        counter = 1

        for chunk_idx in sorted(allocations):
            n = allocations[chunk_idx]
            chunk = chunks[chunk_idx]
            query = getattr(chunk, "page_content", getattr(chunk, "content", str(chunk)))

            result = slide_content_chain.invoke({"query": query, "num_slides": n})

            # Extract slides properly
            if hasattr(result, 'slides'):
                slide_items = result.slides
            else:
                raw = result.model_dump() if hasattr(result, 'model_dump') else result
                slide_items = raw.get("slides", {})

            # Remove token_count if it's inside slides (wrong location)
            if isinstance(slide_items, dict) and "token_count" in slide_items:
                slide_items.pop("token_count")

            # Add slides to collection
            for slide_key in sorted(slide_items.keys(), key=lambda k: int(k.split("_")[1])):
                all_slides.add_slide(f"slide_{counter}", slide_items[slide_key])
                counter += 1

        def custom_slide_sort_key(item_key: str):
            prefix, num_str = item_key.split('_', 1)
            if prefix == "slide":
                group = 0
            elif prefix == "mcq":
                group = 1
            else:
                group = 999
            number = int(num_str) if num_str.isdigit() else 999
            return (group, number)

        ordered_slides = dict(sorted(all_slides.slides.items(), key=lambda kv: custom_slide_sort_key(kv[0])))
        all_slides.slides = ordered_slides

        # Image handling logic
        if is_image:
            for slide_key, slide_content in all_slides.slides.items():
                # Convert dict to SlideContent object if needed
                if isinstance(slide_content, dict):
                    slide_obj = SlideContent(**slide_content)
                else:
                    slide_obj = slide_content

                if slide_obj.visualization_suggestion:
                    image_path = fetch_image_for_slide(
                        slide_key,
                        slide_obj.visualization_suggestion
                    )
                    if image_path:
                        if isinstance(slide_content, dict):
                            all_slides.slides[slide_key]['image'] = image_path
                        else:
                            slide_content.image = image_path
                    else:
                        if isinstance(slide_content, dict):
                            all_slides.slides[slide_key]['image'] = None
                        else:
                            slide_content.image = None
                else:
                    if isinstance(slide_content, dict):
                        all_slides.slides[slide_key]['image'] = None
                    else:
                        slide_content.image = None
        else:
            for slide_key in all_slides.slides:
                if isinstance(all_slides.slides[slide_key], dict):
                    all_slides.slides[slide_key]['image'] = None
                else:
                    all_slides.slides[slide_key].image = None

        # MCQ handling
        if is_question:
            mcq_template = """
Based on the following context from the last two slides, generate one multiple-choice question (MCQ). The question should be relevant to the content and designed to test comprehension.

**Context**: {context}

The MCQ must include:
- A **question** related to the context
- Exactly **4 answer options**
- A clear indication of the **correct answer** as a single letter: 'a', 'b', 'c', or 'd'

⚠️ **Critical Requirements**:
- ✅ Return **only valid JSON** — no explanations, headers, or extra text.
- ✅ Ensure all fields and options are enclosed in **double quotes (`"`)**.
- ✅ Do **not** use letters like "A.", "B." in the options — just the plain text.
- Directly follow the format given below

The final output **must strictly follow** this format:
```json
{{
    "question": "<The MCQ question>",
    "options": [
        "<Option 1>",
        "<Option 2>",
        "<Option 3>",
        "<Option 4>"
    ],
    "correct_answer": "<Correct option (e.g., 'a', 'b', 'c', or 'd')>"
}}

{format_instructions}
"""
            mcq_parser = PydanticOutputParser(pydantic_object=MCQContent)
            mcq_prompt = ChatPromptTemplate.from_template(mcq_template)

            json_fixer = RunnableLambda(quick_json_fix)
            output_fixing_parser = OutputFixingParser.from_llm(llm=llm, parser=mcq_parser)

            mcqs = {}
            slide_keys = list(all_slides.slides.keys())

            for i in range(0, len(slide_keys), 2):
                if len(mcqs) < num_mcqs:
                    context_slides = []
                    for j in range(i, min(i + 2, len(slide_keys))):
                        key = slide_keys[j]
                        slide = all_slides.slides[key]

                        if isinstance(slide, dict):
                            title = slide.get("subheading", "")
                            paras = slide.get("paragraphs", [])
                        else:
                            title = slide.subheading
                            paras = slide.paragraphs

                        context_slides.append(f"{title}: {' '.join(paras)}")

                    context_text = "\n".join(context_slides)
                    try:
                        mcq_result = (
                            mcq_prompt
                            | llm
                            | json_fixer
                            | output_fixing_parser
                        ).invoke({
                            "context": context_text,
                            "format_instructions": mcq_parser.get_format_instructions()
                        })

                        mcqs[f"mcq_{len(mcqs) + 1}"] = mcq_result
                    except Exception as e:
                        continue

            # Interleave MCQs
            interleaved_content = {}
            mcq_counter = 0

            if question_position == 1:
                interval = num_slides // num_mcqs if num_mcqs > 0 else num_slides

                for idx, slide_key in enumerate(slide_keys):
                    interleaved_content[slide_key] = all_slides.slides[slide_key]

                    if (idx + 1) % interval == 0 and mcq_counter < num_mcqs:
                        mcq_key = f"mcq_{mcq_counter + 1}"
                        if mcq_key in mcqs:
                            interleaved_content[mcq_key] = mcqs[mcq_key]
                        mcq_counter += 1
            else:
                for slide_key in slide_keys:
                    interleaved_content[slide_key] = all_slides.slides[slide_key]

                for mcq_counter in range(num_mcqs):
                    mcq_key = f"mcq_{mcq_counter + 1}"
                    if mcq_key in mcqs:
                        interleaved_content[mcq_key] = mcqs[mcq_key]

            return StorigoContentMCQMid(slides=interleaved_content, token_count=0)
        else:
            return StorigoContent(slides=all_slides.slides, token_count=0)

    except Exception as e:
        raise Exception(f"Error generating slide content: {str(e)}")

def  main(input,num_slides,num_mcqs, is_image,is_question, question_position,GPU):
    parse_data = parsing(input)
    with open("parse_data.md", "w", encoding="utf-8") as f:
        f.write(parse_data)
    headers_to_split_on = [
    ("#", "Header 1"),
    ("##", "Header 2"),
    ("###", "Header 3"),
    ("####", "Header 4")
    ]
    with open("parse_data.md", "r", encoding="utf-8") as f:
        content = f.read()
    chunks = marks_splitter(headers_to_split_on, content)
    print("Chunks")

    allocation = allocate_slides(chunks, num_slides, min_chars=100)
    print(f"Total chunks: {len(chunks)}")
    print(f"Valid chunks: {len(allocation)}")
    print("\nSlide allocation:")

    slide_content = generate_slide_content_alloc1(chunks,allocation,num_slides,num_mcqs, is_image, is_question, question_position,GPU)
    print(slide_content)

if __name__ == "__main__":
    # Run the main function asynchronously
    input = 'Chapter3-Basic-Requirement-in-the-Kitchen.pdf'
    #input = 'https://edurigo.com/'
    output_dir = f'temp/{input}'
    embedding_folder_base = 'output_embeddings'

    client_id = 1113331144
    num_slides = 30
    is_image = True
    num_mcqs=14
    is_question = True
    question_position = 1
    GPU =1 
    #main(input, output_dir,client_id, is_image) 
    main(input,num_slides,num_mcqs, is_image,is_question, question_position,GPU)

    #asyncio.run(main(input, output_dir,client_id, is_image))
    
