#!/usr/bin/env python

# ----------------------------------------------------------------
# 3D-Conv-2D-Pool-UNet Testing Indoor Cases
# Written by Haiyang Jiang
# Mar 20th 2019
# ----------------------------------------------------------------

import os, time

import scipy.io
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
from skvideo.io import vwrite, vread


from network import network
from config import *


import sys
if len(sys.argv) <= 1:
    test_case = 3
else:
    try:
        test_case = int(sys.argv[1])
    except ValueError:
        test_case = 3


if test_case == 0:
    file_list = FILE_LIST
    directory = 'train_set_results/'
elif test_case == 1:
    file_list = VALID_LIST
    directory = 'validation_set_results/'
elif test_case == 2:
    file_list = TEST_LIST
    directory = 'test_set_results/'
else:
    file_list = CUSOMIZED_LIST
    directory = 'customized_test_results/'


TEST_RESULT_DIR = RESULT_DIR + directory
FILE_LIST = file_list


with open(FILE_LIST) as f:
    text = f.readlines()

train_ids = [line.strip().split(' ')[0] for line in text]
in_paths = [line.strip().split(' ')[2] for line in text]
gt_paths = [line.strip().split(' ')[1] for line in text]


def equalize_histogram(image, number_bins=256):
    image_histogram, bins = np.histogram(image.flatten(), number_bins)
    cdf = image_histogram.cumsum()
    cdf = (number_bins - 1) * cdf / cdf[-1] # normalize
    
    image_equalized = np.interp(image.flatten(), bins[:-1], cdf)
    
    return image_equalized.reshape(image.shape)


def process_video(sess, in_image, out_image, in_file, raw, out_file=None):
    input_patch = raw

    if DEBUG:
        print '[DEBUG] (begining of preocess_video) input_patch.shape:', input_patch.shape

    i = 0
    j = 0
    k = 0
    step = 1 - OVERLAP
    output = np.zeros([input_patch.shape[0], input_patch.shape[1] * 2, input_patch.shape[2] * 2, 3], dtype='uint16')
    i_range, j_range, k_range = input_patch.shape[0:3]
    weights = np.zeros(output.shape, dtype='uint8')
    
    # 16 bit
    max_val = 65535.0
    scaling_factor = max_val
    val_type = 'uint16'
    
    input_patch = equalize_histogram(input_patch, int(max_val) + 1)
    
    
    done = False
    while i < i_range:
        if i + TEST_CROP_FRAME > i_range:
            if done:
                break
            i = i_range - TEST_CROP_FRAME
            done = True
        print '[INFO] processing frame', i
        j = 0
        while j < j_range:
            k = 0
            while k < k_range:
                temp = input_patch[i: i + TEST_CROP_FRAME, j: j + TEST_CROP_HEIGHT, k: k + TEST_CROP_WIDTH, :]
                network_input = np.float32(np.expand_dims(temp, axis=0))
                network_input = np.minimum(network_input / scaling_factor, 1.0)
                if DEBUG:
                    print '[DEBUG] network_input.shape:', network_input.shape
                network_output = sess.run(out_image, feed_dict={in_image: network_input})
                if DEBUG:
                    print '[DEBUG] network_output.shape:', network_output.shape

                if i + TEST_CROP_FRAME > i_range:
                    temp = network_output[0, :i_range - i, :, :, :]
                else:
                    temp = network_output[0, :, :, :, :]
                network_output = np.minimum(np.maximum(temp, 0), 1)
                output[i: i + TEST_CROP_FRAME, j * 2: (j + TEST_CROP_HEIGHT) * 2, k * 2: (k + TEST_CROP_WIDTH) * 2, :] += (network_output * OUT_MAX).astype('uint16')
                weights[i: i + TEST_CROP_FRAME, j * 2: (j + TEST_CROP_HEIGHT) * 2, k * 2: (k + TEST_CROP_WIDTH) * 2, :] += 1
                k += int(TEST_CROP_WIDTH * step)
            j += int(TEST_CROP_HEIGHT * step)
        i += int(TEST_CROP_FRAME * step)

    output = (output / weights).astype('uint8')

    if out_file is None:
        out_file = os.path.basename(in_file)[:-4] + '.mp4'
        if DEBUG:
            print '[DEBUG] out_file:', out_file
    print '[PROCESS] Processing done. Saving...',
    t0 = time.time()
    vwrite(TEST_RESULT_DIR + out_file, output)
    t1 = time.time()
    print 'done. ({:.3f}s)'.format(t1 - t0)


def main():
    sess = tf.Session()
    in_image = tf.placeholder(tf.float32, [None, TEST_CROP_FRAME, None, None, 4])
    gt_image = tf.placeholder(tf.float32, [None, TEST_CROP_FRAME, None, None, 3])
    out_image = network(in_image)

    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    ckpt = tf.train.get_checkpoint_state(CHECKPOINT_DIR)
    if ckpt:
        print('loaded ' + ckpt.model_checkpoint_path)
        saver.restore(sess, ckpt.model_checkpoint_path)
    if not os.path.isdir(TEST_RESULT_DIR):
        os.makedirs(TEST_RESULT_DIR)

    for i, file0 in enumerate(in_paths):
        t0 = time.time()
        # raw = vread(file0)
        raw = np.load(file0)
        if raw.shape[0] > MAX_FRAME:
            print 'Video with shape', raw.shape, 'is too large. Splitted.'
            count = 0
            begin_frame = 0
            while begin_frame < raw.shape[0]:
                t1 = time.time()
                print 'processing segment %d ...' % (count + 1),
                new_filename = '.'.join(file0.split('.')[:-1] + [str(count)] + file0.split('.')[-1::])
                process_video(sess, in_image, out_image, new_filename, raw[begin_frame: begin_frame + MAX_FRAME, :, :, :])
                count += 1
                begin_frame += MAX_FRAME
                print '\t{}s'.format(time.time() - t1)
        else:
            process_video(sess, in_image, out_image, file0, raw, out_file=train_ids[i] + '.mp4')
        print train_ids[i], '\t{}s'.format(time.time() - t0)


if __name__ == '__main__':
    t0 = time.time()
    main()
    print 'total time: {}s'.format(time.time() - t0)