# Lint as: python3 # Copyright 2020 Google LLC. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Implement the components needed for HiFiC. For more details, see the paper: https://arxiv.org/abs/2006.09965 The default values for all constructors reflect what was used in the paper. """ import collections from compare_gan.architectures import abstract_arch from compare_gan.architectures import arch_ops import numpy as np import tensorflow.compat.v1 as tf import tensorflow_compression as tfc from .helpers import ModelMode SCALES_MIN = 0.11 SCALES_MAX = 256 SCALES_LEVELS = 64 # Output of discriminator, where real and fake are merged into single tensors. DiscOutAll = collections.namedtuple( "DiscOutAll", ["d_all", "d_all_logits"]) # Split each tensor in a DiscOutAll into 2. DiscOutSplit = collections.namedtuple( "DiscOutSplit", ["d_real", "d_fake", "d_real_logits", "d_fake_logits"]) EntropyInfo = collections.namedtuple( "EntropyInfo", "noisy quantized nbits nbpp qbits qbpp", ) FactorizedPriorInfo = collections.namedtuple( "FactorizedPriorInfo", "decoded latent_shape total_nbpp total_qbpp bitstring", ) HyperInfo = collections.namedtuple( "HyperInfo", "decoded latent_shape hyper_latent_shape " "nbpp side_nbpp total_nbpp qbpp side_qbpp total_qbpp " "bitstring side_bitstring", ) class Encoder(tf.keras.Sequential): """Encoder architecture.""" def __init__(self, name="Encoder", num_down=4, num_filters_base=60, num_filters_bottleneck=220): """Instantiate model. Args: name: Name of the layer. num_down: How many downsampling layers to use. num_filters_base: Num filters to base multiplier on. num_filters_bottleneck: Num filters to output for bottleneck (latent). """ self._num_down = num_down model = [ tf.keras.layers.Conv2D( filters=num_filters_base, kernel_size=7, padding="same"), LayerNorm(), tf.keras.layers.ReLU() ] for i in range(num_down): model.extend([ tf.keras.layers.Conv2D( filters=num_filters_base * 2 ** (i + 1), kernel_size=3, padding="same", strides=2), LayerNorm(), tf.keras.layers.ReLU()]) model.append( tf.keras.layers.Conv2D( filters=num_filters_bottleneck, kernel_size=3, padding="same")) super(Encoder, self).__init__(layers=model, name=name) @property def num_downsampling_layers(self): return self._num_down class Decoder(tf.keras.layers.Layer): """Decoder/generator architecture.""" def __init__(self, name="Decoder", num_up=4, num_filters_base=60, num_residual_blocks=9, ): """Instantiate layer. Args: name: name of the layer. num_up: how many upsampling layers. num_filters_base: base number of filters. num_residual_blocks: number of residual blocks. """ head = [LayerNorm(), tf.keras.layers.Conv2D( filters=num_filters_base * (2 ** num_up), kernel_size=3, padding="same"), LayerNorm()] residual_blocks = [] for block_idx in range(num_residual_blocks): residual_blocks.append( ResidualBlock( filters=num_filters_base * (2 ** num_up), kernel_size=3, name="block_{}".format(block_idx), activation="relu", padding="same")) tail = [] for scale in reversed(range(num_up)): filters = num_filters_base * (2 ** scale) tail += [ tf.keras.layers.Conv2DTranspose( filters=filters, kernel_size=3, padding="same", strides=2), LayerNorm(), tf.keras.layers.ReLU()] # Final conv layer. tail.append( tf.keras.layers.Conv2D( filters=3, kernel_size=7, padding="same")) self._head = tf.keras.Sequential(head) self._residual_blocks = tf.keras.Sequential(residual_blocks) self._tail = tf.keras.Sequential(tail) super(Decoder, self).__init__(name=name) def call(self, inputs): after_head = self._head(inputs) after_res = self._residual_blocks(after_head) after_res += after_head # Skip connection return self._tail(after_res) class ResidualBlock(tf.keras.layers.Layer): """Implement a residual block.""" def __init__( self, filters, kernel_size, name=None, activation="relu", **kwargs_conv2d): """Instantiate layer. Args: filters: int, number of filters, passed to the conv layers. kernel_size: int, kernel_size, passed to the conv layers. name: str, name of the layer. activation: function or string, resolved with keras. **kwargs_conv2d: Additional arguments to be passed directly to Conv2D. E.g. 'padding'. """ super(ResidualBlock, self).__init__() kwargs_conv2d["filters"] = filters kwargs_conv2d["kernel_size"] = kernel_size block = [ tf.keras.layers.Conv2D(**kwargs_conv2d), LayerNorm(), tf.keras.layers.Activation(activation), tf.keras.layers.Conv2D(**kwargs_conv2d), LayerNorm()] self.block = tf.keras.Sequential(name=name, layers=block) def call(self, inputs, **kwargs): return inputs + self.block(inputs, **kwargs) class LayerNorm(tf.keras.layers.Layer): """Implement LayerNorm. Based on this paper and keras' InstanceNorm layer: Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv preprint arXiv:1607.06450 (2016). """ def __init__(self, epsilon: float = 1e-3, center: bool = True, scale: bool = True, beta_initializer="zeros", gamma_initializer="ones", **kwargs): """Instantiate layer. Args: epsilon: For stability when normalizing. center: Whether to create and use a {beta}. scale: Whether to create and use a {gamma}. beta_initializer: Initializer for beta. gamma_initializer: Initializer for gamma. **kwargs: Passed to keras. """ super(LayerNorm, self).__init__(**kwargs) self.axis = -1 self.epsilon = epsilon self.center = center self.scale = scale self.beta_initializer = tf.keras.initializers.get(beta_initializer) self.gamma_initializer = tf.keras.initializers.get(gamma_initializer) def build(self, input_shape): self._add_gamma_weight(input_shape) self._add_beta_weight(input_shape) self.built = True super().build(input_shape) def call(self, inputs, modulation=None): mean, variance = self._get_moments(inputs) # inputs = tf.Print(inputs, [mean, variance, self.beta, self.gamma], "NORM") return tf.nn.batch_normalization( inputs, mean, variance, self.beta, self.gamma, self.epsilon, name="normalize") def _get_moments(self, inputs): # Like tf.nn.moments but unbiased sample std. deviation. # Reduce over channels only. mean = tf.reduce_mean(inputs, [self.axis], keepdims=True, name="mean") variance = tf.reduce_sum( tf.squared_difference(inputs, tf.stop_gradient(mean)), [self.axis], keepdims=True, name="variance_sum") # Divide by N-1 inputs_shape = tf.shape(inputs) counts = tf.reduce_prod([inputs_shape[ax] for ax in [self.axis]]) variance /= (tf.cast(counts, tf.float32) - 1) return mean, variance def _add_gamma_weight(self, input_shape): dim = input_shape[self.axis] shape = (dim,) if self.scale: self.gamma = self.add_weight( shape=shape, name="gamma", initializer=self.gamma_initializer) else: self.gamma = None def _add_beta_weight(self, input_shape): dim = input_shape[self.axis] shape = (dim,) if self.center: self.beta = self.add_weight( shape=shape, name="beta", initializer=self.beta_initializer) else: self.beta = None class _PatchDiscriminatorCompareGANImpl(abstract_arch.AbstractDiscriminator): """PatchDiscriminator architecture. Implemented as a compare_gan layer. This has the benefit that we can use spectral_norm from that framework. """ def __init__(self, name, num_filters_base=64, num_layers=3, ): """Instantiate discriminator. Args: name: Name of the layer. num_filters_base: Number of base filters. will be multiplied as we go down in resolution. num_layers: Number of downscaling convolutions. """ super(_PatchDiscriminatorCompareGANImpl, self).__init__( name, batch_norm_fn=None, layer_norm=False, spectral_norm=True) self._num_layers = num_layers self._num_filters_base = num_filters_base def __call__(self, x): """Overwriting compare_gan's __call__ as we only need `x`.""" with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE): return self.apply(x) def apply(self, x): """Overwriting compare_gan's apply as we only need `x`.""" if not isinstance(x, tuple) or len(x) != 2: raise ValueError("Expected 2-tuple, got {}".format(x)) x, latent = x x_shape = tf.shape(x) # Upscale and fuse latent. latent = arch_ops.conv2d(latent, 12, 3, 3, 1, 1, name="latent", use_sn=self._spectral_norm) latent = arch_ops.lrelu(latent, leak=0.2) latent = tf.image.resize(latent, [x_shape[1], x_shape[2]], tf.image.ResizeMethod.NEAREST_NEIGHBOR) x = tf.concat([x, latent], axis=-1) # The discriminator: k = 4 net = arch_ops.conv2d(x, self._num_filters_base, k, k, 2, 2, name="d_conv_head", use_sn=self._spectral_norm) net = arch_ops.lrelu(net, leak=0.2) num_filters = self._num_filters_base for i in range(self._num_layers - 1): num_filters = min(num_filters * 2, 512) net = arch_ops.conv2d(net, num_filters, k, k, 2, 2, name=f"d_conv_{i}", use_sn=self._spectral_norm) net = arch_ops.lrelu(net, leak=0.2) num_filters = min(num_filters * 2, 512) net = arch_ops.conv2d(net, num_filters, k, k, 1, 1, name="d_conv_a", use_sn=self._spectral_norm) net = arch_ops.lrelu(net, leak=0.2) # Final 1x1 conv that maps to 1 Channel net = arch_ops.conv2d(net, 1, k, k, 1, 1, name="d_conv_b", use_sn=self._spectral_norm) out_logits = tf.reshape(net, [-1, 1]) # Reshape all into batch dimension. out = tf.nn.sigmoid(out_logits) return DiscOutAll(out, out_logits) class _CompareGANLayer(tf.keras.layers.Layer): """Base class for wrapping compare_gan classes as keras layers. The main task of this class is to provide a keras-like interface, which includes a `trainable_variables`. This is non-trivial however, as compare_gan uses tf.get_variable. So we try to use the name scope to find these variables. """ def __init__(self, name, compare_gan_cls, **compare_gan_kwargs): """Constructor. Args: name: Name of the layer. IMPORTANT: Setting this to the same string for two different layers will cause unexpected behavior since variables are found using this name. compare_gan_cls: A class from compare_gan, which should inherit from either AbstractGenerator or AbstractDiscriminator. **compare_gan_kwargs: keyword arguments passed to compare_gan_cls to construct it. """ super(_CompareGANLayer, self).__init__(name=name) compare_gan_kwargs["name"] = name self._name = name self._model = compare_gan_cls(**compare_gan_kwargs) def call(self, x): return self._model(x) @property def trainable_variables(self): """Get trainable variables.""" # Note: keras only returns something if self.training is true, but we # don't have training as a flag to the constructor, so we always return. # However, we only call trainable_variables when we are training. return tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope=self._model.name) class Discriminator(_CompareGANLayer): def __init__(self): super(Discriminator, self).__init__( name="Discriminator", compare_gan_cls=_PatchDiscriminatorCompareGANImpl) class Hyperprior(tf.keras.layers.Layer): """Hyperprior architecture (probability model).""" def __init__(self, num_chan_bottleneck=220, num_filters=320, name="Hyperprior"): super(Hyperprior, self).__init__(name=name) self._num_chan_bottleneck = num_chan_bottleneck self._num_filters = num_filters self._analysis = tf.keras.Sequential([ tfc.SignalConv2D( num_filters, (3, 3), name=f"layer_{name}_0", corr=True, padding="same_zeros", use_bias=True, activation=tf.nn.relu), tfc.SignalConv2D( num_filters, (5, 5), name=f"layer_{name}_1", corr=True, strides_down=2, padding="same_zeros", use_bias=True, activation=tf.nn.relu), tfc.SignalConv2D( num_filters, (5, 5), name=f"layer_{name}_2", corr=True, strides_down=2, padding="same_zeros", use_bias=True, activation=None)], name="HyperAnalysis") def _make_synthesis(syn_name): return tf.keras.Sequential([ tfc.SignalConv2D( num_filters, (5, 5), name=f"layer_{syn_name}_0", corr=False, strides_up=2, padding="same_zeros", use_bias=True, kernel_parameterizer=None, activation=tf.nn.relu), tfc.SignalConv2D( num_filters, (5, 5), name=f"layer_{syn_name}_1", corr=False, strides_up=2, padding="same_zeros", use_bias=True, kernel_parameterizer=None, activation=tf.nn.relu), tfc.SignalConv2D( num_chan_bottleneck, (3, 3), name=f"layer_{syn_name}_2", corr=False, padding="same_zeros", use_bias=True, kernel_parameterizer=None, activation=None), ], name="HyperSynthesis") self._synthesis_scale = _make_synthesis("scale") self._synthesis_mean = _make_synthesis("mean") self._side_entropy_model = FactorizedPriorLayer() @property def transform_layers(self): return [self._analysis, self._synthesis_scale, self._synthesis_mean] @property def entropy_layers(self): return [self._side_entropy_model] def call(self, latents, image_shape, mode: ModelMode) -> HyperInfo: """Apply this layer to code `latents`. Args: latents: Tensor of latent values to code. image_shape: The [height, width] of a reference frame. mode: The training, evaluation or validation mode of the model. Returns: A HyperInfo tuple. """ training = (mode == ModelMode.TRAINING) validation = (mode == ModelMode.VALIDATION) latent_shape = tf.shape(latents)[1:-1] hyper_latents = self._analysis(latents, training=training) # Model hyperprior distributions and entropy encode/decode hyper-latents. side_info = self._side_entropy_model( hyper_latents, image_shape=image_shape, mode=mode, training=training) hyper_decoded = side_info.decoded scale_table = np.exp(np.linspace( np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS)) latent_scales = self._synthesis_scale( hyper_decoded, training=training) latent_means = self._synthesis_mean( tf.cast(hyper_decoded, tf.float32), training=training) if not (training or validation): latent_scales = latent_scales[:, :latent_shape[0], :latent_shape[1], :] latent_means = latent_means[:, :latent_shape[0], :latent_shape[1], :] conditional_entropy_model = tfc.GaussianConditional( latent_scales, scale_table, mean=latent_means, name="conditional_entropy_model") entropy_info = estimate_entropy( conditional_entropy_model, latents, spatial_shape=image_shape) compressed = None if training: latents_decoded = _quantize(latents, latent_means) elif validation: latents_decoded = entropy_info.quantized else: compressed = conditional_entropy_model.compress(latents) latents_decoded = conditional_entropy_model.decompress(compressed) info = HyperInfo( decoded=latents_decoded, latent_shape=latent_shape, hyper_latent_shape=side_info.latent_shape, nbpp=entropy_info.nbpp, side_nbpp=side_info.total_nbpp, total_nbpp=entropy_info.nbpp + side_info.total_nbpp, qbpp=entropy_info.qbpp, side_qbpp=side_info.total_qbpp, total_qbpp=entropy_info.qbpp + side_info.total_qbpp, bitstring=compressed, side_bitstring=side_info.bitstring) tf.summary.scalar("bpp/total/noisy", info.total_nbpp) tf.summary.scalar("bpp/total/quantized", info.total_qbpp) return info def _quantize(inputs, mean): half = tf.constant(.5, dtype=tf.float32) outputs = inputs outputs -= mean # Rounding latents for the forward pass (straight-through). outputs = outputs + tf.stop_gradient(tf.math.floor(outputs + half) - outputs) outputs += mean return outputs class FactorizedPriorLayer(tf.keras.layers.Layer): """Factorized prior to code a discrete tensor.""" def __init__(self): """Instantiate layer.""" super(FactorizedPriorLayer, self).__init__(name="FactorizedPrior") self._entropy_model = tfc.EntropyBottleneck( name="entropy_model") def compute_output_shape(self, input_shape): batch_size = input_shape[0] shapes = ( input_shape, # decoded [2], # latent_shape = [height, width] [], # total_nbpp [], # total_qbpp [batch_size], # bitstring ) return tuple(tf.TensorShape(x) for x in shapes) @property def losses(self): return self._entropy_model.losses @property def updates(self): return self._entropy_model.updates def call(self, latents, image_shape, mode: ModelMode) -> FactorizedPriorInfo: """Apply this layer to code `latents`. Args: latents: Tensor of latent values to code. image_shape: The [height, width] of a reference frame. mode: The training, evaluation or validation mode of the model. Returns: A FactorizedPriorInfo tuple """ training = (mode == ModelMode.TRAINING) validation = (mode == ModelMode.VALIDATION) latent_shape = tf.shape(latents)[1:-1] with tf.name_scope("factorized_entropy_model"): noisy, quantized, _, nbpp, _, qbpp = estimate_entropy( self._entropy_model, latents, spatial_shape=image_shape) compressed = None if training: latents_decoded = noisy elif validation: latents_decoded = quantized else: compressed = self._entropy_model.compress(latents) # Decompress using the spatial shape tensor and get tensor coming out of # range decoder. num_channels = latents.shape[-1].value latents_decoded = self._entropy_model.decompress( compressed, shape=tf.concat([latent_shape, [num_channels]], 0)) return FactorizedPriorInfo( decoded=latents_decoded, latent_shape=latent_shape, total_nbpp=nbpp, total_qbpp=qbpp, bitstring=compressed) def estimate_entropy(entropy_model, inputs, spatial_shape=None) -> EntropyInfo: """Compresses `inputs` with the given entropy model and estimates entropy. Arguments: entropy_model: An `EntropyModel` instance. inputs: The input tensor to be fed to the entropy model. spatial_shape: Shape of the input image (HxW). Must be provided for `valid == False`. Returns: The 'noisy' and quantized inputs, as well as differential and discrete entropy estimates, as an `EntropyInfo` named tuple. """ # We are summing over the log likelihood tensor, so we need to explicitly # divide by the batch size. batch = tf.cast(tf.shape(inputs)[0], tf.float32) # Divide by this to flip sign and convert from nats to bits. quotient = tf.constant(-np.log(2), dtype=tf.float32) num_pixels = tf.cast(tf.reduce_prod(spatial_shape), tf.float32) # Compute noisy outputs and estimate differential entropy. noisy, likelihood = entropy_model(inputs, training=True) log_likelihood = tf.log(likelihood) nbits = tf.reduce_sum(log_likelihood) / (quotient * batch) nbpp = nbits / num_pixels # Compute quantized outputs and estimate discrete entropy. quantized, likelihood = entropy_model(inputs, training=False) log_likelihood = tf.log(likelihood) qbits = tf.reduce_sum(log_likelihood) / (quotient * batch) qbpp = qbits / num_pixels return EntropyInfo(noisy, quantized, nbits, nbpp, qbits, qbpp)