import torch
from torch import nn
from torch import distributions


class CaptioningModel(nn.Module):
    def __init__(self):
        super(CaptioningModel, self).__init__()

    def init_weights(self):
        raise NotImplementedError

    def init_state(self, b_s, device):
        raise NotImplementedError

    def step(self, t, state, prev_output, images, seq, *args, mode='teacher_forcing'):
        raise NotImplementedError

    def forward(self, images, seq, *args):
        device = images.device
        b_s = images.size(0)
        seq_len = seq.size(1)
        state = self.init_state(b_s, device)
        out = None

        outputs = []
        for t in range(seq_len):
            out, state = self.step(t, state, out, images, seq, *args, mode='teacher_forcing')
            outputs.append(out)

        outputs = torch.cat([o.unsqueeze(1) for o in outputs], 1)
        return outputs

    def test(self, images, seq_len, *args):
        device = images.device
        b_s = images.size(0)
        state = self.init_state(b_s, device)
        out = None

        outputs = []
        for t in range(seq_len):
            out, state = self.step(t, state, out, images, None, *args, mode='feedback')
            out = torch.max(out, -1)[1]
            outputs.append(out)

        return torch.cat([o.unsqueeze(1) for o in outputs], 1)

    def sample_rl(self, images, seq_len, *args):
        device = images.device
        b_s = images.size(0)
        state = self.init_state(b_s, device)
        out = None

        outputs = []
        log_probs = []
        for t in range(seq_len):
            out, state = self.step(t, state, out, images, None, *args, mode='feedback')
            distr = distributions.Categorical(logits=out)
            out = distr.sample()
            outputs.append(out)
            log_probs.append(distr.log_prob(out))

        return torch.cat([o.unsqueeze(1) for o in outputs], 1), torch.cat([o.unsqueeze(1) for o in log_probs], 1)

    def beam_search(self, images, seq_len, eos_idx, beam_size, out_size=1, *args):
        device = images.device
        b_s = images.size(0)
        images_shape = images.shape
        state = self.init_state(b_s, device)

        seq_mask = images.data.new_ones((b_s, beam_size, 1))
        seq_logprob = images.data.new_zeros((b_s, 1, 1))
        outputs = []
        log_probs = []
        selected_words = None

        for t in range(seq_len):
            cur_beam_size = 1 if t == 0 else beam_size

            word_logprob, state = self.step(t, state, selected_words, images, None, *args, mode='feedback')
            old_seq_logprob = seq_logprob
            word_logprob = word_logprob.view(b_s, cur_beam_size, -1)
            seq_logprob = seq_logprob + word_logprob

            # Mask sequence if it reaches EOS
            if t > 0:
                mask = (selected_words.view(b_s, cur_beam_size) != eos_idx).float().unsqueeze(-1)
                seq_mask = seq_mask * mask
                word_logprob = word_logprob * seq_mask.expand_as(word_logprob)
                old_seq_logprob = old_seq_logprob.expand_as(seq_logprob).contiguous()
                old_seq_logprob[:, :, 1:] = -999
                seq_logprob = seq_mask*seq_logprob + old_seq_logprob*(1-seq_mask)

            selected_logprob, selected_idx = torch.sort(seq_logprob.view(b_s, -1), -1, descending=True)
            selected_logprob, selected_idx = selected_logprob[:, :beam_size], selected_idx[:, :beam_size]

            selected_beam = selected_idx / seq_logprob.shape[-1]
            selected_words = selected_idx - selected_beam*seq_logprob.shape[-1]

            new_state = []
            for s in state:
                shape = [int(sh) for sh in s.shape]
                beam = selected_beam
                for _ in shape[1:]:
                    beam = beam.unsqueeze(-1)
                s = torch.gather(s.view(*([b_s, cur_beam_size] + shape[1:])), 1,
                                 beam.expand(*([b_s, beam_size] + shape[1:])))
                s = s.view(*([-1,] + shape[1:]))
                new_state.append(s)
            state = tuple(new_state)

            images_exp_shape = (b_s, cur_beam_size) + images_shape[1:]
            images_red_shape = (b_s * beam_size, ) + images_shape[1:]
            selected_beam_red_size = (b_s, beam_size) + tuple(1 for _ in range(len(images_exp_shape)-2))
            selected_beam_exp_size = (b_s, beam_size) + images_exp_shape[2:]
            images_exp = images.view(images_exp_shape)
            selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size)
            images = torch.gather(images_exp, 1, selected_beam_exp).view(images_red_shape)
            seq_logprob = selected_logprob.unsqueeze(-1)
            seq_mask = torch.gather(seq_mask, 1, selected_beam.unsqueeze(-1))

            outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs)
            outputs.append(selected_words.unsqueeze(-1))

            this_word_logprob = torch.gather(word_logprob, 1, selected_beam.unsqueeze(-1).expand(b_s, beam_size, word_logprob.shape[-1]))
            this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1))
            log_probs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(b_s, beam_size, 1)) for o in log_probs)
            log_probs.append(this_word_logprob)
            selected_words = selected_words.view(-1)

        # Sort result
        seq_logprob, sort_idxs = torch.sort(seq_logprob, 1, descending=True)
        outputs = torch.cat(outputs, -1)
        outputs = torch.gather(outputs, 1, sort_idxs.expand(b_s, beam_size, seq_len))
        log_probs = torch.cat(log_probs, -1)
        log_probs = torch.gather(log_probs, 1, sort_idxs.expand(b_s, beam_size, seq_len))

        outputs = outputs.contiguous()[:, :out_size]
        log_probs = log_probs.contiguous()[:, :out_size]
        if out_size == 1:
            outputs = outputs.squeeze(1)
            log_probs = log_probs.squeeze(1)
        return outputs, log_probs