from __future__ import print_function
# from trainer_Base import *

import os, pdb
import StringIO
import scipy.misc
import numpy as np
import glob
from itertools import chain
from collections import deque
import pickle, shutil
from tqdm import tqdm
from tqdm import trange

from skimage.measure import compare_ssim as ssim
from skimage.color import rgb2gray
from PIL import Image
from tensorflow.python.ops import control_flow_ops, sparse_ops

import models
from utils import *
import tflib as lib
from wgan_gp import *
from datasets import market1501, deepfashion, dataset_utils

##############################################################################################
######################### Market1501 with FgBgPose BodyROIVis ################################
class DPIG_Encoder_GAN_BodyROI(object):
    def __init__(self, config):
        self._common_init(config)

        self.keypoint_num = 18
        self.D_arch = config.D_arch
        self.part_num = 37  ## Also change *7 --> *37 in datasets/market1501.py
        if 'market' in config.dataset.lower():
            if config.is_train:
                self.dataset_obj = market1501.get_split('train', config.data_path)
            else:
                self.dataset_obj = market1501.get_split('test', config.data_path)

        self.x, self.x_target, self.pose, self.pose_target, self.pose_rcv, self.pose_rcv_target, self.mask_r4, self.mask_r4_target, self.mask_r6, self.mask_r6_target, \
                self.part_bbox, self.part_bbox_target, self.part_vis, self.part_vis_target = self._load_batch_pair_pose(self.dataset_obj)
                
    def _common_init(self, config):
        self.config = config
        self.data_loader = None
        self.dataset = config.dataset

        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.optimizer = config.optimizer
        self.batch_size = config.batch_size

        self.step = tf.Variable(config.start_step, name='step', trainable=False)

        self.g_lr = tf.Variable(config.g_lr, dtype=tf.float32, name='g_lr')
        self.d_lr = tf.Variable(config.d_lr, dtype=tf.float32, name='d_lr')
        self.g_lr_update = tf.assign(self.g_lr, self.g_lr * 0.5, name='g_lr_update')
        self.d_lr_update = tf.assign(self.d_lr, self.d_lr * 0.5, name='d_lr_update')

        self.gamma = config.gamma
        self.lambda_k = config.lambda_k

        self.z_num = config.z_num
        self.conv_hidden_num = config.conv_hidden_num
        self.img_H, self.img_W = config.img_H, config.img_W

        self.model_dir = config.model_dir
        self.load_path = config.load_path

        self.use_gpu = config.use_gpu
        self.data_format = config.data_format

        _, height, width, self.channel = self._get_conv_shape()
        self.repeat_num = int(np.log2(height)) - 2

        self.data_path = config.data_path
        self.pretrained_path = config.pretrained_path
        self.pretrained_appSample_path = config.pretrained_appSample_path
        self.pretrained_poseAE_path = config.pretrained_poseAE_path
        self.pretrained_poseSample_path = config.pretrained_poseSample_path
        self.ckpt_path = config.ckpt_path
        self.z_emb_dir = config.z_emb_dir
        self.start_step = config.start_step
        self.log_step = config.log_step
        self.max_step = config.max_step
        # self.save_model_secs = config.save_model_secs
        self.lr_update_step = config.lr_update_step

        self.is_train = config.is_train
        self.sample_app = config.sample_app
        self.sample_fg = config.sample_fg
        self.sample_bg = config.sample_bg
        self.sample_pose = config.sample_pose
        self.one_app_per_batch = config.one_app_per_batch
        self.interpolate_fg = config.interpolate_fg
        self.interpolate_fg_up = config.interpolate_fg_up
        self.interpolate_fg_down = config.interpolate_fg_down
        self.interpolate_bg = config.interpolate_bg
        self.interpolate_pose = config.interpolate_pose
        self.inverse_fg = config.inverse_fg
        self.inverse_bg = config.inverse_bg
        self.inverse_pose = config.inverse_pose
        self.config = config
        if self.is_train:
            self.num_threads = 4
            self.capacityCoff = 2
        else: # during testing to keep the order of the input data
            self.num_threads = 1
            self.capacityCoff = 1

    def _get_conv_shape(self):
        shape = [self.batch_size, self.img_H, self.img_W, 3]
        return shape

    def _getOptimizer(self, wgan_gp, gen_cost, disc_cost, G_var, D_var):
        clip_disc_weights = None
        if wgan_gp.MODE == 'wgan':
            gen_train_op = tf.train.RMSPropOptimizer(learning_rate=self.g_lr).minimize(gen_cost,
                                                 var_list=G_var, colocate_gradients_with_ops=True)
            disc_train_op = tf.train.RMSPropOptimizer(learning_rate=self.d_lr).minimize(disc_cost,
                                                 var_list=D_var, colocate_gradients_with_ops=True)

            clip_ops = []
            for var in lib.params_with_name('Discriminator'):
                clip_bounds = [-.01, .01]
                clip_ops.append(tf.assign(var, tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])))
            clip_disc_weights = tf.group(*clip_ops)

        elif wgan_gp.MODE == 'wgan-gp':
            gen_train_op = tf.train.AdamOptimizer(learning_rate=self.g_lr, beta1=0.5, beta2=0.9).minimize(gen_cost,
                                              var_list=G_var, colocate_gradients_with_ops=True)
            disc_train_op = tf.train.AdamOptimizer(learning_rate=self.d_lr, beta1=0.5, beta2=0.9).minimize(disc_cost,
                                               var_list=D_var, colocate_gradients_with_ops=True)

        elif wgan_gp.MODE == 'dcgan':
            gen_train_op = tf.train.AdamOptimizer(learning_rate=self.g_lr, beta1=0.5).minimize(gen_cost,
                                              var_list=G_var, colocate_gradients_with_ops=True)
            disc_train_op = tf.train.AdamOptimizer(learning_rate=self.d_lr, beta1=0.5).minimize(disc_cost,
                                               var_list=D_var, colocate_gradients_with_ops=True)

        elif wgan_gp.MODE == 'lsgan':
            gen_train_op = tf.train.RMSPropOptimizer(learning_rate=self.g_lr).minimize(gen_cost,
                                                 var_list=G_var, colocate_gradients_with_ops=True)
            disc_train_op = tf.train.RMSPropOptimizer(learning_rate=self.d_lr).minimize(disc_cost,
                                                  var_list=D_var, colocate_gradients_with_ops=True)
        else:
            raise Exception()
        return gen_train_op, disc_train_op, clip_disc_weights

    def _getDiscriminator(self, wgan_gp, arch='DCGAN'):
        if 'DCGAN'==arch:
            return wgan_gp.DCGANDiscriminator
        elif 'FCDis'==arch:
            return wgan_gp.FCDiscriminator
        if arch.startswith('DCGANRegion'):
            return wgan_gp.DCGANDiscriminatoRegion
        raise Exception('You must choose an architecture!')
    # def _getDiscriminator(self, wgan_gp, arch='DCGAN'):
    #     if 'Patch70x70'==arch:
    #         return wgan_gp.PatchDiscriminator_70x70
    #     elif 'Patch46x46'==arch:
    #         return wgan_gp.PatchDiscriminator_46x46
    #     elif 'Patch28x28'==arch:
    #         return wgan_gp.PatchDiscriminator_28x28
    #     elif 'Patch16x16'==arch:
    #         return wgan_gp.PatchDiscriminator_16x16
    #     elif 'Patch13x13'==arch:
    #         return wgan_gp.PatchDiscriminator_13x13
    #     elif 'DCGAN'==arch:
    #         # Baseline (G: DCGAN, D: DCGAN)
    #         return wgan_gp.DCGANDiscriminator
    #     elif 'FCDis'==arch:
    #         return wgan_gp.FCDiscriminator
    #     raise Exception('You must choose an architecture!')

    def init_net(self):
        self.build_model()

        if self.pretrained_path is not None:
            var1 = tf.get_collection(tf.GraphKeys.VARIABLES, scope='Encoder')
            var2 = tf.get_collection(tf.GraphKeys.VARIABLES, scope='ID_AE')
            self.saverPart = tf.train.Saver(var1+var2, max_to_keep=20)
            
        if self.pretrained_poseAE_path is not None:
            var = tf.get_collection(tf.GraphKeys.VARIABLES, scope='PoseAE')
            self.saverPoseAEPart = tf.train.Saver(var, max_to_keep=20)

        self.saver = tf.train.Saver(max_to_keep=20)
        self.summary_writer = tf.summary.FileWriter(self.model_dir)

        sv = tf.train.Supervisor(logdir=self.model_dir,
                                is_chief=True,
                                saver=None,
                                summary_op=None,
                                summary_writer=self.summary_writer,
                                global_step=self.step,
                                save_model_secs=0,
                                ready_for_local_init_op=None)

        gpu_options = tf.GPUOptions(allow_growth=True)
        sess_config = tf.ConfigProto(allow_soft_placement=True,
                                    gpu_options=gpu_options)
        self.sess = sv.prepare_or_wait_for_session(config=sess_config)
        if self.pretrained_path is not None:
            self.saverPart.restore(self.sess, self.pretrained_path)
            print('restored from pretrained_path:', self.pretrained_path)
        if self.pretrained_poseAE_path is not None:
            self.saverPoseAEPart.restore(self.sess, self.pretrained_poseAE_path)
            print('restored from pretrained_poseAE_path:', self.pretrained_poseAE_path)
        if self.ckpt_path is not None:
            self.saver.restore(self.sess, self.ckpt_path)
            print('restored from ckpt_path:', self.ckpt_path)

        self.test_dir_name = 'test_result'

    def _gan_loss(self, wgan_gp, Discriminator, disc_real, disc_fake, real_data=None, fake_data=None, arch=None):
        if wgan_gp.MODE == 'wgan':
            gen_cost = -tf.reduce_mean(disc_fake)
            disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)

        elif wgan_gp.MODE == 'wgan-gp':
            gen_cost = -tf.reduce_mean(disc_fake)
            disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)

            alpha = tf.random_uniform(
                shape=[wgan_gp.BATCH_SIZE/len(wgan_gp.DEVICES),1,1,1], 
                minval=0.,
                maxval=1.
            )
            differences = fake_data - real_data
            interpolates = real_data + (tf.squeeze(alpha,[2,3])*differences)
            gradients = tf.gradients(Discriminator(interpolates), [interpolates])[0]
            slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
            gradient_penalty = tf.reduce_mean((slopes-1.)**2)
            disc_cost += wgan_gp.LAMBDA*gradient_penalty

        elif wgan_gp.MODE == 'dcgan':
            gen_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake,
                                                                              labels=tf.ones_like(disc_fake)))
            disc_cost =  tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake,
                                                                                labels=tf.zeros_like(disc_fake)))
            disc_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real,
                                                                                labels=tf.ones_like(disc_real)))                                     
            disc_cost /= 2.
        elif wgan_gp.MODE == 'lsgan':
            gen_cost = tf.reduce_mean((disc_fake - 1)**2)
            disc_cost = (tf.reduce_mean((disc_real - 1)**2) + tf.reduce_mean((disc_fake - 0)**2))/2.

        else:
            raise Exception()
        return gen_cost, disc_cost

    def _define_input(self):
        self.is_train_tensor = tf.Variable(self.is_train, name='phase')
        self.Generator_fn = models.GeneratorCNN_ID_UAEAfterResidual
        self.wgan_gp = WGAN_GP(DATA_DIR='', MODE='dcgan', DIM=64, BATCH_SIZE=self.batch_size, ITERS=200000, 
                            LAMBDA=10, G_OUTPUT_DIM=self.img_H*self.img_W*3)
        self.Discriminator_fn = self._getDiscriminator(self.wgan_gp, arch=self.D_arch)

    def build_model(self):
        self._define_input()
        with tf.variable_scope("Encoder") as vs:
            pb_list = tf.split(self.part_bbox, self.part_num, axis=1)
            pv_list = tf.split(self.part_vis, self.part_num, axis=1)
            ## Part 1-3 (totally 3)
            # self.embs, self.Encoder_var = GeneratorCNN_ID_Encoder_BodyROI(self.x, self.part_bbox, len(indices), 64, 
            #                                 self.repeat_num, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, reuse=False)
            # Part 1-7 (totally 7)
            indices = range(7)
            select_part_bbox = tf.concat([pb_list[i] for i in indices], axis=1)
            self.embs, _, self.Encoder_var = GeneratorCNN_ID_Encoder_BodyROI(self.x, select_part_bbox, len(indices), 32, 
                                            self.repeat_num, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, keep_part_prob=1.0, reuse=False)
            # self.embs, _, self.Encoder_var = GeneratorCNN_ID_Encoder_BodyROI2(self.x, select_part_bbox, len(indices), 32, 
            #                                 self.repeat_num, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, keep_part_prob=0.9, reuse=False)
            ## Part 1,4-8 (totally 6)
            # indices = [1] + range(4,9)
            # select_part_bbox = tf.concat([pb_list[i] for i in indices], axis=1)
            # self.embs, _, self.Encoder_var = GeneratorCNN_ID_Encoder_BodyROI(self.x, select_part_bbox, len(indices), 32, 
            #                                 self.repeat_num, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, keep_part_prob=1.0, reuse=False)
            ## Part 1,8-16 (totally 10)
            # indices = [0] + range(7,16)
            # select_part_bbox = tf.concat([pb_list[i] for i in indices], axis=1)
            # select_part_vis = tf.cast(tf.concat([pv_list[i] for i in indices], axis=1), tf.float32)
            # self.embs, _, self.Encoder_var = GeneratorCNN_ID_Encoder_BodyROIVis(self.x, select_part_bbox, select_part_vis, len(indices), 32, 
            #                                 self.repeat_num, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, keep_part_prob=1.0, reuse=False)

        self.embs_rep = tf.tile(tf.expand_dims(self.embs,-1), [1, 1, self.img_H*self.img_W])
        self.embs_rep = tf.reshape(self.embs_rep, [self.batch_size, -1, self.img_H, self.img_W])
        self.embs_rep = nchw_to_nhwc(self.embs_rep)

        with tf.variable_scope("ID_AE") as vs:
            G, _, self.G_var = self.Generator_fn(
                    self.embs_rep, self.pose, 
                    self.channel, self.z_num, self.repeat_num, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, reuse=False)
        self.G_var += self.Encoder_var

        self.G = denorm_img(G, self.data_format)
        pair = tf.concat([self.x, G], 0)
        self.D_z = self.Discriminator_fn(tf.transpose( pair, [0,3,1,2] ), input_dim=3)
        self.D_var = lib.params_with_name('Discriminator.')
        D_z_pos, D_z_neg = tf.split(self.D_z, 2)

        self.g_loss, self.d_loss = self._gan_loss(self.wgan_gp, self.Discriminator_fn, D_z_pos, D_z_neg, arch=self.D_arch)
        self.PoseMaskLoss = tf.reduce_mean(tf.abs(G - self.x) * (self.mask_r6))
        self.L1Loss = tf.reduce_mean(tf.abs(G - self.x))
        self.g_loss_only = self.g_loss

        self._define_loss_optim()
        self.summary_op = tf.summary.merge([
            tf.summary.image("G", self.G),
            tf.summary.scalar("loss/PoseMaskLoss", self.PoseMaskLoss),
            tf.summary.scalar("loss/L1Loss", self.L1Loss),
            tf.summary.scalar("loss/g_loss", self.g_loss),
            tf.summary.scalar("loss/g_loss_only", self.g_loss_only),
            tf.summary.scalar("loss/d_loss", self.d_loss),
            tf.summary.scalar("misc/d_lr", self.d_lr),
            tf.summary.scalar("misc/g_lr", self.g_lr),
        ])

    def _define_loss_optim(self):
        self.g_loss += self.L1Loss * 20
        self.g_optim, self.d_optim, self.clip_disc_weights = self._getOptimizer(self.wgan_gp, 
                                self.g_loss, self.d_loss, self.G_var, self.D_var)

    def train(self):
        x_fixed, x_target_fixed, pose_fixed, pose_target_fixed, mask_fixed, mask_target_fixed, part_bbox_fixed, \
                    part_bbox_target_fixed, part_vis_fixed, part_vis_target_fixed = self.get_image_from_loader()
        save_image(x_fixed, '{}/x_fixed.png'.format(self.model_dir))
        save_image(x_target_fixed, '{}/x_target_fixed.png'.format(self.model_dir))
        save_image((np.amax(pose_fixed, axis=-1, keepdims=True)+1.0)*127.5, '{}/pose_fixed.png'.format(self.model_dir))
        save_image((np.amax(pose_target_fixed, axis=-1, keepdims=True)+1.0)*127.5, '{}/pose_target_fixed.png'.format(self.model_dir))
        save_image(mask_fixed, '{}/mask_fixed.png'.format(self.model_dir))
        save_image(mask_target_fixed, '{}/mask_target_fixed.png'.format(self.model_dir))

        for step in trange(self.start_step, self.max_step):
            if step>0:
                self.sess.run([self.g_optim])
            # Train critic
            if (self.wgan_gp.MODE == 'dcgan') or (self.wgan_gp.MODE == 'lsgan'):
                disc_ITERS = 1
            else:
                disc_ITERS = self.wgan_gp.CRITIC_ITERS
            for i in xrange(disc_ITERS):
                self.sess.run(self.d_optim)
                if self.wgan_gp.MODE == 'wgan':
                    self.sess.run(self.clip_disc_weights)

            if 0==step or step % self.log_step == self.log_step-1:
                fetch_dict = {
                    "summary": self.summary_op
                }
                result = self.sess.run(fetch_dict)
                self.summary_writer.add_summary(result['summary'], step)
                self.summary_writer.flush()

            if 0==step or step % (self.log_step * 3) == (self.log_step * 3)-1:
                x = process_image(x_fixed, 127.5, 127.5)
                x_target = process_image(x_target_fixed, 127.5, 127.5)
                self.generate(x, x_target, pose_fixed, part_bbox_fixed, part_vis_fixed, self.model_dir, idx=step)

            if step % self.lr_update_step == self.lr_update_step - 1:
                self.sess.run([self.g_lr_update, self.d_lr_update])

            if step % (self.log_step * 30) == (self.log_step * 30)-1:
                self.saver.save(self.sess, os.path.join(self.model_dir, 'model.ckpt'), global_step=step)

    def test(self):
        test_result_dir = os.path.join(self.model_dir, 'test_result')
        test_result_dir_x = os.path.join(test_result_dir, 'x')
        test_result_dir_x_target = os.path.join(test_result_dir, 'x_target')
        test_result_dir_G = os.path.join(test_result_dir, 'G')
        test_result_dir_pose = os.path.join(test_result_dir, 'pose')
        test_result_dir_pose_target = os.path.join(test_result_dir, 'pose_target')
        test_result_dir_mask = os.path.join(test_result_dir, 'mask')
        test_result_dir_mask_target = os.path.join(test_result_dir, 'mask_target')
        if not os.path.exists(test_result_dir):
            os.makedirs(test_result_dir)
        if not os.path.exists(test_result_dir_x):
            os.makedirs(test_result_dir_x)
        if not os.path.exists(test_result_dir_x_target):
            os.makedirs(test_result_dir_x_target)
        if not os.path.exists(test_result_dir_G):
            os.makedirs(test_result_dir_G)
        if not os.path.exists(test_result_dir_pose):
            os.makedirs(test_result_dir_pose)
        if not os.path.exists(test_result_dir_pose_target):
            os.makedirs(test_result_dir_pose_target)
        if not os.path.exists(test_result_dir_mask):
            os.makedirs(test_result_dir_mask)
        if not os.path.exists(test_result_dir_mask_target):
            os.makedirs(test_result_dir_mask_target)

        for i in xrange(100):
            x_fixed, x_target_fixed, pose_fixed, pose_target_fixed, mask_fixed, mask_target_fixed, part_bbox_fixed, \
                        part_bbox_target_fixed, part_vis_fixed, part_vis_target_fixed = self.get_image_from_loader()
            x = process_image(x_fixed, 127.5, 127.5)
            x_target = process_image(x_target_fixed, 127.5, 127.5)
            if 0==i:
                x_fake = self.generate(x, x_target, pose_fixed, part_bbox_fixed, part_vis_fixed, test_result_dir, idx=self.start_step, save=True)
            else:
                x_fake = self.generate(x, x_target, pose_fixed, part_bbox_fixed, part_vis_fixed, test_result_dir, idx=self.start_step, save=False)
            p = (np.amax(pose_fixed, axis=-1, keepdims=False)+1.0)*127.5
            pt = (np.amax(pose_target_fixed, axis=-1, keepdims=False)+1.0)*127.5
            for j in xrange(self.batch_size):
                idx = i*self.batch_size+j
                im = Image.fromarray(x_fixed[j,:].astype(np.uint8))
                im.save('%s/%05d.png'%(test_result_dir_x, idx))
                im = Image.fromarray(x_target_fixed[j,:].astype(np.uint8))
                im.save('%s/%05d.png'%(test_result_dir_x_target, idx))
                im = Image.fromarray(x_fake[j,:].astype(np.uint8))
                im.save('%s/%05d.png'%(test_result_dir_G, idx))
                im = Image.fromarray(p[j,:].astype(np.uint8))
                im.save('%s/%05d.png'%(test_result_dir_pose, idx))
                im = Image.fromarray(pt[j,:].astype(np.uint8))
                im.save('%s/%05d.png'%(test_result_dir_pose_target, idx))
                im = Image.fromarray(mask_fixed[j,:].squeeze().astype(np.uint8))
                im.save('%s/%05d.png'%(test_result_dir_mask, idx))
                im = Image.fromarray(mask_target_fixed[j,:].squeeze().astype(np.uint8))
                im.save('%s/%05d.png'%(test_result_dir_mask_target, idx))
            if 0==i:
                save_image(x_fixed, '{}/x_fixed.png'.format(test_result_dir))
                save_image(x_target_fixed, '{}/x_target_fixed.png'.format(test_result_dir))
                save_image(mask_fixed, '{}/mask_fixed.png'.format(test_result_dir))
                save_image(mask_target_fixed, '{}/mask_target_fixed.png'.format(test_result_dir))
                save_image((np.amax(pose_fixed, axis=-1, keepdims=True)+1.0)*127.5, '{}/pose_fixed.png'.format(test_result_dir))
                save_image((np.amax(pose_target_fixed, axis=-1, keepdims=True)+1.0)*127.5, '{}/pose_target_fixed.png'.format(test_result_dir))

    def test_one_by_one(self, img_dir, pair_path, all_peaks_path, subsets_path,
                    pair_num=500, shuffle=True, random_int=0, result_dir_name='test_demo'):
        test_result_dir = os.path.join(self.model_dir, result_dir_name)
        test_result_dir_x = os.path.join(test_result_dir, 'x')
        test_result_dir_x_target = os.path.join(test_result_dir, 'x_target')
        test_result_dir_G = os.path.join(test_result_dir, 'G')
        test_result_dir_pose = os.path.join(test_result_dir, 'pose')
        test_result_dir_pose_target = os.path.join(test_result_dir, 'pose_target')
        test_result_dir_mask = os.path.join(test_result_dir, 'mask')
        test_result_dir_mask_target = os.path.join(test_result_dir, 'mask_target')
        if not os.path.exists(test_result_dir):
            os.makedirs(test_result_dir)
        if not os.path.exists(test_result_dir_x):
            os.makedirs(test_result_dir_x)
        if not os.path.exists(test_result_dir_x_target):
            os.makedirs(test_result_dir_x_target)
        if not os.path.exists(test_result_dir_G):
            os.makedirs(test_result_dir_G)
        if not os.path.exists(test_result_dir_pose):
            os.makedirs(test_result_dir_pose)
        if not os.path.exists(test_result_dir_pose_target):
            os.makedirs(test_result_dir_pose_target)
        if not os.path.exists(test_result_dir_mask):
            os.makedirs(test_result_dir_mask)
        if not os.path.exists(test_result_dir_mask_target):
            os.makedirs(test_result_dir_mask_target)

        pairs = pickle.load(open(pair_path,'r'))
        all_peaks_dic = pickle.load(open(all_peaks_path,'r'))
        subsets_dic = pickle.load(open(subsets_path,'r'))

        if shuffle:
            np.random.seed(0)
            idx_all = np.random.permutation(len(pairs))
        else:
            idx_all = np.array(range(len(pairs)))
        # idx_list = idx_all[:test_pair_num]
        height, width, _ = scipy.misc.imread(os.path.join(img_dir, pairs[0][0])).shape

        cnt = -1
        for i in trange(len(idx_all)):
            if cnt>= pair_num-1:
                break
            idx = idx_all[i]
            
            if (pairs[idx][0] in all_peaks_dic) and (pairs[idx][1] in all_peaks_dic):
                cnt += 1
                ## Pose 0
                peaks_0 = _get_valid_peaks(all_peaks_dic[pairs[idx][0]], subsets_dic[pairs[idx][0]])
                indices_r4_0, values_r4_0, shape = _getSparsePose(peaks_0, height, width, self.keypoint_num, radius=4, mode='Solid')
                pose_dense_0 = _sparse2dense(indices_r4_0, values_r4_0, shape)
                pose_mask_r4_0 = _getPoseMask(peaks_0, height, width, radius=4, mode='Solid')
                ## Pose 1
                peaks_1 = _get_valid_peaks(all_peaks_dic[pairs[idx][1]], subsets_dic[pairs[idx][1]])
                indices_r4_1, values_r4_1, shape = _getSparsePose(peaks_1, height, width, self.keypoint_num, radius=4, mode='Solid')
                pose_dense_1 = _sparse2dense(indices_r4_1, values_r4_1, shape)
                pose_mask_r4_1 = _getPoseMask(peaks_1, height, width, radius=4, mode='Solid')
                ## Generate image
                x = scipy.misc.imread(os.path.join(img_dir, pairs[idx][0]))
                x = process_image(x, 127.5, 127.5)
                x_target = scipy.misc.imread(os.path.join(img_dir, pairs[idx][1]))
                x_target = process_image(x_target, 127.5, 127.5)
                x_batch = np.expand_dims(x,axis=0)
                pose_batch = np.expand_dims(pose_dense_1*2-1,axis=0)
                G = self.sess.run(self.G, {self.x: x_batch, self.pose_target: pose_batch})
                ## Save
                shutil.copy(os.path.join(img_dir, pairs[idx][0]), os.path.join(test_result_dir_x, 'pair%05d-%s'%(cnt, pairs[idx][0])))
                shutil.copy(os.path.join(img_dir, pairs[idx][1]), os.path.join(test_result_dir_x_target, 'pair%05d-%s'%(cnt, pairs[idx][1])))
                im = Image.fromarray(G.squeeze().astype(np.uint8))
                im.save('%s/pair%05d-%s-%s.jpg'%(test_result_dir_G, cnt, pairs[idx][0], pairs[idx][1]))
                im = np.amax(pose_dense_0, axis=-1, keepdims=False)*255
                im = Image.fromarray(im.astype(np.uint8))
                im.save('%s/pair%05d-%s.jpg'%(test_result_dir_pose, cnt, pairs[idx][0]))
                im = np.amax(pose_dense_1, axis=-1, keepdims=False)*255
                im = Image.fromarray(im.astype(np.uint8))
                im.save('%s/pair%05d-%s.jpg'%(test_result_dir_pose_target, cnt, pairs[idx][1]))
                im = pose_mask_r4_0*255
                im = Image.fromarray(im.astype(np.uint8))
                im.save('%s/pair%05d-%s.jpg'%(test_result_dir_mask, cnt, pairs[idx][0]))
                im = pose_mask_r4_1*255
                im = Image.fromarray(im.astype(np.uint8))
                im.save('%s/pair%05d-%s.jpg'%(test_result_dir_mask_target, cnt, pairs[idx][1]))
            else:
                continue

    def generate(self, x_fixed, x_target_fixed, pose_fixed, part_bbox_fixed, part_vis_fixed, root_path=None, path=None, idx=None, save=True):
        G = self.sess.run(self.G, {self.x: x_fixed, self.pose: pose_fixed, self.part_bbox: part_bbox_fixed, self.part_vis: part_vis_fixed})
        ssim_G_x_list = []
        for i in xrange(G.shape[0]):
            G_gray = rgb2gray((G[i,:]).clip(min=0,max=255).astype(np.uint8))
            x_gray = rgb2gray(((x_fixed[i,:]+1)*127.5).clip(min=0,max=255).astype(np.uint8))
            ssim_G_x_list.append(ssim(G_gray, x_gray, data_range=x_gray.max() - x_gray.min(), multichannel=False))
        ssim_G_x_mean = np.mean(ssim_G_x_list)
        if path is None and save:
            path = os.path.join(root_path, '{}_G_ssim{}.png'.format(idx,ssim_G_x_mean))
            save_image(G, path)
            print("[*] Samples saved: {}".format(path))
        return G

    def get_image_from_loader(self):
        x, x_target, pose, pose_target, mask, mask_target, part_bbox, part_bbox_target, part_vis, part_vis_target = self.sess.run([self.x, self.x_target, self.pose, self.pose_target, 
                                        self.mask_r6, self.mask_r6_target, self.part_bbox, self.part_bbox_target, self.part_vis, self.part_vis_target])
        x = unprocess_image(x, 127.5, 127.5)
        x_target = unprocess_image(x_target, 127.5, 127.5)
        mask = mask*255
        mask_target = mask_target*255
        return x, x_target, pose, pose_target, mask, mask_target, part_bbox, part_bbox_target, part_vis, part_vis_target

    def _load_batch_pair_pose(self, dataset):
        data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset, common_queue_capacity=32, common_queue_min=8)

        image_raw_0, image_raw_1, label, pose_rcv_0, pose_rcv_1, mask_r4_0, mask_r4_1, mask_r6_0, mask_r6_1, part_bbox_0, part_bbox_1, part_vis_0, part_vis_1  = data_provider.get([
            'image_raw_0', 'image_raw_1', 'label', 'pose_peaks_0_rcv', 'pose_peaks_1_rcv', 'pose_mask_r4_0', 'pose_mask_r4_1', 'pose_mask_r6_0', 'pose_mask_r6_1', 
            'part_bbox_0', 'part_bbox_1', 'part_vis_0', 'part_vis_1'])

        image_raw_0 = tf.reshape(image_raw_0, [128, 64, 3])        
        image_raw_1 = tf.reshape(image_raw_1, [128, 64, 3]) 
        mask_r4_0 = tf.cast(tf.reshape(mask_r4_0, [128, 64, 1]), tf.float32)
        mask_r4_1 = tf.cast(tf.reshape(mask_r4_1, [128, 64, 1]), tf.float32)
        mask_r6_0 = tf.cast(tf.reshape(mask_r6_0, [128, 64, 1]), tf.float32)
        mask_r6_1 = tf.cast(tf.reshape(mask_r6_1, [128, 64, 1]), tf.float32)
        part_bbox_0 = tf.reshape(part_bbox_0, [self.part_num, 4])        
        part_bbox_1 = tf.reshape(part_bbox_1, [self.part_num, 4]) 

        images_0, images_1, poses_rcv_0, poses_rcv_1, masks_r4_0, masks_r4_1, masks_r6_0, masks_r6_1, part_bboxs_0, part_bboxs_1, part_viss_0, part_viss_1 = tf.train.batch([image_raw_0, 
                                                    image_raw_1, pose_rcv_0, pose_rcv_1, mask_r4_0, mask_r4_1, mask_r6_0, mask_r6_1, part_bbox_0, part_bbox_1, part_vis_0, part_vis_1], 
                                                    batch_size=self.batch_size, num_threads=self.num_threads, capacity=self.capacityCoff * self.batch_size)

        images_0 = process_image(tf.to_float(images_0), 127.5, 127.5)
        images_1 = process_image(tf.to_float(images_1), 127.5, 127.5)
        poses_0 = tf.cast(coord2channel_simple_rcv(poses_rcv_0, keypoint_num=18, is_normalized=False, img_H=128, img_W=64), tf.float32)
        poses_1 = tf.cast(coord2channel_simple_rcv(poses_rcv_1, keypoint_num=18, is_normalized=False, img_H=128, img_W=64), tf.float32)
        poses_0 = tf_poseInflate(poses_0, keypoint_num=18, radius=4, img_H=self.img_H, img_W=self.img_W)
        poses_1 = tf_poseInflate(poses_1, keypoint_num=18, radius=4, img_H=self.img_H, img_W=self.img_W)

        return images_0, images_1, poses_0, poses_1, poses_rcv_0, poses_rcv_1, masks_r4_0, masks_r4_1, masks_r6_0, masks_r6_1, part_bboxs_0, part_bboxs_1, part_viss_0, part_viss_1


class DPIG_Encoder_GAN_BodyROI_FgBg(DPIG_Encoder_GAN_BodyROI):
    def build_model(self):
        self._define_input()
        with tf.variable_scope("Encoder") as vs:
            pb_list = tf.split(self.part_bbox, self.part_num, axis=1)
            pv_list = tf.split(self.part_vis, self.part_num, axis=1)
            ## Part 1,8-16 (totally 10)
            # indices = [1] + range(8,17)
            ## Part 1-7 (totally 7)
            indices = range(0,7)
            select_part_bbox = tf.concat([pb_list[i] for i in indices], axis=1)
            select_part_vis = tf.cast(tf.concat([pv_list[i] for i in indices], axis=1), tf.float32)
            # self.embs, _, self.Encoder_var = models.GeneratorCNN_ID_Encoder_BodyROIVis_FgBgFeaOneBranch(self.x, self.mask_r6, select_part_bbox, select_part_vis, len(indices), 32, 
            #                                 self.repeat_num, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, keep_part_prob=1.0, reuse=False)
            self.embs, _, _, self.Encoder_var = models.GeneratorCNN_ID_Encoder_BodyROIVis_FgBgFeaTwoBranch(self.x, self.mask_r6, select_part_bbox, select_part_vis, len(indices), 32, 
                                            self.repeat_num, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, keep_part_prob=1.0, reuse=False)
            # self.embs, _, self.Encoder_var = models.GeneratorCNN_ID_Encoder_BodyROIVis_FgBgImgOneBranch(self.x, self.mask_r6, select_part_bbox, select_part_vis, len(indices), 32, 
            #                                 self.repeat_num, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, keep_part_prob=1.0, reuse=False)
            # self.embs, _, self.Encoder_var = models.GeneratorCNN_ID_Encoder_BodyROIVis_FgBgImgTwoBranch(self.x, self.mask_r6, select_part_bbox, select_part_vis, len(indices), 32, 
            #                                 self.repeat_num, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, keep_part_prob=1.0, reuse=False)

        self.embs_rep = tf.tile(tf.expand_dims(self.embs,-1), [1, 1, self.img_H*self.img_W])
        self.embs_rep = tf.reshape(self.embs_rep, [self.batch_size, -1, self.img_H, self.img_W])
        self.embs_rep = nchw_to_nhwc(self.embs_rep)

        with tf.variable_scope("ID_AE") as vs:
            G, _, self.G_var = self.Generator_fn(
                    self.embs_rep, self.pose, 
                    self.channel, self.z_num, self.repeat_num, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, reuse=False)

        self.G_var += self.Encoder_var

        self.G = denorm_img(G, self.data_format)

        D_z_pos = self.Discriminator_fn(tf.transpose( self.x, [0,3,1,2] ), input_dim=3)
        D_z_neg = self.Discriminator_fn(tf.transpose( G, [0,3,1,2] ), input_dim=3)
        self.D_var = lib.params_with_name('Discriminator.')

        self.g_loss, self.d_loss = self._gan_loss(self.wgan_gp, self.Discriminator_fn, D_z_pos, D_z_neg, arch=self.D_arch)
        self.PoseMaskLoss = tf.reduce_mean(tf.abs(G - self.x) * (self.mask_r6))
        self.L1Loss = tf.reduce_mean(tf.abs(G - self.x))
        self.g_loss_only = self.g_loss

        self._define_loss_optim()
        self.summary_op = tf.summary.merge([
            tf.summary.image("G", self.G),
            tf.summary.scalar("loss/PoseMaskLoss", self.PoseMaskLoss),
            tf.summary.scalar("loss/L1Loss", self.L1Loss),
            tf.summary.scalar("loss/g_loss", self.g_loss),
            tf.summary.scalar("loss/g_loss_only", self.g_loss_only),
            tf.summary.scalar("loss/d_loss", self.d_loss),
            tf.summary.scalar("misc/d_lr", self.d_lr),
            tf.summary.scalar("misc/g_lr", self.g_lr),
        ])

    def _define_loss_optim(self):
        self.g_loss += self.L1Loss * 20
        self.g_optim, self.d_optim, self.clip_disc_weights = self._getOptimizer(self.wgan_gp, 
                                self.g_loss, self.d_loss, self.G_var, self.D_var)



class DPIG_PoseRCV_AE_BodyROI(DPIG_Encoder_GAN_BodyROI):
    def _define_input(self):
        self.is_train_tensor = tf.Variable(self.is_train, name='phase')
        self.Generator_encoder_fn = models.GeneratorCNN_ID_Encoder
        self.Generator_fn = models.GeneratorCNN_ID_UAEAfterResidual
        self.wgan_gp = WGAN_GP(DATA_DIR='', MODE='dcgan', DIM=64, BATCH_SIZE=self.batch_size, ITERS=200000, 
                            LAMBDA=10, G_OUTPUT_DIM=self.keypoint_num*3)
        self.Discriminator_fn = self._getDiscriminator(self.wgan_gp, arch='FCDis')

    def build_model(self):
        self._define_input()
        with tf.variable_scope("PoseAE") as vs:
            ## Norm to [-1, 1]
            pose_rcv_norm = tf.reshape(self.pose_rcv, [self.batch_size, self.keypoint_num, 3])
            R = tf.cast(tf.slice(pose_rcv_norm, [0,0,0], [-1,-1,1]), tf.float32)/float(self.img_H)*2.0 - 1
            C = tf.cast(tf.slice(pose_rcv_norm, [0,0,1], [-1,-1,1]), tf.float32)/float(self.img_W)*2.0 - 1
            V = tf.cast(tf.slice(pose_rcv_norm, [0,0,2], [-1,-1,1]), tf.float32)
            pose_rcv_norm = tf.concat([R,C,V], axis=-1)
            self.pose_embs, self.G_var_encoder = models.PoseEncoderFCRes(tf.reshape(pose_rcv_norm, [self.batch_size,-1]), 
                        z_num=32, repeat_num=4, hidden_num=512, data_format=self.data_format, activation_fn=LeakyReLU, reuse=False)

            if self.sample_pose: ## Sampling new poses during testing
                self.pose_embs = tf.random_normal(tf.shape(self.pose_embs), mean=0.0, stddev=0.2, dtype=tf.float32)

            G_pose_coord, G_pose_visible, self.G_var_decoder = models.PoseDecoderFCRes(self.pose_embs, self.keypoint_num, 
                        repeat_num=4, hidden_num=512, data_format=self.data_format, activation_fn=LeakyReLU, reuse=False)
            self.G_var_pose = self.G_var_encoder + self.G_var_decoder

        G_pose_rcv = tf.concat([tf.reshape(G_pose_coord, [self.batch_size,self.keypoint_num,2]), tf.expand_dims(G_pose_visible,-1)], axis=-1)
        G_pose = coord2channel_simple_rcv(G_pose_rcv, self.keypoint_num, is_normalized=True, img_H=self.img_H, img_W=self.img_W)
        self.G_pose = denorm_img(tf.tile(tf.reduce_max(G_pose, axis=-1, keep_dims=True), [1,1,1,3]), self.data_format)
        self.G_pose_rcv = G_pose_rcv

        self.reconstruct_loss = tf.reduce_mean(tf.square(pose_rcv_norm - G_pose_rcv))

        self._define_loss_optim()
        self.summary_op = tf.summary.merge([
            tf.summary.scalar("loss/reconstruct_loss", self.reconstruct_loss),
        ])

    def _define_loss_optim(self):
        self.g_optim = tf.train.AdamOptimizer(learning_rate=self.g_lr, beta1=0.5).minimize(self.reconstruct_loss * 20,
                                              var_list=self.G_var_pose, colocate_gradients_with_ops=True)

    def train(self):
        x_fixed, x_target_fixed, pose_fixed, pose_target_fixed, pose_rcv_fixed, mask_fixed, mask_target_fixed, \
                                        part_bbox_fixed, part_bbox_target_fixed = self.get_image_from_loader()
        save_image(x_fixed, '{}/x_fixed.png'.format(self.model_dir))
        save_image(x_target_fixed, '{}/x_target_fixed.png'.format(self.model_dir))
        save_image((np.amax(pose_fixed, axis=-1, keepdims=True)+1.0)*127.5, '{}/pose_fixed.png'.format(self.model_dir))
        save_image((np.amax(pose_target_fixed, axis=-1, keepdims=True)+1.0)*127.5, '{}/pose_target_fixed.png'.format(self.model_dir))
        save_image(mask_fixed, '{}/mask_fixed.png'.format(self.model_dir))
        save_image(mask_target_fixed, '{}/mask_target_fixed.png'.format(self.model_dir))

        for step in trange(self.start_step, self.max_step):
            if step>0:
                self.sess.run([self.g_optim])

            fetch_dict = {}
            if step % self.log_step == self.log_step-1:
                fetch_dict.update({
                    "summary": self.summary_op
                })
            result = self.sess.run(fetch_dict)

            if step % self.log_step == self.log_step-1:
                self.summary_writer.add_summary(result['summary'], step)
                self.summary_writer.flush()

            if step % self.lr_update_step == self.lr_update_step - 1:
                self.sess.run([self.g_lr_update, self.d_lr_update])

            if step % (self.log_step * 30) == (self.log_step * 30)-1:
                self.saver.save(self.sess, os.path.join(self.model_dir, 'model.ckpt'), global_step=step)

    def get_image_from_loader(self):
        x, x_target, pose, pose_target, pose_rcv, mask, mask_target, part_bbox, part_bbox_target = self.sess.run([self.x, self.x_target, self.pose, self.pose_target, 
                                self.pose_rcv, self.mask_r6, self.mask_r6_target, self.part_bbox, self.part_bbox_target])
        x = unprocess_image(x, 127.5, 127.5)
        x_target = unprocess_image(x_target, 127.5, 127.5)
        mask = mask*255
        mask_target = mask_target*255
        return x, x_target, pose, pose_target, pose_rcv, mask, mask_target, part_bbox, part_bbox_target


################### Subnet of Appearance/Pose Sampling #################
class DPIG_Encoder_subSampleAppNetFgBg_GAN_BodyROI(DPIG_Encoder_GAN_BodyROI):
    def _define_input(self):
        self.is_train_tensor = tf.Variable(self.is_train, name='phase')
        self.Generator_encoder_fn = models.GeneratorCNN_ID_Encoder
        self.Generator_fn = models.GeneratorCNN_ID_UAEAfterResidual
        self.wgan_gp_fg = WGAN_GP(DATA_DIR='', MODE='wgan', DIM=64, BATCH_SIZE=self.batch_size, ITERS=200000, 
                            LAMBDA=10)
        self.Discriminator_fg_fn = self._getDiscriminator(self.wgan_gp_fg, arch='FCDis')
        self.wgan_gp_bg = WGAN_GP(DATA_DIR='', MODE='wgan', DIM=64, BATCH_SIZE=self.batch_size, ITERS=200000, 
                            LAMBDA=10)
        self.Discriminator_bg_fn = self._getDiscriminator(self.wgan_gp_bg, arch='FCDis')

    def build_model(self):
        self._define_input()
        with tf.variable_scope("Encoder") as vs:
            pb_list = tf.split(self.part_bbox, self.part_num, axis=1)
            pv_list = tf.split(self.part_vis, self.part_num, axis=1)
            ## Part 1,8-16 (totally 10)
            # indices = [0] + range(7,16)
            ## Part 1-7 (totally 7)
            indices = range(0,7)
            select_part_bbox = tf.concat([pb_list[i] for i in indices], axis=1)
            select_part_vis = tf.cast(tf.concat([pv_list[i] for i in indices], axis=1), tf.float32)
            self.embs, _, _, self.Encoder_var = models.GeneratorCNN_ID_Encoder_BodyROIVis_FgBgFeaTwoBranch(self.x, self.mask_r6, select_part_bbox, select_part_vis, len(indices), 32, 
                                            self.repeat_num, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, keep_part_prob=1.0, reuse=False)

            self.fg_embs = tf.slice(self.embs, [0,0], [-1,len(indices)*32])
            self.bg_embs = tf.slice(self.embs, [0,len(indices)*32], [-1,-1])
        self.embs_rep = tf.tile(tf.expand_dims(self.embs,-1), [1, 1, self.img_H*self.img_W])
        self.embs_rep = tf.reshape(self.embs_rep, [self.batch_size, -1, self.img_H, self.img_W])
        self.embs_rep = nchw_to_nhwc(self.embs_rep)

        with tf.variable_scope("ID_AE") as vs:
            G, _, self.G_var = self.Generator_fn(
                    self.embs_rep, self.pose, 
                    self.channel, self.z_num, self.repeat_num, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, reuse=False)

        with tf.variable_scope("Gaussian_FC_Fg") as vs:
            embs_shape = self.fg_embs.get_shape().as_list()
            self.app_embs_fg, self.G_var_app_embs_fg = models.GaussianFCRes(embs_shape, embs_shape[-1], repeat_num=4, hidden_num=512, data_format=self.data_format, activation_fn=LeakyReLU)

        with tf.variable_scope("Gaussian_FC_Bg") as vs:
            embs_shape = self.bg_embs.get_shape().as_list()
            self.app_embs_bg, self.G_var_app_embs_bg = models.GaussianFCRes(embs_shape, embs_shape[-1], repeat_num=4, hidden_num=256, data_format=self.data_format, activation_fn=LeakyReLU)

        ## Adversarial for Gaussian Fg
        # encode_pair_fg = tf.concat([self.fg_embs, self.app_embs_fg], 0)
        # D_z_embs_fg = self.Discriminator_fg_fn(encode_pair_fg, input_dim=encode_pair_fg.get_shape().as_list()[-1], FC_DIM=512, n_layers=3, reuse=False, name='Fg_FCDis_')
        # D_z_pos_embs_fg, D_z_neg_embs_fg = tf.split(D_z_embs_fg, 2)
        D_z_pos_embs_fg = self.Discriminator_fg_fn(self.fg_embs, input_dim=self.fg_embs.get_shape().as_list()[-1], FC_DIM=512, n_layers=3, reuse=False, name='Fg_FCDis_')
        D_z_neg_embs_fg = self.Discriminator_fg_fn(self.app_embs_fg, input_dim=self.app_embs_fg.get_shape().as_list()[-1], FC_DIM=512, n_layers=3, reuse=False, name='Fg_FCDis_')
        self.D_var_embs_fg = lib.params_with_name('Fg_FCDis_Discriminator.')
        self.g_loss_embs_fg, self.d_loss_embs_fg = self._gan_loss(self.wgan_gp_fg, self.Discriminator_fg_fn, 
                                            D_z_pos_embs_fg, D_z_neg_embs_fg, self.fg_embs, self.app_embs_fg)
        ## Adversarial for Gaussian Bg
        # encode_pair_bg = tf.concat([self.bg_embs, self.app_embs_bg], 0)
        # D_z_embs_bg = self.Discriminator_bg_fn(encode_pair_bg, input_dim=encode_pair_bg.get_shape().as_list()[-1], FC_DIM=512, n_layers=3, reuse=False, name='Bg_FCDis_')
        # D_z_pos_embs_bg, D_z_neg_embs_bg = tf.split(D_z_embs_bg, 2)
        D_z_pos_embs_bg = self.Discriminator_bg_fn(self.bg_embs, input_dim=self.bg_embs.get_shape().as_list()[-1], FC_DIM=512, n_layers=3, reuse=False, name='Bg_FCDis_')
        D_z_neg_embs_bg = self.Discriminator_bg_fn(self.app_embs_bg, input_dim=self.app_embs_bg.get_shape().as_list()[-1], FC_DIM=512, n_layers=3, reuse=False, name='Bg_FCDis_')
        self.D_var_embs_bg = lib.params_with_name('Bg_FCDis_Discriminator.')
        self.g_loss_embs_bg, self.d_loss_embs_bg = self._gan_loss(self.wgan_gp_bg, self.Discriminator_bg_fn, 
                                            D_z_pos_embs_bg, D_z_neg_embs_bg, self.bg_embs, self.app_embs_bg)

        assert (self.batch_size>0)and(self.batch_size%2==0), 'batch should be Even and >0'
        self.app_embs_fixFg = tf.tile(tf.slice(self.app_embs_fg, [0,0], [1,-1]), [self.batch_size/2,1])
        self.app_embs_varyFg = tf.slice(self.app_embs_fg, [self.batch_size/2,0], [self.batch_size/2,-1])
        self.app_embs_fixBg = tf.tile(tf.slice(self.app_embs_bg, [0,0], [1,-1]), [self.batch_size/2,1])
        self.app_embs_varyBg = tf.slice(self.app_embs_bg, [self.batch_size/2,0], [self.batch_size/2,-1])
        self.app_embs = tf.concat([tf.concat([self.app_embs_fixFg,self.app_embs_varyFg],axis=0), tf.concat([self.app_embs_varyBg,self.app_embs_fixBg],axis=0)], axis=-1)
        self.embs_app_rep = tf.tile(tf.expand_dims(self.app_embs,-1), [1, 1, self.img_H*self.img_W])
        self.embs_app_rep = tf.reshape(self.embs_app_rep, [self.batch_size, -1, self.img_H, self.img_W])
        self.embs_app_rep = nchw_to_nhwc(self.embs_app_rep)
        with tf.variable_scope("ID_AE") as vs:
            # pdb.set_trace()
            G_gaussian_app, _, _ = self.Generator_fn(
                    self.embs_app_rep, self.pose, 
                    self.channel, self.z_num, self.repeat_num, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, reuse=True)
        self.G = denorm_img(G_gaussian_app, self.data_format)

        self._define_loss_optim()
        self.summary_op = tf.summary.merge([
            tf.summary.image("G", self.G),
            tf.summary.scalar("loss/g_loss_embs_fg", self.g_loss_embs_fg),
            tf.summary.scalar("loss/d_loss_embs_fg", self.d_loss_embs_fg),
            tf.summary.scalar("loss/g_loss_embs_bg", self.g_loss_embs_bg),
            tf.summary.scalar("loss/d_loss_embs_bg", self.d_loss_embs_bg),
            tf.summary.scalar("misc/d_lr", self.d_lr),
            tf.summary.scalar("misc/g_lr", self.g_lr),
        ])

    def _define_loss_optim(self):
        self.g_optim_embs_fg, self.d_optim_embs_fg, self.clip_disc_weights_embs_fg = self._getOptimizer(self.wgan_gp_fg, 
                                self.g_loss_embs_fg, self.d_loss_embs_fg, self.G_var_app_embs_fg, self.D_var_embs_fg)
        self.g_optim_embs_bg, self.d_optim_embs_bg, self.clip_disc_weights_embs_bg = self._getOptimizer(self.wgan_gp_bg, 
                                self.g_loss_embs_bg, self.d_loss_embs_bg, self.G_var_app_embs_bg, self.D_var_embs_bg)

    def train(self):
        x_fixed, x_target_fixed, pose_fixed, pose_target_fixed, mask_fixed, mask_target_fixed, part_bbox_fixed, part_bbox_target_fixed, part_vis_fixed, part_vis_target_fixed = self.get_image_from_loader()
        save_image(x_fixed, '{}/x_fixed.png'.format(self.model_dir))
        save_image(x_target_fixed, '{}/x_target_fixed.png'.format(self.model_dir))
        save_image((np.amax(pose_fixed, axis=-1, keepdims=True)+1.0)*127.5, '{}/pose_fixed.png'.format(self.model_dir))
        save_image((np.amax(pose_target_fixed, axis=-1, keepdims=True)+1.0)*127.5, '{}/pose_target_fixed.png'.format(self.model_dir))
        save_image(mask_fixed, '{}/mask_fixed.png'.format(self.model_dir))
        save_image(mask_target_fixed, '{}/mask_target_fixed.png'.format(self.model_dir))

        for step in trange(self.start_step, self.max_step):
            ## Fg
            if step>0:
                self.sess.run([self.g_optim_embs_fg])
            # Train critic
            if (self.wgan_gp_fg.MODE == 'dcgan') or (self.wgan_gp_fg.MODE == 'lsgan'):
                disc_ITERS = 1
            else:
                disc_ITERS = self.wgan_gp_fg.CRITIC_ITERS
            for i in xrange(disc_ITERS):
                self.sess.run(self.d_optim_embs_fg)
                if self.wgan_gp_fg.MODE == 'wgan':
                    self.sess.run(self.clip_disc_weights_embs_fg)
            ## Bg
            if step>0:
                self.sess.run([self.g_optim_embs_bg])
            # Train critic
            if (self.wgan_gp_bg.MODE == 'dcgan') or (self.wgan_gp_bg.MODE == 'lsgan'):
                disc_ITERS = 1
            else:
                disc_ITERS = self.wgan_gp_bg.CRITIC_ITERS
            for i in xrange(disc_ITERS):
                self.sess.run(self.d_optim_embs_bg)
                if self.wgan_gp_bg.MODE == 'wgan':
                    self.sess.run(self.clip_disc_weights_embs_bg)

            fetch_dict = {}
            if step % self.log_step == self.log_step-1:
                fetch_dict.update({
                    "summary": self.summary_op
                })
            result = self.sess.run(fetch_dict)

            if step % self.log_step == self.log_step-1:
                self.summary_writer.add_summary(result['summary'], step)
                self.summary_writer.flush()

            if 0==step or step % (self.log_step * 1.5) == (self.log_step * 1.5)-1 or step % (self.log_step * 5) == (self.log_step * 5)-1:
                x = process_image(x_fixed, 127.5, 127.5)
                x_target = process_image(x_target_fixed, 127.5, 127.5)
                self.generate(x, x_target, pose_fixed, part_bbox_fixed, part_vis_fixed, self.model_dir, idx=step)

            if step % self.lr_update_step == self.lr_update_step - 1:
                self.sess.run([self.g_lr_update, self.d_lr_update])

            if step % (self.log_step * 5) == (self.log_step * 5)-1:
                self.saver.save(self.sess, os.path.join(self.model_dir, 'model.ckpt'), global_step=step)


class DPIG_subnetSamplePoseRCV_GAN_BodyROI(DPIG_PoseRCV_AE_BodyROI):
    def _define_input(self):
        self.is_train_tensor = tf.Variable(self.is_train, name='phase')
        self.Generator_encoder_fn = models.GeneratorCNN_ID_Encoder
        self.Generator_fn = models.GeneratorCNN_ID_UAEAfterResidual
        self.wgan_gp_encoder = WGAN_GP(DATA_DIR='', MODE='wgan', DIM=64, BATCH_SIZE=self.batch_size, ITERS=200000, 
                            LAMBDA=10, G_OUTPUT_DIM=self.keypoint_num*3)
        self.Discriminator_encoder_fn = self._getDiscriminator(self.wgan_gp_encoder, arch='FCDis')

    def build_model(self):
        self._define_input()
        with tf.variable_scope("PoseAE") as vs:
            ## Norm to [-1, 1]
            pose_rcv_norm = tf.reshape(self.pose_rcv, [self.batch_size, self.keypoint_num, 3])
            R = tf.cast(tf.slice(pose_rcv_norm, [0,0,0], [-1,-1,1]), tf.float32)/float(self.img_H)*2.0 - 1
            C = tf.cast(tf.slice(pose_rcv_norm, [0,0,1], [-1,-1,1]), tf.float32)/float(self.img_W)*2.0 - 1
            V = tf.cast(tf.slice(pose_rcv_norm, [0,0,2], [-1,-1,1]), tf.float32)
            pose_rcv_norm = tf.concat([R,C,V], axis=-1)

            self.pose_embs, self.G_var_encoder = models.PoseEncoderFCRes(tf.reshape(pose_rcv_norm, [self.batch_size,-1]), 
                        z_num=32, repeat_num=4, hidden_num=512, data_format=self.data_format, activation_fn=LeakyReLU, reuse=False)

        embs_shape = self.pose_embs.get_shape().as_list()
        # with tf.variable_scope("Gaussian_FC") as vs:
        with tf.variable_scope("PoseGaussian") as vs:
            self.G_pose_embs, self.G_var_embs = models.GaussianFCRes(embs_shape, embs_shape[-1], repeat_num=4, hidden_num=512, data_format=self.data_format, activation_fn=LeakyReLU)

        with tf.variable_scope("PoseAE") as vs:
            G_pose_coord, G_pose_visible, self.G_var_decoder = models.PoseDecoderFCRes(self.G_pose_embs, self.keypoint_num, 
                        repeat_num=4, hidden_num=512, data_format=self.data_format, activation_fn=LeakyReLU, reuse=False)
            
        G_pose_rcv = tf.concat([tf.reshape(G_pose_coord, [self.batch_size,self.keypoint_num,2]), tf.expand_dims(G_pose_visible,-1)], axis=-1)
        G_pose = coord2channel_simple_rcv(G_pose_rcv, self.keypoint_num, is_normalized=True, img_H=self.img_H, img_W=self.img_W)
        self.G_pose = denorm_img(tf.tile(tf.reduce_max(G_pose, axis=-1, keep_dims=True), [1,1,1,3]), self.data_format)
        self.G_pose_rcv = G_pose_rcv


        ## Adversarial for pose_embs
        self.pose_embs = tf.reshape(self.pose_embs, [self.batch_size,-1])
        self.G_pose_embs = tf.reshape(self.G_pose_embs, [self.batch_size,-1])
        encode_pair = tf.concat([self.pose_embs, self.G_pose_embs], 0)
        self.D_z_embs = self.Discriminator_encoder_fn(encode_pair, input_dim=encode_pair.get_shape().as_list()[-1], FC_DIM=512, n_layers=3, reuse=False, name='Pose_emb_')
        self.D_var_embs = lib.params_with_name('Pose_emb_Discriminator.')
        D_z_pos, D_z_neg = tf.split(self.D_z_embs, 2)
        self.g_loss_embs, self.d_loss_embs = self._gan_loss(self.wgan_gp_encoder, self.Discriminator_encoder_fn, 
                                            D_z_pos, D_z_neg, self.pose_embs, self.G_pose_embs)

        ## Use the pose to generate person with pretrained generator
        with tf.variable_scope("Encoder") as vs:
            pb_list = tf.split(self.part_bbox, self.part_num, axis=1)
            pv_list = tf.split(self.part_vis, self.part_num, axis=1)
            ## Part 1,1-7 (totally 7)
            indices = range(7)
            ## Part 1,8-16 (totally 10)
            # indices = [1] + range(8,17)
            # indices = [0] + range(7,16)
            select_part_bbox = tf.concat([pb_list[i] for i in indices], axis=1)
            select_part_vis = tf.cast(tf.concat([pv_list[i] for i in indices], axis=1), tf.float32)
            # pdb.set_trace()
            self.embs, _, _, _ = models.GeneratorCNN_ID_Encoder_BodyROIVis_FgBgFeaTwoBranch(self.x, self.mask_r6, select_part_bbox, select_part_vis, len(indices), 32, 
                                            self.repeat_num, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, keep_part_prob=1.0, reuse=False)
            # self.embs, _, _ = GeneratorCNN_ID_Encoder_BodyROIVis(self.x, select_part_bbox, select_part_vis, len(indices), 32, 
            #                                 self.repeat_num, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, keep_part_prob=1.0, reuse=False)
            # self.embs, _, self.Encoder_var = GeneratorCNN_ID_Encoder_BodyROI(self.x, self.part_bbox, 7, 32, 
            #                                 self.repeat_num, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, reuse=False)

        self.embs_rep = tf.tile(tf.expand_dims(self.embs,-1), [1, 1, self.img_H*self.img_W])
        self.embs_rep = tf.reshape(self.embs_rep, [self.batch_size, -1, self.img_H, self.img_W])
        self.embs_rep = nchw_to_nhwc(self.embs_rep)

        ## Use py code to get G_pose_inflated, so the op is out of the graph
        self.G_pose_inflated = tf.placeholder(tf.float32, shape=G_pose.get_shape())
        with tf.variable_scope("ID_AE") as vs:
            G, _, _ = self.Generator_fn(
                    self.embs_rep, self.G_pose_inflated, 
                    self.channel, self.z_num, self.repeat_num, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, reuse=False)
        self.G = denorm_img(G, self.data_format)


        self._define_loss_optim()
        self.summary_op = tf.summary.merge([
            tf.summary.image("G_pose", self.G_pose),
            tf.summary.scalar("loss/g_loss_embs", self.g_loss_embs),
            tf.summary.scalar("loss/d_loss_embs", self.d_loss_embs),
            tf.summary.scalar("misc/d_lr", self.d_lr),
            tf.summary.scalar("misc/g_lr", self.g_lr),
            tf.summary.histogram("distribution/pose_emb", self.pose_embs),
            tf.summary.histogram("distribution/G_pose_embs", self.G_pose_embs),
            tf.summary.histogram("distribution/app_emb", self.embs),
        ])

    def _define_loss_optim(self):
        self.g_optim_embs, self.d_optim_embs, self.clip_disc_weights_embs = self._getOptimizer(self.wgan_gp_encoder, 
                                self.g_loss_embs, self.d_loss_embs, self.G_var_embs, self.D_var_embs)

    def train(self):
        x_fixed, x_target_fixed, pose_fixed, pose_target_fixed, pose_rcv_fixed, mask_fixed, mask_target_fixed, \
                                            part_bbox_fixed, part_bbox_target_fixed = self.get_image_from_loader()
        save_image(x_fixed, '{}/x_fixed.png'.format(self.model_dir))
        save_image(x_target_fixed, '{}/x_target_fixed.png'.format(self.model_dir))
        save_image((np.amax(pose_fixed, axis=-1, keepdims=True)+1.0)*127.5, '{}/pose_fixed.png'.format(self.model_dir))
        save_image((np.amax(pose_target_fixed, axis=-1, keepdims=True)+1.0)*127.5, '{}/pose_target_fixed.png'.format(self.model_dir))
        save_image(mask_fixed, '{}/mask_fixed.png'.format(self.model_dir))
        save_image(mask_target_fixed, '{}/mask_target_fixed.png'.format(self.model_dir))

        for step in trange(self.start_step, self.max_step):
            # Use GAN for Pose Embedding
            if step>0:
                self.sess.run([self.g_optim_embs])
            # Train critic
            if (self.wgan_gp_encoder.MODE == 'dcgan') or (self.wgan_gp_encoder.MODE == 'lsgan'):
                disc_ITERS = 1
            else:
                disc_ITERS = self.wgan_gp_encoder.CRITIC_ITERS
            for i in xrange(disc_ITERS):
                self.sess.run([self.d_optim_embs])
                if self.wgan_gp_encoder.MODE == 'wgan':
                    self.sess.run(self.clip_disc_weights_embs)

            fetch_dict = {}
            if step % self.log_step == self.log_step-1:
                fetch_dict.update({
                    "summary": self.summary_op
                })
            result = self.sess.run(fetch_dict)

            if step % self.log_step == self.log_step-1:
                self.summary_writer.add_summary(result['summary'], step)
                self.summary_writer.flush()

            if 0==step or 10==step or 200==step or step % (self.log_step * 3) == (self.log_step * 3)-1:
                x = process_image(x_fixed, 127.5, 127.5)
                x_target = process_image(x_target_fixed, 127.5, 127.5)
                self.generate(x, x_target, pose_fixed, part_bbox_fixed, self.model_dir, idx=step)

            if step % self.lr_update_step == self.lr_update_step - 1:
                self.sess.run([self.g_lr_update, self.d_lr_update])

            if step % (self.log_step * 30) == (self.log_step * 30)-1:
                self.saver.save(self.sess, os.path.join(self.model_dir, 'model.ckpt'), global_step=step)

    def generate(self, x_fixed, x_target_fixed, pose_fixed, part_bbox_fixed, root_path=None, path=None, idx=None, save=True):
        G_pose_rcv, G_pose = self.sess.run([self.G_pose_rcv, self.G_pose])
        G_pose_inflated = py_poseInflate(G_pose_rcv, is_normalized=True, radius=4, img_H=128, img_W=64)
        G = self.sess.run(self.G, {self.x: x_fixed, self.G_pose_inflated: G_pose_inflated, self.part_bbox: part_bbox_fixed})
        G_pose_inflated_img = np.tile(np.amax((G_pose_inflated+1)*127.5, axis=-1, keepdims=True), [1,1,1,3])
        
        ssim_G_x_list = []
        for i in xrange(G.shape[0]):
            G_gray = rgb2gray((G[i,:]).clip(min=0,max=255).astype(np.uint8))
            x_gray = rgb2gray(((x_fixed[i,:]+1)*127.5).clip(min=0,max=255).astype(np.uint8))
            ssim_G_x_list.append(ssim(G_gray, x_gray, data_range=x_gray.max() - x_gray.min(), multichannel=False))
        ssim_G_x_mean = np.mean(ssim_G_x_list)
        if path is None and save:
            path = os.path.join(root_path, '{}_G_ssim{}.png'.format(idx,ssim_G_x_mean))
            save_image(G, path)
            print("[*] Samples saved: {}".format(path))
            path = os.path.join(root_path, '{}_G_pose.png'.format(idx))
            save_image(G_pose, path)
            print("[*] Samples saved: {}".format(path))
            path = os.path.join(root_path, '{}_G_pose_inflated.png'.format(idx))
            save_image(G_pose_inflated_img, path)
            print("[*] Samples saved: {}".format(path))
        return G





#################################################################################################
####################################### DF train models #########################################

######################### DeepFashion with AppPose BodyROI ################################
class DPIG_Encoder_GAN_BodyROI_256(DPIG_Encoder_GAN_BodyROI):
    def __init__(self, config):
        self._common_init(config)
        self.keypoint_num = 18
        self.part_num = 37
        self.D_arch = config.D_arch
        if 'deepfashion' in config.dataset.lower() or 'df' in config.dataset.lower():
            if config.is_train:
                self.dataset_obj = deepfashion.get_split('train', config.data_path)
            else:
                self.dataset_obj = deepfashion.get_split('test', config.data_path)
        self.x, self.x_target, self.pose, self.pose_target, self.pose_rcv, self.pose_rcv_target, self.mask, self.mask_target, \
                self.part_bbox, self.part_bbox_target, self.part_vis, self.part_vis_target = self._load_batch_pair_pose(self.dataset_obj)

    def _define_input(self):
        self.is_train_tensor = tf.Variable(self.is_train, name='phase')
        self.Generator_fn = models.GeneratorCNN_ID_UAEAfterResidual
        self.wgan_gp = WGAN_GP(DATA_DIR='', MODE='dcgan', DIM=64, BATCH_SIZE=self.batch_size, ITERS=200000, 
                            LAMBDA=10, G_OUTPUT_DIM=self.img_H*self.img_W*3)
        self.Discriminator_fn = self._getDiscriminator(self.wgan_gp, arch=self.D_arch)

    def build_model(self):
        self._define_input()
        with tf.variable_scope("Encoder") as vs:
            pb_list = tf.split(self.part_bbox, self.part_num, axis=1)
            pv_list = tf.split(self.part_vis, self.part_num, axis=1)
            ## Part 1-7 (totally 7)
            indices = range(7)
            select_part_bbox = tf.concat([pb_list[i] for i in indices], axis=1)
            select_part_vis = tf.cast(tf.concat([pv_list[i] for i in indices], axis=1), tf.float32)
            self.embs, _, self.Encoder_var = models.GeneratorCNN_ID_Encoder_BodyROIVis(self.x,select_part_bbox, select_part_vis, len(indices), 32, 
                                            self.repeat_num+1, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, keep_part_prob=1.0, roi_size=64, reuse=False)
            ## Part 1,4-8 (totally 6)
            # # indices = [1] + range(4,9)
            # indices = [0] + range(3,8)
            ## Part 1,8-16 (totally 10)
            # indices = [0] + range(7,16)
            
        self.embs_rep = tf.tile(tf.expand_dims(self.embs,-1), [1, 1, self.img_H*self.img_W])
        self.embs_rep = tf.reshape(self.embs_rep, [self.batch_size, -1, self.img_H, self.img_W])
        self.embs_rep = nchw_to_nhwc(self.embs_rep)

        with tf.variable_scope("ID_AE") as vs:
            G, _, self.G_var = self.Generator_fn(
                    self.embs_rep, self.pose, 
                    self.channel, self.z_num, self.repeat_num-1, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, reuse=False)

        self.G_var += self.Encoder_var

        self.G = denorm_img(G, self.data_format)

        pair = tf.concat([self.x, G], 0)

        self.D_z = self.Discriminator_fn(tf.transpose( pair, [0,3,1,2] ), input_dim=3)
        self.D_var = lib.params_with_name('Discriminator.')

        D_z_pos, D_z_neg = tf.split(self.D_z, 2)

        self.g_loss, self.d_loss = self._gan_loss(self.wgan_gp, self.Discriminator_fn, D_z_pos, D_z_neg, arch=self.D_arch)
        self.PoseMaskLoss = tf.reduce_mean(tf.abs(G - self.x) * (self.mask))
        self.L1Loss = tf.reduce_mean(tf.abs(G - self.x))
        self.g_loss_only = self.g_loss

        self._define_loss_optim()
        self.summary_op = tf.summary.merge([
            tf.summary.image("G", self.G),
            tf.summary.scalar("loss/PoseMaskLoss", self.PoseMaskLoss),
            tf.summary.scalar("loss/L1Loss", self.L1Loss),
            tf.summary.scalar("loss/g_loss", self.g_loss),
            tf.summary.scalar("loss/g_loss_only", self.g_loss_only),
            tf.summary.scalar("loss/d_loss", self.d_loss),
            tf.summary.scalar("misc/d_lr", self.d_lr),
            tf.summary.scalar("misc/g_lr", self.g_lr),
        ])

    def _define_loss_optim(self):
        self.g_loss += self.L1Loss * 20
        self.g_optim, self.d_optim, self.clip_disc_weights = self._getOptimizer(self.wgan_gp, 
                                self.g_loss, self.d_loss, self.G_var, self.D_var)

    def train(self):
        x_fixed, x_target_fixed, pose_fixed, pose_target_fixed, pose_rcv_fixed, pose_rcv_target_fixed, mask_fixed, mask_target_fixed, \
            part_bbox_fixed, part_bbox_target_fixed, part_vis_fixed, part_vis_target_fixed = self.get_image_from_loader()
        save_image(x_fixed, '{}/x_fixed.png'.format(self.model_dir))
        save_image(x_target_fixed, '{}/x_target_fixed.png'.format(self.model_dir))
        save_image((np.amax(pose_fixed, axis=-1, keepdims=True)+1.0)*127.5, '{}/pose_fixed.png'.format(self.model_dir))
        save_image((np.amax(pose_target_fixed, axis=-1, keepdims=True)+1.0)*127.5, '{}/pose_target_fixed.png'.format(self.model_dir))
        save_image(mask_fixed, '{}/mask_fixed.png'.format(self.model_dir))
        save_image(mask_target_fixed, '{}/mask_target_fixed.png'.format(self.model_dir))

        for step in trange(self.start_step, self.max_step):
            if step>0:
                self.sess.run([self.g_optim])
            # Train critic
            if (self.wgan_gp.MODE == 'dcgan') or (self.wgan_gp.MODE == 'lsgan'):
                disc_ITERS = 1
            else:
                disc_ITERS = self.wgan_gp.CRITIC_ITERS
            for i in xrange(disc_ITERS):
                self.sess.run(self.d_optim)
                if self.wgan_gp.MODE == 'wgan':
                    self.sess.run(self.clip_disc_weights)

            fetch_dict = {}
            if step % self.log_step == self.log_step-1:
                fetch_dict.update({
                    "summary": self.summary_op
                })
                    # "k_t": self.k_t,
            result = self.sess.run(fetch_dict)

            if step % self.log_step == self.log_step-1:
                self.summary_writer.add_summary(result['summary'], step)
                self.summary_writer.flush()

            if 0==step or 10==step or step % (self.log_step * 3) == (self.log_step * 3)-1:
                x = process_image(x_fixed, 127.5, 127.5)
                x_target = process_image(x_target_fixed, 127.5, 127.5)
                self.generate(x, x_target, pose_fixed, part_bbox_fixed, part_vis_fixed, self.model_dir, idx=step)

            if step % self.lr_update_step == self.lr_update_step - 1:
                self.sess.run([self.g_lr_update, self.d_lr_update])

            if step % (self.log_step * 30) == (self.log_step * 30)-1:
                self.saver.save(self.sess, os.path.join(self.model_dir, 'model.ckpt'), global_step=step)

    def test(self):
        test_result_dir = os.path.join(self.model_dir, self.test_dir_name)
        test_result_dir_x = os.path.join(test_result_dir, 'x')
        test_result_dir_x_target = os.path.join(test_result_dir, 'x_target')
        test_result_dir_G = os.path.join(test_result_dir, 'G')
        test_result_dir_pose = os.path.join(test_result_dir, 'pose')
        test_result_dir_pose_target = os.path.join(test_result_dir, 'pose_target')
        test_result_dir_mask = os.path.join(test_result_dir, 'mask')
        test_result_dir_mask_target = os.path.join(test_result_dir, 'mask_target')
        if not os.path.exists(test_result_dir):
            os.makedirs(test_result_dir)
        if not os.path.exists(test_result_dir_x):
            os.makedirs(test_result_dir_x)
        if not os.path.exists(test_result_dir_x_target):
            os.makedirs(test_result_dir_x_target)
        if not os.path.exists(test_result_dir_G):
            os.makedirs(test_result_dir_G)
        if not os.path.exists(test_result_dir_pose):
            os.makedirs(test_result_dir_pose)
        if not os.path.exists(test_result_dir_pose_target):
            os.makedirs(test_result_dir_pose_target)
        if not os.path.exists(test_result_dir_mask):
            os.makedirs(test_result_dir_mask)
        if not os.path.exists(test_result_dir_mask_target):
            os.makedirs(test_result_dir_mask_target)

        for i in xrange(100): ## for test Samples
        # for i in xrange(800): ## for IS score
            x_fixed, x_target_fixed, pose_fixed, pose_target_fixed, pose_rcv_fixed, pose_rcv_target_fixed, mask_fixed, mask_target_fixed, \
                part_bbox_fixed, part_bbox_target_fixed, part_vis_fixed, part_vis_target_fixed = self.get_image_from_loader()
            x = process_image(x_fixed, 127.5, 127.5)
            x_target = process_image(x_target_fixed, 127.5, 127.5)
            if 0==i:
                x_fake = self.generate(x, x_target, pose_fixed, part_bbox_fixed, test_result_dir, part_vis_fixed, idx=i, save=True)
            else:
                x_fake = self.generate(x, x_target, pose_fixed, part_bbox_fixed, test_result_dir, part_vis_fixed, idx=i, save=False)
            p = (np.amax(pose_fixed, axis=-1, keepdims=False)+1.0)*127.5
            pt = (np.amax(pose_target_fixed, axis=-1, keepdims=False)+1.0)*127.5
            for j in xrange(self.batch_size):
                idx = i*self.batch_size+j
                im = Image.fromarray(x_fixed[j,:].astype(np.uint8))
                im.save('%s/%05d.png'%(test_result_dir_x, idx))
                im = Image.fromarray(x_target_fixed[j,:].astype(np.uint8))
                im.save('%s/%05d.png'%(test_result_dir_x_target, idx))
                im = Image.fromarray(x_fake[j,:].astype(np.uint8))
                im.save('%s/%05d.png'%(test_result_dir_G, idx))
                im = Image.fromarray(p[j,:].astype(np.uint8))
                im.save('%s/%05d.png'%(test_result_dir_pose, idx))
                im = Image.fromarray(pt[j,:].astype(np.uint8))
                im.save('%s/%05d.png'%(test_result_dir_pose_target, idx))
                # pdb.set_trace()
                im = Image.fromarray(mask_fixed[j,:].squeeze().astype(np.uint8))
                im.save('%s/%05d.png'%(test_result_dir_mask, idx))
                im = Image.fromarray(mask_target_fixed[j,:].squeeze().astype(np.uint8))
                im.save('%s/%05d.png'%(test_result_dir_mask_target, idx))
            # pdb.set_trace()
            if 0==i:
                save_image(x_fixed, '{}/x_fixed.png'.format(test_result_dir))
                save_image(x_target_fixed, '{}/x_target_fixed.png'.format(test_result_dir))
                save_image(mask_fixed, '{}/mask_fixed.png'.format(test_result_dir))
                save_image(mask_target_fixed, '{}/mask_target_fixed.png'.format(test_result_dir))
                save_image((np.amax(pose_fixed, axis=-1, keepdims=True)+1.0)*127.5, '{}/pose_fixed.png'.format(test_result_dir))
                save_image((np.amax(pose_target_fixed, axis=-1, keepdims=True)+1.0)*127.5, '{}/pose_target_fixed.png'.format(test_result_dir))

    def generate(self, x_fixed, x_target_fixed, pose_fixed, part_bbox_fixed, part_vis_fixed, root_path=None, path=None, idx=None, save=True):
        G = self.sess.run(self.G, {self.x: x_fixed, self.pose: pose_fixed, self.part_bbox: part_bbox_fixed, self.part_vis: part_vis_fixed})
        ssim_G_x_list = []
        for i in xrange(G.shape[0]):
            G_gray = rgb2gray((G[i,:]).clip(min=0,max=255).astype(np.uint8))
            x_gray = rgb2gray(((x_fixed[i,:]+1)*127.5).clip(min=0,max=255).astype(np.uint8))
            ssim_G_x_list.append(ssim(G_gray, x_gray, data_range=x_gray.max() - x_gray.min(), multichannel=False))
        ssim_G_x_mean = np.mean(ssim_G_x_list)
        if path is None and save:
            path = os.path.join(root_path, '{}_G_ssim{}.png'.format(idx,ssim_G_x_mean))
            save_image(G, path)
            print("[*] Samples saved: {}".format(path))
        return G

    # def generate(self, x_fixed, x_target_fixed, pose_fixed, pose_target_fixed, part_bbox_fixed, root_path=None, path=None, idx=None, save=True):
    #     G = self.sess.run(self.G, {self.x: x_fixed, self.pose: pose_fixed, self.pose_target: pose_target_fixed, self.part_bbox: part_bbox_fixed})
    #     ssim_G_x_list = []
    #     for i in xrange(G.shape[0]):
    #         G_gray = rgb2gray((G[i,:]).clip(min=0,max=255).astype(np.uint8))
    #         x_gray = rgb2gray(((x_fixed[i,:]+1)*127.5).clip(min=0,max=255).astype(np.uint8))
    #         ssim_G_x_list.append(ssim(G_gray, x_gray, data_range=x_gray.max() - x_gray.min(), multichannel=False))
    #     ssim_G_x_mean = np.mean(ssim_G_x_list)
    #     if path is None and save:
    #         path = os.path.join(root_path, '{}_G_ssim{}.png'.format(idx,ssim_G_x_mean))
    #         save_image(G, path)
    #         print("[*] Samples saved: {}".format(path))
    #     return G

    def get_image_from_loader(self):
        x, x_target, pose, pose_target, pose_rcv, pose_rcv_target, mask, mask_target, part_bbox, part_bbox_target, part_vis, part_vis_target = self.sess.run([self.x, self.x_target, self.pose, self.pose_target, 
                        self.pose_rcv, self.pose_rcv_target, self.mask, self.mask_target, self.part_bbox, self.part_bbox_target, self.part_vis, self.part_vis_target])
        x = unprocess_image(x, 127.5, 127.5)
        x_target = unprocess_image(x_target, 127.5, 127.5)
        mask = mask*255
        mask_target = mask_target*255
        return x, x_target, pose, pose_target, pose_rcv, pose_rcv_target, mask, mask_target, part_bbox, part_bbox_target, part_vis, part_vis_target

    def _load_batch_pair_pose(self, dataset):
        data_provider = slim.dataset_data_provider.DatasetDataProvider(dataset, common_queue_capacity=32, common_queue_min=8)

        image_raw_0, image_raw_1, label, pose_rcv_0, pose_rcv_1, mask_0, mask_1, part_bbox_0, part_bbox_1, part_vis_0, part_vis_1  = data_provider.get([
            'image_raw_0', 'image_raw_1', 'label', 'pose_peaks_0_rcv', 'pose_peaks_1_rcv', 'pose_mask_r4_0', 'pose_mask_r4_1', 
            'part_bbox_0', 'part_bbox_1', 'part_vis_0', 'part_vis_1'])

        image_raw_0 = tf.reshape(image_raw_0, [256, 256, 3])        
        image_raw_1 = tf.reshape(image_raw_1, [256, 256, 3]) 
        mask_0 = tf.cast(tf.reshape(mask_0, [256, 256, 1]), tf.float32)
        mask_1 = tf.cast(tf.reshape(mask_1, [256, 256, 1]), tf.float32)
        part_bbox_0 = tf.reshape(part_bbox_0, [self.part_num, 4])        
        part_bbox_1 = tf.reshape(part_bbox_1, [self.part_num, 4]) 

        images_0, images_1, poses_rcv_0, poses_rcv_1, masks_0, masks_1, part_bboxs_0, part_bboxs_1, part_viss_0, part_viss_1 = tf.train.batch([image_raw_0, 
                                                    image_raw_1, pose_rcv_0, pose_rcv_1, mask_0, mask_1, part_bbox_0, part_bbox_1, part_vis_0, part_vis_1], 
                                                    batch_size=self.batch_size, num_threads=self.num_threads, capacity=self.capacityCoff * self.batch_size)

        images_0 = process_image(tf.to_float(images_0), 127.5, 127.5)
        images_1 = process_image(tf.to_float(images_1), 127.5, 127.5)
        poses_0 = tf.cast(coord2channel_simple_rcv(poses_rcv_0, keypoint_num=18, is_normalized=False, img_H=256, img_W=256), tf.float32)
        poses_1 = tf.cast(coord2channel_simple_rcv(poses_rcv_1, keypoint_num=18, is_normalized=False, img_H=256, img_W=256), tf.float32)
        poses_0 = tf_poseInflate(poses_0, keypoint_num=18, radius=4, img_H=self.img_H, img_W=self.img_W)
        poses_1 = tf_poseInflate(poses_1, keypoint_num=18, radius=4, img_H=self.img_H, img_W=self.img_W)
        
        return images_0, images_1, poses_0, poses_1, poses_rcv_0, poses_rcv_1, masks_0, masks_1, part_bboxs_0, part_bboxs_1, part_viss_0, part_viss_1



class DPIG_Encoder_subSampleAppNet_GAN_BodyROI_256(DPIG_Encoder_GAN_BodyROI_256):
    def init_net(self):
        self.build_model()

        if self.pretrained_path is not None:
            var1 = tf.get_collection(tf.GraphKeys.VARIABLES, scope='Encoder')
            var2 = tf.get_collection(tf.GraphKeys.VARIABLES, scope='ID_AE')
            self.saverPart = tf.train.Saver(var1+var2, max_to_keep=20)
            
        self.saver = tf.train.Saver(max_to_keep=20)
        self.summary_writer = tf.summary.FileWriter(self.model_dir)

        sv = tf.train.Supervisor(logdir=self.model_dir,
                                is_chief=True,
                                saver=None,
                                summary_op=None,
                                summary_writer=self.summary_writer,
                                global_step=self.step,
                                save_model_secs=0,
                                ready_for_local_init_op=None)

        gpu_options = tf.GPUOptions(allow_growth=False)
        sess_config = tf.ConfigProto(allow_soft_placement=True,
                                    gpu_options=gpu_options)
        self.sess = sv.prepare_or_wait_for_session(config=sess_config)
        if self.pretrained_path is not None:
            self.saverPart.restore(self.sess, self.pretrained_path)
            print('restored from pretrained_path:', self.pretrained_path)
        elif self.ckpt_path is not None:
            self.saver.restore(self.sess, self.ckpt_path)
            print('restored from ckpt_path:', self.ckpt_path)

    def _define_input(self):
        self.is_train_tensor = tf.Variable(self.is_train, name='phase')
        self.Generator_encoder_fn = models.GeneratorCNN_ID_Encoder
        self.Generator_fn = models.GeneratorCNN_ID_UAEAfterResidual
        self.wgan_gp_encoder = WGAN_GP(DATA_DIR='', MODE='wgan', DIM=64, BATCH_SIZE=self.batch_size, ITERS=200000, 
                            LAMBDA=10, G_OUTPUT_DIM=7*32)
        self.Discriminator_encoder_fn = self._getDiscriminator(self.wgan_gp_encoder, arch='FCDis')

    def build_model(self):
        self._define_input()
        with tf.variable_scope("Encoder") as vs:
            pb_list = tf.split(self.part_bbox, self.part_num, axis=1)
            pv_list = tf.split(self.part_vis, self.part_num, axis=1)
            ## Part 1-7 (totally 7)
            indices = range(7)
            select_part_bbox = tf.concat([pb_list[i] for i in indices], axis=1)
            self.embs, _, self.Encoder_var = models.GeneratorCNN_ID_Encoder_BodyROI(self.x, select_part_bbox, len(indices), 32, 
                                            self.repeat_num+1, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, keep_part_prob=1.0, reuse=False)

        self.embs_rep = tf.tile(tf.expand_dims(self.embs,-1), [1, 1, self.img_H*self.img_W])
        self.embs_rep = tf.reshape(self.embs_rep, [self.batch_size, -1, self.img_H, self.img_W])
        self.embs_rep = nchw_to_nhwc(self.embs_rep)

        with tf.variable_scope("ID_AE") as vs:
            G, _, self.G_var = self.Generator_fn(
                    self.embs_rep, self.pose, 
                    self.channel, self.z_num, self.repeat_num-1, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, reuse=False)

        embs_shape = self.embs.get_shape().as_list()
        with tf.variable_scope("Gaussian_FC") as vs:
            self.app_embs, self.G_var_app_embs = models.GaussianFCRes(embs_shape, embs_shape[-1], repeat_num=4, hidden_num=512, data_format=self.data_format, activation_fn=LeakyReLU)

        ## Adversarial for Gaussian
        encode_pair = tf.concat([self.embs, self.app_embs], 0)
        self.D_z_embs = self.Discriminator_encoder_fn(encode_pair, input_dim=encode_pair.get_shape().as_list()[-1], FC_DIM=512, n_layers=3, reuse=False, name='FCDis_')
        self.D_var_embs = lib.params_with_name('FCDis_Discriminator.')
        D_z_pos_embs, D_z_neg_embs = tf.split(self.D_z_embs, 2)
        self.g_loss_embs, self.d_loss_embs = self._gan_loss(self.wgan_gp_encoder, self.Discriminator_encoder_fn, 
                                            D_z_pos_embs, D_z_neg_embs, self.embs, self.app_embs)

        self.embs_app_rep = tf.tile(tf.expand_dims(self.app_embs,-1), [1, 1, self.img_H*self.img_W])
        self.embs_app_rep = tf.reshape(self.embs_app_rep, [self.batch_size, -1, self.img_H, self.img_W])
        self.embs_app_rep = nchw_to_nhwc(self.embs_app_rep)
        with tf.variable_scope("ID_AE") as vs:
            G_gaussian_app, _, _ = self.Generator_fn(
                    self.embs_app_rep, self.pose, 
                    self.channel, self.z_num, self.repeat_num-1, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, reuse=True)
        self.G = denorm_img(G_gaussian_app, self.data_format)

        self._define_loss_optim()
        self.summary_op = tf.summary.merge([
            tf.summary.image("G", self.G),
            tf.summary.scalar("loss/g_loss_embs", self.g_loss_embs),
            tf.summary.scalar("loss/d_loss_embs", self.d_loss_embs),
            tf.summary.scalar("misc/d_lr", self.d_lr),
            tf.summary.scalar("misc/g_lr", self.g_lr),
        ])

    def _define_loss_optim(self):
        self.g_optim_embs, self.d_optim_embs, self.clip_disc_weights_embs = self._getOptimizer(self.wgan_gp_encoder, 
                                self.g_loss_embs, self.d_loss_embs, self.G_var_app_embs, self.D_var_embs)

    def train(self):
        x_fixed, x_target_fixed, pose_fixed, pose_target_fixed, mask_fixed, mask_target_fixed, part_bbox_fixed, part_bbox_target_fixed = self.get_image_from_loader()
        save_image(x_fixed, '{}/x_fixed.png'.format(self.model_dir))
        save_image(x_target_fixed, '{}/x_target_fixed.png'.format(self.model_dir))
        save_image((np.amax(pose_fixed, axis=-1, keepdims=True)+1.0)*127.5, '{}/pose_fixed.png'.format(self.model_dir))
        save_image((np.amax(pose_target_fixed, axis=-1, keepdims=True)+1.0)*127.5, '{}/pose_target_fixed.png'.format(self.model_dir))
        save_image(mask_fixed, '{}/mask_fixed.png'.format(self.model_dir))
        save_image(mask_target_fixed, '{}/mask_target_fixed.png'.format(self.model_dir))

        for step in trange(self.start_step, self.max_step):
            if step>0:
                self.sess.run([self.g_optim_embs])
            # Train critic
            if (self.wgan_gp_encoder.MODE == 'dcgan') or (self.wgan_gp_encoder.MODE == 'lsgan'):
                disc_ITERS = 1
            else:
                disc_ITERS = self.wgan_gp_encoder.CRITIC_ITERS
            for i in xrange(disc_ITERS):
                self.sess.run(self.d_optim_embs)
                if self.wgan_gp_encoder.MODE == 'wgan':
                    self.sess.run(self.clip_disc_weights_embs)

            fetch_dict = {}
            if step % self.log_step == self.log_step-1:
                fetch_dict.update({
                    "summary": self.summary_op
                })
            result = self.sess.run(fetch_dict)

            if step % self.log_step == self.log_step-1:
                self.summary_writer.add_summary(result['summary'], step)
                self.summary_writer.flush()

            if 0==step or step % (self.log_step * 3) == (self.log_step * 3)-1:
                x = process_image(x_fixed, 127.5, 127.5)
                x_target = process_image(x_target_fixed, 127.5, 127.5)
                self.generate(x, x_target, pose_fixed, pose_target_fixed, part_bbox_fixed, self.model_dir, idx=step)

            if step % self.lr_update_step == self.lr_update_step - 1:
                self.sess.run([self.g_lr_update, self.d_lr_update])

            if step % (self.log_step * 30) == (self.log_step * 30)-1:
                self.saver.save(self.sess, os.path.join(self.model_dir, 'model.ckpt'), global_step=step)

class DPIG_PoseRCV_AE_BodyROI_256(DPIG_Encoder_GAN_BodyROI_256):
    def init_net(self):
        self.build_model()

        if self.pretrained_path is not None:
            var1 = tf.get_collection(tf.GraphKeys.VARIABLES, scope='Encoder')
            var2 = tf.get_collection(tf.GraphKeys.VARIABLES, scope='ID_AE')
            self.saverPart = tf.train.Saver(var1+var2, max_to_keep=20)
            
        self.saver = tf.train.Saver(max_to_keep=20)
        self.summary_writer = tf.summary.FileWriter(self.model_dir)

        sv = tf.train.Supervisor(logdir=self.model_dir,
                                is_chief=True,
                                saver=None,
                                summary_op=None,
                                summary_writer=self.summary_writer,
                                global_step=self.step,
                                save_model_secs=0,
                                ready_for_local_init_op=None)

        gpu_options = tf.GPUOptions(allow_growth=True)
        sess_config = tf.ConfigProto(allow_soft_placement=True,
                                    gpu_options=gpu_options)
        self.sess = sv.prepare_or_wait_for_session(config=sess_config)
        if self.pretrained_path is not None:
            self.saverPart.restore(self.sess, self.pretrained_path)
            print('restored from pretrained_path:', self.pretrained_path)
        elif self.ckpt_path is not None:
            self.saver.restore(self.sess, self.ckpt_path)
            print('restored from ckpt_path:', self.ckpt_path)

    def _define_input(self):
        self.is_train_tensor = tf.Variable(self.is_train, name='phase')
        self.Generator_encoder_fn = models.GeneratorCNN_ID_Encoder
        self.Generator_fn = models.GeneratorCNN_ID_UAEAfterResidual
        self.wgan_gp = WGAN_GP(DATA_DIR='', MODE='dcgan', DIM=64, BATCH_SIZE=self.batch_size, ITERS=200000, 
                            LAMBDA=10, G_OUTPUT_DIM=self.keypoint_num*3)
        self.Discriminator_fn = self._getDiscriminator(self.wgan_gp, arch='FCDis')

    def build_model(self):
        self._define_input()
        with tf.variable_scope("PoseAE") as vs:
            ## Norm to [-1, 1]
            pose_rcv_norm = tf.reshape(self.pose_rcv, [self.batch_size, self.keypoint_num, 3])
            R = tf.cast(tf.slice(pose_rcv_norm, [0,0,0], [-1,-1,1]), tf.float32)/float(self.img_H)*2.0 - 1
            C = tf.cast(tf.slice(pose_rcv_norm, [0,0,1], [-1,-1,1]), tf.float32)/float(self.img_W)*2.0 - 1
            V = tf.cast(tf.slice(pose_rcv_norm, [0,0,2], [-1,-1,1]), tf.float32)
            pose_rcv_norm = tf.concat([R,C,V], axis=-1)
            self.pose_embs, self.G_var_encoder = models.PoseEncoderFCRes(tf.reshape(pose_rcv_norm, [self.batch_size,-1]), 
                        z_num=32, repeat_num=4, hidden_num=512, data_format=self.data_format, activation_fn=LeakyReLU, reuse=False)

            if self.sample_pose: ## Sampling new poses during testing
                self.pose_embs = tf.random_normal(tf.shape(self.pose_embs), mean=0.0, stddev=0.2, dtype=tf.float32)

            G_pose_coord, G_pose_visible, self.G_var_decoder = models.PoseDecoderFCRes(self.pose_embs, self.keypoint_num, 
                        repeat_num=4, hidden_num=512, data_format=self.data_format, activation_fn=LeakyReLU, reuse=False)

            self.G_var_pose = self.G_var_encoder + self.G_var_decoder

        G_pose_rcv = tf.concat([tf.reshape(G_pose_coord, [self.batch_size,self.keypoint_num,2]), tf.expand_dims(G_pose_visible,-1)], axis=-1)
        self.G_pose_rcv = G_pose_rcv

        self.reconstruct_loss = tf.reduce_mean(tf.square(pose_rcv_norm - G_pose_rcv))


        self._define_loss_optim()
        self.summary_op = tf.summary.merge([
            tf.summary.scalar("loss/reconstruct_loss", self.reconstruct_loss),
        ])

    def _define_loss_optim(self):
        self.g_optim = tf.train.AdamOptimizer(learning_rate=self.g_lr, beta1=0.5).minimize(self.reconstruct_loss * 20,
                                              var_list=self.G_var_pose, colocate_gradients_with_ops=True)

    def train(self):
        x_fixed, x_target_fixed, pose_fixed, pose_target_fixed, pose_rcv_fixed, pose_rcv_target_fixed, mask_fixed, mask_target_fixed, \
            part_bbox_fixed, part_bbox_target_fixed, part_vis_fixed, part_vis_target_fixed = self.get_image_from_loader()
        save_image(x_fixed, '{}/x_fixed.png'.format(self.model_dir))
        save_image(x_target_fixed, '{}/x_target_fixed.png'.format(self.model_dir))
        # save_image((np.amax(pose_fixed, axis=-1, keepdims=True)+1.0)*127.5, '{}/pose_fixed.png'.format(self.model_dir))
        # save_image((np.amax(pose_target_fixed, axis=-1, keepdims=True)+1.0)*127.5, '{}/pose_target_fixed.png'.format(self.model_dir))
        save_image(mask_fixed, '{}/mask_fixed.png'.format(self.model_dir))
        save_image(mask_target_fixed, '{}/mask_target_fixed.png'.format(self.model_dir))

        for step in trange(self.start_step, self.max_step):
            if step>0:
                self.sess.run([self.g_optim])

            if step % self.log_step == self.log_step-1:
                fetch_dict = {
                    "summary": self.summary_op,
                    "reconstruct_loss": self.reconstruct_loss
                }
                result = self.sess.run(fetch_dict)
                self.summary_writer.add_summary(result['summary'], step)
                self.summary_writer.flush()
                print('reconstruct_loss:%f'%result['reconstruct_loss'])

            if step % self.lr_update_step == self.lr_update_step - 1:
                self.sess.run([self.g_lr_update, self.d_lr_update])

            if step % (self.log_step * 30) == (self.log_step * 30)-1:
                self.saver.save(self.sess, os.path.join(self.model_dir, 'model.ckpt'), global_step=step)



class DPIG_subnetSamplePoseRCV_GAN_BodyROI_256(DPIG_PoseRCV_AE_BodyROI_256):
    def init_net(self):
        self.build_model()

        if self.pretrained_path is not None:
            var1 = tf.get_collection(tf.GraphKeys.VARIABLES, scope='Encoder')
            var2 = tf.get_collection(tf.GraphKeys.VARIABLES, scope='ID_AE')
            self.saverPart = tf.train.Saver(var1+var2, max_to_keep=20)
            
        if self.pretrained_poseAE_path is not None:
            var = tf.get_collection(tf.GraphKeys.VARIABLES, scope='PoseAE')
            self.saverPoseAEPart = tf.train.Saver(var, max_to_keep=20)

        self.saver = tf.train.Saver(max_to_keep=20)
        self.summary_writer = tf.summary.FileWriter(self.model_dir)

        sv = tf.train.Supervisor(logdir=self.model_dir,
                                is_chief=True,
                                saver=None,
                                summary_op=None,
                                summary_writer=self.summary_writer,
                                global_step=self.step,
                                save_model_secs=0,
                                ready_for_local_init_op=None)

        gpu_options = tf.GPUOptions(allow_growth=True)
        sess_config = tf.ConfigProto(allow_soft_placement=True,
                                    gpu_options=gpu_options)
        self.sess = sv.prepare_or_wait_for_session(config=sess_config)
        if self.pretrained_path is not None:
            self.saverPart.restore(self.sess, self.pretrained_path)
            print('restored from pretrained_path:', self.pretrained_path)
        if self.pretrained_poseAE_path is not None:
            self.saverPoseAEPart.restore(self.sess, self.pretrained_poseAE_path)
            print('restored from pretrained_poseAE_path:', self.pretrained_poseAE_path)
        if self.ckpt_path is not None:
            self.saver.restore(self.sess, self.ckpt_path)
            print('restored from ckpt_path:', self.ckpt_path)

    def _define_input(self):
        self.is_train_tensor = tf.Variable(self.is_train, name='phase')
        self.Generator_encoder_fn = models.GeneratorCNN_ID_Encoder
        self.Generator_fn = models.GeneratorCNN_ID_UAEAfterResidual
        self.wgan_gp_encoder = WGAN_GP(DATA_DIR='', MODE='wgan', DIM=64, BATCH_SIZE=self.batch_size, ITERS=200000, 
                            LAMBDA=10, G_OUTPUT_DIM=self.keypoint_num*3)
        self.Discriminator_encoder_fn = self._getDiscriminator(self.wgan_gp_encoder, arch='FCDis')

    def build_model(self):
        self._define_input()
        with tf.variable_scope("PoseAE") as vs:
            ## Norm to [-1, 1]
            pose_rcv_norm = tf.reshape(self.pose_rcv, [self.batch_size, self.keypoint_num, 3])
            R = tf.cast(tf.slice(pose_rcv_norm, [0,0,0], [-1,-1,1]), tf.float32)/float(self.img_H)*2.0 - 1
            C = tf.cast(tf.slice(pose_rcv_norm, [0,0,1], [-1,-1,1]), tf.float32)/float(self.img_W)*2.0 - 1
            V = tf.cast(tf.slice(pose_rcv_norm, [0,0,2], [-1,-1,1]), tf.float32)
            pose_rcv_norm = tf.concat([R,C,V], axis=-1)

            self.pose_embs, self.G_var_encoder = models.PoseEncoderFCRes(tf.reshape(pose_rcv_norm, [self.batch_size,-1]), 
                        z_num=32, repeat_num=4, hidden_num=512, data_format=self.data_format, activation_fn=LeakyReLU, reuse=False)
            
        embs_shape = self.pose_embs.get_shape().as_list()
        # with tf.variable_scope("Gaussian_FC") as vs:
        with tf.variable_scope("PoseGaussian") as vs:
            self.G_pose_embs, self.G_var_embs = models.GaussianFCRes(embs_shape, embs_shape[-1], repeat_num=4, hidden_num=512, data_format=self.data_format, activation_fn=LeakyReLU)
            # self.G_pose_embs, self.G_var_embs = models.GaussianFCRes(embs_shape, embs_shape[-1], repeat_num=6, hidden_num=1024, data_format=self.data_format, activation_fn=LeakyReLU)

            
        with tf.variable_scope("PoseAE") as vs:
            G_pose_coord, G_pose_visible, self.G_var_decoder = models.PoseDecoderFCRes(self.G_pose_embs, self.keypoint_num, 
                        repeat_num=4, hidden_num=512, data_format=self.data_format, activation_fn=LeakyReLU, reuse=False)
            
        G_pose_rcv = tf.concat([tf.reshape(G_pose_coord, [self.batch_size,self.keypoint_num,2]), tf.expand_dims(G_pose_visible,-1)], axis=-1)
        G_pose = coord2channel_simple_rcv(G_pose_rcv, self.keypoint_num, is_normalized=True, img_H=self.img_H, img_W=self.img_W)
        self.G_pose = denorm_img(tf.tile(tf.reduce_max(G_pose, axis=-1, keep_dims=True), [1,1,1,3]), self.data_format)
        self.G_pose_rcv = G_pose_rcv

        ## Adversarial for pose_embs
        self.pose_embs = tf.reshape(self.pose_embs, [self.batch_size,-1])
        self.G_pose_embs = tf.reshape(self.G_pose_embs, [self.batch_size,-1])
        encode_pair = tf.concat([self.pose_embs, self.G_pose_embs], 0)
        self.D_z_embs = self.Discriminator_encoder_fn(encode_pair, input_dim=encode_pair.get_shape().as_list()[-1], FC_DIM=512, n_layers=3, reuse=False, name='Pose_emb_')
        self.D_var_embs = lib.params_with_name('Pose_emb_Discriminator.')
        D_z_pos, D_z_neg = tf.split(self.D_z_embs, 2)
        self.g_loss_embs, self.d_loss_embs = self._gan_loss(self.wgan_gp_encoder, self.Discriminator_encoder_fn, 
                                            D_z_pos, D_z_neg, self.pose_embs, self.G_pose_embs)

        # ## Use the pose to generate person with pretrained generator
        # with tf.variable_scope("Encoder") as vs:
        #     pb_list = tf.split(self.part_bbox, self.part_num, axis=1)
        #     pv_list = tf.split(self.part_vis, self.part_num, axis=1)
        #     ## Part 1,1-7 (totally 7)
        #     indices = range(7)
        #     ## Part 1,8-16 (totally 10)
        #     # indices = [1] + range(8,17)
        #     # indices = [0] + range(7,16)
        #     select_part_bbox = tf.concat([pb_list[i] for i in indices], axis=1)
        #     select_part_vis = tf.cast(tf.concat([pv_list[i] for i in indices], axis=1), tf.float32)
        #     # pdb.set_trace()
        #     self.embs, _, _, _ = models.GeneratorCNN_ID_Encoder_BodyROIVis_FgBgFeaTwoBranch(self.x, self.mask_r6, select_part_bbox, select_part_vis, len(indices), 32, 
        #                                     self.repeat_num-1, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, keep_part_prob=1.0, reuse=False)
        #     # self.embs, _, _ = models.GeneratorCNN_ID_Encoder_BodyROIVis(self.x, select_part_bbox, select_part_vis, len(indices), 32, 
        #     #                                 self.repeat_num-1, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, keep_part_prob=1.0, reuse=False)
        #     # self.embs, _, self.Encoder_var = models.GeneratorCNN_ID_Encoder_BodyROI(self.x, self.part_bbox, 7, 32, 
        #     #                                 self.repeat_num-1, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, reuse=False)

        # self.embs_rep = tf.tile(tf.expand_dims(self.embs,-1), [1, 1, self.img_H*self.img_W])
        # self.embs_rep = tf.reshape(self.embs_rep, [self.batch_size, -1, self.img_H, self.img_W])
        # self.embs_rep = nchw_to_nhwc(self.embs_rep)

        # ## Use py code to get G_pose_inflated, so the op is out of the graph
        # self.G_pose_inflated = tf.placeholder(tf.float32, shape=G_pose.get_shape())
        # with tf.variable_scope("ID_AE") as vs:
        #     G, _, _ = self.Generator_fn(
        #             self.embs_rep, self.G_pose_inflated, 
        #             self.channel, self.z_num, self.repeat_num-1, self.conv_hidden_num, self.data_format, activation_fn=tf.nn.relu, reuse=False)
        # self.G = denorm_img(G, self.data_format)


        self._define_loss_optim()
        self.summary_op = tf.summary.merge([
            tf.summary.image("G_pose", self.G_pose),
            tf.summary.scalar("loss/g_loss_embs", self.g_loss_embs),
            tf.summary.scalar("loss/d_loss_embs", self.d_loss_embs),
            tf.summary.scalar("misc/d_lr", self.d_lr),
            tf.summary.scalar("misc/g_lr", self.g_lr),
            # tf.summary.histogram("distribution/pose_emb", self.pose_embs),
            # tf.summary.histogram("distribution/G_pose_embs", self.G_pose_embs),
            # tf.summary.histogram("distribution/app_emb", self.embs),
        ])

    def _define_loss_optim(self):
        self.g_optim_embs, self.d_optim_embs, self.clip_disc_weights_embs = self._getOptimizer(self.wgan_gp_encoder, 
                                self.g_loss_embs, self.d_loss_embs, self.G_var_embs, self.D_var_embs)

    def train(self):
        x_fixed, x_target_fixed, pose_fixed, pose_target_fixed, pose_rcv_fixed, pose_rcv_target_fixed, mask_fixed, mask_target_fixed, \
            part_bbox_fixed, part_bbox_target_fixed, part_vis_fixed, part_vis_target_fixed = self.get_image_from_loader()
        save_image(x_fixed, '{}/x_fixed.png'.format(self.model_dir))
        save_image(x_target_fixed, '{}/x_target_fixed.png'.format(self.model_dir))
        save_image((np.amax(pose_fixed, axis=-1, keepdims=True)+1.0)*127.5, '{}/pose_fixed.png'.format(self.model_dir))
        save_image((np.amax(pose_target_fixed, axis=-1, keepdims=True)+1.0)*127.5, '{}/pose_target_fixed.png'.format(self.model_dir))
        save_image(mask_fixed, '{}/mask_fixed.png'.format(self.model_dir))
        save_image(mask_target_fixed, '{}/mask_target_fixed.png'.format(self.model_dir))

        for step in trange(self.start_step, self.max_step):
            # Use GAN for Pose Embedding
            if step>0:
                self.sess.run([self.g_optim_embs])
            # Train critic
            if (self.wgan_gp_encoder.MODE == 'dcgan') or (self.wgan_gp_encoder.MODE == 'lsgan'):
                disc_ITERS = 1
            else:
                disc_ITERS = self.wgan_gp_encoder.CRITIC_ITERS
            for i in xrange(disc_ITERS):
                self.sess.run([self.d_optim_embs])
                if self.wgan_gp_encoder.MODE == 'wgan':
                    self.sess.run(self.clip_disc_weights_embs)

            fetch_dict = {}
            if step % self.log_step == self.log_step-1:
                fetch_dict.update({
                    "summary": self.summary_op
                })
            result = self.sess.run(fetch_dict)

            if step % self.log_step == self.log_step-1:
                self.summary_writer.add_summary(result['summary'], step)
                self.summary_writer.flush()

            if 0==step or 10==step or 200==step or step % (self.log_step * 3) == (self.log_step * 3)-1:
                x = process_image(x_fixed, 127.5, 127.5)
                x_target = process_image(x_target_fixed, 127.5, 127.5)
                self.generate(x, x_target, pose_fixed, part_bbox_fixed, self.model_dir, idx=step)

            if step % self.lr_update_step == self.lr_update_step - 1:
                self.sess.run([self.g_lr_update, self.d_lr_update])

            if step % (self.log_step * 30) == (self.log_step * 30)-1:
                self.saver.save(self.sess, os.path.join(self.model_dir, 'model.ckpt'), global_step=step)

    def generate(self, x_fixed, x_target_fixed, pose_fixed, part_bbox_fixed, root_path=None, path=None, idx=None, save=True):
        G_pose_rcv, G_pose = self.sess.run([self.G_pose_rcv, self.G_pose])
        G_pose_inflated = py_poseInflate(G_pose_rcv, is_normalized=True, radius=4, img_H=256, img_W=256)
        # G = self.sess.run(self.G, {self.x: x_fixed, self.G_pose_inflated: G_pose_inflated, self.part_bbox: part_bbox_fixed})
        G_pose_inflated_img = np.tile(np.amax((G_pose_inflated+1)*127.5, axis=-1, keepdims=True), [1,1,1,3])
        
        # ssim_G_x_list = []
        # for i in xrange(G_pose.shape[0]):
        #     G_gray = rgb2gray((G[i,:]).clip(min=0,max=255).astype(np.uint8))
        #     x_gray = rgb2gray(((x_fixed[i,:]+1)*127.5).clip(min=0,max=255).astype(np.uint8))
        #     ssim_G_x_list.append(ssim(G_gray, x_gray, data_range=x_gray.max() - x_gray.min(), multichannel=False))
        # ssim_G_x_mean = np.mean(ssim_G_x_list)
        if path is None and save:
            # path = os.path.join(root_path, '{}_G_ssim{}.png'.format(idx,ssim_G_x_mean))
            # save_image(G, path)
            # print("[*] Samples saved: {}".format(path))
            path = os.path.join(root_path, '{}_G_pose.png'.format(idx))
            save_image(G_pose, path)
            print("[*] Samples saved: {}".format(path))
            path = os.path.join(root_path, '{}_G_pose_inflated.png'.format(idx))
            save_image(G_pose_inflated_img, path)
            print("[*] Samples saved: {}".format(path))
        return G_pose