import torch
from torch.nn import functional as F, Parameter
from torch.autograd import Variable

from spodernet.utils.global_config import Config
from spodernet.utils.cuda_utils import CUDATimer
from torch.nn.init import xavier_normal_, xavier_uniform_
from spodernet.utils.cuda_utils import CUDATimer
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn as nn
import pdb

from itertools import chain

timer = CUDATimer()


class Complex(torch.nn.Module):
    def __init__(self, num_entities, num_relations):
        super(Complex, self).__init__()
        self.num_entities = num_entities
        self.emb_e_real = torch.nn.Embedding(num_entities, Config.embedding_dim, padding_idx=0)
        self.emb_e_img = torch.nn.Embedding(num_entities, Config.embedding_dim, padding_idx=0)
        self.emb_rel_real = torch.nn.Embedding(num_relations, Config.embedding_dim, padding_idx=0)
        self.emb_rel_img = torch.nn.Embedding(num_relations, Config.embedding_dim, padding_idx=0)
        self.inp_drop = torch.nn.Dropout(Config.input_dropout)
        self.loss = torch.nn.BCELoss()

    def init(self):
        xavier_normal_(self.emb_e_real.weight.data)
        xavier_normal_(self.emb_e_img.weight.data)
        xavier_normal_(self.emb_rel_real.weight.data)
        xavier_normal_(self.emb_rel_img.weight.data)

    def forward(self, e1, rel):

        e1_embedded_real = self.inp_drop(self.emb_e_real(e1)).view(Config.batch_size, -1)
        rel_embedded_real = self.inp_drop(self.emb_rel_real(rel)).view(Config.batch_size, -1)
        e1_embedded_img = self.inp_drop(self.emb_e_img(e1)).view(Config.batch_size, -1)
        rel_embedded_img = self.inp_drop(self.emb_rel_img(rel)).view(Config.batch_size, -1)

        e1_embedded_real = self.inp_drop(e1_embedded_real)
        rel_embedded_real = self.inp_drop(rel_embedded_real)
        e1_embedded_img = self.inp_drop(e1_embedded_img)
        rel_embedded_img = self.inp_drop(rel_embedded_img)

        # complex space bilinear product (equivalent to HolE)
        realrealreal = torch.mm(e1_embedded_real*rel_embedded_real, self.emb_e_real.weight.transpose(1,0))
        realimgimg = torch.mm(e1_embedded_real*rel_embedded_img, self.emb_e_img.weight.transpose(1,0))
        imgrealimg = torch.mm(e1_embedded_img*rel_embedded_real, self.emb_e_img.weight.transpose(1,0))
        imgimgreal = torch.mm(e1_embedded_img*rel_embedded_img, self.emb_e_real.weight.transpose(1,0))
        pred = realrealreal + realimgimg + imgrealimg - imgimgreal
        pred = F.sigmoid(pred)

        return pred


class DistMult(torch.nn.Module):
    def __init__(self, num_entities, num_relations):
        super(DistMult, self).__init__()
        self.emb_e = torch.nn.Embedding(num_entities, Config.embedding_dim, padding_idx=0)
        self.emb_rel = torch.nn.Embedding(num_relations, Config.embedding_dim, padding_idx=0)
        self.inp_drop = torch.nn.Dropout(Config.input_dropout)
        self.loss = torch.nn.BCELoss()

    def init(self):
        xavier_normal_(self.emb_e.weight.data)
        xavier_normal_(self.emb_rel.weight.data)

    def forward(self, e1, rel):
        e1_embedded= self.emb_e(e1)
        rel_embedded= self.emb_rel(rel)
        e1_embedded = e1_embedded.view(-1, Config.embedding_dim)
        rel_embedded = rel_embedded.view(-1, Config.embedding_dim)

        e1_embedded = self.inp_drop(e1_embedded)
        rel_embedded = self.inp_drop(rel_embedded)

        pred = torch.mm(e1_embedded*rel_embedded, self.emb_e.weight.transpose(1,0))
        pred = F.sigmoid(pred)

        return pred


class ConvE(torch.nn.Module):
    def __init__(self, num_entities, num_relations):
        super(ConvE, self).__init__()
        self.emb_e = torch.nn.Embedding(num_entities, Config.embedding_dim, padding_idx=0)
        self.emb_rel = torch.nn.Embedding(num_relations, Config.embedding_dim, padding_idx=0)
        self.inp_drop = torch.nn.Dropout(Config.input_dropout)
        self.hidden_drop = torch.nn.Dropout(Config.dropout)
        self.feature_map_drop = torch.nn.Dropout2d(Config.feature_map_dropout)
        self.loss = torch.nn.BCELoss()

        self.conv1 = torch.nn.Conv2d(1, 32, (3, 3), 1, 0, bias=Config.use_bias)
        self.bn0 = torch.nn.BatchNorm2d(1)
        self.bn1 = torch.nn.BatchNorm2d(32)
        self.bn2 = torch.nn.BatchNorm1d(Config.embedding_dim)
        self.register_parameter('b', Parameter(torch.zeros(num_entities)))
        self.fc = torch.nn.Linear(10368,Config.embedding_dim)
        print(num_entities, num_relations)

    def init(self):
        xavier_normal_(self.emb_e.weight.data)
        xavier_normal_(self.emb_rel.weight.data)

    def forward(self, e1, rel):
        e1_embedded= self.emb_e(e1).view(Config.batch_size, 1, 10, 20)
        rel_embedded = self.emb_rel(rel).view(Config.batch_size, 1, 10, 20)

        stacked_inputs = torch.cat([e1_embedded, rel_embedded], 2)

        stacked_inputs = self.bn0(stacked_inputs)
        x= self.inp_drop(stacked_inputs)
        x= self.conv1(x)
        x= self.bn1(x)
        x= F.relu(x)
        x = self.feature_map_drop(x)
        x = x.view(Config.batch_size, -1)
        #print(x.size())
        x = self.fc(x)
        x = self.hidden_drop(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = torch.mm(x, self.emb_e.weight.transpose(1,0))
        x += self.b.expand_as(x)
        pred = F.sigmoid(x)

        return pred


"""
Literal Models
--------------
"""


class DistMultLiteral(torch.nn.Module):

    def __init__(self, num_entities, num_relations, numerical_literals):
        super(DistMultLiteral, self).__init__()

        self.emb_dim = Config.embedding_dim

        self.emb_e = torch.nn.Embedding(num_entities, self.emb_dim, padding_idx=0)
        self.emb_rel = torch.nn.Embedding(num_relations, self.emb_dim, padding_idx=0)

        # Literal
        # num_ent x n_num_lit
        self.numerical_literals = Variable(torch.from_numpy(numerical_literals)).cuda()
        self.n_num_lit = self.numerical_literals.size(1)

        self.emb_num_lit = torch.nn.Linear(self.emb_dim+self.n_num_lit, self.emb_dim)

        # Dropout + loss
        self.inp_drop = torch.nn.Dropout(Config.input_dropout)
        self.loss = torch.nn.BCELoss()

    def init(self):
        xavier_normal_(self.emb_e.weight.data)
        xavier_normal_(self.emb_rel.weight.data)

    def forward(self, e1, rel):
        e1_emb = self.emb_e(e1)
        rel_emb = self.emb_rel(rel)

        e1_emb = e1_emb.view(-1, self.emb_dim)
        rel_emb = rel_emb.view(-1, self.emb_dim)

        # Begin literals

        e1_num_lit = self.numerical_literals[e1.view(-1)]
        e1_emb = self.emb_num_lit(torch.cat([e1_emb, e1_num_lit], 1))

        e2_multi_emb = self.emb_num_lit(torch.cat([self.emb_e.weight, self.numerical_literals], 1))

        # End literals

        e1_emb = self.inp_drop(e1_emb)
        rel_emb = self.inp_drop(rel_emb)

        pred = torch.mm(e1_emb*rel_emb, e2_multi_emb.t())
        pred = F.sigmoid(pred)

        return pred


class KBLN(torch.nn.Module):

    def __init__(self, num_entities, num_relations, numerical_literals, c, var):
        super(KBLN, self).__init__()

        self.num_entities = num_entities
        self.emb_dim = Config.embedding_dim

        self.emb_e = torch.nn.Embedding(num_entities, self.emb_dim, padding_idx=0)
        self.emb_rel = torch.nn.Embedding(num_relations, self.emb_dim, padding_idx=0)

        # Literal
        # num_ent x n_num_lit
        self.numerical_literals = Variable(torch.from_numpy(numerical_literals)).cuda()
        self.n_num_lit = self.numerical_literals.size(1)

        # Fixed RBF parameters
        print(c)
        print(var)
        self.c = Variable(torch.FloatTensor(c)).cuda()
        self.var = Variable(torch.FloatTensor(var)).cuda()

        # Weights for numerical, one every relation
        self.nf_weights = nn.Embedding(num_relations, self.n_num_lit)

        # Dropout + loss
        self.inp_drop = torch.nn.Dropout(Config.input_dropout)
        self.loss = torch.nn.BCELoss()

    def init(self):
        xavier_normal_(self.emb_e.weight.data)
        xavier_normal_(self.emb_rel.weight.data)

    def forward(self, e1, rel):
        e1_emb = self.emb_e(e1).view(-1, self.emb_dim)
        rel_emb = self.emb_rel(rel).view(-1, self.emb_dim)

        e1_emb = self.inp_drop(e1_emb)
        rel_emb = self.inp_drop(rel_emb)

        score_l = torch.mm(e1_emb*rel_emb, self.emb_e.weight.t())

        """ Begin numerical literals """
        n_h = self.numerical_literals[e1.view(-1)]  # (batch_size x n_lit)
        n_t = self.numerical_literals  # (num_ents x n_lit)

        # Features (batch_size x num_ents x n_lit)
        n = n_h.unsqueeze(1).repeat(1, self.num_entities, 1) - n_t
        phi = self.rbf(n)
        # Weights (batch_size, 1, n_lits)
        w_nf = self.nf_weights(rel)

        # (batch_size, num_ents)
        score_n = torch.bmm(phi, w_nf.transpose(1, 2)).squeeze()
        """ End numerical literals """

        score = F.sigmoid(score_l + score_n)

        return score

    def rbf(self, n):
        """
        Apply RBF kernel parameterized by (fixed) c and var, pointwise.
        n: (batch_size, num_ents, n_lit)
        """
        return torch.exp(-(n - self.c)**2 / self.var)


class MTKGNN_DistMult(torch.nn.Module):

    def __init__(self, num_entities, num_relations, numerical_literals):
        super(MTKGNN_DistMult, self).__init__()

        self.emb_dim = Config.embedding_dim
        self.num_entities = num_entities
        self.num_relations = num_relations

        self.emb_e = torch.nn.Embedding(num_entities, self.emb_dim, padding_idx=0)
        self.emb_rel = torch.nn.Embedding(num_relations, self.emb_dim, padding_idx=0)

        # Literal
        # num_ent x n_num_lit
        self.numerical_literals = Variable(torch.from_numpy(numerical_literals)).cuda()
        self.n_num_lit = self.numerical_literals.size(1)

        self.emb_attr = torch.nn.Embedding(self.n_num_lit, self.emb_dim)

        self.attr_net_left = torch.nn.Sequential(
            torch.nn.Linear(2*self.emb_dim, 100),
            torch.nn.Tanh(),
            torch.nn.Linear(100, 1))

        self.attr_net_right = torch.nn.Sequential(
            torch.nn.Linear(2*self.emb_dim, 100),
            torch.nn.Tanh(),
            torch.nn.Linear(100, 1))

        self.rel_params = chain(self.emb_e.parameters(), self.emb_rel.parameters())
        self.attr_params = chain(self.emb_e.parameters(), self.emb_attr.parameters(),
            self.attr_net_left.parameters(), self.attr_net_right.parameters())

        # Dropout + loss
        self.inp_drop = torch.nn.Dropout(Config.input_dropout)
        self.loss_rel = torch.nn.BCELoss()
        self.loss_attr = torch.nn.MSELoss()

    def init(self):
        xavier_normal_(self.emb_e.weight.data)
        xavier_normal_(self.emb_rel.weight.data)

    def forward(self, e1, rel):
        e1_embedded= self.emb_e(e1)
        rel_embedded= self.emb_rel(rel)
        e1_embedded = e1_embedded.view(-1, Config.embedding_dim)
        rel_embedded = rel_embedded.view(-1, Config.embedding_dim)

        e1_embedded = self.inp_drop(e1_embedded)
        rel_embedded = self.inp_drop(rel_embedded)

        pred = torch.mm(e1_embedded*rel_embedded, self.emb_e.weight.transpose(1,0))
        pred = F.sigmoid(pred)

        return pred

    def forward_attr(self, e, mode='left'):
        assert mode == 'left' or mode == 'right'

        e_emb = self.emb_e(e.view(-1))

        # Sample one numerical literal for each entity
        e_attr = self.numerical_literals[e.view(-1)]
        m = len(e_attr)
        idxs = torch.randint(self.n_num_lit, size=(m,)).cuda()
        attr_emb = self.emb_attr(idxs)

        inputs = torch.cat([e_emb, attr_emb], dim=1)
        pred = self.attr_net_left(inputs) if mode == 'left' else self.attr_net_right(inputs)
        target = e_attr[range(m), idxs]

        return pred, target

    def forward_AST(self):
        m = Config.batch_size

        idxs_attr = torch.randint(self.n_num_lit, size=(m,)).cuda()
        idxs_ent = torch.randint(self.num_entities, size=(m,)).cuda()

        attr_emb = self.emb_attr(idxs_attr)
        ent_emb = self.emb_e(idxs_ent)

        inputs = torch.cat([ent_emb, attr_emb], dim=1)
        pred_left = self.attr_net_left(inputs)
        pred_right = self.attr_net_right(inputs)
        target = self.numerical_literals[idxs_ent][range(m), idxs_attr]

        return pred_left, pred_right, target


class ComplexLiteral(torch.nn.Module):

    def __init__(self, num_entities, num_relations, numerical_literals):
        super(ComplexLiteral, self).__init__()

        self.emb_dim = Config.embedding_dim

        self.emb_e_real = torch.nn.Embedding(num_entities, self.emb_dim, padding_idx=0)
        self.emb_e_img = torch.nn.Embedding(num_entities, self.emb_dim, padding_idx=0)
        self.emb_rel_real = torch.nn.Embedding(num_relations, self.emb_dim, padding_idx=0)
        self.emb_rel_img = torch.nn.Embedding(num_relations, self.emb_dim, padding_idx=0)

        # Literal
        # num_ent x n_num_lit
        self.numerical_literals = Variable(torch.from_numpy(numerical_literals)).cuda()
        self.n_num_lit = self.numerical_literals.size(1)

        self.emb_num_lit_real = torch.nn.Sequential(
            torch.nn.Linear(self.emb_dim+self.n_num_lit, self.emb_dim),
            torch.nn.Tanh()
        )

        self.emb_num_lit_img = torch.nn.Sequential(
            torch.nn.Linear(self.emb_dim+self.n_num_lit, self.emb_dim),
            torch.nn.Tanh()
        )

        # Dropout + loss
        self.inp_drop = torch.nn.Dropout(Config.input_dropout)
        self.loss = torch.nn.BCELoss()

    def init(self):
        xavier_normal_(self.emb_e_real.weight.data)
        xavier_normal_(self.emb_e_img.weight.data)
        xavier_normal_(self.emb_rel_real.weight.data)
        xavier_normal_(self.emb_rel_img.weight.data)

    def forward(self, e1, rel):
        e1_emb_real = self.emb_e_real(e1).view(Config.batch_size, -1)
        rel_emb_real = self.emb_rel_real(rel).view(Config.batch_size, -1)
        e1_emb_img = self.emb_e_img(e1).view(Config.batch_size, -1)
        rel_emb_img = self.emb_rel_img(rel).view(Config.batch_size, -1)

        # Begin literals

        e1_num_lit = self.numerical_literals[e1.view(-1)]
        e1_emb_real = self.emb_num_lit_real(torch.cat([e1_emb_real, e1_num_lit], 1))
        e1_emb_img = self.emb_num_lit_img(torch.cat([e1_emb_img, e1_num_lit], 1))

        e2_multi_emb_real = self.emb_num_lit_real(torch.cat([self.emb_e_real.weight, self.numerical_literals], 1))
        e2_multi_emb_img = self.emb_num_lit_img(torch.cat([self.emb_e_img.weight, self.numerical_literals], 1))

        # End literals

        e1_emb_real = self.inp_drop(e1_emb_real)
        rel_emb_real = self.inp_drop(rel_emb_real)
        e1_emb_img = self.inp_drop(e1_emb_img)
        rel_emb_img = self.inp_drop(rel_emb_img)

        realrealreal = torch.mm(e1_emb_real*rel_emb_real, e2_multi_emb_real.t())
        realimgimg = torch.mm(e1_emb_real*rel_emb_img, e2_multi_emb_img.t())
        imgrealimg = torch.mm(e1_emb_img*rel_emb_real, e2_multi_emb_img.t())
        imgimgreal = torch.mm(e1_emb_img*rel_emb_img, e2_multi_emb_real.t())

        pred = realrealreal + realimgimg + imgrealimg - imgimgreal
        pred = F.sigmoid(pred)

        return pred


class ConvELiteral(torch.nn.Module):

    def __init__(self, num_entities, num_relations, numerical_literals):
        super(ConvELiteral, self).__init__()

        self.emb_dim = Config.embedding_dim

        self.emb_e = torch.nn.Embedding(num_entities, self.emb_dim, padding_idx=0)
        self.emb_rel = torch.nn.Embedding(num_relations, self.emb_dim, padding_idx=0)

        # Literal
        # num_ent x n_num_lit
        self.numerical_literals = Variable(torch.from_numpy(numerical_literals)).cuda()
        self.n_num_lit = self.numerical_literals.size(1)

        self.emb_num_lit = torch.nn.Sequential(
            torch.nn.Linear(self.emb_dim+self.n_num_lit, self.emb_dim),
            torch.nn.Tanh()
        )

        self.inp_drop = torch.nn.Dropout(Config.input_dropout)
        self.hidden_drop = torch.nn.Dropout(Config.dropout)
        self.feature_map_drop = torch.nn.Dropout2d(Config.feature_map_dropout)
        self.loss = torch.nn.BCELoss()

        self.conv1 = torch.nn.Conv2d(1, 32, (3, 3), 1, 0, bias=Config.use_bias)
        self.bn0 = torch.nn.BatchNorm2d(1)
        self.bn1 = torch.nn.BatchNorm2d(32)
        self.bn2 = torch.nn.BatchNorm1d(self.emb_dim)
        self.register_parameter('b', Parameter(torch.zeros(num_entities)))
        self.fc = torch.nn.Linear(10368, self.emb_dim)
        print(num_entities, num_relations)

    def init(self):
        xavier_normal_(self.emb_e.weight.data)
        xavier_normal_(self.emb_rel.weight.data)

    def forward(self, e1, rel):
        e1_emb = self.emb_e(e1).view(Config.batch_size, -1)
        rel_emb = self.emb_rel(rel)

        # Begin literals

        e1_num_lit = self.numerical_literals[e1.view(-1)]
        e1_emb = self.emb_num_lit(torch.cat([e1_emb, e1_num_lit], 1))

        e2_multi_emb = self.emb_num_lit(torch.cat([self.emb_e.weight, self.numerical_literals], 1))

        # End literals

        e1_emb = e1_emb.view(Config.batch_size, 1, 10, self.emb_dim//10)
        rel_emb = rel_emb.view(Config.batch_size, 1, 10, self.emb_dim//10)

        stacked_inputs = torch.cat([e1_emb, rel_emb], 2)

        stacked_inputs = self.bn0(stacked_inputs)
        x = self.inp_drop(stacked_inputs)
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.feature_map_drop(x)
        x = x.view(Config.batch_size, -1)
        # print(x.size())
        x = self.fc(x)
        x = self.hidden_drop(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = torch.mm(x, e2_multi_emb.t())
        x += self.b.expand_as(x)
        pred = F.sigmoid(x)

        return pred


class Gate(nn.Module):

    def __init__(self,
                 input_size,
                 output_size,
                 # gate_activation=nn.functional.softmax):
                 gate_activation=nn.functional.sigmoid):

        super(Gate, self).__init__()
        self.output_size = output_size

        self.gate_activation = gate_activation
        self.g = nn.Linear(input_size, output_size)
        self.g1 = nn.Linear(output_size, output_size, bias=False)
        self.g2 = nn.Linear(input_size-output_size, output_size, bias=False)
        self.gate_bias = nn.Parameter(torch.zeros(output_size))

    def forward(self, x_ent, x_lit):
        x = torch.cat([x_ent, x_lit], 1)
        g_embedded = F.tanh(self.g(x))
        gate = self.gate_activation(self.g1(x_ent) + self.g2(x_lit) + self.gate_bias)
        output = (1-gate) * x_ent + gate * g_embedded

        return output


class DistMultLiteral_gate(torch.nn.Module):

    def __init__(self, num_entities, num_relations, numerical_literals):
        super(DistMultLiteral_gate, self).__init__()

        self.emb_dim = Config.embedding_dim

        self.emb_e = torch.nn.Embedding(num_entities, self.emb_dim, padding_idx=0)
        self.emb_rel = torch.nn.Embedding(num_relations, self.emb_dim, padding_idx=0)

        # Literal
        # num_ent x n_num_lit
        self.numerical_literals = Variable(torch.from_numpy(numerical_literals)).cuda()
        self.n_num_lit = self.numerical_literals.size(1)

        self.emb_num_lit = Gate(self.emb_dim+self.n_num_lit, self.emb_dim)


        # Dropout + loss
        self.inp_drop = torch.nn.Dropout(Config.input_dropout)
        self.loss = torch.nn.BCELoss()

    def init(self):
        xavier_normal_(self.emb_e.weight.data)
        xavier_normal_(self.emb_rel.weight.data)

    def forward(self, e1, rel):
        e1_emb = self.emb_e(e1)
        rel_emb = self.emb_rel(rel)

        e1_emb = e1_emb.view(-1, self.emb_dim)
        rel_emb = rel_emb.view(-1, self.emb_dim)

        # Begin literals

        e1_num_lit = self.numerical_literals[e1.view(-1)]
        e1_emb = self.emb_num_lit(e1_emb, e1_num_lit)
        e2_multi_emb = self.emb_num_lit(self.emb_e.weight, self.numerical_literals)

        # End literals

        e1_emb = self.inp_drop(e1_emb)
        rel_emb = self.inp_drop(rel_emb)

        pred = torch.mm(e1_emb*rel_emb, e2_multi_emb.t())
        pred = F.sigmoid(pred)

        return pred


class ComplexLiteral_gate(torch.nn.Module):

    def __init__(self, num_entities, num_relations, numerical_literals):
        super(ComplexLiteral_gate, self).__init__()

        self.emb_dim = Config.embedding_dim

        self.emb_e_real = torch.nn.Embedding(num_entities, self.emb_dim, padding_idx=0)
        self.emb_e_img = torch.nn.Embedding(num_entities, self.emb_dim, padding_idx=0)
        self.emb_rel_real = torch.nn.Embedding(num_relations, self.emb_dim, padding_idx=0)
        self.emb_rel_img = torch.nn.Embedding(num_relations, self.emb_dim, padding_idx=0)

        # Literal
        # num_ent x n_num_lit
        self.numerical_literals = Variable(torch.from_numpy(numerical_literals)).cuda()
        self.n_num_lit = self.numerical_literals.size(1)

        self.emb_num_lit_real = Gate(self.emb_dim+self.n_num_lit, self.emb_dim)
        self.emb_num_lit_img = Gate(self.emb_dim+self.n_num_lit, self.emb_dim)

        # Dropout + loss
        self.inp_drop = torch.nn.Dropout(Config.input_dropout)
        self.loss = torch.nn.BCELoss()

    def init(self):
        xavier_normal_(self.emb_e_real.weight.data)
        xavier_normal_(self.emb_e_img.weight.data)
        xavier_normal_(self.emb_rel_real.weight.data)
        xavier_normal_(self.emb_rel_img.weight.data)

    def forward(self, e1, rel):
        e1_emb_real = self.emb_e_real(e1).view(Config.batch_size, -1)
        rel_emb_real = self.emb_rel_real(rel).view(Config.batch_size, -1)
        e1_emb_img = self.emb_e_img(e1).view(Config.batch_size, -1)
        rel_emb_img = self.emb_rel_img(rel).view(Config.batch_size, -1)

        # Begin literals

        e1_num_lit = self.numerical_literals[e1.view(-1)]
        e1_emb_real = self.emb_num_lit_real(e1_emb_real, e1_num_lit)
        e1_emb_img = self.emb_num_lit_img(e1_emb_img, e1_num_lit)

        e2_multi_emb_real = self.emb_num_lit_real(self.emb_e_real.weight, self.numerical_literals)
        e2_multi_emb_img = self.emb_num_lit_img(self.emb_e_img.weight, self.numerical_literals)

        # End literals

        e1_emb_real = self.inp_drop(e1_emb_real)
        rel_emb_real = self.inp_drop(rel_emb_real)
        e1_emb_img = self.inp_drop(e1_emb_img)
        rel_emb_img = self.inp_drop(rel_emb_img)

        realrealreal = torch.mm(e1_emb_real*rel_emb_real, e2_multi_emb_real.t())
        realimgimg = torch.mm(e1_emb_real*rel_emb_img, e2_multi_emb_img.t())
        imgrealimg = torch.mm(e1_emb_img*rel_emb_real, e2_multi_emb_img.t())
        imgimgreal = torch.mm(e1_emb_img*rel_emb_img, e2_multi_emb_real.t())

        pred = realrealreal + realimgimg + imgrealimg - imgimgreal
        pred = F.sigmoid(pred)

        return pred


class ConvELiteral_gate(torch.nn.Module):

    def __init__(self, num_entities, num_relations, numerical_literals):
        super(ConvELiteral_gate, self).__init__()

        self.emb_dim = Config.embedding_dim

        self.emb_e = torch.nn.Embedding(num_entities, self.emb_dim, padding_idx=0)
        self.emb_rel = torch.nn.Embedding(num_relations, self.emb_dim, padding_idx=0)

        # Literal
        # num_ent x n_num_lit
        self.numerical_literals = Variable(torch.from_numpy(numerical_literals)).cuda()
        self.n_num_lit = self.numerical_literals.size(1)

        self.emb_num_lit = Gate(self.emb_dim+self.n_num_lit, self.emb_dim)

        self.inp_drop = torch.nn.Dropout(Config.input_dropout)
        self.hidden_drop = torch.nn.Dropout(Config.dropout)
        self.feature_map_drop = torch.nn.Dropout2d(Config.feature_map_dropout)
        self.loss = torch.nn.BCELoss()

        self.conv1 = torch.nn.Conv2d(1, 32, (3, 3), 1, 0, bias=Config.use_bias)
        self.bn0 = torch.nn.BatchNorm2d(1)
        self.bn1 = torch.nn.BatchNorm2d(32)
        self.bn2 = torch.nn.BatchNorm1d(self.emb_dim)
        self.register_parameter('b', Parameter(torch.zeros(num_entities)))
        self.fc = torch.nn.Linear(10368, self.emb_dim)
        print(num_entities, num_relations)

    def init(self):
        xavier_normal_(self.emb_e.weight.data)
        xavier_normal_(self.emb_rel.weight.data)

    def forward(self, e1, rel):
        e1_emb = self.emb_e(e1).view(Config.batch_size, -1)
        rel_emb = self.emb_rel(rel)

        # Begin literals

        e1_num_lit = self.numerical_literals[e1.view(-1)]

        e1_emb = self.emb_num_lit(e1_emb, e1_num_lit)
        e2_multi_emb = self.emb_num_lit(self.emb_e.weight, self.numerical_literals)

        # End literals

        e1_emb = e1_emb.view(Config.batch_size, 1, 10, self.emb_dim//10)
        rel_emb = rel_emb.view(Config.batch_size, 1, 10, self.emb_dim//10)

        stacked_inputs = torch.cat([e1_emb, rel_emb], 2)

        stacked_inputs = self.bn0(stacked_inputs)
        x = self.inp_drop(stacked_inputs)
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.feature_map_drop(x)
        x = x.view(Config.batch_size, -1)
        # print(x.size())
        x = self.fc(x)
        x = self.hidden_drop(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = torch.mm(x, e2_multi_emb.t())
        x += self.b.expand_as(x)
        pred = F.sigmoid(x)

        return pred


"""
TEXT LITERALS
-----------------------------------
"""

class GateMulti(nn.Module):

    def __init__(self, emb_size, num_lit_size, txt_lit_size, gate_activation=nn.functional.sigmoid):
        super(GateMulti, self).__init__()

        self.emb_size = emb_size
        self.num_lit_size = num_lit_size
        self.txt_lit_size = txt_lit_size

        self.gate_activation = gate_activation
        self.g = nn.Linear(emb_size+num_lit_size+txt_lit_size, emb_size)

        self.gate_ent = nn.Linear(emb_size, emb_size, bias=False)
        self.gate_num_lit = nn.Linear(num_lit_size, emb_size, bias=False)
        self.gate_txt_lit = nn.Linear(txt_lit_size, emb_size, bias=False)
        self.gate_bias = nn.Parameter(torch.zeros(emb_size))

    def forward(self, x_ent, x_lit_num, x_lit_txt):
        x = torch.cat([x_ent, x_lit_num, x_lit_txt], 1)
        g_embedded = F.tanh(self.g(x))
        gate = self.gate_activation(self.gate_ent(x_ent) + self.gate_num_lit(x_lit_num) + self.gate_txt_lit(x_lit_txt) + self.gate_bias)
        output = (1-gate) * x_ent + gate * g_embedded

        return output


class DistMultLiteral_gate_text(torch.nn.Module):

    def __init__(self, num_entities, num_relations, numerical_literals, text_literals):
        super(DistMultLiteral_gate_text, self).__init__()

        self.emb_dim = Config.embedding_dim

        self.emb_e = torch.nn.Embedding(num_entities, self.emb_dim, padding_idx=0)
        self.emb_rel = torch.nn.Embedding(num_relations, self.emb_dim, padding_idx=0)

        # Num. Literal
        # num_ent x n_num_lit
        self.numerical_literals = Variable(torch.from_numpy(numerical_literals)).cuda()
        self.n_num_lit = self.numerical_literals.size(1)

        # Txt. Literal
        # num_ent x n_txt_lit
        self.text_literals = Variable(torch.from_numpy(text_literals)).cuda()
        self.n_txt_lit = self.text_literals.size(1)

        # LiteralE's g
        self.emb_lit = GateMulti(self.emb_dim, self.n_num_lit, self.n_txt_lit)

        # Dropout + loss
        self.inp_drop = torch.nn.Dropout(Config.input_dropout)
        self.loss = torch.nn.BCELoss()

    def init(self):
        xavier_normal_(self.emb_e.weight.data)
        xavier_normal_(self.emb_rel.weight.data)

    def forward(self, e1, rel):
        e1_emb = self.emb_e(e1)
        rel_emb = self.emb_rel(rel)

        e1_emb = e1_emb.view(-1, self.emb_dim)
        rel_emb = rel_emb.view(-1, self.emb_dim)

        # Begin literals
        # --------------
        e1_num_lit = self.numerical_literals[e1.view(-1)]
        e1_txt_lit = self.text_literals[e1.view(-1)]
        e1_emb = self.emb_lit(e1_emb, e1_num_lit, e1_txt_lit)
        e2_multi_emb = self.emb_lit(self.emb_e.weight, self.numerical_literals, self.text_literals)
        # --------------
        # End literals

        e1_emb = self.inp_drop(e1_emb)
        rel_emb = self.inp_drop(rel_emb)

        pred = torch.mm(e1_emb*rel_emb, e2_multi_emb.t())
        pred = F.sigmoid(pred)

        return pred