# -*- coding:utf-8 -*-
# Created Time: Oct 13 Apr 2017 04:07:50 PM CST
# Author: Taihong Xiao <xiaotaihong@126.com>

import tensorflow as tf
import os
from model import Model
from dataset import config, Dataset
import numpy as np
from scipy import misc
import argparse


def run(config, dataset, model, gpu):
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu

    batchs, labels = dataset.input()

    saver = tf.train.Saver()

    # image summary
    image_summry_op = []
    image_summry_op += [tf.summary.image('Ax_{}'.format(i), model.Axs[i], max_outputs=30) for i in range(model.n_feat)]
    image_summry_op += [tf.summary.image('Be_{}'.format(i), model.Bes[i], max_outputs=30) for i in range(model.n_feat)]
    image_summry_op += [tf.summary.image('Ax2_{}'.format(i), model.Axs2[i], max_outputs=30) for i in range(model.n_feat)]
    image_summry_op += [tf.summary.image('Be2_{}'.format(i), model.Bes2[i], max_outputs=30) for i in range(model.n_feat)]
    image_summry_op += [tf.summary.image('Ae_{}'.format(i), model.Aes[i], max_outputs=30) for i in range(model.n_feat)]
    image_summry_op += [tf.summary.image('Bx_{}'.format(i), model.Bxs[i], max_outputs=30) for i in range(model.n_feat)]

    # G loss summary
    for key in model.G_loss.keys():
        tf.summary.scalar(key, model.G_loss[key])

    loss_G_nodecay_op = tf.summary.scalar('loss_G_nodecay', model.loss_G_nodecay)
    loss_G_decay_op = tf.summary.scalar('loss_G_decay', model.loss_G_decay)
    loss_G_op = tf.summary.scalar('loss_G', model.loss_G)

    # D loss summary
    for key in model.D_loss.keys():
        tf.summary.scalar(key, model.D_loss[key])

    loss_D_op = tf.summary.scalar('loss_D', model.loss_D)

    # learning rate summary
    g_lr_op = tf.summary.scalar('g_learning_rate', model.g_lr)
    d_lr_op = tf.summary.scalar('d_learning_rate', model.d_lr)

    # merged_op = tf.contrib.deprecated.merge_all_summaries()
    merged_op = tf.summary.merge_all()

    # start training
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    ckpt = tf.train.get_checkpoint_state(config.model_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)


    writer = tf.summary.FileWriter(config.log_dir, sess.graph)
    writer.add_graph(sess.graph)


    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    for i in range(config.max_iter):
        d_num = 100 if i % 500 == 0 else 1

        batch_images, batch_labels = sess.run([batchs, labels])
        feed_dict = {model.g_lr: config.g_lr(epoch=i),
                     model.d_lr: config.d_lr(epoch=i),
                    }
        for j in range(model.n_feat):
            feed_dict[model.Axs[j]] = batch_images[2*j]
            feed_dict[model.Bes[j]] = batch_images[2*j+1]
            feed_dict[model.label_Axs[j]] = batch_labels[2*j]
            feed_dict[model.label_Bes[j]] = batch_labels[2*j+1]

        # from IPython import embed; embed();exit()
        # update D with clipping
        for j in range(d_num):
            _, loss_D_sum, _ = sess.run([model.d_opt, model.loss_D, model.clip_d],feed_dict=feed_dict)

        # update G
        _, loss_G_sum = sess.run([model.g_opt, model.loss_G], feed_dict=feed_dict)

        print('iter: {:06d},   g_loss: {}    d_loss: {}'.format(i, loss_D_sum, loss_G_sum))

        if i % 20 == 0:
            merged_summary = sess.run(merged_op, feed_dict=feed_dict)
            writer.add_summary(merged_summary, i)

        if i % 500 == 0:
            saver.save(sess, os.path.join(config.model_dir, 'model_{:06d}.ckpt'.format(i)))

            img_Axs, img_Bes, img_Aes, img_Bxs, img_Axs2, img_Bes2 = sess.run([model.Axs, model.Bes, model.Aes, model.Bxs, model.Axs2, model.Bes2],
                                                                            feed_dict=feed_dict)

            for k in range(model.n_feat):
                for j in range(5):
                    img = np.concatenate((img_Axs[k][j], img_Bes[k][j], img_Aes[k][j], img_Bxs[k][j], img_Axs2[k][j], img_Bes2[k][j]), axis=1)
                    misc.imsave(os.path.join(config.sample_img_dir, 'iter_{:06d}_{}_{}.jpg'.format(i,j, model.feature_list[k])), img)

    writer.close()
    saver.save(sess, os.path.join(config.model_dir, 'model.ckpt'))

    coord.request_stop()
    coord.join(threads)

def main():
    parser = argparse.ArgumentParser(description='test', formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument(
        '-a', '--attributes',
        nargs='+',
        type=str,
        help='Specify attribute name for training. \nAll attributes can be found in list_attr_celeba.txt'
    )
    parser.add_argument(
        '-g', '--gpu',
        default='0',
        type=str,
        help='Specify GPU id. \ndefault: %(default)s. \nUse comma to seperate several ids, for example: 0,1'
    )
    args = parser.parse_args()

    celebA = Dataset(args.attributes)
    DNA_GAN = Model(args.attributes, is_train=True)
    run(config, celebA, DNA_GAN, gpu=args.gpu)


if __name__ == "__main__":
    main()