from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import random_ops
from collections import OrderedDict

# Method used for inverting matrices.
POSDEF_INV_METHOD = "eig"
POSDEF_EIG_METHOD = "self_adjoint"


def set_global_constants(posdef_inv_method=None):
  """Sets various global constants used by the classes in this module."""
  global POSDEF_INV_METHOD

  if posdef_inv_method is not None:
    POSDEF_INV_METHOD = posdef_inv_method


class SequenceDict(object):
    """A dict convenience wrapper that allows getting/setting with sequences."""

    def __init__(self, iterable=None):
        self._dict = dict(iterable or [])

    def __getitem__(self, key_or_keys):
        if isinstance(key_or_keys, (tuple, list)):
            return list(map(self.__getitem__, key_or_keys))
        else:
            return self._dict[key_or_keys]

    def __setitem__(self, key_or_keys, val_or_vals):
        if isinstance(key_or_keys, (tuple, list)):
            for key, value in zip(key_or_keys, val_or_vals):
                self[key] = value
        else:
            self._dict[key_or_keys] = val_or_vals

    def items(self):
        return list(self._dict.items())


def tensors_to_column(tensors):
    """Converts a tensor or list of tensors to a column vector.
    Args:
        tensors: A tensor or list of tensors.
    Returns:
        The tensors reshaped into vectors and stacked on top of each other.
    """
    if isinstance(tensors, (tuple, list)):
        return array_ops.concat(
            tuple(array_ops.reshape(tensor, [-1, 1]) for tensor in tensors), axis=0)
    else:
        return array_ops.reshape(tensors, [-1, 1])


def column_to_tensors(tensors_template, colvec):
  """Converts a column vector back to the shape of the given template.
  Args:
    tensors_template: A tensor or list of tensors.
    colvec: A 2d column vector with the same shape as the value of
        tensors_to_column(tensors_template).
  Returns:
    X, where X is tensor or list of tensors with the properties:
     1) tensors_to_column(X) = colvec
     2) X (or its elements) have the same shape as tensors_template (or its
        elements)
  """
  if isinstance(tensors_template, (tuple, list)):
    offset = 0
    tensors = []
    for tensor_template in tensors_template:
      sz = np.prod(tensor_template.shape.as_list(), dtype=np.int32)
      tensor = array_ops.reshape(colvec[offset:(offset + sz)],
                                 tensor_template.shape)
      tensors.append(tensor)
      offset += sz

    tensors = tuple(tensors)
  else:
    tensors = array_ops.reshape(colvec, tensors_template.shape)

  return tensors


def kronecker_product(mat1, mat2):
    """Computes the Kronecker product two matrices."""
    m1, n1 = mat1.get_shape().as_list()
    mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1])
    m2, n2 = mat2.get_shape().as_list()
    mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2])
    return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2])


def layer_params_to_mat2d(vector):
    """Converts a vector shaped like layer parameters to a 2D matrix.
    In particular, we reshape the weights/filter component of the vector to be
    2D, flattening all leading (input) dimensions. If there is a bias component,
    we concatenate it to the reshaped weights/filter component.
    Args:
        vector: A Tensor or pair of Tensors shaped like layer parameters.
    Returns:
        A 2D Tensor with the same coefficients and the same output dimension.
    """
    if isinstance(vector, (tuple, list)):
        w_part, b_part = vector
        w_part_reshaped = array_ops.reshape(w_part,
                                            [-1, w_part.shape.as_list()[-1]])
        return array_ops.concat(
            (w_part_reshaped, array_ops.reshape(b_part, [1, -1])), axis=0)
    else:
        return array_ops.reshape(vector, [-1, vector.shape.as_list()[-1]])


def mat2d_to_layer_params(vector_template, mat2d):
    """Converts a canonical 2D matrix representation back to a vector.
    Args:
        vector_template: A Tensor or pair of Tensors shaped like layer parameters.
        mat2d: A 2D Tensor with the same shape as the value of
            layer_params_to_mat2d(vector_template).
    Returns:
        A Tensor or pair of Tensors with the same coefficients as mat2d and the same
            shape as vector_template.
    """
    if isinstance(vector_template, (tuple, list)):
        w_part, b_part = mat2d[:-1], mat2d[-1]
        return array_ops.reshape(w_part, vector_template[0].shape), b_part
    else:
        return array_ops.reshape(mat2d, vector_template.shape)


def posdef_inv(tensor, damping):
    """Computes the inverse of tensor + damping * identity."""
    identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype)
    damping = math_ops.cast(damping, dtype=tensor.dtype)
    return posdef_inv_functions[POSDEF_INV_METHOD](tensor, identity, damping)


def posdef_inv_matrix_inverse(tensor, identity, damping):
    """Computes inverse(tensor + damping * identity) directly."""
    return linalg_ops.matrix_inverse(tensor + damping * identity)


def posdef_inv_cholesky(tensor, identity, damping):
    """Computes inverse(tensor + damping * identity) with Cholesky."""
    chol = linalg_ops.cholesky(tensor + damping * identity)
    return linalg_ops.cholesky_solve(chol, identity)


def posdef_inv_eig(tensor, identity, damping):
    """Computes inverse(tensor + damping * identity) with eigendecomposition."""
    eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(
        tensor + damping * identity)
    # TODO(GD): it's a little hacky
    eigenvalues = gen_math_ops.maximum(eigenvalues, damping)
    return math_ops.matmul(
        eigenvectors / eigenvalues, eigenvectors, transpose_b=True)


posdef_inv_functions = {
    "matrix_inverse": posdef_inv_matrix_inverse,
    "cholesky": posdef_inv_cholesky,
    "eig": posdef_inv_eig,
}


def posdef_eig(mat):
    """Computes the eigendecomposition of a positive semidefinite matrix."""
    return posdef_eig_functions[POSDEF_EIG_METHOD](mat)


def posdef_eig_svd(mat):
    """Computes the singular values and left singular vectors of a matrix."""
    evals, evecs, _ = linalg_ops.svd(mat)

    return evals, evecs


def posdef_eig_self_adjoint(mat):
    """Computes eigendecomposition using self_adjoint_eig."""
    evals, evecs = linalg_ops.self_adjoint_eig(mat)
    evals = math_ops.abs(evals)  # Should be equivalent to svd approach.

    return evals, evecs


posdef_eig_functions = {
    "self_adjoint": posdef_eig_self_adjoint,
    "svd": posdef_eig_svd,
}


def generate_random_signs(shape, dtype=dtypes.float32):
    """Generate a random tensor with {-1, +1} entries."""
    ints = random_ops.random_uniform(shape, maxval=2, dtype=dtypes.int32)
    return 2 * math_ops.cast(ints, dtype=dtype) - 1


def ensure_sequence(obj):
    """If `obj` isn't a tuple or list, return a tuple containing `obj`."""
    if isinstance(obj, (tuple, list)):
        return obj
    else:
        return (obj,)


class LayerParametersDict(OrderedDict):
    """An OrderedDict where keys are Tensors or tuples of Tensors.
    Ensures that no Tensor is associated with two different keys.
    """

    def __init__(self, *args, **kwargs):
        self._tensors = set()
        super(LayerParametersDict, self).__init__(*args, **kwargs)

    def __setitem__(self, key, value):
        key = self._canonicalize_key(key)
        tensors = key if isinstance(key, (tuple, list)) else (key,)
        key_collisions = self._tensors.intersection(tensors)
        if key_collisions:
            raise ValueError("Key(s) already present: {}".format(key_collisions))
        self._tensors.update(tensors)
        super(LayerParametersDict, self).__setitem__(key, value)

    def __delitem__(self, key):
        key = self._canonicalize_key(key)
        self._tensors.remove(key)
        super(LayerParametersDict, self).__delitem__(key)

    def __getitem__(self, key):
        key = self._canonicalize_key(key)
        return super(LayerParametersDict, self).__getitem__(key)

    def __contains__(self, key):
        key = self._canonicalize_key(key)
        return super(LayerParametersDict, self).__contains__(key)

    def _canonicalize_key(self, key):
        if isinstance(key, (list, tuple)):
            return tuple(key)
        return key