# llama3_chatting.py

import os
import time
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_groq import ChatGroq
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate

GROQ_API_KEY = "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"

def load_all_client_embeddings(client_id):
    embeddings = OllamaEmbeddings(model="nomic-embed-text")
    client_dir = f"my_embeddings/{client_id}"
    
    if not os.path.exists(client_dir):
        raise FileNotFoundError(f"No embeddings found for client {client_id}")
    
    all_vectors = None
    for reference_id in os.listdir(client_dir):
        folder_path = os.path.join(client_dir, reference_id)
        if os.path.isdir(folder_path):
            vectors = FAISS.load_local(folder_path, embeddings, allow_dangerous_deserialization=True)
            
            # Add reference_id to metadata of each document
            for doc in vectors.docstore._dict.values():
                doc.metadata['reference_id'] = reference_id
            
            if all_vectors is None:
                all_vectors = vectors
            else:
                all_vectors.merge_from(vectors)
    
    if all_vectors is None:
        raise FileNotFoundError(f"No valid embeddings found for client {client_id}")
    
    print(f"All embeddings loaded for client {client_id}")
    return all_vectors

def answer_question(client_id, question, max_tokens=200):
    start_time = time.time()
    
    try:
        vectors = load_all_client_embeddings(client_id)
    except FileNotFoundError as e:
        print(f"Error: {e}")
        return "I couldn't access the necessary information to answer your question.", [], 0

    llm = ChatGroq(
        model_name='Llama3-8b-8192',
        groq_api_key=GROQ_API_KEY,
        max_tokens=max_tokens
    )

    template = """You are a highly intelligent and professional AI assistant. Use the following context to answer the question. If the answer is not in the context, say "I don't have enough information to answer that question."

Context: {context}

Question: {question}

Answer: """

    prompt = PromptTemplate(
        template=template,
        input_variables=["context", "question"]
    )

    qa_chain = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=vectors.as_retriever(search_kwargs={"k": 3}),
        return_source_documents=True,
        chain_type_kwargs={"prompt": prompt}
    )

    result = qa_chain.invoke({"query": question})
    end_time = time.time()
    processing_time = end_time - start_time

    answer = result["result"]
    source_docs = result["source_documents"]
    total_tokens = sum(len(doc.page_content.split()) for doc in source_docs) + len(question.split())

    # Extract only the reference_ids from the source documents
    source_reference_ids = [doc.metadata.get('reference_id', 'Unknown') for doc in source_docs]

    return answer, source_reference_ids, total_tokens, processing_time


# Usage example
if __name__ == "__main__":
    client_id = 7  # Replace with the actual client ID
    question = "What are the types of leaves?"

    try:
        answer, source_docs, total_tokens, processing_time = answer_question(client_id, question)
        print("Answer:", answer)
        print(f"\nChatting process took {processing_time:.2f} seconds")
        print(f"Total tokens (question + context): {total_tokens}")
        print("\nSource documents:")
        for doc in source_docs:
            print(f"- {doc.metadata.get('source', 'Unknown source')}")
    except Exception as e:
        print(f"An error occurred: {e}")
