"""An implementation of ArcI Model."""
import typing

import torch
import torch.nn as nn

from matchzoo.engine.param_table import ParamTable
from matchzoo.engine.base_callback import BaseCallback
from matchzoo.engine.param import Param
from matchzoo.engine.base_model import BaseModel
from matchzoo.engine import hyper_spaces
from matchzoo.dataloader import callbacks
from matchzoo.utils import parse_activation


class ArcI(BaseModel):
    """
    ArcI Model.

    Examples:
        >>> model = ArcI()
        >>> model.params['left_filters'] = [32]
        >>> model.params['right_filters'] = [32]
        >>> model.params['left_kernel_sizes'] = [3]
        >>> model.params['right_kernel_sizes'] = [3]
        >>> model.params['left_pool_sizes'] = [2]
        >>> model.params['right_pool_sizes'] = [4]
        >>> model.params['conv_activation_func'] = 'relu'
        >>> model.params['mlp_num_layers'] = 1
        >>> model.params['mlp_num_units'] = 64
        >>> model.params['mlp_num_fan_out'] = 32
        >>> model.params['mlp_activation_func'] = 'relu'
        >>> model.params['dropout_rate'] = 0.5
        >>> model.guess_and_fill_missing_params(verbose=0)
        >>> model.build()

    """

    @classmethod
    def get_default_params(cls) -> ParamTable:
        """:return: model default parameters."""
        params = super().get_default_params(
            with_embedding=True,
            with_multi_layer_perceptron=True
        )
        params.add(Param(name='left_length', value=10,
                         desc='Length of left input.'))
        params.add(Param(name='right_length', value=100,
                         desc='Length of right input.'))
        params.add(Param(name='conv_activation_func', value='relu',
                         desc="The activation function in the "
                         "convolution layer."))
        params.add(Param(name='left_filters', value=[32],
                         desc="The filter size of each convolution "
                         "blocks for the left input."))
        params.add(Param(name='left_kernel_sizes', value=[3],
                         desc="The kernel size of each convolution "
                         "blocks for the left input."))
        params.add(Param(name='left_pool_sizes', value=[2],
                         desc="The pooling size of each convolution "
                         "blocks for the left input."))
        params.add(Param(name='right_filters', value=[32],
                         desc="The filter size of each convolution "
                         "blocks for the right input."))
        params.add(Param(name='right_kernel_sizes', value=[3],
                         desc="The kernel size of each convolution "
                         "blocks for the right input."))
        params.add(Param(name='right_pool_sizes', value=[2],
                         desc="The pooling size of each convolution "
                         "blocks for the right input."))
        params.add(Param(
            'dropout_rate', 0.0,
            hyper_space=hyper_spaces.quniform(
                low=0.0, high=0.8, q=0.01),
            desc="The dropout rate."
        ))
        return params

    @classmethod
    def get_default_padding_callback(
        cls,
        fixed_length_left: int = 10,
        fixed_length_right: int = 100,
        pad_word_value: typing.Union[int, str] = 0,
        pad_word_mode: str = 'pre',
        with_ngram: bool = False,
        fixed_ngram_length: int = None,
        pad_ngram_value: typing.Union[int, str] = 0,
        pad_ngram_mode: str = 'pre'
    ) -> BaseCallback:
        """
        Model default padding callback.

        The padding callback's on_batch_unpacked would pad a batch of data to
        a fixed length.

        :return: Default padding callback.
        """
        return callbacks.BasicPadding(
            fixed_length_left=fixed_length_left,
            fixed_length_right=fixed_length_right,
            pad_word_value=pad_word_value,
            pad_word_mode=pad_word_mode,
            with_ngram=with_ngram,
            fixed_ngram_length=fixed_ngram_length,
            pad_ngram_value=pad_ngram_value,
            pad_ngram_mode=pad_ngram_mode
        )

    def build(self):
        """
        Build model structure.

        ArcI use Siamese arthitecture.
        """
        self.embedding = self._make_default_embedding_layer()

        # Build conv
        activation = parse_activation(self._params['conv_activation_func'])
        left_in_channels = [
            self._params['embedding_output_dim'],
            *self._params['left_filters'][:-1]
        ]
        right_in_channels = [
            self._params['embedding_output_dim'],
            *self._params['right_filters'][:-1]
        ]
        conv_left = [
            self._make_conv_pool_block(ic, oc, ks, activation, ps)
            for ic, oc, ks, ps in zip(left_in_channels,
                                      self._params['left_filters'],
                                      self._params['left_kernel_sizes'],
                                      self._params['left_pool_sizes'])
        ]
        conv_right = [
            self._make_conv_pool_block(ic, oc, ks, activation, ps)
            for ic, oc, ks, ps in zip(right_in_channels,
                                      self._params['right_filters'],
                                      self._params['right_kernel_sizes'],
                                      self._params['right_pool_sizes'])
        ]
        self.conv_left = nn.Sequential(*conv_left)
        self.conv_right = nn.Sequential(*conv_right)

        self.dropout = nn.Dropout(p=self._params['dropout_rate'])

        left_length = self._params['left_length']
        right_length = self._params['right_length']
        for ps in self._params['left_pool_sizes']:
            left_length = left_length // ps
        for ps in self._params['right_pool_sizes']:
            right_length = right_length // ps
        self.mlp = self._make_multi_layer_perceptron_layer(
            left_length * self._params['left_filters'][-1] + (
                right_length * self._params['right_filters'][-1])
        )

        self.out = self._make_output_layer(
            self._params['mlp_num_fan_out']
        )

    def forward(self, inputs):
        """Forward."""

        # Scalar dimensions referenced here:
        #   B = batch size (number of sequences)
        #   D = embedding size
        #   L = `input_left` sequence length
        #   R = `input_right` sequence length
        #   F = number of filters
        #   P = pool size

        # Left input and right input.
        # shape = [B, L]
        # shape = [B, R]
        input_left, input_right = inputs['text_left'], inputs['text_right']

        # Process left and right input.
        # shape = [B, D, L]
        # shape = [B, D, R]
        embed_left = self.embedding(input_left.long()).transpose(1, 2)
        embed_right = self.embedding(input_right.long()).transpose(1, 2)

        # Convolution
        # shape = [B, F, L // P]
        # shape = [B, F, R // P]
        conv_left = self.conv_left(embed_left)
        conv_right = self.conv_right(embed_right)

        # shape = [B, F * (L // P)]
        # shape = [B, F * (R // P)]
        rep_left = torch.flatten(conv_left, start_dim=1)
        rep_right = torch.flatten(conv_right, start_dim=1)

        # shape = [B, F * (L // P) + F * (R // P)]
        concat = self.dropout(torch.cat((rep_left, rep_right), dim=1))

        # shape = [B, *]
        dense_output = self.mlp(concat)

        out = self.out(dense_output)
        return out

    @classmethod
    def _make_conv_pool_block(
        cls,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        activation: nn.Module,
        pool_size: int,
    ) -> nn.Module:
        """Make conv pool block."""
        return nn.Sequential(
            nn.ConstantPad1d((0, kernel_size - 1), 0),
            nn.Conv1d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size
            ),
            activation,
            nn.MaxPool1d(kernel_size=pool_size)
        )