import torch
from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig, AutoTokenizer
from IndicTransToolkit.processor import IndicProcessor

import pdfplumber
from fpdf import FPDF
import textwrap

BATCH_SIZE = 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
quantization = None

def initialize_model_and_tokenizer(ckpt_dir, quantization):
    if quantization == "4-bit":
        qconfig = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
        )
    elif quantization == "8-bit":
        qconfig = BitsAndBytesConfig(
            load_in_8bit=True,
            bnb_8bit_use_double_quant=True,
            bnb_8bit_compute_dtype=torch.bfloat16,
        )
    else:
        qconfig = None

    tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, trust_remote_code=True)
    model = AutoModelForSeq2SeqLM.from_pretrained(
        ckpt_dir,
        trust_remote_code=True,
        low_cpu_mem_usage=True,
        quantization_config=qconfig,
    )

    if qconfig == None:
        model = model.to(DEVICE)
        if DEVICE == "cuda":
            model.half()

    model.eval()

    return tokenizer, model


def batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer, ip):
    translations = []
    for i in range(0, len(input_sentences), BATCH_SIZE):
        batch = input_sentences[i : i + BATCH_SIZE]

        # Preprocess the batch and extract entity mappings
        batch = ip.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang)

        # Tokenize the batch and generate input encodings
        inputs = tokenizer(
            batch,
            truncation=True,
            padding="longest",
            return_tensors="pt",
            return_attention_mask=True,
        ).to(DEVICE)

        # Generate translations using the model
        with torch.no_grad():
            generated_tokens = model.generate(
                **inputs,
                use_cache=True,
                min_length=0,
                max_length=256,
                num_beams=5,
                num_return_sequences=1,
            )

        # Decode the generated tokens into text
        generated_tokens = tokenizer.batch_decode(
            generated_tokens,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True,
        )

        # Postprocess the translations, including entity replacement
        translations += ip.postprocess_batch(generated_tokens, lang=tgt_lang)

        del inputs
        torch.cuda.empty_cache()

    return translations


en_indic_ckpt_dir = "ai4bharat/indictrans2-en-indic-1B"  # ai4bharat/indictrans2-en-indic-dist-200M
en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(en_indic_ckpt_dir, quantization)

ip = IndicProcessor(inference=True)

# en_sents = [
#     "When I was young, I used to go to the park every day.",
#     "He has many old books, which he inherited from his ancestors.",
#     "I can't figure out how to solve my problem.",
#     "She is very hardworking and intelligent, which is why she got all the good marks.",
#     "We watched a new movie last week, which was very inspiring.",
#     "If you had met me at that time, we would have gone out to eat.",
#     "She went to the market with her sister to buy a new sari.",
#     "Raj told me that he is going to his grandmother's house next month.",
#     "All the kids were having fun at the party and were eating lots of sweets.",
#     "My friend has invited me to his birthday party, and I will give him a gift.",
# ]

# src_lang, tgt_lang = "eng_Latn", "hin_Deva"
# hi_translations = batch_translate(en_sents, src_lang, tgt_lang, en_indic_model, en_indic_tokenizer, ip)

# print(f"\n{src_lang} - {tgt_lang}")
# for input_sentence, translation in zip(en_sents, hi_translations):
#     print(f"{src_lang}: {input_sentence}")
#     print(f"{tgt_lang}: {translation}")

def extract_sentences_from_pdf(pdf_path):
    """Extract text from PDF and split into sentence-like chunks."""
    with pdfplumber.open(pdf_path) as pdf:
        full_text = "\n".join(page.extract_text() for page in pdf.pages if page.extract_text())
    
    # Split text into chunks of approx 200 characters
    chunks = textwrap.wrap(full_text, width=200, break_long_words=False, replace_whitespace=False)
    return [chunk.strip() for chunk in chunks if chunk.strip()]

def save_translations_to_pdf(translations, output_pdf_path):
    """Save the translated lines into a new PDF file using Unicode-safe fpdf2."""
    pdf = FPDF()
    pdf.add_page()
    pdf.set_auto_page_break(auto=True, margin=15)

    # Add a Unicode Devanagari font (adjust path as needed)
    pdf.add_font("Devanagari", "", "NotoSansDevanagari-Regular.ttf", uni=True)
    pdf.set_font("Devanagari", size=14)

    for line in translations:
        pdf.multi_cell(w=0, h=10, text=line, new_x="LMARGIN", new_y="NEXT")

    pdf.output(output_pdf_path)
    print(f"✅ Translated PDF saved at: {output_pdf_path}")

def translate_pdf(input_pdf_path, output_pdf_path, model, tokenizer, ip, src_lang="eng_Latn", tgt_lang="hin_Deva"):
    """Extract, translate, and save PDF."""
    print(f"🔍 Extracting text from {input_pdf_path}...")
    input_sentences = extract_sentences_from_pdf(input_pdf_path)
    print(f"📝 Total chunks: {len(input_sentences)}")

    print("🚀 Translating...")
    translations = batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer, ip)

    print("💾 Saving to PDF...")
    save_translations_to_pdf(translations, output_pdf_path)
translate_pdf(
    input_pdf_path="cyber.pdf",
    output_pdf_path="translated_marathi_output.pdf",
    model=en_indic_model,
    tokenizer=en_indic_tokenizer,
    ip=ip,
    src_lang="eng_Latn",
    tgt_lang="mar_Deva"
)

# flush the models to free the GPU memory
del en_indic_tokenizer, en_indic_model