import logging
import pickle
import numpy as np
from nltk.tokenize import TweetTokenizer

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn_utils

MM_EMBEDDINGS_DIM = 50
MM_HIDDEN_SIZE = 128
MM_MAX_DICT_SIZE = 100

TOKEN_UNK = "#unk"


class Model(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(Model, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 64, 5, stride=5),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=2),
            nn.ReLU(),
        )

        conv_out_size = self._get_conv_out(input_shape)

        self.policy = nn.Sequential(
            nn.Linear(conv_out_size, n_actions),
        )

        self.value = nn.Sequential(
            nn.Linear(conv_out_size, 1),
        )

    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def forward(self, x):
        fx = x.float() / 256
        conv_out = self.conv(fx).view(fx.size()[0], -1)
        return self.policy(conv_out), self.value(conv_out)


class ModelMultimodal(nn.Module):
    def __init__(self, input_shape, n_actions, max_dict_size=MM_MAX_DICT_SIZE):
        super(ModelMultimodal, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 64, 5, stride=5),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=2),
            nn.ReLU(),
        )

        conv_out_size = self._get_conv_out(input_shape)

        self.emb = nn.Embedding(max_dict_size, MM_EMBEDDINGS_DIM)
        self.rnn = nn.LSTM(MM_EMBEDDINGS_DIM, MM_HIDDEN_SIZE, batch_first=True)

        self.policy = nn.Sequential(
            nn.Linear(conv_out_size + MM_HIDDEN_SIZE*2, n_actions),
        )

        self.value = nn.Sequential(
            nn.Linear(conv_out_size + MM_HIDDEN_SIZE*2, 1),
        )

    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def _concat_features(self, img_out, rnn_hidden):
        batch_size = img_out.size()[0]
        if isinstance(rnn_hidden, tuple):
            flat_h = list(map(lambda t: t.view(batch_size, -1), rnn_hidden))
            rnn_h = torch.cat(flat_h, dim=1)
        else:
            rnn_h = rnn_hidden.view(batch_size, -1)
        return torch.cat((img_out, rnn_h), dim=1)

    def forward(self, x):
        x_img, x_text = x
        assert isinstance(x_text, rnn_utils.PackedSequence)

        # deal with text data
        emb_out = self.emb(x_text.data)
        emb_out_seq = rnn_utils.PackedSequence(emb_out, x_text.batch_sizes)
        rnn_out, rnn_h = self.rnn(emb_out_seq)

        # extract image features
        fx = x_img.float() / 256
        conv_out = self.conv(fx).view(fx.size()[0], -1)

        feats = self._concat_features(conv_out, rnn_h)
        return self.policy(feats), self.value(feats)


class MultimodalPreprocessor:
    log = logging.getLogger("MulitmodalPreprocessor")

    def __init__(self, max_dict_size=MM_MAX_DICT_SIZE, device="cpu"):
        self.max_dict_size = max_dict_size
        self.token_to_id = {TOKEN_UNK: 0}
        self.next_id = 1
        self.tokenizer = TweetTokenizer(preserve_case=True)
        self.device = device

    def __len__(self):
        return len(self.token_to_id)

    def __call__(self, batch):
        """
        Convert list of multimodel observations (tuples with image and text string) into the form suitable
        for ModelMultimodal to disgest
        :param batch:
        """
        tokens_batch = []
        for img_obs, txt_obs in batch:
            tokens = self.tokenizer.tokenize(txt_obs)
            idx_obs = self.tokens_to_idx(tokens)
            tokens_batch.append((img_obs, idx_obs))
        # sort batch decreasing to seq len
        tokens_batch.sort(key=lambda p: len(p[1]), reverse=True)
        img_batch, seq_batch = zip(*tokens_batch)
        lens = list(map(len, seq_batch))

        # convert data into the target form
        # images
        img_v = torch.FloatTensor(img_batch).to(self.device)
        # sequences
        seq_arr = np.zeros(shape=(len(seq_batch), max(len(seq_batch[0]), 1)), dtype=np.int64)
        for idx, seq in enumerate(seq_batch):
            seq_arr[idx, :len(seq)] = seq
            # Map empty sequences into single #UNK token
            if len(seq) == 0:
                lens[idx] = 1
        seq_v = torch.LongTensor(seq_arr).to(self.device)
        seq_p = rnn_utils.pack_padded_sequence(seq_v, lens, batch_first=True)
        return img_v, seq_p

    def tokens_to_idx(self, tokens):
        res = []
        for token in tokens:
            idx = self.token_to_id.get(token)
            if idx is None:
                if self.next_id == self.max_dict_size:
                    self.log.warning("Maximum size of dict reached, token '%s' converted to #UNK token", token)
                    idx = 0
                else:
                    idx = self.next_id
                    self.next_id += 1
                    self.token_to_id[token] = idx
            res.append(idx)
        return res

    def save(self, file_name):
        with open(file_name, 'wb') as fd:
            pickle.dump(self.token_to_id, fd)
            pickle.dump(self.max_dict_size, fd)
            pickle.dump(self.next_id, fd)

    @classmethod
    def load(cls, file_name):
        with open(file_name, "rb") as fd:
            token_to_id = pickle.load(fd)
            max_dict_size = pickle.load(fd)
            next_id = pickle.load(fd)

            res = MultimodalPreprocessor(max_dict_size)
            res.token_to_id = token_to_id
            res.next_id = next_id
            return res


def train_demo(net, optimizer, batch, writer, step_idx, preprocessor, device="cpu"):
    """
    Train net on demonstration batch
    """
    batch_obs, batch_act = zip(*batch)
    batch_v = preprocessor(batch_obs).to(device)
    optimizer.zero_grad()
    ref_actions_v = torch.LongTensor(batch_act).to(device)
    policy_v = net(batch_v)[0]
    loss_v = F.cross_entropy(policy_v, ref_actions_v)
    loss_v.backward()
    optimizer.step()
    writer.add_scalar("demo_loss", loss_v.item(), step_idx)