from groq import Groq
import os
import pandas as pd
import numpy as np
from moviepy import VideoFileClip
from pydub import AudioSegment
import ffmpeg
from langchain.text_splitter import TokenTextSplitter
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader, UnstructuredPowerPointLoader
from langchain_community.vectorstores import FAISS
#from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.embeddings import OllamaEmbeddings
import time
#from pyannote.audio import Pipeline

client = Groq(api_key="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")
model = 'whisper-large-v3'

class VideoTooLargeException(Exception):
    pass

# 1. Convert MP4 to MP3
def convert_video_to_audio(video_path, mp3_output_path):
    clip = VideoFileClip(video_path)
    video_duration = clip.duration / 60  # Convert to minutes
    
    print(video_duration)
    audio = AudioSegment.from_file(video_path, format="mp4")
    audio.export(mp3_output_path, format="mp3")
    print(f"Converted {video_path} to {mp3_output_path}")

""" def convert_video_to_audio(video_path, mp3_output_path):
    ffmpeg.input(video_path).output(mp3_output_path).run()
    print(f"Converted {video_path} to {mp3_output_path}") """

# def audio_to_text_diarization(filepath, output_file):

#     # Step 1: Check if output file already exists
#     if os.path.exists(output_file):
#         print(f"Reading translation from {output_file}")
#         with open(output_file, "r", encoding="utf-8") as f:
#             translation_text = f.read()
#     else:
#         # Step 2: Generate transcription using Whisper
#         print(f"Generating translation for {filepath}")
#         result = model.transcribe(filepath)
#         translation_text = result['text']
        
#         # Step 3: Write Whisper transcription to output file
#         with open(output_file, "w", encoding="utf-8") as f:
#             f.write(translation_text)

#     # Step 4: Perform speaker diarization using PyAnnote
#     print(f"Performing speaker diarization for {filepath}")
#     diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization")
#     diarization = diarization_pipeline(filepath)

#     # Step 5: Format and integrate diarization with transcription
#     formatted_text = "\n".join(translation_text.split('. '))
#     speaker_timed_segments = []

#     # Iterate over speaker diarization results
#     for turn, _, speaker in diarization.itertracks(yield_label=True):
#         start_time = turn.start
#         end_time = turn.end
#         speaker_timed_segments.append(f"Speaker {speaker}: {start_time:.2f}s - {end_time:.2f}s")
    
#     # Combine speaker labels with transcription
#     final_output = "\n".join(speaker_timed_segments) + "\n\nTranscription:\n" + formatted_text

#     # Step 6: Save combined output to the file
#     with open(output_file, "w", encoding="utf-8") as f:
#         f.write(final_output)

#     return final_output


# 2. Audio Transcription
def audio_to_text(filepath, output_file):
    if os.path.exists(output_file):
        print(f"Reading translation from {output_file}")
        with open(output_file, "r", encoding="utf-8") as f:
            translation_text = f.read()
    else:
        print(f"Generating translation for {filepath}")
        with open(filepath, "rb") as file:
            translation = client.audio.translations.create(
                file=(filepath, file.read()),
                model=model,
            )
        translation_text = translation.text
        
        with open(output_file, "w", encoding="utf-8") as f:
            f.write(translation_text)
    
    formatted_text = "\n".join(translation_text.split('. '))
    return formatted_text
    
    # Split the text into lines based on punctuation for better readability
    formatted_text = "\n".join(translation_text.split('. '))

    return formatted_text

# 3. Transcript Chat Completion
def transcript_chat_completion(client, transcript, user_question):
    chat_completion = client.chat.completions.create(
        messages=[
            {
                "role": "system",
                "content": f'''Use this transcript or transcripts to answer any user questions, citing specific quotes:

                {transcript}
                '''
            },
            {
                "role": "user",
                "content": user_question,
            }
        ],
        model="llama3-8b-8192",
    )
    print(chat_completion.choices[0].message.content)

# 4. Preparing Podcast Files
""" def split_audio(mp3_file_folder, mp3_chunk_folder, episode_id, chunk_length_ms, overlap_ms, print_output=True):
    audio = AudioSegment.from_file(f"{mp3_file_folder}/{episode_id}.mp3", format="mp3")
    num_chunks = len(audio) // (chunk_length_ms - overlap_ms) + (1 if len(audio) % chunk_length_ms else 0)
    
    for i in range(num_chunks):
        start_ms = i * chunk_length_ms - (i * overlap_ms)
        end_ms = start_ms + chunk_length_ms
        chunk = audio[start_ms:end_ms]
        export_fp = f"{mp3_chunk_folder}/{episode_id}_chunk{i+1}.mp3"
        chunk.export(export_fp, format="mp3")
        if print_output:
            print('Exporting', export_fp)
        
    return num_chunks

# 5. Transcribing and Storing Podcast Chunks Without Metadata
def process_chunks(chunk_fps, text_splitter, mp3_chunk_folder, output_file):
    documents = []
    cnt = 0
    for chunk_fp in chunk_fps:
        cnt += 1
        audio_filepath = f"{mp3_chunk_folder}/{chunk_fp}"
        transcript = audio_to_text(audio_filepath, output_file)
        chunks = text_splitter.split_text(transcript)
        episode_id = chunk_fp.split('_chunk')[0]
        
        for chunk in chunks:
            header = f"Episode ID: {episode_id}\nFile: {chunk_fp}\n\n"
            documents.append(Document(page_content=header + chunk, metadata={"source": "local"}))
        
        if np.mod(cnt, round(len(chunk_fps) / 5)) == 0:
            print(f"{round(cnt / len(chunk_fps), 2) * 100} % of transcripts processed...")

    print('# Transcription Chunks:', len(documents))
    return documents
 """

""" def load_and_split_document1(file_path):
    # Ensure the file is a .txt file
    if not file_path.lower().endswith('.txt'):
        raise ValueError("Unsupported file type. Only .txt files are supported.")
    
    # Load and split document
    loader = TextLoader(file_path, encoding='utf-8')
    docs = loader.load()
    
    # Set up the text splitter
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    
    # Split the loaded documents
    return text_splitter.split_documents(docs) """




def load_and_split_document1(file_path):
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            text = file.read()
        
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
        chunks = text_splitter.split_text(text)
        
        documents = [Document(page_content=chunk, metadata={"source": file_path}) for chunk in chunks]
        return documents
    except Exception as e:
        raise Exception(status_code=500, detail=f"Error processing document: {e}")
# 6. Save FAISS Index
def save_faiss_index(documents, faiss_save_path):
    embedding_function = OllamaEmbeddings(model="nomic-embed-text")
    docsearch = FAISS.from_documents(documents, embedding_function)
    docsearch.save_local(faiss_save_path)
    print(f"FAISS index saved to {faiss_save_path}")

def count_tokens_video(documents):
    return sum(len(doc.page_content.split()) for doc in documents)

def count_tokens(documents):
    return sum(len(doc.page_content.split()) for doc in documents)
# Main function
def main():
    video_path = "Gaur_Gopal.mp4"
    mp3_output_path = "mp3-files/Gaur_Gopal.mp3"
    output_file = "Gaur_Gopal.txt"
    # mp3_file_folder = "mp3-files"
    # mp3_chunk_folder = "mp3-chunks"
    # chunk_length_ms = 1000000
    # overlap_ms = 10000
    # model_name = "all-MiniLM-L6-v2"
    #faiss_save_path = "faiss_index"
    client_id = 66
    reference_id = "Gaur_Gopal"

    folder_path = f"my_embeddings_video/{client_id}/{reference_id}"

    # Step 1: Convert MP4 to MP3
    start_time = time.time()
    convert_video_to_audio(video_path, mp3_output_path)
    audio_to_text(mp3_output_path, output_file)
    #audio_to_text_diarization(mp3_output_path, output_file)
    documents = load_and_split_document1(output_file)

    # Step 4: Split the audio into chunks
    """ for fil in os.listdir(mp3_file_folder):
        episode_id = fil.split('.')[0]
        print('Splitting Episode ID:', episode_id)
        split_audio(mp3_file_folder, mp3_chunk_folder, episode_id, chunk_length_ms, overlap_ms)

    # Step 5: Process the chunks
    chunk_fps = os.listdir(mp3_chunk_folder)
    text_splitter = TokenTextSplitter(chunk_size=500, chunk_overlap=100)
    documents = process_chunks(chunk_fps, text_splitter, mp3_chunk_folder, output_file) """

    # Step 6: Save the FAISS index
    save_faiss_index(documents, folder_path)
    end_time = time.time()
    token_count = count_tokens(documents)
    print(f"Training process took {end_time - start_time:.2f} seconds")
    print(f"Total tokens processed during training: {token_count}")

if __name__ == "__main__":
    main()
