from collections import defaultdict
import itertools
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import pyro
import pyro.distributions as dist
import pyro.optim as optim
from pyro.infer import SVI, Trace_ELBO

from .base_model import BaseModel
from models.networks.pose_rnn import PoseRNN
from models.networks.sequence_encoder import SequenceEncoder
from models.networks.encoder import ImageEncoder
from models.networks.decoder import ImageDecoder
import utils


class DDPAE(BaseModel):
  '''
  The DDPAE model.
  '''
  def __init__(self, opt):
    super(DDPAE, self).__init__()

    self.is_train = opt.is_train
    assert opt.image_size[0] == opt.image_size[1]
    self.image_size = opt.image_size[0]
    print('Image size: {}'.format(self.image_size))
    self.object_size = self.image_size // 2

    # Data parameters
    # self.__dict__.update(opt.__dict__)
    self.n_channels = opt.n_channels
    self.n_components = opt.n_components
    self.total_components = self.n_components
    self.batch_size = opt.batch_size
    self.n_frames_input = opt.n_frames_input
    self.n_frames_output = opt.n_frames_output
    self.n_frames_total = self.n_frames_input + self.n_frames_output
    # Hyperparameters
    self.image_latent_size = opt.image_latent_size
    self.content_latent_size = opt.content_latent_size
    self.pose_latent_size = opt.pose_latent_size
    self.hidden_size = opt.hidden_size
    self.ngf = opt.ngf
    self.independent_components = opt.independent_components
    self.predict_loss_only = False

    # Training parameters
    if opt.is_train:
      self.lr_init = opt.lr_init
      self.lr_decay = opt.lr_decay
      self.when_to_predict_only = opt.when_to_predict_only

    # Networks
    self.setup_networks()

    # Priors
    self.scale = opt.stn_scale_prior
    # Initial pose prior
    self.initial_pose_prior_mu = Variable(torch.cuda.FloatTensor([self.scale, 0, 0]))
    self.initial_pose_prior_sigma = Variable(torch.cuda.FloatTensor([0.2, 1, 1]))
    # Beta prior
    sd = 0.1
    self.beta_prior_mu = Variable(torch.zeros(self.pose_latent_size).cuda())
    self.beta_prior_sigma = Variable(torch.ones(self.pose_latent_size).cuda() * sd)

  def setup_networks(self):
    '''
    Networks for DDPAE.
    '''
    self.nets = {}
    # These will be registered in model() and guide() with pyro.module().
    self.model_modules = {}
    self.guide_modules = {}

    # Backbone, Pose RNN
    pose_model = PoseRNN(self.n_components, self.n_frames_output, self.n_channels,
                         self.image_size, self.image_latent_size, self.hidden_size,
                         self.ngf, self.pose_latent_size, self.independent_components)
    self.pose_model = nn.DataParallel(pose_model.cuda())

    self.nets['pose_model'] = self.pose_model
    self.guide_modules['pose_model'] = self.pose_model

    # Content LSTM
    content_lstm = SequenceEncoder(self.content_latent_size, self.hidden_size,
                                   self.content_latent_size * 2)
    self.content_lstm = nn.DataParallel(content_lstm.cuda())
    self.nets['content_lstm'] = self.content_lstm
    self.model_modules['content_lstm'] = self.content_lstm

    # Image encoder and decoder
    n_layers = int(np.log2(self.object_size)) - 1
    object_encoder = ImageEncoder(self.n_channels, self.content_latent_size,
                                  self.ngf, n_layers)
    object_decoder = ImageDecoder(self.content_latent_size, self.n_channels,
                                  self.ngf, n_layers, 'sigmoid')
    self.object_encoder = nn.DataParallel(object_encoder.cuda())
    self.object_decoder = nn.DataParallel(object_decoder.cuda())
    self.nets.update({'object_encoder': self.object_encoder,
                      'object_decoder': self.object_decoder})
    self.model_modules['decoder'] = self.object_decoder
    self.guide_modules['encoder'] = self.object_encoder

  def setup_training(self):
    '''
    Setup Pyro SVI, optimizers.
    '''
    if not self.is_train:
      return

    self.pyro_optimizer = optim.Adam({'lr': self.lr_init})
    self.svis = {'elbo': SVI(self.model, self.guide, self.pyro_optimizer, loss=Trace_ELBO())}

    # Separate pose_model parameters and other networks' parameters
    params = []
    for name, net in self.nets.items():
      if name != 'pose_model':
        params.append(net.parameters())
    self.optimizer = torch.optim.Adam(\
                     [{'params': self.pose_model.parameters(), 'lr': self.lr_init},
                      {'params': itertools.chain(*params), 'lr': self.lr_init}
                     ], betas=(0.5, 0.999))

  def get_objects(self, input, transformer):
    '''
    Crop objects from input given the transformer.
    '''
    # Repeat input: batch_size x n_frames_input x n_components x C x H x W
    repeated_input = torch.stack([input] * self.n_components, dim=2)
    repeated_input = repeated_input.view(-1, *input.size()[-3:])
    # Crop objects
    transformer = transformer.contiguous().view(-1, transformer.size(-1))
    input_obj = utils.image_to_object(repeated_input, transformer, self.object_size)
    input_obj = input_obj.view(-1, *input_obj.size()[-3:])
    return input_obj

  def constrain_pose(self, pose):
    '''
    Constrain the value of the pose vectors.
    '''
    # Makes training faster.
    scale = torch.clamp(pose[..., :1], self.scale - 1, self.scale + 1)
    xy = F.tanh(pose[..., 1:]) * (scale - 0.5)
    pose = torch.cat([scale, xy], dim=-1)
    return pose

  def sample_latent(self, input, input_latent_mu, input_latent_sigma, pred_latent_mu,
                    pred_latent_sigma, initial_pose_mu, initial_pose_sigma, sample=True):
    '''
    Return latent variables: dictionary containing pose and content.
    Then, crop objects from the images and encode into z.
    '''
    latent = defaultdict(lambda: None)

    beta = self.get_transitions(input_latent_mu, input_latent_sigma,
                                pred_latent_mu, pred_latent_sigma, sample)
    pose = self.accumulate_pose(beta)
    # Sample initial pose
    initial_pose = self.pyro_sample('initial_pose', dist.Normal, initial_pose_mu,
                                    initial_pose_sigma, sample)
    pose += initial_pose.view(-1, 1, self.n_components, self.pose_latent_size)
    pose = self.constrain_pose(pose)

    # Get input objects
    input_pose = pose[:, :self.n_frames_input, :, :]
    input_obj = self.get_objects(input, input_pose)
    # Encode the sampled objects
    z = self.object_encoder(input_obj)
    z = self.sample_content(z, sample)
    latent.update({'pose': pose, 'content': z})
    return latent

  def sample_latent_prior(self, input):
    '''
    Return latent variables: [pose, z], sampled from prior distribution.
    '''
    latent = defaultdict(lambda: None)

    batch_size = input.size(0)
    # z prior
    N = batch_size * self.total_components
    z_prior_mu = Variable(torch.zeros(N, self.content_latent_size).cuda())
    z_prior_sigma = Variable(torch.ones(N, self.content_latent_size).cuda())
    z = self.pyro_sample('content', dist.Normal, z_prior_mu, z_prior_sigma, sample=True)

    # input_beta prior
    N = batch_size * self.n_frames_input * self.n_components
    input_beta_prior_mu = self.beta_prior_mu.repeat(N, 1)
    input_beta_prior_sigma = self.beta_prior_sigma.repeat(N, 1)
    input_beta = self.pyro_sample('input_beta', dist.Normal, input_beta_prior_mu,
                                  input_beta_prior_sigma, sample=True)
    beta = input_beta.view(batch_size, self.n_frames_input, self.n_components,
                           self.pose_latent_size)

    # pred_beta prior
    M = batch_size * self.n_frames_output * self.n_components
    pred_beta_prior_mu = self.beta_prior_mu.repeat(M, 1)
    pred_beta_prior_sigma = self.beta_prior_sigma.repeat(M, 1)
    pred_beta = self.pyro_sample('pred_beta', dist.Normal, pred_beta_prior_mu,
                                 pred_beta_prior_sigma, sample=True)
    pred_beta = pred_beta.view(batch_size, self.n_frames_output, self.n_components,
                               self.pose_latent_size)
    beta = torch.cat([beta, pred_beta], dim=1)

    # Get pose
    pose = self.accumulate_pose(beta)
    # Sample and add initial pose
    N = batch_size * self.n_components
    initial_pose_prior_mu = self.initial_pose_prior_mu.repeat(N, 1)
    initial_pose_prior_sigma = self.initial_pose_prior_sigma.repeat(N, 1)
    initial_pose = self.pyro_sample('initial_pose', dist.Normal, initial_pose_prior_mu,
                                    initial_pose_prior_sigma, sample=True)
    pose += initial_pose.view(-1, 1, self.n_components, self.pose_latent_size)
    pose = self.constrain_pose(pose)

    latent.update({'pose': pose, 'content': z})
    return latent

  def decode_components(self, latent):
    '''
    param latent: return value from self.sample_latent()
    Return values:
    components: (batch_size * n_frames * n_components) x n_channels x image_size x image_size
    '''
    pose, z = latent['pose'], latent['content']

    # (batch_size * n_frames_total * n_components) x content_latent_size
    z = z.view(-1, self.content_latent_size)
    objects = self.object_decoder(z)
    objects = objects.view(-1, *objects.size()[-3:]) # N x C x H x W
    pose = pose.view(-1, self.pose_latent_size)
    components = utils.object_to_image(objects, pose, self.image_size)
    latent['Y'] = objects
    return components

  def get_transitions(self, input_latent_mu, input_latent_sigma, pred_latent_mu,
                      pred_latent_sigma, sample=True):
    '''
    Sample the transition variables beta.
    '''
    # input_beta: (batch_size * n_frames_input * n_components) x pose_latent_size
    input_beta = self.pyro_sample('input_beta', dist.Normal, input_latent_mu,
                                  input_latent_sigma, sample)
    beta = input_beta.view(-1, self.n_frames_input, self.n_components, self.pose_latent_size)

    # pred_beta: (batch_size * n_frames_output) x n_components x pose_latent_size
    pred_beta = self.pyro_sample('pred_beta', dist.Normal, pred_latent_mu,
                                 pred_latent_sigma, sample)
    pred_beta = pred_beta.view(-1, self.n_frames_output, self.n_components,
                               self.pose_latent_size)
    # Concatenate the input and prediction beta
    beta = torch.cat([beta, pred_beta], dim=1)
    return beta

  def accumulate_pose(self, beta):
    '''
    Accumulate pose from the transition variables beta.
    pose_k = sum_{i=1}^k beta_k
    '''
    batch_size, n_frames, _, pose_latent_size = beta.size()
    accumulated = []
    for i in range(n_frames):
      if i == 0:
        p_i = beta[:, 0:1, :, :]
      else:
        p_i = beta[:, i:(i+1), :, :] + accumulated[-1]
      accumulated.append(p_i)
    accumulated = torch.cat(accumulated, dim=1)
    return accumulated

  def sample_content(self, content, sample):
    '''
    Pass into content_lstm to get a final content.
    '''
    content = content.view(-1, self.n_frames_input, self.total_components, self.content_latent_size)
    contents = []
    for i in range(self.total_components):
      z = content[:, :, i, :]
      z = self.content_lstm(z).unsqueeze(1) # batch_size x 1 x (content_latent_size * 2)
      contents.append(z)
    content = torch.cat(contents, dim=1).view(-1, self.content_latent_size * 2)

    # Get mu and sigma, and sample.
    content_mu = content[:, :self.content_latent_size]
    content_sigma = F.softplus(content[:, self.content_latent_size:])
    content = self.pyro_sample('content', dist.Normal, content_mu, content_sigma, sample)
    return content

  def get_output(self, components, latent):
    '''
    Take the sum of the components.
    '''
    # components: batch_size x n_frames_total x total_components x C x H x W
    batch_size = components.size(0)
    # Sum the components
    output = torch.sum(components, dim=2)
    output = torch.clamp(output, max=1)
    return output

  def encode(self, input, sample=True):
    '''
    Encode video with pose_model, and sample the latent variables for reconstruction
    and prediction.
    Note: pyro.sample is called in self.sample_latent().
    param input: video of size (batch_size, n_frames_input, C, H, W)
    param sample: True if this is called by guide(), and sample with pyro.sample.
    Return latent: a dictionary {'pose': pose, 'content': content, ...}
    '''
    input_latent_mu, input_latent_sigma, pred_latent_mu, pred_latent_sigma,\
        initial_pose_mu, initial_pose_sigma = self.pose_model(input)

    # Sample latent variables
    latent = self.sample_latent(input, input_latent_mu, input_latent_sigma, pred_latent_mu,
                                pred_latent_sigma, initial_pose_mu, initial_pose_sigma, sample)
    return latent

  def decode(self, latent, batch_size):
    '''
    Decode the latent variables into components, and produce the final output.
    param latent: dictionary, return values from self.encode()
    Return values:
    output: batch_size x n_frames_total x n_channels x image_size x image_size
    components: batch_size x n_frames_total x total_components x n_channels x image_size x image_size
    '''
    # Get the final content: batch_size x 1 x total_components x content_latent_size
    content = latent['content']
    content = content.view(batch_size, 1, self.total_components, content.size(-1))
    # Repeat the contents
    content = content.repeat(1, self.n_frames_total, 1, 1)
    latent['content'] = content

    components = self.decode_components(latent)
    components = components.view(-1, self.n_frames_total, self.total_components,
                                 self.n_channels, self.image_size, self.image_size)
    output = self.get_output(components, latent)
    return output, components

  def model(self, input, output):
    '''
    Likelihood model: sample from prior, then decode to video.
    param input: video of size (batch_size, self.n_frames_input, C, H, W)
    param output: video of size (batch_size, self.n_frames_output, C, H, W)
    '''
    # Register networks
    for name, net in self.model_modules.items():
      pyro.module(name, net)

    observation = torch.cat([input, output], dim=1)

    # Sample from prior
    latent = self.sample_latent_prior(input)
    # Decode
    decoded_output, components = self.decode(latent, input.size(0))
    decoded_output = decoded_output.view(*observation.size())
    if self.predict_loss_only:
      # Only consider loss from the predicted frames
      decoded_output = decoded_output[:, self.n_frames_input:]
      observation = observation[:, self.n_frames_input:]
      components = components[:, self.n_frames_input:, ...]

    # pyro observe
    sd = Variable(0.3 * torch.ones(*decoded_output.size()).cuda())
    pyro.sample('obs', dist.Normal(decoded_output, sd), obs=observation)

  def guide(self, input, output):
    '''
    Posterior model: encode input
    param input: video of size (batch_size, n_frames_input, C, H, W).
    parma output: not used.
    '''
    # Register networks
    for name, net in self.guide_modules.items():
      pyro.module(name, net)

    self.encode(input, sample=True)

  def train(self, input, output):
    '''
    param input: video of size (batch_size, n_frames_input, C, H, W)
    param output: video of size (batch_size, self.n_frames_output, C, H, W)
    Return video_dict, loss_dict
    '''
    input = Variable(input.cuda(), requires_grad=False)
    output = Variable(output.cuda(), requires_grad=False)
    assert input.size(1) == self.n_frames_input

    # SVI
    batch_size, _, C, H, W = input.size()
    numel = batch_size * self.n_frames_total * C * H * W
    loss_dict = {}
    for name, svi in self.svis.items():
      # loss = svi.step(input, output)
      # Note: backward() is already called in loss_and_grads.
      loss = svi.loss_and_grads(svi.model, svi.guide, input, output)
      loss_dict[name] = loss / numel

    # Update parameters
    self.optimizer.step()
    self.optimizer.zero_grad()

    return {}, loss_dict

  def test(self, input, output):
    '''
    Return decoded output.
    '''
    input = Variable(input.cuda())
    batch_size, _, _, H, W = input.size()
    output = Variable(output.cuda())
    gt = torch.cat([input, output], dim=1)

    latent = self.encode(input, sample=False)
    decoded_output, components = self.decode(latent, input.size(0))
    decoded_output = decoded_output.view(*gt.size())
    components = components.view(batch_size, self.n_frames_total, self.total_components,
                                 self.n_channels, H, W)
    latent['components'] = components
    decoded_output = decoded_output.clamp(0, 1)

    self.save_visuals(gt, decoded_output, components, latent)
    return decoded_output.cpu(), latent

  def save_visuals(self, gt, output, components, latent):
    '''
    Save results. Draw bounding boxes on each component.
    '''
    pose = latent['pose']
    components = components.detach().cpu()
    for i in range(self.n_components):
      p = pose.data[0, :, i, :].cpu()
      images = components.data[0, :, i, ...]
      images = utils.draw_components(images, p)
      components.data[0, :, i, ...] = images

    super(DDPAE, self).save_visuals(gt, output, components, latent)

  def update_hyperparameters(self, epoch, n_epochs):
    '''
    If when_to_predict_only > 0 and it halfway through training, then only train with
    prediction loss.
    '''
    lr_dict = super(DDPAE, self).update_hyperparameters(epoch, n_epochs)

    if self.when_to_predict_only > 0 and epoch > int(n_epochs * self.when_to_predict_only):
      self.predict_loss_only = True

    return lr_dict