#!/usr/bin/env python3

from config import get_logging_config, args, train_dir
from config import config as net_config

import time
import os
import sys
import socket
import logging
import logging.config
import subprocess

import tensorflow as tf
import numpy as np

import matplotlib
matplotlib.use('Agg')

from vgg import VGG
from resnet import ResNet
from utils import print_variables
from utils_tf import yxyx_to_xywh, data_augmentation
from datasets import get_dataset
from boxer import PriorBoxGrid

slim = tf.contrib.slim
streaming_mean_iou = tf.contrib.metrics.streaming_mean_iou

logging.config.dictConfig(get_logging_config(args.run_name))
log = logging.getLogger()


def objective(location, confidence, refine_ph, classes_ph,
              pos_mask, seg_logits, seg_gt, dataset, config):
    def smooth_l1(x, y):
        abs_diff = tf.abs(x-y)
        return tf.reduce_sum(tf.where(abs_diff < 1,
                                      0.5*abs_diff*abs_diff,
                                      abs_diff - 0.5),
                             1)

    def segmentation_loss(seg_logits, seg_gt, config):
        mask = seg_gt <= dataset.num_classes
        seg_logits = tf.boolean_mask(seg_logits, mask)
        seg_gt = tf.boolean_mask(seg_gt, mask)
        seg_predictions = tf.argmax(seg_logits, axis=1)

        seg_loss_local = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=seg_logits,
                                                                        labels=seg_gt)
        seg_loss = tf.reduce_mean(seg_loss_local)
        tf.summary.scalar('loss/segmentation', seg_loss)

        mean_iou, update_mean_iou = streaming_mean_iou(seg_predictions, seg_gt,
                                                       dataset.num_classes)
        tf.summary.scalar('accuracy/mean_iou', mean_iou)
        return seg_loss, mean_iou, update_mean_iou

    def detection_loss(location, confidence, refine_ph, classes_ph, pos_mask):
        neg_mask = tf.logical_not(pos_mask)
        number_of_positives = tf.reduce_sum(tf.to_int32(pos_mask))
        true_number_of_negatives = tf.minimum(3 * number_of_positives,
                                              tf.shape(pos_mask)[1] - number_of_positives)
        # max is to avoid the case where no positive boxes were sampled
        number_of_negatives = tf.maximum(1, true_number_of_negatives)
        num_pos_float = tf.to_float(tf.maximum(1, number_of_positives))
        normalizer = tf.to_float(tf.add(number_of_positives, number_of_negatives))
        tf.summary.scalar('batch/size', normalizer)

        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=confidence,
                                                                       labels=classes_ph)
        pos_class_loss = tf.reduce_sum(tf.boolean_mask(cross_entropy, pos_mask))
        tf.summary.scalar('loss/class_pos', pos_class_loss / num_pos_float)

        top_k_worst, top_k_inds = tf.nn.top_k(tf.boolean_mask(cross_entropy, neg_mask),
                                              number_of_negatives)
        # multiplication is to avoid the case where no positive boxes were sampled
        neg_class_loss = tf.reduce_sum(top_k_worst) * \
                         tf.cast(tf.greater(true_number_of_negatives, 0), tf.float32)
        class_loss = (neg_class_loss + pos_class_loss) / num_pos_float
        tf.summary.scalar('loss/class_neg', neg_class_loss / tf.to_float(number_of_negatives))
        tf.summary.scalar('loss/class', class_loss)

        # cond is to avoid the case where no positive boxes were sampled
        bbox_loss = tf.cond(tf.equal(tf.reduce_sum(tf.cast(pos_mask, tf.int32)), 0),
                            lambda: 0.0,
                            lambda: tf.reduce_mean(smooth_l1(tf.boolean_mask(location, pos_mask),
                                                             tf.boolean_mask(refine_ph, pos_mask))))
        tf.summary.scalar('loss/bbox', bbox_loss)

        inferred_class = tf.cast(tf.argmax(confidence, 2), tf.int32)
        positive_matches = tf.equal(tf.boolean_mask(inferred_class, pos_mask),
                                    tf.boolean_mask(classes_ph, pos_mask))
        hard_matches = tf.equal(tf.boolean_mask(inferred_class, neg_mask),
                                tf.boolean_mask(classes_ph, neg_mask))
        hard_matches = tf.gather(hard_matches, top_k_inds)
        train_acc = ((tf.reduce_sum(tf.to_float(positive_matches)) +
                    tf.reduce_sum(tf.to_float(hard_matches))) / normalizer)
        tf.summary.scalar('accuracy/train', train_acc)

        recognized_class = tf.argmax(confidence, 2)
        tp = tf.reduce_sum(tf.to_float(tf.logical_and(recognized_class > 0, pos_mask)))
        fp = tf.reduce_sum(tf.to_float(tf.logical_and(recognized_class > 0, neg_mask)))
        fn = tf.reduce_sum(tf.to_float(tf.logical_and(tf.equal(recognized_class, 0), pos_mask)))
        precision = tp / (tp + fp)
        recall = tp / (tp + fn)
        f1 = 2*(precision * recall)/(precision + recall)
        tf.summary.scalar('metrics/train/precision', precision)
        tf.summary.scalar('metrics/train/recall', recall)
        tf.summary.scalar('metrics/train/f1', f1)
        return class_loss, bbox_loss, train_acc, number_of_positives

    the_loss = 0
    train_acc = tf.constant(1)
    mean_iou = tf.constant(1)
    update_mean_iou = tf.constant(1)

    if args.segment:
        seg_loss, mean_iou, update_mean_iou = segmentation_loss(seg_logits, seg_gt, config)
        the_loss += seg_loss

    if args.detect:
        class_loss, bbox_loss, train_acc, number_of_positives =\
            detection_loss(location, confidence, refine_ph, classes_ph, pos_mask)
        det_loss = class_loss + bbox_loss
        the_loss += det_loss

    regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    wd_loss = tf.add_n(regularization_losses)
    tf.summary.scalar('loss/weight_decay', wd_loss)
    the_loss += wd_loss

    tf.summary.scalar('loss/full', the_loss)
    return the_loss, train_acc, mean_iou, update_mean_iou


def extract_batch(dataset, config):
    with tf.device("/cpu:0"):
        bboxer = PriorBoxGrid(config)
        data_provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset, num_readers=2,
            common_queue_capacity=512, common_queue_min=32)
        if args.segment:
            im, bbox, gt, seg = data_provider.get(['image', 'object/bbox', 'object/label',
                                                   'image/segmentation'])
        else:
            im, bbox, gt = data_provider.get(['image', 'object/bbox', 'object/label'])
            seg = tf.expand_dims(tf.zeros(tf.shape(im)[:2]), 2)
        im = tf.to_float(im)/255
        bbox = yxyx_to_xywh(tf.clip_by_value(bbox, 0.0, 1.0))

        im, bbox, gt, seg = data_augmentation(im, bbox, gt, seg, config)
        inds, cats, refine = bboxer.encode_gt_tf(bbox, gt)

        return tf.train.shuffle_batch([im, inds, refine, cats, seg],
                                      args.batch_size, 2048, 64, num_threads=4)


def train(dataset, net, config):
    image_ph, inds_ph, refine_ph, classes_ph, seg_gt = extract_batch(dataset, config)

    net.create_trunk(image_ph)

    if args.detect:
        net.create_multibox_head(dataset.num_classes)
        confidence = net.outputs['confidence']
        location = net.outputs['location']
        tf.summary.histogram('location', location)
        tf.summary.histogram('confidence', confidence)
    else:
        location, confidence = None, None

    if args.segment:
        net.create_segmentation_head(dataset.num_classes)
        seg_logits = net.outputs['segmentation']
        tf.summary.histogram('segmentation', seg_logits)
    else:
        seg_logits = None

    loss, train_acc, mean_iou, update_mean_iou = objective(location, confidence, refine_ph,
                                                           classes_ph,inds_ph, seg_logits,
                                                           seg_gt, dataset, config)

    ### setting up the learning rate ###
    global_step = slim.get_or_create_global_step()
    learning_rate = args.learning_rate

    learning_rates = [args.warmup_lr, learning_rate]
    steps = [args.warmup_step]

    if len(args.lr_decay) > 0:
        for i, step in enumerate(args.lr_decay):
            steps.append(step)
            learning_rates.append(learning_rate*10**(-i-1))

    learning_rate = tf.train.piecewise_constant(tf.to_int32(global_step),
                                                steps, learning_rates)

    tf.summary.scalar('learning_rate', learning_rate)
    #######

    if args.optimizer == 'adam':
        opt = tf.train.AdamOptimizer(learning_rate)
    elif args.optimizer == 'nesterov':
        opt = tf.train.MomentumOptimizer(learning_rate, 0.9, use_nesterov=True)
    else:
        raise ValueError

    train_vars = tf.trainable_variables()
    print_variables('train', train_vars)

    train_op = slim.learning.create_train_op(
        loss, opt,
        global_step=global_step,
        variables_to_train=train_vars,
        summarize_gradients=True)

    summary_op = tf.summary.merge_all()
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=1000, keep_checkpoint_every_n_hours=1)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False)) as sess:
        summary_writer = tf.summary.FileWriter(train_dir, sess.graph)
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        if args.random_trunk_init:
            print("Training from scratch")
        else:
            init_assign_op, init_feed_dict, init_vars = net.get_imagenet_init(opt)
            print_variables('init from ImageNet', init_vars)
            sess.run(init_assign_op, feed_dict=init_feed_dict)

        ckpt = tf.train.get_checkpoint_state(train_dir)
        if ckpt and ckpt.model_checkpoint_path:
            if args.ckpt == 0:
                ckpt_to_restore = ckpt.model_checkpoint_path
            else:
                ckpt_to_restore = train_dir+'/model.ckpt-%i' % args.ckpt
            log.info("Restoring model %s..." % ckpt_to_restore)
            saver.restore(sess, ckpt_to_restore)

        starting_step = sess.run(global_step)
        tf.get_default_graph().finalize()
        summary_writer = tf.summary.FileWriter(train_dir, sess.graph)

        log.info("Launching prefetch threads")
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        log.info("Starting training...")
        for step in range(starting_step, args.max_iterations+1):
            start_time = time.time()
            try:
                train_loss, acc, iou, _, lr = sess.run([train_op, train_acc, mean_iou,
                                                        update_mean_iou, learning_rate])
            except (tf.errors.OutOfRangeError, tf.errors.CancelledError):
                break
            duration = time.time() - start_time

            num_examples_per_step = args.batch_size
            examples_per_sec = num_examples_per_step / duration
            sec_per_batch = float(duration)

            format_str = ('step %d, loss = %.2f, acc = %.2f, iou=%f, lr=%.3f (%.1f examples/sec; %.3f '
                          'sec/batch)')
            log.info(format_str % (step, train_loss, acc, iou, -np.log10(lr),
                                examples_per_sec, sec_per_batch))

            if step % 100 == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

            if step % 1000 == 0 and step > 0:
                summary_writer.flush()
                log.debug("Saving checkpoint...")
                checkpoint_path = os.path.join(train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

        summary_writer.close()

        coord.request_stop()
        coord.join(threads)


def main(argv=None):  # pylint: disable=unused-argument
    assert args.detect or args.segment, "Either detect or segment should be True"
    if args.trunk == 'resnet50':
        net = ResNet
        depth = 50
    if args.trunk == 'vgg16':
        net = VGG
        depth = 16

    net = net(config=net_config, depth=depth, training=True, weight_decay=args.weight_decay)

    if args.dataset == 'voc07':
        dataset = get_dataset('voc07_trainval')
    if args.dataset == 'voc12-trainval':
        dataset = get_dataset('voc12-train-segmentation', 'voc12-val')
    if args.dataset == 'voc12-train':
        dataset = get_dataset('voc12-train-segmentation')
    if args.dataset == 'voc12-val':
        dataset = get_dataset('voc12-val-segmentation')
    if args.dataset == 'voc07+12':
        dataset = get_dataset('voc07_trainval', 'voc12_train', 'voc12_val')
    if args.dataset == 'voc07+12-segfull':
        dataset = get_dataset('voc07-trainval-segmentation', 'voc12-train-segmentation', 'voc12-val')
    if args.dataset == 'voc07+12-segmentation':
        dataset = get_dataset('voc07-trainval-segmentation', 'voc12-train-segmentation')
    if args.dataset == 'coco':
        # support by default for coco trainval35k split
        dataset = get_dataset('coco-train2014-*', 'coco-valminusminival2014-*')
    if args.dataset == 'coco-seg':
        # support by default for coco trainval35k split
        dataset = get_dataset('coco-seg-train2014-*', 'coco-seg-valminusminival2014-*')

    train(dataset, net, net_config)

if __name__ == '__main__':
    exec_string = ' '.join(sys.argv)
    log.debug("Executing a command: %s", exec_string)
    cur_commit = subprocess.check_output("git log -n 1 --pretty=format:\"%H\"".split())
    cur_branch = subprocess.check_output("git rev-parse --abbrev-ref HEAD".split())
    git_diff = subprocess.check_output('git diff --no-color'.split()).decode('ascii')
    log.debug("on branch %s with the following diff from HEAD (%s):" % (cur_branch, cur_commit))
    log.debug(git_diff)
    hostname = socket.gethostname()
    if 'gpuhost' in hostname:
        gpu_id = os.environ["CUDA_VISIBLE_DEVICES"]
        nvidiasmi = subprocess.check_output('nvidia-smi').decode('ascii')
        log.debug("Currently we are on %s and use gpu%s:" % (hostname, gpu_id))
        log.debug(nvidiasmi)
    tf.app.run()