import os
import time
import random
import re
import json
import requests
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser, PydanticOutputParser
from langchain_core.runnables import RunnablePassthrough
from PyPDF2 import PdfReader
from pydantic import BaseModel, Field
from typing import List, Optional, Dict
from urllib.parse import urlparse
from typing import Union, Dict
from langchain_core.runnables import RunnableLambda
import shutil
from langchain_ollama import ChatOllama


GROQ_API_KEY = "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"
OLLAMA_MODEL = "nomic-embed-text"
PIXABAY_API_KEY = "44622834-f22df6f12cf45558ee180dd8d"

class SlideContent(BaseModel):
    type: str = Field("flash")
    #heading: str = Field(..., description="The main heading of the slide")
    subheading: Optional[str] = Field(None, description="An optional subheading for the slide")
    paragraphs: List[str] = Field(..., description="List of paragraphs for the slide content")
    #key_points: Optional[List[str]] = Field(None, description="Optional list of key points")
    #additional_info: Optional[str] = Field(None, description="Any additional information or context")
    visualization_suggestion: str = Field(..., description="A specific and concise suggestion for a relevant visualization or image (max 5 words)")
    image: Optional[str] = Field(None, description="URL of the image for the slide")

class StorigoContent1(BaseModel):
    slides: Dict[str, SlideContent] = Field(..., description="Dictionary of slide contents with slide numbers as keys")
    token_count: int = Field(..., description="Total token count for all the generated content")

class MCQContent(BaseModel):
    type: str = Field("Question")
    question: str = Field(..., description="The multiple-choice question")
    options: List[str] = Field(..., description="A list of 4 answer options")
    correct_answer: str = Field(..., description="The correct answer (e.g., 'a', 'b', 'c', or 'd')")

class StorigoContent(BaseModel):
    slides: Dict[str, SlideContent] = Field(..., description="Dictionary of slide contents with slide numbers as keys")
    #mcqs: Dict[str, MCQContent] = Field(..., description="Dictionary of MCQs with identifiers like 'mcq_1' as keys")
    #token_count: int = Field(..., description="Total token count for all the generated content")
    token_count: int = 0


class StorigoContentMCQ(BaseModel):
    slides: Dict[str, SlideContent] = Field(..., description="Dictionary of slide contents with slide numbers as keys")
    mcqs: Dict[str, MCQContent] = Field(..., description="Dictionary of MCQs with identifiers like 'mcq_1' as keys")
    #token_count: int = Field(..., description="Total token count for all the generated content")
    #token_count: int = 0


class StorigoContentMCQMid(BaseModel):
    slides: Dict[str, Union[SlideContent, MCQContent]] = Field(..., description="Dictionary of slide contents with slide numbers as keys and MCQs with MCQ numbers as keys")
    #mcqs: Dict[str, MCQContent] = Field(..., description="Dictionary of MCQs with identifiers like 'mcq_1' as keys")
    token_count: int = Field(..., description="Total token count for all the generated content")

def extract_text_from_pdf(pdf_path):
    try:
        reader = PdfReader(pdf_path)
        text = ""
        for page in reader.pages:
            text += page.extract_text()
        return text
    except Exception as e:
        raise Exception(f"Error extracting text from PDF: {str(e)}")

def create_embeddings(text, client_id):
    try:
        embeddings = OllamaEmbeddings(model=OLLAMA_MODEL)
        vectors = FAISS.from_texts([text], embeddings)
        
        client_dir = f"my_embeddings/{client_id}"
        os.makedirs(client_dir, exist_ok=True)
        vectors.save_local(client_dir)
        
        return vectors
    except Exception as e:
        raise Exception(f"Error creating embeddings: {str(e)}")

def generate_search_query(visualization_suggestion, slide_content):
    context_keywords = extract_context_keywords(slide_content)
    
    # Combine visualization suggestion with context keywords
    combined_query = f"{visualization_suggestion} {' '.join(context_keywords)}"
    
    # Extract key words from the combined query
    words = re.findall(r'\w+', combined_query.lower())
    
    # Remove common words
    common_words = set(['and', 'or', 'the', 'a', 'an', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'])
    filtered_words = [word for word in words if word not in common_words]
    
    # Prioritize words from the visualization suggestion
    suggestion_words = visualization_suggestion.lower().split()
    prioritized_words = suggestion_words + [word for word in filtered_words if word not in suggestion_words]
    
    # Take the first 5 words
    query_words = prioritized_words[:min(5, len(prioritized_words))]
    
    return " ".join(query_words)

def extract_context_keywords(slide_content):
    # Extract keywords from slide content to provide context
    #text = f"{slide_content.heading} {slide_content.subheading or ''} {' '.join(slide_content.paragraphs)}"
    text = f"{slide_content.subheading or ''} {' '.join(slide_content.paragraphs)}"
    words = re.findall(r'\w+', text.lower())
    common_words = set(['and', 'or', 'the', 'a', 'an', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'])
    keywords = [word for word in words if word not in common_words]
    return list(set(keywords))[:3]  # Return up to 3 unique keywords

def fetch_pixabay_image(query):
    url = "https://pixabay.com/api/"
    params = {
        "key": PIXABAY_API_KEY,
        "q": query,
        "image_type": "photo",
        "orientation": "horizontal",
        "per_page": 5,  # Fetch top 5 images
        "safesearch": "true",
        "order": "relevance"
    }
    
    try:
        response = requests.get(url, params=params)
        response.raise_for_status()
        data = response.json()
        
        if data["hits"]:
            # Sort images by relevance score (you may need to adjust this based on Pixabay's API)
            sorted_hits = sorted(data["hits"], key=lambda x: x["likes"] + x["downloads"], reverse=True)
            return sorted_hits[0]["webformatURL"]  # Return the most relevant image
        else:
            print(f"No image found for query: {query}")
            return None
    except requests.RequestException as e:
        print(f"Error fetching image from Pixabay: {str(e)}")
        return None
    except Exception as e:
        print(f"Unexpected error in fetch_pixabay_image: {str(e)}")
        return None

def get_valid_image(visualization_suggestion, slide_content, max_attempts=3):
    if not visualization_suggestion:
        print("No visualization suggestion provided.")
        return None

    for attempt in range(max_attempts):
        try:
            query = generate_search_query(visualization_suggestion, slide_content)
            print(f"Attempt {attempt + 1} to fetch image for query: {query}")
            image_url = fetch_pixabay_image(query)
            
            if image_url:
                print(f"Valid image found: {image_url}")
                return image_url
            else:
                print(f"No image URL returned for query: {query}")
            
            time.sleep(1)
        except Exception as e:
            print(f"Error in get_valid_image (attempt {attempt + 1}): {str(e)}")
    
    print(f"No valid image found after {max_attempts} attempts")
    return None

def count_tokens(text):
    # Here you can define how you count tokens.
    # A simple way is to count words, assuming each word is a token.
    tokens = re.findall(r'\w+', text)
    return len(tokens)


def generate_slide_content_from_prompt(prompt, num_slides,num_mcqs, is_image, is_question, question_position, GPU):
    try:
        if(GPU==0):
            llm = ChatGroq(model_name='llama3-70b-8192', groq_api_key=GROQ_API_KEY)
        else:
            llm = ChatOllama(
            base_url = 'http://127.0.0.1:11434',
            model = "llama3:8b"
            #model = "deepseek-r1:8b"
        )  

        slide_content_template = """
        Based on the following prompt, generate professional and engaging content for exactly {num_slides} slides for a Storigos presentation.
        Each slide MUST have ALL of the following:
        - A subheading
        - 2-3 paragraphs
        - A specific, concise visualization suggestion (max 5 words)
        
        For the visualization suggestion:
        - Provide a clear and specific description of an image that would be relevant to the slide content.
        - Keep it very concise, using a maximum of 5 words.
        - Focus on concrete objects, scenes, or concepts that can be easily visualized.
        - Avoid abstract or overly complex ideas.
        - Include the context of the topic.

        Prompt: {prompt}

        Be creative and professional. Include engaging elements like:
        - Thought-provoking questions
        - Relevant statistics or data points
        - Industry insights or trends
        - Practical examples or case studies
        - Calls to action

        {format_instructions}

        Ensure that the output is a valid JSON object with keys "slide_1", "slide_2", etc., each containing the slide content.
        ALL fields ( subheading, paragraphs, visualization_suggestion) MUST be filled for each slide.
        Do not use null values. If you can't think of content for a field, provide a relevant placeholder or general information.
        """

        parser = PydanticOutputParser(pydantic_object=StorigoContent)
        slide_content_prompt = ChatPromptTemplate.from_template(slide_content_template)

        slide_content_chain = (
            {
                "prompt": lambda x: x["prompt"],
                "num_slides": lambda x: x["num_slides"],
                "format_instructions": lambda x: parser.get_format_instructions()
            }
            | slide_content_prompt
            | llm
            | parser
        )

        result = slide_content_chain.invoke({"prompt": prompt, "num_slides": num_slides})
        ordered_slides = dict(sorted(result.slides.items(), key=lambda item: item[0]))
        print("start")
        print(result)
        print("end")

        # Ensure all fields are filled
        for slide_key, slide_content in result.slides.items():
            # if not slide_content.subheading:
            #     slide_content.subheading = "Exploring " + slide_content.heading
            # if not slide_content.key_points:
            #     slide_content.key_points = ["Key point 1", "Key point 2", "Key point 3"]
            # if not slide_content.additional_info:
            #     slide_content.additional_info = "Additional context and information related to the slide topic."

            # Ensure there are at least 2 paragraphs
            while len(slide_content.paragraphs) < 2:
                slide_content.paragraphs.append("Additional information and context for this slide.")

        if is_image:
            for slide_key, slide_content in result.slides.items():
                if slide_content.visualization_suggestion:
                    image_url = get_valid_image(slide_content.visualization_suggestion, slide_content)
                    if image_url:
                        slide_content.image = image_url
                    else:
                        print(f"Warning: No suitable image found for slide {slide_key} after multiple attempts.")
                        slide_content.image = "https://via.placeholder.com/640x480.png?text=Placeholder+Image"
                else:
                    print(f"Warning: No visualization suggestion for slide {slide_key}.")
                    slide_content.image = "https://via.placeholder.com/640x480.png?text=Placeholder+Image"
        else:
            for slide_content in result.slides.values():
                slide_content.image = None

        # Calculate the total token count
        token_count_text = 0
        for slide_content in result.slides.values():
           # text_content = f"{slide_content.heading} {slide_content.subheading} {' '.join(slide_content.paragraphs)} {' '.join(slide_content.key_points or [])} {slide_content.additional_info or ''} {slide_content.visualization_suggestion}"
            text_content = f"{slide_content.subheading} {' '.join(slide_content.paragraphs)} {slide_content.visualization_suggestion}"
            token_count_text += count_tokens(text_content)

        #ordered_slides = result
        
        #mcqs = {}
        if is_question:
            mcq_template = """
Based on the following context from the last two slides, generate one multiple-choice question (MCQ). The question should be relevant to the content and designed to test comprehension.

**Context**: {context}

The MCQ must include:
- A **question** related to the context  
- Exactly **4 answer options**  
- A clear indication of the **correct answer** as a single letter: 'a', 'b', 'c', or 'd'

⚠️ **Critical Requirements**:
- ✅ Return **only valid JSON** — no explanations, headers, or extra text.
- ✅ Ensure all fields and options are enclosed in **double quotes (`"`)**.
- ✅ Do **not** use letters like "A.", "B." in the options — just the plain text.
- Dont give Here is the MCQ while generating MCQ
- Directly follow the format given below

The final output **must strictly follow** this format:
```json
{{
    "question": "<The MCQ question>",
    "options": [
        "<Option 1>",
        "<Option 2>",
        "<Option 3>",
        "<Option 4>"
    ],
    "correct_answer": "<Correct option (e.g., 'a', 'b', 'c', or 'd')>"

    Always give "question","options","correct_answer" these labels in double quotes only
}}






"""


            def is_valid_json(response):
                try:
                    json.loads(response)
                    return True
                except json.JSONDecodeError:
                    return False

            #mcq_parser = PydanticOutputParser(pydantic_object=MCQContent)
            mcq_prompt = ChatPromptTemplate.from_template(mcq_template)
            llm = llm = ChatOllama(
            base_url = 'http://127.0.0.1:11434',
            model = "llama3:8b")
            
            mcqs = {}
            slide_keys = list(ordered_slides.keys())
            
            for i in range(0, len(slide_keys), 2):
                if len(mcqs) < num_mcqs:
                    context_slides = []
                    for j in range(i, min(i + 2, len(slide_keys))):
                        slide = ordered_slides[slide_keys[j]]
                        context_slides.append(f"{slide.subheading}: {' '.join(slide.paragraphs)}")

                    context_text = "\n".join(context_slides)
                    print("Context")
                    print(context_text)
                    try:
                        mcq_result = (
                            RunnableLambda(lambda x: {"context": context_text})
                            | mcq_prompt
                            | llm
                        ).invoke({})

                        print("qwer")
                        print(mcq_result)
                        print("END")
                        #tokens1 = mcq_result['content']
                        #tokens1 = mcq_result['usage_metadata']['total_tokens']
                        print("tokens1")
                        #print(tokens1)

                        if hasattr(mcq_result, 'content'):
                            # Extract the content (this is the LLM's response)
                            content = mcq_result.content
                        else:
                            # If mcq_result does not have 'content', try to convert it to a string
                            content = str(mcq_result)

                        # Now, process the content to extract the JSON inside the curly braces

                        start = content.find('{')  # Find the first opening curly brace
                        if start != -1:
                            open_braces = 0
                            end = -1
                            for i, char in enumerate(content[start:], start=start):
                                if char == '{':
                                    open_braces += 1
                                elif char == '}':
                                    open_braces -= 1
                                    if open_braces == 0:
                                        end = i
                                        break

                            # Extract the content inside the first pair of curly braces
                            if end != -1:
                                content_only = content[start:end+1]
                                print(content_only) 
                                json_object = json.loads(content_only)
                                 # This will print the content inside the outermost curly braces
                            else:
                                print("No matching closing brace found.")
                        else:
                            print("No opening brace found.")
                        print("match")
                            
                        
                         # Extract the content from the AI response
                        mcq_content = mcq_result.content if hasattr(mcq_result, "content") else mcq_result
                        print("mcq_content")
                        print(mcq_content)

                        print("content_only")
                        print(content_only)
                        print("content_only12")
                        # Clean and fix the JSON format
                        # if not mcq_content or not mcq_content.strip():
                        #     raise ValueError("Empty or invalid response from Ollama")

                        if not content_only or not content_only.strip():
                            raise ValueError("Empty or invalid response from Ollama")

                        try:
                            # cleaned_content = mcq_content.replace("'", '"')  
                            # mcq_json = json.loads(cleaned_content)
                            
                            # formatted_mcq = {
                            #     "question": mcq_json.get("question", ""),
                            #     "options": mcq_json.get("options", []),
                            #     "correct_answer": mcq_json.get("correct_answer", "")
                            # }
                            #cleaned_content = content_only.replace("'", '"')  
                            #mcq_json = json.loads(cleaned_content)
                            
                            formatted_mcq = {
                                "question": json_object.get("question", ""),
                                "options": json_object.get("options", []),
                                "correct_answer": json_object.get("correct_answer", "")
                            }

                            print("Formatted MCQ:")
                            print(json.dumps(formatted_mcq, indent=4))

                            if formatted_mcq["question"] and len(formatted_mcq["options"]) == 4 and formatted_mcq["correct_answer"]:
                                mcq_key = f"mcq_{len(mcqs) + 1}"
                                mcqs[mcq_key] = formatted_mcq
                                
                                print(f"✅ MCQ {mcq_key} generated and saved!")
                            else:
                                raise ValueError("Incomplete MCQ data")
            
                            
                        except json.JSONDecodeError as e:
                            print(f"JSON Decode Error: {e}")
                            print("Raw Response:", mcq_content)
                        # Validate the JSON structure
                        # if not is_valid_json(mcq_json):
                        #     raise ValueError("Invalid JSON structure for MCQ")

                        # # Save the valid MCQ
                        # mcqs[f"mcq_{len(mcqs) + 1}"] = mcq_json
                        # # if not mcq_result or not is_valid_json(mcq_result):
                        # #     raise ValueError("Invalid or empty response from Ollama")
                        # if not mcq_content or not is_valid_json(mcq_content):
                        #     raise ValueError("Invalid or empty response from Ollama")
                        
                        # #mcqs[f"mcq_{len(mcqs) + 1}"] = mcq_result
                        # #mcqs[f"mcq_{len(mcqs) + 1}"] = mcq_result.model_dump()
                        # mcqs[f"mcq_{len(mcqs) + 1}"] = json.loads(mcq_content)
                        
                    # except Exception as e:
                    #     print(f"Error generating MCQ: {e}")
                    #     mcqs[f"mcq_{len(mcqs) + 1}"] = {
                    #         "question": "Could not generate question",
                    #         "options": ["N/A", "N/A", "N/A", "N/A"],
                    #         "correct_answer": "a"
                    #     }
                    except Exception as e:
                        print(f"Error generating MCQ: {e}")
            
            #Calculate token count
            token_count=0
            for mcq in mcqs.values():
                text_content = f"{mcq['question']} {' '.join(mcq['options'])} {mcq['correct_answer']}"
                token_count += count_tokens(text_content)

            interleaved_content = {}
            mcq_counter = 0
            total_slides = num_slides
            total_mcqs = num_mcqs
            
            if question_position == 1:
                interval = total_slides // total_mcqs if total_mcqs > 0 else total_slides
                
                for idx, slide_key in enumerate(slide_keys):
                    interleaved_content[slide_key] = ordered_slides[slide_key]
                    
                    if (idx + 1) % interval == 0 and mcq_counter < total_mcqs:
                        mcq_key = f"mcq_{mcq_counter + 1}"
                        print(mcq_key)
                        interleaved_content[mcq_key] = mcqs[mcq_key]
                        mcq_counter += 1
                
                storigo_content = StorigoContentMCQMid(slides=interleaved_content, token_count=token_count_text)
                return storigo_content
            
            else:
                for slide_key in slide_keys:
                    interleaved_content[slide_key] = ordered_slides[slide_key]
                
                for mcq_counter in range(total_mcqs):
                    mcq_key = f"mcq_{mcq_counter + 1}"
                    interleaved_content[mcq_key] = mcqs[mcq_key]
                
                storigo_content = StorigoContentMCQMid(slides=interleaved_content, token_count=token_count_text)
                return storigo_content
        else:
            storigo_content_without_mc = StorigoContent(slides=ordered_slides, token_count=token_count_text)
            return storigo_content_without_mc

    except Exception as e:
        raise Exception(f"Error generating slide content: {str(e)}")
        
    #     # Return the StorigoContent with token count
    #     return StorigoContent(slides=result.slides, token_count=token_count)
    # except Exception as e:
    #     raise Exception(f"Error generating slide content from prompt: {str(e)}")

    
# def main(prompt, num_slides, is_image):
#     try:
#         start_time = time.time()
#         num_mcqs=3
#         is_question = True
#         question_position = 0
#         GPU = 1
#         slide_content = generate_slide_content_from_prompt(prompt, num_slides,num_mcqs, is_image, is_question, question_position, GPU)
        
#         end_time = time.time()
        
#         print("Generated Slide Content:")
#         print(slide_content)
#         print(f"\nProcess took {end_time - start_time:.2f} seconds")
#     except Exception as e:
#         print(f"Error in main function: {str(e)}")

def main(prompt, num_slides, is_image):
    
    start_time = time.time()
    num_mcqs=2
    is_question = True
    question_position = 0
    GPU = 1
    print("Starting")
    slide_content = generate_slide_content_from_prompt(prompt, num_slides,num_mcqs, is_image, is_question, question_position, GPU)
    
    end_time = time.time()
    
    print("Generated Slide Content:")
    print(slide_content)
    print(f"\nProcess took {end_time - start_time:.2f} seconds")
    

if __name__ == "__main__":
    prompt = "The history and impact of artificial intelligence in modern society"
    num_slides = 8
    is_image = True
    main(prompt, num_slides, is_image)
