# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""FisherBlock definitions.

This library contains classes for estimating blocks in a model's Fisher
Information matrix. Suppose one has a model that parameterizes a posterior
distribution over 'y' given 'x' with parameters 'params', p(y | x, params). Its
Fisher Information matrix is given by,

  F(params) = E[ v(x, y, params) v(x, y, params)^T ]

where,

  v(x, y, params) = (d / d params) log p(y | x, params)

and the expectation is taken with respect to the data's distribution for 'x' and
the model's posterior distribution for 'y',

  x ~ p(x)
  y ~ p(y | x, params)

"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc

# Dependency imports
import six
import tensorflow as tf

from tensorflow.python.util import nest
from kfac.python.ops import fisher_factors
from kfac.python.ops import utils

# For blocks corresponding to convolutional layers, or any type of block where
# the parameters can be thought of as being replicated in time or space,
# we want to adjust the scale of the damping by
#   damping /= num_replications ** NORMALIZE_DAMPING_POWER
NORMALIZE_DAMPING_POWER = 1.0

# Methods for adjusting damping for FisherBlocks. See
# compute_pi_adjusted_damping() for details.
PI_OFF_NAME = "off"
PI_TRACENORM_NAME = "tracenorm"
PI_TYPE = PI_TRACENORM_NAME


def set_global_constants(normalize_damping_power=None, pi_type=None):
  """Sets various global constants used by the classes in this module."""
  global NORMALIZE_DAMPING_POWER
  global PI_TYPE

  if normalize_damping_power is not None:
    NORMALIZE_DAMPING_POWER = normalize_damping_power

  if pi_type is not None:
    PI_TYPE = pi_type


def normalize_damping(damping, num_replications):
  """Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER."""
  if NORMALIZE_DAMPING_POWER:
    return damping / (num_replications ** NORMALIZE_DAMPING_POWER)
  return damping


def compute_pi_tracenorm(left_cov, right_cov):
  """Computes the scalar constant pi for Tikhonov regularization/damping.

  pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) )
  See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details.

  Args:
    left_cov: A LinearOperator object. The left Kronecker factor "covariance".
    right_cov: A LinearOperator object. The right Kronecker factor "covariance".

  Returns:
    The computed scalar constant pi for these Kronecker Factors (as a Tensor).
  """
  # Instead of dividing by the dim of the norm, we multiply by the dim of the
  # other norm. This works out the same in the ratio.
  left_norm = left_cov.trace() * int(right_cov.domain_dimension)
  right_norm = right_cov.trace() * int(left_cov.domain_dimension)
  assert_positive = tf.assert_positive(
      right_norm,
      message="PI computation, trace of right cov matrix should be positive. "
      "Note that most likely cause of this error is that the optimizer "
      "diverged (e.g. due to bad hyperparameters).")
  with tf.control_dependencies([assert_positive]):
    pi = tf.sqrt(left_norm / right_norm)
  return pi


def compute_pi_adjusted_damping(left_cov, right_cov, damping):

  if PI_TYPE == PI_TRACENORM_NAME:
    pi = compute_pi_tracenorm(left_cov, right_cov)
    damping = tf.cast(damping, dtype=pi.dtype)
    return (damping * pi, damping / pi)

  elif PI_TYPE == PI_OFF_NAME:
    return (damping, damping)


class PackagedFunc(object):
  """A Python thunk with a stable ID.

  Enables stable names for lambdas.
  """

  def __init__(self, func, func_id):
    """Initializes PackagedFunc.

    Args:
      func: a zero-arg Python function.
      func_id: a hashable, function that produces a hashable, or a list/tuple
        thereof.
    """
    self._func = func
    func_id = func_id if isinstance(func_id, (tuple, list)) else (func_id,)
    self._func_id = func_id

  def __call__(self):
    return self._func()

  @property
  def func_id(self):
    """A hashable identifier for this function."""
    return tuple(elt() if callable(elt) else elt for elt in self._func_id)


def _package_func(func, func_id):
  return PackagedFunc(func, func_id)


@six.add_metaclass(abc.ABCMeta)
class FisherBlock(object):
  """Abstract base class for objects modeling approximate Fisher matrix blocks.

  Subclasses must implement register_matpower, multiply_matpower,
  instantiate_factors, tensors_to_compute_grads, and num_registered_towers
  methods.
  """

  def __init__(self, layer_collection):
    self._layer_collection = layer_collection

  @abc.abstractmethod
  def instantiate_factors(self, grads_list, damping):
    """Creates and registers the component factors of this Fisher block.

    Args:
      grads_list: A list gradients (each a Tensor or tuple of Tensors) with
          respect to the tensors returned by tensors_to_compute_grads() that
          are to be used to estimate the block.
      damping: The damping factor (float or Tensor).
    """
    pass

  @abc.abstractmethod
  def register_matpower(self, exp):
    """Registers a matrix power to be computed by the block.

    Args:
      exp: A float representing the power to raise the block by.
    """
    pass

  @abc.abstractmethod
  def register_cholesky(self):
    """Registers a Cholesky factor to be computed by the block."""
    pass

  @abc.abstractmethod
  def register_cholesky_inverse(self):
    """Registers an inverse Cholesky factor to be computed by the block."""
    pass

  def register_inverse(self):
    """Registers a matrix inverse to be computed by the block."""
    self.register_matpower(-1)

  @abc.abstractmethod
  def multiply_matpower(self, vector, exp):
    """Multiplies the vector by the (damped) matrix-power of the block.

    Args:
      vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
      exp: A float representing the power to raise the block by before
        multiplying it by the vector.

    Returns:
      The vector left-multiplied by the (damped) matrix-power of the block.
    """
    pass

  def multiply_inverse(self, vector):
    """Multiplies the vector by the (damped) inverse of the block.

    Args:
      vector: The vector (a Tensor or tuple of Tensors) to be multiplied.

    Returns:
      The vector left-multiplied by the (damped) inverse of the block.
    """
    return self.multiply_matpower(vector, -1)

  def multiply(self, vector):
    """Multiplies the vector by the (damped) block.

    Args:
      vector: The vector (a Tensor or tuple of Tensors) to be multiplied.

    Returns:
      The vector left-multiplied by the (damped) block.
    """
    return self.multiply_matpower(vector, 1)

  @abc.abstractmethod
  def multiply_cholesky(self, vector, transpose=False):
    """Multiplies the vector by the (damped) Cholesky-factor of the block.

    Args:
      vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
      transpose: Bool. If true the Cholesky factor is transposed before
        multiplying the vector. (Default: False)

    Returns:
      The vector left-multiplied by the (damped) Cholesky-factor of the block.
    """
    pass

  @abc.abstractmethod
  def multiply_cholesky_inverse(self, vector, transpose=False):
    """Multiplies vector by the (damped) inverse Cholesky-factor of the block.

    Args:
      vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
      transpose: Bool. If true the Cholesky factor inverse is transposed
        before multiplying the vector. (Default: False)
    Returns:
      Vector left-multiplied by (damped) inverse Cholesky-factor of the block.
    """
    pass

  @abc.abstractmethod
  def tensors_to_compute_grads(self):
    """Returns the Tensor(s) with respect to which this FisherBlock needs grads.
    """
    pass

  @abc.abstractproperty
  def num_registered_towers(self):
    """Number of towers registered for this FisherBlock.

    Typically equal to the number of towers in a multi-tower setup.
    """
    pass


@six.add_metaclass(abc.ABCMeta)
class FullFB(FisherBlock):
  """Base class for blocks that use full matrix representations (no approx)."""

  def register_matpower(self, exp):
    self._factor.register_matpower(exp, self._damping_func)

  def register_cholesky(self):
    self._factor.register_cholesky(self._damping_func)

  def register_cholesky_inverse(self):
    self._factor.register_cholesky_inverse(self._damping_func)

  def _multiply_matrix(self, matrix, vector, transpose=False):
    vector_flat = utils.tensors_to_column(vector)
    out_flat = matrix.matmul(vector_flat, adjoint=transpose)
    return utils.column_to_tensors(vector, out_flat)

  def multiply_matpower(self, vector, exp):
    matrix = self._factor.get_matpower(exp, self._damping_func)
    return self._multiply_matrix(matrix, vector)

  def multiply_cholesky(self, vector, transpose=False):
    matrix = self._factor.get_cholesky(self._damping_func)
    return self._multiply_matrix(matrix, vector, transpose=transpose)

  def multiply_cholesky_inverse(self, vector, transpose=False):
    matrix = self._factor.get_cholesky_inverse(self._damping_func)
    return self._multiply_matrix(matrix, vector, transpose=transpose)

  def full_fisher_block(self):
    """Explicitly constructs the full Fisher block."""
    return self._factor.get_cov_as_linear_operator().to_dense()


class NaiveFullFB(FullFB):
  """FisherBlock using a full matrix estimate (no approximations).

  NaiveFullFB uses a full matrix estimate (no approximations), and should only
  ever be used for very low dimensional parameters.

  Note that this uses the naive "square the sum estimator", and so is applicable
  to any type of parameter in principle, but has very high variance.
  """

  def __init__(self, layer_collection, params):
    """Creates a NaiveFullFB block.

    Args:
      layer_collection: The LayerCollection object which owns this block.
      params: The parameters of this layer (Tensor or tuple of Tensors).
    """
    self._batch_sizes = []
    self._params = params

    super(NaiveFullFB, self).__init__(layer_collection)

  def instantiate_factors(self, grads_list, damping):
    self._damping_func = _package_func(lambda: damping, (damping,))

    self._factor = self._layer_collection.make_or_get_factor(
        fisher_factors.NaiveFullFactor, (grads_list, self._batch_size))

  def tensors_to_compute_grads(self):
    return self._params

  def register_additional_tower(self, batch_size):
    """Register an additional tower.

    Args:
      batch_size: The batch size, used in the covariance estimator.
    """
    self._batch_sizes.append(batch_size)

  @property
  def num_registered_towers(self):
    return len(self._batch_sizes)

  @property
  def _batch_size(self):
    return tf.reduce_sum(self._batch_sizes)


@six.add_metaclass(abc.ABCMeta)
class DiagonalFB(FisherBlock):
  """A base class for FisherBlocks that use diagonal approximations."""

  def register_matpower(self, exp):
    # Not needed for this.  Matrix powers are computed on demand in the
    # diagonal case
    pass

  def register_cholesky(self):
    # Not needed for this.  Cholesky's are computed on demand in the
    # diagonal case
    pass

  def register_cholesky_inverse(self):
    # Not needed for this.  Cholesky inverses's are computed on demand in the
    # diagonal case
    pass

  def _multiply_matrix(self, matrix, vector):
    vector_flat = utils.tensors_to_column(vector)
    out_flat = matrix.matmul(vector_flat)
    return utils.column_to_tensors(vector, out_flat)

  def multiply_matpower(self, vector, exp):
    matrix = self._factor.get_matpower(exp, self._damping_func)
    return self._multiply_matrix(matrix, vector)

  def multiply_cholesky(self, vector, transpose=False):
    matrix = self._factor.get_cholesky(self._damping_func)
    return self._multiply_matrix(matrix, vector)

  def multiply_cholesky_inverse(self, vector, transpose=False):
    matrix = self._factor.get_cholesky_inverse(self._damping_func)
    return self._multiply_matrix(matrix, vector)

  def full_fisher_block(self):
    return self._factor.get_cov_as_linear_operator().to_dense()


class NaiveDiagonalFB(DiagonalFB):
  """FisherBlock using a diagonal matrix approximation.

  This type of approximation is generically applicable but quite primitive.

  Note that this uses the naive "square the sum estimator", and so is applicable
  to any type of parameter in principle, but has very high variance.
  """

  def __init__(self, layer_collection, params):
    """Creates a NaiveDiagonalFB block.

    Args:
      layer_collection: The LayerCollection object which owns this block.
      params: The parameters of this layer (must be a single Tensor).
    """
    self._params = params
    self._batch_sizes = []

    super(NaiveDiagonalFB, self).__init__(layer_collection)

  def instantiate_factors(self, grads_list, damping):
    self._damping_func = _package_func(lambda: damping, (damping,))

    self._factor = self._layer_collection.make_or_get_factor(
        fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size))

  def tensors_to_compute_grads(self):
    return self._params

  def register_additional_tower(self, batch_size):
    """Register an additional tower.

    Args:
      batch_size: The batch size, used in the covariance estimator.
    """
    self._batch_sizes.append(batch_size)

  @property
  def num_registered_towers(self):
    return len(self._batch_sizes)

  @property
  def _batch_size(self):
    return tf.reduce_sum(self._batch_sizes)


class InputOutputMultiTower(object):
  """Mix-in class for blocks with inputs & outputs and multiple mini-batches."""

  def __init__(self, *args, **kwargs):
    self.__inputs = []
    self.__outputs = []
    super(InputOutputMultiTower, self).__init__(*args, **kwargs)

  def _process_data(self, grads_list):
    """Process data into the format used by the factors.

    This function takes inputs and grads_lists data and processes it into
    one of the formats expected by the FisherFactor classes (depending on
    the value of the global configuration variable TOWER_STRATEGY).

    The initial format of self._inputs is expected to be a list of Tensors
    over towers. Similarly grads_lists is expected to be a list over sources
    of such lists.

    If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing a single
    tensor (represented as a PartitionedTensor object) equal to the
    concatenation (across towers) of all of the elements of self._inputs. And
    similarly grads_list is formatted into a tuple (over sources) of such
    tensors (also represented as PartitionedTensors).

    If TOWER_STRATEGY is "separate", formatting of inputs and grads_list
    remains unchanged from the initial format (although possibly converting
    from lists into tuples).

    Args:
      grads_list: grads_list in its initial format (see above).

    Returns:
      inputs: self._inputs transformed into the appropriate format (see
        above).
      grads_list: grads_list transformed into the appropriate format (see
        above).

    Raises:
      ValueError: if TOWER_STRATEGY is not one of "separate" or "concat".
    """
    inputs = self._inputs
    # inputs is a list over towers of Tensors
    # grads_list is a list of list with the first index being sources and the
    # second being towers.
    if fisher_factors.TOWER_STRATEGY == "concat":
      # Merge towers together into a PartitionedTensor. We package it in
      # a singleton tuple since the factors will expect a list over towers
      inputs = (utils.PartitionedTensor(inputs),)
      # Do the same for grads_list but preserve leading sources dimension
      grads_list = tuple((utils.PartitionedTensor(grads),)
                         for grads in grads_list)
    elif fisher_factors.TOWER_STRATEGY == "separate":
      inputs = tuple(inputs)
      grads_list = tuple(grads_list)

    else:
      raise ValueError("Global config variable TOWER_STRATEGY must be one of "
                       "'concat' or 'separate'.")

    return inputs, grads_list

  def tensors_to_compute_grads(self):
    """Tensors to compute derivative of loss with respect to."""
    return tuple(self._outputs)

  def register_additional_tower(self, inputs, outputs):
    self._inputs.append(inputs)
    self._outputs.append(outputs)

  @property
  def num_registered_towers(self):
    result = len(self._inputs)
    assert result == len(self._outputs)
    return result

  @property
  def _inputs(self):
    return self.__inputs

  @property
  def _outputs(self):
    return self.__outputs


class FullyConnectedDiagonalFB(InputOutputMultiTower, DiagonalFB):
  """FisherBlock for fully-connected (dense) layers using a diagonal approx.

  Estimates the Fisher Information matrix's diagonal entries for a fully
  connected layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of
  squares" estimator.

  Let 'params' be a vector parameterizing a model and 'i' an arbitrary index
  into it. We are interested in Fisher(params)[i, i]. This is,

    Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]
                         = E[ v(x, y, params)[i] ^ 2 ]

  Consider fully connected layer in this model with (unshared) weight matrix
  'w'. For an example 'x' that produces layer inputs 'a' and output
  preactivations 's',

    v(x, y, w) = vec( a (d loss / d s)^T )

  This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding
  to the layer's parameters 'w'.
  """

  def __init__(self, layer_collection, has_bias=False):
    """Creates a FullyConnectedDiagonalFB block.

    Args:
      layer_collection: The LayerCollection object which owns this block.
      has_bias: bool. If True, estimates Fisher with respect to a bias
        parameter as well as the layer's weights.
          (Default: False)
    """
    self._has_bias = has_bias

    super(FullyConnectedDiagonalFB, self).__init__(layer_collection)

  def instantiate_factors(self, grads_list, damping):
    inputs, grads_list = self._process_data(grads_list)

    self._factor = self._layer_collection.make_or_get_factor(
        fisher_factors.FullyConnectedDiagonalFactor,
        (inputs, grads_list, self._has_bias))

    self._damping_func = _package_func(lambda: damping, (damping,))


class ConvDiagonalFB(InputOutputMultiTower, DiagonalFB):
  """FisherBlock for 2-D convolutional layers using a diagonal approx.

  Estimates the Fisher Information matrix's diagonal entries for a convolutional
  layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares"
  estimator.

  Let 'params' be a vector parameterizing a model and 'i' an arbitrary index
  into it. We are interested in Fisher(params)[i, i]. This is,

    Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]
                         = E[ v(x, y, params)[i] ^ 2 ]

  Consider a convolutional layer in this model with (unshared) filter matrix
  'w'. For an example image 'x' that produces layer inputs 'a' and output
  preactivations 's',

    v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T )

  where 'loc' is a single (x, y) location in an image.

  This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding
  to the layer's parameters 'w'.
  """

  def __init__(self,
               layer_collection,
               params,
               strides,
               padding,
               data_format=None,
               dilations=None,
               patch_mask=None):
    """Creates a ConvDiagonalFB block.

    Args:
      layer_collection: The LayerCollection object which owns this block.
      params: The parameters (Tensor or tuple of Tensors) of this layer. If
        kernel alone, a Tensor of shape [kernel_height, kernel_width,
        in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
        containing the previous and a Tensor of shape [out_channels].
      strides: The stride size in this layer (1-D Tensor of length 4).
      padding: The padding in this layer (e.g. "SAME").
      data_format: str or None. Format of input data.
      dilations: List of 4 ints or None. Rate for dilation along all dimensions.
      patch_mask: Tensor of shape [kernel_height, kernel_width, in_channels]
        or None. If not None this is multiplied against the extracted patches
        Tensor (broadcasting along the batch dimension) before statistics are
        computed. (Default: None)

    Raises:
      ValueError: if strides is not length-4.
      ValueError: if dilations is not length-4.
      ValueError: if channel is not last dimension.
    """
    if len(strides) != 4:
      raise ValueError("strides must contain 4 numbers.")

    if dilations is None:
      dilations = [1, 1, 1, 1]

    if len(dilations) != 4:
      raise ValueError("dilations must contain 4 numbers.")

    if not utils.is_data_format_channel_last(data_format):
      raise ValueError("data_format must be channels-last.")

    self._strides = maybe_tuple(strides)
    self._padding = padding
    self._data_format = data_format
    self._dilations = maybe_tuple(dilations)
    self._has_bias = isinstance(params, (tuple, list))

    fltr = params[0] if self._has_bias else params
    self._filter_shape = tuple(fltr.shape.as_list())

    if len(self._filter_shape) != 4:
      raise ValueError(
          "Convolution filter must be of shape"
          " [filter_height, filter_width, in_channels, out_channels].")

    self._patch_mask = patch_mask

    super(ConvDiagonalFB, self).__init__(layer_collection)

  @property
  def _factor_implementation(self):
    return fisher_factors.ConvDiagonalFactor

  def instantiate_factors(self, grads_list, damping):
    inputs, grads_list = self._process_data(grads_list)

    # Infer number of locations upon which convolution is applied.
    self._num_locations = utils.num_conv_locations(inputs[0].shape.as_list(),
                                                   list(self._filter_shape),
                                                   self._strides,
                                                   self._padding)

    self._factor = self._layer_collection.make_or_get_factor(
        self._factor_implementation,
        (inputs, grads_list, self._filter_shape, self._strides, self._padding,
         self._data_format, self._dilations, self._has_bias, self._patch_mask))

    def damping_func():
      return self._num_locations * normalize_damping(damping,
                                                     self._num_locations)

    damping_id = (self._num_locations, "mult", "normalize_damping", damping,
                  self._num_locations)
    self._damping_func = _package_func(damping_func, damping_id)


class ScaleAndShiftFullFB(InputOutputMultiTower, FullFB):
  """A FisherBlock class for scale and shift ops that uses no approximations.

  This class estimates the same thing that NaiveFullFB would (when applied
  to the scale and shift params), but with a lower variance estimator. In
  particular it uses a "sum the squares estimator", and thus the variance will
  shrink as 1/batch_size.
  """

  def __init__(self, layer_collection, broadcast_dim, has_shift=True):
    """Creates a ScaleAndShiftFullFB block.

    Args:
      layer_collection: The LayerCollection object which owns this block.
      broadcast_dim: The dimension of the input up to which broadcasting
        takes place when the scale and shift are multiplied/added.
      has_shift: bool. If True, estimates Fisher with respect to a shift
        parameter as well the scale parameter (which is always included).
    """
    self._broadcast_dim = broadcast_dim
    self._has_shift = has_shift

    super(ScaleAndShiftFullFB, self).__init__(layer_collection)

  def instantiate_factors(self, grads_list, damping):

    inputs, grads_list = self._process_data(grads_list)

    self._factor = self._layer_collection.make_or_get_factor(
        fisher_factors.ScaleAndShiftFullFactor,
        (inputs, grads_list, self._broadcast_dim, self._has_shift))

    self._damping_func = _package_func(lambda: damping, (damping,))


class ScaleAndShiftDiagonalFB(InputOutputMultiTower, DiagonalFB):
  """A FisherBlock class for scale and shift ops that uses a diagonal approx.

  This class estimates the same thing that NaiveDiagonalFB would (when applied
  to the scale and shift params), but with a lower variance estimator. In
  particular it uses a "sum the squares estimator", and thus the variance will
  shrink as 1/batch_size.
  """

  def __init__(self, layer_collection, broadcast_dim, has_shift=True):
    """Creates a ScaleAndShiftDiagonalFB block.

    Args:
      layer_collection: The LayerCollection object which owns this block.
      broadcast_dim: The dimension of the input up to which broadcasting
        takes place when the scale and shift are multiplied/added.
      has_shift: bool. If True, estimates Fisher with respect to a shift
        parameter as well the scale parameter (which is always included).
    """
    self._broadcast_dim = broadcast_dim
    self._has_shift = has_shift

    super(ScaleAndShiftDiagonalFB, self).__init__(layer_collection)

  def instantiate_factors(self, grads_list, damping):

    inputs, grads_list = self._process_data(grads_list)

    self._factor = self._layer_collection.make_or_get_factor(
        fisher_factors.ScaleAndShiftDiagonalFactor,
        (inputs, grads_list, self._broadcast_dim, self._has_shift))

    self._damping_func = _package_func(lambda: damping, (damping,))


class KroneckerProductFB(FisherBlock):
  """A base class for blocks with separate input and output Kronecker factors.

  The Fisher block is approximated as a Kronecker product of the input and
  output factors.
  """

  def _setup_damping(self, damping, normalization=None):
    """Makes functions that compute the damping values for both factors."""
    def compute_damping():
      if normalization is not None:
        maybe_normalized_damping = normalize_damping(damping, normalization)
      else:
        maybe_normalized_damping = damping

      return compute_pi_adjusted_damping(
          self._input_factor.get_cov_as_linear_operator(),
          self._output_factor.get_cov_as_linear_operator(),
          maybe_normalized_damping**0.5)

    if normalization is not None:
      damping_id = ("compute_pi_adjusted_damping",
                    "cov", self._input_factor.name,
                    "cov", self._output_factor.name,
                    "normalize_damping", damping, normalization, "power", 0.5)
    else:
      damping_id = ("compute_pi_adjusted_damping",
                    "cov", self._input_factor.name,
                    "cov", self._output_factor.name,
                    damping, "power", 0.5)

    self._input_damping_func = _package_func(lambda: compute_damping()[0],
                                             damping_id + ("ref", 0))
    self._output_damping_func = _package_func(lambda: compute_damping()[1],
                                              damping_id + ("ref", 1))

    # Also store the damping op for access to the effective damping later on,
    # such as when writing to summary.
    if normalization is not None:
      self._damping = normalize_damping(damping, normalization)
    else:
      self._damping = damping

  def register_matpower(self, exp):
    self._input_factor.register_matpower(exp, self._input_damping_func)
    self._output_factor.register_matpower(exp, self._output_damping_func)

  def register_cholesky(self):
    self._input_factor.register_cholesky(self._input_damping_func)
    self._output_factor.register_cholesky(self._output_damping_func)

  def register_cholesky_inverse(self):
    self._input_factor.register_cholesky_inverse(self._input_damping_func)
    self._output_factor.register_cholesky_inverse(self._output_damping_func)

  @property
  def damping(self):
    """A copy of the damping op.

    This is not used (and should never be used) in KFAC computations. A valid
    usage of this property could be to write damping values to the summary.

    Returns:
      0-D Tensor.
    """
    return self._damping

  @property
  def input_factor(self):
    return self._input_factor

  @property
  def output_factor(self):
    return self._output_factor

  @property
  def _renorm_coeff(self):
    """Kronecker factor multiplier coefficient.

    If this FisherBlock is represented as 'FB = c * kron(left, right)', then
    this is 'c'.

    Returns:
      0-D Tensor.
    """
    return 1.0

  def _multiply_factored_matrix(self, left_factor, right_factor, vector,
                                extra_scale=1.0, transpose_left=False,
                                transpose_right=False):
    """Multiplies a factored matrix."""
    reshaped_vector = utils.layer_params_to_mat2d(vector)
    reshaped_out = right_factor.matmul_right(reshaped_vector,
                                             adjoint=transpose_right)
    reshaped_out = left_factor.matmul(reshaped_out,
                                      adjoint=transpose_left)
    if extra_scale != 1.0:
      reshaped_out = tf.scalar_mul(extra_scale, reshaped_out)
    return utils.mat2d_to_layer_params(vector, reshaped_out)

  def multiply_matpower(self, vector, exp):
    left_factor = self._input_factor.get_matpower(
        exp, self._input_damping_func)
    right_factor = self._output_factor.get_matpower(
        exp, self._output_damping_func)
    extra_scale = float(self._renorm_coeff)**exp
    return self._multiply_factored_matrix(left_factor, right_factor, vector,
                                          extra_scale=extra_scale)

  def multiply_cholesky(self, vector, transpose=False):
    left_factor = self._input_factor.get_cholesky(self._input_damping_func)
    right_factor = self._output_factor.get_cholesky(self._output_damping_func)
    extra_scale = float(self._renorm_coeff)**0.5
    return self._multiply_factored_matrix(left_factor, right_factor, vector,
                                          extra_scale=extra_scale,
                                          transpose_left=transpose,
                                          transpose_right=not transpose)

  def multiply_cholesky_inverse(self, vector, transpose=False):
    left_factor = self._input_factor.get_cholesky_inverse(
        self._input_damping_func)
    right_factor = self._output_factor.get_cholesky_inverse(
        self._output_damping_func)
    extra_scale = float(self._renorm_coeff)**-0.5
    return self._multiply_factored_matrix(left_factor, right_factor, vector,
                                          extra_scale=extra_scale,
                                          transpose_left=transpose,
                                          transpose_right=not transpose)

  def full_fisher_block(self):
    """Explicitly constructs the full Fisher block.

    Used for testing purposes. (In general, the result may be very large.)

    Returns:
      The full Fisher block.
    """
    left_factor = self._input_factor.get_cov_as_linear_operator().to_dense()
    right_factor = self._output_factor.get_cov_as_linear_operator().to_dense()
    return self._renorm_coeff * utils.kronecker_product(left_factor,
                                                        right_factor)


class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB):
  """K-FAC FisherBlock for fully-connected (dense) layers.

  This uses the Kronecker-factorized approximation from the original
  K-FAC paper (https://arxiv.org/abs/1503.05671)
  """

  def __init__(self, layer_collection, has_bias=False,
               diagonal_approx_for_input=False,
               diagonal_approx_for_output=False):
    """Creates a FullyConnectedKFACBasicFB block.

    Args:
      layer_collection: The LayerCollection object which owns this block.
      has_bias: bool. If True, estimates Fisher with respect to a bias
        parameter as well as the layer's weights.
        (Default: False)
      diagonal_approx_for_input: Whether to use diagonal approximation for the
        input Kronecker factor. (Default: False)
      diagonal_approx_for_output: Whether to use diagonal approximation for the
        output Kronecker factor. (Default: False)
    """
    self._has_bias = has_bias
    self._diagonal_approx_for_input = diagonal_approx_for_input
    self._diagonal_approx_for_output = diagonal_approx_for_output

    super(FullyConnectedKFACBasicFB, self).__init__(layer_collection)

  def instantiate_factors(self, grads_list, damping):
    """Instantiate Kronecker Factors for this FisherBlock.

    Args:
      grads_list: List of list of Tensors. grads_list[i][j] is the
        gradient of the loss with respect to 'outputs' from source 'i' and
        tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size].
      damping: 0-D Tensor or float. 'damping' * identity is approximately added
        to this FisherBlock's Fisher approximation.
    """
    inputs, grads_list = self._process_data(grads_list)

    if self._diagonal_approx_for_input:
      self._input_factor = self._layer_collection.make_or_get_factor(
          fisher_factors.DiagonalKroneckerFactor,
          ((inputs,), self._has_bias))
    else:
      self._input_factor = self._layer_collection.make_or_get_factor(
          fisher_factors.FullyConnectedKroneckerFactor,
          ((inputs,), self._has_bias))

    if self._diagonal_approx_for_output:
      self._output_factor = self._layer_collection.make_or_get_factor(
          fisher_factors.DiagonalKroneckerFactor,
          (grads_list,))
    else:
      self._output_factor = self._layer_collection.make_or_get_factor(
          fisher_factors.FullyConnectedKroneckerFactor,
          (grads_list,))

    self._setup_damping(damping)


class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB):
  """FisherBlock for convolutional layers using the basic KFC approx.

  Estimates the Fisher Information matrix's blog for a convolutional
  layer.

  Consider a convolutional layer in this model with (unshared) filter matrix
  'w'. For a minibatch that produces inputs 'a' and output preactivations 's',
  this FisherBlock estimates,

    F(w) = #locations * kronecker(E[flat(a) flat(a)^T],
                                  E[flat(ds) flat(ds)^T])

  where

    ds = (d / ds) log p(y | x, w)
    #locations = number of (x, y) locations where 'w' is applied.

  where the expectation is taken over all examples and locations and flat()
  concatenates an array's leading dimensions.

  See equation 23 in https://arxiv.org/abs/1602.01407 for details.
  """

  def __init__(self,
               layer_collection,
               params,
               padding,
               strides=None,
               dilation_rate=None,
               data_format=None,
               extract_patches_fn=None,
               sub_sample_inputs=None,
               sub_sample_patches=None,
               use_sua_approx_for_input_factor=False,
               patch_mask=None):
    """Creates a ConvKFCBasicFB block.

    Args:
      layer_collection: The LayerCollection object which owns this block.
      params: The parameters (Tensor or tuple of Tensors) of this layer. If
        kernel alone, a Tensor of shape [..spatial_filter_shape..,
        in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
        containing the previous and a Tensor of shape [out_channels].
      padding: str. Padding method.
      strides: List of ints or None. Contains [..spatial_filter_strides..] if
        'extract_patches_fn' is compatible with tf.nn.convolution(), else
        [1, ..spatial_filter_strides, 1].
      dilation_rate: List of ints or None. Rate for dilation along each spatial
        dimension if 'extract_patches_fn' is compatible with
        tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
      data_format: str or None. Format of input data.
      extract_patches_fn: str or None. Name of function that extracts image
        patches. One of "extract_convolution_patches", "extract_image_patches",
        "extract_pointwise_conv2d_patches".
      sub_sample_inputs: `bool`. If True, then subsample the inputs from which
        the image patches are extracted. (Default: None)
      sub_sample_patches: `bool`, If `True` then subsample the extracted
        patches. (Default: None)
      use_sua_approx_for_input_factor: `bool`, If `True` then use
        `ConvInputSUAKroneckerFactor` for input factor. Otherwise use
        `ConvInputKroneckerFactor`. (Default: None)
      patch_mask: Tensor of shape [kernel_height, kernel_width, in_channels]
        or None. If not None this is multiplied against the extracted patches
        Tensor (broadcasting along the batch dimension) before statistics are
        computed in the input factor. (Default: None)
    """
    self._padding = padding
    self._strides = maybe_tuple(strides)
    self._dilation_rate = maybe_tuple(dilation_rate)
    self._data_format = data_format
    self._extract_patches_fn = extract_patches_fn
    self._has_bias = isinstance(params, (tuple, list))
    self._use_sua_approx_for_input_factor = use_sua_approx_for_input_factor

    fltr = params[0] if self._has_bias else params
    self._filter_shape = tuple(fltr.shape.as_list())

    self._sub_sample_inputs = sub_sample_inputs
    self._sub_sample_patches = sub_sample_patches
    self._patch_mask = patch_mask

    super(ConvKFCBasicFB, self).__init__(layer_collection)

  def instantiate_factors(self, grads_list, damping):
    inputs, grads_list = self._process_data(grads_list)

    # Infer number of locations upon which convolution is applied.
    self._num_locations = utils.num_conv_locations(inputs[0].shape.as_list(),
                                                   list(self._filter_shape),
                                                   self._strides,
                                                   self._padding)

    if self._use_sua_approx_for_input_factor:
      self._input_factor = self._layer_collection.make_or_get_factor(
          fisher_factors.ConvInputSUAKroneckerFactor,
          (inputs, self._filter_shape, self._has_bias))
    else:
      self._input_factor = self._layer_collection.make_or_get_factor(
          fisher_factors.ConvInputKroneckerFactor,
          (inputs, self._filter_shape, self._padding, self._strides,
           self._dilation_rate, self._data_format, self._extract_patches_fn,
           self._has_bias, self._sub_sample_inputs, self._sub_sample_patches,
           self._patch_mask))

    self._output_factor = self._layer_collection.make_or_get_factor(
        fisher_factors.ConvOutputKroneckerFactor, (grads_list,))

    self._setup_damping(damping, normalization=self._num_locations)

  @property
  def _renorm_coeff(self):
    return self._num_locations


class DepthwiseConvDiagonalFB(ConvDiagonalFB):
  """FisherBlock for depthwise_conv2d().

  Equivalent to ConvDiagonalFB applied to each input channel in isolation.
  """

  def __init__(self,
               layer_collection,
               params,
               strides,
               padding,
               rate=None,
               data_format=None):
    """Creates a DepthwiseConvKFCBasicFB block.

    Args:
      layer_collection: The LayerCollection object which owns this block.
      params: Tensor of shape [filter_height, filter_width, in_channels,
        channel_multiplier].
      strides: List of 4 ints. Strides along all dimensions.
      padding: str. Padding method.
      rate: List of 4 ints or None. Rate for dilation along all dimensions.
      data_format: str or None. Format of input data.

    Raises:
      NotImplementedError: If parameters contains bias.
      ValueError: If filter is not 4-D.
      ValueError: If strides is not length-4.
      ValueError: If rates is not length-2.
      ValueError: If channels are not last dimension.
    """
    if isinstance(params, (tuple, list)):
      raise NotImplementedError("Bias not yet supported.")

    if params.shape.ndims != 4:
      raise ValueError("Filter must be 4-D.")

    if len(strides) != 4:
      raise ValueError("strides must account for 4 dimensions.")

    if rate is not None:
      if len(rate) != 2:
        raise ValueError("rate must only account for spatial dimensions.")
      rate = [1, rate[0], rate[1], 1]  # conv2d expects 4-element rate.

    if not utils.is_data_format_channel_last(data_format):
      raise ValueError("data_format must be channels-last.")

    super(DepthwiseConvDiagonalFB, self).__init__(
        layer_collection=layer_collection,
        params=params,
        strides=strides,
        padding=padding,
        dilations=rate,
        data_format=data_format)

    # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__().
    filter_height, filter_width, in_channels, channel_multiplier = (
        params.shape.as_list())
    self._filter_shape = (filter_height, filter_width, in_channels,
                          in_channels * channel_multiplier)

  def _multiply_matrix(self, matrix, vector):
    conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
    conv2d_result = super(
        DepthwiseConvDiagonalFB, self)._multiply_matrix(matrix, conv2d_vector)
    return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)


class DepthwiseConvKFCBasicFB(ConvKFCBasicFB):
  """FisherBlock for depthwise_conv2d().

  Equivalent to ConvKFCBasicFB applied to each input channel in isolation.
  """

  def __init__(self,
               layer_collection,
               params,
               strides,
               padding,
               rate=None,
               data_format=None):
    """Creates a DepthwiseConvKFCBasicFB block.

    Args:
      layer_collection: The LayerCollection object which owns this block.
      params: Tensor of shape [filter_height, filter_width, in_channels,
        channel_multiplier].
      strides: List of 4 ints. Strides along all dimensions.
      padding: str. Padding method.
      rate: List of 4 ints or None. Rate for dilation along all dimensions.
      data_format: str or None. Format of input data.

    Raises:
      NotImplementedError: If parameters contains bias.
      ValueError: If filter is not 4-D.
      ValueError: If strides is not length-4.
      ValueError: If rates is not length-2.
      ValueError: If channels are not last dimension.
    """
    if isinstance(params, (tuple, list)):
      raise NotImplementedError("Bias not yet supported.")

    if params.shape.ndims != 4:
      raise ValueError("Filter must be 4-D.")

    if len(strides) != 4:
      raise ValueError("strides must account for 4 dimensions.")

    if rate is not None:
      if len(rate) != 2:
        raise ValueError("rate must only account for spatial dimensions.")
      rate = [1, rate[0], rate[1], 1]  # conv2d expects 4-element rate.

    if not utils.is_data_format_channel_last(data_format):
      raise ValueError("data_format must be channels-last.")

    super(DepthwiseConvKFCBasicFB, self).__init__(
        layer_collection=layer_collection,
        params=params,
        padding=padding,
        strides=strides,
        dilation_rate=rate,
        data_format=data_format,
        extract_patches_fn="extract_image_patches")

    # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__().
    filter_height, filter_width, in_channels, channel_multiplier = (
        params.shape.as_list())
    self._filter_shape = (filter_height, filter_width, in_channels,
                          in_channels * channel_multiplier)

  def _multiply_factored_matrix(self, left_factor, right_factor, vector,
                                extra_scale=1.0, transpose_left=False,
                                transpose_right=False):
    conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
    conv2d_result = super(
        DepthwiseConvKFCBasicFB, self)._multiply_factored_matrix(
            left_factor, right_factor, conv2d_vector, extra_scale=extra_scale,
            transpose_left=transpose_left, transpose_right=transpose_right)
    return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)


def depthwise_conv2d_filter_to_conv2d_filter(filter, name=None):  # pylint: disable=redefined-builtin
  """Converts a convolution filter for use with conv2d.

  Transforms a filter for use with tf.nn.depthwise_conv2d() to one that's
  compatible with tf.nn.conv2d().

  Args:
    filter: Tensor of shape [height, width, in_channels, channel_multiplier].
    name: None or str. Name of Op.

  Returns:
    Tensor of shape [height, width, in_channels, out_channels].

  """
  with tf.name_scope(name, "depthwise_conv2d_filter_to_conv2d_filter",
                     [filter]):
    filter = tf.convert_to_tensor(filter)
    filter_height, filter_width, in_channels, channel_multiplier = (
        filter.shape.as_list())

    results = []
    for i in range(in_channels):
      # Slice out one in_channel's filter. Insert zeros around it to force it
      # to affect that channel and that channel alone.
      elements = []
      if i > 0:
        elements.append(
            tf.zeros([filter_height, filter_width, i, channel_multiplier]))
      elements.append(filter[:, :, i:(i + 1), :])
      if i + 1 < in_channels:
        elements.append(
            tf.zeros([
                filter_height, filter_width, in_channels - (i + 1),
                channel_multiplier
            ]))

      # Concat along in_channel.
      results.append(tf.concat(elements, axis=-2, name="in_channel_%d" % i))

    # Concat along out_channel.
    return tf.concat(results, axis=-1, name="out_channel")


def conv2d_filter_to_depthwise_conv2d_filter(filter, name=None):  # pylint: disable=redefined-builtin
  """Converts a convolution filter for use with depthwise_conv2d.

  Transforms a filter for use with tf.nn.conv2d() to one that's
  compatible with tf.nn.depthwise_conv2d(). Ignores all filters but those along
  the diagonal.

  Args:
    filter: Tensor of shape [height, width, in_channels, out_channels].
    name: None or str. Name of Op.

  Returns:
    Tensor of shape,
      [height, width, in_channels, channel_multiplier]

  Raises:
    ValueError: if out_channels is not evenly divisible by in_channels.
  """
  with tf.name_scope(name, "conv2d_filter_to_depthwise_conv2d_filter",
                     [filter]):
    filter = tf.convert_to_tensor(filter)
    filter_height, filter_width, in_channels, out_channels = (
        filter.shape.as_list())

    if out_channels % in_channels != 0:
      raise ValueError("out_channels must be evenly divisible by in_channels.")
    channel_multiplier = out_channels // in_channels

    results = []
    filter = tf.reshape(filter, [
        filter_height, filter_width, in_channels, in_channels,
        channel_multiplier
    ])
    for i in range(in_channels):
      # Slice out output corresponding to the correct filter.
      filter_slice = tf.reshape(
          filter[:, :, i, i, :],
          [filter_height, filter_width, 1, channel_multiplier])
      results.append(filter_slice)

    # Concat along out_channel.
    return tf.concat(results, axis=-2, name="in_channels")


def maybe_tuple(obj):
  if not isinstance(obj, list):
    return obj
  return tuple(obj)


class InputOutputMultiTowerMultiUse(InputOutputMultiTower):
  """Adds methods for multi-use/time-step case to InputOutputMultiTower."""

  def __init__(self, num_uses=None, *args, **kwargs):
    self._num_uses = num_uses
    super(InputOutputMultiTowerMultiUse, self).__init__(*args, **kwargs)

  def _process_data(self, grads_list):
    """Process temporal/multi-use data into the format used by the factors.

    This function takes inputs and grads_lists data and processes it into
    one of the formats expected by the FisherFactor classes (depending on
    the value of the global configuration variable TOWER_STRATEGY).

    It accepts the data in one of two initial formats. The first possible
    format is where self._inputs is a list of list of Tensors. The first index
    is tower, the second is use/time-step. grads_list, meanwhile, is a list
    over sources of such lists of lists.

    The second possible data format is where self._inputs is a list of Tensors
    (over towers), where each Tensor has uses/times-steps folded into the batch
    dimension. i.e. they are Tensors of shape [num_uses * batch_size, ...],
    which represent reshapes of a Tensor of shape [num_uses, batch_size, ...].
    And similarly grads_list is a list over sources of such lists of Tensors.

    There are two possible formats which inputs and grads_list are transformed
    into.

    If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing
    a single tensor (represented as a PartitionedTensor object) with all of
    the data from the towers, as well as the uses/time-steps, concatenated
    together. In this tensor the leading dimension is the batch and
    use/time-step dimensions folded together (with 'use' being the major of
    these two, so that the tensors can be thought of as reshapes of ones of
    shape [num_uses, batch_size, ...]). grads_list is similarly formatted as a
    tuple over sources of such tensors.

    If TOWER_STRATEGY is "separate" the inputs are formatted into lists of
    tensors over towers. Each of these tensors has a similar format to
    the tensor produced by the "concat" option, except that each contains
    only the data from a single tower.  grads_list is similarly formatted
    into a tuple over sources of such tuples.

    Args:
      grads_list: grads_list in its initial format (see above).

    Returns:
      inputs: self._inputs transformed into the appropriate format (see
        above).
      grads_list: grads_list transformed into the appropriate format (see
        above).

    Raises:
      ValueError: If TOWER_STRATEGY is not one of "separate" or "concat".
      ValueError: If the given/initial format of self._inputs and grads_list
        isn't recognized, or doesn't agree with self._num_uses.
    """
    inputs = self._inputs

    # The first data format.
    if isinstance(inputs[0], (list, tuple)):

      num_uses = len(inputs[0])

      if self._num_uses is not None and self._num_uses != num_uses:
        raise ValueError("num_uses argument doesn't match length of inputs.")
      else:
        self._num_uses = num_uses

      # Check that all mini-batches/towers have the same number of uses
      if not all(len(input_) == num_uses for input_ in inputs):
        raise ValueError("Length of inputs argument is inconsistent across "
                         "towers.")

      if fisher_factors.TOWER_STRATEGY == "concat":
        # Reverse the tower and use/time-step indices, so that use is now first,
        # and towers is second
        inputs = tuple(zip(*inputs))

        # Flatten the two dimensions
        inputs = nest.flatten(inputs)

        # Merge everything together into a PartitionedTensor. We package it in
        # a singleton tuple since the factors will expect a list over towers
        inputs = (utils.PartitionedTensor(inputs),)

      elif fisher_factors.TOWER_STRATEGY == "separate":
        # Merge together the uses/time-step dimension into PartitionedTensors,
        # but keep the leading dimension (towers) intact for the factors to
        # process individually.
        inputs = tuple(utils.PartitionedTensor(input_) for input_ in inputs)

      else:
        raise ValueError("Global config variable TOWER_STRATEGY must be one of "
                         "'concat' or 'separate'.")
    # The second data format
    else:
      inputs = tuple(inputs)

    # Now we perform the analogous processing for grads_list

    # The first data format.
    if isinstance(grads_list[0][0], (list, tuple)):

      num_uses = len(grads_list[0][0])

      if self._num_uses is not None and self._num_uses != num_uses:
        raise ValueError("num_uses argument doesn't match length of outputs, "
                         "or length of outputs is inconsistent with length of "
                         "inputs.")
      else:
        self._num_uses = num_uses

      if not all(len(grad) == num_uses for grads in grads_list
                 for grad in grads):
        raise ValueError("Length of outputs argument is inconsistent across "
                         "towers.")

      if fisher_factors.TOWER_STRATEGY == "concat":
        # Reverse the tower and use/time-step indices, so that use is now first,
        # and towers is second
        grads_list = tuple(tuple(zip(*grads)) for grads in grads_list)

        # Flatten the two dimensions, leaving the leading dimension (source)
        # intact
        grads_list = tuple(nest.flatten(grads) for grads in grads_list)

        # Merge inner dimensions together into PartitionedTensors. We package
        # them in a singleton tuple since the factors will expect a list over
        # towers
        grads_list = tuple((utils.PartitionedTensor(grads),)
                           for grads in grads_list)

      elif fisher_factors.TOWER_STRATEGY == "separate":
        # Merge together the uses/time-step dimension into PartitionedTensors,
        # but keep the leading dimension (towers) intact for the factors to
        # process individually.
        grads_list = tuple(tuple(utils.PartitionedTensor(grad)
                                 for grad in grads)
                           for grads in grads_list)

      else:
        raise ValueError("Global config variable TOWER_STRATEGY must be one of "
                         "'concat' or 'separate'.")

    # The second data format.
    else:
      grads_list = tuple(tuple(grads) for grads in grads_list)

    if self._num_uses is None:
      raise ValueError("You must supply a value for the num_uses argument if "
                       "the number of uses cannot be inferred from inputs or "
                       "outputs arguments (e.g. if they are both given in the "
                       "single Tensor format, instead of as lists of Tensors.")

    return inputs, grads_list


class FullyConnectedMultiIndepFB(InputOutputMultiTowerMultiUse,
                                 KroneckerProductFB):
  """FisherBlock for fully-connected layers that share parameters.

  This class implements the "independence across time" approximation from the
  following paper:
    https://openreview.net/pdf?id=HyMTkQZAb
  """

  def __init__(self, layer_collection, has_bias=False, num_uses=None,
               diagonal_approx_for_input=False,
               diagonal_approx_for_output=False):
    """Creates a FullyConnectedMultiIndepFB block.

    Args:
      layer_collection: LayerCollection instance.
      has_bias: bool. If True, estimates Fisher with respect to a bias
        parameter as well as the layer's weights. (Default: False)
      num_uses: int or None. Number of uses of the layer in the model's graph.
        Only required if the data is formatted with uses/time folded into the
        batch dimension (instead of uses/time being a list dimension).
        (Default: None)
      diagonal_approx_for_input: Whether to use diagonal approximation for the
        input Kronecker factor. (Default: False)
      diagonal_approx_for_output: Whether to use diagonal approximation for the
        output Kronecker factor. (Default: False)
    """
    self._has_bias = has_bias
    self._diagonal_approx_for_input = diagonal_approx_for_input
    self._diagonal_approx_for_output = diagonal_approx_for_output

    super(FullyConnectedMultiIndepFB, self).__init__(
        layer_collection=layer_collection,
        num_uses=num_uses)

  def instantiate_factors(self, grads_list, damping):
    inputs, grads_list = self._process_data(grads_list)

    if self._diagonal_approx_for_input:
      self._input_factor = self._layer_collection.make_or_get_factor(
          fisher_factors.DiagonalMultiKF,
          ((inputs,), self._num_uses, self._has_bias))
    else:
      self._input_factor = self._layer_collection.make_or_get_factor(
          fisher_factors.FullyConnectedMultiKF,
          ((inputs,), self._num_uses, self._has_bias))

    if self._diagonal_approx_for_output:
      self._output_factor = self._layer_collection.make_or_get_factor(
          fisher_factors.DiagonalMultiKF,
          (grads_list, self._num_uses))
    else:
      self._output_factor = self._layer_collection.make_or_get_factor(
          fisher_factors.FullyConnectedMultiKF,
          (grads_list, self._num_uses))

    self._setup_damping(damping, normalization=self._num_uses)

  @property
  def _renorm_coeff(self):
    return float(self._num_uses)


class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse,
                               KroneckerProductFB):
  """FisherBlock for 2D convolutional layers using the basic KFC approx.

  Similar to ConvKFCBasicFB except that this version supports multiple
  uses/time-steps via a standard independence approximation.  Similar to the
  "independence across time" used in FullyConnectedMultiIndepFB but generalized
  in the obvious way to conv layers.
  """

  def __init__(self,
               layer_collection,
               params,
               padding,
               strides=None,
               dilation_rate=None,
               data_format=None,
               extract_patches_fn=None,
               num_uses=None):
    """Creates a ConvKFCBasicMultiIndepFB block.

    Args:
      layer_collection: The LayerCollection object which owns this block.
      params: The parameters (Tensor or tuple of Tensors) of this layer. If
        kernel alone, a Tensor of shape [..spatial_filter_shape..,
        in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
        containing the previous and a Tensor of shape [out_channels].
      padding: str. Padding method.
      strides: List of ints or None. Contains [..spatial_filter_strides..] if
        'extract_patches_fn' is compatible with tf.nn.convolution(), else
        [1, ..spatial_filter_strides, 1].
      dilation_rate: List of ints or None. Rate for dilation along each spatial
        dimension if 'extract_patches_fn' is compatible with
        tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
      data_format: str or None. Format of input data.
      extract_patches_fn: str or None. Name of function that extracts image
        patches. One of "extract_convolution_patches", "extract_image_patches",
        "extract_pointwise_conv2d_patches".
      num_uses: int or None. Number of uses of the layer in the model's graph.
        Only required if the data is formatted with uses/time folded into the
        batch dimension (instead of uses/time being a list dimension).
        (Default: None)
    """
    self._padding = padding
    self._strides = maybe_tuple(strides)
    self._dilation_rate = maybe_tuple(dilation_rate)
    self._data_format = data_format
    self._extract_patches_fn = extract_patches_fn
    self._has_bias = isinstance(params, (tuple, list))

    fltr = params[0] if self._has_bias else params
    self._filter_shape = tuple(fltr.shape.as_list())

    super(ConvKFCBasicMultiIndepFB, self).__init__(
        layer_collection=layer_collection,
        num_uses=num_uses)

  def instantiate_factors(self, grads_list, damping):
    inputs, grads_list = self._process_data(grads_list)

    # Infer number of locations upon which convolution is applied.
    self._num_locations = utils.num_conv_locations(inputs[0].shape.as_list(),
                                                   list(self._filter_shape),
                                                   self._strides,
                                                   self._padding)

    self._input_factor = self._layer_collection.make_or_get_factor(
        fisher_factors.ConvInputMultiKF,
        (inputs, self._filter_shape, self._padding, self._num_uses,
         self._strides, self._dilation_rate, self._data_format,
         self._extract_patches_fn, self._has_bias, self._num_uses))
    self._output_factor = self._layer_collection.make_or_get_factor(
        fisher_factors.ConvOutputMultiKF, (grads_list, self._num_uses,
                                           self._data_format))

    self._setup_damping(damping,
                        normalization=(self._num_locations * self._num_uses))

  @property
  def _renorm_coeff(self):
    return self._num_locations * self._num_uses


class SeriesFBApproximation(object):
  """See FullyConnectedSeriesFB.__init__ for description and usage."""
  option1 = 1
  option2 = 2


class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse,
                             KroneckerProductFB):
  """FisherBlock for fully-connected layers that share parameters across time.

  This class implements the "Option 1" and "Option 2" approximation from the
  following paper:
    https://openreview.net/pdf?id=HyMTkQZAb

  See the end of the appendix of the paper for a pseudo-code of the
  algorithm being implemented by multiply_matpower here.  Note that we are
  using pre-computed versions of certain matrix-matrix products to speed
  things up.  This is explicitly explained wherever it is done.
  """

  def __init__(self,
               layer_collection,
               has_bias=False,
               num_uses=None,
               option=SeriesFBApproximation.option2):
    """Constructs a new `FullyConnectedSeriesFB`.

    Args:
      layer_collection: The collection of all layers in the K-FAC approximate
        Fisher information matrix to which this FisherBlock belongs.
      has_bias: bool. If True, estimates Fisher with respect to a bias
        parameter as well as the layer's weights.
      num_uses: int or None. Number of time-steps over which the layer
        is used. Only required if the data is formatted with time folded into
        the batch dimension (instead of time being a list dimension).
        (Default: None)
      option: A `SeriesFBApproximation` specifying the simplifying assumption
        to be used in this block. `option1` approximates the cross-covariance
        over time as a symmetric matrix, while `option2` makes
        the assumption that training sequences are infinitely long. See section
        3.5 of the paper for more details.
    """

    self._has_bias = has_bias
    self._option = option

    super(FullyConnectedSeriesFB, self).__init__(
        layer_collection=layer_collection,
        num_uses=num_uses)

  @property
  def _num_timesteps(self):
    return self._num_uses

  @property
  def _renorm_coeff(self):
    # This should no longer be used since the multiply_X functions from the base
    # class have been overridden
    assert False

  def instantiate_factors(self, grads_list, damping):
    inputs, grads_list = self._process_data(grads_list)

    self._input_factor = self._layer_collection.make_or_get_factor(
        fisher_factors.FullyConnectedMultiKF,
        ((inputs,), self._num_uses, self._has_bias))
    self._input_factor.register_cov_dt1()

    self._output_factor = self._layer_collection.make_or_get_factor(
        fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses))
    self._output_factor.register_cov_dt1()

    self._setup_damping(damping, normalization=self._num_uses)

  def register_matpower(self, exp):
    if exp != -1:
      raise NotImplementedError("FullyConnectedSeriesFB only supports inverse"
                                "multiplications.")

    if self._option == SeriesFBApproximation.option1:
      self._input_factor.register_option1quants(self._input_damping_func)
      self._output_factor.register_option1quants(self._output_damping_func)
    elif self._option == SeriesFBApproximation.option2:
      self._input_factor.register_option2quants(self._input_damping_func)
      self._output_factor.register_option2quants(self._output_damping_func)
    else:
      raise ValueError(
          "Unrecognized FullyConnectedSeriesFB approximation: {}".format(
              self._option))

  def multiply_matpower(self, vector, exp):
    if exp != -1:
      raise NotImplementedError("FullyConnectedSeriesFB only supports inverse"
                                "multiplications.")

    # pylint: disable=invalid-name

    Z = utils.layer_params_to_mat2d(vector)

    # Derivations were done for "batch_dim==1" case so we need to convert to
    # that orientation:
    Z = tf.transpose(Z)

    if self._option == SeriesFBApproximation.option1:

      # Note that L_A = A0^{-1/2} * U_A and L_G = G0^{-1/2} * U_G.
      L_A, psi_A = self._input_factor.get_option1quants(
          self._input_damping_func)
      L_G, psi_G = self._output_factor.get_option1quants(
          self._output_damping_func)

      def gamma(x):
        # We are assuming that each case has the same number of time-steps.
        # If this stops being the case one shouldn't simply replace this T
        # with its average value.  Instead, one needs to go back to the
        # definition of the gamma function from the paper.
        T = self._num_timesteps
        return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T))

      # Y = \gamma( psi_G*psi_A^T ) (computed element-wise)
      # Even though Y is Z-independent we are recomputing it from the psi's
      # each since Y depends on both A and G quantities, and it is relatively
      # cheap to compute.
      Y = gamma(tf.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A)

      # Z = L_G^T * Z * L_A
      # This is equivalent to the following computation from the original
      # pseudo-code:
      # Z = G0^{-1/2} * Z * A0^{-1/2}
      # Z = U_G^T * Z * U_A
      Z = tf.matmul(L_G, tf.matmul(Z, L_A), transpose_a=True)

      # Z = Z .* Y
      Z *= Y

      # Z = L_G * Z * L_A^T
      # This is equivalent to the following computation from the original
      # pseudo-code:
      # Z = U_G * Z * U_A^T
      # Z = G0^{-1/2} * Z * A0^{-1/2}
      Z = tf.matmul(L_G, tf.matmul(Z, L_A, transpose_b=True))

    elif self._option == SeriesFBApproximation.option2:

      # Note that P_A = A_1^T * A_0^{-1} and P_G = G_1^T * G_0^{-1},
      # and K_A = A_0^{-1/2} * E_A\ and\ K_G = G_0^{-1/2} * E_G.
      P_A, K_A, mu_A = self._input_factor.get_option2quants(
          self._input_damping_func)
      P_G, K_G, mu_G = self._output_factor.get_option2quants(
          self._output_damping_func)

      # Our approach differs superficially from the pseudo-code in the paper
      # in order to reduce the total number of matrix-matrix multiplies.
      # In particular, the first three computations in the pseudo code are
      # Z = G0^{-1/2} * Z * A0^{-1/2}
      # Z = Z - hPsi_G^T * Z * hPsi_A
      # Z = E_G^T * Z * E_A
      # Noting that hPsi = C0^{-1/2} * C1 * C0^{-1/2}, so that
      # C0^{-1/2} * hPsi = C0^{-1} * C1 * C0^{-1/2} = P^T * C0^{-1/2}
      # the entire computation can be written as
      # Z = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}
      #     - hPsi_G^T * G0^{-1/2} * Z * A0^{-1/2} * hPsi_A) * E_A
      #   = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}
      #     - G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2}) * E_A
      #   = E_G^T * G0^{-1/2} * Z * A0^{-1/2} * E_A
      #     -  E_G^T* G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2} * E_A
      #   = K_G^T * Z * K_A  -  K_G^T * P_G * Z * P_A^T * K_A
      # This final expression is computed by the following two lines:
      # Z = Z - P_G * Z * P_A^T
      Z -= tf.matmul(P_G, tf.matmul(Z, P_A, transpose_b=True))
      # Z = K_G^T * Z * K_A
      Z = tf.matmul(K_G, tf.matmul(Z, K_A), transpose_a=True)

      # Z = Z ./ (1*1^T - mu_G*mu_A^T)
      # Be careful with the outer product.  We don't want to accidentally
      # make it an inner-product instead.
      tmp = 1.0 - tf.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A
      # Prevent some numerical issues by setting any 0.0 eigs to 1.0
      tmp += 1.0 * tf.cast(tf.equal(tmp, 0.0), dtype=tmp.dtype)
      Z /= tmp

      # We now perform the transpose/reverse version of the operations
      # derived above, whose derivation from the original pseudo-code is
      # analgous.
      # Z = K_G * Z * K_A^T
      Z = tf.matmul(K_G, tf.matmul(Z, K_A, transpose_b=True))

      # Z = Z - P_G^T * Z * P_A
      Z -= tf.matmul(P_G, tf.matmul(Z, P_A), transpose_a=True)

      # Z = normalize (1/E[T]) * Z
      # Note that this normalization is done because we compute the statistics
      # by averaging, not summing, over time. (And the gradient is presumably
      # summed over time, not averaged, and thus their scales are different.)
      Z /= tf.cast(self._num_timesteps, Z.dtype)

    # Convert back to the "batch_dim==0" orientation.
    Z = tf.transpose(Z)

    return utils.mat2d_to_layer_params(vector, Z)

    # pylint: enable=invalid-name

  def multiply_cholesky(self, vector):
    raise NotImplementedError("FullyConnectedSeriesFB does not support "
                              "Cholesky computations.")

  def multiply_cholesky_inverse(self, vector):
    raise NotImplementedError("FullyConnectedSeriesFB does not support "
                              "Cholesky computations.")