import os
import random
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_groq import ChatGroq
from langchain.prompts import PromptTemplate
from langchain_ollama import ChatOllama
#from langchain.output_parsers import RegexParser
import re
import json
import asyncio
from tqdm import tqdm
import time

GROQ_API_KEY = "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"

def load_all_embeddings(client_id, reference_id):
    embeddings = OllamaEmbeddings(model="nomic-embed-text")
    reference_dir = f"my_embeddings/{client_id}/{reference_id}/merged_faiss"
    if os.path.isdir(reference_dir):
        embedding_files = [f for f in os.listdir(reference_dir) if f.endswith((".faiss", ".pkl"))]
        if embedding_files:
            vectors = FAISS.load_local(reference_dir, embeddings, allow_dangerous_deserialization=True)
            print(f"Embeddings loaded from {reference_dir}")
            return vectors
        else:
            print(f"No embedding files found in {reference_dir}.")
    else:
        print(f"No embeddings found for client_id {client_id} and reference_id {reference_id}.")
    return None

class QuestionHistory:
    def __init__(self, client_id, reference_id):
        self.file_path = f"question_history/{client_id}/{reference_id}.json"
        self.questions = self.load()

    def load(self):
        if os.path.exists(self.file_path):
            try:
                with open(self.file_path, 'r') as f:
                    content = f.read().strip()
                    return set(json.loads(content)) if content else set()
            except json.JSONDecodeError:
                print(f"Warning: Invalid JSON in {self.file_path}. Starting with an empty history.")
                return set()
        else:
            print(f"No history file found. Starting with an empty history.")
            return set()

    def save(self):
        os.makedirs(os.path.dirname(self.file_path), exist_ok=True)
        with open(self.file_path, 'w') as f:
            json.dump(list(self.questions), f)

    def add(self, question):
        self.questions.add(question)

    def __contains__(self, question):
        return question in self.questions

    def __len__(self):
        return len(self.questions)

class RateLimiter:
    def __init__(self, max_requests, period):
        self.max_requests = max_requests
        self.period = period
        self.requests = []

    async def wait(self):
        now = time.time()
        self.requests = [req for req in self.requests if now - req < self.period]
        if len(self.requests) >= self.max_requests:
            sleep_time = self.period - (now - self.requests[0])
            if sleep_time > 0:
                await asyncio.sleep(sleep_time)
        self.requests.append(time.time())

async def generate_mcq(client_id, num_questions, reference_id,GPU):
    vectors = load_all_embeddings(client_id, reference_id)
    if not vectors:
        print("No embeddings found for the given client ID and reference ID.")
        return [], 0

    question_history = QuestionHistory(client_id, reference_id)
    
    retriever = vectors.as_retriever()
    #llm = ChatGroq(model_name="llama-3.3-70b-versatile", groq_api_key=GROQ_API_KEY)
    if(GPU==0):
            llm = ChatGroq(model_name='llama3-70b-8192', groq_api_key=GROQ_API_KEY)
    else:
        llm = ChatOllama(
        base_url = 'http://127.0.0.1:11434',
        model = "llama3:8b"
        #model = "deepseek-r1:8b"
    )
    mcq_template = PromptTemplate(
        input_variables=["context", "num_questions"],
         template="""
        {context}

        Generate {num_questions} multiple-choice questions based on the provided context. Each question should include one correct answer and three plausible but incorrect options. Follow the format below:
        Q: <question>
        A) <option_a>
        B) <option_b>
        C) <option_c>
        D) <option_d>
        Correct: <correct_letter>

        Guidelines:

    Ensure questions are clear, relevant, and focused on key details from the context.
    Phrase questions concisely to avoid ambiguity.
    Avoid questions like "in this document" or "in this code" or "  in given document" unless accompanied by specific, relevant information.
    Make incorrect options (distractors) credible and similar in structure or content to the correct answer, encouraging thoughtful selection.

        """
    )
    
    mcq_regex = re.compile(r"Q: (.+?)\nA\) (.+?)\nB\) (.+?)\nC\) (.+?)\nD\) (.+?)\nCorrect: ([A-D])", re.DOTALL)

    mcqs = []
    question_number = 1
    rate_limiter = RateLimiter(max_requests=25, period=60)  # 25 requests per minute

    async def process_context(context, start_number, batch_size):
        await rate_limiter.wait()
        prompt = mcq_template.format(context=context, num_questions=batch_size)
        response = await llm.ainvoke([prompt])
        response_str = response.content
        questions = []
        for match in mcq_regex.finditer(response_str):
            question, option_a, option_b, option_c, option_d, correct_letter = match.groups()
            question = question.strip()
            if question not in question_history:
                questions.append({
                    "number": start_number + len(questions),
                    "question": question,
                    "options": {
                        "A": option_a.strip(),
                        "B": option_b.strip(),
                        "C": option_c.strip(),
                        "D": option_d.strip()
                    },
                    "correct_answer": correct_letter
                })
        return questions

    with tqdm(total=num_questions, desc="Generating MCQs") as pbar:
        timeout = time.time() + 300  # 5 minutes timeout
        while len(mcqs) < num_questions and time.time() < timeout:
            remaining_questions = num_questions - len(mcqs)
            batch_size = min(remaining_questions, 5)  # Adjust batch size dynamically
            
            query = f"Random query for retrieval {random.randint(1, 1000)}"  # Add randomness to query
            results = await retriever.ainvoke(query)
            
            for result in results:
                if len(mcqs) >= num_questions:
                    break
                context = result.page_content
                try:
                    new_mcqs = await process_context(context, question_number, batch_size)
                    for mcq in new_mcqs:
                        if len(mcqs) < num_questions and mcq['question'] not in question_history:
                            mcqs.append(mcq)
                            question_history.add(mcq['question'])
                            question_number += 1
                            pbar.update(1)
                except Exception as e:
                    print(f"Error processing context: {e}")
                    continue

                if len(mcqs) >= num_questions:
                    break

        if len(mcqs) < num_questions:
            print(f"Warning: Only generated {len(mcqs)} unique questions out of {num_questions} requested.")

    question_history.save()
    return mcqs, len(question_history)

async def main():
    client_id = 'rich'
    num_questions = 20
    reference_id = 15
    GPU=1
    mcqs, history_count = await generate_mcq(client_id, num_questions, reference_id,GPU)
    if not mcqs:
        print("No multiple-choice questions generated.")
        return
    
    print(f"Total questions in history: {history_count}")
    
    for mcq in mcqs:
        print(f"Question {mcq['number']}: {mcq['question']}")
        for option, text in mcq['options'].items():
            print(f"{option}) {text}")
        print(f"Correct Answer: {mcq['correct_answer']}\n")

if __name__ == "__main__":
    asyncio.run(main())