import os
import csv
import glob
import time
import shutil
import argparse
import cv2
import numpy as np
from PIL import Image

import tensorflow as tf
slim = tf.contrib.slim

import ops
import utils
import model


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dnnet", type=str)
    parser.add_argument("--dtnet", type=str)
    parser.add_argument("--loop", dest="loop", action="store_true")
    parser.add_argument("--gpu", dest="gpu", action="store_true")

    parser.add_argument("--sample_dir", type=str, default="sample/")
    parser.add_argument("--checkpoint_dir", type=str, default="log/")
    parser.add_argument("--csv_path", type=str, default="dataset/test.csv")

    parser.add_argument("--image_size", type=int, default=256),
    parser.add_argument("--low_thres", type=float, default=500.0),
    parser.add_argument("--up_thres", type=float, default=3000.0),

    parser.add_argument("--num_threads", type=int, default=1)

    return parser.parse_args()


def build_model(im_height, im_width, config):
    # batch_size = 1
    depth_in = tf.placeholder(tf.float32, [1, im_height, im_width, 1])
    paddings = tf.constant([[0, 0], [8, 8], [8, 8], [0, 0]])  # avoid edge vanish
    depth_in_pad = tf.pad(depth_in, paddings, mode='SYMMETRIC', name='depth_in_padded')
    color = tf.placeholder(tf.float32, [1, im_height, im_width, 1])
    color_pad = tf.pad(color, paddings, mode='SYMMETRIC', name='color_in_padded')
    is_training = tf.placeholder(tf.bool, name="is_training")

    dnnet = dtnet = None
    if config.dnnet == "None":
        dnnet = None
    elif config.dnnet == "base":
        dnnet = model.base
    elif config.dnnet == "unet":
        dnnet = model.unet
    elif config.dnnet == "convResnet":
        dnnet = model.convResnet
    elif config.dnnet == "UResnet":
        dnnet = model.UResnet
    else:
        raise NotImplementedError("There is no such {} dnnet".format(config.dnnet))
    if config.dtnet == "None":
        dtnet = None
    elif config.dtnet == "hypercolumn":
        dtnet = model.hypercolumn
    else:
        raise NotImplementedError("There is no such {} dtnet".format(config.dtnet))

    if config.gpu:
        device = "/gpu:0"
    else:
        device = "/cpu:0"

    with tf.device(device):
        if dnnet:
            depth_dn, end_pts, weight_vars = dnnet(depth_in_pad, is_training, aux=None, scope="dn_net")
        else:
            depth_dn = None

        if dtnet:
            depth_dt, end_pts, weight_vars = dtnet(depth_dn if depth_dn is not None else depth_in_pad, color_pad)
        else:
            depth_dt = None

        if depth_dn is not None:
            depth_dn = tf.reshape(depth_dn[:, 8:-8, 8:-8, :], [im_height, im_width])
        if dtnet is not None:
            depth_dt = tf.reshape(depth_dt[:, 8:-8, 8:-8, :], [im_height, im_width])

    return {"depth_in": depth_in, "color": color, "is_training": is_training,
            "depth_dn": depth_dn,
            "depth_dt": depth_dt}


def wait_for_new_checkpoint(checkpoint_dir, history):
    while True:
        path = tf.train.latest_checkpoint(checkpoint_dir)
        if not path in history:
            break
        time.sleep(300)

    history.append(path)
    return path


def load_from_checkpoint(sess, path, exclude=None):
    saver = tf.train.Saver()
    saver.restore(sess, path)


def loop_body_patch(it, ckpt_path, depth_in, depth_ref, color, mask, config):
    """
    :param depth_ref: unused yet. offline quantitative evaluation of depth_dt. 
    """
    print(time.ctime())
    print("Load checkpoint: {}".format(ckpt_path))

    h, w = depth_in.shape[:2]
    low_thres = config.low_thres
    up_thres = config.up_thres
    thres_range = (up_thres - low_thres) / 2.0

    params = build_model(h, w, config)
    # ckpt_step = ckpt_path.split("/")[-1]

    sess = tf.Session()
    load_from_checkpoint(sess, ckpt_path)

    depth_dn_im, depth_dt_im = sess.run([params["depth_dn"], params["depth_dt"]],
                                        feed_dict={params["depth_in"]: depth_in.reshape(1, h, w, 1),
                                                   params["color"]: color.reshape(1, h, w, 1),
                                                   params["is_training"]: False})

    depth_dn_im = (((depth_dn_im + 1.0) * thres_range + low_thres) * mask).astype(np.uint16)
    depth_dt_im = (((depth_dt_im + 1.0) * thres_range + low_thres) * mask).astype(np.uint16)

    utils.save_image(depth_dn_im, config.sample_dir, "frame_{}_dn.png".format(it))
    utils.save_image(depth_dt_im, config.sample_dir, "frame_{}_dt.png".format(it))

    tf.reset_default_graph()
    print("saving img {}.".format(it))


def loop_body_whole(it, ckpt_path, raw_arr, gt_arr, rgb_arr, H, W, config):
    """
    forward input raw_array patches seperately, then h_stack and v_stack these patches into whole.
    """
    print(time.ctime())
    print("Load checkpoint: {}".format(ckpt_path))

    h, w = raw_arr[0][0].shape[:2]
    low_thres = config.low_thres
    up_thres = config.up_thres
    thres_range = (up_thres - low_thres) / 2.0

    params = build_model(h, w, config)
    ckpt_step = ckpt_path.split("/")[-1]

    sess = tf.Session()
    load_from_checkpoint(sess, ckpt_path)

    dn_arr = []
    for i, h_list in enumerate(raw_arr):
        dn_h_list = []
        for j in range(len(h_list)):
            depth_dn_patch, depth_dt_patch = sess.run([params["depth_dn"], params["depth_dt"]],
                                                feed_dict={params["depth_in"]: depth_in.reshape(1, h, w, 1),
                                                           params["color"]: color.reshape(1, h, w, 1),
                                                           params["is_training"]: False})
            dn_h_list.append(depth_dn_patch)
        dn_arr.append(dn_h_list)
    dn_im = utils.stack_patch(dn_arr, H, W)
    dn_im = ((dn_im + 1.0) * thres_range + low_thres).astype(np.uint16)
    utils.save_image(dn_im, config.sample_dir, "frame_{}_dn.png".format(it))

    tf.reset_default_graph()
    print("saving img {}.".format(it))


def loop_body_patch_time(sess, name, params, depth_in, color, mask, config):
    """
    Build test graph only once, more efficient when training phase is finished.
    :return: time elapsed for one batch of testing image.
    """
    h, w = depth_in.shape[:2]
    depth_in = depth_in.reshape(1, h, w, 1)
    color = color.reshape(1, h, w, 1)

    low_thres = config.low_thres
    up_thres = config.up_thres
    thres_range = (up_thres - low_thres) / 2.0

    t_start = time.time()
    feed_dict = {params["depth_in"]: depth_in,
                 params["color"]: color,
                 params["is_training"]: False}
    depth_dn_im = depth_dt_im = None
    if (params["depth_dn"] is not None) and (params["depth_dt"] is not None):
        depth_dn_im, depth_dt_im = sess.run([params["depth_dn"], params["depth_dt"]],
                                            feed_dict=feed_dict)
    elif (params["depth_dn"] is not None):
        depth_dn_im = sess.run(params["depth_dn"], feed_dict=feed_dict)
    elif (params["depth_dt"] is not None):
        depth_dt_im = sess.run(params["depth_dt"], feed_dict=feed_dict)

    print("saving img {}.".format(name))
    if depth_dn_im is not None:
        depth_dn_im = (((depth_dn_im + 1.0) * thres_range + low_thres) * mask).astype(np.uint16)
        utils.save_image(depth_dn_im, config.sample_dir, "dn_{}".format(name))
    if depth_dt_im is not None:
        depth_dt_im = (((depth_dt_im + 1.0) * thres_range + low_thres) * mask).astype(np.uint16)
        utils.save_image(depth_dt_im, config.sample_dir, "dt_{}".format(name))
    t_end = time.time()
    return (t_end - t_start)


def loop(data_info, config, split_stack=True, test_time=True):
    up_thres, low_thres = config.up_thres, config.low_thres
    all_ims = []
    for info in data_info:
        depth_in_path, depth_ref_path, color_path, mask_path = info
        name = os.path.basename(depth_in_path)
        raw = Image.open(depth_in_path)
        gt = Image.open(depth_ref_path)
        rgb = Image.open(color_path).convert('L').resize(raw.size)
        mask = Image.open(mask_path)
        assert raw.size == gt.size, 'gt size not match raw size!'

        # Do center crop here
        if not split_stack:
            if config.image_size < min(raw.size):
                raw, gt, rgb, mask = utils.center_crop(raw, gt, rgb, mask, config.image_size)
            elif config.image_size >= max(raw.size):
                raw, gt, rgb, mask = utils.center_pad(raw, gt, rgb, mask, config.image_size)
            else:
                raise NotImplementedError('invalid config.image_size.')

        mask = np.array(mask, dtype=np.float32) / 255.0
        rgb = np.array(rgb, dtype=np.float32) / (127.0 - 1.0) * mask
        thres_range = (up_thres - low_thres) / 2.0
        raw = np.clip(np.array(raw, dtype=np.float32), low_thres, up_thres)
        gt = np.clip(np.array(gt, dtype=np.float32), low_thres, up_thres)
        if config.dnnet == "bilateral":
            raw = cv2.bilateralFilter(raw, 9, 75, 75)
        raw = (raw - low_thres) / thres_range - 1.0
        gt = (gt - low_thres) / thres_range - 1.0

        if split_stack:
            raw_arr, H, W = utils.split_patch(raw, config.image_size)
            gt_arr, _, _ = utils.split_patch(gt, config.image_size)
            rgb_arr, _, _ = utils.split_patch(rgb, config.image_size)
            all_ims.append((name, raw_arr, gt_arr, rgb_arr))
        else:
            all_ims.append((name, raw, gt, rgb, mask))

    ckpt_history = list()
    if test_time and not split_stack:
        path = tf.train.latest_checkpoint(config.checkpoint_dir)
        print("Load checkpoint: {}".format(path))
        params = build_model(config.image_size, config.image_size, config)
        sess = tf.Session()
        load_from_checkpoint(sess, path)

        tt_time = 0.0
        history_len = 3
        t_history = np.zeros(history_len, dtype=np.float32)
        for i, (name, raw, _, rgb, mask) in enumerate(all_ims):
            t_elapsed = loop_body_patch_time(sess, name, params, raw, rgb, mask, config)
            t_history[i % history_len] = t_elapsed
            tt_time += t_elapsed
            avg_time = 1000 * tt_time / (i + 1)  # ms
            mv_avg_time = 1000 * np.mean(t_history)
            print('iter {} | tt_time: {:.4f}s; avg_time: {:.2f}; mv_avg_time: {:.2f}'.format(i+1, tt_time, avg_time, mv_avg_time))
        tf.reset_default_graph()
    else:
        while True:
            # Wait until new checkpoint exist when training phase is not finished.
            path = wait_for_new_checkpoint(config.checkpoint_dir, ckpt_history)
            print("Loading from checkpoint: {}".format(path))

            if split_stack:
                print('evaluating {} imgs'.format(len(all_ims)))
                for i, (name, raw_arr, gt_arr, rgb_arr) in enumerate(all_ims):
                    loop_body_whole(i, path, raw_arr, gt_arr, rgb_arr, H, W, config)
            else:
                for i, (name, raw, gt, rgb, mask) in enumerate(all_ims):
                    loop_body_patch(i, path, raw, gt, rgb, mask, config)

            if not config.loop: break


def main(config):
    print('evaluating csv file: {}'.format(config.csv_path))
    with open(config.csv_path, "r") as csvfile:
        reader = csv.reader(csvfile)
        data_info = [row for row in reader]

        if not os.path.exists(config.sample_dir):
            os.makedirs(config.sample_dir)

        loop(data_info, config, split_stack=False)


if __name__ == "__main__":
    config = parse_args()
    main(config)