# Copyright 2018 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Regularizers that group other regularizers for residual connections. An element-wise operation between two tensors (addition, multiplication, maximum etc.) imposes a constraint of equality on the shapes of the constituents. For example, if A, B are convolutions, and another op in the network receives A + B as input, it means that the i-th output of A is tied to the i-th output of B. Only if the i-th output was regularized away by the regularizer in both A and B can we discard the i-th activation in both. Therefore we group the i-th output of A and the i-th output of B in a group LASSO, a group for each i. The grouping methods can vary, and this file offers several variants. Residual connections, in ResNet or in RNNs, are examples where this kind of grouping is needed. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from morph_net.framework import generic_regularizers import tensorflow.compat.v1 as tf DEFAULT_THRESHOLD = 0.01 class MaxGroupingRegularizer(generic_regularizers.OpRegularizer): """A regularizer that groups others by taking their maximum.""" def __init__(self, regularizers_to_group): """Creates an instance. Args: regularizers_to_group: A list of generic_regularizers.OpRegularizer objects.Their regularization_vector (alive_vector) are expected to be of the same length. Raises: ValueError: regularizers_to_group is not of length at least 2. """ if len(regularizers_to_group) < 2: raise ValueError('Groups must be of at least size 2.') first = regularizers_to_group[0] regularization_vector = first.regularization_vector alive_vector = first.alive_vector for index in range(1, len(regularizers_to_group)): regularizer = regularizers_to_group[index] regularization_vector = tf.maximum(regularization_vector, regularizer.regularization_vector) alive_vector = tf.logical_or(alive_vector, regularizer.alive_vector) self._regularization_vector = regularization_vector self._alive_vector = alive_vector @property def regularization_vector(self): return self._regularization_vector @property def alive_vector(self): return self._alive_vector class L2GroupingRegularizer(generic_regularizers.OpRegularizer): r"""A regularizer that groups others by taking their L2 norm. R_j = sqrt((\sum_i r_{ij}^2)) Where r_i is the i-th regularization vector, r_{ij} is its j-th element, and R_j is the j-th element of the resulting regularization vector. """ def __init__(self, regularizers_to_group, threshold=DEFAULT_THRESHOLD): """Creates an instance. Args: regularizers_to_group: A list of generic_regularizers.OpRegularizer objects.Their regularization_vector (alive_vector) are expected to be of the same length. threshold: A float. An group of activations will be considered alive if its L2 norm is greater than `threshold`. Raises: ValueError: regularizers_to_group is not of length at least 2. """ if len(regularizers_to_group) < 2: raise ValueError('Groups must be of at least size 2.') self._regularization_vector = tf.sqrt( tf.add_n([ lazy_square(r.regularization_vector) for r in regularizers_to_group ])) self._alive_vector = self._regularization_vector > threshold @property def regularization_vector(self): return self._regularization_vector @property def alive_vector(self): return self._alive_vector def lazy_square(tensor): """Computes the square of a tensor in a lazy way. This function is lazy in the following sense, for: tensor = tf.sqrt(input) will return input (and not tf.square(tensor)). Args: tensor: A `Tensor` of floats to compute the square of. Returns: The square of the input tensor. """ if tensor.op.type == 'Sqrt': return tensor.op.inputs[0] else: return tf.square(tensor)