# python3.7
"""Contains the implementation of generator described in ProgressiveGAN.

Different from the official tensorflow model in folder `pggan_tf_official`, this
is a simple pytorch version which only contains the generator part. This class
is specially used for inference.

For more details, please check the original paper:
https://arxiv.org/pdf/1710.10196.pdf
"""

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['PGGANGeneratorModel']

# Defines a dictionary, which maps the target resolution of the final generated
# image to numbers of filters used in each convolutional layer in sequence.
_RESOLUTIONS_TO_CHANNELS = {
    8: [512, 512, 512],
    16: [512, 512, 512, 512],
    32: [512, 512, 512, 512, 512],
    64: [512, 512, 512, 512, 512, 256],
    128: [512, 512, 512, 512, 512, 256, 128],
    256: [512, 512, 512, 512, 512, 256, 128, 64],
    512: [512, 512, 512, 512, 512, 256, 128, 64, 32],
    1024: [512, 512, 512, 512, 512, 256, 128, 64, 32, 16],
}

# Variable mapping from pytorch model to official tensorflow model.
_PGGAN_PTH_VARS_TO_TF_VARS = {
    'lod': 'lod',  # []
    'layer0.conv.weight': '4x4/Dense/weight',  # [512, 512, 4, 4]
    'layer0.wscale.bias': '4x4/Dense/bias',  # [512]
    'layer1.conv.weight': '4x4/Conv/weight',  # [512, 512, 3, 3]
    'layer1.wscale.bias': '4x4/Conv/bias',  # [512]
    'layer2.conv.weight': '8x8/Conv0/weight',  # [512, 512, 3, 3]
    'layer2.wscale.bias': '8x8/Conv0/bias',  # [512]
    'layer3.conv.weight': '8x8/Conv1/weight',  # [512, 512, 3, 3]
    'layer3.wscale.bias': '8x8/Conv1/bias',  # [512]
    'layer4.conv.weight': '16x16/Conv0/weight',  # [512, 512, 3, 3]
    'layer4.wscale.bias': '16x16/Conv0/bias',  # [512]
    'layer5.conv.weight': '16x16/Conv1/weight',  # [512, 512, 3, 3]
    'layer5.wscale.bias': '16x16/Conv1/bias',  # [512]
    'layer6.conv.weight': '32x32/Conv0/weight',  # [512, 512, 3, 3]
    'layer6.wscale.bias': '32x32/Conv0/bias',  # [512]
    'layer7.conv.weight': '32x32/Conv1/weight',  # [512, 512, 3, 3]
    'layer7.wscale.bias': '32x32/Conv1/bias',  # [512]
    'layer8.conv.weight': '64x64/Conv0/weight',  # [256, 512, 3, 3]
    'layer8.wscale.bias': '64x64/Conv0/bias',  # [256]
    'layer9.conv.weight': '64x64/Conv1/weight',  # [256, 256, 3, 3]
    'layer9.wscale.bias': '64x64/Conv1/bias',  # [256]
    'layer10.conv.weight': '128x128/Conv0/weight',  # [128, 256, 3, 3]
    'layer10.wscale.bias': '128x128/Conv0/bias',  # [128]
    'layer11.conv.weight': '128x128/Conv1/weight',  # [128, 128, 3, 3]
    'layer11.wscale.bias': '128x128/Conv1/bias',  # [128]
    'layer12.conv.weight': '256x256/Conv0/weight',  # [64, 128, 3, 3]
    'layer12.wscale.bias': '256x256/Conv0/bias',  # [64]
    'layer13.conv.weight': '256x256/Conv1/weight',  # [64, 64, 3, 3]
    'layer13.wscale.bias': '256x256/Conv1/bias',  # [64]
    'layer14.conv.weight': '512x512/Conv0/weight',  # [32, 64, 3, 3]
    'layer14.wscale.bias': '512x512/Conv0/bias',  # [32]
    'layer15.conv.weight': '512x512/Conv1/weight',  # [32, 32, 3, 3]
    'layer15.wscale.bias': '512x512/Conv1/bias',  # [32]
    'layer16.conv.weight': '1024x1024/Conv0/weight',  # [16, 32, 3, 3]
    'layer16.wscale.bias': '1024x1024/Conv0/bias',  # [16]
    'layer17.conv.weight': '1024x1024/Conv1/weight',  # [16, 16, 3, 3]
    'layer17.wscale.bias': '1024x1024/Conv1/bias',  # [16]
    'output0.conv.weight': 'ToRGB_lod8/weight',  # [3, 512, 1, 1]
    'output0.wscale.bias': 'ToRGB_lod8/bias',  # [3]
    'output1.conv.weight': 'ToRGB_lod7/weight',  # [3, 512, 1, 1]
    'output1.wscale.bias': 'ToRGB_lod7/bias',  # [3]
    'output2.conv.weight': 'ToRGB_lod6/weight',  # [3, 512, 1, 1]
    'output2.wscale.bias': 'ToRGB_lod6/bias',  # [3]
    'output3.conv.weight': 'ToRGB_lod5/weight',  # [3, 512, 1, 1]
    'output3.wscale.bias': 'ToRGB_lod5/bias',  # [3]
    'output4.conv.weight': 'ToRGB_lod4/weight',  # [3, 256, 1, 1]
    'output4.wscale.bias': 'ToRGB_lod4/bias',  # [3]
    'output5.conv.weight': 'ToRGB_lod3/weight',  # [3, 128, 1, 1]
    'output5.wscale.bias': 'ToRGB_lod3/bias',  # [3]
    'output6.conv.weight': 'ToRGB_lod2/weight',  # [3, 64, 1, 1]
    'output6.wscale.bias': 'ToRGB_lod2/bias',  # [3]
    'output7.conv.weight': 'ToRGB_lod1/weight',  # [3, 32, 1, 1]
    'output7.wscale.bias': 'ToRGB_lod1/bias',  # [3]
    'output8.conv.weight': 'ToRGB_lod0/weight',  # [3, 16, 1, 1]
    'output8.wscale.bias': 'ToRGB_lod0/bias',  # [3]
}


class PGGANGeneratorModel(nn.Module):
  """Defines the generator module in ProgressiveGAN.

  Note that the generated images are with RGB color channels with range [-1, 1].
  """

  def __init__(self,
               resolution=1024,
               fused_scale=False,
               output_channels=3):
    """Initializes the generator with basic settings.

    Args:
      resolution: The resolution of the final output image. (default: 1024)
      fused_scale: Whether to fused `upsample` and `conv2d` together, resulting
        in `conv2_transpose`. (default: False)
      output_channels: Number of channels of the output image. (default: 3)

    Raises:
      ValueError: If the input `resolution` is not supported.
    """
    super().__init__()

    try:
      self.channels = _RESOLUTIONS_TO_CHANNELS[resolution]
    except KeyError:
      raise ValueError(f'Invalid resolution: {resolution}!\n'
                       f'Resolutions allowed: '
                       f'{list(_RESOLUTIONS_TO_CHANNELS)}.')
    assert len(self.channels) == int(np.log2(resolution))

    self.resolution = resolution
    self.fused_scale = fused_scale
    self.output_channels = output_channels

    for block_idx in range(1, len(self.channels)):
      if block_idx == 1:
        self.add_module(
            f'layer{2 * block_idx - 2}',
            ConvBlock(in_channels=self.channels[block_idx - 1],
                      out_channels=self.channels[block_idx],
                      kernel_size=4,
                      padding=3))
      else:
        self.add_module(
            f'layer{2 * block_idx - 2}',
            ConvBlock(in_channels=self.channels[block_idx - 1],
                      out_channels=self.channels[block_idx],
                      upsample=True,
                      fused_scale=self.fused_scale))
      self.add_module(
          f'layer{2 * block_idx - 1}',
          ConvBlock(in_channels=self.channels[block_idx],
                    out_channels=self.channels[block_idx]))
      self.add_module(
          f'output{block_idx - 1}',
          ConvBlock(in_channels=self.channels[block_idx],
                    out_channels=self.output_channels,
                    kernel_size=1,
                    padding=0,
                    wscale_gain=1.0,
                    activation_type='linear'))

    self.upsample = ResolutionScalingLayer()
    self.lod = nn.Parameter(torch.zeros(()))

    self.pth_to_tf_var_mapping = {}
    for pth_var_name, tf_var_name in _PGGAN_PTH_VARS_TO_TF_VARS.items():
      if self.fused_scale and 'Conv0' in tf_var_name:
        pth_var_name = pth_var_name.replace('conv.weight', 'weight')
        tf_var_name = tf_var_name.replace('Conv0', 'Conv0_up')
      self.pth_to_tf_var_mapping[pth_var_name] = tf_var_name

  def forward(self, x):
    if len(x.shape) != 2:
      raise ValueError(f'The input tensor should be with shape [batch_size, '
                       f'noise_dim], but {x.shape} received!')
    x = x.view(x.shape[0], x.shape[1], 1, 1)

    lod = self.lod.cpu().tolist()
    for block_idx in range(1, len(self.channels)):
      if block_idx + lod < len(self.channels):
        x = self.__getattr__(f'layer{2 * block_idx - 2}')(x)
        x = self.__getattr__(f'layer{2 * block_idx - 1}')(x)
        image = self.__getattr__(f'output{block_idx - 1}')(x)
      else:
        image = self.upsample(image)
    return image


class PixelNormLayer(nn.Module):
  """Implements pixel-wise feature vector normalization layer."""

  def __init__(self, epsilon=1e-8):
    super().__init__()
    self.epsilon = epsilon

  def forward(self, x):
    return x / torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon)


class ResolutionScalingLayer(nn.Module):
  """Implements the resolution scaling layer.

  Basically, this layer can be used to upsample or downsample feature maps from
  spatial domain with nearest neighbor interpolation.
  """

  def __init__(self, scale_factor=2):
    super().__init__()
    self.scale_factor = scale_factor

  def forward(self, x):
    return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')


class WScaleLayer(nn.Module):
  """Implements the layer to scale weight variable and add bias.

  Note that, the weight variable is trained in `nn.Conv2d` layer, and only
  scaled with a constant number, which is not trainable, in this layer. However,
  the bias variable is trainable in this layer.
  """

  def __init__(self, in_channels, out_channels, kernel_size, gain=np.sqrt(2.0)):
    super().__init__()
    fan_in = in_channels * kernel_size * kernel_size
    self.scale = gain / np.sqrt(fan_in)
    self.bias = nn.Parameter(torch.zeros(out_channels))

  def forward(self, x):
    return x * self.scale + self.bias.view(1, -1, 1, 1)


class ConvBlock(nn.Module):
  """Implements the convolutional block used in ProgressiveGAN.

  Basically, this block executes pixel-wise normalization layer, upsampling
  layer (if needed), convolutional layer, weight-scale layer, and activation
  layer in sequence.
  """

  def __init__(self,
               in_channels,
               out_channels,
               kernel_size=3,
               stride=1,
               padding=1,
               dilation=1,
               add_bias=False,
               upsample=False,
               fused_scale=False,
               wscale_gain=np.sqrt(2.0),
               activation_type='lrelu'):
    """Initializes the class with block settings.

    Args:
      in_channels: Number of channels of the input tensor fed into this block.
      out_channels: Number of channels (kernels) of the output tensor.
      kernel_size: Size of the convolutional kernel.
      stride: Stride parameter for convolution operation.
      padding: Padding parameter for convolution operation.
      dilation: Dilation rate for convolution operation.
      add_bias: Whether to add bias onto the convolutional result.
      upsample: Whether to upsample the input tensor before convolution.
      fused_scale: Whether to fused `upsample` and `conv2d` together, resulting
        in `conv2_transpose`.
      wscale_gain: The gain factor for `wscale` layer.
      wscale_lr_multiplier: The learning rate multiplier factor for `wscale`
        layer.
      activation_type: Type of activation function. Support `linear`, `lrelu`
        and `tanh`.

    Raises:
      NotImplementedError: If the input `activation_type` is not supported.
    """
    super().__init__()
    self.pixel_norm = PixelNormLayer()

    if upsample and not fused_scale:
      self.upsample = ResolutionScalingLayer()
    else:
      self.upsample = nn.Identity()

    if upsample and fused_scale:
      self.weight = nn.Parameter(
          torch.randn(kernel_size, kernel_size, in_channels, out_channels))
      fan_in = in_channels * kernel_size * kernel_size
      self.scale = wscale_gain / np.sqrt(fan_in)
    else:
      self.conv = nn.Conv2d(in_channels=in_channels,
                            out_channels=out_channels,
                            kernel_size=kernel_size,
                            stride=stride,
                            padding=padding,
                            dilation=dilation,
                            groups=1,
                            bias=add_bias)

    self.wscale = WScaleLayer(in_channels=in_channels,
                              out_channels=out_channels,
                              kernel_size=kernel_size,
                              gain=wscale_gain)

    if activation_type == 'linear':
      self.activate = nn.Identity()
    elif activation_type == 'lrelu':
      self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
    elif activation_type == 'tanh':
      self.activate = nn.Hardtanh()
    else:
      raise NotImplementedError(f'Not implemented activation function: '
                                f'{activation_type}!')

  def forward(self, x):
    x = self.pixel_norm(x)
    x = self.upsample(x)
    if hasattr(self, 'conv'):
      x = self.conv(x)
    else:
      kernel = self.weight * self.scale
      kernel = F.pad(kernel, (0, 0, 0, 0, 1, 1, 1, 1), 'constant', 0.0)
      kernel = (kernel[1:, 1:] + kernel[:-1, 1:] +
                kernel[1:, :-1] + kernel[:-1, :-1])
      kernel = kernel.permute(2, 3, 0, 1)
      x = F.conv_transpose2d(x, kernel, stride=2, padding=1)
      x = x / self.scale
    x = self.wscale(x)
    x = self.activate(x)
    return x