import math
import numpy as np
import tensorflow as tf
from contextlib import contextmanager

from tensorflow.python.framework import ops

from utils import *

slim = tf.contrib.slim
rng = np.random.RandomState([2016, 6, 1])
ln = tf.contrib.layers.layer_norm
bn = slim.batch_norm

def conv_cond_concat(x, y):
    """Concatenate conditioning vector on feature map axis."""
    x_shapes = x.get_shape()
    y_shapes = y.get_shape()
    return tf.concat([x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3)

def lrelu(x, leak=0.2, name="lrelu"):
    with tf.variable_scope(name):
        f1 = 0.5 * (1 + leak)
        f2 = 0.5 * (1 - leak)
        return f1 * x + f2 * abs(x)

def sin_and_cos(x, name="ignored"):
    return tf.concat(len(x.get_shape()) - 1, [tf.sin(x), tf.cos(x)])

def maxout(x, k = 2):
    shape = [int(e) for e in x.get_shape()]
    ax = len(shape)
    ch = shape[-1]
    assert ch % k == 0
    shape[-1] = ch / k
    x = tf.reshape(x, shape)
    return tf.reduce_max(x, ax)

def offset_maxout(x, k = 2):
    shape = [int(e) for e in x.get_shape()]
    ax = len(shape)
    ch = shape[-1]
    assert ch % k == 0
    shape[-1] = ch / k
    x = tf.reshape(x, shape)
    ofs = rng.randn(1000, k).max(axis=1).mean()
    return tf.reduce_max(x, ax) - ofs

def lrelu_sq(x):
    Concatenates lrelu and square
    dim = len(x.get_shape()) - 1
    return tf.concat([lrelu(x), tf.minimum(tf.abs(x), tf.square(x))], dim)

def nin(input_, output_size, name=None, mean=0., stddev=0.02, bias_start=0.0, with_w=False):
    s = list(map(int, input_.get_shape()))
    input_ = tf.reshape(input_, [[:-1]), s[-1]])
    input_ = fc(input_, output_size, act=None, norm=None)
    return tf.reshape(input_, s[:-1]+[output_size])

def variables_on_cpu():
    old_fn = tf.get_variable
    def new_fn(*args, **kwargs):
        with tf.device("/cpu:0"):
            return old_fn(*args, **kwargs)
    tf.get_variable = new_fn
    tf.get_variable = old_fn

def variables_on_gpu0():
    old_fn = tf.get_variable
    def new_fn(*args, **kwargs):
        with tf.device("/gpu:0"):
            return old_fn(*args, **kwargs)
    tf.get_variable = new_fn
    tf.get_variable = old_fn

def avg_grads(tower_grads):
  """Calculate the average gradient for each shared variable across all towers.

  Note that this function provides a synchronization point across all towers.

    tower_grads: List of lists of (gradient, variable) tuples. The outer list
      is over individual gradients. The inner list is over the gradient
      calculation for each tower.
     List of pairs of (gradient, variable) where the gradient has been averaged
     across all towers.
  average_grads = []
  for grad_and_vars in zip(*tower_grads):
    # Note that each grad_and_vars looks like the following:
    #   ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
    grads = []
    for g, _ in grad_and_vars:
      # Add 0 dimension to the gradients to represent the tower.
      expanded_g = tf.expand_dims(g, 0)

      # Append on a 'tower' dimension which we will average over below.

    # Average over the 'tower' dimension.
    grad = tf.reduce_mean(tf.concat(grads, 0), 0)

    # Keep in mind that the Variables are redundant because they are shared
    # across towers. So .. we will just return the first tower's pointer to
    # the Variable.
    v = grad_and_vars[0][1]
    grad_and_var = (grad, v)
  return average_grads

def decayer(x, name="decayer"):
    with tf.variable_scope(name):
        scale = tf.get_variable("scale", [1], initializer=tf.constant_initializer(1.))
        decay_scale = tf.get_variable("decay_scale", [1], initializer=tf.constant_initializer(1.))
        relu = tf.nn.relu(x)
        return scale * relu / (1. + tf.abs(decay_scale) * tf.square(decay_scale))

def decayer2(x, name="decayer"):
    with tf.variable_scope(name):
        scale = tf.get_variable("scale", [int(x.get_shape()[-1])], initializer=tf.constant_initializer(1.))
        decay_scale = tf.get_variable("decay_scale", [int(x.get_shape()[-1])], initializer=tf.constant_initializer(1.))
        relu = tf.nn.relu(x)
        return scale * relu / (1. + tf.abs(decay_scale) * tf.square(decay_scale))

def masked_relu(x, name="ignored"):
    shape = [int(e) for e in x.get_shape()]
    prefix = [0] * (len(shape) - 1)
    most = shape[:-1]
    assert shape[-1] % 2 == 0
    half = shape[-1] // 2
    first_half = tf.slice(x, prefix + [0], most + [half])
    second_half = tf.slice(x, prefix + [half], most + [half])
    return tf.nn.relu(first_half) * tf.nn.sigmoid(second_half)

def make_z(shape, minval=-1.0, maxval=1.0, name="z"):
    z = tf.random_uniform(shape,
                        minval=minval, maxval=maxval,
                        name=name, dtype=tf.float32)
    #z = tf.random_normal(shape, name=name, stddev=0.5, dtype=tf.float32)
    return z

def get_sample_zs(model):
    assert model.sample_size > model.batch_size
    assert model.sample_size % model.batch_size == 0
    if model.config.multigpu:
        batch_size = model.batch_size // len(model.devices)
        batch_size = model.batch_size

    steps = model.sample_size // batch_size
    assert steps > 0

    sample_zs = []
    for i in xrange(steps):
        cur_zs =

    sample_zs = np.concatenate(sample_zs, axis=0)
    assert sample_zs.shape[0] == model.sample_size
    return sample_zs

def batch_to_grid(images, width=8):
    images = tf.squeeze(images[:width**2])
    images_list = tf.unstack(images, num=width**2, axis=0)
    conc = tf.concat(images_list, axis=1)
    sp = tf.split(conc, width, axis=1)
    grid = tf.expand_dims(tf.concat(sp, axis=0), axis=0)
    if len(grid.get_shape().as_list()) < 4:
        grid = tf.expand_dims(grid, axis=-1)

    return grid

def fc(x, out_dim, is_training, act=tf.nn.relu, norm=bn, init=tf.truncated_normal_initializer(stddev=0.02)):#, is_training=False):
    if norm == bn:
        return slim.fully_connected(
            x, out_dim, activation_fn=act, normalizer_fn=norm,
            weights_initializer=init, normalizer_params={'is_training':is_training})
        return slim.fully_connected(
            x, out_dim, activation_fn=act, normalizer_fn=norm, weights_initializer=init)

def deconv2d(x, out_dim, k, s, is_training, act=tf.nn.relu, norm=bn, init=tf.truncated_normal_initializer(stddev=0.02)):#, is_training=False):
    if norm == bn:
        return slim.conv2d_transpose(
            x, out_dim, k, s, activation_fn=act, normalizer_fn=norm,
            weights_initializer=init, normalizer_params={'is_training':is_training})
        return slim.conv2d_transpose(
            x, out_dim, k, s, activation_fn=act, normalizer_fn=norm, weights_initializer=init)

def conv2d(x, out_dim, k, s, is_training, act=tf.nn.relu, norm=bn, init=tf.truncated_normal_initializer(stddev=0.02)):#, is_training=False):
    if norm == bn:
        return slim.conv2d(
            x, out_dim, k, s, activation_fn=act, normalizer_fn=norm,
            weights_initializer=init, normalizer_params={'is_training':is_training})
        return slim.conv2d(
            x, out_dim, k, s, activation_fn=act, normalizer_fn=norm, weights_initializer=init)

def preprocess_image(image, dataset, use_augmentation=False):
    image = tf.divide(image, 255., name=None)

    if use_augmentation:
        image = tf.image.random_brightness(image, max_delta=16. / 255.)
        image = tf.image.random_contrast(image, lower=0.8, upper=1.2)

    image = tf.minimum(tf.maximum(image, 0.0), 1.0)

    if ('mnist' not in dataset) and ('fashion' not in dataset):
        image = tf.subtract(image * 2., 1.)

    return image

def conv_mean_pool(x, out_dim, k=3, act=tf.nn.relu, norm=bn, init=tf.truncated_normal_initializer(stddev=0.02)):
    h = conv2d(x, out_dim, k=k, s=1, act=act, norm=norm, init=init)
    return tf.add_n([h[:,::2,::2,:], h[:,1::2,::2,:], h[:,::2,1::2,:], h[:,1::2,1::2,:]]) / 4.

def resize_conv2d(x, out_dim, k=3, scale=2, act=tf.nn.relu, norm=bn, init=tf.truncated_normal_initializer(stddev=0.02)):
    h = tf.concat([x, x, x, x], axis=3)
    h = tf.depth_to_space(h, 2)
    return conv2d(h, out_dim, k=k, s=1, act=act, norm=norm, init=init)

def residual_block(x, resample=None, labels=None, act=tf.nn.relu, norm=bn, init=tf.truncated_normal_initializer(stddev=0.02)):
    c_dim = x.get_shape().as_list()[-1]

    if resample=='down':
        h = conv2d(x, c_dim, 3, 1, act=act, init=init)
        #h = conv2d(x, c_dim, 3, 1, act=act, init=init, norm=norm)
        h = conv_mean_pool(h, c_dim, 3, act=None, norm=None, init=init)
        h += conv_mean_pool(x, c_dim, 1, act=None, norm=None, init=init)
        h = act(norm(h))
    elif resample=='up':
        h = resize_conv2d(x, c_dim, 3, act=act, init=init)
        #h = resize_conv2d(x, c_dim, 3, act=act, init=init, norm=norm)
        h = conv2d(h, c_dim, 3, 1, act=None, norm=None, init=init)
        h += resize_conv2d(x, c_dim, 1, act=None, norm=None, init=init)
        h = act(norm(h))
    elif resample==None:
        h = conv2d(x, c_dim, 3, 1, act=act, init=init)
        #h = conv2d(x, c_dim, 3, 1, act=act, init=init, norm=norm)
        h = conv2d(h, c_dim, 3, 1, act=None, norm=None, init=init)
        h += x
        h = act(norm(h))
        raise Exception('invalid resample value')

    return h

def get_vars_maybe_avg(var_names, ema, **kwargs):
    ''' utility for retrieving polyak averaged params '''
    vars = []
    for vn in var_names:
        vars.append(get_var_maybe_avg(vn, ema, **kwargs))
    return vars

def concat_relu(x):
    axis = len(x.get_shape().as_list())-1
    return tf.nn.relu(tf.concat([x, -x], axis))

def concat_elu(x):
    """ like concatenated ReLU (, but then with ELU """
    axis = len(x.get_shape().as_list())-1
    return tf.nn.elu(tf.concat([x, -x], axis))

def fc_wn(x, num_units, nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs):
    ''' fully connected layer '''
    name = get_name('dense', counters)
    with tf.variable_scope(name):
        if init:
            # data based initialization of parameters
            V = tf.get_variable('V', [int(x.get_shape()[1]),num_units], tf.float32, tf.random_normal_initializer(0, 0.05), trainable=True)
            V_norm = tf.nn.l2_normalize(V.initialized_value(), [0])
            x_init = tf.matmul(x, V_norm)
            m_init, v_init = tf.nn.moments(x_init, [0])
            scale_init = init_scale/tf.sqrt(v_init + 1e-10)
            g = tf.get_variable('g', dtype=tf.float32, initializer=scale_init, trainable=True)
            b = tf.get_variable('b', dtype=tf.float32, initializer=-m_init*scale_init, trainable=True)
            x_init = tf.reshape(scale_init,[1,num_units])*(x_init-tf.reshape(m_init,[1,num_units]))
            if nonlinearity is not None:
                x_init = nonlinearity(x_init)
            return x_init

            V,g,b = get_vars_maybe_avg(['V','g','b'], ema)

            # use weight normalization (Salimans & Kingma, 2016)
            x = tf.matmul(x, V)
            scaler = g/tf.sqrt(tf.reduce_sum(tf.square(V),[0]))
            x = tf.reshape(scaler,[1,num_units])*x + tf.reshape(b,[1,num_units])

            # apply nonlinearity
            if nonlinearity is not None:
                x = nonlinearity(x)
            return x

def conv2d_wn(x, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs):
    ''' convolutional layer '''
    name = get_name('conv2d', counters)
    with tf.variable_scope(name):
        if init:
            # data based initialization of parameters
            V = tf.get_variable('V', filter_size+[int(x.get_shape()[-1]),num_filters], tf.float32, tf.random_normal_initializer(0, 0.05), trainable=True)
            V_norm = tf.nn.l2_normalize(V.initialized_value(), [0,1,2])
            x_init = tf.nn.conv2d(x, V_norm, [1]+stride+[1], pad)
            m_init, v_init = tf.nn.moments(x_init, [0,1,2])
            scale_init = init_scale/tf.sqrt(v_init + 1e-8)
            g = tf.get_variable('g', dtype=tf.float32, initializer=scale_init, trainable=True)
            b = tf.get_variable('b', dtype=tf.float32, initializer=-m_init*scale_init, trainable=True)
            x_init = tf.reshape(scale_init,[1,1,1,num_filters])*(x_init-tf.reshape(m_init,[1,1,1,num_filters]))
            if nonlinearity is not None:
                x_init = nonlinearity(x_init)
            return x_init

            V, g, b = get_vars_maybe_avg(['V', 'g', 'b'], ema)

            # use weight normalization (Salimans & Kingma, 2016)
            W = tf.reshape(g,[1,1,1,num_filters])*tf.nn.l2_normalize(V,[0,1,2])

            # calculate convolutional layer output
            x = tf.nn.bias_add(tf.nn.conv2d(x, W, [1]+stride+[1], pad), b)

            # apply nonlinearity
            if nonlinearity is not None:
                x = nonlinearity(x)
            return x

def deconv2d_wn(x, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs):
    ''' transposed convolutional layer '''
    name = get_name('deconv2d', counters)
    xs = int_shape(x)
    if pad=='SAME':
        target_shape = [xs[0], xs[1]*stride[0], xs[2]*stride[1], num_filters]
        target_shape = [xs[0], xs[1]*stride[0] + filter_size[0]-1, xs[2]*stride[1] + filter_size[1]-1, num_filters]
    with tf.variable_scope(name):
        if init:
            # data based initialization of parameters
            V = tf.get_variable('V', filter_size+[num_filters,int(x.get_shape()[-1])], tf.float32, tf.random_normal_initializer(0, 0.05), trainable=True)
            V_norm = tf.nn.l2_normalize(V.initialized_value(), [0,1,3])
            x_init = tf.nn.conv2d_transpose(x, V_norm, target_shape, [1]+stride+[1], padding=pad)
            m_init, v_init = tf.nn.moments(x_init, [0,1,2])
            scale_init = init_scale/tf.sqrt(v_init + 1e-8)
            g = tf.get_variable('g', dtype=tf.float32, initializer=scale_init, trainable=True)
            b = tf.get_variable('b', dtype=tf.float32, initializer=-m_init*scale_init, trainable=True)
            x_init = tf.reshape(scale_init,[1,1,1,num_filters])*(x_init-tf.reshape(m_init,[1,1,1,num_filters]))
            if nonlinearity is not None:
                x_init = nonlinearity(x_init)
            return x_init

            V, g, b = get_vars_maybe_avg(['V', 'g', 'b'], ema)

            # use weight normalization (Salimans & Kingma, 2016)
            W = tf.reshape(g,[1,1,num_filters,1])*tf.nn.l2_normalize(V,[0,1,3])

            # calculate convolutional layer output
            x = tf.nn.conv2d_transpose(x, W, target_shape, [1]+stride+[1], padding=pad)
            x = tf.nn.bias_add(x, b)

            # apply nonlinearity
            if nonlinearity is not None:
                x = nonlinearity(x)
            return x

def gated_resnet(x, a=None, h=None, nonlinearity=concat_elu, conv=conv2d, init=False, counters={}, ema=None, dropout_p=0., **kwargs):
    xs = int_shape(x)
    num_filters = xs[-1]

    c1 = conv(nonlinearity(x), num_filters)
    if a is not None: # add short-cut connection if auxiliary input 'a' is given
        c1 += nin(nonlinearity(a), num_filters)
    c1 = nonlinearity(c1)
    if dropout_p > 0:
        c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p)
    c2 = conv(c1, num_filters * 2, init_scale=0.1)

    # add projection of h vector if included: conditional generation
    if h is not None:
        with tf.variable_scope(get_name('conditional_weights', counters)):
            hw = get_var_maybe_avg('hw', ema, shape=[int_shape(h)[-1], 2 * num_filters], dtype=tf.float32,
                                    initializer=tf.random_normal_initializer(0, 0.05), trainable=True)
        if init:
            hw = hw.initialized_value()
        c2 += tf.reshape(tf.matmul(h, hw), [xs[0], 1, 1, 2 * num_filters])

    a, b = tf.split(3, 2, c2)
    c3 = a * tf.nn.sigmoid(b)
    return x + c3

def cross_entropy(y, smooth=1e-3):
    y = tf.minimum(1-smooth, tf.maximum(smooth, y))
    return tf.reduce_mean(-tf.log(y))

ops_with_bn = [fc, conv2d, deconv2d]