import torch
from torch import nn
from src.util import load_embedding
from src.bert_embedding import BertEmbeddingPredictor


class EmbeddingRegularizer(nn.Module):
    ''' Perform word embedding regularization training for ASR'''

    def __init__(self, tokenizer, dec_dim, enable, src, distance, weight, fuse, temperature,
                 freeze=True, fuse_normalize=False, dropout=0.0, bert=None):
        super(EmbeddingRegularizer, self).__init__()
        self.enable = enable
        if enable:
            if bert is not None:
                self.use_bert = True
                if not isinstance(bert, str):
                    raise ValueError(
                        "`bert` should be a str specifying bert config such as \"bert-base-uncased\".")
                self.emb_table = BertEmbeddingPredictor(bert, tokenizer, src)
                vocab_size, emb_dim = self.emb_table.model.bert.embeddings.word_embeddings.weight.shape
                vocab_size = vocab_size-3  # cls,sep,mask not used
                self.dim = emb_dim
            else:
                self.use_bert = False
                pretrained_emb = torch.FloatTensor(
                    load_embedding(tokenizer, src))
                # pretrained_emb = nn.functional.normalize(pretrained_emb,dim=-1) # ToDo : Check impact on old version
                vocab_size, emb_dim = pretrained_emb.shape
                self.dim = emb_dim

                self.emb_table = nn.Embedding.from_pretrained(
                    pretrained_emb, freeze=freeze, padding_idx=0)

            self.emb_net = nn.Sequential(nn.Linear(dec_dim, (emb_dim+dec_dim)//2),
                                         nn.ReLU(),
                                         nn.Linear((emb_dim+dec_dim)//2, emb_dim))
            self.weight = weight
            self.distance = distance
            self.fuse_normalize = fuse_normalize
            if distance == 'CosEmb':
                # This maybe somewhat reduandant since cos emb loss includes ||x||
                self.measurement = nn.CosineEmbeddingLoss(reduction='none')
            elif distance == 'MSE':
                self.measurement = nn.MSELoss(reduction='none')
            else:
                raise NotImplementedError

            self.apply_dropout = dropout > 0
            if self.apply_dropout:
                self.dropout = nn.Dropout(dropout)

            self.apply_fuse = fuse != 0
            if self.apply_fuse:
                # Weight for mixing emb/dec prob
                if fuse == -1:
                    # Learnable fusion
                    self.fuse_type = "learnable"
                    self.fuse_learnable = True
                    self.fuse_lambda = nn.Parameter(
                        data=torch.FloatTensor([0.5]))
                elif fuse == -2:
                    # Learnable vocab-wise fusion
                    self.fuse_type = "vocab-wise learnable"
                    self.fuse_learnable = True
                    self.fuse_lambda = nn.Parameter(
                        torch.ones((vocab_size))*0.5)
                else:
                    self.fuse_type = str(fuse)
                    self.fuse_learnable = False
                    self.register_buffer(
                        'fuse_lambda', torch.FloatTensor([fuse]))
                # Temperature of emb prob.
                if temperature == -1:
                    self.temperature = 'learnable'
                    self.temp = nn.Parameter(data=torch.FloatTensor([1]))
                elif temperature == -2:
                    self.temperature = 'elementwise'
                    self.temp = nn.Parameter(torch.ones((vocab_size)))
                else:
                    self.temperature = str(temperature)
                    self.register_buffer(
                        'temp', torch.FloatTensor([temperature]))
                self.eps = 1e-8

    def create_msg(self):
        msg = ['Plugin.    | Word embedding regularization enabled (type:{}, weight:{})'.format(
            self.distance, self.weight)]
        if self.apply_fuse:
            msg.append('           | Embedding-fusion decoder enabled ( temp. = {}, lambda = {} )'.
                       format(self.temperature, self.fuse_type))
        return msg

    def get_weight(self):
        if self.fuse_learnable:
            return torch.sigmoid(self.fuse_lambda).mean().cpu().data
        else:
            return self.fuse_lambda

    def get_temp(self):
        return nn.functional.relu(self.temp).mean()

    def fuse_prob(self, x_emb, dec_logit):
        ''' Takes context and decoder logit to perform word embedding fusion '''
        # Compute distribution for dec/emb
        if self.fuse_normalize:
            emb_logit = nn.functional.linear(nn.functional.normalize(x_emb, dim=-1),
                                             nn.functional.normalize(self.emb_table.weight, dim=-1))
        else:
            emb_logit = nn.functional.linear(x_emb, self.emb_table.weight)
        emb_prob = (nn.functional.relu(self.temp)*emb_logit).softmax(dim=-1)
        dec_prob = dec_logit.softmax(dim=-1)
        # Mix distribution
        if self.fuse_learnable:
            fused_prob = (1-torch.sigmoid(self.fuse_lambda))*dec_prob +\
                torch.sigmoid(self.fuse_lambda)*emb_prob
        else:
            fused_prob = (1-self.fuse_lambda)*dec_prob + \
                self.fuse_lambda*emb_prob
        # Log-prob
        log_fused_prob = (fused_prob+self.eps).log()

        return log_fused_prob

    def forward(self, dec_state, dec_logit, label=None, return_loss=True):
        # Match embedding dim.
        log_fused_prob = None
        loss = None

        #x_emb = nn.functional.normalize(self.emb_net(dec_state),dim=-1)
        if self.apply_dropout:
            dec_state = self.dropout(dec_state)
        x_emb = self.emb_net(dec_state)

        if return_loss:
            # Compute embedding loss
            b, t = label.shape
            # Retrieve embedding
            if self.use_bert:
                with torch.no_grad():
                    y_emb = self.emb_table(label).contiguous()
            else:
                y_emb = self.emb_table(label)
            # Regression loss on embedding
            if self.distance == 'CosEmb':
                loss = self.measurement(
                    x_emb.view(-1, self.dim), y_emb.view(-1, self.dim), torch.ones(1).to(dec_state.device))
            else:
                loss = self.measurement(
                    x_emb.view(-1, self.dim), y_emb.view(-1, self.dim))
            loss = loss.view(b, t)
            # Mask out padding
            loss = torch.where(label != 0, loss, torch.zeros_like(loss))
            loss = torch.mean(loss.sum(dim=-1) /
                              (label != 0).sum(dim=-1).float())

        if self.apply_fuse:
            log_fused_prob = self.fuse_prob(x_emb, dec_logit)

        return loss, log_fused_prob