```import math
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torch.nn.functional as F

class biDafAttn(nn.Module):
def __init__(self, channel_size):
super(biDafAttn, self).__init__()
"""
This method do biDaf from s2 to s1:
The return value will have the same size as s1.
:param channel_size: Hidden size of the input
"""
# self.mlp = nn.Linear(channel_size * 3, 1, bias=False)

def similarity(self, s1, l1, s2, l2):
"""
:param s1: [B, t1, D]
:param l1: [B]
:param s2: [B, t2, D]
:param l2: [B]
:return:
"""
batch_size = s1.size(0)
t1 = s1.size(1)
t2 = s2.size(1)
S = torch.bmm(s1, s2.transpose(1,
2))  # [B, t1, D] * [B, D, t2] -> [B, t1, t2] S is the similarity matrix from biDAF paper. [B, T1, T2]

s_mask = S.data.new(*S.size()).fill_(1).byte()  # [B, T1, T2]
# Init similarity mask using lengths
for i, (l_1, l_2) in enumerate(zip(l1, l2)):

return S

def get_U_tile(self, S, s2):
a_weight = F.softmax(S, dim=2)  # [B, t1, t2]
a_weight.data.masked_fill_(a_weight.data != a_weight.data, 0)  # remove nan from softmax on -inf
U_tile = torch.bmm(a_weight, s2)  # [B, t1, t2] * [B, t2, D] -> [B, t1, D]
return U_tile

def get_both_tile(self, S, s1, s2):
a_weight = F.softmax(S, dim=2)  # [B, t1, t2]
a_weight.data.masked_fill_(a_weight.data != a_weight.data, 0)  # remove nan from softmax on -inf
U_tile = torch.bmm(a_weight, s2)  # [B, t1, t2] * [B, t2, D] -> [B, t1, D]

a1_weight = F.softmax(S, dim=1)  # [B, t1, t2]
a1_weight.data.masked_fill_(a1_weight.data != a1_weight.data, 0)  # remove nan from softmax on -inf
U1_tile = torch.bmm(a1_weight.transpose(1, 2), s1)  # [B, t2, t1] * [B, t1, D] -> [B, t2, D]
return U_tile, U1_tile

def forward(self, s1, l1, s2, l2):
S = self.similarity(s1, l1, s2, l2)
U_tile = self.get_U_tile(S, s2)
return U_tile

class CoattMaxPool(nn.Module):
def __init__(self, args):

super(CoattMaxPool, self).__init__()
h_size = [300, 300]
d = 300
mlp_d = 300
v_size=args.max_snli_vocab_size
max_l= None
num_of_class=3
drop_r=args.dropout
featurizer=None
itos=None
with_emlo=False
activation_type='relu'
self.h_size = h_size
self.e_embd = nn.Embedding(v_size, d)
self.embd_dropout = nn.Dropout(drop_r)
self.featurizer = featurizer
self.itos = itos
self.args = args
if self.featurizer is not None:
fcount = self.featurizer.n_context_features()
else:
fcount = 0

self.emlo_embedding_d = 0
if with_emlo:
self.emlo_ee = ElmoEmbedder(cuda_device=n_device)
self.emlo_embedding_d = 1024

self.emlo_gamma = nn.Parameter(torch.FloatTensor([1]))
self.emlo_s_vector = nn.Parameter(torch.FloatTensor([1, 1, 1]))

if self.args.cell_type=='gru':
self.lstm = nn.GRU(input_size=d + fcount + self.emlo_embedding_d * 1, hidden_size=h_size[0],
num_layers=1, bidirectional=True, batch_first=True)

self.lstm_1 = nn.GRU(input_size=h_size[1] + fcount + self.emlo_embedding_d * 1, hidden_size=h_size[1],
num_layers=1, bidirectional=True, batch_first=True)

else:
self.lstm = nn.LSTM(input_size=d + fcount + self.emlo_embedding_d * 1, hidden_size=h_size[0],
num_layers=1, bidirectional=True, batch_first=True)

self.lstm_1 = nn.LSTM(input_size=h_size[1] + fcount + self.emlo_embedding_d * 1, hidden_size=h_size[1],
num_layers=1, bidirectional=True, batch_first=True)

self.projection = nn.Linear(h_size[0] * 2 * 4, h_size[1])
self.projection_dropout = nn.Dropout(drop_r)

self.max_l = max_l
self.bidaf = biDafAttn(300)

self.mlp_1 = nn.Linear(h_size[1] * 2 * 4, mlp_d)
self.sm = nn.Linear(mlp_d, num_of_class)
if activation_type == 'relu':
activation = nn.ReLU()
# self.classifier = nn.Sequential(*[self.mlp_1, nn.ReLU(), nn.Dropout(drop_r), self.sm])
elif activation_type == 'tanh':
activation = nn.Tanh()
else:
raise ValueError("Not a valid activation!")

self.classifier = nn.Sequential(*[nn.Dropout(drop_r), self.mlp_1, activation, nn.Dropout(drop_r), self.sm])

def count_params(self):
total_c = 0
for param in self.parameters():
if len(param.size()) == 2:
d1, d2 = param.size()[0], param.size()[1]
total_c += d1 * d2
print("Total count:", total_c)

def display(self):
for param in self.parameters():
print(param.data.size())

def forward(self, s1, l1, s2, l2):  # [B, T]
if self.max_l:
max_l = min(s1.size(1), self.max_l)
max_l = max(1, max_l)
max_s1_l = min(max(l1), max_l)

l1 = l1.clamp(min=1, max=max_s1_l)
if s1.size(1) > max_s1_l:
s1 = s1[:, :max_s1_l]

s1_max_l = s1.size(1)

if self.max_l:
max_l = min(s2.size(1), self.max_l)
max_l = max(1, max_l)
max_s2_l = min(max(l2), max_l)

l2 = l2.clamp(min=1, max=max_s2_l)
if s2.size(1) > max_s2_l:
s2 = s2[:, :max_s2_l]

s2_max_l = s2.size(1)

batch_size = s1.size(0)

th_packed_f_s1, th_packed_f_s2 = None, None
emlo_s1_sum, emlo_s2_sum = None, None

p_s1 = self.e_embd(s1)
p_s2 = self.e_embd(s2)

p_s1 = self.embd_dropout(p_s1)  # Embedding dropout
p_s2 = self.embd_dropout(p_s2)  # Embedding dropout

feature_p_s1 = torch.cat([seq for seq in [p_s1, th_packed_f_s1, emlo_s1_sum] if seq is not None], dim=2)
feature_p_s2 = torch.cat([seq for seq in [p_s2, th_packed_f_s2, emlo_s2_sum] if seq is not None], dim=2)

s1_layer1_out = self.auto_rnn(self.lstm, feature_p_s1, l1)
s2_layer1_out = self.auto_rnn(self.lstm, feature_p_s2, l2)

S = self.bidaf.similarity(s1_layer1_out, l1, s2_layer1_out, l2)
s1_att, s2_att = self.bidaf.get_both_tile(S, s1_layer1_out, s2_layer1_out)

s1_coattentioned = torch.cat([s1_layer1_out, s1_att, s1_layer1_out - s1_att,
s1_layer1_out * s1_att], dim=2)

s2_coattentioned = torch.cat([s2_layer1_out, s2_att, s2_layer1_out - s2_att,
s2_layer1_out * s2_att], dim=2)

p_s1_coattentioned = self.projection_dropout(F.relu(self.projection(s1_coattentioned)))
p_s2_coattentioned = self.projection_dropout(F.relu(self.projection(s2_coattentioned)))

s1_coatt_features = torch.cat(
[seq for seq in [p_s1_coattentioned, th_packed_f_s1, emlo_s1_sum] if seq is not None], dim=2)
s2_coatt_features = torch.cat(
[seq for seq in [p_s2_coattentioned, th_packed_f_s2, emlo_s2_sum] if seq is not None], dim=2)

s1_layer2_out = self.auto_rnn(self.lstm_1, s1_coatt_features, l1)
s2_layer2_out = self.auto_rnn(self.lstm_1, s2_coatt_features, l2)

s1_lay2_maxout = max_along_time(s1_layer2_out, l1)
s2_lay2_maxout = max_along_time(s2_layer2_out, l2)

s1_lay2_avgout = avg_along_time(s1_layer2_out, l1)
s2_lay2_avgout = avg_along_time(s2_layer2_out, l2)

features = torch.cat([s1_lay2_maxout, s2_lay2_maxout,
s1_lay2_avgout, s2_lay2_avgout], dim=1)

logits = self.classifier(features)
probs = F.softmax(logits, 1)
pred = torch.max(probs, 1)[1]

return logits, probs, pred

def auto_rnn(self, rnn: nn.RNN, seqs, lengths, batch_first=True, init_state=None, output_last_states=False):
batch_size = seqs.size(0) if batch_first else seqs.size(1)
state_shape = get_state_shape(rnn, batch_size, rnn.bidirectional)

if not init_state:
h0 = c0 = Variable(seqs.data.new(*state_shape).zero_())
else:
h0 = init_state['h0'].expand(state_shape)
c0 = init_state['c0'].expand(state_shape)

packed_pinputs, r_index = pack_for_rnn_seq(seqs, lengths, batch_first)
if self.args.cell_type == 'gru':
output, hn = rnn(packed_pinputs, h0)
else:
output, (hn, cn) = rnn(packed_pinputs, (h0, c0))
output = unpack_from_rnn_seq(output, r_index, batch_first)

if not output_last_states:
return output
else:
return output, (hn, cn)

"""
The seq is a sequence having shape [T, ..]. Note: The seq contains only one instance. This is not batched.

:param seq:  Input sequence with shape [T, ...]
:return:  Output sequence will have shape [Pad_L, ...]
"""
l = seq.size(0)
return seq[:pad_l, ]  # Truncate the length if the length is bigger than required padded_length.
else:

#TODO The method seems useless to me. Delete?
"""
Padding the sequence to a fixed length.

:param seqs: [B, T, D] or [B, T] if batch_first else [T, B * D] or [T, B]
:param length: [B]
:param batch_first:
:return:
"""
if batch_first:
# [B * T * D]
if length <= seqs.size(1):
return seqs[:, :length]
else:
batch_size = seqs.size(0)
pad_seq = Variable(seqs.data.new(batch_size, length - seqs.size(1), *seqs.size()[2:]).zero_())
# [B * T * D]
else:
# [T * B * D]
if length <= seqs.size(0):
return seqs
else:

def batch_first2time_first(inputs):
"""
Convert input from batch_first to time_first:
[B, T, D] -> [T, B, D]

:param inputs:
:return:
"""

def time_first2batch_first(inputs):
"""
Convert input from batch_first to time_first:
[T, B, D] -> [B, T, D]

:param inputs:
:return:
"""

def get_state_shape(rnn: nn.RNN, batch_size, bidirectional=False):
"""
Return the state shape of a given RNN. This is helpful when you want to create a init state for RNN.
Example:
c0 = h0 = Variable(src_seq_p.data.new(*get_state_shape([your rnn], 3, bidirectional)).zero_())

:param rnn: nn.LSTM, nn.GRU or subclass of nn.RNN
:param batch_size:
:param bidirectional:
:return:
"""
if bidirectional:
return rnn.num_layers * 2, batch_size, rnn.hidden_size
else:
return rnn.num_layers, batch_size, rnn.hidden_size

def pack_list_sequence(inputs, l, max_l=None, batch_first=True):
"""
Pack a batch of Tensor into one Tensor with max_length.
:param inputs:
:param l:
:param max_l: The max_length of the packed sequence.
:param batch_first:
:return:
"""
batch_list = []
max_l = max(list(l)) if not max_l else max_l
batch_size = len(inputs)

for b_i in range(batch_size):
pack_batch_list = torch.stack(batch_list, dim=1) if not batch_first \
else torch.stack(batch_list, dim=0)
return pack_batch_list

def pack_for_rnn_seq(inputs, lengths, batch_first=True):
"""
:param inputs: Shape of the input should be [B, T, D] if batch_first else [T, B, D].
:param lengths:  [B]
:param batch_first:
:return:
"""
if not batch_first:
_, sorted_indices = lengths.sort()
'''
Reverse to decreasing order
'''
r_index = reversed(list(sorted_indices))

s_inputs_list = []
lengths_list = []
reverse_indices = np.zeros(lengths.size(0), dtype=np.int64)

for j, i in enumerate(r_index):
s_inputs_list.append(inputs[:, i, :].unsqueeze(1))
lengths_list.append(lengths[i])
reverse_indices[i] = j

reverse_indices = list(reverse_indices)

s_inputs = torch.cat(s_inputs_list, 1)

return packed_seq, reverse_indices

else:
#print(lengths)
#_, sorted_indices = lengths.sort()
r_index = reversed(list(np.argsort(lengths)))
'''
Reverse to decreasing order
'''
#r_index = reversed(list(sorted_indices))

s_inputs_list = []
lengths_list = []
#reverse_indices = np.zeros(lengths.size(0), dtype=np.int64)
reverse_indices = np.zeros(len(lengths), dtype=np.int64)

for j, i in enumerate(r_index):
s_inputs_list.append(inputs[i, :, :])
lengths_list.append(lengths[i])
reverse_indices[i] = j

reverse_indices = list(reverse_indices)

s_inputs = torch.stack(s_inputs_list, dim=0)

return packed_seq, reverse_indices

def unpack_from_rnn_seq(packed_seq, reverse_indices, batch_first=True):
s_inputs_list = []

if not batch_first:
for i in reverse_indices:
s_inputs_list.append(unpacked_seq[:, i, :].unsqueeze(1))
else:
for i in reverse_indices:
s_inputs_list.append(unpacked_seq[i, :, :].unsqueeze(0))

def auto_rnn(rnn: nn.RNN, seqs, lengths, batch_first=True, init_state=None, output_last_states=False):
batch_size = seqs.size(0) if batch_first else seqs.size(1)
state_shape = get_state_shape(rnn, batch_size, rnn.bidirectional)

if not init_state:
h0 = c0 = Variable(seqs.data.new(*state_shape).zero_())
else:
h0 = init_state['h0'].expand(state_shape)
c0 = init_state['c0'].expand(state_shape)

packed_pinputs, r_index = pack_for_rnn_seq(seqs, lengths, batch_first)
output, (hn, cn) = rnn(packed_pinputs, (h0, c0))
output = unpack_from_rnn_seq(output, r_index, batch_first)

if not output_last_states:
return output
else:
return output, (hn, cn)

def pack_sequence_for_linear(inputs, lengths, batch_first=True):
"""
:param inputs: [B, T, D] if batch_first
:param lengths:  [B]
:param batch_first:
:return:
"""
batch_list = []
if batch_first:
for i, l in enumerate(lengths):
# print(inputs[i, :l].size())
batch_list.append(inputs[i, :l])
packed_sequence = torch.cat(batch_list, 0)
# if chuck:
#     return list(torch.chunk(packed_sequence, chuck, dim=0))
# else:
return packed_sequence
else:
raise NotImplemented()

def chucked_forward(inputs, net, chuck=None):
if not chuck:
return net(inputs)
else:
output_list = [net(chuck) for chuck in torch.chunk(inputs, chuck, dim=0)]

def unpack_sequence_for_linear(inputs, lengths, batch_first=True):
batch_list = []
max_l = max(lengths)

if not isinstance(inputs, list):
inputs = [inputs]
inputs = torch.cat(inputs)

if batch_first:
start = 0
for l in lengths:
end = start + l
start = end
else:
raise NotImplemented()

def seq2seq_cross_entropy(logits, label, l, chuck=None, sos_truncate=True):
"""
:param logits: [exB, V] : exB = sum(l)
:param label: [B] : a batch of Label
:param l: [B] : a batch of LongTensor indicating the lengths of each inputs
:param chuck: Number of chuck to process
:return: A loss value
"""
packed_label = pack_sequence_for_linear(label, l)
cross_entropy_loss = functools.partial(F.cross_entropy, size_average=False)
total = sum(l)

assert total == logits.size(0) or packed_label.size(0) == logits.size(0),\
"logits length mismatch with label length."

if chuck:
logits_losses = 0
for x, y in zip(torch.chunk(logits, chuck, dim=0), torch.chunk(packed_label, chuck, dim=0)):
logits_losses += cross_entropy_loss(x, y)
return logits_losses * (1 / total)
else:
return cross_entropy_loss(logits, packed_label) * (1 / total)

def avg_along_time(inputs, lengths, list_in=False):
"""
:param inputs: [B, T, D]
:param lengths:  [B]
:return: [B * D] max_along_time
:param list_in:
"""
ls = list(lengths)

if not list_in:
b_seq_max_list = []
for i, l in enumerate(ls):
seq_i = inputs[i, :l, :]
seq_i_max = seq_i.mean(dim=0)
seq_i_max = seq_i_max.squeeze()
b_seq_max_list.append(seq_i_max)

else:
b_seq_max_list = []
for i, l in enumerate(ls):
seq_i = inputs[i]
seq_i_max = seq_i.mean(dim=0)
seq_i_max = seq_i_max.squeeze()
b_seq_max_list.append(seq_i_max)

def max_along_time(inputs, lengths, list_in=False):
"""
:param inputs: [B, T, D]
:param lengths:  [B]
:return: [B * D] max_along_time
:param list_in:
"""
ls = list(lengths)

if not list_in:
b_seq_max_list = []
for i, l in enumerate(ls):
seq_i = inputs[i, :l, :]
seq_i_max, _ = seq_i.max(dim=0)
seq_i_max = seq_i_max.squeeze()
b_seq_max_list.append(seq_i_max)