import torch
import torch.nn as nn
import torch.nn.functional as F
import math, random
from utils import activation_method
from collections import defaultdict


def create_module(module_type, **config):
    module_type = module_type.lower()
    if module_type == 'mlp':
        return MLP(**config)
    elif module_type == 'gcn':
        return AttentionGCN(**config)
    elif module_type == 'empty':
        return nn.Sequential()
    else:
        raise NotImplementedError


class MLP(nn.Module):
    def __init__(self, input_size, hidden_layers, final_size=0, final_activation="none", normalization="batch_norm",
                 activation='relu'):
        """
        :param input_size:
        :param hidden_layers: [(unit_num, normalization, dropout_rate)]
        :param final_size:
        :param final_activation:
        """
        nn.Module.__init__(self)
        self.input_size = input_size
        fcs = []
        last_size = self.input_size
        for size, to_norm, dropout_rate in hidden_layers:
            linear = nn.Linear(last_size, size)
            linear.bias.data.fill_(0.0)
            fcs.append(linear)
            last_size = size
            if to_norm:
                if normalization == 'batch_norm':
                    fcs.append(nn.BatchNorm1d(last_size))
                elif normalization == 'layer_norm':
                    fcs.append(nn.LayerNorm(last_size))
            fcs.append(activation_method(activation))
            if dropout_rate > 0.0:
                fcs.append(nn.Dropout(dropout_rate))
        self.fc = nn.Sequential(*fcs)
        if final_size > 0:
            linear = nn.Linear(last_size, final_size)
            linear.bias.data.fill_(0.0)
            finals = [linear, activation_method(final_activation)]
        else:
            finals = []
        self.final_layer = nn.Sequential(*finals)

    def forward(self, x):
        out = self.fc(x)
        out = self.final_layer(out)
        return out


class MultiheadAttention(nn.Module):
    def __init__(self, input_size, query_size, value_size, head_num, dropout=0.0, concatenate=True, configurable=False,
                 use_dot=True):
        nn.Module.__init__(self)
        self.use_dot = use_dot
        if use_dot is True:
            self.query_heads = nn.Linear(input_size, head_num * query_size, bias=True)
        else:
            self.query_heads = nn.Linear(query_size + input_size, head_num, bias=False)
        self.head_num = head_num
        self.concatenate = concatenate
        self.input_size = input_size
        self.value_size = value_size
        if concatenate:
            self.value_proj = nn.Linear(value_size, input_size)
        else:
            self.value_proj = nn.Linear(value_size, input_size * head_num)
        if configurable:
            self.param_divide(self.query_heads, with_query=True)
            self.param_divide(self.value_proj, with_query=True)
        if dropout > 0.0:
            self.attn_dropout = nn.Dropout(dropout)
        else:
            self.attn_dropout = None
        self.attn = None

    @staticmethod
    def param_divide(linear_module, with_query):
        weight = getattr(linear_module, 'weight')
        del linear_module._parameters['weight']
        linear_module.register_parameter('share_weight', weight)
        setattr(linear_module, 'weight', weight.data)
        if with_query:
            input_size, output_size = linear_module.in_features, linear_module.out_features
            bound = math.sqrt(6.0 / input_size)
            query_vector = torch.empty(output_size, dtype=torch.float)
            nn.init.uniform_(query_vector, -bound, bound)
            linear_module.register_parameter('query', nn.Parameter(query_vector))

    def configure(self, in_vector):
        """
        :param in_vector: (2, in_features)
        :return:
        """
        setattr(self.query_heads, 'weight',
                torch.matmul(self.query_heads.query.unsqueeze(-1), in_vector[0:1]) + self.query_heads.share_weight)
        setattr(self.value_proj, 'weight',
                torch.matmul(self.value_proj.query.unsqueeze(-1), in_vector[1:2]) + self.value_proj.share_weight)

    @staticmethod
    def attention(scores, value, mask=None, dropout=None):
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        p_attn = F.softmax(scores, dim=-1)
        p_attn_org = p_attn
        if dropout is not None:
            p_attn = dropout(p_attn)
        return torch.matmul(p_attn, value), p_attn_org

    def forward(self, query, key, value, mask=None):
        """
        :param query: (batch_size, input_size)
        :param key: (batch_size, max_len, query_size)
        :param value: (batch_size, max_len, value_size)
        :return:
        """
        batch_size, max_len = key.size(0), key.size(1)
        value_size = self.value_proj.out_features // self.head_num
        value = self.value_proj(value)
        # batch_size, attnhead_num, max_len, out_features
        value = value.view(batch_size, max_len, self.head_num, value_size).transpose(1, 2)
        # (*, output_features) (*, max_len, out_features)
        if self.use_dot:
            attnhead_size = self.query_heads.out_features // self.head_num
            query = self.query_heads(query)
            query = query.view(batch_size, self.head_num, 1, attnhead_size)
            # batch_size attnhead_num, max_len, query_size
            key = key.unsqueeze(1).expand(-1, self.head_num, -1, -1)
            # batch_size, query_num, 1, dict_size
            scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(attnhead_size)
        else:
            # batch_size, max_len, query_size + input_size
            query = torch.cat((query.unsqueeze(1).expand(-1, max_len, -1), key), dim=-1)
            # batch_size, head_size, 1, max_len
            scores = self.query_heads(query).transpose(-1, -2).unsqueeze(-2)
        attn_value, attn = self.attention(scores, value, mask=mask, dropout=self.attn_dropout)
        self.attn = attn.detach()
        if self.concatenate:
            attn_value = attn_value.view(batch_size, -1)
        else:
            attn_value = attn_value.mean(dim=1).squeeze(1)
        return attn_value


class AttentionGCN(nn.Module):
    def __init__(self, input_size, layer_num, attnhead_num, attention=True, activation='relu', layer_norm=False,
                 attn_dropout=0.0, configurable=False, use_dot=True):
        nn.Module.__init__(self)
        self.layer_num = layer_num
        self.input_size = input_size
        self.attnhead_num = attnhead_num
        self.layers = []
        attn_heads, node_projs, activations, normalizations = [], [], [], []
        self.d_k = input_size // attnhead_num
        self.use_attention = attention
        for i in range(layer_num):
            if attention:
                attn_head = MultiheadAttention(input_size=input_size, query_size=input_size, value_size=input_size,
                                               head_num=attnhead_num, use_dot=use_dot,
                                               configurable=configurable, dropout=attn_dropout,
                                               concatenate=i != layer_num - 1)
            else:
                attn_head = nn.Linear(input_size, input_size)
            attn_heads.append(attn_head)
            node_projs.append(nn.Linear(input_size, input_size, bias=True))
            if i != layer_num - 1:
                activations.append(activation_method(activation))
            else:
                activations.append(activation_method('none'))
            if layer_norm and i != layer_num - 1:
                normalizations.append(nn.LayerNorm(input_size))
            else:
                normalizations.append(nn.Sequential())
        if len(attn_heads) > 0:
            self.attn_heads = nn.ModuleList(attn_heads)
        self.node_projs = nn.ModuleList(node_projs)
        self.activations = nn.ModuleList(activations)
        self.normalizations = nn.ModuleList(normalizations)
        self.attns = []
        if attn_dropout > 0.0:
            self.atten_dropout = nn.Dropout(attn_dropout)
        else:
            self.atten_dropout = None

    @staticmethod
    def generate_mask(lengths, max_len):
        batch_size = lengths.size(0)
        masks = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1)
        masks = masks < lengths.unsqueeze(1)
        return masks

    def forward(self, node_embeds, neighbor_embeds, node_degrees=None):
        """
        :param node_embeds: (batch_size, embed_dim)
        :param neighbor_embeds: (batch_size, max_len, embed_dim)
        :param node_degrees: (batch_size,)
        :return:
        """
        batch_size, max_len, embed_dim = neighbor_embeds.size()
        if node_degrees is not None:
            neighbor_mask = self.generate_mask(node_degrees, max_len).view(batch_size, 1, 1, max_len)
            neighbor_embeds = neighbor_embeds.clone().masked_fill_(neighbor_mask.view(batch_size, max_len, 1) == 0, 0.0)
        else:
            neighbor_mask = None
        if not self.use_attention:
            if node_degrees is None:
                neighbor_embeds = neighbor_embeds.mean(dim=1)
            else:
                neighbor_embeds = neighbor_embeds.sum(dim=1)
                nonzeros = node_degrees.nonzero().squeeze(-1)
                neighbor_embeds[nonzeros] /= node_degrees[nonzeros].unsqueeze(-1).type(dtype=torch.float)
        for i in range(self.layer_num):
            node_proj, activation, normalization = self.node_projs[i], \
                                                   self.activations[i], self.normalizations[i]
            attn_head = self.attn_heads[i]
            if self.use_attention:
                attn_value = attn_head(node_embeds, neighbor_embeds, neighbor_embeds, mask=neighbor_mask)
            else:
                attn_value = attn_head(neighbor_embeds)
            node_embeds = node_proj(node_embeds) + attn_value
            node_embeds = normalization(node_embeds)
            node_embeds = activation(node_embeds)
        return node_embeds


class Recommender(nn.Module):
    def __init__(self, useritem_embeds, user_graph=False, item_graph=False):
        nn.Module.__init__(self)
        self.useritem_embeds = useritem_embeds
        self.user_graph = user_graph
        self.item_graph = item_graph

    def forward(self, query_users, query_items, with_attr=False):
        if query_users[0].dim() > 1:
            query_users = list(map(lambda x: x.squeeze(0), query_users))
        if query_items[0].dim() > 1:
            query_items = list(map(lambda x: x.squeeze(0), query_items))
        if not with_attr:
            query_users = self.useritem_embeds(*query_users, is_user=True, with_neighbor=self.user_graph)
            query_items = self.useritem_embeds(*query_items, is_user=False, with_neighbor=self.item_graph)
        return query_users, query_items


class InteractionRecommender(Recommender):
    def __init__(self, useritem_embeds, mlp_config):
        super(InteractionRecommender, self).__init__(useritem_embeds)
        self.mlp = MLP(**mlp_config)

    def forward(self, query_users, query_items, support_users=None, support_items=None, with_attr=False):
        query_users, query_items = super(InteractionRecommender, self).forward(query_users, query_items,
                                                                               with_attr=with_attr)
        query_users, query_items = query_users[0], query_items[0]
        if query_users.size(0) == 1:
            query_users = query_users.expand(query_items.size(0), -1)
        query_embeds = torch.cat((query_users, query_items), dim=1)
        return self.mlp(query_embeds).squeeze(1)


class EmbedRecommender(Recommender):
    def __init__(self, useritem_embeds, user_config, item_config, user_graph=True, item_graph=True):
        super(EmbedRecommender, self).__init__(useritem_embeds, user_graph, item_graph)
        self.user_model = create_module(**user_config)
        self.item_model = create_module(**item_config)

    def forward(self, query_users, query_items, with_attr=False):
        """
        :param with_attr:
        :param query_users: (batch_size,)
        :param query_items: (batch_size)
        :return:
        """
        query_users, query_items = Recommender.forward(self, query_users, query_items, with_attr=with_attr)
        query_users = self.user_model(*query_users)
        query_items = self.item_model(*query_items)
        return (query_users * query_items).sum(dim=1)


class CoNet(nn.Module):
    def __init__(self, useritem_embeds, source_ratings, item_padding_idx, input_size, hidden_layers):
        nn.Module.__init__(self)
        self.useritem_embeds = useritem_embeds
        self.source_ratings = source_ratings
        self.item_padding_idx = item_padding_idx
        last_size = input_size * 2
        layers1, layers2, transfer_layers = [], [], []
        for hidden_size in hidden_layers:
            layers1.append(nn.Linear(last_size, hidden_size))
            layers2.append(nn.Linear(last_size, hidden_size))
            transfer_layers.append(nn.Linear(last_size, hidden_size))
            last_size = hidden_size
        self.target_layers = nn.ModuleList(layers1)
        self.auxiliary_layers = nn.ModuleList(layers2)
        self.transfer_layers = nn.ModuleList(transfer_layers)
        self.target_output = nn.Linear(last_size, 1)
        self.auxiliary_output = nn.Linear(last_size, 1)

    def forward(self, query_users, target_items, auxiliary_items=None):
        only_target = False
        if auxiliary_items is None:
            only_target = True
            auxiliary_items = [
                random.choice(self.source_ratings[user_id.item()]) if len(
                    self.source_ratings[user_id.item()]) > 0 else self.item_padding_idx for user_id in query_users[0]]
            auxiliary_items = (torch.tensor(auxiliary_items, dtype=torch.long, device=query_users[0].device),)
        query_users = list(map(lambda x: x.expand(target_items[0].size(0)), query_users))
        auxiliary_items = list(map(lambda x: x.expand(target_items[0].size(0)), auxiliary_items))
        query_users = self.useritem_embeds(*query_users, is_user=True)
        target_items, auxiliary_items = self.useritem_embeds(*target_items, is_user=False), self.useritem_embeds(
            *auxiliary_items, is_user=False)
        target_x = torch.cat((*query_users, *target_items), dim=1)
        auxiliary_x = torch.cat((*query_users, *auxiliary_items), dim=1)
        for target_layer, auxiliary_layer, transfer_layer in zip(self.target_layers, self.auxiliary_layers,
                                                                 self.transfer_layers):
            new_target_x = target_layer(target_x) + transfer_layer(auxiliary_x)
            new_auxiliary_x = auxiliary_layer(auxiliary_x) + transfer_layer(target_x)
            target_x, auxiliary_x = new_target_x, new_auxiliary_x
            target_x, auxiliary_x = torch.relu_(target_x), torch.relu_(auxiliary_x)
        if only_target:
            return self.target_output(target_x).squeeze(-1)
        else:
            return self.target_output(target_x).squeeze(-1), self.auxiliary_output(auxiliary_x).squeeze(-1)


class HybridRecommender(Recommender):
    def __init__(self, useritem_embeds, input_size, hidden_layers, final_size, activation='relu',
                 normalization="batch_norm"):
        super(HybridRecommender, self).__init__(useritem_embeds, False, False)
        self.interaction_model = MLP(input_size=2 * input_size, hidden_layers=hidden_layers, activation=activation,
                                     normalization=normalization, final_activation='none', final_size=final_size)
        self.final_layer = nn.Linear(input_size + final_size, 1)

    def forward(self, query_users, query_items, with_attr=False):
        query_users, query_items = Recommender.forward(self, query_users, query_items, with_attr=with_attr)
        query_users, query_items = query_users[0], query_items[0]
        if query_users.size(0) == 1:
            query_users = query_users.expand(query_items.size(0), -1)
        interactions = torch.cat((query_users, query_items), dim=-1)
        interactions = self.interaction_model(interactions)
        product = query_users * query_items
        concatenation = torch.cat((interactions, product), dim=-1)
        return self.final_layer(concatenation).squeeze(-1)