from functools import partial

from keras.models import Model, Input
from keras.optimizers import Adam
from keras.layers import Lambda
from config import main_config
from layers import custom_layers
from layers.losses import wasserstein_loss, gradient_penalty_loss, \
  confidence_reconstruction_loss, id_mrf_loss
from models.discriminator import GlobalDiscriminator, LocalDiscriminator
from models.generator import Generator
from models.wgan import WassersteinGAN
from utils import constants


class GMCNNGan(WassersteinGAN):
  
  def __init__(self, batch_size, img_height, img_width, num_channels, warm_up_generator,
               config: main_config.MainConfig, output_paths: constants.OutputPaths):
    super(GMCNNGan, self).__init__(img_height, img_width, num_channels, batch_size,
                                   config.training.wgan_training_ratio, output_paths)
    
    self.img_height = img_height
    self.img_width = img_width
    self.num_channels = num_channels
    self.warm_up_generator = warm_up_generator
    self.learning_rate = config.training.learning_rate
    self.num_gaussian_steps = config.model.num_gaussian_steps
    self.gradient_penalty_loss_weight = config.model.gradient_penalty_loss_weight
    self.id_mrf_loss_weight = config.model.id_mrf_loss_weight
    self.adversarial_loss_weight = config.model.adversarial_loss_weight
    self.nn_stretch_sigma = config.model.nn_stretch_sigma
    self.vgg_16_layers = config.model.vgg_16_layers
    self.id_mrf_style_weight = config.model.id_mrf_style_weight
    self.id_mrf_content_weight = config.model.id_mrf_content_weight
    self.gaussian_kernel_size = config.model.gaussian_kernel_size
    self.gaussian_kernel_std = config.model.gaussian_kernel_std
    self.add_mask_as_generator_input = config.model.add_mask_as_generator_input
    
    self.generator_optimizer = Adam(lr=self.learning_rate, beta_1=0.5, beta_2=0.9)
    self.discriminator_optimizer = Adam(lr=self.learning_rate, beta_1=0.5, beta_2=0.9)
    
    self.local_discriminator_raw = LocalDiscriminator(self.img_height, self.img_width,
                                                      self.num_channels, output_paths)
    self.global_discriminator_raw = GlobalDiscriminator(self.img_height, self.img_width,
                                                        self.num_channels, output_paths)
    self.generator_raw = Generator(self.img_height, self.img_width, self.num_channels,
                                   self.add_mask_as_generator_input, output_paths)
    
    # define generator model
    self.global_discriminator_raw.disable()
    self.local_discriminator_raw.disable()
    self.generator_model = self.define_generator_model(self.generator_raw,
                                                       self.local_discriminator_raw,
                                                       self.global_discriminator_raw)
    
    # define global discriminator model
    self.global_discriminator_raw.enable()
    self.generator_raw.disable()
    self.global_discriminator_model = self.define_global_discriminator(self.generator_raw,
                                                                       self.global_discriminator_raw)
    
    # define local discriminator model
    self.local_discriminator_raw.enable()
    self.global_discriminator_raw.disable()
    self.local_discriminator_model = self.define_local_discriminator(self.generator_raw,
                                                                     self.local_discriminator_raw)
  
  def define_generator_model(self, generator_raw, local_discriminator_raw,
                             global_discriminator_raw):
    
    generator_inputs_img = Input(shape=(self.img_height, self.img_width, self.num_channels))
    generator_inputs_mask = Input(shape=(self.img_height, self.img_width, self.num_channels))
    
    generator_outputs = generator_raw.model([generator_inputs_img, generator_inputs_mask])
    global_discriminator_outputs = global_discriminator_raw.model(generator_outputs)
    
    local_discriminator_outputs = local_discriminator_raw.model([generator_outputs,
                                                                 generator_inputs_mask])
    
    generator_model = Model(inputs=[generator_inputs_img, generator_inputs_mask],
                            outputs=[generator_outputs, generator_outputs,
                                     global_discriminator_outputs,
                                     local_discriminator_outputs])
    
    # this partial trick is required for passing additional parameters for loss functions
    partial_cr_loss = partial(confidence_reconstruction_loss,
                              mask=generator_inputs_mask,
                              num_steps=self.num_gaussian_steps,
                              gaussian_kernel_size=self.gaussian_kernel_size,
                              gaussian_kernel_std=self.gaussian_kernel_std)
    
    partial_cr_loss.__name__ = 'confidence_reconstruction_loss'
    
    partial_id_mrf_loss = partial(id_mrf_loss,
                                  mask=generator_inputs_mask,
                                  nn_stretch_sigma=self.nn_stretch_sigma,
                                  batch_size=self.batch_size,
                                  vgg_16_layers=self.vgg_16_layers,
                                  id_mrf_style_weight=self.id_mrf_style_weight,
                                  id_mrf_content_weight=self.id_mrf_content_weight,
                                  id_mrf_loss_weight=self.id_mrf_loss_weight)
    
    partial_id_mrf_loss.__name__ = 'id_mrf_loss'
    
    partial_wasserstein_loss = partial(wasserstein_loss,
                                       wgan_loss_weight=self.adversarial_loss_weight)
    
    partial_wasserstein_loss.__name__ = 'wasserstein_loss'
    
    if self.warm_up_generator:
      # set Wasserstein loss to 0 - total generator loss will be based only on reconstruction loss
      generator_model.compile(optimizer=self.generator_optimizer,
                              loss=[partial_cr_loss, partial_id_mrf_loss, partial_wasserstein_loss,
                                    partial_wasserstein_loss],
                              loss_weights=[1., 0., 0., 0.])
      # metrics=[metrics.psnr])
    else:
      generator_model.compile(optimizer=self.generator_optimizer,
                              loss=[partial_cr_loss, partial_id_mrf_loss, partial_wasserstein_loss,
                                    partial_wasserstein_loss])
    
    return generator_model
  
  def define_global_discriminator(self, generator_raw, global_discriminator_raw):
    generator_inputs = Input(shape=(self.img_height, self.img_width, self.num_channels))
    generator_masks = Input(shape=(self.img_height, self.img_width, self.num_channels))
    
    real_samples = Input(shape=(self.img_height, self.img_width, self.num_channels))
    fake_samples = generator_raw.model([generator_inputs, generator_masks])
    # fake_samples = generator_inputs * (1 - generator_masks) + fake_samples * generator_masks
    fake_samples = Lambda(make_comp_sample)([generator_inputs, fake_samples, generator_masks])
    
    discriminator_output_from_fake_samples = global_discriminator_raw.model(fake_samples)
    discriminator_output_from_real_samples = global_discriminator_raw.model(real_samples)
    
    averaged_samples = custom_layers.RandomWeightedAverage()([real_samples, fake_samples])
    # We then run these samples through the discriminator as well. Note that we never
    # really use the discriminator output for these samples - we're only running them to
    # get the gradient norm for the gradient penalty loss.
    averaged_samples_outputs = global_discriminator_raw.model(averaged_samples)
    
    # The gradient penalty loss function requires the input averaged samples to get
    # gradients. However, Keras loss functions can only have two arguments, y_true and
    # y_pred. We get around this by making a partial() of the function with the averaged
    # samples here.
    partial_gp_loss = partial(gradient_penalty_loss,
                              averaged_samples=averaged_samples,
                              gradient_penalty_weight=self.gradient_penalty_loss_weight)
    # Functions need names or Keras will throw an error
    partial_gp_loss.__name__ = 'gradient_penalty'
    
    global_discriminator_model = Model(inputs=[real_samples, generator_inputs, generator_masks],
                                       outputs=[discriminator_output_from_real_samples,
                                                discriminator_output_from_fake_samples,
                                                averaged_samples_outputs])
    # We use the Adam paramaters from Gulrajani et al. We use the Wasserstein loss for both
    # the real and generated samples, and the gradient penalty loss for the averaged samples
    global_discriminator_model.compile(optimizer=self.discriminator_optimizer,
                                       loss=[wasserstein_loss, wasserstein_loss, partial_gp_loss])
    
    return global_discriminator_model
  
  def define_local_discriminator(self, generator_raw, local_discriminator_raw):
    generator_inputs = Input(shape=(self.img_height, self.img_width, self.num_channels))
    generator_masks = Input(shape=(self.img_height, self.img_width, self.num_channels))
    
    real_samples = Input(shape=(self.img_height, self.img_width, self.num_channels))
    fake_samples = generator_raw.model([generator_inputs, generator_masks])
    # fake_samples = generator_inputs * (1 - generator_masks) + fake_samples * generator_masks
    # fake_samples = Lambda(make_comp_sample)([generator_inputs, fake_samples, generator_masks])
    
    discriminator_output_from_fake_samples = local_discriminator_raw.model(
      [fake_samples, generator_masks])
    discriminator_output_from_real_samples = local_discriminator_raw.model(
      [real_samples, generator_masks])
    
    averaged_samples = custom_layers.RandomWeightedAverage()([real_samples, fake_samples])
    averaged_samples_output = local_discriminator_raw.model([averaged_samples, generator_masks])
    
    partial_gp_loss = partial(gradient_penalty_loss,
                              averaged_samples=averaged_samples,
                              gradient_penalty_weight=self.gradient_penalty_loss_weight)
    partial_gp_loss.__name__ = 'gradient_penalty'
    
    local_discriminator_model = Model(inputs=[real_samples, generator_inputs, generator_masks],
                                      outputs=[discriminator_output_from_real_samples,
                                               discriminator_output_from_fake_samples,
                                               averaged_samples_output])
    
    local_discriminator_model.compile(optimizer=self.discriminator_optimizer,
                                      loss=[wasserstein_loss, wasserstein_loss, partial_gp_loss])
    return local_discriminator_model
  
  @property
  def global_discriminator(self):
    return self.global_discriminator_model
  
  @property
  def local_discriminator(self):
    return self.local_discriminator_model
  
  @property
  def generator(self):
    return self.generator_model
  
  @property
  def generator_for_prediction(self):
    return self.generator_raw.model


def make_comp_sample(inputs):
  generator_inputs, fake_samples, generator_masks = inputs
  return generator_inputs * (1 - generator_masks) + fake_samples * generator_masks