import os
import faiss
import json
from langchain.document_loaders import PyMuPDFLoader  # Extract text from PDFs
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import OllamaEmbeddings  # Use Ollama for local embedding
from langchain.memory import ConversationBufferMemory  # Stores chat history
from langchain.llms import Ollama
from langchain.chains import ConversationalRetrievalChain

# --- Configuration ---
PDF_FOLDER = "pdfs/"  # Folder containing PDF files
EMBEDDING_DIR = "my_embeddings/"  # Where embeddings are stored
FAISS_INDEX_FILE = os.path.join(EMBEDDING_DIR, "course_101.index")

# --- Step 1: Load PDFs and Create Embeddings ---
def load_and_embed_pdfs():
    pdf_files = [f for f in os.listdir(PDF_FOLDER) if f.endswith(".pdf")]
    docs = []

    # Extract text from PDFs
    for pdf in pdf_files:
        loader = PyMuPDFLoader(os.path.join(PDF_FOLDER, pdf))
        docs.extend(loader.load())

    # Split text into chunks for embeddings
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    text_chunks = text_splitter.split_documents(docs)

    # Use Ollama embeddings (can be replaced with OpenAI or other models)
    embeddings = OllamaEmbeddings(model="nomic-embed")  # Local embedding model

    # Create FAISS vector store and save locally
    faiss_db = FAISS.from_documents(text_chunks, embeddings)
    faiss_db.save_local(FAISS_INDEX_FILE)
    return faiss_db

# --- Step 2: Load FAISS Index ---
def load_faiss():
    embeddings = OllamaEmbeddings(model="nomic-embed")
    return FAISS.load_local(FAISS_INDEX_FILE, embeddings)

# --- Step 3: Initialize Conversational Agent ---
def create_conversational_chain():
    # Load embeddings (or create if not found)
    if os.path.exists(FAISS_INDEX_FILE):
        faiss_db = load_faiss()
    else:
        faiss_db = load_and_embed_pdfs()

    # Define chat memory (stores previous questions & answers)
    memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)

    # Use Ollama as LLM (can replace with GPT or Mistral)
    llm = Ollama(model="mistral")  # Local LLM

    # Create conversational chain with memory
    chain = ConversationalRetrievalChain.from_llm(
        llm=llm,
        retriever=faiss_db.as_retriever(),
        memory=memory,
    )
    return chain

# --- Step 4: Run Chatbot ---
if __name__ == "__main__":
    chat_agent = create_conversational_chain()
    
    print("AI Chat with PDFs (type 'exit' to stop)")
    while True:
        query = input("\nYou: ")
        if query.lower() == "exit":
            break
        response = chat_agent.invoke({"question": query})
        print("\nAI:", response["answer"])
