# Copyright 2016 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Contains convenience wrappers for typical Neural Network TensorFlow layers. Additionally it maintains a collection with update_ops that need to be updated after the ops have been computed, for exmaple to update moving means and moving variances of batch_norm. Ops that have different behavior during training or eval have an is_training parameter. Additionally Ops that contain variables.variable have a trainable parameter, which control if the ops variables are trainable or not. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.math_ops import sigmoid from tensorflow.python.ops.math_ops import tanh from tensorflow.python.ops.rnn_cell_impl import _RNNCell as RNNCell from tensorflow.contrib.framework.python.ops import add_arg_scope from tensorflow.contrib.layers.python.layers import utils from tensorflow.python.ops import nn from tensorflow.python.training import moving_averages from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest from tensorflow.contrib.framework.python.ops import variables DATA_FORMAT_NCHW = 'NCHW' DATA_FORMAT_NHWC = 'NHWC' def signed_sqrt(inputs, scope): with tf.variable_scope(scope, 'signed_sqrt', [inputs]): return tf.sign(inputs) * tf.sqrt(tf.abs(inputs)) def l2_norm(inputs, scope): with tf.variable_scope(scope, 'l2_norm', [inputs]): dim = len(inputs.get_shape()) - 1 return tf.nn.l2_normalize(inputs, dim) class GRUCell(RNNCell): """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).""" def __init__(self, num_units, input_size=None, activation=tanh): if input_size is not None: logging.warn("%s: The input_size parameter is deprecated.", self) self._num_units = num_units self._activation = activation @property def state_size(self): return self._num_units @property def output_size(self): return self._num_units def __call__(self, inputs, state, scope=None): """Gated recurrent unit (GRU) with nunits cells.""" with tf.variable_scope(scope or type(self).__name__): # "GRUCell" with tf.variable_scope("Gates"): # Reset gate and update gate. # We start with bias of 1.0 to not reset and not update. r, u = array_ops.split(_linear([inputs, state], 2 * self._num_units, True, 1.0), 2, 1) r, u = sigmoid(r), sigmoid(u) with tf.variable_scope("Candidate"): c = self._activation(_linear([inputs, r * state], self._num_units, True)) new_h = u * state + (1 - u) * c return new_h, new_h def _linear(args, output_size, bias, bias_start=0.0, scope=None): """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. Args: args: a 2D Tensor or a list of 2D, batch x n, Tensors. output_size: int, second dimension of W[i]. bias: boolean, whether to add a bias term or not. bias_start: starting value to initialize the bias; 0 by default. scope: VariableScope for the created subgraph; defaults to "Linear". Returns: A 2D Tensor with shape [batch x output_size] equal to sum_i(args[i] * W[i]), where W[i]s are newly created matrices. Raises: ValueError: if some of the arguments has unspecified or wrong shape. """ if args is None or (nest.is_sequence(args) and not args): raise ValueError("`args` must be specified") if not nest.is_sequence(args): args = [args] # Calculate the total size of arguments on dimension 1. total_arg_size = 0 shapes = [a.get_shape().as_list() for a in args] for shape in shapes: if len(shape) != 2: raise ValueError("Linear is expecting 2D arguments: %s" % str(shapes)) if not shape[1]: raise ValueError("Linear expects shape[1] of arguments: %s" % str(shapes)) else: total_arg_size += shape[1] dtype = [a.dtype for a in args][0] # Now the computation. with tf.variable_scope(scope or "Linear"): matrix = tf.get_variable( "Matrix", [total_arg_size, output_size], dtype=dtype) if len(args) == 1: res = math_ops.matmul(args[0], matrix) else: res = math_ops.matmul(array_ops.concat(args, 1), matrix) if not bias: return res bias_term = tf.get_variable( "Bias", [output_size], dtype=dtype, initializer=init_ops.constant_initializer( bias_start, dtype=dtype)) return res + bias_term def adjust_max(start, stop, start_value, stop_value, name=None): with ops.name_scope(name, "AdjustMax", [start, stop, name]) as name: global_step = tf.train.get_global_step() if global_step is not None: start = tf.convert_to_tensor(start, dtype=tf.int64) stop = tf.convert_to_tensor(stop, dtype=tf.int64) start_value = tf.convert_to_tensor(start_value, dtype=tf.float32) stop_value = tf.convert_to_tensor(stop_value, dtype=tf.float32) pred_fn_pairs = {} pred_fn_pairs[global_step <= start] = lambda: start_value pred_fn_pairs[(global_step > start) & (global_step <= stop)] = lambda: tf.train.polynomial_decay( start_value, global_step-start, stop-start, end_learning_rate=stop_value, power=1.0, cycle=False) default = lambda: stop_value return tf.case(pred_fn_pairs, default, exclusive=True) else: return None @add_arg_scope def fused_batch_norm( inputs, renorm=False, RMAX=None, DMAX=None, decay=0.999, center=True, scale=False, epsilon=0.001, activation_fn=None, param_initializers=None, is_training=True, reuse=None, variables_collections=None, outputs_collections=None, trainable=True, data_format=DATA_FORMAT_NHWC, zero_debias_moving_mean=False, scope=None): """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167. "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" Sergey Ioffe, Christian Szegedy Can be used as a normalizer function for conv2d and fully_connected. Note: When is_training is True the moving_mean and moving_variance need to be updated, by default the update_ops are placed in `tf.GraphKeys.UPDATE_OPS` so they need to be added as a dependency to the `train_op`, example: update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) if update_ops: updates = tf.group(*update_ops) total_loss = control_flow_ops.with_dependencies([updates], total_loss) Args: inputs: a tensor with 2 or more dimensions, where the first dimension has `batch_size`. The normalization is over all but the last dimension if `data_format` is `NHWC` and the second dimension if `data_format` is `NCHW`. decay: decay for the moving average. Reasonable values for `decay` are close to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc. Lower `decay` value (recommend trying `decay`=0.9) if model experiences reasonably good training performance but poor validation and/or test performance. center: If True, add offset of `beta` to normalized tensor. If False, `beta` is ignored. scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling can be done by the next layer. epsilon: small float added to variance to avoid dividing by zero. activation_fn: activation function, default set to None to skip it and maintain a linear activation. param_initializers: optional initializers for beta, gamma, moving mean and moving variance. updates_collections: collections to collect the update ops for computation. The updates_ops need to be executed with the train_op. If None, a control dependency would be added to make sure the updates are computed in place. is_training: whether or not the layer is in training mode. In training mode it would accumulate the statistics of the moments into `moving_mean` and `moving_variance` using an exponential moving average with the given `decay`. When it is not in training mode then it would use the values of the `moving_mean` and the `moving_variance`. reuse: whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. variables_collections: optional collections for the variables. outputs_collections: collections to add the outputs. trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). data_format: A string. `NHWC` (default) and `NCHW` are supported. zero_debias_moving_mean: Use zero_debias for moving_mean. scope: Optional scope for `variable_scope`. Returns: A `Tensor` representing the output of the operation. Raises: ValueError: if `data_format` is neither `NHWC` nor `NCHW`. ValueError: if the rank of `inputs` is undefined. ValueError: if the rank of `inputs` is neither 2 or 4. ValueError: if rank or `C` dimension of `inputs` is undefined. """ if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): raise ValueError('data_format has to be either NCHW or NHWC.') with tf.variable_scope( scope, 'BatchNorm', [inputs], reuse=reuse) as sc: inputs = ops.convert_to_tensor(inputs) original_shape = inputs.get_shape() original_rank = original_shape.ndims if original_rank is None: raise ValueError('Inputs %s has undefined rank' % inputs.name) elif original_rank not in [2, 4]: raise ValueError('Inputs %s has unsupported rank.' ' Expected 2 or 4 but got %d' % ( inputs.name, original_rank)) if original_rank == 2: channels = inputs.get_shape()[-1].value if channels is None: raise ValueError('`C` dimension must be known but is None') new_shape = [-1, 1, 1, channels] if data_format == DATA_FORMAT_NCHW: new_shape = [-1, channels, 1, 1] inputs = array_ops.reshape(inputs, new_shape) inputs_shape = inputs.get_shape() dtype = inputs.dtype.base_dtype if data_format == DATA_FORMAT_NHWC: params_shape = inputs_shape[-1:] else: params_shape = inputs_shape[1:2] if not params_shape.is_fully_defined(): raise ValueError('Inputs %s has undefined `C` dimension %s.' % (inputs.name, params_shape)) if not param_initializers: param_initializers = {} # Allocate parameters for the beta and gamma of the normalization. trainable_beta = trainable and center if trainable_beta: beta_collections = utils.get_variable_collections(variables_collections, 'beta') beta_initializer = param_initializers.get('beta', init_ops.zeros_initializer()) real_beta = variables.model_variable( 'beta', shape=params_shape, dtype=dtype, initializer=beta_initializer, collections=beta_collections, trainable=trainable_beta) beta = tf.zeros(params_shape, name='fakebeta') else: real_beta = tf.zeros(params_shape, name='beta') beta = tf.zeros(params_shape, name='fakebeta') trainable_gamma = trainable and scale if trainable_gamma: gamma_collections = utils.get_variable_collections(variables_collections, 'gamma') gamma_initializer = param_initializers.get('gamma', init_ops.ones_initializer()) gamma = variables.model_variable( 'gamma', shape=params_shape, dtype=dtype, initializer=gamma_initializer, collections=gamma_collections, trainable=trainable_gamma) else: gamma = tf.ones(params_shape, name='gamma') # Create moving_mean and moving_variance variables and add them to the # appropiate collections. moving_mean_collections = utils.get_variable_collections( variables_collections, 'moving_mean') moving_mean_initializer = param_initializers.get( 'moving_mean', init_ops.zeros_initializer()) moving_mean = variables.model_variable( 'moving_mean', shape=params_shape, dtype=dtype, initializer=moving_mean_initializer, trainable=False, collections=moving_mean_collections) moving_variance_collections = utils.get_variable_collections( variables_collections, 'moving_variance') moving_variance_initializer = param_initializers.get( 'moving_variance', init_ops.ones_initializer()) moving_variance = variables.model_variable( 'moving_variance', shape=params_shape, dtype=dtype, initializer=moving_variance_initializer, trainable=False, collections=moving_variance_collections) def _fused_batch_norm_training(): outputs, mean, variance = nn.fused_batch_norm( inputs, gamma, beta, epsilon=epsilon, data_format=data_format) if renorm: moving_inv = math_ops.rsqrt(moving_variance + epsilon) r = tf.stop_gradient(tf.clip_by_value(tf.sqrt(variance + epsilon) * moving_inv, 1/RMAX, RMAX)) d = tf.stop_gradient(tf.clip_by_value((mean - moving_mean) * moving_inv, -DMAX, DMAX)) outputs = outputs * r + d return outputs, mean, variance def _fused_batch_norm_inference(): return nn.fused_batch_norm( inputs, gamma, beta, mean=moving_mean, variance=moving_variance, epsilon=epsilon, is_training=False, data_format=data_format) outputs, mean, variance = utils.smart_cond(is_training, _fused_batch_norm_training, _fused_batch_norm_inference) outputs = tf.nn.bias_add(outputs, real_beta) # If `is_training` doesn't have a constant value, because it is a `Tensor`, # a `Variable` or `Placeholder` then is_training_value will be None and # `need_updates` will be true. is_training_value = utils.constant_value(is_training) need_updates = is_training_value is None or is_training_value if need_updates: moving_vars_fn = lambda: (moving_mean, moving_variance) def _delay_updates(): """Internal function that delay updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay, zero_debias=zero_debias_moving_mean) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay, zero_debias=False) return update_moving_mean, update_moving_variance update_mean, update_variance = utils.smart_cond(is_training, _delay_updates, moving_vars_fn) ops.add_to_collections(ops.GraphKeys.UPDATE_OPS, update_mean) ops.add_to_collections(ops.GraphKeys.UPDATE_OPS, update_variance) outputs.set_shape(inputs_shape) if original_shape.ndims == 2: outputs = array_ops.reshape(outputs, original_shape) if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.original_name_scope, outputs)