import tensorflow as tf
import numpy as np
from collections import OrderedDict

##################################################################################
# Layer
##################################################################################

# pad = ceil[ (kernel - stride) / 2 ]

def get_weight(weight_shape, gain, lrmul):
    fan_in = np.prod(weight_shape[:-1])  # [kernel, kernel, fmaps_in, fmaps_out] or [in, out]
    he_std = gain / np.sqrt(fan_in)  # He init

    # equalized learning rate
    init_std = 1.0 / lrmul
    runtime_coef = he_std * lrmul

    # create variable.
    weight = tf.get_variable('weight', shape=weight_shape, dtype=tf.float32,
                             initializer=tf.initializers.random_normal(0, init_std)) * runtime_coef
    return weight

def conv(x, channels, kernel=3, stride=1, gain=np.sqrt(2), lrmul=1.0, sn=False, scope='conv_0'):
    with tf.variable_scope(scope):
        weight_shape = [kernel, kernel, x.get_shape().as_list()[-1], channels]

        weight = get_weight(weight_shape, gain, lrmul)

        if sn :
            weight = spectral_norm(weight)

        x = tf.nn.conv2d(input=x, filter=weight, strides=[1, stride, stride, 1], padding='SAME')

        return x

def fully_connected(x, units, gain=np.sqrt(2), lrmul=1.0, sn=False, scope='linear'):
    with tf.variable_scope(scope):
        x = flatten(x)
        weight_shape = [x.get_shape().as_list()[-1], units]
        weight = get_weight(weight_shape, gain, lrmul)

        if sn :
            weight = spectral_norm(weight)

        x = tf.matmul(x, weight)

        return x

def flatten(x) :
    return tf.layers.flatten(x)

##################################################################################
# Activation function
##################################################################################

def lrelu(x, alpha=0.2):
    return tf.nn.leaky_relu(x, alpha)


##################################################################################
# Normalization function
##################################################################################

def spectral_norm(w, iteration=1):
    w_shape = w.shape.as_list()
    w = tf.reshape(w, [-1, w_shape[-1]])

    u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False)

    u_hat = u
    v_hat = None
    for i in range(iteration):
        """
        power iteration
        Usually iteration = 1 will be enough
        """

        v_ = tf.matmul(u_hat, tf.transpose(w))
        v_hat = tf.nn.l2_normalize(v_)

        u_ = tf.matmul(v_hat, w)
        u_hat = tf.nn.l2_normalize(u_)

    u_hat = tf.stop_gradient(u_hat)
    v_hat = tf.stop_gradient(v_hat)

    sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))

    with tf.control_dependencies([u.assign(u_hat)]):
        w_norm = w / sigma
        w_norm = tf.reshape(w_norm, w_shape)

    return w_norm

def pixel_norm(x, epsilon=1e-8):
    with tf.variable_scope('PixelNorm'):
        norm = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
        x = x * tf.rsqrt(norm + epsilon)
    return x

def adaptive_instance_norm(x, w):
    x = instance_norm(x)
    x = style_mod(x, w)
    return x

def instance_norm(x, epsilon=1e-8):
    with tf.variable_scope('InstanceNorm'):
        x = x - tf.reduce_mean(x, axis=[1, 2], keepdims=True)
        x = x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=[1, 2], keepdims=True) + epsilon)

    return x

##################################################################################
# StyleGAN trick function
##################################################################################

def compute_loss(real_images, real_logit, fake_logit):
    r1_gamma, r2_gamma = 10.0, 0.0

    # discriminator loss: gradient penalty
    d_loss_gan = tf.nn.softplus(fake_logit) + tf.nn.softplus(-real_logit)
    real_loss = tf.reduce_sum(real_logit)
    real_grads = tf.gradients(real_loss, [real_images])[0]
    r1_penalty = tf.reduce_sum(tf.square(real_grads), axis=[1, 2, 3])
    # r1_penalty = tf.reduce_mean(r1_penalty)
    d_loss = d_loss_gan + r1_penalty * (r1_gamma * 0.5)
    d_loss = tf.reduce_mean(d_loss)

    # generator loss: logistic nonsaturating
    g_loss = tf.nn.softplus(-fake_logit)
    g_loss = tf.reduce_mean(g_loss)

    return d_loss, g_loss

def lerp(a, b, t):
    # t == 1.0: use b
    # t == 0.0: use a
    with tf.name_scope("Lerp"):
        out = a + (b - a) * t
    return out

def lerp_clip(a, b, t):
    # t >= 1.0: use b
    # t <= 0.0: use a
    with tf.name_scope("LerpClip"):
        out = a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
    return out

def smooth_transition(prv, cur, res, transition_res, alpha):
    # alpha == 1.0: use only previous resolution output
    # alpha == 0.0: use only current resolution output

    with tf.variable_scope('{:d}x{:d}'.format(res, res)):
        with tf.variable_scope('smooth_transition'):
            # use alpha for current resolution transition
            if transition_res == res:
                out = lerp_clip(cur, prv, alpha)

            # ex) transition_res=32, current_res=16
            # use res=16 block output
            else:   # transition_res > res
                out = lerp_clip(cur, prv, 0.0)
    return out

def smooth_transition_state(batch_size, global_step, train_trans_images_per_res_tensor, zero_constant):
    # alpha == 1.0: use only previous resolution output
    # alpha == 0.0: use only current resolution output
    n_cur_img = batch_size * global_step
    n_cur_img = tf.cast(n_cur_img, dtype=tf.float32)

    is_transition_state = tf.less_equal(n_cur_img, train_trans_images_per_res_tensor)
    alpha = tf.cond(is_transition_state,
                    true_fn=lambda: (train_trans_images_per_res_tensor - n_cur_img) / train_trans_images_per_res_tensor,
                    false_fn=lambda: zero_constant)
    return alpha

def get_alpha_const(iterations, batch_size, global_step) :
    # additional variables (reuse zero constants)
    zero_constant = tf.constant(0.0, dtype=tf.float32, shape=[])

    # additional variables (for training only)
    train_trans_images_per_res_tensor = tf.constant(iterations, dtype=tf.float32, shape=[], name='train_trans_images_per_res')

    # determine smooth transition state and compute alpha value
    alpha_const = smooth_transition_state(batch_size, global_step, train_trans_images_per_res_tensor, zero_constant)

    return alpha_const, zero_constant

##################################################################################
# StyleGAN discriminator
##################################################################################

def discriminator_block(x, res, n_f0, n_f1, sn=False):
    with tf.variable_scope('{:d}x{:d}'.format(res, res)):
        with tf.variable_scope('Conv0'):
            x = conv(x, channels=n_f0, kernel=3, stride=1, gain=np.sqrt(2), lrmul=1.0, sn=sn)
            x = apply_bias(x, lrmul=1.0)
            x = lrelu(x, 0.2)

        with tf.variable_scope('Conv1_down'):
            x = blur2d(x, [1, 2, 1])
            x = downscale_conv(x, n_f1, kernel=3, gain=np.sqrt(2), lrmul=1.0, sn=sn)
            x = apply_bias(x, lrmul=1.0)
            x = lrelu(x, 0.2)

    return x

def discriminator_last_block(x, res, n_f0, n_f1, sn=False):

    with tf.variable_scope('{:d}x{:d}'.format(res, res)):
        x = minibatch_stddev_layer(x, group_size=4, num_new_features=1)

        with tf.variable_scope('Conv0'):
            x = conv(x, channels=n_f0, kernel=3, stride=1, gain=np.sqrt(2), lrmul=1.0, sn=sn)
            x = apply_bias(x, lrmul=1.0)
            x = lrelu(x, 0.2)

        with tf.variable_scope('Dense0'):
            x = fully_connected(x, units=n_f1, gain=np.sqrt(2), lrmul=1.0, sn=sn)
            x = apply_bias(x, lrmul=1.0)
            x = lrelu(x, 0.2)

        with tf.variable_scope('Dense1'):
            x = fully_connected(x, units=1, gain=1.0, lrmul=1.0, sn=sn)
            x = apply_bias(x, lrmul=1.0)

    return x

##################################################################################
# StyleGAN generator
##################################################################################

def get_style_class(resolutions, featuremaps) :

    coarse_styles = OrderedDict()
    middle_styles = OrderedDict()
    fine_styles = OrderedDict()

    for res, n_f in zip(resolutions, featuremaps) :
        if res >= 4 and res <= 8 :
            coarse_styles[res] = n_f
        elif res >= 16 and res <= 32 :
            middle_styles[res] = n_f
        else :
            fine_styles[res] = n_f

    return coarse_styles, middle_styles, fine_styles

def synthesis_const_block(res, w_broadcasted, n_f, sn=False):
    w0 = w_broadcasted[:, 0]
    w1 = w_broadcasted[:, 1]


    batch_size = tf.shape(w0)[0]

    with tf.variable_scope('{:d}x{:d}'.format(res, res)):
        with tf.variable_scope('const_block'):
            x = tf.get_variable('Const', shape=[1, 4, 4, n_f], dtype=tf.float32, initializer=tf.initializers.ones())
            x = tf.tile(x, [batch_size, 1, 1, 1])

            x = apply_noise(x) # B module
            x = apply_bias(x, lrmul=1.0)

            x = lrelu(x, 0.2)
            x = adaptive_instance_norm(x, w0) # A module

        with tf.variable_scope('Conv'):
            x = conv(x, channels=n_f, kernel=3, stride=1, gain=np.sqrt(2), lrmul=1.0, sn=sn)

            x = apply_noise(x) # B module
            x = apply_bias(x, lrmul=1.0)

            x = lrelu(x, 0.2)
            x = adaptive_instance_norm(x, w1) # A module

    return x

def synthesis_block(x, res, w_broadcasted, layer_index, n_f, sn=False):
    w0 = w_broadcasted[:, layer_index]
    w1 = w_broadcasted[:, layer_index + 1]

    with tf.variable_scope('{:d}x{:d}'.format(res, res)):
        with tf.variable_scope('Conv0_up'):
            x = upscale_conv(x, n_f, kernel=3, gain=np.sqrt(2), lrmul=1.0, sn=sn)
            x = blur2d(x, [1, 2, 1])

            x = apply_noise(x) # B module
            x = apply_bias(x, lrmul=1.0)

            x = lrelu(x, 0.2)
            x = adaptive_instance_norm(x, w0) # A module

        with tf.variable_scope('Conv1'):
            x = conv(x, n_f, kernel=3, stride=1, gain=np.sqrt(2), lrmul=1.0, sn=sn)

            x = apply_noise(x) # B module
            x = apply_bias(x, lrmul=1.0)

            x = lrelu(x, 0.2)
            x = adaptive_instance_norm(x, w1) # A module

    return x

##################################################################################
# StyleGAN Etc
##################################################################################

def downscale_conv(x, channels, kernel, gain, lrmul, sn=False):
    height, width = x.shape[1], x.shape[2]
    fused_scale = (min(height, width) * 2) >= 128

    # Not fused => call the individual ops directly.
    if not fused_scale:
        x = conv(x, channels=channels, kernel=kernel, stride=1, gain=gain, lrmul=lrmul, sn=sn)
        x = downscale2d(x)
        return x

    # Fused => perform both ops simultaneously using tf.nn.conv2d().
    weight = get_weight([kernel, kernel, x.get_shape().as_list()[-1], channels], gain, lrmul)
    weight = tf.pad(weight, [[1, 1], [1, 1], [0, 0], [0, 0]], mode='CONSTANT')
    weight = tf.add_n([weight[1:, 1:], weight[:-1, 1:], weight[1:, :-1], weight[:-1, :-1]]) * 0.25

    if sn:
        weight = spectral_norm(weight)

    x = tf.nn.conv2d(input=x, filter=weight, strides=[1, 2, 2, 1], padding='SAME')

    return x

def upscale_conv(x, channels, kernel, gain=np.sqrt(2), lrmul=1.0, sn=False):
    batch_size = tf.shape(x)[0]
    height, width = x.shape[1], x.shape[2]
    fused_scale = (min(height, width) * 2) >= 128

    # Not fused => call the individual ops directly.
    if not fused_scale:
        x = upscale2d(x)
        x = conv(x, channels=channels, kernel=kernel, stride=1, gain=gain, lrmul=lrmul, sn=sn)
        return x

    # Fused => perform both ops simultaneously using tf.nn.conv2d_transpose().
    weight_shape = [kernel, kernel, channels, x.get_shape().as_list()[-1]]
    output_shape = [batch_size, height * 2, width * 2, channels]

    weight = get_weight(weight_shape, gain, lrmul)
    weight = tf.pad(weight, [[1, 1], [1, 1], [0, 0], [0, 0]], mode='CONSTANT')
    weight = tf.add_n([weight[1:, 1:], weight[:-1, 1:], weight[1:, :-1], weight[:-1, :-1]])

    if sn:
        weight = spectral_norm(weight)

    x = tf.nn.conv2d_transpose(x, filter=weight, output_shape=output_shape, strides=[1, 2, 2, 1], padding='SAME')

    return x

def torgb(x, res, sn=False):
    with tf.variable_scope('{:d}x{:d}'.format(res, res)):
        with tf.variable_scope('ToRGB'):
            x = conv(x, channels=3, kernel=1, stride=1, gain=1.0, lrmul=1.0, sn=sn)
            x = apply_bias(x, lrmul=1.0)
    return x

def fromrgb(x, res, n_f, sn=False):
    with tf.variable_scope('{:d}x{:d}'.format(res, res)):
        with tf.variable_scope('FromRGB'):
            x = conv(x, channels=n_f, kernel=1, stride=1, gain=np.sqrt(2), lrmul=1.0, sn=sn)
            x = apply_bias(x, lrmul=1.0)
            x = lrelu(x, 0.2)
    return x

def style_mod(x, w):
    with tf.variable_scope('StyleMod'):
        units = x.shape[-1] * 2
        style = fully_connected(w, units=units, gain=1.0, lrmul=1.0)
        style = apply_bias(style, lrmul=1.0)

        style = tf.reshape(style, [-1, 2, 1, 1, x.shape[-1]])
        x = x * (style[:, 0] + 1) + style[:, 1]

    return x

def apply_noise(x):
    with tf.variable_scope('Noise'):
        noise = tf.random_normal([tf.shape(x)[0], x.shape[1], x.shape[2], 1])
        weight = tf.get_variable('weight', shape=[x.get_shape().as_list()[-1]], initializer=tf.initializers.zeros())
        weight = tf.reshape(weight, [1, 1, 1, -1])
        x = x + noise * weight

    return x

def apply_bias(x, lrmul):
    b = tf.get_variable('bias', shape=[x.shape[-1]], initializer=tf.initializers.zeros()) * lrmul

    if len(x.shape) == 2:
        x = x + b
    else:
        x = x + tf.reshape(b, [1, 1, 1, -1])

    return x

##################################################################################
# StyleGAN Official operation
##################################################################################

# ----------------------------------------------------------------------------
# Primitive ops for manipulating 4D activation tensors.
# The gradients of these are not necessary efficient or even meaningful.
def _blur2d(x, f, normalize=True, flip=False, stride=1):
    assert x.shape.ndims == 4 and all(dim.value is not None for dim in x.shape[1:])
    assert isinstance(stride, int) and stride >= 1

    # Finalize filter kernel.
    f = np.array(f, dtype=np.float32)
    if f.ndim == 1:
        f = f[:, np.newaxis] * f[np.newaxis, :]
    assert f.ndim == 2
    if normalize:
        f /= np.sum(f)
    if flip:
        f = f[::-1, ::-1]
    f = f[:, :, np.newaxis, np.newaxis]
    f = np.tile(f, [1, 1, int(x.shape[-1]), 1])

    # No-op => early exit.
    if f.shape == (1, 1) and f[0, 0] == 1:
        return x

    # Convolve using depthwise_conv2d.
    orig_dtype = x.dtype
    x = tf.cast(x, tf.float32)  # tf.nn.depthwise_conv2d() doesn't support fp16
    f = tf.constant(f, dtype=x.dtype, name='filter')
    strides = [1, stride, stride, 1]
    x = tf.nn.depthwise_conv2d(x, f, strides=strides, padding='SAME')
    x = tf.cast(x, orig_dtype)
    return x


def _upscale2d(x, factor=2, gain=1):
    assert x.shape.ndims == 4 and all(dim.value is not None for dim in x.shape[1:])
    assert isinstance(factor, int) and factor >= 1

    # Apply gain.
    if gain != 1:
        x *= gain

    # No-op => early exit.
    if factor == 1:
        return x

    # Upscale using tf.tile().
    s = x.shape # [bs, h, w, c]
    x = tf.reshape(x, [-1, s[1], 1, s[2], 1, s[-1]])
    x = tf.tile(x, [1, 1, factor, 1, factor, 1])
    x = tf.reshape(x, [-1, s[1] * factor, s[2] * factor, s[-1]])
    return x


def _downscale2d(x, factor=2, gain=1):
    assert x.shape.ndims == 4 and all(dim.value is not None for dim in x.shape[1:])
    assert isinstance(factor, int) and factor >= 1

    # 2x2, float32 => downscale using _blur2d().
    if factor == 2 and x.dtype == tf.float32:
        f = [np.sqrt(gain) / factor] * factor
        return _blur2d(x, f=f, normalize=False, stride=factor)

    # Apply gain.
    if gain != 1:
        x *= gain

    # No-op => early exit.
    if factor == 1:
        return x

    # Large factor => downscale using tf.nn.avg_pool().
    # NOTE: Requires tf_config['graph_options.place_pruned_graph']=True to work.
    ksize = [1, factor, factor, 1]
    return tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding='VALID')


# ----------------------------------------------------------------------------
# High-level ops for manipulating 4D activation tensors.
# The gradients of these are meant to be as efficient as possible.

def blur2d(x, f, normalize=True):
    with tf.variable_scope('Blur2D'):
        @tf.custom_gradient
        def func(x):
            y = _blur2d(x, f, normalize)

            @tf.custom_gradient
            def grad(dy):
                dx = _blur2d(dy, f, normalize, flip=True)
                return dx, lambda ddx: _blur2d(ddx, f, normalize)

            return y, grad

        return func(x)


def upscale2d(x, factor=2):
    with tf.variable_scope('Upscale2D'):
        @tf.custom_gradient
        def func(x):
            y = _upscale2d(x, factor)

            @tf.custom_gradient
            def grad(dy):
                dx = _downscale2d(dy, factor, gain=factor ** 2)
                return dx, lambda ddx: _upscale2d(ddx, factor)

            return y, grad

        return func(x)


def downscale2d(x, factor=2):
    with tf.variable_scope('Downscale2D'):
        @tf.custom_gradient
        def func(x):
            y = _downscale2d(x, factor)

            @tf.custom_gradient
            def grad(dy):
                dx = _upscale2d(dy, factor, gain=1 / factor ** 2)
                return dx, lambda ddx: _downscale2d(ddx, factor)

            return y, grad

        return func(x)


def minibatch_stddev_layer(x, group_size=4, num_new_features=1):
    with tf.variable_scope('MinibatchStddev'):
        group_size = tf.minimum(group_size, tf.shape(x)[0])
        s = x.shape
        y = tf.reshape(x, [group_size, -1, num_new_features, s[3] // num_new_features, s[1], s[2]])
        y = tf.cast(y, tf.float32)
        y -= tf.reduce_mean(y, axis=0, keepdims=True)
        y = tf.reduce_mean(tf.square(y), axis=0)
        y = tf.sqrt(y + 1e-8)
        y = tf.reduce_mean(y, axis=[2, 3, 4], keepdims=True)
        y = tf.reduce_mean(y, axis=2)
        y = tf.cast(y, x.dtype)

        y = tf.tile(y, [group_size, s[1], s[2], 1])
        return tf.concat([x, y], axis=-1)

##################################################################################
# Etc
##################################################################################

def filter_trainable_variables(res):
    res_in_focus = [2 ** r for r in range(int(np.log2(res)), 1, -1)]
    res_in_focus = res_in_focus[::-1]

    t_vars = tf.trainable_variables()
    d_vars = list()
    g_vars = list()
    for var in t_vars:
        if var.name.startswith('generator') :
            if 'g_mapping' in var.name:
                g_vars.append(var)
            elif 'g_synthesis' in var.name:
                for r in res_in_focus:
                    if '{:d}x{:d}'.format(r, r) in var.name:
                        g_vars.append(var)
        elif var.name.startswith('discriminator'):
            for r in res_in_focus:
                if '{:d}x{:d}'.format(r, r) in var.name:
                    d_vars.append(var)

    return d_vars, g_vars

def resolution_list(img_size) :

    res = 4
    x = []

    while True :
        if res > img_size :
            break
        else :
            x.append(res)
            res = res * 2

    return x

def featuremap_list(img_size) :

    start_feature_map = 512
    feature_map = start_feature_map
    x = []

    fix_num = 0

    while True :
        if img_size < 4 :
            break
        else :
            x.append(feature_map)
            img_size = img_size // 2

            if fix_num > 2 :
                feature_map = feature_map // 2

            fix_num += 1

    return x

def get_batch_sizes(gpu_num) :

    # batch size for each gpu

    if gpu_num == 1 :
        x = OrderedDict([(4, 128), (8, 128), (16, 128), (32, 64), (64, 32), (128, 16), (256, 8), (512, 4), (1024, 4)])

    elif gpu_num == 2 or gpu_num == 3 :
        x = OrderedDict([(4, 128), (8, 128), (16, 64), (32, 32), (64, 16), (128, 8), (256, 4), (512, 4), (1024, 4)])

    elif gpu_num == 4 or gpu_num == 5 or gpu_num == 6 :
        x = OrderedDict([(4, 128), (8, 64), (16, 32), (32, 16), (64, 8), (128, 4), (256, 4), (512, 4), (1024, 4)])

    elif gpu_num == 7 or gpu_num == 8 or gpu_num == 9 :
        x = OrderedDict([(4, 64), (8, 32), (16, 16), (32, 8), (64, 4), (128, 4), (256, 4), (512, 4), (1024, 4)])

    else : # >= 10
        x = OrderedDict([(4, 32), (8, 16), (16, 8), (32, 4), (64, 2), (128, 2), (256, 2), (512, 2), (1024, 2)])

    return x

def get_end_iteration(iter, max_iter, do_trans, res_list, start_res) :

    end_iter = max_iter

    for res in res_list[res_list.index(start_res):-1] :
        if do_trans[res] :
            end_iter -= iter
        else :
            end_iter -= iter // 2

    return end_iter