# Copyright 2018 Changan Wang

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import numpy as np
#from scipy.misc import imread, imsave, imshow, imresize
import tensorflow as tf

from net import hourglass as hg
from utility import train_helper
from utility import mertric

from preprocessing import preprocessing
from preprocessing import dataset
import config

# hardware related configuration
tf.app.flags.DEFINE_integer(
    'num_readers', 16,#16
    'The number of parallel readers that read data from the dataset.')
tf.app.flags.DEFINE_integer(
    'num_preprocessing_threads', 48,#48
    'The number of threads used to create the batches.')
tf.app.flags.DEFINE_integer(
    'num_cpu_threads', 0,
    'The number of cpu cores used to train.')
tf.app.flags.DEFINE_float(
    'gpu_memory_fraction', 1., 'GPU memory fraction to use.')
# scaffold related configuration
tf.app.flags.DEFINE_string(
    'data_dir', '../Datasets/tfrecords',#'/media/rs/0E06CD1706CD0127/Kapok/Chi/Datasets/tfrecords',
    'The directory where the dataset input data is stored.')
tf.app.flags.DEFINE_string(
    'dataset_name', '{}_????', 'The pattern of the dataset name to load.')
tf.app.flags.DEFINE_string(
    'model_dir', './logs_hg/',
    'The parent directory where the model will be stored.')
tf.app.flags.DEFINE_integer(
    'log_every_n_steps', 10,
    'The frequency with which logs are print.')
tf.app.flags.DEFINE_integer(
    'save_summary_steps', 100,
    'The frequency with which summaries are saved, in seconds.')
tf.app.flags.DEFINE_integer(
    'save_checkpoints_secs', 3600,
    'The frequency with which the model is saved, in seconds.')
# model related configuration
tf.app.flags.DEFINE_integer(
    'train_image_size', 384,
    'The size of the input image for the model to use.')
tf.app.flags.DEFINE_integer(
    'heatmap_size', 96,
    'The size of the output heatmap of the model.')
tf.app.flags.DEFINE_float(
    'heatmap_sigma', 1.,
    'The sigma of Gaussian which generate the target heatmap.')
tf.app.flags.DEFINE_integer('feats_channals', 256, 'Number of features in the hourglass.')
tf.app.flags.DEFINE_integer('num_stacks', 4, 'Number of hourglasses to stack.')#8
tf.app.flags.DEFINE_integer('num_modules', 1, 'Number of residual modules at each location in the hourglass.')
tf.app.flags.DEFINE_float(
    'bbox_border', 25.,
    'The nearest distance of the crop border to al keypoints.')
tf.app.flags.DEFINE_integer(
    'train_epochs', 50,
    'The number of epochs to use for training.')
tf.app.flags.DEFINE_integer(
    'epochs_per_eval', 20,
    'The number of training epochs to run between evaluations.')
tf.app.flags.DEFINE_integer(
    'batch_size', 6,
    'Batch size for training and evaluation.')
tf.app.flags.DEFINE_boolean(
    'use_ohkm', True,
    'Wether we will use the ohkm for hard keypoints.')
tf.app.flags.DEFINE_string(
    'data_format', 'channels_first', # 'channels_first' or 'channels_last'
    'A flag to override the data format used in the model. channels_first '
    'provides a performance boost on GPU but is not always compatible '
    'with CPU. If left unspecified, the data format will be chosen '
    'automatically based on whether TensorFlow was built for CPU or GPU.')
# optimizer related configuration
tf.app.flags.DEFINE_integer(
    'tf_random_seed', 20180406, 'Random seed for TensorFlow initializers.')
tf.app.flags.DEFINE_float(
    'weight_decay', 0.00000, 'The weight decay on the model weights.')
tf.app.flags.DEFINE_float(
    'mse_weight', 1., 'The weight decay on the model weights.')
tf.app.flags.DEFINE_float(
    'momentum', 0.0,#0.9
    'The momentum for the MomentumOptimizer and RMSPropOptimizer.')
tf.app.flags.DEFINE_float('learning_rate', 5e-3, 'Initial learning rate.')#2.5e-4
tf.app.flags.DEFINE_float(
    'end_learning_rate', 0.000001,
    'The minimal end learning rate used by a polynomial decay learning rate.')
tf.app.flags.DEFINE_float(
    'warmup_learning_rate', 0.00001,
    'The start warm-up learning rate to avoid NAN.')
tf.app.flags.DEFINE_integer(
    'warmup_steps', 100,
    'The total steps to warm-up.')
# for learning rate piecewise_constant decay
tf.app.flags.DEFINE_string(
    'decay_boundaries', '2, 3',
    'Learning rate decay boundaries by global_step (comma-separated list).')
tf.app.flags.DEFINE_string(
    'lr_decay_factors', '1, 0.5, 0.1',
    'The values of learning_rate decay factor for each segment between boundaries (comma-separated list).')
# checkpoint related configuration
tf.app.flags.DEFINE_string(
    'checkpoint_path', None,
    'The path to a checkpoint from which to fine-tune.')
tf.app.flags.DEFINE_string(
    'checkpoint_model_scope', 'all',
    'Model scope in the checkpoint. None if the same as the trained model.')
tf.app.flags.DEFINE_string(
    #'blouse', 'dress', 'outwear', 'skirt', 'trousers', 'all'
    'model_scope', 'all',
    'Model scope name used to replace the name_scope in checkpoint.')
tf.app.flags.DEFINE_string(
    'checkpoint_exclude_scopes', None,
    'Comma-separated list of scopes of variables to exclude when restoring from a checkpoint.')
tf.app.flags.DEFINE_boolean(
    'ignore_missing_vars', True,
    'When restoring a checkpoint would ignore missing variables.')
tf.app.flags.DEFINE_boolean(
    'run_on_cloud', True,
    'Wether we will train on cloud.')
tf.app.flags.DEFINE_boolean(
    'seq_train', False,
    'Wether we will train a sequence model.')
tf.app.flags.DEFINE_string(
    'model_to_train', 'blouse, dress, outwear, skirt, trousers', #'all, blouse, dress, outwear, skirt, trousers', 'skirt, dress, outwear, trousers',
    'The sub-model to train (comma-separated list).')

FLAGS = tf.app.flags.FLAGS
#--model_scope=blouse --checkpoint_path=./logs/all --data_format=channels_last --batch_size=1
def input_pipeline(is_training=True, model_scope=FLAGS.model_scope, num_epochs=FLAGS.epochs_per_eval):
    if 'all' in model_scope:
        lnorm_table = tf.contrib.lookup.HashTable(tf.contrib.lookup.KeyValueTensorInitializer(tf.constant(config.global_norm_key, dtype=tf.int64),
                                                                tf.constant(config.global_norm_lvalues, dtype=tf.int64)), 0)
        rnorm_table = tf.contrib.lookup.HashTable(tf.contrib.lookup.KeyValueTensorInitializer(tf.constant(config.global_norm_key, dtype=tf.int64),
                                                                tf.constant(config.global_norm_rvalues, dtype=tf.int64)), 1)
    else:
        lnorm_table = tf.contrib.lookup.HashTable(tf.contrib.lookup.KeyValueTensorInitializer(tf.constant(config.local_norm_key, dtype=tf.int64),
                                                                tf.constant(config.local_norm_lvalues, dtype=tf.int64)), 0)
        rnorm_table = tf.contrib.lookup.HashTable(tf.contrib.lookup.KeyValueTensorInitializer(tf.constant(config.local_norm_key, dtype=tf.int64),
                                                                tf.constant(config.local_norm_rvalues, dtype=tf.int64)), 1)

    preprocessing_fn = lambda org_image, classid, shape, key_x, key_y, key_v: preprocessing.preprocess_image(org_image, classid, shape, FLAGS.train_image_size, FLAGS.train_image_size, key_x, key_y, key_v, (lnorm_table, rnorm_table), is_training=is_training, data_format=('NCHW' if FLAGS.data_format=='channels_first' else 'NHWC'), category=(model_scope if 'all' not in model_scope else '*'), bbox_border=FLAGS.bbox_border, heatmap_sigma=FLAGS.heatmap_sigma, heatmap_size=FLAGS.heatmap_size)

    images, shape, classid, targets, key_v, isvalid, norm_value = dataset.slim_get_split(FLAGS.data_dir, preprocessing_fn, FLAGS.batch_size, FLAGS.num_readers, FLAGS.num_preprocessing_threads, num_epochs=num_epochs, is_training=is_training, file_pattern=FLAGS.dataset_name, category=(model_scope if 'all' not in model_scope else '*'), reader=None)

    return images, {'targets': targets, 'key_v': key_v, 'shape': shape, 'classid': classid, 'isvalid': isvalid, 'norm_value': norm_value}

if config.PRED_DEBUG:
  from scipy.misc import imread, imsave, imshow, imresize
  def save_image_with_heatmap(image, height, width, heatmap_size, targets, pred_heatmap, indR, indG, indB):
      if not hasattr(save_image_with_heatmap, "counter"):
          save_image_with_heatmap.counter = 0  # it doesn't exist yet, so initialize it
      save_image_with_heatmap.counter += 1

      img_to_save = np.array(image.tolist()) + 128
      #print(img_to_save.shape)

      img_to_save = img_to_save.astype(np.uint8)

      heatmap0 = np.sum(targets[indR, ...], axis=0).astype(np.uint8)
      heatmap1 = np.sum(targets[indG, ...], axis=0).astype(np.uint8)
      heatmap2 = np.sum(targets[indB, ...], axis=0).astype(np.uint8) if len(indB) > 0 else np.zeros((heatmap_size, heatmap_size), dtype=np.float32)

      img_to_save = imresize(img_to_save, (height, width), interp='lanczos')
      heatmap0 = imresize(heatmap0, (height, width), interp='lanczos')
      heatmap1 = imresize(heatmap1, (height, width), interp='lanczos')
      heatmap2 = imresize(heatmap2, (height, width), interp='lanczos')

      img_to_save = img_to_save/2
      img_to_save[:,:,0] = np.clip((img_to_save[:,:,0] + heatmap0 + heatmap2), 0, 255)
      img_to_save[:,:,1] = np.clip((img_to_save[:,:,1] + heatmap1 + heatmap2), 0, 255)
      #img_to_save[:,:,2] = np.clip((img_to_save[:,:,2]/4. + heatmap2), 0, 255)
      file_name = 'targets_{}.jpg'.format(save_image_with_heatmap.counter)
      imsave(os.path.join(config.DEBUG_DIR, file_name), img_to_save.astype(np.uint8))

      pred_heatmap = np.array(pred_heatmap.tolist())
      #print(pred_heatmap.shape)
      for ind in range(pred_heatmap.shape[0]):
        img = pred_heatmap[ind]
        img = img - img.min()
        img *= 255.0/img.max()
        file_name = 'heatmap_{}_{}.jpg'.format(save_image_with_heatmap.counter, ind)
        imsave(os.path.join(config.DEBUG_DIR, file_name), img.astype(np.uint8))
      return save_image_with_heatmap.counter

def get_keypoint(image, targets, predictions, heatmap_size, height, width, category, clip_at_zero=True, data_format='channels_last', name=None):
    predictions = tf.reshape(predictions, [1, -1, heatmap_size*heatmap_size])

    pred_max = tf.reduce_max(predictions, axis=-1)
    pred_indices = tf.argmax(predictions, axis=-1)
    pred_x, pred_y = tf.cast(tf.floormod(pred_indices, heatmap_size), tf.float32), tf.cast(tf.floordiv(pred_indices, heatmap_size), tf.float32)

    width, height = tf.cast(width, tf.float32), tf.cast(height, tf.float32)
    pred_x, pred_y = pred_x * width / tf.cast(heatmap_size, tf.float32), pred_y * height / tf.cast(heatmap_size, tf.float32)

    if clip_at_zero:
      pred_x, pred_y =  pred_x * tf.cast(pred_max>0, tf.float32), pred_y * tf.cast(pred_max>0, tf.float32)
      pred_x = pred_x * tf.cast(pred_max>0, tf.float32) + tf.cast(pred_max<=0, tf.float32) * (width / 2.)
      pred_y = pred_y * tf.cast(pred_max>0, tf.float32) + tf.cast(pred_max<=0, tf.float32) * (height / 2.)

    if config.PRED_DEBUG:
      pred_indices_ = tf.squeeze(pred_indices)
      image_ = tf.squeeze(image) * 255.
      pred_heatmap = tf.one_hot(pred_indices_, heatmap_size*heatmap_size, on_value=1., off_value=0., axis=-1, dtype=tf.float32)

      pred_heatmap = tf.reshape(pred_heatmap, [-1, heatmap_size, heatmap_size])
      if data_format == 'channels_first':
        image_ = tf.transpose(image_, perm=(1, 2, 0))
      save_image_op = tf.py_func(save_image_with_heatmap,
                                  [image_, height, width,
                                  heatmap_size,
                                  tf.reshape(pred_heatmap * 255., [-1, heatmap_size, heatmap_size]),
                                  tf.reshape(predictions, [-1, heatmap_size, heatmap_size]),
                                  config.left_right_group_map[category][0],
                                  config.left_right_group_map[category][1],
                                  config.left_right_group_map[category][2]],
                                  tf.int64, stateful=True)
      with tf.control_dependencies([save_image_op]):
        pred_x, pred_y = pred_x * 1., pred_y * 1.
    return pred_x, pred_y

def keypoint_model_fn(features, labels, mode, params):
    targets = labels['targets']
    shape = labels['shape']
    classid = labels['classid']
    key_v = labels['key_v']
    isvalid = labels['isvalid']
    norm_value = labels['norm_value']

    cur_batch_size = tf.shape(features)[0]
    #features= tf.ones_like(features)

    with tf.variable_scope(params['model_scope'], default_name=None, values=[features], reuse=tf.AUTO_REUSE):
        pred_outputs = hg.create_model(features, params['num_stacks'], params['feats_channals'],
                            config.class_num_joints[(params['model_scope'] if 'all' not in params['model_scope'] else '*')], params['num_modules'],
                            (mode == tf.estimator.ModeKeys.TRAIN), params['data_format'])

    if params['data_format'] == 'channels_last':
        pred_outputs = [tf.transpose(pred_outputs[ind], [0, 3, 1, 2], name='outputs_trans_{}'.format(ind)) for ind in list(range(len(pred_outputs)))]

    score_map = pred_outputs[-1]

    pred_x, pred_y = get_keypoint(features, targets, score_map, params['heatmap_size'], params['train_image_size'], params['train_image_size'], (params['model_scope'] if 'all' not in params['model_scope'] else '*'), clip_at_zero=True, data_format=params['data_format'])

    # this is important!!!
    targets = 255. * targets

    # print(key_v)
    #targets = tf.reshape(255.*tf.one_hot(tf.ones_like(key_v,tf.int64)*(32*64+32), params['heatmap_size']*params['heatmap_size']), [cur_batch_size,-1,params['heatmap_size'],params['heatmap_size']])
    #norm_value = tf.ones_like(norm_value)
    # score_map = tf.reshape(tf.one_hot(tf.ones_like(key_v,tf.int64)*(31*64+31), params['heatmap_size']*params['heatmap_size']), [cur_batch_size,-1,params['heatmap_size'],params['heatmap_size']])

    #with tf.control_dependencies([pred_x, pred_y]):
    ne_mertric = mertric.normalized_error(targets, score_map, norm_value, key_v, isvalid,
                             cur_batch_size,
                             config.class_num_joints[(params['model_scope'] if 'all' not in params['model_scope'] else '*')],
                             params['heatmap_size'],
                             params['train_image_size'])

    # last_pred_mse = tf.metrics.mean_squared_error(score_map, targets,
    #                             weights=1.0 / tf.cast(cur_batch_size, tf.float32),
    #                             name='last_pred_mse')

    # filter all invisible keypoint maybe better for this task
    # all_visible = tf.logical_and(key_v>0, isvalid>0)
    # targets = tf.boolean_mask(targets, all_visible)
    # pred_outputs = [tf.boolean_mask(pred_outputs[ind], all_visible, name='boolean_mask_{}'.format(ind)) for ind in list(range(len(pred_outputs)))]
    all_visible = tf.expand_dims(tf.expand_dims(tf.cast(tf.logical_and(key_v>0, isvalid>0), tf.float32), axis=-1), axis=-1)
    targets = targets * all_visible
    pred_outputs = [pred_outputs[ind] * all_visible for ind in list(range(len(pred_outputs)))]

    sq_diff = tf.reduce_sum(tf.squared_difference(targets, pred_outputs[-1]), axis=-1)
    last_pred_mse = tf.metrics.mean_absolute_error(sq_diff, tf.zeros_like(sq_diff), name='last_pred_mse')

    metrics = {'normalized_error': ne_mertric, 'last_pred_mse':last_pred_mse}
    predictions = {'normalized_error': ne_mertric[1]}
    ne_mertric = tf.identity(ne_mertric[1], name='ne_mertric')

    base_learning_rate = params['learning_rate']
    mse_loss_list = []
    if params['use_ohkm']:
        base_learning_rate = 1.5 * base_learning_rate
        for pred_ind in list(range(len(pred_outputs) - 1)):
            mse_loss_list.append(0.6 * tf.losses.mean_squared_error(targets, pred_outputs[pred_ind],
                                weights=1.0 / tf.cast(cur_batch_size, tf.float32),
                                scope='loss_{}'.format(pred_ind),
                                loss_collection=None,#tf.GraphKeys.LOSSES,
                                # mean all elements of all pixels in all batch
                                reduction=tf.losses.Reduction.MEAN))# SUM, SUM_OVER_BATCH_SIZE, default mean by all elements

        temp_loss = tf.reduce_mean(tf.reshape(tf.losses.mean_squared_error(targets, pred_outputs[-1], weights=1.0, loss_collection=None, reduction=tf.losses.Reduction.NONE), [cur_batch_size, config.class_num_joints[(params['model_scope'] if 'all' not in params['model_scope'] else '*')], -1]), axis=-1)

        num_topk = config.class_num_joints[(params['model_scope'] if 'all' not in params['model_scope'] else '*')] // 2
        gather_col = tf.nn.top_k(temp_loss, k=num_topk, sorted=True)[1]
        gather_row = tf.reshape(tf.tile(tf.reshape(tf.range(cur_batch_size), [-1, 1]), [1, num_topk]), [-1, 1])
        gather_indcies = tf.stop_gradient(tf.stack([gather_row, tf.reshape(gather_col, [-1, 1])], axis=-1))

        select_targets = tf.gather_nd(targets, gather_indcies)
        select_heatmap = tf.gather_nd(pred_outputs[-1], gather_indcies)

        mse_loss_list.append(tf.losses.mean_squared_error(select_targets, select_heatmap,
                                weights=1.0 / tf.cast(cur_batch_size, tf.float32),
                                scope='loss_{}'.format(len(pred_outputs) - 1),
                                loss_collection=None,#tf.GraphKeys.LOSSES,
                                # mean all elements of all pixels in all batch
                                reduction=tf.losses.Reduction.MEAN))
    else:
        for pred_ind in list(range(len(pred_outputs))):
            mse_loss_list.append(tf.losses.mean_squared_error(targets, pred_outputs[pred_ind],
                                weights=1.0 / tf.cast(cur_batch_size, tf.float32),
                                scope='loss_{}'.format(pred_ind),
                                loss_collection=None,#tf.GraphKeys.LOSSES,
                                reduction=tf.losses.Reduction.MEAN))# SUM, SUM_OVER_BATCH_SIZE, default mean by all elements

    mse_loss = tf.multiply(params['mse_weight'], tf.add_n(mse_loss_list), name='mse_loss')
    tf.summary.scalar('mse', mse_loss)
    tf.losses.add_loss(mse_loss)

    # bce_loss_list = []
    # for pred_ind in list(range(len(pred_outputs))):
    #     bce_loss_list.append(tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pred_outputs[pred_ind], labels=targets, name='loss_{}'.format(pred_ind)), name='loss_mean_{}'.format(pred_ind)))

    # mse_loss = tf.multiply(params['mse_weight'] / params['num_stacks'], tf.add_n(bce_loss_list), name='mse_loss')
    # tf.summary.scalar('mse', mse_loss)
    # tf.losses.add_loss(mse_loss)

    # Add weight decay to the loss. We exclude the batch norm variables because
    # doing so leads to a small improvement in accuracy.
    loss = mse_loss + params['weight_decay'] * tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name])
    total_loss = tf.identity(loss, name='total_loss')
    tf.summary.scalar('loss', total_loss)

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, predictions=predictions, eval_metric_ops=metrics)

    if mode == tf.estimator.ModeKeys.TRAIN:
        global_step = tf.train.get_or_create_global_step()

        lr_values = [params['warmup_learning_rate']] + [base_learning_rate * decay for decay in params['lr_decay_factors']]
        learning_rate = tf.train.piecewise_constant(tf.cast(global_step, tf.int32),
                                                    [params['warmup_steps']] + [int(float(ep)*params['steps_per_epoch']) for ep in params['decay_boundaries']],
                                                    lr_values)
        truncated_learning_rate = tf.maximum(learning_rate, tf.constant(params['end_learning_rate'], dtype=learning_rate.dtype), name='learning_rate')
        tf.summary.scalar('lr', truncated_learning_rate)

        optimizer = tf.train.MomentumOptimizer(learning_rate=truncated_learning_rate,
                                                momentum=params['momentum'])

        # Batch norm requires update_ops to be added as a train_op dependency.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step)
    else:
        train_op = None

    return tf.estimator.EstimatorSpec(
                          mode=mode,
                          predictions=predictions,
                          loss=loss,
                          train_op=train_op,
                          eval_metric_ops=metrics,
                          scaffold=tf.train.Scaffold(init_fn=train_helper.get_init_fn_for_scaffold_(params['checkpoint_path'], params['model_dir'], params['checkpoint_exclude_scopes'], params['model_scope'], params['checkpoint_model_scope'], params['ignore_missing_vars'])))

def parse_comma_list(args):
    return [float(s.strip()) for s in args.split(',')]

def sub_loop(model_fn, model_scope, model_dir, run_config, train_epochs, epochs_per_eval, lr_decay_factors, decay_boundaries, checkpoint_path=None, checkpoint_exclude_scopes='', checkpoint_model_scope='', ignore_missing_vars=True):
    steps_per_epoch = config.split_size[(model_scope if 'all' not in model_scope else '*')]['train'] // FLAGS.batch_size
    fashionAI = tf.estimator.Estimator(
        model_fn=model_fn, model_dir=model_dir, config=run_config,
        params={
            'checkpoint_path': checkpoint_path,
            'model_dir': model_dir,
            'checkpoint_exclude_scopes': checkpoint_exclude_scopes,
            'model_scope': model_scope,
            'checkpoint_model_scope': checkpoint_model_scope,
            'ignore_missing_vars': ignore_missing_vars,
            'train_image_size': FLAGS.train_image_size,
            'heatmap_size': FLAGS.heatmap_size,
            'feats_channals': FLAGS.feats_channals,
            'num_stacks': FLAGS.num_stacks,
            'num_modules': FLAGS.num_modules,
            'data_format': FLAGS.data_format,
            'steps_per_epoch': steps_per_epoch,
            'batch_size': FLAGS.batch_size,
            'use_ohkm': FLAGS.use_ohkm,
            'weight_decay': FLAGS.weight_decay,
            'mse_weight': FLAGS.mse_weight,
            'momentum': FLAGS.momentum,
            'learning_rate': FLAGS.learning_rate,
            'end_learning_rate': FLAGS.end_learning_rate,
            'warmup_learning_rate': FLAGS.warmup_learning_rate,
            'warmup_steps': FLAGS.warmup_steps,
            'decay_boundaries': parse_comma_list(decay_boundaries),
            'lr_decay_factors': parse_comma_list(lr_decay_factors),
        })

    tf.gfile.MakeDirs(model_dir)
    tf.logging.info('Starting to train model {}.'.format(model_scope))
    for _ in range(train_epochs // epochs_per_eval):
        tensors_to_log = {
            'lr': 'learning_rate',
            'loss': 'total_loss',
            'mse': 'mse_loss',
            'ne': 'ne_mertric',
        }

        logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=FLAGS.log_every_n_steps, formatter=lambda dicts: '{}:'.format(model_scope) + (', '.join(['%s=%.6f' % (k, v) for k, v in dicts.items()])))

        tf.logging.info('Starting a training cycle.')
        fashionAI.train(input_fn=lambda : input_pipeline(True, model_scope, epochs_per_eval), hooks=[logging_hook], max_steps=(steps_per_epoch*train_epochs))

        tf.logging.info('Starting to evaluate.')
        eval_results = fashionAI.evaluate(input_fn=lambda : input_pipeline(False, model_scope, 1))
        tf.logging.info(eval_results)
    tf.logging.info('Finished model {}.'.format(model_scope))

def main(_):
    # Using the Winograd non-fused algorithms provides a small performance boost.
    os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = FLAGS.gpu_memory_fraction)
    sess_config = tf.ConfigProto(allow_soft_placement = True, log_device_placement = False, intra_op_parallelism_threads = FLAGS.num_cpu_threads, inter_op_parallelism_threads = FLAGS.num_cpu_threads, gpu_options = gpu_options)

    # Set up a RunConfig to only save checkpoints once per training cycle.
    run_config = tf.estimator.RunConfig().replace(
                                        save_checkpoints_secs=FLAGS.save_checkpoints_secs).replace(
                                        save_checkpoints_steps=None).replace(
                                        save_summary_steps=FLAGS.save_summary_steps).replace(
                                        keep_checkpoint_max=5).replace(
                                        tf_random_seed=FLAGS.tf_random_seed).replace(
                                        log_step_count_steps=FLAGS.log_every_n_steps).replace(
                                        session_config=sess_config)

    if FLAGS.seq_train:
        detail_params = {
            'all': {
                'model_dir' : os.path.join(FLAGS.model_dir, 'all'),
                'train_epochs': 6,
                'epochs_per_eval': 4,
                'lr_decay_factors': '1, 0.5, 0.1',
                'decay_boundaries': '3, 4',
                'model_scope': 'all',
                'checkpoint_path': None,
                'checkpoint_model_scope': '',
                'checkpoint_exclude_scopes': '',
                'ignore_missing_vars': True,
            },
            'blouse': {
                'model_dir' : os.path.join(FLAGS.model_dir, 'blouse'),
                'train_epochs': 50,
                'epochs_per_eval': 30,
                'lr_decay_factors': '1, 0.5, 0.1',
                'decay_boundaries': '15, 30',
                'model_scope': 'blouse',
                'checkpoint_path': os.path.join(FLAGS.model_dir, 'all'),
                'checkpoint_model_scope': 'all',
                'checkpoint_exclude_scopes': 'blouse/hg_heatmap',
                'ignore_missing_vars': True,
            },
            'dress': {
                'model_dir' : os.path.join(FLAGS.model_dir, 'dress'),
                'train_epochs': 50,
                'epochs_per_eval': 30,
                'lr_decay_factors': '1, 0.5, 0.1',
                'decay_boundaries': '15, 30',
                'model_scope': 'dress',
                'checkpoint_path': os.path.join(FLAGS.model_dir, 'all'),
                'checkpoint_model_scope': 'all',
                'checkpoint_exclude_scopes': 'dress/hg_heatmap',
                'ignore_missing_vars': True,
            },
            'outwear': {
                'model_dir' : os.path.join(FLAGS.model_dir, 'outwear'),
                'train_epochs': 50,
                'epochs_per_eval': 30,
                'lr_decay_factors': '1, 0.5, 0.1',
                'decay_boundaries': '15, 30',
                'model_scope': 'outwear',
                'checkpoint_path': os.path.join(FLAGS.model_dir, 'all'),
                'checkpoint_model_scope': 'all',
                'checkpoint_exclude_scopes': 'outwear/hg_heatmap',
                'ignore_missing_vars': True,
            },
            'skirt': {
                'model_dir' : os.path.join(FLAGS.model_dir, 'skirt'),
                'train_epochs': 50,
                'epochs_per_eval': 30,
                'lr_decay_factors': '1, 0.5, 0.1',
                'decay_boundaries': '15, 30',
                'model_scope': 'skirt',
                'checkpoint_path': os.path.join(FLAGS.model_dir, 'all'),
                'checkpoint_model_scope': 'all',
                'checkpoint_exclude_scopes': 'skirt/hg_heatmap',
                'ignore_missing_vars': True,
            },
            'trousers': {
                'model_dir' : os.path.join(FLAGS.model_dir, 'trousers'),
                'train_epochs': 50,
                'epochs_per_eval': 30,
                'lr_decay_factors': '1, 0.5, 0.1',
                'decay_boundaries': '15, 30',
                'model_scope': 'trousers',
                'checkpoint_path': os.path.join(FLAGS.model_dir, 'all'),
                'checkpoint_model_scope': 'all',
                'checkpoint_exclude_scopes': 'trousers/hg_heatmap',
                'ignore_missing_vars': True,
            },
        }
    else:
        detail_params = {
            'blouse': {
                'model_dir' : os.path.join(FLAGS.model_dir, 'blouse'),
                'train_epochs': 40,
                'epochs_per_eval': 15,
                'lr_decay_factors': '1, 0.5, 0.1',
                'decay_boundaries': '10, 20',
                'model_scope': 'blouse',
                'checkpoint_path': None,
                'checkpoint_model_scope': '',
                'checkpoint_exclude_scopes': '',
                'ignore_missing_vars': True,
            },
            'dress': {
                'model_dir' : os.path.join(FLAGS.model_dir, 'dress'),
                'train_epochs': 40,
                'epochs_per_eval': 15,
                'lr_decay_factors': '1, 0.5, 0.1',
                'decay_boundaries': '10, 20',
                'model_scope': 'dress',
                'checkpoint_path': None,
                'checkpoint_model_scope': '',
                'checkpoint_exclude_scopes': '',
                'ignore_missing_vars': True,
            },
            'outwear': {
                'model_dir' : os.path.join(FLAGS.model_dir, 'outwear'),
                'train_epochs': 40,
                'epochs_per_eval': 15,
                'lr_decay_factors': '1, 0.5, 0.1',
                'decay_boundaries': '10, 20',
                'model_scope': 'outwear',
                'checkpoint_path': None,
                'checkpoint_model_scope': '',
                'checkpoint_exclude_scopes': '',
                'ignore_missing_vars': True,
            },
            'skirt': {
                'model_dir' : os.path.join(FLAGS.model_dir, 'skirt'),
                'train_epochs': 40,
                'epochs_per_eval': 15,
                'lr_decay_factors': '1, 0.5, 0.1',
                'decay_boundaries': '10, 20',
                'model_scope': 'skirt',
                'checkpoint_path': None,
                'checkpoint_model_scope': '',
                'checkpoint_exclude_scopes': '',
                'ignore_missing_vars': True,
            },
            'trousers': {
                'model_dir' : os.path.join(FLAGS.model_dir, 'trousers'),
                'train_epochs': 40,
                'epochs_per_eval': 15,
                'lr_decay_factors': '1, 0.5, 0.1',
                'decay_boundaries': '10, 20',
                'model_scope': 'trousers',
                'checkpoint_path': None,
                'checkpoint_model_scope': '',
                'checkpoint_exclude_scopes': '',
                'ignore_missing_vars': True,
            },
        }
    model_to_train = [s.strip() for s in FLAGS.model_to_train.split(',')]
    for m in model_to_train:
        sub_loop(keypoint_model_fn, m, detail_params[m]['model_dir'], run_config, detail_params[m]['train_epochs'], detail_params[m]['epochs_per_eval'], detail_params[m]['lr_decay_factors'], detail_params[m]['decay_boundaries'], detail_params[m]['checkpoint_path'], detail_params[m]['checkpoint_exclude_scopes'], detail_params[m]['checkpoint_model_scope'], detail_params[m]['ignore_missing_vars'])

if __name__ == '__main__':
  tf.logging.set_verbosity(tf.logging.INFO)
  tf.app.run()