# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Facebook, Inc. All Rights Reserved


import torch

from torch import nn

try:
    from transformers import AutoConfig, AutoTokenizer
except ImportError:
    pass

from . import transformermodel


class MMPTModel(nn.Module):
    """An e2e wrapper of inference model.
    """
    @classmethod
    def from_pretrained(cls, config, checkpoint="checkpoint_best.pt"):
        import os
        from ..utils import recursive_config
        from ..tasks import Task
        config = recursive_config(config)
        mmtask = Task.config_task(config)
        checkpoint_path = os.path.join(config.eval.save_path, checkpoint)
        mmtask.build_model(checkpoint=checkpoint_path)
        # TODO(huxu): make the video encoder configurable.
        from ..processors.models.s3dg import S3D
        video_encoder = S3D('pretrained_models/s3d_dict.npy', 512)
        video_encoder.load_state_dict(
            torch.load('pretrained_models/s3d_howto100m.pth'))
        from transformers import AutoTokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            config.dataset.bert_name, use_fast=config.dataset.use_fast
        )
        from ..processors import Aligner
        aligner = Aligner(config.dataset)
        return (
            MMPTModel(config, mmtask.model, video_encoder),
            tokenizer,
            aligner
        )

    def __init__(self, config, model, video_encoder, **kwargs):
        super().__init__()
        self.max_video_len = config.dataset.max_video_len
        self.video_encoder = video_encoder
        self.model = model

    def forward(self, video_frames, caps, cmasks, return_score=False):
        bsz = video_frames.size(0)
        assert bsz == 1, "only bsz=1 is supported now."
        seq_len = video_frames.size(1)
        video_frames = video_frames.view(-1, *video_frames.size()[2:])
        vfeats = self.video_encoder(video_frames.permute(0, 4, 1, 2, 3))
        vfeats = vfeats['video_embedding']
        vfeats = vfeats.view(bsz, seq_len, vfeats.size(-1))
        padding = torch.zeros(
            bsz, self.max_video_len - seq_len, vfeats.size(-1))
        vfeats = torch.cat([vfeats, padding], dim=1)
        vmasks = torch.cat([
            torch.ones((bsz, seq_len), dtype=torch.bool),
            torch.zeros((bsz, self.max_video_len - seq_len), dtype=torch.bool)
            ],
            dim=1
        )
        output = self.model(caps, cmasks, vfeats, vmasks)
        if return_score:
            output = {"score": torch.bmm(
                output["pooled_video"][:, None, :],
                output["pooled_text"][:, :, None]
            ).squeeze(-1).squeeze(-1)}
        return output


class MMFusion(nn.Module):
    """a MMPT wrapper class for MMBert style models.
    TODO: move isolated mask to a subclass.
    """
    def __init__(self, config, **kwargs):
        super().__init__()
        transformer_config = AutoConfig.from_pretrained(
            config.dataset.bert_name)
        self.hidden_size = transformer_config.hidden_size
        self.is_train = False
        if config.dataset.train_path is not None:
            self.is_train = True
        # 0 means no iso; 1-12 means iso up to that layer.
        self.num_hidden_layers = transformer_config.num_hidden_layers
        self.last_iso_layer = 0
        if config.dataset.num_iso_layer is not None:
            self.last_iso_layer = config.dataset.num_iso_layer - 1 + 1

        if config.model.mm_encoder_cls is not None:
            mm_encoder_cls = getattr(transformermodel, config.model.mm_encoder_cls)
            model_config = AutoConfig.from_pretrained(config.dataset.bert_name)
            model_config.max_video_len = config.dataset.max_video_len
            # TODO: a general way to add parameter for a model.
            model_config.use_seg_emb = config.model.use_seg_emb
            self.mm_encoder = mm_encoder_cls.from_pretrained(
                config.dataset.bert_name, config=model_config)
        elif config.model.video_encoder_cls is not None\
                and config.model.text_encoder_cls is not None:
            video_encoder_cls = getattr(transformermodel, config.model.video_encoder_cls)
            model_config = AutoConfig.from_pretrained(config.dataset.bert_name)
            model_config.max_video_len = config.dataset.max_video_len
            # TODO: make each model a set of config class.
            if hasattr(model_config, "num_layers"):
                model_config.num_layers = config.model.num_hidden_video_layers
            else:
                model_config.num_hidden_layers = config.model.num_hidden_video_layers
            self.video_encoder = video_encoder_cls.from_pretrained(
                config.dataset.bert_name, config=model_config)
            # exact same NLP model from Huggingface.
            text_encoder_cls = getattr(transformermodel, config.model.text_encoder_cls)
            self.text_encoder = text_encoder_cls.from_pretrained(
                config.dataset.bert_name)
        else:
            raise ValueError("the encoder must be either MM or two backbones.")

    def forward(
        self,
        caps,
        cmasks,
        vfeats,
        vmasks,
        **kwargs
    ):
        raise NotImplementedError(
            "Please derive MMFusion module."
        )

    def _mm_on_the_fly(
        self,
        cmasks,
        vmasks,
        attention_mask
    ):
        """helper function for mask, seg_ids and token_type_ids."""
        if attention_mask is None:
            attention_mask = self._mm_attention_mask(cmasks, vmasks)

        """
        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
        | first sequence    | second sequence |
        """
        token_type_ids = torch.cat(
            [
                torch.zeros(
                    (vmasks.size(0), vmasks.size(1) + 2),
                    dtype=torch.long,
                    device=vmasks.device,
                ),
                torch.ones(
                    (cmasks.size(0), cmasks.size(1) - 2),
                    dtype=torch.long,
                    device=cmasks.device,
                ),
            ],
            dim=1,
        )
        return attention_mask, token_type_ids

    def _mm_attention_mask(self, cmasks, vmasks):
        assert cmasks.size(0) == vmasks.size(0), "{}, {}, {}, {}".format(
            str(cmasks.size()),
            str(vmasks.size()),
            str(cmasks.size(0)),
            str(vmasks.size(0)),
        )

        mm_mask = torch.cat([cmasks[:, :1], vmasks, cmasks[:, 1:]], dim=1)
        if self.last_iso_layer == 0:
            # hard attention mask.
            return mm_mask
        else:
            # a gpu iso mask; 0 : num_iso_layer is isolated;
            # num_iso_layer: are MM-fused.
            # make an iso layer
            batch_size = cmasks.size(0)
            iso_mask = self._make_iso_mask(batch_size, cmasks, vmasks)
            mm_mask = mm_mask[:, None, :].repeat(1, mm_mask.size(-1), 1)
            iso_mm_masks = []
            # hard attention mask.
            iso_mask = iso_mask[:, None, :, :].repeat(
                1, self.last_iso_layer, 1, 1)
            iso_mm_masks.append(iso_mask)
            if self.last_iso_layer < self.num_hidden_layers:
                mm_mask = mm_mask[:, None, :, :].repeat(
                    1, self.num_hidden_layers - self.last_iso_layer, 1, 1
                )
                iso_mm_masks.append(mm_mask)
            iso_mm_masks = torch.cat(iso_mm_masks, dim=1)
            return iso_mm_masks

    def _make_iso_mask(self, batch_size, cmasks, vmasks):
        cls_self_mask = torch.cat(
            [
                torch.ones(
                    (batch_size, 1), dtype=torch.bool, device=cmasks.device),
                torch.zeros(
                    (batch_size, cmasks.size(1) + vmasks.size(1) - 1),
                    dtype=torch.bool, device=cmasks.device)
            ], dim=1)

        iso_video_mask = torch.cat(
            [
                # [CLS] is not used.
                torch.zeros(
                    (batch_size, 1), dtype=torch.bool, device=cmasks.device
                ),
                vmasks,
                # assume to be 1.
                cmasks[:, 1:2],
                # 2 means [CLS] + [SEP]
                torch.zeros(
                    (batch_size, cmasks.size(1) - 2),
                    dtype=torch.bool,
                    device=cmasks.device,
                ),
            ],
            dim=1,
        )
        iso_text_mask = torch.cat(
            [
                torch.zeros(
                    (batch_size, 2 + vmasks.size(1)),
                    dtype=torch.bool,
                    device=cmasks.device,
                ),  # [CLS] is not used.
                cmasks[:, 2:],  # assume to be 1.
            ],
            dim=1,
        )
        cls_self_mask = cls_self_mask[:, None, :]
        iso_video_mask = iso_video_mask[:, None, :].repeat(
            1, vmasks.size(1) + 1, 1)
        iso_text_mask = iso_text_mask[:, None, :].repeat(
            1, cmasks.size(1) - 2, 1)
        return torch.cat([cls_self_mask, iso_video_mask, iso_text_mask], dim=1)

    def _pooling_vt_layer(
        self,
        layered_sequence_output,
        cmasks,
        vmasks
    ):
        layer_idx = self.last_iso_layer \
                if self.last_iso_layer > 0 else self.num_hidden_layers
        hidden_state = layered_sequence_output[layer_idx]
        # also output pooled_video and pooled_text.
        batch_size = cmasks.size(0)
        # pool the modality.
        text_offset = vmasks.size(1) + 2  # [CLS] + [SEP]
        # video tokens + [SEP]
        video_outputs = hidden_state[:, 1:text_offset]
        video_attention_mask = torch.cat(
            [
                vmasks,
                torch.ones(
                    (batch_size, 1), dtype=torch.bool, device=vmasks.device),
            ],
            dim=1,
        )
        assert video_outputs.size(1) == video_attention_mask.size(1)
        pooled_video = torch.sum(
            video_outputs * video_attention_mask.unsqueeze(-1), dim=1
        ) / video_attention_mask.sum(1, keepdim=True)
        # pooled_video = torch.mean(video_outputs[0], dim=1)

        # text tokens + [SEP]
        text_attention_mask = cmasks[:, 2:]
        text_outputs = hidden_state[:, text_offset:]
        assert text_outputs.size(1) == text_attention_mask.size(1)
        pooled_text = torch.sum(
            text_outputs * text_attention_mask.unsqueeze(-1), dim=1
        ) / text_attention_mask.sum(1, keepdim=True)
        return pooled_video, pooled_text


class MMFusionMFMMLM(MMFusion):
    """forward function for MFM and MLM."""
    def forward(
        self,
        caps,
        cmasks,
        vfeats,
        vmasks,
        attention_mask=None,
        video_label=None,
        text_label=None,
        **kwargs
    ):
        output_hidden_states = False if self.is_train else True

        target_vfeats, non_masked_frame_mask = None, None
        if video_label is not None:
            target_vfeats = vfeats.masked_select(
                video_label.unsqueeze(-1)).view(
                -1, vfeats.size(-1)
            )
            # mask video token.
            vfeats[video_label] = 0.0
            non_masked_frame_mask = vmasks.clone()
            non_masked_frame_mask[video_label] = False

        attention_mask, token_type_ids = self._mm_on_the_fly(
            cmasks, vmasks, attention_mask)

        outputs = self.mm_encoder(
            input_ids=caps,
            input_video_embeds=vfeats,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            masked_frame_labels=video_label,
            target_video_hidden_states=target_vfeats,
            non_masked_frame_mask=non_masked_frame_mask,
            masked_lm_labels=text_label,
            output_hidden_states=output_hidden_states,
        )

        video_logits, text_logits = outputs[0], outputs[1]

        if self.is_train:  # return earlier for training.
            return {
                "video_logits": video_logits,
                "text_logits": text_logits,
            }

        pooled_video, pooled_text = self._pooling_vt_layer(
            outputs[2], cmasks, vmasks)
        return {"pooled_video": pooled_video, "pooled_text": pooled_text}


class MMFusionMTM(MMFusionMFMMLM):
    def __init__(self, config, **kwargs):
        super().__init__(config)
        """
        For reproducibility:
        self.mm_encoder will be initialized then discarded.
        """
        from .transformermodel import MMBertForMTM
        model_config = AutoConfig.from_pretrained(config.dataset.bert_name)
        model_config.max_video_len = config.dataset.max_video_len
        model_config.use_seg_emb = config.model.use_seg_emb
        self.mm_encoder = MMBertForMTM.from_pretrained(
            config.dataset.bert_name, config=model_config)


class MMFusionShare(MMFusion):
    """A retrival wrapper using mm_encoder as both video/text backbone.
    TODO: move formally.
    """
    def forward(
        self,
        caps,
        cmasks,
        vfeats,
        vmasks,
        attention_mask=None,
        video_label=None,
        text_label=None,
        output_hidden_states=False,
        **kwargs
    ):
        pooled_video = self.forward_video(
            vfeats,
            vmasks,
            caps,
            cmasks,
            output_hidden_states
        )

        pooled_text = self.forward_text(
            caps,
            cmasks,
            output_hidden_states
        )

        return {"pooled_video": pooled_video, "pooled_text": pooled_text}

    def forward_video(
        self,
        vfeats,
        vmasks,
        caps,
        cmasks,
        output_hidden_states=False,
        **kwargs
    ):
        input_ids = caps[:, :2]

        attention_mask = torch.cat([
            cmasks[:, :1],
            vmasks,
            cmasks[:, 1:2]
        ], dim=1)

        token_type_ids = torch.zeros(
            (vmasks.size(0), vmasks.size(1) + 2),
            dtype=torch.long,
            device=vmasks.device)

        outputs = self.mm_encoder(
            input_ids=input_ids,
            input_video_embeds=vfeats,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            output_hidden_states=True
        )
        video_outputs = outputs[0]

        if output_hidden_states:
            return video_outputs

        batch_size = cmasks.size(0)

        video_attention_mask = torch.cat(
            [
                torch.zeros(
                    (batch_size, 1), dtype=torch.bool, device=vmasks.device),
                vmasks,
                torch.ones(
                    (batch_size, 1), dtype=torch.bool, device=vmasks.device),
            ],
            dim=1,
        )
        assert video_outputs.size(1) == video_attention_mask.size(1)

        video_attention_mask = video_attention_mask.type(video_outputs.dtype) \
            / video_attention_mask.sum(1, keepdim=True)

        pooled_video = torch.bmm(
            video_outputs.transpose(2, 1),
            video_attention_mask.unsqueeze(2)
        ).squeeze(-1)
        return pooled_video  # video_outputs

    def forward_text(
        self,
        caps,
        cmasks,
        output_hidden_states=False,
        **kwargs
    ):
        input_ids = torch.cat([
            caps[:, :1], caps[:, 2:],
            ], dim=1)

        attention_mask = torch.cat([
            cmasks[:, :1],
            cmasks[:, 2:]
        ], dim=1)

        token_type_ids = torch.cat([
            torch.zeros(
                (cmasks.size(0), 1),
                dtype=torch.long,
                device=cmasks.device),
            torch.ones(
                (cmasks.size(0), cmasks.size(1) - 2),
                dtype=torch.long,
                device=cmasks.device)
            ], dim=1)

        outputs = self.mm_encoder(
            input_ids=input_ids,
            input_video_embeds=None,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            output_hidden_states=True
        )
        text_outputs = outputs[0]

        if output_hidden_states:
            return text_outputs

        batch_size = caps.size(0)
        # text tokens + [SEP]
        text_attention_mask = torch.cat([
            torch.zeros(
                (batch_size, 1), dtype=torch.bool, device=cmasks.device),
            cmasks[:, 2:]
        ], dim=1)

        assert text_outputs.size(1) == text_attention_mask.size(1)

        text_attention_mask = text_attention_mask.type(text_outputs.dtype) \
            / text_attention_mask.sum(1, keepdim=True)

        pooled_text = torch.bmm(
            text_outputs.transpose(2, 1),
            text_attention_mask.unsqueeze(2)
        ).squeeze(-1)
        return pooled_text  # text_outputs


class MMFusionSeparate(MMFusionShare):
    def forward_video(
        self,
        vfeats,
        vmasks,
        caps,
        cmasks,
        output_hidden_states=False,
        **kwargs
    ):
        input_ids = caps[:, :2]

        attention_mask = torch.cat([
            cmasks[:, :1],
            vmasks,
            cmasks[:, 1:2]
        ], dim=1)

        token_type_ids = torch.zeros(
            (vmasks.size(0), vmasks.size(1) + 2),
            dtype=torch.long,
            device=vmasks.device)

        outputs = self.video_encoder(
            input_ids=input_ids,
            input_video_embeds=vfeats,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            output_hidden_states=True
        )
        video_outputs = outputs[0]

        if output_hidden_states:
            return video_outputs

        batch_size = cmasks.size(0)

        video_attention_mask = torch.cat(
            [
                torch.zeros(
                    (batch_size, 1), dtype=torch.bool, device=vmasks.device),
                vmasks,
                torch.ones(
                    (batch_size, 1), dtype=torch.bool, device=vmasks.device),
            ],
            dim=1,
        )
        assert video_outputs.size(1) == video_attention_mask.size(1)

        video_attention_mask = video_attention_mask.type(video_outputs.dtype) \
            / video_attention_mask.sum(1, keepdim=True)

        pooled_video = torch.bmm(
            video_outputs.transpose(2, 1),
            video_attention_mask.unsqueeze(2)
        ).squeeze(-1)
        return pooled_video  # video_outputs

    def forward_text(
        self,
        caps,
        cmasks,
        output_hidden_states=False,
        **kwargs
    ):
        input_ids = torch.cat([
            caps[:, :1], caps[:, 2:],
            ], dim=1)

        attention_mask = torch.cat([
            cmasks[:, :1],
            cmasks[:, 2:]
        ], dim=1)
        # different from sharing, we use all-0 type.
        token_type_ids = torch.zeros(
            (cmasks.size(0), cmasks.size(1) - 1),
            dtype=torch.long,
            device=cmasks.device)

        outputs = self.text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            output_hidden_states=True
        )
        text_outputs = outputs[0]

        if output_hidden_states:
            return text_outputs

        batch_size = caps.size(0)
        # text tokens + [SEP]
        text_attention_mask = torch.cat([
            torch.zeros(
                (batch_size, 1), dtype=torch.bool, device=cmasks.device),
            cmasks[:, 2:]
        ], dim=1)

        assert text_outputs.size(1) == text_attention_mask.size(1)

        text_attention_mask = text_attention_mask.type(text_outputs.dtype) \
            / text_attention_mask.sum(1, keepdim=True)

        pooled_text = torch.bmm(
            text_outputs.transpose(2, 1),
            text_attention_mask.unsqueeze(2)
        ).squeeze(-1)
        return pooled_text  # text_outputs


class MMFusionJoint(MMFusion):
    """fine-tuning wrapper for retrival task."""

    def forward(
        self,
        caps,
        cmasks,
        vfeats,
        vmasks,
        attention_mask=None,
        video_label=None,
        text_label=None,
        **kwargs
    ):
        # TODO (huxu): other ways to do negative examples; move the following
        # into your criterion forward.
        output_hidden_states = True

        attention_mask, token_type_ids = self._mm_on_the_fly(
            cmasks, vmasks, attention_mask)

        separate_forward_split = (
            None if self.is_train else vmasks.size(1) + 2
        )  # [CLS] + [SEP]

        outputs = self.mm_encoder(
            input_ids=caps,
            input_video_embeds=vfeats,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            output_hidden_states=output_hidden_states,
            separate_forward_split=separate_forward_split,
        )

        pooled_video, pooled_text = self._pooling_vt_layer(
            outputs[2], cmasks, vmasks)
        return {"pooled_video": pooled_video, "pooled_text": pooled_text}


class MMFusionActionSegmentation(MMFusion):
    """Fine-tuning wrapper for action segmentation.
    TODO: rename this for VLM.
    """
    def forward(
        self,
        caps,
        cmasks,
        vfeats,
        vmasks,
        attention_mask=None,
        **kwargs
    ):
        # ActionLocalization assume of batch_size=1, squeeze it.
        caps = caps.view(-1, caps.size(-1))
        cmasks = cmasks.view(-1, cmasks.size(-1))
        vfeats = vfeats.view(-1, vfeats.size(2), vfeats.size(3))
        vmasks = vmasks.view(-1, vmasks.size(-1))

        # this may not cover all shapes of attention_mask.
        attention_mask = attention_mask.view(
            -1, attention_mask.size(2), attention_mask.size(3)) \
            if attention_mask is not None else None

        # TODO (huxu): other ways to do negative examples; move the following
        # into your criterion forward.
        output_hidden_states = True

        #  video forwarding, text is dummy; never use attention_mask.
        attention_mask, token_type_ids = self._mm_on_the_fly(
            cmasks, vmasks, attention_mask)

        logits = self.mm_encoder(
            input_ids=caps,
            input_video_embeds=vfeats,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            output_hidden_states=output_hidden_states,
        )
        return {"logits": logits[0][:, 1:vmasks.size(1)+1]}


class MMFusionActionLocalization(MMFusion):
    """fine-tuning model for retrival task."""

    def __init__(self, config, **kwargs):
        super().__init__(config)
        tokenizer = AutoTokenizer.from_pretrained(
            config.dataset.bert_name)
        self.cls_token_id = tokenizer.cls_token_id
        self.sep_token_id = tokenizer.sep_token_id
        self.pad_token_id = tokenizer.pad_token_id

    def forward(
        self,
        caps,
        cmasks,
        vfeats,
        vmasks,
        attention_mask=None,
        **kwargs
    ):
        # ActionLocalization assume of batch_size=1, squeeze it.
        caps = caps.squeeze(0)
        cmasks = cmasks.squeeze(0)
        vfeats = vfeats.squeeze(0)
        vmasks = vmasks.squeeze(0)
        attention_mask = attention_mask.squeeze(0) if attention_mask is not None else None

        # TODO (huxu): other ways to do negative examples; move the following
        # into your criterion forward.
        output_hidden_states = True

        # a len1 dummy video token.
        dummy_vfeats = torch.zeros(
            (caps.size(0), 1, vfeats.size(-1)), device=vfeats.device, dtype=vfeats.dtype)
        dummy_vmasks = torch.ones(
            (caps.size(0), 1), dtype=torch.bool,
            device=vfeats.device)

        dummy_caps = torch.LongTensor(
            [[self.cls_token_id, self.sep_token_id,
              self.pad_token_id, self.sep_token_id]],
            ).to(caps.device).repeat(vfeats.size(0), 1)
        dummy_cmasks = torch.BoolTensor(
            [[0, 1, 0, 1]]  # pad are valid for attention.
            ).to(caps.device).repeat(vfeats.size(0), 1)

        #  video forwarding, text is dummy; never use attention_mask.
        attention_mask, token_type_ids = self._mm_on_the_fly(
            dummy_cmasks, vmasks, None)

        outputs = self.mm_encoder(
            input_ids=dummy_caps,
            input_video_embeds=vfeats,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            output_hidden_states=output_hidden_states,
        )

        layer_idx = self.last_iso_layer \
                if self.last_iso_layer > 0 else self.num_hidden_layers

        video_seq = outputs[2][layer_idx][:, 1:vmasks.size(1)+1].masked_select(
                vmasks.unsqueeze(-1)
            ).view(-1, self.hidden_size)

        # text forwarding, video is dummy
        attention_mask, token_type_ids = self._mm_on_the_fly(
            cmasks, dummy_vmasks, None)

        outputs = self.mm_encoder(
            input_ids=caps,
            input_video_embeds=dummy_vfeats,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            output_hidden_states=output_hidden_states,
        )

        _, pooled_text = self._pooling_vt_layer(
            outputs[2], cmasks, dummy_vmasks)
        # this line is not right.
        logits = torch.mm(video_seq, pooled_text.transpose(1, 0))
        return {"logits": logits}


# --------------- MMFusionSeparate for end tasks ---------------

class MMFusionSeparateActionSegmentation(MMFusionSeparate):
    """Fine-tuning wrapper for action segmentation."""
    def forward(
        self,
        caps,
        cmasks,
        vfeats,
        vmasks,
        attention_mask=None,
        **kwargs
    ):
        # ActionLocalization assume of batch_size=1, squeeze it.
        caps = caps.view(-1, caps.size(-1))
        cmasks = cmasks.view(-1, cmasks.size(-1))
        vfeats = vfeats.view(-1, vfeats.size(2), vfeats.size(3))
        vmasks = vmasks.view(-1, vmasks.size(-1))
        logits = self.forward_video(
            vfeats,
            vmasks,
            caps,
            cmasks,
            output_hidden_states=True
        )
        return {"logits": logits[:, 1:vmasks.size(1)+1]}


class MMFusionSeparateActionLocalization(MMFusionSeparate):
    def __init__(self, config, **kwargs):
        super().__init__(config)
        tokenizer = AutoTokenizer.from_pretrained(
            config.dataset.bert_name)
        self.cls_token_id = tokenizer.cls_token_id
        self.sep_token_id = tokenizer.sep_token_id
        self.pad_token_id = tokenizer.pad_token_id

    def forward(
        self,
        caps,
        cmasks,
        vfeats,
        vmasks,
        **kwargs
    ):
        # ActionLocalization assume of batch_size=1, squeeze it.
        caps = caps.squeeze(0)
        cmasks = cmasks.squeeze(0)
        vfeats = vfeats.squeeze(0)
        vmasks = vmasks.squeeze(0)

        # TODO (huxu): other ways to do negative examples; move the following
        # into your criterion forward.
        dummy_caps = torch.LongTensor(
            [[self.cls_token_id, self.sep_token_id,
              self.pad_token_id, self.sep_token_id]],
            ).to(caps.device).repeat(vfeats.size(0), 1)
        dummy_cmasks = torch.BoolTensor(
            [[0, 1, 0, 1]]  # pad are valid for attention.
            ).to(caps.device).repeat(vfeats.size(0), 1)

        outputs = self.forward_video(
            vfeats,
            vmasks,
            dummy_caps,
            dummy_cmasks,
            output_hidden_states=True
        )

        video_seq = outputs[:, 1:vmasks.size(1)+1].masked_select(
                vmasks.unsqueeze(-1)
            ).view(-1, self.hidden_size)

        pooled_text = self.forward_text(
            caps,
            cmasks,
            output_hidden_states=False
        )

        # this line is not right.
        logits = torch.mm(video_seq, pooled_text.transpose(1, 0))
        return {"logits": logits}


class MMFusionShareActionLocalization(MMFusionShare):
    def __init__(self, config, **kwargs):
        super().__init__(config)
        tokenizer = AutoTokenizer.from_pretrained(
            config.dataset.bert_name)
        self.cls_token_id = tokenizer.cls_token_id
        self.sep_token_id = tokenizer.sep_token_id
        self.pad_token_id = tokenizer.pad_token_id

    def forward(
        self,
        caps,
        cmasks,
        vfeats,
        vmasks,
        **kwargs
    ):
        # ActionLocalization assume of batch_size=1, squeeze it.
        caps = caps.squeeze(0)
        cmasks = cmasks.squeeze(0)
        vfeats = vfeats.squeeze(0)
        vmasks = vmasks.squeeze(0)

        # TODO (huxu): other ways to do negative examples; move the following
        # into your criterion forward.
        dummy_caps = torch.LongTensor(
            [[self.cls_token_id, self.sep_token_id,
              self.pad_token_id, self.sep_token_id]],
            ).to(caps.device).repeat(vfeats.size(0), 1)
        dummy_cmasks = torch.BoolTensor(
            [[0, 1, 0, 1]]  # pad are valid for attention.
            ).to(caps.device).repeat(vfeats.size(0), 1)

        outputs = self.forward_video(
            vfeats,
            vmasks,
            dummy_caps,
            dummy_cmasks,
            output_hidden_states=True
        )

        video_seq = outputs[:, 1:vmasks.size(1)+1].masked_select(
                vmasks.unsqueeze(-1)
            ).view(-1, self.hidden_size)

        pooled_text = self.forward_text(
            caps,
            cmasks,
            output_hidden_states=False
        )

        # this line is not right.
        logits = torch.mm(video_seq, pooled_text.transpose(1, 0))
        return {"logits": logits}
