import streamlit as st
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_ollama import OllamaEmbeddings
from langchain_ollama import ChatOllama
from langchain_community.vectorstores import Chroma
from langchain.prompts import ChatPromptTemplate
import os
import tempfile

# Set page configuration
st.set_page_config(
    page_title="PDF Chatbot",
    page_icon="📚",
    layout="wide"
)

# App title and description
st.title("📚 Multi-PDF Chatbot")
st.markdown("Upload multiple PDF documents and ask questions about their content.")

# Initialize session state variables
if "conversation_history" not in st.session_state:
    st.session_state.conversation_history = []
if "vectorstore" not in st.session_state:
    st.session_state.vectorstore = None
if "uploaded_files" not in st.session_state:
    st.session_state.uploaded_files = []
if "processing_complete" not in st.session_state:
    st.session_state.processing_complete = False

# Sidebar for file upload and processing
with st.sidebar:
    st.header("Document Upload")
    uploaded_files = st.file_uploader("Upload PDF Documents", type="pdf", accept_multiple_files=True)
    
    # Process button
    if uploaded_files and st.button("Process Documents"):
        with st.spinner("Processing PDFs..."):
            # Reset processing state
            st.session_state.processing_complete = False
            st.session_state.conversation_history = []
            
            # Create a temporary directory to save uploaded files
            temp_dir = tempfile.mkdtemp()
            all_docs = []
            
            # Save uploaded files to temp directory and process them
            for file in uploaded_files:
                temp_file_path = os.path.join(temp_dir, file.name)
                with open(temp_file_path, "wb") as f:
                    f.write(file.getvalue())
                
                st.write(f"Loading {file.name}...")
                loader = PyPDFLoader(temp_file_path)
                docs = loader.load()
                all_docs.extend(docs)
            
            st.session_state.uploaded_files = [file.name for file in uploaded_files]
            st.write(f"Loaded {len(all_docs)} pages in total.")
            
            # Split documents into chunks
            text_splitter = RecursiveCharacterTextSplitter(
                chunk_size=1000,
                chunk_overlap=200
            )
            chunks = text_splitter.split_documents(all_docs)
            st.write(f"Split into {len(chunks)} chunks")
            
            # Create embeddings and store in vector database
            with st.spinner("Creating embeddings... this might take a while"):
                try:
                    embeddings = OllamaEmbeddings(model="nomic-embed-text")
                    st.session_state.vectorstore = Chroma.from_documents(chunks, embeddings)
                    st.session_state.processing_complete = True
                    st.success("Documents processed successfully!")
                except Exception as e:
                    st.error(f"Error creating embeddings: {str(e)}")
    
    # Display processed files
    if st.session_state.uploaded_files:
        st.header("Processed Documents")
        for file_name in st.session_state.uploaded_files:
            st.write(f"- {file_name}")
    
    # Model settings
    st.header("Model Settings")
    model_name = st.selectbox(
        "Select Ollama Model",
        ["llama3:8b", "deepseek-r1:8b"],
        index=0
    )

# Main chat interface
st.header("Chat with your PDFs")

# Display conversation history
for i, (query, response) in enumerate(st.session_state.conversation_history):
    # User message
    with st.chat_message("user"):
        st.write(query)
    # Bot message
    with st.chat_message("assistant"):
        st.write(response)

# Chat input
user_query = st.chat_input("Ask a question about your PDFs...", disabled=not st.session_state.processing_complete)

if user_query and st.session_state.processing_complete:
    # Display user message
    with st.chat_message("user"):
        st.write(user_query)
    
    # Generate response
    with st.chat_message("assistant"):
        with st.spinner("Thinking..."):
            # Define helper functions for generating responses
            def format_docs(docs):
                return "\n\n".join(doc.page_content for doc in docs)

            def format_chat_history(history):
                if not history:
                    return "No previous conversation."
                formatted = ""
                for i, (q, a) in enumerate(history):
                    formatted += f"Question {i+1}: {q}\nAnswer {i+1}: {a}\n\n"
                return formatted
            
            # Create LLM instance
            llm = ChatOllama(
                base_url='http://127.0.0.1:11434',
                model=model_name
            )
            
            # Set up retriever
            retriever = st.session_state.vectorstore.as_retriever(search_kwargs={"k": 4})
            
            # Create prompt template
            template = """
            You are a helpful assistant that answers questions based on the provided PDF documents.
            Answer the question based only on the following context from the PDFs:
            {context}

            Previous conversation history:
            {chat_history}

            Current question: {question}

            Provide a comprehensive and accurate answer based on the information in the PDFs.
            """
            
            prompt = ChatPromptTemplate.from_template(template)
            
            try:
                # Retrieve relevant documents
                docs = retriever.get_relevant_documents(user_query)
                context = format_docs(docs)
                
                # Format conversation history
                chat_history = format_chat_history(st.session_state.conversation_history)
                
                # Generate response
                messages = prompt.format_messages(
                    context=context,
                    chat_history=chat_history,
                    question=user_query
                )
                response = llm.invoke(messages).content
                
                # Display response
                st.write(response)
                
                # Update conversation history
                st.session_state.conversation_history.append((user_query, response))
            except Exception as e:
                st.error(f"Error generating response: {str(e)}")

# Show a message if no documents have been processed
if not st.session_state.processing_complete:
    st.info("👈 Please upload and process PDF documents from the sidebar to start chatting.")

# Add a clear chat button
if st.session_state.conversation_history and st.button("Clear Chat History"):
    st.session_state.conversation_history = []
    st.experimental_rerun()

# Add footer
st.markdown("---")
st.markdown("Built with Langchain and Ollama")