# Based on WaveGAN codes

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

def tf_repeat(output, idx, dim1, dim2, bias):
    # tensor equivalent of np.repeat
    # 1d to 3d array tensor
    if bias:
        idx = tf.tile(idx, [1, dim1 * dim2])
        idx = tf.reshape(idx, [-1, dim1, dim2])
        return output * idx
    else:
        return output

def conv1d_transpose(
    inputs,
    filters,
    kernel_width,
    stride=4,
    padding='same',
    upsample='zeros'):
    if upsample == 'zeros':
        return tf.layers.conv2d_transpose(
            tf.expand_dims(inputs, axis=1),
            filters,
            (1, kernel_width),
            strides=(1, stride),
            padding='same'
        )[:, 0]
    else:
        raise NotImplementedError


def lrelu(inputs, alpha=0.2):
  return tf.maximum(alpha * inputs, inputs)


def apply_phaseshuffle(x, rad, pad_type='reflect'):
  b, x_len, nch = x.get_shape().as_list()

  phase = tf.random_uniform([], minval=-rad, maxval=rad + 1, dtype=tf.int32)
  pad_l = tf.maximum(phase, 0)
  pad_r = tf.maximum(-phase, 0)
  phase_start = pad_r
  x = tf.pad(x, [[0, 0], [pad_l, pad_r], [0, 0]], mode=pad_type)

  x = x[:, phase_start:phase_start+x_len]
  x.set_shape([b, x_len, nch])

  return x


"""
  Input: [None, 8192, 1]
  Output: [None] (linear output)
"""
def discriminator_wavegan(
    x,
    labels,
    kernel_len=25,
    dim=64,
    use_batchnorm=True,
    phaseshuffle_rad=0,
    reuse=False,
    scope='Discriminator',
    bias=False):
    with tf.variable_scope(scope, reuse=reuse):
        batch_size = tf.shape(x)[0]

        if use_batchnorm:
            batchnorm = lambda x: tf.layers.batch_normalization(x, training=True)
        else:
            batchnorm = lambda x: x

        if phaseshuffle_rad > 0:
            phaseshuffle = lambda x: apply_phaseshuffle(x, phaseshuffle_rad)
        else:
            phaseshuffle = lambda x: x

        with tf.variable_scope('discriminator_0', reuse=reuse):
            # Layer 0
            # [8192, 1] -> [4096, 64]
            output = x
            output = tf.layers.conv1d(output, dim, kernel_len, 2, padding='SAME', name='downconv_0')
            output = tf_repeat(output, labels, 4096, dim, bias)
            output = lrelu(output)
            output = phaseshuffle(output)

            # Layer 1
            # [4096, 64] -> [1024, 128]
            output = tf.layers.conv1d(output, dim * 2, kernel_len, 4, padding='SAME', name='downconv_1')
            output = tf_repeat(output, labels, 1024, dim * 2, bias)
            output = batchnorm(output)
            output = lrelu(output)
            output = phaseshuffle(output)

            # Layer 2
            # [1024, 128] -> [256, 256]
            output = tf.layers.conv1d(output, dim * 4, kernel_len, 4, padding='SAME', name='downconv_2')
            output = tf_repeat(output, labels, 256, dim * 4, bias)
            output = batchnorm(output)
            output = lrelu(output)
            output = phaseshuffle(output)

            # Layer 3
            # [256, 256] -> [64, 512]
            output = tf.layers.conv1d(output, dim * 8, kernel_len, 4, padding='SAME', name='downconv_3')
            output = tf_repeat(output, labels, 64, dim * 8, bias)
            output = batchnorm(output)
            output = lrelu(output)
            output = phaseshuffle(output)

            # Layer 4
            # [64, 512] -> [16, 1024]
            output = tf.layers.conv1d(output, dim * 16, kernel_len, 4, padding='SAME', name='downconv_4')
            output = tf_repeat(output, labels, 16, dim * 16, bias)
            output = batchnorm(output)
            output = lrelu(output)

        # Flatten
        output = tf.reshape(output, [batch_size, 4 * 4 * dim * 16])

        # Connect to single logit
        with tf.variable_scope('output', reuse=reuse):
            output = tf.layers.dense(output, 1)[:, 0]

        # Don't need to aggregate batchnorm update ops like we do for the generator because we only use the discriminator for training

        return output


"""
  Input: [None, 100]
  Output: [None, 8192, 1]
"""
def generator_wavegan(
    z,
    labels,
    kernel_len=25,
    dim=64,
    use_batchnorm=True,
    upsample='zeros',
    train=False,
    scope='Generator',
    bias=False):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        batch_size = tf.shape(z)[0]

        if use_batchnorm:
            batchnorm = lambda x: tf.layers.batch_normalization(x, training=train)
        else:
            batchnorm = lambda x: x

        # FC and reshape for convolution
        # [100] -> [16, 1024]
        output = z
        with tf.variable_scope('z_project'):
            output = tf.layers.dense(output, 4 * 4 * dim * 16)
            output = tf.reshape(output, [batch_size, 16, dim * 16])
            output = batchnorm(output)
        output = tf_repeat(output, labels, 16, dim * 16, bias)
        output = tf.nn.relu(output)

        # Layer 0
        # [16, 1024] -> [64, 512]
        with tf.variable_scope('upconv_0'):
            output = conv1d_transpose(output, dim * 8, kernel_len, 4, upsample=upsample)
            output = batchnorm(output)
        output = tf_repeat(output, labels, 64, dim * 8, bias)
        output = tf.nn.relu(output)

        # Layer 1
        # [64, 512] -> [256, 256]
        with tf.variable_scope('upconv_1'):
            output = conv1d_transpose(output, dim * 4, kernel_len, 4, upsample=upsample)
            output = batchnorm(output)
        output = tf_repeat(output, labels, 256, dim * 4, bias)
        output = tf.nn.relu(output)

        # Layer 2
        # [256, 256] -> [1024, 128]
        with tf.variable_scope('upconv_2'):
            output = conv1d_transpose(output, dim * 2, kernel_len, 4, upsample=upsample)
            output = batchnorm(output)
        output = tf_repeat(output, labels, 1024, dim * 2, bias)
        output = tf.nn.relu(output)

        # Layer 3
        # [1024, 128] -> [4096, 64]
        with tf.variable_scope('upconv_3'):
            output = conv1d_transpose(output, dim, kernel_len, 4, upsample=upsample)
            output = batchnorm(output)
        output = tf_repeat(output, labels, 4096, dim, bias)
        output = tf.nn.relu(output)

        # Layer 4
        # [4096, 64] -> [8192, 1]
        with tf.variable_scope('upconv_4'):
            # output = conv1d_transpose(output, 1, kernel_len, 4, upsample=upsample)
            output = conv1d_transpose(output, 1, kernel_len, 2, upsample=upsample)
        output = tf_repeat(output, labels, 8192, 1, bias)
        output = tf.nn.tanh(output)

        # Automatically update batchnorm moving averages every time G is used during training
        if train and use_batchnorm:
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if len(update_ops) != 10:
                raise Exception('Other update ops found in graph')
            with tf.control_dependencies(update_ops):
                output = tf.identity(output)

        return output