import itertools

import torch
import torch.nn as nn
import torch.nn.init as init

from jacinle.utils.container import GView
from jactorch.functional.indexing import index_one_hot_ellipsis
from jactorch.functional.kernel import cosine_distance
from jactorch.functional.linalg import normalize
from jactorch.nn.rnn_utils import rnn_with_length
from jactorch.graph.variable import var_with
from jactorch.quickstart.models import MLPModel


def cosine(input1, input2):
    input1 = normalize(input1, eps=1e-6)
    input2 = normalize(input2, eps=1e-6)
    return (input1 * input2).sum(dim=-1)


def cosine_loss(input1, input2):
    return 1 - cosine(input1, input2)


class CompletionModel(nn.Module):
    def __init__(self, embedding):
        super(CompletionModel, self).__init__()
        self.embedding = embedding
        self.embedding_dim = embedding.embedding_dim
        self.hidden_dim = 512
        self.image_dim = 2048
        self.gru_f = nn.GRU(self.embedding_dim, self.hidden_dim, 1, batch_first=True, bidirectional=False)
        self.gru_b = nn.GRU(self.embedding_dim, self.hidden_dim, 1, batch_first=True, bidirectional=False)
        self.predict = MLPModel(self.hidden_dim * 2 + self.image_dim, self.embedding_dim, [], activation='relu')
        self.init_weights()

    def init_weights(self):
        for name, parameter in itertools.chain(self.gru_f.named_parameters(), self.gru_b.named_parameters()):
            if name.startswith('weight'):
                init.orthogonal(parameter.data)
            elif name.startswith('bias'):
                parameter.data.zero_()
            else:
                raise ValueError('Unknown parameter type: {}'.format(name))

    def forward(self, feed_dict):
        feed_dict = GView(feed_dict)
        feature_f = self._extract_sent_feature(feed_dict.sent_f, feed_dict.sent_f_length, self.gru_f)
        feature_b = self._extract_sent_feature(feed_dict.sent_b, feed_dict.sent_b_length, self.gru_b)
        feature_img = feed_dict.image
        
        feature = torch.cat([feature_f, feature_b, feature_img], dim=1)
        predict = self.predict(feature)

        if self.training:
            label = self.embedding(feed_dict.label)
            loss = cosine_loss(predict, label).mean()
            return loss, {}, {}
        else:
            output_dict = dict(pred=predict)
            if 'label' in feed_dict:
                dis = cosine_distance(predict, self.embedding.weight)
                _, topk = dis.topk(1000, dim=1, sorted=True)
                for k in [1, 10, 100, 1000]:
                    output_dict['top{}'.format(k)] = torch.eq(topk, feed_dict.label.unsqueeze(-1))[:, :k].float().sum(dim=1).mean()
            return output_dict

    def _extract_sent_feature(self, sent, length, gru):
        sent = self.embedding(sent)
        batch_size = sent.size(0)

        state_shape = (1, batch_size, self.hidden_dim)
        initial_state = var_with(torch.zeros(state_shape), sent)
        rnn_output, _ = rnn_with_length(gru, sent, length, initial_state)
        rnn_result = index_one_hot_ellipsis(rnn_output, 1, length - 1)

        return rnn_result