from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import contextlib

import numpy as np
import six

from ops import utils
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import moving_averages

# Whether to initialize covariance estimators at a zero matrix (or the identity
# matrix).
INIT_COVARIANCES_AT_ZERO = False

# Whether to zero-debias the moving averages.
ZERO_DEBIAS = False

# When the number of inverses requested from a FisherFactor exceeds this value,
# the inverses are computed using an eigenvalue decomposition.
EIGENVALUE_DECOMPOSITION_THRESHOLD = 2

# Numerical eigenvalues computed from covariance matrix estimates are clipped to
# be at least as large as this value before they are used to compute inverses or
# matrix powers. Must be nonnegative.
EIGENVALUE_CLIPPING_THRESHOLD = 0.0


@contextlib.contextmanager
def _maybe_colocate_with(op, colocate_cov_ops_with_inputs):
    """Context to colocate with `op` if `colocate_cov_ops_with_inputs`."""
    if colocate_cov_ops_with_inputs:
        if isinstance(op, (list, tuple)):
            with tf_ops.colocate_with(op[0]):
                yield
        else:
            with tf_ops.colocate_with(op):
                yield
    else:
        yield


def set_global_constants(init_covariances_at_zero=None,
                         zero_debias=None,
                         eigenvalue_decomposition_threshold=None,
                         eigenvalue_clipping_threshold=None):
    """Sets various global constants used by the classes in this module."""
    global INIT_COVARIANCES_AT_ZERO
    global ZERO_DEBIAS
    global EIGENVALUE_DECOMPOSITION_THRESHOLD
    global EIGENVALUE_CLIPPING_THRESHOLD

    if init_covariances_at_zero is not None:
        INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero
    if zero_debias is not None:
        ZERO_DEBIAS = zero_debias
    if eigenvalue_decomposition_threshold is not None:
        EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold
    if eigenvalue_clipping_threshold is not None:
        EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold


def inverse_initializer(shape, dtype, partition_info=None):  # pylint: disable=unused-argument
    return array_ops.diag(array_ops.ones(shape[0], dtype))


def covariance_initializer(shape, dtype, partition_info=None):  # pylint: disable=unused-argument
    if INIT_COVARIANCES_AT_ZERO:
        return array_ops.diag(array_ops.zeros(shape[0], dtype))
    return array_ops.diag(array_ops.ones(shape[0], dtype))


def diagonal_covariance_initializer(shape, dtype, partition_info):  # pylint: disable=unused-argument
    if INIT_COVARIANCES_AT_ZERO:
        return array_ops.zeros(shape, dtype)
    return array_ops.ones(shape, dtype)


def _compute_cov(tensor, tensor_right=None, normalizer=None):
    """Compute the empirical second moment of the rows of a 2D Tensor.
    This function is meant to be applied to random matrices for which the true row
    mean is zero, so that the true second moment equals the true covariance.
    Args:
      tensor: A 2D Tensor.
      tensor_right: An optional 2D Tensor. If provided, this function computes
        the matrix product tensor^T * tensor_right instead of tensor^T * tensor.
      normalizer: optional scalar for the estimator (by default, the normalizer is
          the number of rows of tensor).
    Returns:
      A square 2D Tensor with as many rows/cols as the number of input columns.
    """
    if normalizer is None:
        normalizer = array_ops.shape(tensor)[0]
    if tensor_right is None:
        cov = (
            math_ops.matmul(tensor, tensor, transpose_a=True) / math_ops.cast(
                normalizer, tensor.dtype))
        return (cov + array_ops.transpose(cov)) / math_ops.cast(2.0, cov.dtype)
    else:
        return (math_ops.matmul(tensor, tensor_right, transpose_a=True) /
                math_ops.cast(normalizer, tensor.dtype))


def _append_homog(tensor):
    """Appends a homogeneous coordinate to the last dimension of a Tensor.
    Args:
      tensor: A Tensor.
    Returns:
      A Tensor identical to the input but one larger in the last dimension.  The
      new entries are filled with ones.
    """
    rank = len(tensor.shape.as_list())
    shape = array_ops.concat([array_ops.shape(tensor)[:-1], [1]], axis=0)
    ones = array_ops.ones(shape, dtype=tensor.dtype)
    return array_ops.concat([tensor, ones], axis=rank - 1)


def scope_string_from_params(params):
    """Builds a variable scope string name from the given parameters.
    Supported parameters are:
      * tensors
      * booleans
      * ints
      * strings
      * depth-1 tuples/lists of ints
      * any depth tuples/lists of tensors
    Other parameter types will throw an error.
    Args:
      params: A parameter or list of parameters.
    Returns:
      A string to use for the variable scope.
    Raises:
      ValueError: if params includes an unsupported type.
    """
    params = params if isinstance(params, (tuple, list)) else (params,)

    name_parts = []
    for param in params:
        if isinstance(param, (tuple, list)):
            if all([isinstance(p, int) for p in param]):
                name_parts.append("-".join([str(p) for p in param]))
            else:
                name_parts.append(scope_string_from_name(param))
        elif isinstance(param, (str, int, bool)):
            name_parts.append(str(param))
        elif isinstance(param, (tf_ops.Tensor, variables.Variable)):
            name_parts.append(scope_string_from_name(param))
        else:
            raise ValueError("Encountered an unsupported param type {}".format(
                type(param)))
    return "_".join(name_parts)


def scope_string_from_name(tensor):
    if isinstance(tensor, (tuple, list)):
        return "__".join([scope_string_from_name(t) for t in tensor])
    # "gradients/add_4_grad/Reshape:0" -> "gradients_add_4_grad_Reshape"
    return tensor.name.split(":")[0].replace("/", "_")


def scalar_or_tensor_to_string(val):
    return repr(val) if np.isscalar(val) else scope_string_from_name(val)


@six.add_metaclass(abc.ABCMeta)
class FisherFactor(object):
    """Base class for objects modeling factors of approximate Fisher blocks.
       Note that for blocks that aren't based on approximations, a 'factor' can
       be the entire block itself, as is the case for the diagonal and full
       representations.
       Subclasses must implement the _compute_new_cov method, and the _var_scope
       and _cov_shape properties.
    """

    def __init__(self):
        self.instantiate_covariance()

    @abc.abstractproperty
    def _var_scope(self):
        pass

    @abc.abstractproperty
    def _cov_shape(self):
        """The shape of the cov matrix."""
        pass

    @abc.abstractproperty
    def _num_sources(self):
        """The number of things to sum over when computing cov.
        The default make_covariance_update_op function will call _compute_new_cov
        with indices ranging from 0 to _num_sources-1. The typical situation is
        where the factor wants to sum the statistics it computes over multiple
        backpropped "gradients" (typically passed in via "tensors" or
        "outputs_grads" arguments).
        """
        pass

    @abc.abstractproperty
    def _dtype(self):
        pass

    @property
    def _cov_initializer(self):
        return covariance_initializer

    def instantiate_covariance(self):
        """Instantiates the covariance Variable as the instance member _cov."""
        with variable_scope.variable_scope(self._var_scope):
            self._cov = variable_scope.get_variable(
                "cov",
                initializer=self._cov_initializer,
                shape=self._cov_shape,
                trainable=False,
                dtype=self._dtype)

    @abc.abstractmethod
    def _compute_new_cov(self, idx=0):
        pass

    def make_covariance_update_op(self, ema_decay):
        """Constructs and returns the covariance update Op.
        Args:
          ema_decay: The exponential moving average decay (float or Tensor).
        Returns:
          An Op for updating the covariance Variable referenced by _cov.
        """
        new_cov = math_ops.add_n(
            tuple(self._compute_new_cov(idx) for idx in range(self._num_sources)))

        return moving_averages.assign_moving_average(
            self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS)

    @abc.abstractmethod
    def make_inverse_update_ops(self):
        """Create and return update ops corresponding to registered computations."""
        pass

    def get_cov(self):
        return self._cov


class InverseProvidingFactor(FisherFactor):
    """Base class for FisherFactors that maintain inverses, powers, etc of _cov.
    Assumes that the _cov property is a square PSD matrix.
    Subclasses must implement the _compute_new_cov method, and the _var_scope and
    _cov_shape properties.
    """

    # TODO(b/69108481): This class (and its subclasses) should be refactored to
    # serve the matrix quantities it computes as both (potentially stale)
    # variables, updated by the inverse update ops, and fresh values stored in
    # tensors that recomputed once every session.run() call.  Currently matpower
    # and damp_inverse have the former behavior, while eigendecomposition has
    # the latter.

    def __init__(self):
        self._inverses_by_damping = {}
        self._matpower_by_exp_and_damping = {}
        self._eigendecomp = None

        super(InverseProvidingFactor, self).__init__()

    def register_damped_inverse(self, damping):
        """Registers a damped inverse needed by a FisherBlock.
        This creates a variable and signals make_inverse_update_ops to make the
        corresponding update op.  The variable can be read via the method
        get_inverse.
        Args:
          damping: The damping value (float or Tensor) for this factor.
        """
        if damping not in self._inverses_by_damping:
            damping_string = scalar_or_tensor_to_string(damping)
            with variable_scope.variable_scope(self._var_scope):
                inv = variable_scope.get_variable(
                    "inv_damp{}".format(damping_string),
                    initializer=inverse_initializer,
                    shape=self._cov_shape,
                    trainable=False,
                    dtype=self._dtype)
            self._inverses_by_damping[damping] = inv

    def register_matpower(self, exp, damping):
        """Registers a matrix power needed by a FisherBlock.
        This creates a variable and signals make_inverse_update_ops to make the
        corresponding update op.  The variable can be read via the method
        get_matpower.
        Args:
          exp: The exponent (float or Tensor) to raise the matrix to.
          damping: The damping value (float or Tensor).
        """
        if (exp, damping) not in self._matpower_by_exp_and_damping:
            exp_string = scalar_or_tensor_to_string(exp)
            damping_string = scalar_or_tensor_to_string(damping)
            with variable_scope.variable_scope(self._var_scope):
                matpower = variable_scope.get_variable(
                    "matpower_exp{}_damp{}".format(exp_string, damping_string),
                    initializer=inverse_initializer,
                    shape=self._cov_shape,
                    trainable=False,
                    dtype=self._dtype)
            self._matpower_by_exp_and_damping[(exp, damping)] = matpower

    def register_eigendecomp(self):
        """Registers an eigendecomposition.
        Unlike register_damp_inverse and register_matpower this doesn't create
        any variables or inverse ops.  Instead it merely makes tensors containing
        the eigendecomposition available to anyone that wants them.  They will be
        recomputed (once) for each session.run() call (when they needed by some op).
        """
        if not self._eigendecomp:
            eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov)

            # The matrix self._cov is positive semidefinite by construction, but the
            # numerical eigenvalues could be negative due to numerical errors, so here
            # we clip them to be at least FLAGS.eigenvalue_clipping_threshold
            clipped_eigenvalues = math_ops.maximum(eigenvalues,
                                                   EIGENVALUE_CLIPPING_THRESHOLD)
            self._eigendecomp = (clipped_eigenvalues, eigenvectors)

    def make_inverse_update_ops(self):
        """Create and return update ops corresponding to registered computations."""
        ops = []

        num_inverses = len(self._inverses_by_damping)
        matrix_power_registered = bool(self._matpower_by_exp_and_damping)
        use_eig = (
            self._eigendecomp or matrix_power_registered or
            num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD)

        if use_eig:
            self.register_eigendecomp()  # ensures self._eigendecomp is set
            eigenvalues, eigenvectors = self._eigendecomp  # pylint: disable=unpacking-non-sequence

            for damping, inv in self._inverses_by_damping.items():
                ops.append(
                    inv.assign(
                        math_ops.matmul(eigenvectors / (eigenvalues + damping),
                                        array_ops.transpose(eigenvectors))))

            for (exp, damping), matpower in self._matpower_by_exp_and_damping.items():
                ops.append(
                    matpower.assign(
                        math_ops.matmul(eigenvectors *
                                        (eigenvalues + damping)**exp,
                                        array_ops.transpose(eigenvectors))))
            # These ops share computation and should be run on a single device.
            ops = [control_flow_ops.group(*ops)]
        else:
            for damping, inv in self._inverses_by_damping.items():
                ops.append(inv.assign(utils.posdef_inv(self._cov, damping)))

        return ops

    def get_damped_inverse(self, damping):
        # Note that this function returns a variable which gets updated by the
        # inverse ops.  It may be stale / inconsistent with the latest value of
        # get_cov().
        return self._inverses_by_damping[damping]

    def get_matpower(self, exp, damping):
        # Note that this function returns a variable which gets updated by the
        # inverse ops.  It may be stale / inconsistent with the latest value of
        # get_cov().
        return self._matpower_by_exp_and_damping[(exp, damping)]

    def get_eigendecomp(self):
        # Unlike get_inverse and get_matpower this doesn't retrieve a stored
        # variable, but instead always computes a fresh version from the current
        # value of get_cov().
        return self._eigendecomp


class FullFactor(InverseProvidingFactor):
    """FisherFactor for a full matrix representation of the Fisher of a parameter.
    Note that this uses the naive "square the sum estimator", and so is applicable
    to any type of parameter in principle, but has very high variance.
    """

    def __init__(self,
                 params_grads,
                 batch_size,
                 colocate_cov_ops_with_inputs=False):
        self._batch_size = batch_size
        self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
        self._orig_params_grads_name = scope_string_from_params(
            [params_grads, self._batch_size])
        params_grads_flat = []
        for params_grad in params_grads:
            with _maybe_colocate_with(params_grad,
                                      self._colocate_cov_ops_with_inputs):
                col = utils.tensors_to_column(params_grad)
                params_grads_flat.append(col)
        self._params_grads_flat = tuple(params_grads_flat)
        super(FullFactor, self).__init__()

    @property
    def _var_scope(self):
        return "ff_full/" + self._orig_params_grads_name

    @property
    def _cov_shape(self):
        size = self._params_grads_flat[0].shape[0]
        return [size, size]

    @property
    def _num_sources(self):
        return len(self._params_grads_flat)

    @property
    def _dtype(self):
        return self._params_grads_flat[0].dtype

    def _compute_new_cov(self, idx=0):
        # This will be a very basic rank 1 estimate
        with _maybe_colocate_with(self._params_grads_flat[idx],
                                  self._colocate_cov_ops_with_inputs):
            return ((self._params_grads_flat[idx] * array_ops.transpose(
                self._params_grads_flat[idx])) / math_ops.cast(
                self._batch_size, self._params_grads_flat[idx].dtype))


class DiagonalFactor(FisherFactor):
    """A core class for FisherFactors that use diagonal approximations."""

    def __init__(self):
        super(DiagonalFactor, self).__init__()

    @property
    def _cov_initializer(self):
        return diagonal_covariance_initializer

    def make_inverse_update_ops(self):
        return []


class NaiveDiagonalFactor(DiagonalFactor):
    """FisherFactor for a diagonal approximation of any type of param's Fisher.
    Note that this uses the naive "square the sum estimator", and so is applicable
    to any type of parameter in principle, but has very high variance.
    """

    def __init__(self,
                 params_grads,
                 batch_size,
                 colocate_cov_ops_with_inputs=False):
        self._batch_size = batch_size
        self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
        params_grads_flat = []
        for params_grad in params_grads:
            with _maybe_colocate_with(params_grad,
                                      self._colocate_cov_ops_with_inputs):
                col = utils.tensors_to_column(params_grad)
                params_grads_flat.append(col)
        self._params_grads = tuple(params_grads_flat)
        self._orig_params_grads_name = scope_string_from_params(
            [self._params_grads, self._batch_size])
        super(NaiveDiagonalFactor, self).__init__()

    @property
    def _var_scope(self):
        return "ff_naivediag/" + self._orig_params_grads_name

    @property
    def _cov_shape(self):
        return self._params_grads[0].shape

    @property
    def _num_sources(self):
        return len(self._params_grads)

    @property
    def _dtype(self):
        return self._params_grads[0].dtype

    def _compute_new_cov(self, idx=0):
        with _maybe_colocate_with(self._params_grads[idx],
                                  self._colocate_cov_ops_with_inputs):
            return (math_ops.square(self._params_grads[idx]) / math_ops.cast(
                self._batch_size, self._params_grads[idx].dtype))


class FullyConnectedDiagonalFactor(DiagonalFactor):
    r"""FisherFactor for a diagonal approx of a fully-connected layer's Fisher.
    Given in = [batch_size, input_size] and out_grad = [batch_size, output_size],
    approximates the covariance as,
      Cov(in, out) = (1/batch_size) \sum_{i} outer(in[i], out_grad[i]) ** 2.0
    where the square is taken element-wise.
    """

    # TODO(jamesmartens): add units tests for this class

    def __init__(self,
                 inputs,
                 outputs_grads,
                 has_bias=False,
                 colocate_cov_ops_with_inputs=False):
        """Instantiate FullyConnectedDiagonalFactor.
        Args:
          inputs: Tensor of shape [batch_size, input_size]. Inputs to fully
            connected layer.
          outputs_grads: List of Tensors of shape [batch_size, output_size].
            Gradient of loss with respect to layer's preactivations.
          has_bias: bool. If True, append '1' to each input.
          colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with
              their inputs.
        """
        self._outputs_grads = outputs_grads
        self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
        self._batch_size = array_ops.shape(inputs)[0]
        self._orig_tensors_name = scope_string_from_params(
            (inputs,) + tuple(outputs_grads))

        # Note that we precompute the required operations on the inputs since the
        # inputs don't change with the 'idx' argument to _compute_new_cov.  (Only
        # the target entry of _outputs_grads changes with idx.)
        with _maybe_colocate_with(inputs, self._colocate_cov_ops_with_inputs):
            if has_bias:
                inputs = _append_homog(inputs)
            self._squared_inputs = math_ops.square(inputs)

        super(FullyConnectedDiagonalFactor, self).__init__()

    @property
    def _var_scope(self):
        return "ff_diagfc/" + self._orig_tensors_name

    @property
    def _cov_shape(self):
        return [self._squared_inputs.shape[1], self._outputs_grads[0].shape[1]]

    @property
    def _num_sources(self):
        return len(self._outputs_grads)

    @property
    def _dtype(self):
        return self._outputs_grads[0].dtype

    def _compute_new_cov(self, idx=0):
        # The well-known special formula that uses the fact that the entry-wise
        # square of an outer product is the outer-product of the entry-wise squares.
        # The gradient is the outer product of the input and the output gradients,
        # so we just square both and then take their outer-product.
        with _maybe_colocate_with(self._squared_inputs,
                                  self._colocate_cov_ops_with_inputs):
            new_cov = math_ops.matmul(
                self._squared_inputs,
                math_ops.square(self._outputs_grads[idx]),
                transpose_a=True)
            new_cov /= math_ops.cast(self._batch_size, new_cov.dtype)
            return new_cov


class ConvDiagonalFactor(DiagonalFactor):
    """FisherFactor for a diagonal approx of a convolutional layer's Fisher."""

    # TODO(jamesmartens): add units tests for this class

    def __init__(self,
                 inputs,
                 outputs_grads,
                 filter_shape,
                 strides,
                 padding,
                 has_bias=False,
                 colocate_cov_ops_with_inputs=False):
        """Creates a ConvDiagonalFactor object.
        Args:
          inputs: Tensor of shape [batch_size, height, width, in_channels].
            Input activations to this layer.
          outputs_grads: Tensor of shape [batch_size, height, width, out_channels].
            Per-example gradients to the loss with respect to the layer's output
            preactivations.
          filter_shape: Tuple of 4 ints: (kernel_height, kernel_width, in_channels,
            out_channels). Represents shape of kernel used in this layer.
          strides: The stride size in this layer (1-D Tensor of length 4).
          padding: The padding in this layer (1-D of Tensor length 4).
          has_bias: Python bool. If True, the layer is assumed to have a bias
            parameter in addition to its filter parameter.
          colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with
              their inputs.
        """
        self._filter_shape = filter_shape
        self._has_bias = has_bias
        self._outputs_grads = outputs_grads
        self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs

        self._orig_tensors_name = scope_string_from_name(
            (inputs,) + tuple(outputs_grads))

        # Note that we precompute the required operations on the inputs since the
        # inputs don't change with the 'idx' argument to _compute_new_cov.  (Only
        # the target entry of _outputs_grads changes with idx.)
        with _maybe_colocate_with(inputs, self._colocate_cov_ops_with_inputs):
            filter_height, filter_width, _, _ = self._filter_shape

            # TODO(b/64144716): there is potential here for a big savings in terms of
            # memory use.
            patches = array_ops.extract_image_patches(
                inputs,
                ksizes=[1, filter_height, filter_width, 1],
                strides=strides,
                rates=[1, 1, 1, 1],
                padding=padding)

            if has_bias:
                patches = _append_homog(patches)

            self._patches = patches

        super(ConvDiagonalFactor, self).__init__()

    @property
    def _var_scope(self):
        return "ff_convdiag/" + self._orig_tensors_name

    @property
    def _cov_shape(self):
        filter_height, filter_width, in_channels, out_channels = self._filter_shape
        return [
            filter_height * filter_width * in_channels + self._has_bias,
            out_channels
        ]

    @property
    def _num_sources(self):
        return len(self._outputs_grads)

    @property
    def _dtype(self):
        return self._outputs_grads[0].dtype

    def _compute_new_cov(self, idx=0):
        with _maybe_colocate_with(self._outputs_grads[idx],
                                  self._colocate_cov_ops_with_inputs):
            outputs_grad = self._outputs_grads[idx]
            batch_size = array_ops.shape(self._patches)[0]

            new_cov = self._convdiag_sum_of_squares(self._patches, outputs_grad)
            new_cov /= math_ops.cast(batch_size, new_cov.dtype)

            return new_cov

    def _convdiag_sum_of_squares(self, patches, outputs_grad):
        # This computes the sum of the squares of the per-training-case "gradients".
        # It does this simply by computing a giant tensor containing all of these,
        # doing an entry-wise square, and them summing along the batch dimension.
        case_wise_gradients = special_math_ops.einsum("bijk,bijl->bkl", patches,
                                                      outputs_grad)
        return math_ops.reduce_sum(math_ops.square(case_wise_gradients), axis=0)


class FullyConnectedKroneckerFactor(InverseProvidingFactor):
    """Kronecker factor for the input or output side of a fully-connected layer.
    """

    def __init__(self,
                 tensors,
                 has_bias=False,
                 colocate_cov_ops_with_inputs=False):
        """Instantiate FullyConnectedKroneckerFactor.
        Args:
          tensors: List of Tensors of shape [batch_size, n]. Represents either a
            layer's inputs or its output's gradients.
          has_bias: bool. If True, append '1' to each row.
          colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with
              their inputs.
        """
        # The tensor argument is either a tensor of input activations or a tensor of
        # output pre-activation gradients.
        self._has_bias = has_bias
        self._tensors = tensors
        self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
        super(FullyConnectedKroneckerFactor, self).__init__()

    @property
    def _var_scope(self):
        return "ff_fckron/" + scope_string_from_params(
            [self._tensors, self._has_bias])

    @property
    def _cov_shape(self):
        size = self._tensors[0].shape[1] + self._has_bias
        return [size, size]

    @property
    def _num_sources(self):
        return len(self._tensors)

    @property
    def _dtype(self):
        return self._tensors[0].dtype

    def _compute_new_cov(self, idx=0):
        with _maybe_colocate_with(self._tensors[idx],
                                  self._colocate_cov_ops_with_inputs):
            tensor = self._tensors[idx]
            if self._has_bias:
                tensor = _append_homog(tensor)
            return _compute_cov(tensor)


class ConvInputKroneckerFactor(InverseProvidingFactor):
    r"""Kronecker factor for the input side of a convolutional layer.
    Estimates E[ a a^T ] where a is the inputs to a convolutional layer given
    example x. Expectation is taken over all examples and locations.
    Equivalent to \Omega in https://arxiv.org/abs/1602.01407 for details. See
    Section 3.1 Estimating the factors.
    """

    def __init__(self,
                 inputs,
                 filter_shape,
                 strides,
                 padding,
                 has_bias=False,
                 colocate_cov_ops_with_inputs=False):
        """Initializes ConvInputKroneckerFactor.
        Args:
          inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs
            to layer.
          filter_shape: 1-D Tensor of length 4. Contains [kernel_height,
            kernel_width, in_channels, out_channels].
          strides: 1-D Tensor of length 4. Contains [batch_stride, height_stride,
            width_stride, in_channel_stride].
          padding: str. Padding method for layer. "SAME" or "VALID".
          has_bias: bool. If True, append 1 to in_channel.
          colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with
              their inputs.
        """
        self._filter_shape = filter_shape
        self._strides = strides
        self._padding = padding
        self._has_bias = has_bias
        self._inputs = inputs
        self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
        super(ConvInputKroneckerFactor, self).__init__()

    @property
    def _var_scope(self):
        return "ff_convinkron/" + scope_string_from_params([
            self._inputs, self._filter_shape, self._strides, self._padding,
            self._has_bias
        ])

    @property
    def _cov_shape(self):
        filter_height, filter_width, in_channels, _ = self._filter_shape
        size = filter_height * filter_width * in_channels + self._has_bias
        return [size, size]

    @property
    def _num_sources(self):
        return 1

    @property
    def _dtype(self):
        return self._inputs.dtype

    def _compute_new_cov(self, idx=0):
        if idx != 0:
            raise ValueError("ConvInputKroneckerFactor only supports idx = 0")

        # TODO(jamesmartens): factor this patches stuff out into a utility function
        with _maybe_colocate_with(self._inputs, self._colocate_cov_ops_with_inputs):
            filter_height, filter_width, in_channels, _ = self._filter_shape

            # TODO(b/64144716): there is potential here for a big savings in terms of
            # memory use.
            patches = array_ops.extract_image_patches(
                self._inputs,
                ksizes=[1, filter_height, filter_width, 1],
                strides=self._strides,
                rates=[1, 1, 1, 1],
                padding=self._padding)

            flatten_size = (filter_height * filter_width * in_channels)
            patches_flat = array_ops.reshape(patches, [-1, flatten_size])

            if self._has_bias:
                patches_flat = _append_homog(patches_flat)

            return _compute_cov(patches_flat)


class ConvOutputKroneckerFactor(InverseProvidingFactor):
    r"""Kronecker factor for the output side of a convolutional layer.
    Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer
    given example x and ds = (d / d s) log(p(y|x, w)). Expectation is taken over
    all examples and locations.
    Equivalent to \Gamma in https://arxiv.org/abs/1602.01407 for details. See
    Section 3.1 Estimating the factors.
    """

    def __init__(self, outputs_grads, colocate_cov_ops_with_inputs=False):
        """Initializes ConvOutputKroneckerFactor.
        Args:
          outputs_grads: list of Tensors. Each Tensor is of shape
              [batch_size, height, width, out_channels].
          colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with
              their inputs.
        """
        self._out_channels = outputs_grads[0].shape.as_list()[3]
        self._outputs_grads = outputs_grads
        self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs
        super(ConvOutputKroneckerFactor, self).__init__()

    @property
    def _var_scope(self):
        return "ff_convoutkron/" + scope_string_from_params(self._outputs_grads)

    @property
    def _cov_shape(self):
        size = self._out_channels
        return [size, size]

    @property
    def _num_sources(self):
        return len(self._outputs_grads)

    @property
    def _dtype(self):
        return self._outputs_grads[0].dtype

    def _compute_new_cov(self, idx=0):
        with _maybe_colocate_with(self._outputs_grads[idx],
                                  self._colocate_cov_ops_with_inputs):
            reshaped_tensor = array_ops.reshape(self._outputs_grads[idx],
                                                [-1, self._out_channels])
            return _compute_cov(reshaped_tensor)