# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Layers for a progressive GAN model. This module contains basic building blocks to build a progressive GAN model. See https://arxiv.org/abs/1710.10196 for details about the model. See https://github.com/tkarras/progressive_growing_of_gans for the original theano implementation. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import tensorflow as tf def pixel_norm(images, epsilon=1.0e-8): """Pixel normalization. For each pixel a[i,j,k] of image in HWC format, normalize its value to b[i,j,k] = a[i,j,k] / SQRT(SUM_k(a[i,j,k]^2) / C + eps). Args: images: A 4D `Tensor` of NHWC format. epsilon: A small positive number to avoid division by zero. Returns: A 4D `Tensor` with pixel-wise normalized channels. """ return images * tf.rsqrt( tf.reduce_mean(tf.square(images), axis=3, keepdims=True) + epsilon) def _get_validated_scale(scale): """Returns the scale guaranteed to be a positive integer.""" scale = int(scale) if scale <= 0: raise ValueError('`scale` must be a positive integer.') return scale def downscale(images, scale): """Box downscaling of images. Args: images: A 4D `Tensor` in NHWC format. scale: A positive integer scale. Returns: A 4D `Tensor` of `images` down scaled by a factor `scale`. Raises: ValueError: If `scale` is not a positive integer. """ scale = _get_validated_scale(scale) if scale == 1: return images return tf.nn.avg_pool( images, ksize=[1, scale, scale, 1], strides=[1, scale, scale, 1], padding='VALID') def upscale(images, scale): """Box upscaling (also called nearest neighbors) of images. Args: images: A 4D `Tensor` in NHWC format. scale: A positive integer scale. Returns: A 4D `Tensor` of `images` up scaled by a factor `scale`. Raises: ValueError: If `scale` is not a positive integer. """ scale = _get_validated_scale(scale) if scale == 1: return images return tf.batch_to_space( tf.tile(images, [scale**2, 1, 1, 1]), crops=[[0, 0], [0, 0]], block_size=scale) def minibatch_mean_stddev(x): """Computes the standard deviation average. This is used by the discriminator as a form of batch discrimination. Args: x: A `Tensor` for which to compute the standard deviation average. The first dimension must be batch size. Returns: A scalar `Tensor` which is the mean variance of variable x. """ mean, var = tf.nn.moments(x, axes=[0]) del mean return tf.reduce_mean(tf.sqrt(var)) def scalar_concat(tensor, scalar): """Concatenates a scalar to the last dimension of a tensor. Args: tensor: A `Tensor`. scalar: a scalar `Tensor` to concatenate to tensor `tensor`. Returns: A `Tensor`. If `tensor` has shape [...,N], the result R has shape [...,N+1] and R[...,N] = scalar. Raises: ValueError: If `tensor` is a scalar `Tensor`. """ ndims = tensor.shape.ndims if ndims < 1: raise ValueError('`tensor` must have number of dimensions >= 1.') shape = tf.shape(tensor) return tf.concat( [tensor, tf.ones([shape[i] for i in range(ndims - 1)] + [1]) * scalar], axis=ndims - 1) def he_initializer_scale(shape, slope=1.0): """The scale of He neural network initializer. Args: shape: A list of ints representing the dimensions of a tensor. slope: A float representing the slope of the ReLu following the layer. Returns: A float of he initializer scale. """ fan_in = np.prod(shape[:-1]) return np.sqrt(2. / ((1. + slope**2) * fan_in)) def _custom_layer_impl(apply_kernel, kernel_shape, bias_shape, activation, he_initializer_slope, use_weight_scaling): """Helper function to implement custom_xxx layer. Args: apply_kernel: A function that transforms kernel to output. kernel_shape: An integer tuple or list of the kernel shape. bias_shape: An integer tuple or list of the bias shape. activation: An activation function to be applied. None means no activation. he_initializer_slope: A float slope for the He initializer. use_weight_scaling: Whether to apply weight scaling. Returns: A `Tensor` computed as apply_kernel(kernel) + bias where kernel is a `Tensor` variable with shape `kernel_shape`, bias is a `Tensor` variable with shape `bias_shape`. """ kernel_scale = he_initializer_scale(kernel_shape, he_initializer_slope) init_scale, post_scale = kernel_scale, 1.0 if use_weight_scaling: init_scale, post_scale = post_scale, init_scale kernel_initializer = tf.random_normal_initializer(stddev=init_scale) bias = tf.get_variable( 'bias', shape=bias_shape, initializer=tf.zeros_initializer()) output = post_scale * apply_kernel(kernel_shape, kernel_initializer) + bias if activation is not None: output = activation(output) return output def custom_conv2d(x, filters, kernel_size, strides=(1, 1), padding='SAME', activation=None, he_initializer_slope=1.0, use_weight_scaling=True, scope='custom_conv2d', reuse=None): """Custom conv2d layer. In comparison with tf.layers.conv2d this implementation use the He initializer to initialize convolutional kernel and the weight scaling trick (if `use_weight_scaling` is True) to equalize learning rates. See https://arxiv.org/abs/1710.10196 for more details. Args: x: A `Tensor` of NHWC format. filters: An int of output channels. kernel_size: An integer or a int tuple of [kernel_height, kernel_width]. strides: A list of strides. padding: One of "VALID" or "SAME". activation: An activation function to be applied. None means no activation. Defaults to None. he_initializer_slope: A float slope for the He initializer. Defaults to 1.0. use_weight_scaling: Whether to apply weight scaling. Defaults to True. scope: A string or variable scope. reuse: Whether to reuse the weights. Defaults to None. Returns: A `Tensor` of NHWC format where the last dimension has size `filters`. """ if not isinstance(kernel_size, (list, tuple)): kernel_size = [kernel_size] * 2 kernel_size = list(kernel_size) def _apply_kernel(kernel_shape, kernel_initializer): return tf.layers.conv2d( x, filters=filters, kernel_size=kernel_shape[0:2], strides=strides, padding=padding, use_bias=False, kernel_initializer=kernel_initializer) with tf.variable_scope(scope, reuse=reuse): return _custom_layer_impl( _apply_kernel, kernel_shape=kernel_size + [x.shape.as_list()[3], filters], bias_shape=(filters,), activation=activation, he_initializer_slope=he_initializer_slope, use_weight_scaling=use_weight_scaling) def custom_dense(x, units, activation=None, he_initializer_slope=1.0, use_weight_scaling=True, scope='custom_dense', reuse=None): """Custom dense layer. In comparison with tf.layers.dense This implementation use the He initializer to initialize weights and the weight scaling trick (if `use_weight_scaling` is True) to equalize learning rates. See https://arxiv.org/abs/1710.10196 for more details. Args: x: A `Tensor`. units: An int of the last dimension size of output. activation: An activation function to be applied. None means no activation. Defaults to None. he_initializer_slope: A float slope for the He initializer. Defaults to 1.0. use_weight_scaling: Whether to apply weight scaling. Defaults to True. scope: A string or variable scope. reuse: Whether to reuse the weights. Defaults to None. Returns: A `Tensor` where the last dimension has size `units`. """ x = tf.contrib.layers.flatten(x) def _apply_kernel(kernel_shape, kernel_initializer): return tf.layers.dense( x, kernel_shape[1], use_bias=False, kernel_initializer=kernel_initializer) with tf.variable_scope(scope, reuse=reuse): return _custom_layer_impl( _apply_kernel, kernel_shape=(x.shape.as_list()[-1], units), bias_shape=(units,), activation=activation, he_initializer_slope=he_initializer_slope, use_weight_scaling=use_weight_scaling)