# coding=utf-8
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common util functions used by layers.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from collections import namedtuple
from collections import OrderedDict

# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variables

__all__ = ['collect_named_outputs',
           'constant_value',
           'static_cond',
           'smart_cond',
           'get_variable_collections',
           'two_element_tuple',
           'n_positive_integers',
           'channel_dimension',
           'last_dimension']

NamedOutputs = namedtuple('NamedOutputs', ['name', 'outputs'])


def collect_named_outputs(collections, alias, outputs):
  """Add `Tensor` outputs tagged with alias to collections.

  It is useful to collect end-points or tags for summaries. Example of usage:

  logits = collect_named_outputs('end_points', 'inception_v3/logits', logits)
  assert 'inception_v3/logits' in logits.aliases

  Args:
    collections: A collection or list of collections. If None skip collection.
    alias: String to append to the list of aliases of outputs, for example,
           'inception_v3/conv1'.
    outputs: Tensor, an output tensor to collect

  Returns:
    The outputs Tensor to allow inline call.
  """
  if collections:
    append_tensor_alias(outputs, alias)
    ops.add_to_collections(collections, outputs)
  return outputs


def append_tensor_alias(tensor, alias):
  """Append an alias to the list of aliases of the tensor.

  Args:
    tensor: A `Tensor`.
    alias: String, to add to the list of aliases of the tensor.

  Returns:
    The tensor with a new alias appended to its list of aliases.
  """
  # Remove ending '/' if present.
  if alias[-1] == '/':
    alias = alias[:-1]
  if hasattr(tensor, 'aliases'):
    tensor.aliases.append(alias)
  else:
    tensor.aliases = [alias]
  return tensor


def gather_tensors_aliases(tensors):
  """Given a list of tensors, gather their aliases.

  Args:
    tensors: A list of `Tensors`.

  Returns:
    A list of strings with the aliases of all tensors.
  """
  aliases = []
  for tensor in tensors:
    aliases += get_tensor_aliases(tensor)
  return aliases


def get_tensor_aliases(tensor):
  """Get a list with the aliases of the input tensor.

  If the tensor does not have any alias, it would default to its its op.name or
  its name.

  Args:
    tensor: A `Tensor`.

  Returns:
    A list of strings with the aliases of the tensor.
  """
  if hasattr(tensor, 'aliases'):
    aliases = tensor.aliases
  else:
    if tensor.name[-2:] == ':0':
      # Use op.name for tensor ending in :0
      aliases = [tensor.op.name]
    else:
      aliases = [tensor.name]
  return aliases


def convert_collection_to_dict(collection, clear_collection=False):
  """Returns an OrderedDict of Tensors with their aliases as keys.

  Args:
    collection: A collection.
    clear_collection: When True, it clears the collection after converting to
      OrderedDict.

  Returns:
    An OrderedDict of {alias: tensor}
  """
  output = OrderedDict((alias, tensor)
                       for tensor in ops.get_collection(collection)
                       for alias in get_tensor_aliases(tensor))
  if clear_collection:
    ops.get_default_graph().clear_collection(collection)
  return output


def constant_value(value_or_tensor_or_var, dtype=None):
  """Returns value if value_or_tensor_or_var has a constant value.

  Args:
    value_or_tensor_or_var: A value, a `Tensor` or a `Variable`.
    dtype: Optional `tf.dtype`, if set it would check it has the right
      dtype.

  Returns:
    The constant value or None if it not constant.

  Raises:
    ValueError: if value_or_tensor_or_var is None or the tensor_variable has the
    wrong dtype.
  """
  if value_or_tensor_or_var is None:
    raise ValueError('value_or_tensor_or_var cannot be None')
  value = value_or_tensor_or_var
  if isinstance(value_or_tensor_or_var, (ops.Tensor, variables.Variable)):
    if dtype and value_or_tensor_or_var.dtype != dtype:
      raise ValueError('It has the wrong type %s instead of %s' % (
          value_or_tensor_or_var.dtype, dtype))
    if isinstance(value_or_tensor_or_var, variables.Variable):
      value = None
    else:
      value = tensor_util.constant_value(value_or_tensor_or_var)
  return value


def static_cond(pred, fn1, fn2):
  """Return either fn1() or fn2() based on the boolean value of `pred`.

  Same signature as `tf.cond()` but requires pred to be a bool.

  Args:
    pred: A value determining whether to return the result of `fn1` or `fn2`.
    fn1: The callable to be performed if pred is true.
    fn2: The callable to be performed if pred is false.

  Returns:
    Tensors returned by the call to either `fn1` or `fn2`.

  Raises:
    TypeError: if `fn1` or `fn2` is not callable.
  """
  if not callable(fn1):
    raise TypeError('fn1 must be callable.')
  if not callable(fn2):
    raise TypeError('fn2 must be callable.')
  if pred:
    return fn1()
  else:
    return fn2()


def smart_cond(pred, fn1, fn2, name=None):
  """Return either fn1() or fn2() based on the boolean predicate/value `pred`.

  If `pred` is bool or has a constant value it would use `static_cond`,
  otherwise it would use `tf.cond`.

  Args:
    pred: A scalar determining whether to return the result of `fn1` or `fn2`.
    fn1: The callable to be performed if pred is true.
    fn2: The callable to be performed if pred is false.
    name: Optional name prefix when using tf.cond
  Returns:
    Tensors returned by the call to either `fn1` or `fn2`.
  """
  pred_value = constant_value(pred)
  if pred_value is not None:
    # Use static_cond if pred has a constant value.
    return static_cond(pred_value, fn1, fn2)
  else:
    # Use dynamic cond otherwise.
    return control_flow_ops.cond(pred, fn1, fn2, name)


def get_variable_collections(variables_collections, name):
  if isinstance(variables_collections, dict):
    variable_collections = variables_collections.get(name, None)
  else:
    variable_collections = variables_collections
  return variable_collections


def _get_dimension(shape, dim, min_rank=1):
  """Returns the `dim` dimension of `shape`, while checking it has `min_rank`.

  Args:
    shape: A `TensorShape`.
    dim: Integer, which dimension to return.
    min_rank: Integer, minimum rank of shape.

  Returns:
    The value of the `dim` dimension.

  Raises:
    ValueError: if inputs don't have at least min_rank dimensions, or if the
      first dimension value is not defined.
  """
  dims = shape.dims
  if dims is None:
    raise ValueError('dims of shape must be known but is None')
  if len(dims) < min_rank:
    raise ValueError('rank of shape must be at least %d not: %d' % (min_rank,
                                                                    len(dims)))
  value = dims[dim].value
  if value is None:
    raise ValueError(
        'dimension %d of shape must be known but is None: %s' % (dim, shape))
  return value


def channel_dimension(shape, data_format, min_rank=1):
  """Returns the channel dimension of shape, while checking it has min_rank.

  Args:
    shape: A `TensorShape`.
    data_format: `channels_first` or `channels_last`.
    min_rank: Integer, minimum rank of shape.

  Returns:
    The value of the first dimension.

  Raises:
    ValueError: if inputs don't have at least min_rank dimensions, or if the
      first dimension value is not defined.
  """
  return _get_dimension(shape, 1 if data_format == 'channels_first' else -1,
                        min_rank=min_rank)


def last_dimension(shape, min_rank=1):
  """Returns the last dimension of shape while checking it has min_rank.

  Args:
    shape: A `TensorShape`.
    min_rank: Integer, minimum rank of shape.

  Returns:
    The value of the last dimension.

  Raises:
    ValueError: if inputs don't have at least min_rank dimensions, or if the
      last dimension value is not defined.
  """
  return _get_dimension(shape, -1, min_rank=min_rank)


def two_element_tuple(int_or_tuple):
  """Converts `int_or_tuple` to height, width.

  Several of the functions that follow accept arguments as either
  a tuple of 2 integers or a single integer.  A single integer
  indicates that the 2 values of the tuple are the same.

  This functions normalizes the input value by always returning a tuple.

  Args:
    int_or_tuple: A list of 2 ints, a single int or a `TensorShape`.

  Returns:
    A tuple with 2 values.

  Raises:
    ValueError: If `int_or_tuple` it not well formed.
  """
  if isinstance(int_or_tuple, (list, tuple)):
    if len(int_or_tuple) != 2:
      raise ValueError('Must be a list with 2 elements: %s' % int_or_tuple)
    return int(int_or_tuple[0]), int(int_or_tuple[1])
  if isinstance(int_or_tuple, int):
    return int(int_or_tuple), int(int_or_tuple)
  if isinstance(int_or_tuple, tensor_shape.TensorShape):
    if len(int_or_tuple) == 2:
      return int_or_tuple[0], int_or_tuple[1]
  raise ValueError('Must be an int, a list with 2 elements or a TensorShape of '
                   'length 2')


def n_positive_integers(n, value):
  """Converts `value` to a sequence of `n` positive integers.

  `value` may be either be a sequence of values convertible to `int`, or a
  single value convertible to `int`, in which case the resulting integer is
  duplicated `n` times.  It may also be a TensorShape of rank `n`.

  Args:
    n: Length of sequence to return.
    value: Either a single value convertible to a positive `int` or an
      `n`-element sequence of values convertible to a positive `int`.

  Returns:
    A tuple of `n` positive integers.

  Raises:
    TypeError: If `n` is not convertible to an integer.
    ValueError: If `n` or `value` are invalid.
  """

  n_orig = n
  n = int(n)
  if n < 1 or n != n_orig:
    raise ValueError('n must be a positive integer')

  try:
    value = int(value)
  except (TypeError, ValueError):
    sequence_len = len(value)
    if sequence_len != n:
      raise ValueError(
          'Expected sequence of %d positive integers, but received %r' %
          (n, value))
    try:
      values = tuple(int(x) for x in value)
    except:
      raise ValueError(
          'Expected sequence of %d positive integers, but received %r' %
          (n, value))
    for x in values:
      if x < 1:
        raise ValueError('expected positive integer, but received %d' % x)
    return values

  if value < 1:
    raise ValueError('expected positive integer, but received %d' % value)
  return (value,) * n