# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from tqdm import tqdm

from fairseq import utils
from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion
from tests.speech import TestFairseqSpeech


@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
class TestTTSTransformer(TestFairseqSpeech):
    def setUp(self):
        self.set_up_ljspeech()

    @torch.no_grad()
    def test_ljspeech_tts_transformer_checkpoint(self):
        models, cfg, task, generator = self.download_and_load_checkpoint(
            "ljspeech_transformer_g2p.pt",
            arg_overrides={
                "config_yaml": "cfg_ljspeech_g2p.yaml",
                "vocoder": "griffin_lim",
                "fp16": False,
            },
        )

        batch_iterator = self.get_batch_iterator(task, "ljspeech_test", 65_536, 1024)
        progress = tqdm(batch_iterator, total=len(batch_iterator))
        mcd, n_samples = 0.0, 0
        for sample in progress:
            sample = utils.move_to_cuda(sample) if self.use_cuda else sample
            hypos = generator.generate(models[0], sample, has_targ=True)
            rets = batch_mel_cepstral_distortion(
                [hypo["targ_waveform"] for hypo in hypos],
                [hypo["waveform"] for hypo in hypos],
                sr=task.sr,
            )
            mcd += sum(d.item() for d, _ in rets)
            n_samples += len(sample["id"].tolist())

        mcd = round(mcd / n_samples, 1)
        reference_mcd = 3.3
        print(f"MCD: {mcd} (reference: {reference_mcd})")
        self.assertAlmostEqual(mcd, reference_mcd, delta=0.1)


if __name__ == "__main__":
    unittest.main()
