#!/usr/bin/env python3

import random
import unittest

import numpy as np
import pytorch_translate.attention.multihead_attention as multihead_attention
import torch
from pytorch_translate.attention import attention_utils, dot_attention, mlp_attention


class TestAttention(unittest.TestCase):
    def setUp(self):
        self.bsz = 10
        self.src_len = 5
        self.ctx_dim = 3
        self.dec_dim = 4
        self.att_dim = 2

    def test_masked_softmax(self):
        scores = torch.rand(20, 20)
        lengths = torch.arange(start=1, end=21)

        masked_normalized_scores = attention_utils.masked_softmax(
            scores, lengths, src_length_masking=True
        )

        for i in range(20):
            scores_sum = masked_normalized_scores[i].numpy().sum()
            self.assertAlmostEqual(scores_sum, 1, places=6)

    def _test_attention(self, attention):
        dummy_source_hids = torch.rand(self.src_len, self.bsz, self.ctx_dim)
        dummy_decoder_state = torch.rand(self.bsz, self.dec_dim)
        dummy_src_lengths = torch.fmod(torch.arange(self.bsz), self.src_len) + 1
        attention(dummy_decoder_state, dummy_source_hids, dummy_src_lengths)

    def test_dot_attention(self):
        self._test_attention(
            dot_attention.DotAttention(
                self.dec_dim,
                self.ctx_dim,
                src_length_masking=True,
                force_projection=True,
            )
        )

    def test_mlp_attention(self):
        self._test_attention(
            mlp_attention.MLPAttention(
                self.dec_dim,
                self.ctx_dim,
                src_length_masking=True,
                attention_dim=self.att_dim,
            )
        )


def _softmax(x):  # softmax over 4 dim matrix
    """ Numpy-based reference softmax over 4 dim matrix"""
    output = np.zeros(x.shape, dtype=np.float32)
    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            for k in range(x.shape[2]):
                x_curr = x[i, j, k, :]
                e_x = np.exp(x_curr - np.amax(x_curr))
                output[i, j, k, :] = e_x / np.sum(e_x)
    return output


def _batchmatmul(a, b):  # batchmatmul over 4 dim matrix
    """ Numpy-based batch matrix multiply over 4 dim matrix"""
    assert a.shape[0] == b.shape[0]
    assert a.shape[1] == b.shape[1]
    retval = np.zeros(
        (a.shape[0], a.shape[1], a.shape[2], b.shape[3]), dtype=np.float32
    )
    for i in range(a.shape[0]):
        for j in range(a.shape[1]):
            retval[i, j, :, :] = np.matmul(a[i, j, :, :], b[i, j, :, :])
    return retval


class MultiheadAttentionTest(unittest.TestCase):
    def _scaled_dot_attn_ref(self, Q, K, V, dims, unseen_mask=False, src_lengths=None):
        """ Numpy-based reference implementation of scaled dot attention
        for testing"""
        QKT = _batchmatmul(
            Q,
            np.transpose(K, axes=[0, 1, 3, 2])
            / np.sqrt(dims[3], dtype=np.float32),  # divide by sqrt(d_head)
        )
        if unseen_mask or src_lengths is not None:
            b1, b2, s1, s2 = QKT.shape
            # assert s1 == s2
            for i in range(b1):
                for j in range(b2):
                    for m in range(s1):
                        for n in range(s2):
                            if unseen_mask and n > m:
                                QKT[i, j, m, n] = -np.inf
                            if src_lengths is not None and n >= src_lengths[i]:
                                QKT[i, j, m, n] = -np.inf
        reference = _softmax(QKT)
        reference = _batchmatmul(reference, V)
        return reference

    def _generate_src_lengths(self, batch_size, seq_len):
        src_lengths = np.array([random.randint(1, seq_len) for i in range(batch_size)])

        # max source length has to equal seq_len, so randomly choose
        # one example to have source length = seq_len
        max_len_example_i = random.randint(0, batch_size - 1)
        src_lengths[max_len_example_i] = seq_len

        src_lengths_tensor = torch.from_numpy(src_lengths).int()
        return src_lengths, src_lengths_tensor

    def _split_heads_ref(self, X, dims, nheads, d_head):
        X_split = np.reshape(X, dims[:2] + [nheads, d_head])
        X_split_transposed = np.transpose(X_split, [0, 2, 1, 3])
        reference = np.reshape(X_split_transposed, [dims[0], nheads, dims[1], d_head])
        return reference

    def _combine_heads_ref(self, X, dims, nheads, d_head):
        X_transposed = np.transpose(X, [0, 2, 1, 3])
        reference = np.reshape(X_transposed, dims[:2] + [nheads * d_head])
        return reference

    def _fc(self, X, X_name, module, start=None, end=None):
        X_fc_b = None
        X_fc_w = None
        for name, param in module.named_parameters():
            if X_name + ".weight" in name or X_name + "_weight" in name:
                if X_fc_w is not None:
                    raise Exception(f"Duplicate FC name {name} found")
                X_fc_w = param.detach().numpy()
            elif X_name + ".bias" in name or X_name + "_bias" in name:
                if X_fc_b is not None:
                    raise Exception(f"Duplicate FC name {name} found")
                X_fc_b = param.detach().numpy()
        return np.matmul(X, np.transpose(X_fc_w)) + X_fc_b

    def _multihead_attn_test_helper(self, use_src_lengths):
        for _ in range(100):
            batch_sz, seq_len = [random.randint(2, 10) for r in range(2)]
            d_head = random.randint(3, 10)
            nheads = random.randint(3, 10)
            d_model = d_head * nheads
            dims = [batch_sz, seq_len, d_model]

            src_lengths = None
            src_lengths_tensor = None
            if use_src_lengths:
                src_lengths, src_lengths_tensor = self._generate_src_lengths(
                    batch_size=batch_sz, seq_len=seq_len
                )

            decoder_state = np.random.rand(batch_sz, d_model).astype(np.float32)
            K = np.random.rand(*dims).astype(np.float32)
            V = K
            Q = np.expand_dims(decoder_state, 1)

            decoder_state_tensor = torch.from_numpy(decoder_state).float()
            source_hid_tensor = torch.from_numpy(K).float().transpose(0, 1)

            multihead_attn_module = multihead_attention.MultiheadAttention(
                context_dim=d_model, decoder_hidden_state_dim=d_model, nheads=nheads
            )

            for squeeze in [True, False]:
                result = (
                    multihead_attn_module(
                        decoder_state=decoder_state_tensor,
                        source_hids=source_hid_tensor,
                        src_lengths=src_lengths_tensor,
                        squeeze=squeeze,
                    )[0]
                    .detach()
                    .numpy()
                )

                if not squeeze:
                    self.assertEqual(result.ndim, 3)
                    result = np.squeeze(result, axis=0)

                Q_fc = self._fc(Q, "q_proj", multihead_attn_module, end=d_model)
                K_fc = self._fc(
                    K, "k_proj", multihead_attn_module, start=d_model, end=2 * d_model
                )
                V_fc = self._fc(V, "v_proj", multihead_attn_module, start=2 * d_model)

                Q_split = self._split_heads_ref(
                    Q_fc, [batch_sz, 1, d_model], nheads, d_head
                )
                K_split = self._split_heads_ref(K_fc, dims, nheads, d_head)
                V_split = self._split_heads_ref(V_fc, dims, nheads, d_head)

                attn_heads = self._scaled_dot_attn_ref(
                    Q=Q_split,
                    K=K_split,
                    V=V_split,
                    dims=Q_split.shape,
                    src_lengths=src_lengths,
                )

                combined_attn_heads = self._combine_heads_ref(
                    X=attn_heads, dims=[batch_sz, 1], nheads=nheads, d_head=d_head
                )

                reference = self._fc(
                    combined_attn_heads, "out_proj", multihead_attn_module
                )
                reference = np.squeeze(reference, axis=1)
                self.assertEqual(tuple(result.shape), (batch_sz, d_model))
                np.testing.assert_allclose(result, reference, atol=1e-5)

    def test_multihead_attn_no_masking(self):
        self._multihead_attn_test_helper(use_src_lengths=None)

    def test_multihead_attn_with_src_lengths(self):
        self._multihead_attn_test_helper(use_src_lengths=True)