# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

r"""Utility functions for Real NVP.
"""

# pylint: disable=dangerous-default-value

import numpy
import tensorflow as tf
from tensorflow.python.framework import ops

DEFAULT_BN_LAG = .0


def stable_var(input_, mean=None, axes=[0]):
    """Numerically more stable variance computation."""
    if mean is None:
        mean = tf.reduce_mean(input_, axes)
    res = tf.square(input_ - mean)
    max_sqr = tf.reduce_max(res, axes)
    res /= max_sqr
    res = tf.reduce_mean(res, axes)
    res *= max_sqr

    return res


def variable_on_cpu(name, shape, initializer, trainable=True):
    """Helper to create a Variable stored on CPU memory.

    Args:
            name: name of the variable
            shape: list of ints
            initializer: initializer for Variable
            trainable: boolean defining if the variable is for training
    Returns:
            Variable Tensor
    """
    var = tf.get_variable(
        name, shape, initializer=initializer, trainable=trainable)
    return var


# layers
def conv_layer(input_,
               filter_size,
               dim_in,
               dim_out,
               name,
               stddev=1e-2,
               strides=[1, 1, 1, 1],
               padding="SAME",
               nonlinearity=None,
               bias=False,
               weight_norm=False,
               scale=False):
    """Convolutional layer."""
    with tf.variable_scope(name) as scope:
        weights = variable_on_cpu(
            "weights",
            filter_size + [dim_in, dim_out],
            tf.random_uniform_initializer(
                minval=-stddev, maxval=stddev))
        # weight normalization
        if weight_norm:
            weights /= tf.sqrt(tf.reduce_sum(tf.square(weights), [0, 1, 2]))
            if scale:
                magnitude = variable_on_cpu(
                    "magnitude", [dim_out],
                    tf.constant_initializer(
                        stddev * numpy.sqrt(dim_in * numpy.prod(filter_size) / 12.)))
                weights *= magnitude
        res = input_
        # handling filter size bigger than image size
        if hasattr(input_, "shape"):
            if input_.get_shape().as_list()[1] < filter_size[0]:
                pad_1 = tf.zeros([
                    input_.get_shape().as_list()[0],
                    filter_size[0] - input_.get_shape().as_list()[1],
                    input_.get_shape().as_list()[2],
                    input_.get_shape().as_list()[3]
                ])
                pad_2 = tf.zeros([
                    input_.get_shape().as_list[0],
                    filter_size[0],
                    filter_size[1] - input_.get_shape().as_list()[2],
                    input_.get_shape().as_list()[3]
                ])
                res = tf.concat(axis=1, values=[pad_1, res])
                res = tf.concat(axis=2, values=[pad_2, res])
        res = tf.nn.conv2d(
            input=res,
            filter=weights,
            strides=strides,
            padding=padding,
            name=scope.name)

        if hasattr(input_, "shape"):
            if input_.get_shape().as_list()[1] < filter_size[0]:
                res = tf.slice(res, [
                    0, filter_size[0] - input_.get_shape().as_list()[1],
                    filter_size[1] - input_.get_shape().as_list()[2], 0
                ], [-1, -1, -1, -1])

        if bias:
            biases = variable_on_cpu("biases", [dim_out], tf.constant_initializer(0.))
            res = tf.nn.bias_add(res, biases)
        if nonlinearity is not None:
            res = nonlinearity(res)

    return res


def max_pool_2x2(input_):
    """Max pooling."""
    return tf.nn.max_pool(
        input_, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")


def depool_2x2(input_, stride=2):
    """Depooling."""
    shape = input_.get_shape().as_list()
    batch_size = shape[0]
    height = shape[1]
    width = shape[2]
    channels = shape[3]
    res = tf.reshape(input_, [batch_size, height, 1, width, 1, channels])
    res = tf.concat(
        axis=2, values=[res, tf.zeros([batch_size, height, stride - 1, width, 1, channels])])
    res = tf.concat(axis=4, values=[
        res, tf.zeros([batch_size, height, stride, width, stride - 1, channels])
    ])
    res = tf.reshape(res, [batch_size, stride * height, stride * width, channels])

    return res


# random flip on a batch of images
def batch_random_flip(input_):
    """Simultaneous horizontal random flip."""
    if isinstance(input_, (float, int)):
        return input_
    shape = input_.get_shape().as_list()
    batch_size = shape[0]
    height = shape[1]
    width = shape[2]
    channels = shape[3]
    res = tf.split(axis=0, num_or_size_splits=batch_size, value=input_)
    res = [elem[0, :, :, :] for elem in res]
    res = [tf.image.random_flip_left_right(elem) for elem in res]
    res = [tf.reshape(elem, [1, height, width, channels]) for elem in res]
    res = tf.concat(axis=0, values=res)

    return res


# build a one hot representation corresponding to the integer tensor
# the one-hot dimension is appended to the integer tensor shape
def as_one_hot(input_, n_indices):
    """Convert indices to one-hot."""
    shape = input_.get_shape().as_list()
    n_elem = numpy.prod(shape)
    indices = tf.range(n_elem)
    indices = tf.cast(indices, tf.int64)
    indices_input = tf.concat(axis=0, values=[indices, tf.reshape(input_, [-1])])
    indices_input = tf.reshape(indices_input, [2, -1])
    indices_input = tf.transpose(indices_input)
    res = tf.sparse_to_dense(
        indices_input, [n_elem, n_indices], 1., 0., name="flat_one_hot")
    res = tf.reshape(res, [elem for elem in shape] + [n_indices])

    return res


def squeeze_2x2(input_):
    """Squeezing operation: reshape to convert space to channels."""
    return squeeze_nxn(input_, n_factor=2)


def squeeze_nxn(input_, n_factor=2):
    """Squeezing operation: reshape to convert space to channels."""
    if isinstance(input_, (float, int)):
        return input_
    shape = input_.get_shape().as_list()
    batch_size = shape[0]
    height = shape[1]
    width = shape[2]
    channels = shape[3]
    if height % n_factor != 0:
        raise ValueError("Height not divisible by %d." % n_factor)
    if width % n_factor != 0:
        raise ValueError("Width not divisible by %d." % n_factor)
    res = tf.reshape(
        input_,
        [batch_size,
         height // n_factor,
         n_factor, width // n_factor,
         n_factor, channels])
    res = tf.transpose(res, [0, 1, 3, 5, 2, 4])
    res = tf.reshape(
        res,
        [batch_size,
         height // n_factor,
         width // n_factor,
         channels * n_factor * n_factor])

    return res


def unsqueeze_2x2(input_):
    """Unsqueezing operation: reshape to convert channels into space."""
    if isinstance(input_, (float, int)):
        return input_
    shape = input_.get_shape().as_list()
    batch_size = shape[0]
    height = shape[1]
    width = shape[2]
    channels = shape[3]
    if channels % 4 != 0:
        raise ValueError("Number of channels not divisible by 4.")
    res = tf.reshape(input_, [batch_size, height, width, channels // 4, 2, 2])
    res = tf.transpose(res, [0, 1, 4, 2, 5, 3])
    res = tf.reshape(res, [batch_size, 2 * height, 2 * width, channels // 4])

    return res


# batch norm
def batch_norm(input_,
               dim,
               name,
               scale=True,
               train=True,
               epsilon=1e-8,
               decay=.1,
               axes=[0],
               bn_lag=DEFAULT_BN_LAG):
    """Batch normalization."""
    # create variables
    with tf.variable_scope(name):
        var = variable_on_cpu(
            "var", [dim], tf.constant_initializer(1.), trainable=False)
        mean = variable_on_cpu(
            "mean", [dim], tf.constant_initializer(0.), trainable=False)
        step = variable_on_cpu("step", [], tf.constant_initializer(0.), trainable=False)
        if scale:
            gamma = variable_on_cpu("gamma", [dim], tf.constant_initializer(1.))
        beta = variable_on_cpu("beta", [dim], tf.constant_initializer(0.))
    # choose the appropriate moments
    if train:
        used_mean, used_var = tf.nn.moments(input_, axes, name="batch_norm")
        cur_mean, cur_var = used_mean, used_var
        if bn_lag > 0.:
            used_mean -= (1. - bn_lag) * (used_mean - tf.stop_gradient(mean))
            used_var -= (1 - bn_lag) * (used_var - tf.stop_gradient(var))
            used_mean /= (1. - bn_lag**(step + 1))
            used_var /= (1. - bn_lag**(step + 1))
    else:
        used_mean, used_var = mean, var
        cur_mean, cur_var = used_mean, used_var

    # normalize
    res = (input_ - used_mean) / tf.sqrt(used_var + epsilon)
    # de-normalize
    if scale:
        res *= gamma
    res += beta

    # update variables
    if train:
        with tf.name_scope(name, "AssignMovingAvg", [mean, cur_mean, decay]):
            with ops.colocate_with(mean):
                new_mean = tf.assign_sub(
                    mean,
                    tf.check_numerics(decay * (mean - cur_mean), "NaN in moving mean."))
        with tf.name_scope(name, "AssignMovingAvg", [var, cur_var, decay]):
            with ops.colocate_with(var):
                new_var = tf.assign_sub(
                    var,
                    tf.check_numerics(decay * (var - cur_var),
                                      "NaN in moving variance."))
        with tf.name_scope(name, "IncrementTime", [step]):
            with ops.colocate_with(step):
                new_step = tf.assign_add(step, 1.)
        res += 0. * new_mean * new_var * new_step

    return res


# batch normalization taking into account the volume transformation
def batch_norm_log_diff(input_,
                        dim,
                        name,
                        train=True,
                        epsilon=1e-8,
                        decay=.1,
                        axes=[0],
                        reuse=None,
                        bn_lag=DEFAULT_BN_LAG):
    """Batch normalization with corresponding log determinant Jacobian."""
    if reuse is None:
        reuse = not train
    # create variables
    with tf.variable_scope(name) as scope:
        if reuse:
            scope.reuse_variables()
        var = variable_on_cpu(
            "var", [dim], tf.constant_initializer(1.), trainable=False)
        mean = variable_on_cpu(
            "mean", [dim], tf.constant_initializer(0.), trainable=False)
        step = variable_on_cpu("step", [], tf.constant_initializer(0.), trainable=False)
    # choose the appropriate moments
    if train:
        used_mean, used_var = tf.nn.moments(input_, axes, name="batch_norm")
        cur_mean, cur_var = used_mean, used_var
        if bn_lag > 0.:
            used_var = stable_var(input_=input_, mean=used_mean, axes=axes)
            cur_var = used_var
            used_mean -= (1 - bn_lag) * (used_mean - tf.stop_gradient(mean))
            used_mean /= (1. - bn_lag**(step + 1))
            used_var -= (1 - bn_lag) * (used_var - tf.stop_gradient(var))
            used_var /= (1. - bn_lag**(step + 1))
    else:
        used_mean, used_var = mean, var
        cur_mean, cur_var = used_mean, used_var

    # update variables
    if train:
        with tf.name_scope(name, "AssignMovingAvg", [mean, cur_mean, decay]):
            with ops.colocate_with(mean):
                new_mean = tf.assign_sub(
                    mean,
                    tf.check_numerics(
                        decay * (mean - cur_mean), "NaN in moving mean."))
        with tf.name_scope(name, "AssignMovingAvg", [var, cur_var, decay]):
            with ops.colocate_with(var):
                new_var = tf.assign_sub(
                    var,
                    tf.check_numerics(decay * (var - cur_var),
                                      "NaN in moving variance."))
        with tf.name_scope(name, "IncrementTime", [step]):
            with ops.colocate_with(step):
                new_step = tf.assign_add(step, 1.)
        used_var += 0. * new_mean * new_var * new_step
    used_var += epsilon

    return used_mean, used_var


def convnet(input_,
            dim_in,
            dim_hid,
            filter_sizes,
            dim_out,
            name,
            use_batch_norm=True,
            train=True,
            nonlinearity=tf.nn.relu):
    """Chaining of convolutional layers."""
    dims_in = [dim_in] + dim_hid[:-1]
    dims_out = dim_hid
    res = input_

    bias = (not use_batch_norm)
    with tf.variable_scope(name):
        for layer_idx in xrange(len(dim_hid)):
            res = conv_layer(
                input_=res,
                filter_size=filter_sizes[layer_idx],
                dim_in=dims_in[layer_idx],
                dim_out=dims_out[layer_idx],
                name="h_%d" % layer_idx,
                stddev=1e-2,
                nonlinearity=None,
                bias=bias)
            if use_batch_norm:
                res = batch_norm(
                    input_=res,
                    dim=dims_out[layer_idx],
                    name="bn_%d" % layer_idx,
                    scale=(nonlinearity == tf.nn.relu),
                    train=train,
                    epsilon=1e-8,
                    axes=[0, 1, 2])
            if nonlinearity is not None:
                res = nonlinearity(res)

        res = conv_layer(
            input_=res,
            filter_size=filter_sizes[-1],
            dim_in=dims_out[-1],
            dim_out=dim_out,
            name="out",
            stddev=1e-2,
            nonlinearity=None)

    return res


# distributions
# log-likelihood estimation
def standard_normal_ll(input_):
    """Log-likelihood of standard Gaussian distribution."""
    res = -.5 * (tf.square(input_) + numpy.log(2. * numpy.pi))

    return res


def standard_normal_sample(shape):
    """Samples from standard Gaussian distribution."""
    return tf.random_normal(shape)


SQUEEZE_MATRIX = numpy.array([[[[1., 0., 0., 0.]], [[0., 0., 1., 0.]]],
                              [[[0., 0., 0., 1.]], [[0., 1., 0., 0.]]]])


def squeeze_2x2_ordered(input_, reverse=False):
    """Squeezing operation with a controlled ordering."""
    shape = input_.get_shape().as_list()
    batch_size = shape[0]
    height = shape[1]
    width = shape[2]
    channels = shape[3]
    if reverse:
        if channels % 4 != 0:
            raise ValueError("Number of channels not divisible by 4.")
        channels /= 4
    else:
        if height % 2 != 0:
            raise ValueError("Height not divisible by 2.")
        if width % 2 != 0:
            raise ValueError("Width not divisible by 2.")
    weights = numpy.zeros((2, 2, channels, 4 * channels))
    for idx_ch in xrange(channels):
        slice_2 = slice(idx_ch, (idx_ch + 1))
        slice_3 = slice((idx_ch * 4), ((idx_ch + 1) * 4))
        weights[:, :, slice_2, slice_3] = SQUEEZE_MATRIX
    shuffle_channels = [idx_ch * 4 for idx_ch in xrange(channels)]
    shuffle_channels += [idx_ch * 4 + 1 for idx_ch in xrange(channels)]
    shuffle_channels += [idx_ch * 4 + 2 for idx_ch in xrange(channels)]
    shuffle_channels += [idx_ch * 4 + 3 for idx_ch in xrange(channels)]
    shuffle_channels = numpy.array(shuffle_channels)
    weights = weights[:, :, :, shuffle_channels].astype("float32")
    if reverse:
        res = tf.nn.conv2d_transpose(
            value=input_,
            filter=weights,
            output_shape=[batch_size, height * 2, width * 2, channels],
            strides=[1, 2, 2, 1],
            padding="SAME",
            name="unsqueeze_2x2")
    else:
        res = tf.nn.conv2d(
            input=input_,
            filter=weights,
            strides=[1, 2, 2, 1],
            padding="SAME",
            name="squeeze_2x2")

    return res