# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from argparse import Namespace
import os
import re
import unittest
from pathlib import Path
from tqdm import tqdm
from typing import List, Dict, Optional
import torch
from fairseq.checkpoint_utils import load_model_ensemble_and_task
from fairseq.scoring.wer import WerScorer
from fairseq.scoring.bleu import SacrebleuScorer
from fairseq import utils
import zipfile

S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq"


class TestFairseqSpeech(unittest.TestCase):
    @classmethod
    def download(cls, base_url: str, out_root: Path, filename: str):
        url = f"{base_url}/{filename}"
        path = out_root / filename
        if not path.exists():
            torch.hub.download_url_to_file(url, path.as_posix(), progress=True)
        return path

    def _set_up(self, dataset_id: str, s3_dir: str, data_filenames: List[str]):
        self.use_cuda = torch.cuda.is_available()
        self.root = Path.home() / ".cache" / "fairseq" / dataset_id
        self.root.mkdir(exist_ok=True, parents=True)
        os.chdir(self.root)
        self.base_url = (
            s3_dir if re.search("^https:", s3_dir) else f"{S3_BASE_URL}/{s3_dir}"
        )
        for filename in data_filenames:
            self.download(self.base_url, self.root, filename)

    def set_up_librispeech(self):
        self._set_up(
            "librispeech",
            "s2t/librispeech",
            [
                "cfg_librispeech.yaml",
                "spm_librispeech_unigram10000.model",
                "spm_librispeech_unigram10000.txt",
                "librispeech_test-other.tsv",
                "librispeech_test-other.zip",
            ],
        )

    def set_up_ljspeech(self):
        self._set_up(
            "ljspeech",
            "s2/ljspeech",
            [
                "cfg_ljspeech_g2p.yaml",
                "ljspeech_g2p_gcmvn_stats.npz",
                "ljspeech_g2p.txt",
                "ljspeech_test.tsv",
                "ljspeech_test.zip",
            ],
        )

    def set_up_sotasty_es_en(self):
        self._set_up(
            "sotasty_es_en",
            "s2t/big/es-en",
            [
                "cfg_es_en.yaml",
                "spm_bpe32768_es_en.model",
                "spm_bpe32768_es_en.txt",
                "sotasty_es_en_test_ted.tsv",
                "sotasty_es_en_test_ted.zip",
            ],
        )

    def set_up_mustc_de_fbank(self):
        self._set_up(
            "mustc_de_fbank",
            "https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de",
            [
                "config.yaml",
                "spm.model",
                "dict.txt",
                "src_dict.txt",
                "tgt_dict.txt",
                "tst-COMMON.tsv",
                "tst-COMMON.zip",
            ],
        )

    def download_and_load_checkpoint(
        self,
        checkpoint_filename: str,
        arg_overrides: Optional[Dict[str, str]] = None,
        strict: bool = True,
    ):
        path = self.download(self.base_url, self.root, checkpoint_filename)
        _arg_overrides = arg_overrides or {}
        _arg_overrides["data"] = self.root.as_posix()
        models, cfg, task = load_model_ensemble_and_task(
            [path.as_posix()], arg_overrides=_arg_overrides, strict=strict
        )
        if self.use_cuda:
            for model in models:
                model.cuda()

        return models, cfg, task, self.build_generator(task, models, cfg)

    def build_generator(
        self,
        task,
        models,
        cfg,
    ):
        return task.build_generator(models, cfg)

    @classmethod
    def get_batch_iterator(cls, task, test_split, max_tokens, max_positions):
        task.load_dataset(test_split)
        return task.get_batch_iterator(
            dataset=task.dataset(test_split),
            max_tokens=max_tokens,
            max_positions=max_positions,
            num_workers=1,
        ).next_epoch_itr(shuffle=False)

    @classmethod
    def get_wer_scorer(
        cls, tokenizer="none", lowercase=False, remove_punct=False, char_level=False
    ):
        scorer_args = {
            "wer_tokenizer": tokenizer,
            "wer_lowercase": lowercase,
            "wer_remove_punct": remove_punct,
            "wer_char_level": char_level,
        }
        return WerScorer(Namespace(**scorer_args))

    @classmethod
    def get_bleu_scorer(cls, tokenizer="13a", lowercase=False, char_level=False):
        scorer_args = {
            "sacrebleu_tokenizer": tokenizer,
            "sacrebleu_lowercase": lowercase,
            "sacrebleu_char_level": char_level,
        }
        return SacrebleuScorer(Namespace(**scorer_args))

    @torch.no_grad()
    def base_test(
        self,
        ckpt_name,
        reference_score,
        score_delta=0.3,
        dataset="librispeech_test-other",
        max_tokens=65_536,
        max_positions=(4_096, 1_024),
        arg_overrides=None,
        strict=True,
        score_type="wer",
    ):
        models, _, task, generator = self.download_and_load_checkpoint(
            ckpt_name, arg_overrides=arg_overrides, strict=strict
        )
        if not self.use_cuda:
            return

        batch_iterator = self.get_batch_iterator(
            task, dataset, max_tokens, max_positions
        )
        if score_type == "bleu":
            scorer = self.get_bleu_scorer()
        elif score_type == "wer":
            scorer = self.get_wer_scorer()
        else:
            raise Exception(f"Unsupported score type {score_type}")

        progress = tqdm(enumerate(batch_iterator), total=len(batch_iterator))
        for batch_idx, sample in progress:
            sample = utils.move_to_cuda(sample) if self.use_cuda else sample
            hypo = task.inference_step(generator, models, sample)
            for i, sample_id in enumerate(sample["id"].tolist()):
                tgt_str, hypo_str = self.postprocess_tokens(
                    task,
                    sample["target"][i, :],
                    hypo[i][0]["tokens"].int().cpu(),
                )
                if batch_idx == 0 and i < 3:
                    print(f"T-{sample_id} {tgt_str}")
                    print(f"H-{sample_id} {hypo_str}")
                scorer.add_string(tgt_str, hypo_str)

        print(scorer.result_string() + f" (reference: {reference_score})")
        self.assertAlmostEqual(scorer.score(), reference_score, delta=score_delta)

    def postprocess_tokens(self, task, target, hypo_tokens):
        tgt_tokens = utils.strip_pad(target, task.tgt_dict.pad()).int().cpu()
        tgt_str = task.tgt_dict.string(tgt_tokens, "sentencepiece")
        hypo_str = task.tgt_dict.string(hypo_tokens, "sentencepiece")
        return tgt_str, hypo_str

    def unzip_files(self, zip_file_name):
        zip_file_path = self.root / zip_file_name
        with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
            zip_ref.extractall(self.root / zip_file_name.strip(".zip"))
