#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Translate pre-processed data with a trained model.
"""

import numpy as np
import torch
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.sequence_generator import EnsembleModel
from fairseq.utils import safe_hasattr


def get_avg_pool(
    models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False
):
    model = EnsembleModel(models)

    # model.forward normally channels prev_output_tokens into the decoder
    # separately, but SequenceGenerator directly calls model.encoder
    encoder_input = {
        k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
    }

    # compute the encoder output for each beam
    encoder_outs = model.forward_encoder(encoder_input)
    np_encoder_outs = encoder_outs[0].encoder_out.cpu().numpy().astype(np.float32)
    encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy().astype(
        np.float32
    )
    encoder_mask = np.expand_dims(encoder_mask.T, axis=2)
    if has_langtok:
        encoder_mask = encoder_mask[1:, :, :]
        np_encoder_outs = np_encoder_outs[1, :, :]
    masked_encoder_outs = encoder_mask * np_encoder_outs
    avg_pool = (masked_encoder_outs / encoder_mask.sum(axis=0)).sum(axis=0)
    return avg_pool


def main(args):
    assert args.path is not None, "--path required for generation!"
    assert (
        not args.sampling or args.nbest == args.beam
    ), "--sampling requires --nbest to be equal to --beam"
    assert (
        args.replace_unk is None or args.raw_text
    ), "--replace-unk requires a raw text dataset (--raw-text)"

    args.beam = 1
    utils.import_user_module(args)

    if args.max_tokens is None:
        args.max_tokens = 12000
    print(args)
    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, "source_dictionary", None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    print("| loading model(s) from {}".format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(":"),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    num_sentences = 0
    source_sentences = []
    shard_id = 0
    all_avg_pool = None
    encoder_has_langtok = (
        safe_hasattr(task.args, "encoder_langtok")
        and task.args.encoder_langtok is not None
        and safe_hasattr(task.args, "lang_tok_replacing_bos_eos")
        and not task.args.lang_tok_replacing_bos_eos
    )
    with progress_bar.build_progress_bar(args, itr) as t:
        for sample in t:
            if sample is None:
                print("Skipping None")
                continue
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if "net_input" not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample["target"][:, : args.prefix_size]

            with torch.no_grad():
                avg_pool = get_avg_pool(
                    models,
                    sample,
                    prefix_tokens,
                    src_dict,
                    args.post_process,
                    has_langtok=encoder_has_langtok,
                )
                if all_avg_pool is not None:
                    all_avg_pool = np.concatenate((all_avg_pool, avg_pool))
                else:
                    all_avg_pool = avg_pool

            if not isinstance(sample["id"], list):
                sample_ids = sample["id"].tolist()
            else:
                sample_ids = sample["id"]
            for i, sample_id in enumerate(sample_ids):
                # Remove padding
                src_tokens = utils.strip_pad(
                    sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()
                )

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(args.gen_subset).src.get_original_text(
                        sample_id
                    )
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.post_process)
                    else:
                        src_str = ""

                if not args.quiet:
                    if src_dict is not None:
                        print("S-{}\t{}".format(sample_id, src_str))

                source_sentences.append(f"{sample_id}\t{src_str}")

            num_sentences += sample["nsentences"]
            if all_avg_pool.shape[0] >= 1000000:
                with open(
                    f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}",
                    "w",
                ) as avg_pool_file:
                    all_avg_pool.tofile(avg_pool_file)
                with open(
                    f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}",
                    "w",
                ) as sentence_file:
                    sentence_file.writelines(f"{line}\n" for line in source_sentences)
                all_avg_pool = None
                source_sentences = []
                shard_id += 1

    if all_avg_pool is not None:
        with open(
            f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}", "w"
        ) as avg_pool_file:
            all_avg_pool.tofile(avg_pool_file)
        with open(
            f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}", "w"
        ) as sentence_file:
            sentence_file.writelines(f"{line}\n" for line in source_sentences)
    return None


def cli_main():
    parser = options.get_generation_parser()
    parser.add_argument(
        "--encoder-save-dir",
        default="",
        type=str,
        metavar="N",
        help="directory to save encoder outputs",
    )
    args = options.parse_args_and_arch(parser)
    main(args)


if __name__ == "__main__":
    cli_main()
