# coding=utf-8 # Copyright 2019 The Interval Bound Propagation Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Additional Sonnet modules.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import sonnet as snt import tensorflow.compat.v1 as tf # Slightly altered version of snt.BatchNorm that allows to easily grab which # mean and variance are currently in use (whether the last _build was # invoked with is_training=True or False). # Modifications include: # - Removing fused option (which we do not support). # - Removing test_local_stats (which we do not support). # - Providing a mean and variance property. # - Provides scale, bias properties that return None if there are none. class BatchNorm(snt.BatchNorm): """Batch normalization module, including optional affine transformation.""" def __init__(self, axis=None, offset=True, scale=False, decay_rate=0.999, eps=1e-3, initializers=None, partitioners=None, regularizers=None, update_ops_collection=None, name='batch_norm'): """Constructs a BatchNorm module. See original code for more details.""" super(BatchNorm, self).__init__( axis=axis, offset=offset, scale=scale, decay_rate=decay_rate, eps=eps, initializers=initializers, partitioners=partitioners, regularizers=regularizers, fused=False, update_ops_collection=update_ops_collection, name=name) def _build_statistics(self, input_batch, axis, use_batch_stats, stat_dtype): """Builds the statistics part of the graph when using moving variance.""" self._mean, self._variance = super(BatchNorm, self)._build_statistics( input_batch, axis, use_batch_stats, stat_dtype) return self._mean, self._variance def _build(self, input_batch, is_training=True, test_local_stats=False, reuse=False): """Connects the BatchNorm module into the graph. Args: input_batch: A Tensor of arbitrary dimension. By default, the final dimension is not reduced over when computing the minibatch statistics. is_training: A boolean to indicate if the module should be connected in training mode, meaning the moving averages are updated. Can be a Tensor. test_local_stats: A boolean to indicate if the statistics should be from the local batch. When is_training is True, test_local_stats is not used. reuse: If True, the statistics computed by previous call to _build are used and is_training is ignored. Otherwise, behaves like a normal batch normalization layer. Returns: A tensor with the same shape as `input_batch`. Raises: ValueError: If `axis` is not valid for the input shape or has negative entries. """ if reuse: self._ensure_is_connected() return tf.nn.batch_normalization( input_batch, self._mean, self._variance, self._beta, self._gamma, self._eps, name='batch_norm') else: return super(BatchNorm, self)._build(input_batch, is_training, test_local_stats=test_local_stats) @property def scale(self): self._ensure_is_connected() return tf.stop_gradient(self._gamma) if self._gamma is not None else None @property def bias(self): self._ensure_is_connected() return tf.stop_gradient(self._beta) if self._beta is not None else None @property def mean(self): self._ensure_is_connected() return tf.stop_gradient(self._mean) @property def variance(self): self._ensure_is_connected() return tf.stop_gradient(self._variance) @property def epsilon(self): self._ensure_is_connected() return self._eps class ImageNorm(snt.AbstractModule): """Module that does per channel normalization.""" def __init__(self, mean, std, name='image_norm'): """Constructs a module that does (x[:, :, c] - mean[c]) / std[c].""" super(ImageNorm, self).__init__(name=name) if isinstance(mean, float): mean = [mean] if isinstance(std, float): std = [std] scale = [] for s in std: if s <= 0.: raise ValueError('Cannot use negative standard deviations.') scale.append(1. / s) with self._enter_variable_scope(): # Using broadcasting. self._scale = tf.constant(scale, dtype=tf.float32) self._offset = tf.constant(mean, dtype=tf.float32) def _build(self, inputs): return self.apply(inputs) @property def scale(self): return self._scale @property def offset(self): return self._offset # Provide a function that allows to use the IncreasingMonotonicWrapper. def apply(self, inputs): return (inputs - self._offset) * self._scale