import time
import os
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from pptx import Presentation
import textract
import subprocess

from langchain_experimental.text_splitter import SemanticChunker
import shutil

def convert_ppt_to_pptx(file_path):
    pptx_file_path = file_path.replace('.ppt', '.pptx')
    subprocess.run(['libreoffice', '--headless', '--convert-to', 'pptx', file_path])
    return pptx_file_path

def load_and_split_document(file_path):
    file_extension = os.path.splitext(file_path)[1].lower()
    
    if file_extension == '.pdf':
        loader = PyPDFLoader(file_path)
        docs = loader.load()
    elif file_extension == '.txt':
        loader = TextLoader(file_path)
        docs = loader.load()
    elif file_extension in ['.docx', '.doc']:
        text = textract.process(file_path).decode('utf-8')
        docs = [text]
    elif file_extension == '.pptx':
        prs = Presentation(file_path)
        text = []
        for slide in prs.slides:
            for shape in slide.shapes:
                if hasattr(shape, "text"):
                    text.append(shape.text)
        docs = ['\n'.join(text)]
    elif file_extension == '.ppt':
        pptx_file_path = convert_ppt_to_pptx(file_path)
        docs = load_and_split_document(pptx_file_path)
        os.remove(pptx_file_path)
        return docs
    else:
        raise ValueError(f"Unsupported file type: {file_extension}")

    if not docs:
        print("No documents loaded from file.")
        return []

    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    final_documents = []
    for doc in docs:
        if isinstance(doc, str):
            text = doc
        elif hasattr(doc, 'page_content'):
            text = doc.page_content
        else:
            print("Unsupported document format:", doc)
            continue

        chunks = text_splitter.split_text(text)
        final_documents.extend(chunks)

    return final_documents

def create_embeddings(documents, model="nomic-embed-text"):
    embeddings = OllamaEmbeddings(model=model)

    texts = []
    for doc in documents:
        if isinstance(doc, str):
            texts.append(doc)
        elif hasattr(doc, 'page_content'):
            texts.append(doc.page_content)
        else:
            print("Unsupported document format:", doc)

    vectors = FAISS.from_texts(texts=texts, embedding=embeddings)
    return vectors

def count_tokens(documents):
    return sum(len(doc.split()) for doc in documents)




### NEW


def load_pdf(file_path):
    """Load PDF and return documents."""
    #loader = PyPDFLoader(f"{file_path}.pdf")
    loader = PyPDFLoader(file_path)
    return loader.load()

def count_tokens_in_documents(documents):
    """Counts total tokens in the given documents."""
    total_tokens = 0
    for document in documents:
        # Simple tokenization based on whitespace; adjust based on your needs
        tokens = document.page_content.split()  # Split by whitespace, which is simple but not accurate for all cases
        total_tokens += len(tokens)
    return total_tokens

def count_total_words(docs):
    """Count the total number of words in the documents."""
    return sum(len(doc['page_content'].split()) for doc in docs)

def split_text_with_semantic_chunker(docs, embeddings):
    """Splits the text into semantic chunks using the given embeddings."""
    text_splitter = SemanticChunker(
        embeddings, breakpoint_threshold_type="percentile"  # Can be changed to "standard_deviation", "interquartile"
    )
    documents = text_splitter.create_documents([doc.page_content for doc in docs])
    #documents = text_splitter.create_documents([doc['page_content'] for doc in docs])
    print("Documents split into semantic chunks.")
    return documents

def save_documents_to_txt(documents, output_dir):
    """Saves each document in the documents list as a separate .txt file."""
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)  # Create the output directory if it doesn't exist
    
    for i, document in enumerate(documents):
        file_name = f"document_part_{i+1}.txt"
        file_path = os.path.join(output_dir, file_name)
        
        with open(file_path, 'w', encoding='utf-8') as file:
            file.write(document.page_content)  # Assuming each document object has a 'page_content' attribute
        
        print(f"Saved: {file_path}")

def create_and_save_embeddings(split_documents, client_id,reference_id):

    client_id = str(client_id)
    reference_id = str(reference_id)
    # Base folder structure: my_embeddings/{client_id}/{reference_id}
    embedding_folder_base = os.path.join("my_embeddings", client_id,reference_id)
    
    # Make sure the base embedding folder exists
    os.makedirs(embedding_folder_base, exist_ok=True)
    
    # Initialize the embedding model
    embeddings = OllamaEmbeddings(model='nomic-embed-text')
    
    # Iterate over each document chunk and generate embeddings
    for idx, doc in enumerate(split_documents, start=1):  
        # Create a unique folder for each document's embeddings directly inside embedding_folder_base
        embedding_folder = os.path.join(embedding_folder_base)
        
        # Ensure each document's folder is created fresh inside the base folder, without nesting
        os.makedirs(embedding_folder, exist_ok=True)
        
        # Create a FAISS index for this chunk
        temp_db = FAISS.from_documents([doc], embedding=embeddings)
        
        # Save the FAISS index for this chunk with an incremental filename
        embedding_file_path = os.path.join(embedding_folder, f"faiss_index{idx}")
        temp_db.save_local(embedding_file_path)
        
        print(f"Saved FAISS embedding for document part {idx} as faiss_index{idx} in {embedding_folder}")


def merge_all_faiss(client_id,reference_id, base_path='my_embeddings'):
    embeddings=OllamaEmbeddings(model="nomic-embed-text") 
    # Initialize an empty FAISS vectorstore for merging
    merged_faiss = None
    
    # Construct the base folder path
    folder_path = f'{base_path}/{client_id}/{reference_id}'
    
    # List all folders that match the pattern 'faiss_index{i}'
    faiss_folders = [
        folder for folder in os.listdir(folder_path) 
        if folder.startswith('faiss_index') and folder[len('faiss_index'):].isdigit()
    ]
    
    # Sort folders by the index number extracted from 'faiss_index{i}'
    sorted_folders = sorted(faiss_folders, key=lambda x: int(x.replace('faiss_index', '')))
    
    # Loop through the sorted folders and merge FAISS stores
    for folder in sorted_folders:
        faiss_path = os.path.join(folder_path, folder)
        print(f"Loading FAISS index from: {faiss_path}")  # Debugging: See the order of loading
        current_faiss = FAISS.load_local(faiss_path, embeddings, allow_dangerous_deserialization=True)

        # Extract document content
        current_texts = [current_faiss.docstore.search(doc_id).page_content 
                         for doc_id in current_faiss.index_to_docstore_id.values()]

        if merged_faiss is None:
            merged_faiss = current_faiss
        else:
            # Only add texts — this adds both vectors and metadata
            merged_faiss.add_texts(current_texts)
    
    # Optionally, save the merged FAISS index to a new folder
    if merged_faiss is not None:
        merged_faiss.save_local(f'{folder_path}/merged_faiss')
    print(merged_faiss)

    # Delete individual FAISS index folders, except for the 'merged_faiss'
    for folder in sorted_folders:
        faiss_path = os.path.join(folder_path, folder)
        try:
            # Delete the entire directory for the FAISS index (e.g., faiss_index3)
            shutil.rmtree(faiss_path)
            print(f"Deleted FAISS index folder: {faiss_path}")
        except FileNotFoundError:
            print(f"Folder not found: {faiss_path}")
        except OSError as e:
            print(f"Error deleting {faiss_path}: {e}")
    return merged_faiss


""" def main():
    file_path = "python_basics.pdf"
    client_id = 4
    reference_id = "1011"

    folder_path = f"my_embeddings/{client_id}/{reference_id}"
    os.makedirs(folder_path, exist_ok=True)

    start_time = time.time()
    final_documents = load_and_split_document(file_path)

    if not final_documents:
        print("No valid documents to process.")
        return

    token_count = count_tokens(final_documents)
    vectors = create_embeddings(final_documents)
    vectors.save_local(folder_path)
    end_time = time.time()

    print(f"Training process took {end_time - start_time:.2f} seconds")
    print(f"Total tokens processed during training: {token_count}") """

def main(file_path, output_dir,client_id,reference_id):
    start_time = time.time()
    print(f"Start Time: {start_time}")
    
    # Load the PDF
    docs = load_pdf(file_path)
    #docs = extract_pdf_with_headings(file_path)
    Total_count = count_tokens_in_documents(docs)
    
    # Count the total words
    #total_words = count_total_words(docs)
    #print(total_words) 
    
    # Create embeddings
    embeddings = OllamaEmbeddings(model ='nomic-embed-text')
    
    # Split the text with semantic chunker
    split_documents = split_text_with_semantic_chunker(docs, embeddings)
    
    # Save the split documents to text files
    save_documents_to_txt(split_documents, output_dir)
    #save_documents_to_json(split_documents, output_dir)
    
    # Create and save embeddings for the split documents
    create_and_save_embeddings(split_documents, client_id, reference_id)

    # Generate slide content after creating and saving embeddings
    embedding_dir = f"my_embeddings/{client_id}/{reference_id}"
    """ combined_faiss_index = load_all_faiss_indices(embedding_dir)
    slides = generate_slide_content(combined_faiss_index, num_slides, is_image) """
    
    # Process each FAISS index and generate slides
    #process_each_faiss_index(embedding_dir,  num_slides, is_image)
    merge_embeddings = merge_all_faiss(client_id,reference_id)    

if __name__ == "__main__":
    file_path = "cyber.pdf"
    output_dir = f'temp/{file_path}'
    embedding_folder_base = 'output_embeddings'
    client_id = 4567
    reference_id = 1145
    main(file_path, output_dir, client_id,reference_id)