import tensorflow as tf
from keras import backend as K
from keras.models import Model
from keras.layers import Conv2D, Dense, Activation, Input, UpSampling2D
from keras.layers import concatenate, Flatten, Reshape, Lambda
from keras.layers import LeakyReLU, MaxPooling2D
import keras


def my_conv(x_in, nf, ks=3, strides=1, activation='lrelu', name=None):
    x_out = Conv2D(nf, kernel_size=ks, padding='same', strides=strides)(x_in)

    if activation == 'lrelu':
        x_out = LeakyReLU(0.2, name=name)(x_out)
    elif activation != 'none':
        x_out = Activation(activation, name=name)(x_out)

    return x_out


def vgg_loss(feat_net, feat_weights, n_layers, reg=0.1):
    def loss_fcn(y_true, y_pred):
        y_true_feat = feat_net(Lambda(vgg_preprocess)(y_true))
        y_pred_feat = feat_net(Lambda(vgg_preprocess)(y_pred))

        loss = []
        for j in range(n_layers):

            std = feat_weights[str(j)][1] + reg
            std = tf.expand_dims(tf.expand_dims(tf.expand_dims(std, 0), 0), 0)
            d = tf.subtract(y_true_feat[j], y_pred_feat[j])
            loss_j = tf.reduce_mean(tf.abs(tf.divide(d, std)))

            if j == 0:
                loss = loss_j
            else:
                loss = tf.add(loss, loss_j)
        return loss / (n_layers * 1.0)

    return loss_fcn


def vgg_preprocess(arg):
    z = 255.0 * (arg + 1.0) / 2.0
    r = z[:, :, :, 0] - 103.939
    g = z[:, :, :, 1] - 116.779
    b = z[:, :, :, 2] - 123.68
    return tf.stack([r, g, b], axis=3)


def make_trainable(net, val):
    net.trainable = val
    for l in net.layers:
        l.trainable = val


def discriminator(param):
    img_h = param['IMG_HEIGHT']
    img_w = param['IMG_WIDTH']
    n_joints = param['n_joints']
    pose_dn = param['posemap_downsample']

    x_tgt = Input(shape=(img_h, img_w, 3))
    x_src_pose = Input(shape=(img_h / pose_dn, img_w / pose_dn, n_joints))
    x_tgt_pose = Input(shape=(img_h / pose_dn, img_w / pose_dn, n_joints))

    x = my_conv(x_tgt, 64, ks=5)
    x = MaxPooling2D()(x) # 128
    x = concatenate([x, x_src_pose, x_tgt_pose])
    x = my_conv(x, 128, ks=5)
    x = MaxPooling2D()(x) # 64
    x = my_conv(x, 256)
    x = MaxPooling2D()(x) # 32
    x = my_conv(x, 256)
    x = MaxPooling2D()(x) # 16
    x = my_conv(x, 256)
    x = MaxPooling2D()(x) # 8
    x = my_conv(x, 256)  # 8

    x = Flatten()(x)

    x = Dense(256, activation='relu')(x)
    x = Dense(256, activation='relu')(x)
    y = Dense(1, activation='sigmoid')(x)

    model = Model(inputs=[x_tgt, x_src_pose, x_tgt_pose], outputs=y, name='discriminator')
    return model


def wass(y_true, y_pred):
    return tf.reduce_mean(y_true * y_pred)


def gan(gen_model, disc_model, param):

    img_h = param['IMG_HEIGHT']
    img_w = param['IMG_WIDTH']
    n_joints = param['n_joints']
    n_limbs = param['n_limbs']
    pose_dn = param['posemap_downsample']

    src_in = Input(shape=(img_h, img_w, 3))
    pose_src = Input(shape=(img_h / pose_dn, img_w / pose_dn, n_joints))
    pose_tgt = Input(shape=(img_h / pose_dn, img_w / pose_dn, n_joints))
    mask_in = Input(shape=(img_h, img_w, n_limbs+1))
    trans_in = Input(shape=(2, 3, n_limbs+1))

    make_trainable(disc_model, False)
    y_gen = gen_model([src_in, pose_src, pose_tgt, mask_in, trans_in])
    y_class = disc_model([y_gen, pose_src, pose_tgt])

    gan_model = Model(inputs=[src_in, pose_src, pose_tgt, mask_in, trans_in],
                      outputs=[y_gen, y_class], name='gan')

    return gan_model


def repeat(x,n_repeats):
    rep = tf.transpose(tf.expand_dims(tf.ones(shape=tf.stack([n_repeats,])), 1),[1,0])
    rep = tf.cast(rep, dtype='int32')
    x = tf.matmul(tf.reshape(x, (-1, 1)), rep)
    return tf.reshape(x,[-1])
        
def meshgrid(height, width):
    x_t = tf.matmul(tf.ones(shape=tf.stack([height, 1])),
                    tf.transpose(tf.expand_dims(tf.linspace(0.0,
                                 tf.cast(width,tf.float32)-1.0, width), 1), [1, 0]))
    y_t = tf.matmul(tf.expand_dims(tf.linspace(0.0,
                    tf.cast(height,tf.float32)-1.0, height), 1),
                    tf.ones(shape=tf.stack([1, width])))
    return x_t,y_t
        
def interpolate(im,x,y):
            
    im = tf.pad(im,[[0,0],[1,1],[1,1],[0,0]],"REFLECT")
                
    num_batch = tf.shape(im)[0]
    height = tf.shape(im)[1]
    width = tf.shape(im)[2]
    channels = tf.shape(im)[3]

    out_height = tf.shape(x)[1]
    out_width = tf.shape(x)[2]
                
    x = tf.reshape(x,[-1])
    y = tf.reshape(y,[-1])
                
    x = tf.cast(x, 'float32')+1
    y = tf.cast(y, 'float32')+1
                
    max_x = tf.cast(width - 1, 'int32')
    max_y = tf.cast(height - 1, 'int32')
                
    x0 = tf.cast(tf.floor(x), 'int32')
    x1 = x0 + 1
    y0 = tf.cast(tf.floor(y), 'int32')
    y1 = y0 + 1
                
    x0 = tf.clip_by_value(x0, 0, max_x)
    x1 = tf.clip_by_value(x1, 0, max_x)
    y0 = tf.clip_by_value(y0, 0, max_y)
    y1 = tf.clip_by_value(y1, 0, max_y)
                
    base = repeat(tf.range(num_batch)*width*height, (out_height*out_width))

    base_y0 = base + y0*width
    base_y1 = base + y1*width

    idx_a = base_y0 + x0
    idx_b = base_y1 + x0
    idx_c = base_y0 + x1
    idx_d = base_y1 + x1
                
    # use indices to lookup pixels in the flat image and restore
    # channels dim
    im_flat = tf.reshape(im, tf.stack([-1, channels]))
    im_flat = tf.cast(im_flat, 'float32')
                
    Ia = tf.gather(im_flat, idx_a)
    Ib = tf.gather(im_flat, idx_b)
    Ic = tf.gather(im_flat, idx_c)
    Id = tf.gather(im_flat, idx_d)
                
    # and finally calculate interpolated values
    x1_f = tf.cast(x1, 'float32')
    y1_f = tf.cast(y1, 'float32')
                
    dx = x1_f - x
    dy = y1_f - y
                
    wa = tf.expand_dims((dx * dy), 1)
    wb = tf.expand_dims((dx * (1-dy)), 1)
    wc = tf.expand_dims(((1-dx) * dy), 1)
    wd = tf.expand_dims(((1-dx) * (1-dy)), 1)
                
    output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
    output = tf.reshape(output, tf.stack([-1,out_height,out_width,channels]))
    return output


def affine_warp(im, theta):
    num_batch = tf.shape(im)[0]
    height = tf.shape(im)[1]
    width = tf.shape(im)[2]

    x_t, y_t = meshgrid(height, width)
    x_t_flat = tf.reshape(x_t, (1, -1))
    y_t_flat = tf.reshape(y_t, (1, -1))
    ones = tf.ones_like(x_t_flat)
    grid = tf.concat(axis=0, values=[x_t_flat, y_t_flat, ones])
    grid = tf.expand_dims(grid, 0)
    grid = tf.reshape(grid, [-1])
    grid = tf.tile(grid, tf.stack([num_batch]))
    grid = tf.reshape(grid, tf.stack([num_batch, 3, -1]))

    T_g = tf.matmul(theta, grid)
    x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1])
    y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1])

    x_s = tf.reshape(x_s, [num_batch, height, width])
    y_s = tf.reshape(y_s, [num_batch, height, width])

    return interpolate(im, x_s, y_s)	


def make_warped_stack(args):
    mask = args[0]
    src_in = args[1]
    trans_in = args[2]

    for i in range(11):
        mask_i = K.repeat_elements(tf.expand_dims(mask[:, :, :, i], 3), 3, 3)
        src_masked = tf.multiply(mask_i, src_in)

        if i == 0:
            warps = src_masked
        else:
            warp_i = affine_warp(src_masked, trans_in[:, :, :, i])
            warps = tf.concat([warps, warp_i], 3)

    return warps


def interp_upsampling(im):
    [xx,yy] = meshgrid(tf.shape(im)[1]*2,tf.shape(im)[2]*2)
    xx = tf.expand_dims(xx/2.0,0)
    yy = tf.expand_dims(yy/2.0,0)
    xx = tf.tile(xx, [tf.shape(im)[0], 1, 1])
    yy = tf.tile(yy, [tf.shape(im)[0], 1, 1])
    return interpolate(im, xx, yy)

def unet(x_in, pose_in, nf_enc, nf_dec):
    x0 = my_conv(x_in, nf_enc[0], ks=7)  # 256
    x1 = my_conv(x0, nf_enc[1], strides=2)  # 128
    x2 = concatenate([x1, pose_in])
    x3 = my_conv(x2, nf_enc[2])
    x4 = my_conv(x3, nf_enc[3], strides=2)  # 64
    x5 = my_conv(x4, nf_enc[4])
    x6 = my_conv(x5, nf_enc[5], strides=2)  # 32
    x7 = my_conv(x6, nf_enc[6])
    x8 = my_conv(x7, nf_enc[7], strides=2)  # 16
    x9 = my_conv(x8, nf_enc[8])
    x10 = my_conv(x9, nf_enc[9], strides=2)  # 8
    x = my_conv(x10, nf_enc[10])

    skips = [x9, x7, x5, x3, x0]
    filters = [nf_enc[10], nf_dec[0], nf_dec[1], nf_dec[2], nf_enc[3]]

    for i in range(5):
        out_sz = 8*(2**(i+1))
        x = Lambda(interp_upsampling, output_shape = (out_sz, out_sz, filters[i]))(x)
        x = concatenate([x, skips[i]])
        x = my_conv(x, nf_dec[i])

    return x


def network_posewarp(param):
    img_h = param['IMG_HEIGHT']
    img_w = param['IMG_WIDTH']
    n_joints = param['n_joints']
    pose_dn = param['posemap_downsample']
    n_limbs = param['n_limbs']

    # Inputs
    src_in = Input(shape=(img_h, img_w, 3))
    pose_src = Input(shape=(img_h / pose_dn, img_w / pose_dn, n_joints))
    pose_tgt = Input(shape=(img_h / pose_dn, img_w / pose_dn, n_joints))
    src_mask_prior = Input(shape=(img_h, img_w, n_limbs+1))
    trans_in = Input(shape=(2, 3, n_limbs+1))

    # 1. FG/BG separation
    x = unet(src_in, pose_src, [64]*2 + [128]*9, [128]*4 + [32])
    src_mask_delta = my_conv(x, 11, activation='linear')
    src_mask = keras.layers.add([src_mask_delta, src_mask_prior])
    src_mask = Activation('softmax', name='mask_src')(src_mask)

    # 2. Separate into fg limbs and background
    warped_stack = Lambda(make_warped_stack)([src_mask, src_in, trans_in])
    fg_stack = Lambda(lambda arg: arg[:, :, :, 3:], output_shape=(img_h, img_w, 3*n_limbs),
                      name='fg_stack')(warped_stack)
    bg_src = Lambda(lambda arg: arg[:, :, :, 0:3], output_shape=(img_h, img_w, 3),
                    name='bg_src')(warped_stack)
    bg_src_mask = Lambda(lambda arg: tf.expand_dims(arg[:, :, :, 0], 3))(src_mask)

    # 3. BG/FG synthesis
    x = unet(concatenate([bg_src, bg_src_mask]), pose_src, [64]*2 + [128]*9, [128]*4 + [64])
    bg_tgt = my_conv(x, 3, activation='tanh', name='bg_tgt')

    x = unet(fg_stack, pose_tgt, [64]*2 + [128]*9, [128]*4 + [64])
    # x = unet(fg_stack, pose_tgt, [64] + [128] * 3 + [256] * 7, [256, 256, 256, 128, 64])

    fg_tgt = my_conv(x, 3, activation='tanh', name='fg_tgt')

    fg_mask = my_conv(x, 1, activation='sigmoid', name='fg_mask_tgt')
    fg_mask = concatenate([fg_mask, fg_mask, fg_mask])
    bg_mask = Lambda(lambda arg: 1 - arg)(fg_mask)

    # 5. Merge bg and fg
    fg_tgt = keras.layers.multiply([fg_tgt, fg_mask], name='fg_tgt_masked')
    bg_tgt = keras.layers.multiply([bg_tgt, bg_mask], name='bg_tgt_masked')
    y = keras.layers.add([fg_tgt, bg_tgt])

    model = Model(inputs=[src_in, pose_src, pose_tgt, src_mask_prior, trans_in], outputs=[y])

    return model


def network_unet(param):
    n_joints = param['n_joints']
    pose_dn = param['posemap_downsample']
    img_h = param['IMG_HEIGHT']
    img_w = param['IMG_WIDTH']

    src_in = Input(shape=(img_h, img_w, 3))
    pose_src = Input(shape=(img_h / pose_dn, img_w / pose_dn, n_joints))
    pose_tgt = Input(shape=(img_h / pose_dn, img_w / pose_dn, n_joints))

    x = unet(src_in, concatenate([pose_src, pose_tgt]), [64] + [128] * 3 + [256] * 7,
             [256, 256, 256, 128, 64])
    y = my_conv(x, 3, activation='tanh')

    model = Model(inputs=[src_in, pose_src, pose_tgt], outputs=[y])
    return model