#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Run inference for pre-processed data with a trained model.
"""

import ast
from collections import namedtuple
from dataclasses import dataclass, field
from enum import Enum, auto
import hydra
from hydra.core.config_store import ConfigStore
import logging
import math
import os
from omegaconf import OmegaConf
from typing import Optional
import sys

import editdistance
import torch

from hydra.core.hydra_config import HydraConfig

from fairseq import checkpoint_utils, progress_bar, tasks, utils
from fairseq.data.data_utils import post_process
from fairseq.dataclass.configs import FairseqDataclass, FairseqConfig
from fairseq.logging.meters import StopwatchMeter
from omegaconf import open_dict

from examples.speech_recognition.kaldi.kaldi_decoder import KaldiDecoderConfig

logging.root.setLevel(logging.INFO)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)


class DecoderType(Enum):
    VITERBI = auto()
    KENLM = auto()
    FAIRSEQ = auto()
    KALDI = auto()


@dataclass
class UnsupGenerateConfig(FairseqDataclass):
    fairseq: FairseqConfig = FairseqConfig()
    lm_weight: float = field(
        default=2.0,
        metadata={"help": "language model weight"},
    )
    w2l_decoder: DecoderType = field(
        default=DecoderType.VITERBI,
        metadata={"help": "type of decoder to use"},
    )
    kaldi_decoder_config: Optional[KaldiDecoderConfig] = None
    lexicon: Optional[str] = field(
        default=None,
        metadata={
            "help": "path to lexicon. This is also used to 'phonemize' for unsupvised param tuning"
        },
    )
    lm_model: Optional[str] = field(
        default=None,
        metadata={"help": "path to language model (kenlm or fairseq)"},
    )
    decode_stride: Optional[float] = field(
        default=None,
        metadata={"help": "changing the decoding frequency of the generator"},
    )
    unit_lm: bool = field(
        default=False,
        metadata={"help": "whether to use unit lm"},
    )
    beam_threshold: float = field(
        default=50.0,
        metadata={"help": "beam score threshold"},
    )
    beam_size_token: float = field(
        default=100.0,
        metadata={"help": "max tokens per beam"},
    )
    beam: int = field(
        default=5,
        metadata={"help": "decoder beam size"},
    )
    nbest: int = field(
        default=1,
        metadata={"help": "number of results to return"},
    )
    word_score: float = field(
        default=1.0,
        metadata={"help": "word score to add at end of word"},
    )
    unk_weight: float = field(
        default=-math.inf,
        metadata={"help": "unknown token weight"},
    )
    sil_weight: float = field(
        default=0.0,
        metadata={"help": "silence token weight"},
    )
    targets: Optional[str] = field(
        default=None,
        metadata={"help": "extension of ground truth labels to compute UER"},
    )
    results_path: Optional[str] = field(
        default=None,
        metadata={"help": "where to store results"},
    )
    post_process: Optional[str] = field(
        default=None,
        metadata={"help": "how to post process results"},
    )
    vocab_usage_power: float = field(
        default=2,
        metadata={"help": "for unsupervised param tuning"},
    )

    viterbi_transcript: Optional[str] = field(
        default=None,
        metadata={"help": "for unsupervised param tuning"},
    )
    min_lm_ppl: float = field(
        default=0,
        metadata={"help": "for unsupervised param tuning"},
    )
    min_vt_uer: float = field(
        default=0,
        metadata={"help": "for unsupervised param tuning"},
    )

    blank_weight: float = field(
        default=0,
        metadata={"help": "value to add or set for blank emission"},
    )
    blank_mode: str = field(
        default="set",
        metadata={
            "help": "can be add or set, how to modify blank emission with blank weight"
        },
    )
    sil_is_blank: bool = field(
        default=False,
        metadata={"help": "if true, <SIL> token is same as blank token"},
    )

    unsupervised_tuning: bool = field(
        default=False,
        metadata={
            "help": "if true, returns a score based on unsupervised param selection metric instead of UER"
        },
    )
    is_ax: bool = field(
        default=False,
        metadata={
            "help": "if true, assumes we are using ax for tuning and returns a tuple for ax to consume"
        },
    )


def get_dataset_itr(cfg, task):
    return task.get_batch_iterator(
        dataset=task.dataset(cfg.fairseq.dataset.gen_subset),
        max_tokens=cfg.fairseq.dataset.max_tokens,
        max_sentences=cfg.fairseq.dataset.batch_size,
        max_positions=(sys.maxsize, sys.maxsize),
        ignore_invalid_inputs=cfg.fairseq.dataset.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=cfg.fairseq.dataset.required_batch_size_multiple,
        num_shards=cfg.fairseq.dataset.num_shards,
        shard_id=cfg.fairseq.dataset.shard_id,
        num_workers=cfg.fairseq.dataset.num_workers,
        data_buffer_size=cfg.fairseq.dataset.data_buffer_size,
    ).next_epoch_itr(shuffle=False)


def process_predictions(
    cfg: UnsupGenerateConfig,
    hypos,
    tgt_dict,
    target_tokens,
    res_files,
):
    retval = []
    word_preds = []
    transcriptions = []
    dec_scores = []

    for i, hypo in enumerate(hypos[: min(len(hypos), cfg.nbest)]):
        if torch.is_tensor(hypo["tokens"]):
            tokens = hypo["tokens"].int().cpu()
            tokens = tokens[tokens >= tgt_dict.nspecial]
            hyp_pieces = tgt_dict.string(tokens)
        else:
            hyp_pieces = " ".join(hypo["tokens"])

        if "words" in hypo and len(hypo["words"]) > 0:
            hyp_words = " ".join(hypo["words"])
        else:
            hyp_words = post_process(hyp_pieces, cfg.post_process)

        to_write = {}
        if res_files is not None:
            to_write[res_files["hypo.units"]] = hyp_pieces
            to_write[res_files["hypo.words"]] = hyp_words

        tgt_words = ""
        if target_tokens is not None:
            if isinstance(target_tokens, str):
                tgt_pieces = tgt_words = target_tokens
            else:
                tgt_pieces = tgt_dict.string(target_tokens)
                tgt_words = post_process(tgt_pieces, cfg.post_process)

            if res_files is not None:
                to_write[res_files["ref.units"]] = tgt_pieces
                to_write[res_files["ref.words"]] = tgt_words

        if not cfg.fairseq.common_eval.quiet:
            logger.info(f"HYPO {i}:" + hyp_words)
            if tgt_words:
                logger.info("TARGET:" + tgt_words)

            if "am_score" in hypo and "lm_score" in hypo:
                logger.info(
                    f"DECODER AM SCORE: {hypo['am_score']}, DECODER LM SCORE: {hypo['lm_score']}, DECODER SCORE: {hypo['score']}"
                )
            elif "score" in hypo:
                logger.info(f"DECODER SCORE: {hypo['score']}")

            logger.info("___________________")

        hyp_words_arr = hyp_words.split()
        tgt_words_arr = tgt_words.split()

        retval.append(
            (
                editdistance.eval(hyp_words_arr, tgt_words_arr),
                len(hyp_words_arr),
                len(tgt_words_arr),
                hyp_pieces,
                hyp_words,
            )
        )
        word_preds.append(hyp_words_arr)
        transcriptions.append(to_write)
        dec_scores.append(-hypo.get("score", 0))  # negate cuz kaldi returns NLL

    if len(retval) > 1:
        best = None
        for r, t in zip(retval, transcriptions):
            if best is None or r[0] < best[0][0]:
                best = r, t
        for dest, tran in best[1].items():
            print(tran, file=dest)
            dest.flush()
        return best[0]

    assert len(transcriptions) == 1
    for dest, tran in transcriptions[0].items():
        print(tran, file=dest)

    return retval[0]


def prepare_result_files(cfg: UnsupGenerateConfig):
    def get_res_file(file_prefix):
        if cfg.fairseq.dataset.num_shards > 1:
            file_prefix = f"{cfg.fairseq.dataset.shard_id}_{file_prefix}"
        path = os.path.join(
            cfg.results_path,
            "{}{}.txt".format(
                cfg.fairseq.dataset.gen_subset,
                file_prefix,
            ),
        )
        return open(path, "w", buffering=1)

    if not cfg.results_path:
        return None

    return {
        "hypo.words": get_res_file(""),
        "hypo.units": get_res_file("_units"),
        "ref.words": get_res_file("_ref"),
        "ref.units": get_res_file("_ref_units"),
        "hypo.nbest.words": get_res_file("_nbest_words"),
    }


def optimize_models(cfg: UnsupGenerateConfig, use_cuda, models):
    """Optimize ensemble for generation"""
    for model in models:
        model.eval()
        if cfg.fairseq.common.fp16:
            model.half()
        if use_cuda:
            model.cuda()


GenResult = namedtuple(
    "GenResult",
    [
        "count",
        "errs_t",
        "gen_timer",
        "lengths_hyp_unit_t",
        "lengths_hyp_t",
        "lengths_t",
        "lm_score_t",
        "num_feats",
        "num_sentences",
        "num_symbols",
        "vt_err_t",
        "vt_length_t",
    ],
)


def generate(cfg: UnsupGenerateConfig, models, saved_cfg, use_cuda):
    task = tasks.setup_task(cfg.fairseq.task)
    saved_cfg.task.labels = cfg.fairseq.task.labels
    task.load_dataset(cfg.fairseq.dataset.gen_subset, task_cfg=saved_cfg.task)
    # Set dictionary
    tgt_dict = task.target_dictionary
    logger.info(
        "| {} {} {} examples".format(
            cfg.fairseq.task.data,
            cfg.fairseq.dataset.gen_subset,
            len(task.dataset(cfg.fairseq.dataset.gen_subset)),
        )
    )
    # Load dataset (possibly sharded)
    itr = get_dataset_itr(cfg, task)
    # Initialize generator
    gen_timer = StopwatchMeter()

    def build_generator(cfg: UnsupGenerateConfig):
        w2l_decoder = cfg.w2l_decoder
        if w2l_decoder == DecoderType.VITERBI:
            from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder

            return W2lViterbiDecoder(cfg, task.target_dictionary)
        elif w2l_decoder == DecoderType.KENLM:
            from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder

            return W2lKenLMDecoder(cfg, task.target_dictionary)
        elif w2l_decoder == DecoderType.FAIRSEQ:
            from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder

            return W2lFairseqLMDecoder(cfg, task.target_dictionary)
        elif w2l_decoder == DecoderType.KALDI:
            from examples.speech_recognition.kaldi.kaldi_decoder import KaldiDecoder

            assert cfg.kaldi_decoder_config is not None

            return KaldiDecoder(
                cfg.kaldi_decoder_config,
                cfg.beam,
            )
        else:
            raise NotImplementedError(
                "only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment but found "
                + str(w2l_decoder)
            )

    generator = build_generator(cfg)

    kenlm = None
    fairseq_lm = None
    if cfg.lm_model is not None:
        import kenlm

        kenlm = kenlm.Model(cfg.lm_model)

    num_sentences = 0
    if cfg.results_path is not None and not os.path.exists(cfg.results_path):
        os.makedirs(cfg.results_path)

    res_files = prepare_result_files(cfg)
    errs_t = 0
    lengths_hyp_t = 0
    lengths_hyp_unit_t = 0
    lengths_t = 0
    count = 0
    num_feats = 0
    all_hyp_pieces = []
    all_hyp_words = []

    num_symbols = (
        len([s for s in tgt_dict.symbols if not s.startswith("madeup")])
        - tgt_dict.nspecial
    )
    targets = None
    if cfg.targets is not None:
        tgt_path = os.path.join(
            cfg.fairseq.task.data, cfg.fairseq.dataset.gen_subset + "." + cfg.targets
        )
        if os.path.exists(tgt_path):
            with open(tgt_path, "r") as f:
                targets = f.read().splitlines()
    viterbi_transcript = None
    if cfg.viterbi_transcript is not None and len(cfg.viterbi_transcript) > 0:
        logger.info(f"loading viterbi transcript from {cfg.viterbi_transcript}")
        with open(cfg.viterbi_transcript, "r") as vf:
            viterbi_transcript = vf.readlines()
            viterbi_transcript = [v.rstrip().split() for v in viterbi_transcript]

    gen_timer.start()

    start = 0
    end = len(itr)

    hypo_futures = None
    if cfg.w2l_decoder == DecoderType.KALDI:
        logger.info("Extracting features")
        hypo_futures = []
        samples = []
        with progress_bar.build_progress_bar(cfg.fairseq.common, itr) as t:
            for i, sample in enumerate(t):
                if "net_input" not in sample or i < start or i >= end:
                    continue
                if "padding_mask" not in sample["net_input"]:
                    sample["net_input"]["padding_mask"] = None

                hypos, num_feats = gen_hypos(
                    generator, models, num_feats, sample, task, use_cuda
                )
                hypo_futures.append(hypos)
                samples.append(sample)
        itr = list(zip(hypo_futures, samples))
        start = 0
        end = len(itr)
        logger.info("Finished extracting features")

    with progress_bar.build_progress_bar(cfg.fairseq.common, itr) as t:
        for i, sample in enumerate(t):
            if i < start or i >= end:
                continue

            if hypo_futures is not None:
                hypos, sample = sample
                hypos = [h.result() for h in hypos]
            else:
                if "net_input" not in sample:
                    continue

                hypos, num_feats = gen_hypos(
                    generator, models, num_feats, sample, task, use_cuda
                )

            for i, sample_id in enumerate(sample["id"].tolist()):
                if targets is not None:
                    target_tokens = targets[sample_id]
                elif "target" in sample or "target_label" in sample:
                    toks = (
                        sample["target"][i, :]
                        if "target_label" not in sample
                        else sample["target_label"][i, :]
                    )

                    target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
                else:
                    target_tokens = None

                # Process top predictions
                (
                    errs,
                    length_hyp,
                    length,
                    hyp_pieces,
                    hyp_words,
                ) = process_predictions(
                    cfg,
                    hypos[i],
                    tgt_dict,
                    target_tokens,
                    res_files,
                )
                errs_t += errs
                lengths_hyp_t += length_hyp
                lengths_hyp_unit_t += (
                    len(hyp_pieces) if len(hyp_pieces) > 0 else len(hyp_words)
                )
                lengths_t += length
                count += 1
                all_hyp_pieces.append(hyp_pieces)
                all_hyp_words.append(hyp_words)

            num_sentences += (
                sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
            )

    lm_score_sum = 0
    if kenlm is not None:

        if cfg.unit_lm:
            lm_score_sum = sum(kenlm.score(w) for w in all_hyp_pieces)
        else:
            lm_score_sum = sum(kenlm.score(w) for w in all_hyp_words)
    elif fairseq_lm is not None:
        lm_score_sum = sum(fairseq_lm.score([h.split() for h in all_hyp_words])[0])

    vt_err_t = 0
    vt_length_t = 0
    if viterbi_transcript is not None:
        unit_hyps = []
        if cfg.targets is not None and cfg.lexicon is not None:
            lex = {}
            with open(cfg.lexicon, "r") as lf:
                for line in lf:
                    items = line.rstrip().split()
                    lex[items[0]] = items[1:]
            for h in all_hyp_pieces:
                hyp_ws = []
                for w in h.split():
                    assert w in lex, w
                    hyp_ws.extend(lex[w])
                unit_hyps.append(hyp_ws)

        else:
            unit_hyps.extend([h.split() for h in all_hyp_words])

        vt_err_t = sum(
            editdistance.eval(vt, h) for vt, h in zip(viterbi_transcript, unit_hyps)
        )

        vt_length_t = sum(len(h) for h in viterbi_transcript)

    if res_files is not None:
        for r in res_files.values():
            r.close()

    gen_timer.stop(lengths_hyp_t)

    return GenResult(
        count,
        errs_t,
        gen_timer,
        lengths_hyp_unit_t,
        lengths_hyp_t,
        lengths_t,
        lm_score_sum,
        num_feats,
        num_sentences,
        num_symbols,
        vt_err_t,
        vt_length_t,
    )


def gen_hypos(generator, models, num_feats, sample, task, use_cuda):
    sample = utils.move_to_cuda(sample) if use_cuda else sample

    if "features" in sample["net_input"]:
        sample["net_input"]["dense_x_only"] = True
        num_feats += (
            sample["net_input"]["features"].shape[0]
            * sample["net_input"]["features"].shape[1]
        )
    hypos = task.inference_step(generator, models, sample, None)
    return hypos, num_feats


def main(cfg: UnsupGenerateConfig, model=None):
    if (
        cfg.fairseq.dataset.max_tokens is None
        and cfg.fairseq.dataset.batch_size is None
    ):
        cfg.fairseq.dataset.max_tokens = 1024000

    use_cuda = torch.cuda.is_available() and not cfg.fairseq.common.cpu

    task = tasks.setup_task(cfg.fairseq.task)

    overrides = ast.literal_eval(cfg.fairseq.common_eval.model_overrides)

    if cfg.fairseq.task._name == "unpaired_audio_text":
        overrides["model"] = {
            "blank_weight": cfg.blank_weight,
            "blank_mode": cfg.blank_mode,
            "blank_is_sil": cfg.sil_is_blank,
            "no_softmax": True,
            "segmentation": {
                "type": "NONE",
            },
        }
    else:
        overrides["model"] = {
            "blank_weight": cfg.blank_weight,
            "blank_mode": cfg.blank_mode,
        }
    
    if cfg.decode_stride:
        overrides["model"]["generator_stride"] = cfg.decode_stride

    if model is None:
        # Load ensemble
        logger.info("| loading model(s) from {}".format(cfg.fairseq.common_eval.path))
        models, saved_cfg = checkpoint_utils.load_model_ensemble(
            cfg.fairseq.common_eval.path.split("\\"),
            arg_overrides=overrides,
            task=task,
            suffix=cfg.fairseq.checkpoint.checkpoint_suffix,
            strict=(cfg.fairseq.checkpoint.checkpoint_shard_count == 1),
            num_shards=cfg.fairseq.checkpoint.checkpoint_shard_count,
        )
        optimize_models(cfg, use_cuda, models)
    else:
        models = [model]
        saved_cfg = cfg.fairseq

    with open_dict(saved_cfg.task):
        saved_cfg.task.shuffle = False
        saved_cfg.task.sort_by_length = False

    gen_result = generate(cfg, models, saved_cfg, use_cuda)

    wer = None
    if gen_result.lengths_t > 0:
        wer = gen_result.errs_t * 100.0 / gen_result.lengths_t
        logger.info(f"WER: {wer}")

    lm_ppl = float("inf")

    if gen_result.lm_score_t != 0 and gen_result.lengths_hyp_t > 0:
        hyp_len = gen_result.lengths_hyp_t
        lm_ppl = math.pow(
            10, -gen_result.lm_score_t / (hyp_len + gen_result.num_sentences)
        )
        logger.info(f"LM PPL: {lm_ppl}")

    logger.info(
        "| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
        " sentences/s, {:.2f} tokens/s)".format(
            gen_result.num_sentences,
            gen_result.gen_timer.n,
            gen_result.gen_timer.sum,
            gen_result.num_sentences / gen_result.gen_timer.sum,
            1.0 / gen_result.gen_timer.avg,
        )
    )

    vt_diff = None
    if gen_result.vt_length_t > 0:
        vt_diff = gen_result.vt_err_t / gen_result.vt_length_t
        vt_diff = max(cfg.min_vt_uer, vt_diff)

    lm_ppl = max(cfg.min_lm_ppl, lm_ppl)

    if not cfg.unsupervised_tuning:
        weighted_score = wer
    else:
        weighted_score = math.log(lm_ppl) * (vt_diff or 1.0)

    res = (
        f"| Generate {cfg.fairseq.dataset.gen_subset} with beam={cfg.beam}, "
        f"lm_weight={cfg.kaldi_decoder_config.acoustic_scale if cfg.kaldi_decoder_config else cfg.lm_weight}, "
        f"word_score={cfg.word_score}, sil_weight={cfg.sil_weight}, blank_weight={cfg.blank_weight}, "
        f"WER: {wer}, LM_PPL: {lm_ppl}, num feats: {gen_result.num_feats}, "
        f"length: {gen_result.lengths_hyp_t}, UER to viterbi: {(vt_diff or 0) * 100}, score: {weighted_score}"
    )

    logger.info(res)
    # print(res)

    return task, weighted_score


@hydra.main(
    config_path=os.path.join("../../..", "fairseq", "config"), config_name="config"
)
def hydra_main(cfg):
    with open_dict(cfg):
        # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
        cfg.job_logging_cfg = OmegaConf.to_container(
            HydraConfig.get().job_logging, resolve=True
        )

    cfg = OmegaConf.create(
        OmegaConf.to_container(cfg, resolve=False, enum_to_str=False)
    )
    OmegaConf.set_struct(cfg, True)
    logger.info(cfg)

    utils.import_user_module(cfg.fairseq.common)

    _, score = main(cfg)

    if cfg.is_ax:
        return score, None
    return score


def cli_main():
    try:
        from hydra._internal.utils import get_args

        cfg_name = get_args().config_name or "config"
    except:
        logger.warning("Failed to get config name from hydra args")
        cfg_name = "config"

    cs = ConfigStore.instance()
    cs.store(name=cfg_name, node=UnsupGenerateConfig)
    hydra_main()


if __name__ == "__main__":
    cli_main()
