import os
import random
import json
import asyncio
import re
import time
from tqdm import tqdm
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_groq import ChatGroq
from langchain.prompts import PromptTemplate
from langchain_ollama import ChatOllama
GROQ_API_KEY = "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"

def load_all_embeddings(client_id, reference_id):
    embeddings = OllamaEmbeddings(model="nomic-embed-text")
    reference_dir = f"my_embeddings_video/{client_id}/{reference_id}"

    if os.path.isdir(reference_dir):
        embedding_files = [f for f in os.listdir(reference_dir) if f.endswith((".faiss", ".pkl"))]
        if embedding_files:
            vectors = FAISS.load_local(reference_dir, embeddings, allow_dangerous_deserialization=True)
            print(f"Embeddings loaded from {reference_dir}")
            return vectors
        else:
            print(f"No embedding files found in {reference_dir}.")
    else:
        print(f"No embeddings found for client_id {client_id} and reference_id {reference_id}.")
    return None

class QuestionHistory:
    def __init__(self, client_id, reference_id):
        self.file_path = f"question_history_video/{client_id}/{reference_id}.json"
        self.questions = self.load()

    def load(self):
        if os.path.exists(self.file_path):
            try:
                with open(self.file_path, 'r') as f:
                    content = f.read().strip()
                return set(json.loads(content)) if content else set()
            except json.JSONDecodeError:
                print(f"Warning: Invalid JSON in {self.file_path}. Starting with an empty history.")
                return set()
        else:
            print(f"No history file found. Starting with an empty history.")
            return set()

    def save(self):
        os.makedirs(os.path.dirname(self.file_path), exist_ok=True)
        with open(self.file_path, 'w') as f:
            json.dump(list(self.questions), f)

    def add(self, question):
        self.questions.add(question)

    def __contains__(self, question):
        return question in self.questions

    def __len__(self):
        return len(self.questions)

class RateLimiter:
    def __init__(self, max_requests, period):
        self.max_requests = max_requests
        self.period = period
        self.requests = []

    async def wait(self):
        now = time.time()
        self.requests = [req for req in self.requests if now - req < self.period]
        if len(self.requests) >= self.max_requests:
            sleep_time = self.period - (now - self.requests[0])
            if sleep_time > 0:
                await asyncio.sleep(sleep_time)
        self.requests.append(time.time())

async def generate_mcq_video(client_id, num_questions, reference_id,GPU):
    vectors = load_all_embeddings(client_id, reference_id)
    if not vectors:
        print("No embeddings found for the given client ID and reference ID.")
        return [], 0, 0

    question_history = QuestionHistory(client_id, reference_id)
    retriever = vectors.as_retriever(search_kwargs={"k": 10})
    #llm = ChatGroq(model_name="Llama3-8b-8192", groq_api_key=GROQ_API_KEY)
    if(GPU==0):
            llm = ChatGroq(model_name='llama3-70b-8192', groq_api_key=GROQ_API_KEY)
    else:
        llm = ChatOllama(
        base_url = 'http://127.0.0.1:11434',
        model = "llama3:8b"
        #model = "deepseek-r1:8b"
    )
    mcq_template = PromptTemplate(
        input_variables=["context", "num_questions"],
        template="""
        {context}
        Generate {num_questions} multiple-choice questions based on the information provided above. Each question should have one correct option and three incorrect options. Use the following format:
        Q:
        A)
        B)
        C)
        D)
        Correct:
        Ensure questions are concise and directly related to the content. Try to cover different aspects of the provided information. Do not include phrases like "in the document", "according to the text", or any other references to the source material in the questions.
        """
    )

    mcq_regex = re.compile(r"Q: (.+?)\nA\) (.+?)\nB\) (.+?)\nC\) (.+?)\nD\) (.+?)\nCorrect: ([A-D])", re.DOTALL)
    mcqs = []
    question_number = 1
    total_token_count = 0
    rate_limiter = RateLimiter(max_requests=25, period=60)

    async def process_context(context, start_number, batch_size):
        await rate_limiter.wait()
        prompt = mcq_template.format(context=context, num_questions=batch_size)
        response = await llm.ainvoke([prompt])
        response_str = response.content
        questions = []

        for match in mcq_regex.finditer(response_str):
            question, option_a, option_b, option_c, option_d, correct_letter = match.groups()
            question = question.strip()

            # Remove document references from the question
            question = re.sub(r'\b(in|from|according to|mentioned in|stated in|as per) (the|this)? ?(document|text|passage|content)\b', '', question, flags=re.IGNORECASE).strip()
            question = re.sub(r'^(The|This) (document|text|passage|content) (states|mentions|says|indicates) that', '', question, flags=re.IGNORECASE).strip()

            if question not in question_history:
                questions.append({
                    "number": start_number + len(questions),
                    "question": question,
                    "options": {
                        "A": option_a.strip(),
                        "B": option_b.strip(),
                        "C": option_c.strip(),
                        "D": option_d.strip()
                    },
                    "correct_answer": correct_letter
                })
        return questions

    with tqdm(total=num_questions, desc="Generating MCQs") as pbar:
        timeout = time.time() + 300  # 5 minutes timeout
        query = "Provide a comprehensive overview of the entire content"
        results = await retriever.ainvoke(query)
        contexts = [result.page_content for result in results]

        while len(mcqs) < num_questions and time.time() < timeout:
            remaining_questions = num_questions - len(mcqs)
            batch_size = min(remaining_questions, 3)
            context = random.choice(contexts)

            try:
                new_mcqs = await process_context(context, question_number, batch_size)

                for mcq in new_mcqs:
                    if len(mcqs) < num_questions and mcq['question'] not in question_history:
                        mcqs.append(mcq)
                        question_history.add(mcq['question'])
                        question_number += 1
                        total_token_count += len(mcq['question'].split()) + sum(len(option.split()) for option in mcq['options'].values())
                pbar.update(1)

            except Exception as e:
                print(f"Error processing context: {e}")
                continue

        if len(mcqs) < num_questions:
            print(f"Warning: Only generated {len(mcqs)} unique questions out of {num_questions} requested.")

        question_history.save()
        return mcqs, len(question_history), total_token_count

async def main():
    client_id = 3
    num_questions = 3
    reference_id = 114
    mcqs, history_count, token_count = await generate_mcq_video(client_id, num_questions, reference_id)

    if not mcqs:
        print("No multiple-choice questions generated.")
        return

    print(f"Total questions in history: {history_count}")
    print(f"Total token count: {token_count}")

    for mcq in mcqs:
        print(f"Question {mcq['number']}: {mcq['question']}")
        for option, text in mcq['options'].items():
            print(f"{option}) {text}")
        print(f"Correct Answer: {mcq['correct_answer']}\n")

if __name__ == "__main__":
    asyncio.run(main())