import sys

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from model_utils import sort_batch_by_length, SelfAttentiveSum, SimpleDecoder, MultiSimpleDecoder, CNN, GCNMultiDecoder, GCNSimpleDecoder, DotAttn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from label_corr import build_concurr_matrix
import numpy as np
from attention import SimpleEncoder

sys.path.insert(0, './resources')
import constant

def cosine_similarity(x1, x2=None, eps=1e-8):
    x2 = x1 if x2 is None else x2
    w1 = x1.norm(p=2, dim=1, keepdim=True)
    w2 = w1 if x2 is x1 else x2.norm(p=2, dim=1, keepdim=True)
    return torch.mm(x1, x2.t()) / (w1 * w2.t()).clamp(min=eps)

def gelu(x):
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

class Fusion(nn.Module):
    """docstring for Fusion"""
    def __init__(self, d_hid):
        super(Fusion, self).__init__()
        self.r = nn.Linear(d_hid*3, d_hid)
        self.g = nn.Linear(d_hid*3, d_hid)

    def forward(self, x, y):
        r_ = gelu(self.r(torch.cat([x,y,x-y], dim=-1)))
        g_ = torch.sigmoid(self.g(torch.cat([x,y,x-y], dim=-1)))
        return g_ * r_ + (1 - g_) * x
    

class Model(nn.Module):
  def __init__(self, args, answer_num):
    super(Model, self).__init__()
    self.args = args
    self.output_dim = args.rnn_dim * 2
    self.mention_dropout = nn.Dropout(args.mention_dropout)
    self.input_dropout = nn.Dropout(args.input_dropout)
    self.dim_hidden = args.dim_hidden
    self.embed_dim = 300
    self.mention_dim = 300
    self.lstm_type = args.lstm_type
    self.enhanced_mention = args.enhanced_mention
    if args.enhanced_mention:
      self.head_attentive_sum = SelfAttentiveSum(self.mention_dim, 1)
      self.cnn = CNN()
      self.mention_dim += 50
    self.output_dim += self.mention_dim

    if args.model_debug:
      self.mention_proj = nn.Linear(self.mention_dim, 2*args.rnn_dim)
      self.attn = nn.Linear(2*args.rnn_dim, 2*args.rnn_dim)
      self.fusion = Fusion(2*args.rnn_dim)
      self.output_dim = 2*args.rnn_dim*2

    self.batch_num = 0

    if args.add_regu:
      corr_matrix, _, _, mask, mask_inverse = build_concurr_matrix(goal=args.goal)
      corr_matrix -= np.identity(corr_matrix.shape[0])
      self.corr_matrix = torch.from_numpy(corr_matrix).to(torch.device('cuda')).float()
      self.incon_mask = torch.from_numpy(mask).to(torch.device('cuda')).float()
      self.con_mask = torch.from_numpy(mask_inverse).to(torch.device('cuda')).float()

      self.b = nn.Parameter(torch.rand(corr_matrix.shape[0], 1))
      self.b_ = nn.Parameter(torch.rand(corr_matrix.shape[0], 1))

    # Defining LSTM here.   
    self.attentive_sum = SelfAttentiveSum(args.rnn_dim * 2, 100)
    if self.lstm_type == "two":
      self.left_lstm = nn.LSTM(self.embed_dim, 100, bidirectional=True, batch_first=True)
      self.right_lstm = nn.LSTM(self.embed_dim, 100, bidirectional=True, batch_first=True)
    elif self.lstm_type == 'single':
      self.lstm = nn.LSTM(self.embed_dim + 50, args.rnn_dim, bidirectional=True,
                          batch_first=True)
      self.token_mask = nn.Linear(4, 50)

    if args.self_attn:
      self.embed_proj = nn.Linear(self.embed_dim + 50, 2*args.rnn_dim)
      self.encoder = SimpleEncoder(2*args.rnn_dim, head=4, layer=1, dropout=0.2)

    self.loss_func = nn.BCEWithLogitsLoss()
    self.sigmoid_fn = nn.Sigmoid()
    self.goal = args.goal
    self.multitask = args.multitask

    if args.data_setup == 'joint' and args.multitask and args.gcn:
      print("Multi-task learning with gcn on labels")
      self.decoder = GCNMultiDecoder(self.output_dim)
    elif args.data_setup == 'joint' and args.multitask:
      print("Multi-task learning")
      self.decoder = MultiSimpleDecoder(self.output_dim)
    elif args.data_setup == 'joint' and not args.multitask and args.gcn:
      print("Joint training with GCN simple decoder")
      self.decoder = GCNSimpleDecoder(self.output_dim, answer_num, "open"
        )
    elif args.goal == 'onto' and args.gcn:
      print("Ontonotes with gcn decoder")
      self.decoder = GCNSimpleDecoder(self.output_dim, answer_num, "onto")
    else:
      print("Ontonotes using simple decoder")
      self.decoder = SimpleDecoder(self.output_dim, answer_num)

  def sorted_rnn(self, sequences, sequence_lengths, rnn):
    sorted_inputs, sorted_sequence_lengths, restoration_indices = sort_batch_by_length(sequences, sequence_lengths)
    packed_sequence_input = pack_padded_sequence(sorted_inputs,
                                                 sorted_sequence_lengths.data.tolist(),
                                                 batch_first=True)
    packed_sequence_output, _ = rnn(packed_sequence_input, None)
    unpacked_sequence_tensor, _ = pad_packed_sequence(packed_sequence_output, batch_first=True)
    return unpacked_sequence_tensor.index_select(0, restoration_indices)

  def rnn(self, sequences, lstm):
    outputs, _ = lstm(sequences)
    return outputs.contiguous()

  def define_loss(self, logits, targets, data_type):
    if not self.multitask or data_type == 'onto':
      loss = self.loss_func(logits, targets)
      return loss
    if data_type == 'wiki':
      gen_cutoff, fine_cutoff, final_cutoff = constant.ANSWER_NUM_DICT['gen'], constant.ANSWER_NUM_DICT['kb'], \
                                              constant.ANSWER_NUM_DICT[data_type]
    else:
      gen_cutoff, fine_cutoff, final_cutoff = constant.ANSWER_NUM_DICT['gen'], constant.ANSWER_NUM_DICT['kb'], None
    loss = 0.0
    comparison_tensor = torch.Tensor([1.0]).cuda()
    gen_targets = targets[:, :gen_cutoff]
    fine_targets = targets[:, gen_cutoff:fine_cutoff]
    gen_target_sum = torch.sum(gen_targets, 1)
    fine_target_sum = torch.sum(fine_targets, 1)

    if torch.sum(gen_target_sum.data) > 0:
      gen_mask = torch.squeeze(torch.nonzero(torch.min(gen_target_sum.data, comparison_tensor)), dim=1)
      gen_logit_masked = logits[:, :gen_cutoff][gen_mask, :]
      gen_target_masked = gen_targets.index_select(0, gen_mask)
      gen_loss = self.loss_func(gen_logit_masked, gen_target_masked)
      loss += gen_loss 
    if torch.sum(fine_target_sum.data) > 0:
      fine_mask = torch.squeeze(torch.nonzero(torch.min(fine_target_sum.data, comparison_tensor)), dim=1)
      fine_logit_masked = logits[:,gen_cutoff:fine_cutoff][fine_mask, :]
      fine_target_masked = fine_targets.index_select(0, fine_mask)
      fine_loss = self.loss_func(fine_logit_masked, fine_target_masked)
      loss += fine_loss 

    if not data_type == 'kb':
      if final_cutoff:
        finer_targets = targets[:, fine_cutoff:final_cutoff]
        logit_masked = logits[:, fine_cutoff:final_cutoff]
      else:
        logit_masked = logits[:, fine_cutoff:]
        finer_targets = targets[:, fine_cutoff:]
      if torch.sum(torch.sum(finer_targets, 1).data) >0:
        finer_mask = torch.squeeze(torch.nonzero(torch.min(torch.sum(finer_targets, 1).data, comparison_tensor)), dim=1)
        finer_target_masked = finer_targets.index_select(0, finer_mask)
        logit_masked = logit_masked[finer_mask, :]
        layer_loss = self.loss_func(logit_masked, finer_target_masked)
        loss += layer_loss

    if self.args.add_regu:
      if self.batch_num > self.args.regu_steps:

        # inconsistency loss 1: never concurr, then -1, otherwise log
        # label_matrix = cosine_similarity(self.decoder.linear.weight, self.decoder.linear.weight)
        # target = -1 * self.incon_mask  + self.con_mask * torch.log(self.corr_matrix + 1e-8)
        # auxiliary_loss = ((target - label_matrix) ** 2).mean()
        # loss += self.args.incon_w * auxiliary_loss


        # glove like loss
        less_max_mask = (self.corr_matrix < 100).float()
        greater_max_mask = (self.corr_matrix >= 100).float()
        weight_matrix = less_max_mask * ((self.corr_matrix / 100.0) ** 0.75) + greater_max_mask
        auxiliary_loss = weight_matrix * (torch.mm(self.decoder.linear.weight, self.decoder.linear.weight.t()) + self.b + self.b_.t() - torch.log(self.corr_matrix + 1e-8)) ** 2
        auxiliary_loss = auxiliary_loss.mean()

        # # inconsistency loss 2: only consider these inconsistency labels
        # label_matrix = cosine_similarity(self.decoder.linear.weight, self.decoder.linear.weight)
        # target = -1 * self.incon_mask
        # auxiliary_loss = (((target - label_matrix) * self.incon_mask) ** 2).sum() / self.incon_mask.sum()
        # loss += self.args.incon_w * auxiliary_loss

        # # inconsitenct loss 3: margin loss
        # label_matrix = cosine_similarity(self.decoder.linear.weight, self.decoder.linear.weight)
        # label_consistent = label_matrix * self.con_mask
        # label_contradict = label_matrix * self.incon_mask
        # distance = label_consistent.sum(1) / (self.con_mask.sum(1) + 1e-8) - label_contradict.sum(1) / (self.incon_mask.sum(1) + 1e-8)
        # margin = 0.2
        # auxiliary_loss = torch.max(torch.tensor(0.0).to(torch.device('cuda')), margin - distance).mean()

        loss += self.args.incon_w * auxiliary_loss

    return loss

  def normalize(self, raw_scores, lengths):
      backup = raw_scores.data.clone()
      max_len = raw_scores.size(2)

      for i, length in enumerate(lengths):
          if length == max_len:
              continue
          raw_scores.data[i, :, int(length):] = -1e30

      normalized_scores = F.softmax(raw_scores, dim=-1)
      raw_scores.data.copy_(backup)
      return normalized_scores

  def forward(self, feed_dict, data_type):
    if self.lstm_type == 'two':
      left_outputs = self.rnn(self.input_dropout(feed_dict['left_embed']), self.left_lstm)
      right_outputs = self.rnn(self.input_dropout(feed_dict['right_embed']), self.right_lstm)
      context_rep = torch.cat((left_outputs, right_outputs), 1)
      context_rep, _ = self.attentive_sum(context_rep)
    elif self.lstm_type == 'single':
      token_mask_embed = self.token_mask(feed_dict['token_bio'].view(-1, 4))
      token_mask_embed = token_mask_embed.view(feed_dict['token_embed'].size()[0], -1, 50)
      token_embed = torch.cat((feed_dict['token_embed'], token_mask_embed), 2)
      context_rep_ = self.sorted_rnn(self.input_dropout(token_embed), feed_dict['token_seq_length'], self.lstm)
      if self.args.goal == 'onto' or self.args.model_id == 'baseline':
        context_rep, _ = self.attentive_sum(context_rep_)
      else:
        context_rep, _ = self.attentive_sum(context_rep_, feed_dict["token_seq_length"])

    # Mention Representation
    if self.enhanced_mention:
      if self.args.goal == 'onto'  or self.args.model_id == 'baseline':
        mention_embed, _ = self.head_attentive_sum(feed_dict['mention_embed'])
      else:
        mention_embed, _ = self.head_attentive_sum(feed_dict['mention_embed'], feed_dict['mention_len'])
      span_cnn_embed = self.cnn(feed_dict['span_chars'])
      mention_embed = torch.cat((span_cnn_embed, mention_embed), 1)
    else:
      mention_embed = torch.sum(feed_dict['mention_embed'], dim=1)
    mention_embed = self.mention_dropout(mention_embed)
    # model change
    if self.args.model_debug:
      mention_embed_proj = self.mention_proj(mention_embed).tanh()
      affinity = self.attn(mention_embed_proj.unsqueeze(1)).bmm(F.dropout(context_rep_.transpose(2,1), 0.1, self.training)) # b*1*50
      m_over_c = self.normalize(affinity, feed_dict['token_seq_length'].squeeze().tolist())
      m_retrieve_c = torch.bmm(m_over_c, context_rep_) # b*1*200
      fusioned = self.fusion(m_retrieve_c.squeeze(1), mention_embed_proj)
      output = F.dropout(torch.cat([fusioned, context_rep], dim=1), 0.2, self.training) # seems to be a good choice for ultra-fine
    else:
      output = F.dropout(torch.cat((context_rep, mention_embed), 1), 0.3, self.training)
      # output = torch.cat((context_rep, mention_embed), 1)

    logits = self.decoder(output, data_type)
    loss = self.define_loss(logits, feed_dict['y'], data_type)

    return loss, logits