import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import soundfile as sf

device = "cuda:0" if torch.cuda.is_available() else "cpu"

model = ParlerTTSForConditionalGeneration.from_pretrained("ai4bharat/indic-parler-tts").to(device)
tokenizer = AutoTokenizer.from_pretrained("ai4bharat/indic-parler-tts")
description_tokenizer = AutoTokenizer.from_pretrained(model.config.text_encoder._name_or_path)

#prompt = "अरे, तुम आज कैसे हो?"
prompt  = '''

प्रकृति हमारे जीवन का अभिन्न हिस्सा है। 
यह हमें शुद्ध वायु, जल, भोजन और जीवन जीने के लिए आवश्यक 
सभी संसाधन प्रदान करती है। पेड़-पौधे, नदियाँ, पर्वत और वन्य जीव –

'''
description = "Divya's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."

description_input_ids = description_tokenizer(description, return_tensors="pt").to(device)
prompt_input_ids = tokenizer(prompt, return_tensors="pt").to(device)

generation = model.generate(input_ids=description_input_ids.input_ids, attention_mask=description_input_ids.attention_mask, prompt_input_ids=prompt_input_ids.input_ids, prompt_attention_mask=prompt_input_ids.attention_mask)
audio_arr = generation.cpu().numpy().squeeze()
sf.write("indic_tts_out.wav", audio_arr, model.config.sampling_rate)