# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """FisherFactor definitions.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import abc import contextlib import math # Dependency imports import numpy as np import six import tensorflow as tf from collections import OrderedDict from tensorflow.python.util import nest from kfac.python.ops import linear_operator as lo from kfac.python.ops import utils # Whether to initialize covariance estimators at a zero matrix (or the identity # matrix). INIT_COVARIANCES_AT_ZERO = True # Whether to zero-debias the moving averages. ZERO_DEBIAS = True # Whether to initialize inverse (and other such matrices computed from the cov # matrices) to the zero matrix (or the identity matrix). Initializing to # zero is a safeguard against anything using the inverse before their first # proper update, and so is preferred. INIT_INVERSES_AT_ZERO = True # When the number of inverses requested from a FisherFactor is >= this value, # the inverses are computed using an eigenvalue decomposition. EIGENVALUE_DECOMPOSITION_THRESHOLD = 4 # Numerical eigenvalues computed from covariance matrix estimates are clipped to # be at least as large as this value before they are used to compute inverses or # matrix powers. Must be nonnegative. EIGENVALUE_CLIPPING_THRESHOLD = 0.0 # When approximating conv layer input factor using spatially uncorrelated # activations (`ConvInputSUAKroneckerfactor`) if this is True then assumes the # activations to have zero mean. ASSUME_ZERO_MEAN_ACTIVATIONS = False # When approximating conv layer input factor using spatially uncorrelated # activations (`ConvInputSUAKroneckerfactor`) if this is True then do # mean subtraction from covariance matrix. Note this flag is only checked in the # case where ASSUME_ZERO_MEAN_ACTIVATIONS is set to True. If # ASSUME_ZERO_MEAN_ACTIVATIONS is False then mean is always subtracted from the # covaraince matrix and this flag is redundant. SUBTRACT_MEAN_CONTRIB_FROM_COV = True # Subsample the inputs passed to the extract image patches. The number of # inputs is normally batch_size. If _SUB_SAMPLE_INPUTS = True then # the inputs will be randomly subsampled down to a total of # _INPUTS_TO_EXTRACT_PATCHES_FACTOR * batch_size. # # Note that the value of _SUB_SAMPLE_INPUTS can be overridden locally for a # particular layer by passing in an argument to the factor class (or the # registration function for the corresponding layer). _SUB_SAMPLE_INPUTS = False _INPUTS_TO_EXTRACT_PATCHES_FACTOR = 0.2 # Subsample the extracted image patches during covariance estimation for # input factors in conv layer. The number of patches subsampled will be # calculated based on the following formula: # # if _SUB_SAMPLE_PATCHES: # num_patches = min(_MAX_NUM_PATCHES, # ceil(_MAX_NUM_PATCHES_PER_DIMENSION*dimension)) # else # num_patches = total_patches # # where dimension is the number of rows (or columns) of the input factor matrix, # which is typically the number of input channels times the number of pixels # in a patch. # # Note that the value of _SUB_SAMPLE_PATCHES can be overridden locally for a # particular layer by passing in an argument to the factor class (or the # registration function for the corresponding layer). _SUB_SAMPLE_PATCHES = False _MAX_NUM_PATCHES = 10000000 _MAX_NUM_PATCHES_PER_DIMENSION = 3.0 # If true we use the custom XLA implementation of an op to compute the second # moment of the patch vectors. Note that _SUB_SAMPLE_PATCHES doesn't do anything # when this is enabled. Also note that _SUB_SAMPLE_INPUTS probably doesn't # need to be used either, since that feature was designed to mitigate the # extreme memory consumption of the naive implementation of this op. _USE_PATCHES_SECOND_MOMENT_OP = False # TOWER_STRATEGY can be one of "concat" or "separate". If "concat", the data # passed to the factors from the blocks will be concatenated across towers # (lazily via PartitionedTensor objects). Otherwise a tuple of tensors over # towers will be passed in, and the factors will iterate over this and do the # cov computations separately for each one, averaging the results together. TOWER_STRATEGY = "separate" #TOWER_STRATEGY = "concat" # The variable scope names can be edited by passing a custom sanitizer function. # By default the scope name is unchanged. _GET_SANITIZED_NAME_FN = lambda x: x def set_global_constants(init_covariances_at_zero=None, zero_debias=None, init_inverses_at_zero=None, eigenvalue_decomposition_threshold=None, eigenvalue_clipping_threshold=None, assume_zero_mean_activations=None, subtract_mean_contrib_from_cov=None, sub_sample_inputs=None, inputs_to_extract_patches_factor=None, sub_sample_patches=None, max_num_patches=None, max_num_patches_per_dimension=None, tower_strategy=None, get_sanitized_name_fn=None, use_patches_second_moment_op=None): """Sets various global constants used by the classes in this module.""" global INIT_COVARIANCES_AT_ZERO global ZERO_DEBIAS global INIT_INVERSES_AT_ZERO global EIGENVALUE_DECOMPOSITION_THRESHOLD global EIGENVALUE_CLIPPING_THRESHOLD global ASSUME_ZERO_MEAN_ACTIVATIONS global SUBTRACT_MEAN_CONTRIB_FROM_COV global _SUB_SAMPLE_INPUTS global _INPUTS_TO_EXTRACT_PATCHES_FACTOR global _SUB_SAMPLE_PATCHES global _MAX_NUM_PATCHES global _MAX_NUM_PATCHES_PER_DIMENSION global _GET_SANITIZED_NAME_FN global TOWER_STRATEGY global _USE_PATCHES_SECOND_MOMENT_OP if init_covariances_at_zero is not None: INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero if zero_debias is not None: ZERO_DEBIAS = zero_debias if init_inverses_at_zero is not None: INIT_INVERSES_AT_ZERO = init_inverses_at_zero if eigenvalue_decomposition_threshold is not None: EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold if eigenvalue_clipping_threshold is not None: EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold if assume_zero_mean_activations is not None: ASSUME_ZERO_MEAN_ACTIVATIONS = assume_zero_mean_activations if subtract_mean_contrib_from_cov is not None: SUBTRACT_MEAN_CONTRIB_FROM_COV = subtract_mean_contrib_from_cov if sub_sample_inputs is not None: _SUB_SAMPLE_INPUTS = sub_sample_inputs if inputs_to_extract_patches_factor is not None: _INPUTS_TO_EXTRACT_PATCHES_FACTOR = inputs_to_extract_patches_factor if sub_sample_patches is not None: _SUB_SAMPLE_PATCHES = sub_sample_patches if max_num_patches is not None: _MAX_NUM_PATCHES = max_num_patches if max_num_patches_per_dimension is not None: _MAX_NUM_PATCHES_PER_DIMENSION = max_num_patches_per_dimension if tower_strategy is not None: TOWER_STRATEGY = tower_strategy if get_sanitized_name_fn is not None: _GET_SANITIZED_NAME_FN = get_sanitized_name_fn if use_patches_second_moment_op is not None: _USE_PATCHES_SECOND_MOMENT_OP = use_patches_second_moment_op if INIT_INVERSES_AT_ZERO: inverse_initializer = tf.zeros_initializer else: inverse_initializer = tf.initializers.identity if INIT_COVARIANCES_AT_ZERO: covariance_initializer = tf.zeros_initializer else: covariance_initializer = tf.initializers.identity if INIT_COVARIANCES_AT_ZERO: diagonal_covariance_initializer = tf.zeros_initializer else: diagonal_covariance_initializer = tf.ones_initializer @contextlib.contextmanager def maybe_place_on_device(device): if device is not None and len(device) and TOWER_STRATEGY == "separate": with tf.device(device): yield else: yield def compute_cov(tensor, tensor_right=None, normalizer=None): """Compute the empirical second moment of the rows of a 2D Tensor. This function is meant to be applied to random matrices for which the true row mean is zero, so that the true second moment equals the true covariance. Args: tensor: A 2D Tensor. tensor_right: An optional 2D Tensor. If provided, this function computes the matrix product tensor^T * tensor_right instead of tensor^T * tensor. normalizer: optional scalar for the estimator (by default, the normalizer is the number of rows of tensor). Returns: A square 2D Tensor with as many rows/cols as the number of input columns. """ if normalizer is None: normalizer = utils.get_shape(tensor)[0] if tensor_right is None: cov = ( tf.matmul(tensor, tensor, transpose_a=True) / tf.cast( normalizer, tensor.dtype)) return (cov + tf.transpose(cov)) / tf.cast(2.0, cov.dtype) else: return (tf.matmul(tensor, tensor_right, transpose_a=True) / tf.cast(normalizer, tensor.dtype)) def append_homog(tensor, homog_value=None): """Appends a homogeneous coordinate to the last dimension of a Tensor. Args: tensor: A Tensor. homog_value: Value to append as homogeneous coordinate to the last dimension of `tensor`. If None 1.0 is used. (Default: None) Returns: A Tensor identical to the input but one larger in the last dimension. The new entries are filled with ones. """ shape = tensor.shape.as_list() rank = len(shape) if any(elt is None for elt in shape): shape = tf.concat([tf.shape(tensor)[:-1], [1]], axis=0) else: shape[-1] = 1 if homog_value is not None: appendage = homog_value * tf.ones(shape, dtype=tensor.dtype) else: appendage = tf.ones(shape, dtype=tensor.dtype) return tf.concat([tensor, appendage], axis=-1) def scope_string_from_params(params): """Builds a variable scope string name from the given parameters. Supported parameters are: * tensors * booleans * ints * strings * depth-1 tuples/lists of ints * any depth tuples/lists of tensors Other parameter types will throw an error. Args: params: A parameter or list of parameters. Returns: A string to use for the variable scope. Raises: ValueError: if params includes an unsupported type. """ params = params if isinstance(params, (tuple, list)) else (params,) name_parts = [] for param in params: if param is None: name_parts.append("None") elif isinstance(param, (tuple, list)): if all([isinstance(p, int) for p in param]): name_parts.append("-".join([str(p) for p in param])) else: name_parts.append(scope_string_from_name(param)) elif isinstance(param, (six.string_types, int, bool)): name_parts.append(str(param)) elif isinstance(param, (tf.Tensor, tf.Variable)): name_parts.append(scope_string_from_name(param)) elif isinstance(param, utils.PartitionedTensor): name_parts.append(scope_string_from_name(param.tensors)) else: raise ValueError("Encountered an unsupported param {} of type {}".format( param, type(param))) return "_".join(name_parts) def scope_string_from_name(tensor): if isinstance(tensor, (tuple, list)): return "__".join([scope_string_from_name(t) for t in tensor]) # "gradients/add_4_grad/Reshape:0/replica_0" -> # "gradients_add_4_grad_Reshape_0_replica_0" tensor_name = tensor.name.replace("/", "_").replace(":", "_") return _GET_SANITIZED_NAME_FN(tensor_name) def scalar_or_tensor_to_string(val): return repr(val) if np.isscalar(val) else scope_string_from_name(val) def list_to_string(lst): return "_".join(val if isinstance(val, six.string_types) else scalar_or_tensor_to_string(val) for val in lst) def graph_func_to_id(func): """Returns a hashable object that represents func's computation.""" # TODO(b/74201126): replace with Topohash of func's output return func.func_id def graph_func_to_string(func): # TODO(b/74201126): replace with Topohash of func's output return list_to_string(func.func_id) def _subsample_patches(patches, name=None): """Subsample a patches matrix. Subsample an array of image patches. The number of patches subsampled will be calculated based on the following formula: num_patches = min(_MAX_NUM_PATCHES, ceil(_MAX_NUM_PATCHES_PER_DIMENSION*dimension)) Args: patches: Tensor, of shape `[total_patches, dimension]`. name: `string`, Default (None) Returns: A tensor of shape `[num_patches, dimension]`. Raises: ValueError: If patches is not matrix-shaped. ValueError: If total_patches cannot be inferred. """ with tf.name_scope(name, "subsample", [patches]): patches = tf.convert_to_tensor(patches) if len(patches.shape) != 2: raise ValueError("Input param patches must be a matrix.") total_patches = patches.shape.as_list()[0] dimension = patches.shape.as_list()[1] num_patches = min(_MAX_NUM_PATCHES, int(math.ceil(_MAX_NUM_PATCHES_PER_DIMENSION*dimension))) if total_patches is None: total_patches = utils.get_shape(patches)[0] should_subsample = tf.less(num_patches, total_patches) return tf.cond(should_subsample, lambda: _random_tensor_gather(patches, num_patches, name), lambda: patches) else: if num_patches < total_patches: return _random_tensor_gather(patches, num_patches, name) else: return patches def _random_tensor_gather(array, num_ind, name=None): """Samples random indices of an array (along the first dimension). Args: array: Tensor of shape `[batch_size, ...]`. num_ind: int. Number of indices to sample. name: `string`. (Default: None) Returns: A tensor of shape `[num_ind, ...]`. """ with tf.name_scope(name, "random_gather", [array]): array = tf.convert_to_tensor(array) total_size = array.shape.as_list()[0] if total_size is None: total_size = utils.get_shape(array)[0] indices = tf.random_shuffle(tf.range(0, total_size))[:num_ind] return tf.gather(array, indices, axis=0) @six.add_metaclass(abc.ABCMeta) class FisherFactor(object): """Base class for objects modeling factors of approximate Fisher blocks. A FisherFactor represents part of an approximate Fisher Information matrix. For example, one approximation to the Fisher uses the Kronecker product of two FisherFactors A and B, F = kron(A, B). FisherFactors are composed with FisherBlocks to construct a block-diagonal approximation to the full Fisher. FisherFactors are backed by a single, non-trainable variable that is updated by running FisherFactor.make_covariance_update_op(). The shape and type of this variable is implementation specific. Note that for blocks that aren't based on approximations, a 'factor' can be the entire block itself, as is the case for the diagonal and full representations. """ def __init__(self): self._cov_tensor = None self._cov = None self._acc_cov = None @abc.abstractproperty def _var_scope(self): """Variable scope for this FisherFactor instance. Returns: string that unique identifies this FisherFactor instance. """ pass @property def name(self): return self._var_scope @abc.abstractproperty def _cov_shape(self): """The shape of the variable backing this FisherFactor.""" pass @abc.abstractproperty def _num_sources(self): """The number of things to sum over when updating covariance variable. The default make_covariance_update_op function will call _compute_new_cov with indices ranging from 0 to _num_sources-1. The typical situation is where the factor wants to sum the statistics it computes over multiple backpropped "gradients" (typically passed in via "tensors" or "outputs_grads" arguments). """ pass @abc.abstractproperty def _num_towers(self): pass @abc.abstractproperty def _dtype(self): """dtype for variable backing this factor.""" pass @abc.abstractmethod def _partial_batch_size(self, source=0, tower=0): """Returns (partial) batch size associated with given source and tower.""" pass def batch_size(self, source=0): """Returns (total) batch size associated with given source.""" return sum(self._partial_batch_size(source=source, tower=tower) for tower in range(self._num_towers)) def check_partial_batch_sizes(self): """Ensures partial batch sizes are equal across towers and source.""" # While it could be okay in principle for the different batch sizes for # different towers, the way the code has been written isn't compatible with # this. Basically, the normalizations occur for each tower and then the # results are summed across towers and divided by the number of towers. # The only way this is correct is if the towers all have the same batch # size. # Should make these messages use quote characters instead of parentheses # when the bug with quote character rendering in assertion messages is # fixed. See b/129476712 msg = ("Inconsistent (partial) batch sizes detected for factor ({}) of type" " {}. This can be caused by passing Tensors with the wrong sizes to " "the registration functions, or misspecification of arguments like " "batch_size, num_uses, or num_timesteps.".format( self.name, utils.cls_name(self))) partial_batch_size = self._partial_batch_size() if self._num_sources > 1 or self._num_towers > 1: if isinstance(partial_batch_size, int): checks = tuple( partial_batch_size == self._partial_batch_size(source=source, tower=tower) for source, tower in zip(range(self._num_sources), range(self._num_towers))) if not all(checks): raise ValueError(msg) return tf.no_op() else: asserts = tuple( tf.assert_equal(partial_batch_size, self._partial_batch_size(source=source, tower=tower), message=msg) for source, tower in zip(range(self._num_sources), range(self._num_towers))) return tf.group(asserts) return tf.no_op() @property def _cov_initializer(self): """Function for initializing covariance variable.""" return covariance_initializer def instantiate_cov_variables(self): """Makes the internal cov variable(s).""" assert self._cov is None with tf.variable_scope(self._var_scope): self._cov = utils.MovingAverageVariable( name="cov", shape=self._cov_shape, dtype=self._dtype, initializer=self._cov_initializer, normalize_value=ZERO_DEBIAS) @abc.abstractmethod def _compute_new_cov(self, source, tower): """Computes minibatch-estimated covariance for a single source. Args: source: int in [0, self._num_sources). Which source to use when computing the cov update. tower: int in [0, self._num_towers). Which tower to use when computing the cov update. Returns: Tensor of same shape as self.cov. """ pass def _compute_total_new_cov(self): """Computes covariance by summing across (source, towers).""" new_cov_contribs = [] for source in range(self._num_sources): for tower in range(self._num_towers): with maybe_place_on_device(self._get_data_device(tower)): new_cov_contribs.append(self._compute_new_cov(source, tower)) new_cov = tf.add_n(new_cov_contribs) / float(self._num_towers) # Compute average of 'new_cov' across all replicas. On a replica, each # instance of 'new_cov' will be based on a different minibatch. This ensures # that by the time variable assignment happens, all replicas have the same # value. # # Other implementations of make_covariance_update_op() that accumulate # statistics in other variables should mimic this behavior. # # NOTE: communicating this matrix at every iteration is wasteful in the # sense that we might only need fresh copies when we do the inversions. # (Although be careful about factors [e.g. diagonal] or ops # [e.g. multiply()] that directly use the cov vars instead of the inv vars!) new_cov = utils.all_average(new_cov) return new_cov def make_covariance_update_op(self, ema_decay, ema_weight): """Constructs and returns the covariance update Op. Args: ema_decay: float or Tensor. The exponential moving average decay. ema_weight: float or Tensor. The weight to put on the newly computed values. This is typically 1.0 - ema_decay. Returns: The op which updates the cov variable (via acc_cov). """ cov_tensor = self._compute_total_new_cov() self._cov_tensor = cov_tensor # This is used for non-standard applications # and debugging I think. return self._cov.add_to_average(cov_tensor, decay=ema_decay, weight=ema_weight) @abc.abstractmethod def _get_data_device(self, tower): pass @abc.abstractmethod def instantiate_inv_variables(self): """Makes the internal "inverse" variable(s).""" pass @abc.abstractmethod def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" pass @property def cov(self): return self._cov.value def get_cov_vars(self): return [self.cov] def get_inv_vars(self): return [] @abc.abstractmethod def get_cov_as_linear_operator(self): """Returns `LinearOperator` instance which wraps the cov matrix.""" pass @abc.abstractmethod def register_matpower(self, exp, damping_func): pass @abc.abstractmethod def register_cholesky(self, damping_func): pass @abc.abstractmethod def register_cholesky_inverse(self, damping_func): pass @abc.abstractmethod def get_matpower(self, exp, damping_func): pass @abc.abstractmethod def get_cholesky(self, damping_func): pass @abc.abstractmethod def get_cholesky_inverse(self, damping_func): pass class DenseSquareMatrixFactor(FisherFactor): """Base class for FisherFactors that are stored as dense square matrices. This class explicitly calculates and stores inverses of their `cov` matrices, which must be square dense matrices. Subclasses must implement the _compute_new_cov method, and the _var_scope and _cov_shape properties. """ # TODO(b/69108481): This class (and its subclasses) should be refactored to # serve the matrix quantities it computes as both (potentially stale) # variables, updated by the inverse update ops, and fresh values stored in # tensors that recomputed once every session.run() call. Currently matpower # and damp_inverse have the former behavior, while eigendecomposition has # the latter. def __init__(self): self._matpower_by_exp_and_damping = OrderedDict() # { (float, hashable): variable } self._matpower_registrations = set() # { (float, hashable) } self._eigendecomp = None self._damping_funcs_by_id = OrderedDict() # {hashable: lambda} self._cholesky_registrations = set() # { hashable } self._cholesky_inverse_registrations = set() # { hashable } self._cholesky_by_damping = OrderedDict() # { hashable: variable } self._cholesky_inverse_by_damping = OrderedDict() # { hashable: variable } super(DenseSquareMatrixFactor, self).__init__() def get_cov_as_linear_operator(self): """Returns `LinearOperator` instance which wraps the cov matrix.""" assert self.cov.shape.ndims == 2 return lo.LinearOperatorFullMatrix(self.cov, is_self_adjoint=True, is_square=True) def _register_damping(self, damping_func): damping_id = graph_func_to_id(damping_func) if damping_id not in self._damping_funcs_by_id: self._damping_funcs_by_id[damping_id] = damping_func return damping_id def register_inverse(self, damping_func): # Just for backwards compatibility of some old code and tests self.register_matpower(-1, damping_func) def register_matpower(self, exp, damping_func): """Registers a matrix power to be maintained and served on demand. This creates a variable and signals make_inverse_update_ops to make the corresponding update op. The variable can be read via the method get_matpower. Args: exp: float. The exponent to use in the matrix power. damping_func: A function that computes a 0-D Tensor or a float which will be the damping value used. i.e. damping = damping_func(). """ if exp == 1.0: return damping_id = self._register_damping(damping_func) if (exp, damping_id) not in self._matpower_registrations: self._matpower_registrations.add((exp, damping_id)) def register_cholesky(self, damping_func): """Registers a Cholesky factor to be maintained and served on demand. This creates a variable and signals make_inverse_update_ops to make the corresponding update op. The variable can be read via the method get_cholesky. Args: damping_func: A function that computes a 0-D Tensor or a float which will be the damping value used. i.e. damping = damping_func(). """ damping_id = self._register_damping(damping_func) if damping_id not in self._cholesky_registrations: self._cholesky_registrations.add(damping_id) def register_cholesky_inverse(self, damping_func): """Registers an inverse Cholesky factor to be maintained/served on demand. This creates a variable and signals make_inverse_update_ops to make the corresponding update op. The variable can be read via the method get_cholesky_inverse. Args: damping_func: A function that computes a 0-D Tensor or a float which will be the damping value used. i.e. damping = damping_func(). """ damping_id = self._register_damping(damping_func) if damping_id not in self._cholesky_inverse_registrations: self._cholesky_inverse_registrations.add(damping_id) def get_inv_vars(self): inv_vars = [] inv_vars.extend(self._matpower_by_exp_and_damping.values()) inv_vars.extend(self._cholesky_by_damping.values()) inv_vars.extend(self._cholesky_inverse_by_damping.values()) return inv_vars def instantiate_inv_variables(self): """Makes the internal "inverse" variable(s).""" for (exp, damping_id) in self._matpower_registrations: exp_string = scalar_or_tensor_to_string(exp) damping_func = self._damping_funcs_by_id[damping_id] damping_string = graph_func_to_string(damping_func) with tf.variable_scope(self._var_scope): matpower = tf.get_variable( "matpower_exp{}_damp{}".format(exp_string, damping_string), initializer=inverse_initializer, shape=self._cov_shape, trainable=False, dtype=self._dtype, use_resource=True) assert (exp, damping_id) not in self._matpower_by_exp_and_damping self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower for damping_id in self._cholesky_registrations: damping_func = self._damping_funcs_by_id[damping_id] damping_string = graph_func_to_string(damping_func) with tf.variable_scope(self._var_scope): chol = tf.get_variable( "cholesky_damp{}".format(damping_string), initializer=inverse_initializer, shape=self._cov_shape, trainable=False, dtype=self._dtype, use_resource=True) assert damping_id not in self._cholesky_by_damping self._cholesky_by_damping[damping_id] = chol for damping_id in self._cholesky_inverse_registrations: damping_func = self._damping_funcs_by_id[damping_id] damping_string = graph_func_to_string(damping_func) with tf.variable_scope(self._var_scope): cholinv = tf.get_variable( "cholesky_inverse_damp{}".format(damping_string), initializer=inverse_initializer, shape=self._cov_shape, trainable=False, dtype=self._dtype, use_resource=True) assert damping_id not in self._cholesky_inverse_by_damping self._cholesky_inverse_by_damping[damping_id] = cholinv def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" ops = [] num_inverses = sum(1 for (exp, _) in self._matpower_by_exp_and_damping if exp == -1) num_other_matpower = len(self._matpower_by_exp_and_damping) - num_inverses other_matrix_power_registered = num_other_matpower >= 1 use_eig = ( self._eigendecomp or other_matrix_power_registered or num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD) # We precompute these so we don't need to evaluate them multiple times (for # each matrix power that uses them) damping_value_by_id = {damping_id: tf.cast( self._damping_funcs_by_id[damping_id](), self._dtype) for damping_id in self._damping_funcs_by_id} if use_eig: eigenvalues, eigenvectors = self.get_eigendecomp() # pylint: disable=unpacking-non-sequence for (exp, damping_id), matpower in ( self._matpower_by_exp_and_damping.items()): damping = damping_value_by_id[damping_id] ops.append( utils.smart_assign( matpower, tf.matmul(eigenvectors * (eigenvalues + damping)**exp, tf.transpose(eigenvectors)))) # These ops share computation and should be run on a single device. ops = [tf.group(*ops)] else: for (exp, damping_id), matpower in ( self._matpower_by_exp_and_damping.items()): assert exp == -1 damping = damping_value_by_id[damping_id] ops.append( utils.smart_assign(matpower, utils.posdef_inv(self.cov, damping))) # TODO(b/77902055): If inverses are being computed with Cholesky's # we can share the work. Instead this code currently just computes the # Cholesky a second time. It does at least share work between requests for # Cholesky's and Cholesky inverses with the same damping id. for damping_id, cholesky_inv in self._cholesky_inverse_by_damping.items(): cholesky_ops = [] damping = damping_value_by_id[damping_id] cholesky_value = utils.cholesky(self.cov, damping) if damping_id in self._cholesky_by_damping: cholesky = self._cholesky_by_damping[damping_id] cholesky_ops.append(utils.smart_assign(cholesky, cholesky_value)) identity = tf.eye( cholesky_value.shape.as_list()[0], dtype=cholesky_value.dtype) cholesky_inv_value = tf.matrix_triangular_solve(cholesky_value, identity) cholesky_ops.append(utils.smart_assign(cholesky_inv, cholesky_inv_value)) ops.append(tf.group(*cholesky_ops)) for damping_id, cholesky in self._cholesky_by_damping.items(): if damping_id not in self._cholesky_inverse_by_damping: damping = damping_value_by_id[damping_id] cholesky_value = utils.cholesky(self.cov, damping) ops.append(utils.smart_assign(cholesky, cholesky_value)) self._eigendecomp = False return ops def get_inverse(self, damping_func): # Just for backwards compatibility of some old code and tests return self.get_matpower(-1, damping_func) def get_matpower(self, exp, damping_func): # Note that this function returns a variable which gets updated by the # inverse ops. It may be stale / inconsistent with the latest value of # self.cov (except when exp == 1). if exp != 1: damping_id = graph_func_to_id(damping_func) matpower = self._matpower_by_exp_and_damping[(exp, damping_id)] else: cov = self.cov identity = tf.eye(cov.shape.as_list()[0], dtype=cov.dtype) matpower = cov + tf.cast(damping_func(), dtype=self.cov.dtype)*identity assert matpower.shape.ndims == 2 return lo.LinearOperatorFullMatrix(matpower, is_non_singular=True, is_self_adjoint=True, is_positive_definite=True, is_square=True) def get_cholesky(self, damping_func): # Note that this function returns a variable which gets updated by the # inverse ops. It may be stale / inconsistent with the latest value of # self.cov. damping_id = graph_func_to_id(damping_func) cholesky = self._cholesky_by_damping[damping_id] assert cholesky.shape.ndims == 2 return lo.LinearOperatorFullMatrix(cholesky, is_non_singular=True, is_square=True) def get_cholesky_inverse(self, damping_func): # Note that this function returns a variable which gets updated by the # inverse ops. It may be stale / inconsistent with the latest value of # self.cov. damping_id = graph_func_to_id(damping_func) cholesky_inv = self._cholesky_inverse_by_damping[damping_id] assert cholesky_inv.shape.ndims == 2 return lo.LinearOperatorFullMatrix(cholesky_inv, is_non_singular=True, is_square=True) def get_eigendecomp(self): """Creates or retrieves eigendecomposition of self._cov.""" # Unlike get_matpower this doesn't retrieve a stored variable, but instead # always computes a fresh version from the current value of self.cov. if not self._eigendecomp: eigenvalues, eigenvectors = tf.self_adjoint_eig(self.cov) # The matrix self._cov is positive semidefinite by construction, but the # numerical eigenvalues could be negative due to numerical errors, so here # we clip them to be at least FLAGS.eigenvalue_clipping_threshold clipped_eigenvalues = tf.maximum(eigenvalues, EIGENVALUE_CLIPPING_THRESHOLD) self._eigendecomp = (clipped_eigenvalues, eigenvectors) return self._eigendecomp class NaiveFullFactor(DenseSquareMatrixFactor): """FisherFactor for a full matrix representation of the Fisher of a parameter. Note that this uses the naive "square the sum estimator", and so is applicable to any type of parameter in principle, but has very high variance. """ def __init__(self, params_grads, batch_size): self._batch_size = batch_size self._params_grads = tuple(utils.ensure_sequence(params_grad) for params_grad in params_grads) super(NaiveFullFactor, self).__init__() @property def _var_scope(self): return "ff_naivefull_" + scope_string_from_params( [self._params_grads, self._batch_size]) @property def _cov_shape(self): size = sum(param_grad.shape.num_elements() for param_grad in self._params_grads[0]) return (size, size) @property def _num_sources(self): return len(self._params_grads) @property def _num_towers(self): return 1 @property def _dtype(self): return self._params_grads[0][0].dtype def _partial_batch_size(self, source=0, tower=0): assert source == 0 and tower == 0 return self._batch_size def _compute_new_cov(self, source, tower): assert tower == 0 # This will be a very basic rank 1 estimate params_grads_flat = utils.tensors_to_column(self._params_grads[source]) return ((params_grads_flat * tf.transpose(params_grads_flat)) / tf.cast( self._batch_size, params_grads_flat.dtype)) def _get_data_device(self, tower): return None @six.add_metaclass(abc.ABCMeta) class DiagonalFactor(FisherFactor): """A base class for FisherFactors that use diagonal approximations. A DiagonalFactor's covariance variable can be of any shape, but must contain exactly one entry per parameter. """ def get_cov_as_linear_operator(self): """Returns `LinearOperator` instance which wraps the cov matrix.""" return lo.LinearOperatorDiag(self._matrix_diagonal, is_self_adjoint=True, is_square=True) @property def _cov_initializer(self): return diagonal_covariance_initializer @property def _matrix_diagonal(self): return tf.reshape(self.cov, [-1]) def make_inverse_update_ops(self): return [] def instantiate_inv_variables(self): pass def register_matpower(self, exp, damping_func): pass def register_cholesky(self, damping_func): pass def register_cholesky_inverse(self, damping_func): pass def get_matpower(self, exp, damping_func): matpower_diagonal = (self._matrix_diagonal + tf.cast(damping_func(), self._dtype))**exp return lo.LinearOperatorDiag(matpower_diagonal, is_non_singular=True, is_self_adjoint=True, is_positive_definite=True, is_square=True) def get_cholesky(self, damping_func): return self.get_matpower(0.5, damping_func) def get_cholesky_inverse(self, damping_func): return self.get_matpower(-0.5, damping_func) class NaiveDiagonalFactor(DiagonalFactor): """FisherFactor for a diagonal approximation of any type of param's Fisher. Note that this uses the naive "square the sum estimator", and so is applicable to any type of parameter in principle, but has very high variance. """ def __init__(self, params_grads, batch_size): """Initializes NaiveDiagonalFactor instance. Args: params_grads: List of tensors (or lists), with the first index corresponding to source, and the second optional index corresponding to the element of the parameter list. batch_size: int or 0-D Tensor. The batch size. """ self._params_grads = params_grads self._batch_size = batch_size super(NaiveDiagonalFactor, self).__init__() @property def _var_scope(self): return "ff_naivediag_" + scope_string_from_params( [self._params_grads, self._batch_size]) @property def _cov_shape(self): return self._params_grads[0].shape @property def _num_sources(self): return len(self._params_grads) @property def _num_towers(self): return 1 @property def _dtype(self): return self._params_grads[0].dtype def _partial_batch_size(self, source=0, tower=0): assert source == 0 and tower == 0 return self._batch_size def _compute_new_cov(self, source, tower): assert tower == 0 return (tf.square(self._params_grads[source]) / tf.cast( self._batch_size, self._params_grads[source].dtype)) def _get_data_device(self, tower): return None class DiagonalKroneckerFactor(DiagonalFactor): """A Kronecker FisherFactor using diagonal approximations. This class handles both sparse and dense inputs. The covariance is estimated using the diagonal covariance matrix. For a dense tensor: Cov(inputs, inputs) = (1/batch_size) sum_{i} diag(inputs[i,:] ** 2). For sparse inputs, one of the most common use cases is the sparse input to an embedding layer. Given tensor = [batch_size, input_size] representing indices into an [vocab_size, embedding_size] embedding matrix, the diagonal covariance matrix is Cov(inputs, inputs) = (1/batch_size) sum_{i} diag(n_hot(inputs[i]) ** 2). where inputs[i] is the ith list of input ids, n_hot() constructs an n-hot binary vector and diag() constructs a diagonal matrix of size [vocab_size, vocab_size]. """ def __init__(self, tensors, has_bias=False, dtype=None): """Instantiate DiagonalKroneckerFactor. Args: tensors: List of list of Tensors, each of shape [batch_size, n]. First index is source, second index is tower. Two types of tensors are supported. Dense tensors are typically either a layer's inputs or its output's gradients. Sparse tensors are typically indices into an [vocab_size, embedding_dim] embedding matrix. Sparse tensors must have a property named "one_hot_depth" indicating the depth of one-hot tensors they should be converted to. dtype: dtype for covariance statistics. Only used for sparse inputs. Must be a floating point type. Defaults to float32. has_bias: bool. If True, append '1' to each input. """ self._tensors = tensors dtype = dtype or tf.float32 self._has_bias = has_bias self._one_hot_depth = getattr(self._tensors[0][0], "one_hot_depth", None) if self._one_hot_depth is None: self._dense_input = True self._cov_dtype = self._tensors[0][0].dtype else: self._dense_input = False self._cov_dtype = dtype super(DiagonalKroneckerFactor, self).__init__() @property def _var_scope(self): return "ff_diag_kron_" + scope_string_from_params( nest.flatten(self._tensors)) @property def _cov_shape(self): if self._dense_input: size = self._tensors[0][0].shape[1] + self._has_bias else: size = self._one_hot_depth + self._has_bias return [size] @property def _num_sources(self): return len(self._tensors) @property def _num_towers(self): return len(self._tensors[0]) @property def _dtype(self): return self._cov_dtype def _partial_batch_size(self, source=0, tower=0): return utils.get_shape(self._tensors[source][tower])[0] def _compute_new_cov(self, source, tower): tensor = self._tensors[source][tower] if len(tensor.shape) > 2: raise ValueError( "Input tensors to DiagonalKroneckerFactor must have rank <= 2. " "Found tensor with wrong rank: {}".format(tensor)) batch_size = utils.get_shape(tensor)[0] if self._dense_input: new_cov = tf.square(tensor) else: # Transform indices into one-hot vectors. # # TODO(b/72714822): There must be a faster way to construct the diagonal # covariance matrix! This operation is O(batch_size * vocab_size), where # it should be O(batch_size * input_size). flat_input_ids = tf.reshape(tensor, [-1]) new_cov = tf.one_hot(flat_input_ids, self._one_hot_depth) # [?, vocab_size] # Take average across examples. Note that, because all entries have # magnitude zero or one, there's no need to square the entries. # # TODO(b/72714822): Support for SparseTensor, other kinds of aggregation # within an example such as average. # # TODO(b/72714822): Support for partitioned embeddings. new_cov = tf.reduce_sum(new_cov, axis=0) new_cov /= tf.cast(batch_size, new_cov.dtype) if self._has_bias: new_cov = append_homog(new_cov) return new_cov def _get_data_device(self, tower): return self._tensors[0][tower].device class DiagonalMultiKF(DiagonalKroneckerFactor): def __init__(self, tensors, num_uses, has_bias=False, dtype=None): super(DiagonalMultiKF, self).__init__( tensors, dtype=dtype, has_bias=has_bias) self._num_uses = num_uses def _partial_batch_size(self, source=0, tower=0): # Note that some internal comptutations of "batch_size" done in the parent # class won't actually be the proper batch size. Instead, they will be # just "the thing to normalize the statistics by", essentially. This is okay # as we don't mix the two things up. return (super(DiagonalMultiKF, self)._partial_batch_size(source=source, tower=tower) // self._num_uses) class FullyConnectedDiagonalFactor(DiagonalFactor): r"""FisherFactor for a diagonal approx of a fully-connected layer's Fisher. Given in = [batch_size, input_size] and out_grad = [batch_size, output_size], approximates the covariance as, Cov(in, out) = (1/batch_size) sum_{i} outer(in[i], out_grad[i]) ** 2.0 where the square is taken element-wise. """ def __init__(self, inputs, outputs_grads, has_bias=False): """Instantiate FullyConnectedDiagonalFactor. Args: inputs: List of Tensors of shape [batch_size, input_size]. Inputs to this layer. List index is towers. outputs_grads: List of Tensors, each of shape [batch_size, output_size], which are the gradients of the loss with respect to the layer's outputs. First index is source, second is tower. has_bias: bool. If True, append '1' to each input. """ self._inputs = inputs self._has_bias = has_bias self._outputs_grads = outputs_grads self._squared_inputs = None super(FullyConnectedDiagonalFactor, self).__init__() @property def _var_scope(self): return "ff_diagfc_" + scope_string_from_params( tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads))) @property def _cov_shape(self): input_size = self._inputs[0].shape[1] + self._has_bias output_size = self._outputs_grads[0][0].shape[1] return [input_size, output_size] @property def _num_sources(self): return len(self._outputs_grads) @property def _num_towers(self): return len(self._inputs) @property def _dtype(self): return self._outputs_grads[0][0].dtype def _partial_batch_size(self, source=0, tower=0): return utils.get_shape(self._outputs_grads[source][tower])[0] def make_covariance_update_op(self, ema_decay, ema_weight): self._squared_inputs = [] for tower in range(self._num_towers): inputs = self._inputs[tower] with maybe_place_on_device(self._get_data_device(tower)): if self._has_bias: inputs = append_homog(inputs) self._squared_inputs.append(tf.square(inputs)) return super(FullyConnectedDiagonalFactor, self).make_covariance_update_op( ema_decay, ema_weight) def _compute_new_cov(self, source, tower): batch_size = utils.get_shape(self._squared_inputs[tower])[0] outputs_grad = self._outputs_grads[source][tower] # The well-known special formula that uses the fact that the entry-wise # square of an outer product is the outer-product of the entry-wise squares. # The gradient is the outer product of the input and the output gradients, # so we just square both and then take their outer-product. new_cov = tf.matmul( self._squared_inputs[tower], tf.square(outputs_grad), transpose_a=True) new_cov /= tf.cast(batch_size, new_cov.dtype) return new_cov def _get_data_device(self, tower): return self._inputs[tower].device @six.add_metaclass(abc.ABCMeta) class ScaleAndShiftFactor(FisherFactor): def __init__(self, inputs, outputs_grads, broadcast_dim, has_shift=True, approx="full"): assert approx == "full" or approx == "diagonal" self._inputs = inputs self._outputs_grads = outputs_grads self._broadcast_dim = broadcast_dim self._has_shift = has_shift self._approx = approx super(ScaleAndShiftFactor, self).__init__() @property def _var_scope(self): return "ff_scaleshift_" + scope_string_from_params( [self._inputs, self._outputs_grads, self._broadcast_dim, self._has_shift, self._approx]) @property def _cov_shape(self): size = np.prod(self._inputs[0].shape[self._broadcast_dim:]) if self._has_shift: size *= 2 if self._approx == "full": return (size, size) elif self._approx == "diagonal": return (size,) @property def _num_sources(self): return len(self._outputs_grads) @property def _num_towers(self): return len(self._inputs) @property def _dtype(self): return self._inputs[0].dtype def _partial_batch_size(self, source=0, tower=0): return utils.get_shape(self._outputs_grads[source][tower])[0] def _compute_new_cov(self, source, tower): # Here we implement a "sum of squares" estimator that uses the special # structure of the scale & shift operation. In particular, we sum across # all dimensions that broadcast, then square (or take outer-products), and # then average across the mini-batch. inputs = self._inputs[tower] outputs_grad = self._outputs_grads[source][tower] batch_size = utils.get_shape(inputs)[0] assert len(inputs.shape) == len(outputs_grad.shape) for i in range(1, len(inputs.shape)): assert inputs.shape[i] <= outputs_grad.shape[i] # The formula for the gradient of the shift param is just the element-wise # product of the inputs and the output gradients, summed across the # dimensions that get broadcasted. scale_grads = tf.reduce_sum(inputs * outputs_grad, axis=list(range(1, self._broadcast_dim))) scale_grads_flat = tf.reshape(scale_grads, [batch_size, -1]) if self._has_shift: # The formula for the gradient of the shift param is just the output # gradients, summed across the dimensions that get broadcasted. shift_grads = tf.reduce_sum(outputs_grad, axis=list(range(1, self._broadcast_dim))) shift_grads_flat = tf.reshape(shift_grads, [batch_size, -1]) params_grads_flat = tf.concat([scale_grads_flat, shift_grads_flat], axis=1) else: params_grads_flat = scale_grads_flat if self._approx == "full": new_cov = compute_cov(params_grads_flat) elif self._approx == "diagonal": new_cov = tf.reduce_mean(tf.square(params_grads_flat), axis=0) return new_cov def _get_data_device(self, tower): return self._inputs[tower].device class ScaleAndShiftFullFactor(ScaleAndShiftFactor, DenseSquareMatrixFactor): def __init__(self, inputs, outputs_grads, broadcast_dim, has_shift=True): super(ScaleAndShiftFullFactor, self).__init__(inputs, outputs_grads, broadcast_dim, has_shift=has_shift, approx="full") class ScaleAndShiftDiagonalFactor(ScaleAndShiftFactor, DiagonalFactor): def __init__(self, inputs, outputs_grads, broadcast_dim, has_shift=True): super(ScaleAndShiftDiagonalFactor, self).__init__(inputs, outputs_grads, broadcast_dim, has_shift=has_shift, approx="diagonal") class ConvDiagonalFactor(DiagonalFactor): """FisherFactor for a diagonal approx of a convolutional layer's Fisher.""" def __init__(self, inputs, outputs_grads, filter_shape, strides, padding, data_format=None, dilations=None, has_bias=False, patch_mask=None): """Creates a ConvDiagonalFactor object. Args: inputs: List of Tensors of shape [batch_size, height, width, in_channels]. Input activations to this layer. List index is towers. outputs_grads: List of Tensors, each of shape [batch_size, height, width, out_channels], which are the gradients of the loss with respect to the layer's outputs. First index is source, second index is tower. filter_shape: Tuple of 4 ints: (kernel_height, kernel_width, in_channels, out_channels). Represents shape of kernel used in this layer. strides: The stride size in this layer (1-D Tensor of length 4). padding: The padding in this layer (1-D of Tensor length 4). data_format: None or str. Format of conv2d inputs. dilations: None or tuple of 4 ints. has_bias: Python bool. If True, the layer is assumed to have a bias parameter in addition to its filter parameter. patch_mask: Tensor of shape [kernel_height, kernel_width, in_channels] or None. If not None this is multiplied against the extracted patches Tensor (broadcasting along the batch dimension) before statistics are computed. (Default: None) Raises: ValueError: If inputs, output_grads, and filter_shape do not agree on in_channels or out_channels. ValueError: If strides, dilations are not length-4 lists of ints. ValueError: If data_format does not put channel last. """ if not utils.is_data_format_channel_last(data_format): raise ValueError("Channel must be last.") if any(input_.shape.ndims != 4 for input_ in inputs): raise ValueError("inputs must be a list of 4-D Tensors.") if any(input_.shape.as_list()[-1] != filter_shape[-2] for input_ in inputs): raise ValueError("inputs and filter_shape must agree on in_channels.") for i, outputs_grad in enumerate(outputs_grads): if any(output_grad.shape.ndims != 4 for output_grad in outputs_grad): raise ValueError("outputs[%d] must be 4-D Tensor." % i) if any(output_grad.shape.as_list()[-1] != filter_shape[-1] for output_grad in outputs_grad): raise ValueError( "outputs[%d] and filter_shape must agree on out_channels." % i) if len(strides) != 4: raise ValueError("strides must be length-4 list of ints.") if dilations is not None and len(dilations) != 4: raise ValueError("dilations must be length-4 list of ints.") self._inputs = inputs self._outputs_grads = outputs_grads self._filter_shape = filter_shape self._strides = strides self._padding = padding self._data_format = data_format self._dilations = dilations self._has_bias = has_bias self._patches = None self._patch_mask = patch_mask super(ConvDiagonalFactor, self).__init__() @property def _var_scope(self): return "ff_convdiag_" + scope_string_from_params( tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads))) @property def _cov_shape(self): filter_height, filter_width, in_channels, out_channels = self._filter_shape return [ filter_height * filter_width * in_channels + self._has_bias, out_channels ] @property def _num_sources(self): return len(self._outputs_grads) @property def _num_towers(self): return len(self._inputs) @property def _dtype(self): return self._inputs[0].dtype def _partial_batch_size(self, source=0, tower=0): return utils.get_shape(self._outputs_grads[source][tower])[0] def make_covariance_update_op(self, ema_decay, ema_weight): filter_height, filter_width, _, _ = self._filter_shape # TODO(b/64144716): there is potential here for a big savings in terms # of memory use. if self._dilations is None: rates = (1, 1, 1, 1) else: rates = tuple(self._dilations) self._patches = [] for tower in range(self._num_towers): with maybe_place_on_device(self._get_data_device(tower)): patches = tf.extract_image_patches( self._inputs[tower], ksizes=[1, filter_height, filter_width, 1], strides=self._strides, rates=rates, padding=self._padding) if self._patch_mask is not None: assert self._patch_mask.shape == self._filter_shape[0:-1] # This should work as intended due to broadcasting. patches *= self._patch_mask if self._has_bias: patches = append_homog(patches) self._patches.append(patches) return super(ConvDiagonalFactor, self).make_covariance_update_op( ema_decay, ema_weight) def _compute_new_cov(self, source, tower): patches = self._patches[tower] batch_size = utils.get_shape(patches)[0] outputs_grad = self._outputs_grads[source][tower] new_cov = self._convdiag_sum_of_squares(patches, outputs_grad) new_cov /= tf.cast(batch_size, new_cov.dtype) return new_cov def _convdiag_sum_of_squares(self, patches, outputs_grad): # This computes the sum of the squares of the per-training-case "gradients". # It does this simply by computing a giant tensor containing all of these, # doing an entry-wise square, and them summing along the batch dimension. case_wise_gradients = tf.einsum("bijk,bijl->bkl", patches, outputs_grad) return tf.reduce_sum(tf.square(case_wise_gradients), axis=0) def _get_data_device(self, tower): return self._inputs[tower].device class FullyConnectedKroneckerFactor(DenseSquareMatrixFactor): """Kronecker factor for the input or output side of a fully-connected layer. """ def __init__(self, tensors, has_bias=False): """Instantiate FullyConnectedKroneckerFactor. Args: tensors: List of list of Tensors, each of shape [batch_size, n]. The Tensors are typically either a layer's inputs or its output's gradients. The first list index is source, the second is tower. has_bias: bool. If True, append '1' to each row. """ # The tensor argument is either a tensor of input activations or a tensor of # output pre-activation gradients. self._has_bias = has_bias self._tensors = tensors super(FullyConnectedKroneckerFactor, self).__init__() @property def _var_scope(self): return "ff_fckron_" + scope_string_from_params( tuple(nest.flatten(self._tensors)) + (self._has_bias,)) @property def _cov_shape(self): size = self._tensors[0][0].shape[1] + self._has_bias return [size, size] @property def _num_sources(self): return len(self._tensors) @property def _num_towers(self): return len(self._tensors[0]) @property def _dtype(self): return self._tensors[0][0].dtype def _partial_batch_size(self, source=0, tower=0): return utils.get_shape(self._tensors[source][tower])[0] def _compute_new_cov(self, source, tower): tensor = self._tensors[source][tower] if self._has_bias: tensor = append_homog(tensor) return compute_cov(tensor) def _get_data_device(self, tower): return self._tensors[0][tower].device class ConvInputKroneckerFactor(DenseSquareMatrixFactor): r"""Kronecker factor for the input side of a convolutional layer. Estimates E[ a a^T ] where a is the inputs to a convolutional layer given example x. Expectation is taken over all examples and locations. Equivalent to Omega in https://arxiv.org/abs/1602.01407 for details. See Section 3.1 Estimating the factors. """ def __init__(self, inputs, filter_shape, padding, strides=None, dilation_rate=None, data_format=None, extract_patches_fn=None, has_bias=False, sub_sample_inputs=None, sub_sample_patches=None, patch_mask=None): """Initializes ConvInputKroneckerFactor. Args: inputs: List of Tensors of shape [batch_size, ..spatial_input_size.., in_channels]. Inputs to layer. List index is tower. filter_shape: List of ints. Contains [..spatial_filter_size.., in_channels, out_channels]. Shape of convolution kernel. padding: str. Padding method for layer. "SAME" or "VALID". strides: List of ints or None. Contains [..spatial_filter_strides..] if 'extract_patches_fn' is compatible with tf.nn.convolution(), else [1, ..spatial_filter_strides, 1]. dilation_rate: List of ints or None. Rate for dilation along each spatial dimension if 'extract_patches_fn' is compatible with tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. data_format: str or None. Format of input data. extract_patches_fn: str or None. Name of function that extracts image patches. One of "extract_convolution_patches", "extract_image_patches", "extract_pointwise_conv2d_patches". has_bias: bool. If True, append 1 to in_channel. sub_sample_inputs: `bool`. If True, then subsample the inputs from which the image patches are extracted. (Default: None) sub_sample_patches: `bool`, If `True` then subsample the extracted patches. (Default: None) patch_mask: Tensor of shape [kernel_height, kernel_width, in_channels] or None. If not None this is multiplied against the extracted patches Tensor (broadcasting along the batch dimension) before statistics are computed. (Default: None) """ self._inputs = inputs self._filter_shape = filter_shape self._strides = strides self._padding = padding self._dilation_rate = dilation_rate self._data_format = data_format self._extract_patches_fn = extract_patches_fn self._has_bias = has_bias if sub_sample_inputs is None: self._sub_sample_inputs = _SUB_SAMPLE_INPUTS else: self._sub_sample_inputs = sub_sample_inputs if sub_sample_patches is None: self._sub_sample_patches = _SUB_SAMPLE_PATCHES else: self._sub_sample_patches = sub_sample_patches self._patch_mask = patch_mask super(ConvInputKroneckerFactor, self).__init__() @property def _var_scope(self): return "ff_convinkron_" + scope_string_from_params( tuple(self._inputs) + tuple((self._filter_shape, self._strides, self._padding, self._dilation_rate, self._data_format, self._has_bias, self._patch_mask))) @property def _cov_shape(self): spatial_filter_shape = self._filter_shape[0:-2] in_channels = self._filter_shape[-2] size = np.prod(spatial_filter_shape) * in_channels + self._has_bias return [size, size] @property def _num_sources(self): return 1 @property def _num_towers(self): return len(self._inputs) @property def _dtype(self): return self._inputs[0].dtype def _partial_batch_size(self, source=0, tower=0): assert source == 0 return utils.get_shape(self._inputs[tower])[0] def _compute_new_cov(self, source, tower): assert source == 0 inputs = self._inputs[tower] if self._sub_sample_inputs: batch_size = utils.get_shape(inputs)[0] # computes: int(math.ceil(batch_size * _INPUTS_TO_EXTRACT_PATCHES_FACTOR)) new_size = tf.cast( tf.ceil(tf.multiply(tf.cast(batch_size, dtype=tf.float32), _INPUTS_TO_EXTRACT_PATCHES_FACTOR)), dtype=tf.int32) inputs = _random_tensor_gather(inputs, new_size) # TODO(b/64144716): there is potential here for a big savings in terms of # memory use. if _USE_PATCHES_SECOND_MOMENT_OP: raise NotImplementedError # patches op is not available outside of Google, # sorry! You'll need to turn it off to proceed. else: if self._extract_patches_fn in [None, "extract_convolution_patches"]: patches = utils.extract_convolution_patches( inputs, self._filter_shape, padding=self._padding, strides=self._strides, dilation_rate=self._dilation_rate, data_format=self._data_format) elif self._extract_patches_fn == "extract_image_patches": assert inputs.shape.ndims == 4 assert len(self._filter_shape) == 4 assert len(self._strides) == 4, self._strides if self._dilation_rate is None: rates = [1, 1, 1, 1] else: rates = self._dilation_rate assert len(rates) == 4 assert rates[0] == rates[-1] == 1 patches = tf.extract_image_patches( inputs, ksizes=[1] + list(self._filter_shape[0:-2]) + [1], strides=self._strides, rates=rates, padding=self._padding) elif self._extract_patches_fn == "extract_pointwise_conv2d_patches": assert self._strides in [None, [1, 1, 1, 1], (1, 1, 1, 1)] assert self._filter_shape[0] == self._filter_shape[1] == 1 patches = utils.extract_pointwise_conv2d_patches( inputs, self._filter_shape, data_format=None) else: raise NotImplementedError(self._extract_patches_fn) if self._patch_mask is not None: assert self._patch_mask.shape == self._filter_shape[0:-1] # This should work as intended due to broadcasting. patches *= tf.reshape(self._patch_mask, [-1]) flatten_size = np.prod(self._filter_shape[0:-1]) # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14), # where M = minibatch size, |T| = number of spatial locations, # |Delta| = number of spatial offsets, and J = number of input maps # for convolutional layer l. patches_flat = tf.reshape(patches, [-1, flatten_size]) # We append a homogenous coordinate to patches_flat if the layer has # bias parameters. This gives us [[A_l]]_H from the paper. if self._sub_sample_patches: patches_flat = _subsample_patches(patches_flat) if self._has_bias: patches_flat = append_homog(patches_flat) # We call compute_cov without passing in a normalizer. compute_cov uses # the first dimension of patches_flat i.e. M|T| as the normalizer by # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from # the paper but has a different scale here for consistency with # ConvOutputKroneckerFactor. # (Tilde omitted over A for clarity.) return compute_cov(patches_flat) def _get_data_device(self, tower): return self._inputs[tower].device class ConvInputMultiKF(ConvInputKroneckerFactor): def __init__(self, inputs, filter_shape, padding, num_uses, strides=None, dilation_rate=None, data_format=None, extract_patches_fn=None, has_bias=False, sub_sample_inputs=None, sub_sample_patches=None, patch_mask=None): super(ConvInputMultiKF, self).__init__(inputs, filter_shape, padding, strides=strides, dilation_rate=dilation_rate, data_format=data_format, extract_patches_fn=extract_patches_fn, has_bias=has_bias, sub_sample_inputs=sub_sample_inputs, sub_sample_patches=sub_sample_patches, patch_mask=patch_mask) self._num_uses = num_uses def _partial_batch_size(self, source=0, tower=0): # Note that some internal comptutations of "batch_size" done in the parent # class won't actually be the proper batch size. Instead, they will be # just "the thing to normalize the statistics by", essentially. This is okay # as we don't mix the two things up. return (super(ConvInputMultiKF, self)._partial_batch_size(source=source, tower=tower) // self._num_uses) class ConvInputSUAKroneckerFactor(FisherFactor): r"""Kronecker factor for the input side of a convolutional layer. Assumes activations across locations are uncorrelated. Check section 4.2 Theorem 4 in https://arxiv.org/pdf/1602.01407.pdf for further details on the assumptions. This is a computationally more efficient approximation, especially for very wide layers. """ def __init__(self, inputs, filter_shape, has_bias=False): """Initializes ConvInputSUAKroneckerFactor. If `ASSUME_ZERO_MEAN_ACTIVATIONS` is `True` then assumes activations zero mean and the contribution from `M(j) M(j')` term in Theorem 4 from https://arxiv.org/pdf/1602.01407.pdf is ignored. Args: inputs: List of Tensors of shape [batch_size, ..spatial_input_size.., in_channels]. Inputs to layer. List index is tower. filter_shape: List of ints. Contains [..spatial_filter_size.., in_channels, out_channels]. Shape of convolution kernel. has_bias: bool. If True, appends 1 to mean activations. """ self._inputs = inputs self._filter_shape = filter_shape self._has_bias = has_bias self._kw_kh = np.prod(self._filter_shape[0:-2]) self._in_channels = self._filter_shape[-2] self._matpower_by_exp_and_damping = OrderedDict() # { (float, hashable): variable } self._matpower_registrations = set() # { (float, hashable) } self._damping_funcs_by_id = OrderedDict() # {hashable: lambda} self._damping_var_by_id = OrderedDict() if not ASSUME_ZERO_MEAN_ACTIVATIONS: self._cov_inv_mu_by_damping_id = OrderedDict() self._rank_one_update_scale_by_damping_id = OrderedDict() super(ConvInputSUAKroneckerFactor, self).__init__() @property def _var_scope(self): return "ff_convinsuakron_" + scope_string_from_params( tuple(self._inputs) + tuple((self._filter_shape, self._has_bias))) @property def _cov_shape(self): """Returns a list with value [in_channels, in_channels]. NOTE: This does not return the shape of the full cov matrix. But returns the shape of the matrix which computes the covariance of the input channel activations under the assumption mentioned in Theorem 4 in https://arxiv.org/pdf/1602.01407.pdf. This does not include bias dimension and also includes only the `Sigma` term from Theorem 4 in the paper. """ return [self._in_channels, self._in_channels] @property def _num_sources(self): return 1 @property def _num_towers(self): return len(self._inputs) @property def _dtype(self): return self._inputs[0].dtype @property def mu(self): return self._mu.value def _partial_batch_size(self, source=0, tower=0): assert source == 0 return utils.get_shape(self._inputs[tower])[0] def _register_damping(self, damping_func): damping_id = graph_func_to_id(damping_func) if damping_id not in self._damping_funcs_by_id: self._damping_funcs_by_id[damping_id] = damping_func return damping_id def get_inv_vars(self): inv_vars = [] inv_vars.extend(self._matpower_by_exp_and_damping.values()) return inv_vars def instantiate_cov_variables(self): """Makes the internal cov variable(s).""" super(ConvInputSUAKroneckerFactor, self).instantiate_cov_variables() # Create variables for computing the mean activations only if # `ASSUME_ZERO_MEAN_ACTIVATIONS` is set to `False`. Otherwise the # contribution from the second term in equation 35 in the paper # https://arxiv.org/pdf/1602.01407.pdf is ignored. if not ASSUME_ZERO_MEAN_ACTIVATIONS: with tf.variable_scope(self._var_scope): self._mu = utils.MovingAverageVariable( name="mu", shape=(self._in_channels, 1), # number of input channels. dtype=self._dtype, initializer=tf.zeros_initializer(), normalize_value=ZERO_DEBIAS) def make_covariance_update_op(self, ema_decay, ema_weight): """Constructs and returns the covariance update Op. Args: ema_decay: The exponential moving average decay (float or Tensor). ema_weight: float or Tensor. The weight to put on the newly computed values. This is typically 1.0 - ema_decay. Returns: An Op for updating the covariance Variable referenced by _cov and possibly updating mean activations. """ # The newly computed cov matrix is returned and assigned below to the # moving average. `new_cov` is required to compute mean activations. # Mean activations is given by last row and col of `new_cov. # Remove the last row and col from `new_cov`. new_cov = super(ConvInputSUAKroneckerFactor, self)._compute_total_new_cov() new_mu = new_cov[:-1, -1:] new_cov = new_cov[0:-1, 0:-1] if not ASSUME_ZERO_MEAN_ACTIVATIONS: new_cov = new_cov - tf.matmul(new_mu, new_mu, transpose_b=True) acc_mu_op = self._mu.add_to_average(new_mu, decay=ema_decay, weight=ema_weight) else: acc_mu_op = tf.no_op() if SUBTRACT_MEAN_CONTRIB_FROM_COV: new_cov = new_cov - tf.matmul(new_mu, new_mu, transpose_b=True) acc_cov_op = self._cov.add_to_average(new_cov, decay=ema_decay, weight=ema_weight) return tf.group(acc_cov_op, acc_mu_op) def _compute_new_cov(self, source, tower): assert source == 0 inputs = self._inputs[tower] # Reshape inputs to compute [in_channels, in_channels] shape cov. channel_inputs = tf.reshape(inputs, shape=(-1, self._in_channels)) # Append the bias dimension as we need this to calculate mean activations. channel_inputs = append_homog(channel_inputs) return compute_cov(channel_inputs) def register_matpower(self, exp, damping_func): """Registers a matrix power to be maintained and served on demand. This creates a variable and signals make_inverse_update_ops to make the corresponding update op. The variable can be read via the method get_matpower. Args: exp: float. The exponent to use in the matrix power. damping_func: A function that computes a 0-D Tensor or a float which will be the damping value used. i.e. damping = damping_func(). """ if exp == 1.0: return if exp != -1: raise ValueError("ConvInputSUAKroneckerFactor supports only" "matrix inversion") damping_id = self._register_damping(damping_func) if (exp, damping_id) not in self._matpower_registrations: self._matpower_registrations.add((exp, damping_id)) def _compute_sm_rank_one_update_quants(self, exp, damping_id, damping_value): """Returns tensors to compute Fisher inv using Sherman-Morrison formula.""" cov_inv = self._matpower_by_exp_and_damping[(exp, damping_id)] cov_inv_mu = tf.matmul(cov_inv, self.mu) hatmu_t_cov_inv_hatmu = self._kw_kh * tf.squeeze( tf.matmul(self.mu, cov_inv_mu, transpose_a=True)) if self._has_bias: tildemu_t_cov_inv_tildemu = hatmu_t_cov_inv_hatmu + (1. / damping_value) return cov_inv_mu, (1. / (1. + tildemu_t_cov_inv_tildemu)) else: return cov_inv_mu, (1. / (1. + hatmu_t_cov_inv_hatmu)) def get_matpower(self, exp, damping_func): # Note that this function returns a variable which gets updated by the # inverse ops. It may be stale / inconsistent with the latest value of # self.cov (except when exp == 1). if exp == 1: return self._make_cov_linear_operator( damping=tf.cast(damping_func(), dtype=self._dtype)) elif exp == -1: damping_id = graph_func_to_id(damping_func) cov_inv = self._matpower_by_exp_and_damping[(exp, damping_id)] damping_value = self._damping_var_by_id[damping_id] # Replicates the in_channels * in_channels cov inverse matrix. # Note that in this function the replications are not done explicitly. # They are done using tf.linalg ops and hence they are computationally # efficient. quant_1 = tf.linalg.LinearOperatorKronecker([ tf.linalg.LinearOperatorFullMatrix( cov_inv, is_non_singular=True, is_self_adjoint=True, is_positive_definite=True, is_square=True), tf.linalg.LinearOperatorIdentity( num_rows=self._kw_kh, dtype=self._dtype) ]) # If a bias dimension needs to be appended then we need to expand # scaled_cov_inv_mu and assign `1` to the last dimension. Also # we need to append inverse of damping constant (1 * 1 matrix) to # to the replicated cov inverse matrix. if self._has_bias: bias_operator = tf.linalg.LinearOperatorFullMatrix( [[1. / damping_value]], is_non_singular=True, is_self_adjoint=True, is_positive_definite=True, is_square=True) cov_inv_kron_identity_operator = tf.linalg.LinearOperatorBlockDiag( [quant_1, bias_operator]) if not ASSUME_ZERO_MEAN_ACTIVATIONS: cov_inv_mu = self._cov_inv_mu_by_damping_id[damping_id] scale = self._rank_one_update_scale_by_damping_id[damping_id] # Compute cov_inv_mu kron 1's vec. We tile the cov_inv_mu on the last # dim and then reshape. mean_update = ( tf.expand_dims( append_homog( tf.reshape(tf.tile(cov_inv_mu, [1, self._kw_kh]), (-1,)), homog_value=(1. / damping_value)), axis=1)) else: cov_inv_kron_identity_operator = quant_1 if not ASSUME_ZERO_MEAN_ACTIVATIONS: cov_inv_mu = self._cov_inv_mu_by_damping_id[damping_id] scale = self._rank_one_update_scale_by_damping_id[damping_id] # Compute cov_inv_mu kron 1's vec. We tile the cov_inv_mu on the last # dim and then reshape. mean_update = tf.reshape( tf.tile(cov_inv_mu, [1, self._kw_kh]), (-1, 1)) if ASSUME_ZERO_MEAN_ACTIVATIONS: return cov_inv_kron_identity_operator else: # To include the contribution from the mean activations we need to # low rank update op. Note the Sherman Morrison formula requires # negative of (mean_update * mean_update^T) / scale term to be added. # In order to achieve this using `LinearOperatorLowRankUpdate` set `v` # to negative of mean update vector multiplied by scale. return tf.linalg.LinearOperatorLowRankUpdate( cov_inv_kron_identity_operator, mean_update, v=-scale * mean_update, is_non_singular=True, is_self_adjoint=True, is_positive_definite=True, is_square=True) else: raise ValueError("ConvInputSUAKroneckerFactor only supports" "computing inverse of cov matrix.") def make_inverse_update_ops(self): """Creates and return update ops for registered computations.""" inverse_ops = [] for (exp, damping_id), matpower in self._matpower_by_exp_and_damping.items(): assert exp == -1 damping = tf.cast(self._damping_funcs_by_id[damping_id](), self._dtype) damping_assign_op = utils.smart_assign( self._damping_var_by_id[damping_id], damping) inverse_op = utils.smart_assign(matpower, utils.posdef_inv(self.cov, damping)) inverse_ops.append(damping_assign_op) if not ASSUME_ZERO_MEAN_ACTIVATIONS: with tf.control_dependencies([inverse_op]): (cov_inv_mu, rank_one_update_scale) = self._compute_sm_rank_one_update_quants( exp, damping_id, damping) inverse_ops.append( utils.smart_assign(self._cov_inv_mu_by_damping_id[damping_id], cov_inv_mu)) inverse_ops.append( utils.smart_assign( self._rank_one_update_scale_by_damping_id[damping_id], rank_one_update_scale)) else: inverse_ops.append(inverse_op) return inverse_ops def get_inverse(self, damping_func): # Just for backwards compatibility of some old code and tests return self.get_matpower(-1, damping_func) def instantiate_inv_variables(self): """Makes the internal "inverse" variable(s).""" for (exp, damping_id) in self._matpower_registrations: if exp != -1.: raise ValueError("ConvInputSUAKroneckerFactor only supports inverse" "computation") exp_string = scalar_or_tensor_to_string(exp) damping_func = self._damping_funcs_by_id[damping_id] damping_string = graph_func_to_string(damping_func) with tf.variable_scope(self._var_scope): matpower = tf.get_variable( "matpower_exp{}_damp{}".format(exp_string, damping_string), initializer=inverse_initializer, shape=self._cov_shape, trainable=False, dtype=self._dtype, use_resource=True) assert (exp, damping_id) not in self._matpower_by_exp_and_damping self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower self._damping_var_by_id[damping_id] = tf.get_variable( "damping_var_{}_{}".format(exp_string, damping_string), initializer=tf.zeros_initializer(), shape=(), trainable=False, dtype=self._dtype, use_resource=True) if not ASSUME_ZERO_MEAN_ACTIVATIONS: self._cov_inv_mu_by_damping_id[damping_id] = tf.get_variable( "cov_inv_mu_{}_{}".format(exp_string, damping_string), initializer=tf.zeros_initializer(), shape=(self._in_channels, 1), trainable=False, dtype=self._dtype, use_resource=True) self._rank_one_update_scale_by_damping_id[damping_id] = tf.get_variable( "rank_one_update_scale_{}_{}".format(exp_string, damping_string), initializer=tf.zeros_initializer(), shape=(), trainable=False, dtype=self._dtype, use_resource=True) def _make_cov_linear_operator(self, damping=None): """Returns cov as a linear operator. Args: damping: Damping value tensor. If `damping` is not None then returns damped covariance matrix. Returns: tf.linalg.LinearOperator instance. """ if damping is not None: cov = self.cov + damping * tf.eye(self._cov_shape[0], dtype=self._dtype) else: cov = self.cov cov_operator = tf.linalg.LinearOperatorKronecker([ tf.linalg.LinearOperatorFullMatrix( cov, is_self_adjoint=True, is_square=True), tf.linalg.LinearOperatorIdentity( num_rows=self._kw_kh, dtype=self._dtype) ]) if self._has_bias: bias_value = damping if damping is not None else 0. bias_operator = tf.linalg.LinearOperatorFullMatrix([[bias_value]], is_self_adjoint=True, is_square=True) cov_operator = tf.linalg.LinearOperatorBlockDiag( [cov_operator, bias_operator]) if ASSUME_ZERO_MEAN_ACTIVATIONS: return cov_operator else: # self.mu kron 1's vec is computed below by tiling mu. hatmu = tf.tile(self.mu, [1, self._kw_kh]) if self._has_bias: tildemu = append_homog(tf.reshape(hatmu, (-1,))) mean_update = tf.expand_dims(tildemu, axis=1) else: mean_update = tf.reshape(hatmu, (-1, 1)) return tf.linalg.LinearOperatorLowRankUpdate( cov_operator, mean_update, is_self_adjoint=True, is_square=True) def get_cov_as_linear_operator(self): return self._make_cov_linear_operator() def get_cholesky(self, damping_func): raise NotImplementedError("ConvInputSUAKroneckerFactor does not support" "cholesky factorization") def get_cholesky_inverse(self, damping_func): raise NotImplementedError("ConvInputSUAKroneckerFactor does not support" "cholesky inverse computation") def register_cholesky(self): raise NotImplementedError("ConvInputSUAKroneckerFactor does not support" "cholesky factorization") def register_cholesky_inverse(self): raise NotImplementedError("ConvInputSUAKroneckerFactor does not support" "cholesky inverse computation") def _get_data_device(self, tower): return self._inputs[tower].device class ConvOutputKroneckerFactor(DenseSquareMatrixFactor): r"""Kronecker factor for the output side of a convolutional layer. Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer given example x and ds = (d / d s) log(p(y|x, w)). Expectation is taken over all examples and locations. Equivalent to Gamma in https://arxiv.org/abs/1602.01407 for details. See Section 3.1 Estimating the factors. """ def __init__(self, outputs_grads, data_format=None): """Initializes ConvOutputKroneckerFactor. Args: outputs_grads: List of list of Tensors. Each Tensor is of shape [batch_size, ..spatial_input_size.., out_channels]. First list index is source, the second is tower. data_format: None or str. Format of outputs_grads. Raises: ValueError: If channels are not final dimension. """ if not utils.is_data_format_channel_last(data_format): raise ValueError("Channel must be last.") self._out_channels = outputs_grads[0][0].shape.as_list()[-1] self._outputs_grads = outputs_grads super(ConvOutputKroneckerFactor, self).__init__() @property def _var_scope(self): return "ff_convoutkron_" + scope_string_from_params( nest.flatten(self._outputs_grads)) @property def _cov_shape(self): size = self._out_channels return [size, size] @property def _num_sources(self): return len(self._outputs_grads) @property def _num_towers(self): return len(self._outputs_grads[0]) @property def _dtype(self): return self._outputs_grads[0][0].dtype def _partial_batch_size(self, source=0, tower=0): return utils.get_shape(self._outputs_grads[source][tower])[0] def _compute_new_cov(self, source, tower): outputs_grad = self._outputs_grads[source][tower] # reshaped_tensor below is the matrix DS_l defined in the KFC paper # (tilde omitted over S for clarity). It has shape M|T| x I, where # M = minibatch size, |T| = number of spatial locations, and # I = number of output maps for convolutional layer l. reshaped_tensor = tf.reshape(outputs_grad, [-1, self._out_channels]) # Following the reasoning in ConvInputKroneckerFactor._compute_new_cov, # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l # as defined in the paper, with shape I x I. # (Tilde omitted over S for clarity.) return compute_cov(reshaped_tensor) def _get_data_device(self, tower): return self._outputs_grads[0][tower].device class ConvOutputMultiKF(ConvOutputKroneckerFactor): def __init__(self, outputs_grads, num_uses, data_format=None): super(ConvOutputMultiKF, self).__init__(outputs_grads, data_format=data_format) self._num_uses = num_uses def _partial_batch_size(self, source=0, tower=0): # Note that some internal comptutations of "batch_size" done in the parent # class won't actually be the proper batch size. Instead, they will be # just "the thing to normalize the statistics by", essentially. This is okay # as we don't mix the two things up. return (super(ConvOutputMultiKF, self)._partial_batch_size(source=source, tower=tower) // self._num_uses) class FullyConnectedMultiKF(FullyConnectedKroneckerFactor): """Kronecker factor for a fully connected layer used multiple times.""" def __init__(self, tensors, num_uses=None, has_bias=False): """Constructs a new `FullyConnectedMultiKF`. Args: tensors: List of list of Tensors of shape, each of shape [num_uses * batch_size, n], and is a reshape version of a Tensor of shape [num_uses, batch_size, n]. Each of these tensors is usually a layer's inputs or its output's gradients. The first list index is sources, the second is towers. num_uses: int. The number of time-steps / uses. has_bias: bool. If True, '1' is appended to each row. """ self._num_uses = num_uses self._cov_dt1 = None self._acc_cov_dt1 = None self._make_cov_dt1 = False self._option1quants_by_damping = OrderedDict() self._option2quants_by_damping = OrderedDict() self._option1quants_registrations = set() self._option2quants_registrations = set() super(FullyConnectedMultiKF, self).__init__(tensors=tensors, has_bias=has_bias) @property def _num_timesteps(self): return self._num_uses def _partial_batch_size(self, source=0, tower=0): total_len = utils.get_shape(self._tensors[source][tower])[0] return total_len // self._num_timesteps @property def _var_scope(self): return "ff_fc_multi_" + scope_string_from_params( tuple(nest.flatten(self._tensors)) + (self._num_timesteps, self._has_bias,)) def get_inv_vars(self): inv_vars = super(FullyConnectedMultiKF, self).get_inv_vars() inv_vars.extend(self._option1quants_by_damping.values()) inv_vars.extend(self._option2quants_by_damping.values()) return inv_vars def make_covariance_update_op(self, ema_decay, ema_weight): op = super(FullyConnectedMultiKF, self).make_covariance_update_op( ema_decay, ema_weight) if self._cov_dt1 is not None: new_cov_dt1_contribs = [] for source in range(self._num_sources): for tower in range(self._num_towers): with maybe_place_on_device(self._get_data_device(tower)): new_cov_dt1_contribs.append(self._compute_new_cov_dt1(source, tower)) new_cov_dt1 = (tf.add_n(new_cov_dt1_contribs) / float(self._num_towers)) # See comments in FisherFactor.make_covariance_update_op() for details. new_cov_dt1 = utils.all_average(new_cov_dt1) op2 = self._cov_dt1.add_to_average(new_cov_dt1, decay=ema_decay, weight=ema_weight) # TODO(b/69112164): # It's important that _cov and _cov_dt1 remain consistent with each # other while the inverse ops are happening. How can we ensure this? # We will need to add explicit synchronization for this to # work with asynchronous training. op = tf.group(op, op2) return op def _compute_new_cov_dt1(self, source, tower): # pylint: disable=missing-docstring tensor = self._tensors[source][tower] if self._has_bias: # This appending is technically done twice (the other time is for # _compute_new_cov()) tensor = append_homog(tensor) total_len = utils.get_shape(tensor)[0] batch_size = total_len // self._num_timesteps tensor_present = tensor[:-batch_size, :] tensor_future = tensor[batch_size:, :] # We specify a normalizer for this computation to ensure a PSD Fisher # block estimate. This is equivalent to padding with zeros, as was done # in Section B.2 of the appendix. return compute_cov( tensor_future, tensor_right=tensor_present, normalizer=total_len) def _get_data_device(self, tower): return self._tensors[0][tower].device @property def _vec_shape(self): size = self._tensors[0][0].shape[1] + self._has_bias return [size] def get_option1quants(self, damping_func): damping_id = graph_func_to_id(damping_func) return self._option1quants_by_damping[damping_id] def get_option2quants(self, damping_func): damping_id = graph_func_to_id(damping_func) return self._option2quants_by_damping[damping_id] @property def cov_dt1(self): assert self._cov_dt1 is not None return self._cov_dt1.value def get_cov_vars(self): cov_vars = super(FullyConnectedMultiKF, self).get_cov_vars() if self._make_cov_dt1: cov_vars += [self.cov_dt1] return cov_vars def register_cov_dt1(self): self._make_cov_dt1 = True def instantiate_cov_variables(self): super(FullyConnectedMultiKF, self).instantiate_cov_variables() assert self._cov_dt1 is None if self._make_cov_dt1: with tf.variable_scope(self._var_scope): self._cov_dt1 = utils.MovingAverageVariable( name="cov_dt1", shape=self._cov_shape, dtype=self._dtype, initializer=tf.zeros_initializer(), normalize_value=ZERO_DEBIAS) def register_option1quants(self, damping_func): damping_id = self._register_damping(damping_func) if damping_id not in self._option1quants_registrations: self._option1quants_registrations.add(damping_id) def register_option2quants(self, damping_func): damping_id = self._register_damping(damping_func) if damping_id not in self._option2quants_registrations: self._option2quants_registrations.add(damping_id) def instantiate_inv_variables(self): super(FullyConnectedMultiKF, self).instantiate_inv_variables() for damping_id in self._option1quants_registrations: damping_func = self._damping_funcs_by_id[damping_id] damping_string = graph_func_to_string(damping_func) # It's questionable as to whether we should initialize with stuff like # this at all. Ideally these values should never be used until they are # updated at least once. with tf.variable_scope(self._var_scope): Lmat = tf.get_variable( # pylint: disable=invalid-name "Lmat_damp{}".format(damping_string), initializer=inverse_initializer, shape=self._cov_shape, trainable=False, dtype=self._dtype, use_resource=True) psi = tf.get_variable( "psi_damp{}".format(damping_string), initializer=tf.ones_initializer(), shape=self._vec_shape, trainable=False, dtype=self._dtype, use_resource=True) assert damping_id not in self._option1quants_by_damping self._option1quants_by_damping[damping_id] = (Lmat, psi) for damping_id in self._option2quants_registrations: damping_func = self._damping_funcs_by_id[damping_id] damping_string = graph_func_to_string(damping_func) # It's questionable as to whether we should initialize with stuff like # this at all. Ideally these values should never be used until they are # updated at least once. with tf.variable_scope(self._var_scope): Pmat = tf.get_variable( # pylint: disable=invalid-name "Lmat_damp{}".format(damping_string), initializer=inverse_initializer, shape=self._cov_shape, trainable=False, dtype=self._dtype, use_resource=True) Kmat = tf.get_variable( # pylint: disable=invalid-name "Kmat_damp{}".format(damping_string), initializer=inverse_initializer, shape=self._cov_shape, trainable=False, dtype=self._dtype, use_resource=True) mu = tf.get_variable( "mu_damp{}".format(damping_string), initializer=tf.ones_initializer(), shape=self._vec_shape, trainable=False, dtype=self._dtype, use_resource=True) assert damping_id not in self._option2quants_by_damping self._option2quants_by_damping[damping_id] = (Pmat, Kmat, mu) def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" # TODO(b/69918258): Add correctness tests for this method. # pylint: disable=invalid-name ops = [] if (len(self._option1quants_by_damping) + len(self._option2quants_by_damping)): # Note that C0 and C1 are stand-ins for A0 and A1, or G0 and G1, from # the pseudo-code in the original paper. Because the computations for # the A and G case are essentially the same they can both be performed by # the same class (this one). C1 = self.cov_dt1 # Get the eigendecomposition of C0 (= self.cov) eigen_e, eigen_V = self.get_eigendecomp() # TODO(b/69678661): Note, there is an implicit assumption here that C1 # and C0 (as represented here by its eigen-decomp) are consistent. This # could fail to be the case if self._cov and self._cov_dt1 are not updated # consistently, or are somehow read between or during the cov updates. # Can this possibly happen? Is there a way to prevent it? for damping_id, (Lmat_var, psi_var) in self._option1quants_by_damping.items(): damping = self._damping_funcs_by_id[damping_id]() damping = tf.cast(damping, self._dtype) invsqrtC0 = tf.matmul( eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) # Might need to enforce symmetry lost due to numerical issues. invsqrtC0 = (invsqrtC0 + tf.transpose(invsqrtC0)) / 2.0 # The following line imposes the symmetry assumed by "Option 1" on C1. # Strangely the code can work okay with this line commented out, # depending on how psd_eig is defined. I'm not sure why. C1 = (C1 + tf.transpose(C1)) / 2.0 # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi}) hPsi = tf.matmul(tf.matmul(invsqrtC0, C1), invsqrtC0) # Compute the decomposition U*diag(psi)*U^T = hPsi psi, U = utils.posdef_eig(hPsi) # L = C0^(-1/2) * U Lmat = tf.matmul(invsqrtC0, U) ops.append(utils.smart_assign(Lmat_var, Lmat)) ops.append(utils.smart_assign(psi_var, psi)) for damping_id, (Pmat_var, Kmat_var, mu_var) in self._option2quants_by_damping.items(): damping = self._damping_funcs_by_id[damping_id]() damping = tf.cast(damping, self._dtype) # compute C0^(-1/2) invsqrtC0 = tf.matmul( eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) # Might need to enforce symmetry lost due to numerical issues. invsqrtC0 = (invsqrtC0 + tf.transpose(invsqrtC0)) / 2.0 # Compute the product C0^(-1/2) * C1 invsqrtC0C1 = tf.matmul(invsqrtC0, C1) # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi}) hPsi = tf.matmul(invsqrtC0C1, invsqrtC0) # Compute the decomposition E*diag(mu)*E^T = hPsi^T * hPsi # Note that we using the notation mu instead of "m" for the eigenvalues. # Instead of computing the product hPsi^T * hPsi and then doing an # eigen-decomposition of this we just compute the SVD of hPsi and then # square the singular values to get the eigenvalues. For a justification # of this approach, see: # https://en.wikipedia.org/wiki/Singular-value_decomposition#Relation_to_eigenvalue_decomposition sqrtmu, _, E = tf.svd(hPsi) mu = tf.square(sqrtmu) # Mathematically, the eigenvalues should not should not exceed 1.0, but # due to numerical issues, or possible issues with inconsistent # values of C1 and (the eigen-decomposition of) C0 they might. So # we enforce this condition. mu = tf.minimum(mu, 1.0) # P = (C0^(-1/2) * C1)^T * C0^(-1/2) = C_1^T * C_0^(-1) Pmat = tf.matmul(invsqrtC0C1, invsqrtC0, transpose_a=True) # K = C_0^(-1/2) * E Kmat = tf.matmul(invsqrtC0, E) ops.append(utils.smart_assign(Pmat_var, Pmat)) ops.append(utils.smart_assign(Kmat_var, Kmat)) ops.append(utils.smart_assign(mu_var, mu)) ops += super(FullyConnectedMultiKF, self).make_inverse_update_ops() return [tf.group(*ops)] # pylint: enable=invalid-name