# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Normamlization methods that implements cross replica nomalization for TPU.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import functools import tensorflow as tf from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.python.keras import layers as keras_layers from tensorflow.python.ops import math_ops def cross_replica_average(t, num_groups=1): """Calculates the average value of input tensor across TPU replicas.""" num_shards = tpu_function.get_tpu_context().number_of_shards num_shards_per_group = 1 group_assignment = None if num_groups > 0: if num_shards % num_groups != 0: raise ValueError('num_shards: %d mod num_groups: %d, should be 0' % (num_shards, num_groups)) num_shards_per_group = num_shards // num_groups group_assignment = [[ x for x in range(num_shards) if x // num_shards_per_group == y ] for y in range(num_groups)] return tpu_ops.cross_replica_sum(t, group_assignment) / math_ops.cast( num_shards_per_group, t.dtype) class BatchNormalization(keras_layers.BatchNormalization, tf.layers.Layer): """Batch Normalization layer that supports cross replica computation on TPU. This class extends the keras.BatchNormalization implementation by supporting cross replica means and variances. The base class implementation only computes moments based on mini-batch per replica (TPU core). For detailed information of arguments and implementation, refer to: https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization Arguments: fused: if `None` or `True`, use a faster, fused implementation if possible. If `False`, use the system recommended implementation. cross_replica_average_fn: A function takes a tensor and outputs the mean value across all the replicas. Currently, only TPU version supports this feature. If specified, fused must be `False`. """ def __init__(self, fused=None, cross_replica_average_fn=None, **kwargs): super(BatchNormalization, self).__init__(**kwargs) self.cross_replica_average_fn = cross_replica_average_fn if fused and cross_replica_average_fn is not None: raise ValueError('fused must be `False` when sepcifying' ' cross_replica_average_fn') def _moments(self, inputs, reduction_axes, keep_dims): mean, variance = super(BatchNormalization, self)._moments( inputs, reduction_axes, keep_dims=keep_dims) if self.cross_replica_average_fn: mean = self.cross_replica_average_fn(mean) variance = self.cross_replica_average_fn(variance) return (mean, variance) def cross_replica_batch_normalization(inputs, training=False, num_distributed_groups=1, **kwargs): """Functional interface for the cross replica batch normalization layer. For detailed information of arguments and implementation, refer to: https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization Arguments: inputs: Tensor input. training: Either a Python boolean, or a TensorFlow boolean scalar tensor (e.g. a placeholder). Whether to return the output in training mode (normalized with statistics of the current batch) or in inference mode (normalized with moving statistics). **NOTE**: make sure to set this parameter correctly, or else your training/inference will not work properly. num_distributed_groups: Number of groups to normalize in the distributed batch normalization. Replicas will evenly split into groups. For example, 1 for global batch norm and -1 or None for per-replica batch norm. **kwargs: For passing through arguments to BatchNormalization. Returns: Output tensor. Raises: ValueError: if eager execution is enabled. """ layer = BatchNormalization( cross_replica_average_fn=functools.partial( cross_replica_average, num_groups=num_distributed_groups), **kwargs) return layer.apply(inputs, training=training)