# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""DNC util ops and modules."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf


def batch_invert_permutation(permutations):
  """Returns batched `tf.invert_permutation` for every row in `permutations`."""
  with tf.name_scope('batch_invert_permutation', values=[permutations]):
    perm = tf.cast(permutations, tf.float32)
    dim = int(perm.get_shape()[-1])
    size = tf.cast(tf.shape(perm)[0], tf.float32)
    delta = tf.cast(tf.shape(perm)[-1], tf.float32)
    rg = tf.range(0, size * delta, delta, dtype=tf.float32)
    rg = tf.expand_dims(rg, 1)
    rg = tf.tile(rg, [1, dim])
    perm = tf.add(perm, rg)
    flat = tf.reshape(perm, [-1])
    perm = tf.invert_permutation(tf.cast(flat, tf.int32))
    perm = tf.reshape(perm, [-1, dim])
    return tf.subtract(perm, tf.cast(rg, tf.int32))


def batch_gather(values, indices):
  """Returns batched `tf.gather` for every row in the input."""
  with tf.name_scope('batch_gather', values=[values, indices]):
    idx = tf.expand_dims(indices, -1)
    size = tf.shape(indices)[0]
    rg = tf.range(size, dtype=tf.int32)
    rg = tf.expand_dims(rg, -1)
    rg = tf.tile(rg, [1, int(indices.get_shape()[-1])])
    rg = tf.expand_dims(rg, -1)
    gidx = tf.concat([rg, idx], -1)
    return tf.gather_nd(values, gidx)


def one_hot(length, index):
  """Return an nd array of given `length` filled with 0s and a 1 at `index`."""
  result = np.zeros(length)
  result[index] = 1
  return result

def reduce_prod(x, axis, name=None):
  """Efficient reduce product over axis.

  Uses tf.cumprod and tf.gather_nd as a workaround to the poor performance of calculating tf.reduce_prod's gradient on CPU.
  """
  with tf.name_scope(name, 'util_reduce_prod', values=[x]):
    cp = tf.cumprod(x, axis, reverse=True)
    size = tf.shape(cp)[0]
    idx1 = tf.range(tf.cast(size, tf.float32), dtype=tf.float32)
    idx2 = tf.zeros([size], tf.float32)
    indices = tf.stack([idx1, idx2], 1)
    return tf.gather_nd(cp, tf.cast(indices, tf.int32))