# -*- coding: utf-8 -*-
from __future__ import division, print_function, absolute_import
import tensorflow as tf
import numpy as np
import os
import sys
import time
import cv2

from six.moves import xrange
from scipy import misc, io
from tensorflow.contrib import slim

import matplotlib.pyplot as plt
from network import pyramid_processing, pyramid_processing_bidirection, get_shape
from datasets import BasicDataset
from utils import average_gradients, lrelu, occlusion, rgb_bgr
from data_augmentation import flow_resize
from flowlib import flow_to_color, write_flo
from warp import tf_warp

class DDFlowModel(object):
    def __init__(self, batch_size=8, iter_steps=1000000, initial_learning_rate=1e-4, decay_steps=2e5, 
                 decay_rate=0.5, is_scale=True, num_input_threads=4, buffer_size=5000,
                 beta1=0.9, num_gpus=1, save_checkpoint_interval=5000, write_summary_interval=200,
                 display_log_interval=50, allow_soft_placement=True, log_device_placement=False, 
                 regularizer_scale=1e-4, cpu_device='/cpu:0', save_dir='KITTI', checkpoint_dir='checkpoints', 
                 model_name='model', sample_dir='sample', summary_dir='summary', training_mode="no_distillation", 
                 is_restore_model=False, restore_model='./models/KITTI/no_census_no_occlusion',
                 dataset_config={}, distillation_config={}):
        self.batch_size = batch_size
        self.iter_steps = iter_steps
        self.initial_learning_rate = initial_learning_rate
        self.decay_steps = decay_steps
        self.decay_rate = decay_rate
        self.is_scale = is_scale
        self.num_input_threads = num_input_threads
        self.buffer_size = buffer_size
        self.beta1 = beta1       
        self.num_gpus = num_gpus
        self.save_checkpoint_interval = save_checkpoint_interval
        self.write_summary_interval = write_summary_interval
        self.display_log_interval = display_log_interval
        self.allow_soft_placement = allow_soft_placement
        self.log_device_placement = log_device_placement
        self.regularizer_scale = regularizer_scale
        self.training_mode = training_mode
        self.is_restore_model = is_restore_model
        self.restore_model = restore_model
        self.dataset_config = dataset_config
        self.distillation_config = distillation_config
        self.shared_device = '/gpu:0' if self.num_gpus == 1 else cpu_device
        assert(np.mod(batch_size, num_gpus) == 0)
        self.batch_size_per_gpu = int(batch_size / np.maximum(num_gpus, 1))
        
        self.save_dir = save_dir
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)         
        
        self.checkpoint_dir = '/'.join([self.save_dir, checkpoint_dir])
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir) 
        
        self.model_name = model_name
        if not os.path.exists('/'.join([self.checkpoint_dir, model_name])):
            os.makedirs(('/'.join([self.checkpoint_dir, self.model_name])))         
            
        self.sample_dir = '/'.join([self.save_dir, sample_dir])
        if not os.path.exists(self.sample_dir):
            os.makedirs(self.sample_dir)  
        if not os.path.exists('/'.join([self.sample_dir, self.model_name])):
            os.makedirs(('/'.join([self.sample_dir, self.model_name])))    
        
        self.summary_dir = '/'.join([self.save_dir, summary_dir])
        if not os.path.exists(self.summary_dir):
            os.makedirs(self.summary_dir) 
        if not os.path.exists('/'.join([self.summary_dir, 'train'])):
            os.makedirs(('/'.join([self.summary_dir, 'train']))) 
        if not os.path.exists('/'.join([self.summary_dir, 'test'])):
            os.makedirs(('/'.join([self.summary_dir, 'test'])))             
    
    def create_dataset_and_iterator(self, training_mode='no_distillation'):
        if training_mode=='no_distillation':
            dataset = BasicDataset(crop_h=self.dataset_config['crop_h'], 
                                   crop_w=self.dataset_config['crop_w'],
                                   batch_size=self.batch_size_per_gpu,
                                   data_list_file=self.dataset_config['data_list_file'],
                                   img_dir=self.dataset_config['img_dir'])
            iterator = dataset.create_batch_iterator(data_list=dataset.data_list, batch_size=dataset.batch_size,
                shuffle=True, buffer_size=self.buffer_size, num_parallel_calls=self.num_input_threads)   
        elif training_mode == 'distillation':
            dataset = BasicDataset(crop_h=self.dataset_config['crop_h'], 
                                   crop_w=self.dataset_config['crop_w'],
                                   batch_size=self.batch_size_per_gpu,
                                   data_list_file=self.dataset_config['data_list_file'],
                                   img_dir=self.dataset_config['img_dir'],
				   fake_flow_occ_dir=self.distillation_config['fake_flow_occ_dir'])
            iterator = dataset.create_batch_distillation_iterator(data_list=dataset.data_list, batch_size=dataset.batch_size,
                shuffle=True, buffer_size=self.buffer_size, num_parallel_calls=self.num_input_threads) 
        else:
            raise ValueError('Invalid training_mode. Training_mode should be one of {no_distillation, distillation}')
        return dataset, iterator
    
    def epe_loss(self, diff, mask):
        diff_norm = tf.norm(diff, axis=-1, keepdims=True)
        diff_norm = tf.multiply(diff_norm, mask)
        diff_norm_sum = tf.reduce_sum(diff_norm)
        loss_mean = diff_norm_sum / (tf.reduce_sum(mask) + 1e-6)
        
        return loss_mean 
    
    def abs_robust_loss(self, diff, mask, q=0.4):
        diff = tf.pow((tf.abs(diff)+0.01), q)
        diff = tf.multiply(diff, mask)
        diff_sum = tf.reduce_sum(diff)
        loss_mean = diff_sum / (tf.reduce_sum(mask) * 2 + 1e-6) 
        return loss_mean 
    
    def create_mask(self, tensor, paddings):
        with tf.variable_scope('create_mask'):
            shape = tf.shape(tensor)
            inner_width = shape[1] - (paddings[0][0] + paddings[0][1])
            inner_height = shape[2] - (paddings[1][0] + paddings[1][1])
            inner = tf.ones([inner_width, inner_height])
    
            mask2d = tf.pad(inner, paddings)
            mask3d = tf.tile(tf.expand_dims(mask2d, 0), [shape[0], 1, 1])
            mask4d = tf.expand_dims(mask3d, 3)
            return tf.stop_gradient(mask4d) 
        
    def census_loss(self, img1, img2_warped, mask, max_distance=3):
        patch_size = 2 * max_distance + 1
        with tf.variable_scope('census_loss'):
            def _ternary_transform(image):
                intensities = tf.image.rgb_to_grayscale(image) * 255
                #patches = tf.extract_image_patches( # fix rows_in is None
                #    intensities,
                #    ksizes=[1, patch_size, patch_size, 1],
                #    strides=[1, 1, 1, 1],
                #    rates=[1, 1, 1, 1],
                #    padding='SAME')
                out_channels = patch_size * patch_size
                w = np.eye(out_channels).reshape((patch_size, patch_size, 1, out_channels))
                weights =  tf.constant(w, dtype=tf.float32)
                patches = tf.nn.conv2d(intensities, weights, strides=[1, 1, 1, 1], padding='SAME')
    
                transf = patches - intensities
                transf_norm = transf / tf.sqrt(0.81 + tf.square(transf))
                return transf_norm
    
            def _hamming_distance(t1, t2):
                dist = tf.square(t1 - t2)
                dist_norm = dist / (0.1 + dist)
                dist_sum = tf.reduce_sum(dist_norm, 3, keepdims=True)
                return dist_sum
    
            t1 = _ternary_transform(img1)
            t2 = _ternary_transform(img2_warped)
            dist = _hamming_distance(t1, t2)
    
            transform_mask = self.create_mask(mask, [[max_distance, max_distance],
                                                [max_distance, max_distance]])
            return self.abs_robust_loss(dist, mask * transform_mask) 
    
 
    def compute_losses(self, batch_img1, batch_img2, flow_fw, flow_bw, mask_fw, mask_bw, train=True, is_scale=True):
        img_size = get_shape(batch_img1, train=train)
        img1_warp = tf_warp(batch_img1, flow_bw['full_res'], img_size[1], img_size[2])
        img2_warp = tf_warp(batch_img2, flow_fw['full_res'], img_size[1], img_size[2])
        
        losses = {}
        
        abs_robust_mean = {}
        abs_robust_mean['no_occlusion'] = self.abs_robust_loss(batch_img1-img2_warp, tf.ones_like(mask_fw)) + self.abs_robust_loss(batch_img2-img1_warp, tf.ones_like(mask_bw))
        abs_robust_mean['occlusion'] = self.abs_robust_loss(batch_img1-img2_warp, mask_fw) + self.abs_robust_loss(batch_img2-img1_warp, mask_bw)
        losses['abs_robust_mean'] = abs_robust_mean
        
        census_loss = {}
        census_loss['no_occlusion'] = self.census_loss(batch_img1, img2_warp, tf.ones_like(mask_fw), max_distance=3) + \
                    self.census_loss(batch_img2, img1_warp, tf.ones_like(mask_bw), max_distance=3)        
        census_loss['occlusion'] = self.census_loss(batch_img1, img2_warp, mask_fw, max_distance=3) + \
            self.census_loss(batch_img2, img1_warp, mask_bw, max_distance=3)
        losses['census'] = census_loss
        
        return losses
        
    def add_loss_summary(self, losses, keys=['abs_robust_mean'], prefix=None):
        for key in keys:
            for loss_key, loss_value in losses[key].items():
                if prefix:
                    loss_name = '%s/%s/%s' % (prefix, key, loss_key)
                else:
                    loss_name = '%s/%s' % (key, loss_key)
                tf.summary.scalar(loss_name, loss_value)
    
    def build_no_data_distillation(self, iterator, regularizer_scale=1e-4, train=True, trainable=True, is_scale=True):
        batch_img1, batch_img2 = iterator.get_next()
        regularizer = slim.l2_regularizer(scale=regularizer_scale)
        flow_fw, flow_bw = pyramid_processing_bidirection(batch_img1, batch_img2, 
            train=train, trainable=trainable, reuse=None, regularizer=regularizer, is_scale=is_scale)  
        
        occ_fw, occ_bw = occlusion(flow_fw['full_res'], flow_bw['full_res'])
        mask_fw = 1. - occ_fw
        mask_bw = 1. - occ_bw  

        losses = self.compute_losses(batch_img1, batch_img2, flow_fw, flow_bw, mask_fw, mask_bw, train=train, is_scale=is_scale)
        
        l2_regularizer = tf.losses.get_regularization_losses()
        regularizer_loss = tf.add_n(l2_regularizer)
        return losses, regularizer_loss 
    
    def build_data_distillation(self, iterator, regularizer_scale=1e-4, train=True, trainable=True, is_scale=True):
        batch_img1, batch_img2, flow_fw, flow_bw, occ_fw, occ_bw = iterator.get_next()
        regularizer = slim.l2_regularizer(scale=regularizer_scale)
        h = self.dataset_config['crop_h']
        w = self.dataset_config['crop_w']
        target_h = self.distillation_config['target_h']
        target_w = self.distillation_config['target_w']           
        offect_h = tf.random_uniform([], minval=0, maxval=h-target_h, dtype=tf.int32)
        offect_w = tf.random_uniform([], minval=0, maxval=w-target_w, dtype=tf.int32)
 
        
        batch_img1_cropped_patch = tf.image.crop_to_bounding_box(batch_img1, offect_h, offect_w, target_h, target_w)
        batch_img2_cropped_patch = tf.image.crop_to_bounding_box(batch_img2, offect_h, offect_w, target_h, target_w)     
        flow_fw_cropped_patch = tf.image.crop_to_bounding_box(flow_fw, offect_h, offect_w, target_h, target_w) 
        flow_bw_cropped_patch = tf.image.crop_to_bounding_box(flow_bw, offect_h, offect_w, target_h, target_w) 
        occ_fw_cropped_patch = tf.image.crop_to_bounding_box(occ_fw, offect_h, offect_w, target_h, target_w)
        occ_bw_cropped_patch = tf.image.crop_to_bounding_box(occ_bw, offect_h, offect_w, target_h, target_w)
        
        flow_fw_patch, flow_bw_patch = pyramid_processing_bidirection(batch_img1_cropped_patch, batch_img2_cropped_patch, 
            train=train, trainable=trainable, reuse=None, regularizer=regularizer, is_scale=is_scale)  
        
        occ_fw_patch, occ_bw_patch = occlusion(flow_fw_patch['full_res'], flow_bw_patch['full_res'])
        mask_fw_patch = 1. - occ_fw_patch
        mask_bw_patch = 1. - occ_bw_patch

        losses = self.compute_losses(batch_img1_cropped_patch, batch_img2_cropped_patch, flow_fw_patch, flow_bw_patch, mask_fw_patch, mask_bw_patch, train=train, is_scale=is_scale)
        
        valid_mask_fw = tf.clip_by_value(occ_fw_patch - occ_fw_cropped_patch, 0., 1.)
        valid_mask_bw = tf.clip_by_value(occ_bw_patch - occ_bw_cropped_patch, 0., 1.)
        data_distillation_loss = {}
        data_distillation_loss['distillation'] = (self.abs_robust_loss(flow_fw_cropped_patch-flow_fw_patch['full_res'], valid_mask_fw) + \
                                       self.abs_robust_loss(flow_bw_cropped_patch-flow_bw_patch['full_res'], valid_mask_bw)) / 2
        losses['data_distillation'] = data_distillation_loss
        
        l2_regularizer = tf.losses.get_regularization_losses()
        regularizer_loss = tf.add_n(l2_regularizer)
        return losses, regularizer_loss  
    
    def build(self, iterator, regularizer_scale=1e-4, train=True, trainable=True, is_scale=True, training_mode='no_distillation'):
        if training_mode == 'no_distillation':
            losses, regularizer_loss = self.build_no_data_distillation(iterator=iterator, regularizer_scale=regularizer_scale, train=train, trainable=trainable, is_scale=is_scale)
        elif training_mode == 'distillation':
            losses, regularizer_loss = self.build_data_distillation(iterator=iterator, regularizer_scale=regularizer_scale, train=train, trainable=trainable, is_scale=is_scale)
        else:
            raise ValueError('Invalid training_mode. Training_mode should be one of {no_distillation, distillation}')      
        return losses, regularizer_loss
                    
    def create_train_op(self, optim, iterator, global_step, regularizer_scale=1e-4, train=True, trainable=True, is_scale=True, training_mode='no_distillation'):  
        if self.num_gpus == 1:
            losses, regularizer_loss = self.build(iterator, regularizer_scale=regularizer_scale, train=train, trainable=trainable, is_scale=is_scale, training_mode=training_mode)
            optim_loss = losses['abs_robust_mean']['no_occlusion']
            train_op = optim.minimize(optim_loss, var_list=tf.trainable_variables(), global_step=global_step)            
        else:
            tower_grads = []
            tower_losses = []
            tower_regularizer_losses = []
            with tf.variable_scope(tf.get_variable_scope()):
                for i in range(self.num_gpus):
                    with tf.device('/gpu:%d' % i):
                        with tf.name_scope('tower_{}'.format(i)) as scope:
                            losses_, regularizer_loss_ = self.build(iterator, regularizer_scale=regularizer_scale, train=train, trainable=trainable, is_scale=is_scale, training_mode=training_mode) 
                            optim_loss = losses_['abs_robust_mean']['no_occlusion']

                            # Reuse variables for the next tower.
                            tf.get_variable_scope().reuse_variables()

                            grads = self.optim.compute_gradients(optim_loss, var_list=tf.trainable_variables())
                            tower_grads.append(grads)
                            tower_losses.append(losses_)
                            tower_regularizer_losses.append(regularizer_loss_)
                            #self.add_loss_summary(losses_, keys=['abs_robust_mean', 'census'], prefix='tower_%d' % i)
                                        
            grads = average_gradients(tower_grads)
            train_op = optim.apply_gradients(grads, global_step=global_step)
            
            losses = tower_losses[0].copy()
            for key in losses.keys():
                for loss_key, loss_value in losses[key].items():
                    for i in range(1, self.num_gpus):
                        losses[key][loss_key] += tower_losses[i][key][loss_key]
                    losses[key][loss_key] /= self.num_gpus
            regularizer_loss = 0.
            for i in range(self.num_gpus):
                regularizer_loss += tower_regularizer_losses[i]
            regularizer_loss /= self.num_gpus

        self.add_loss_summary(losses, keys=losses.keys())
        tf.summary.scalar('regularizer_loss', regularizer_loss)
        
        return train_op, losses, regularizer_loss
    
    def train(self):
        with tf.Graph().as_default(), tf.device(self.shared_device):
            self.global_step = tf.Variable(0, trainable=False)
            self.dataset, self.iterator = self.create_dataset_and_iterator(training_mode=self.training_mode)       
            self.lr_decay = tf.train.exponential_decay(self.initial_learning_rate, self.global_step, decay_steps=self.decay_steps, decay_rate=self.decay_rate, staircase=True)
            tf.summary.scalar('learning_rate', self.lr_decay)
            self.optim = tf.train.AdamOptimizer(self.lr_decay, self.beta1)            
            self.train_op, self.losses, self.regularizer_loss = self.create_train_op(optim=self.optim, iterator=self.iterator, 
                global_step=self.global_step, regularizer_scale=self.regularizer_scale, train=True, trainable=True, is_scale=self.is_scale, training_mode=self.training_mode)
            
            merge_summary = tf.summary.merge_all()
            summary_writer = tf.summary.FileWriter(logdir='/'.join([self.summary_dir, 'train', self.model_name]))
            self.trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 
            self.saver = tf.train.Saver(var_list=self.trainable_vars + [self.global_step], max_to_keep=500)
            
            self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=self.allow_soft_placement, log_device_placement=self.log_device_placement))
            self.sess.run(tf.global_variables_initializer())
            self.sess.run(tf.local_variables_initializer())
            
            if self.is_restore_model:
                self.saver.restore(self.sess, self.restore_model)
            
            self.sess.run(tf.assign(self.global_step, 0))
            start_step = self.sess.run(self.global_step)
            self.sess.run(self.iterator.initializer)
            start_time = time.time()
            for step in range(start_step+1, self.iter_steps+1):
                _, abs_robust_mean_no_occlusion, census_occlusion = self.sess.run([self.train_op,
                    self.losses['abs_robust_mean']['no_occlusion'], self.losses['census']['occlusion']])
                if np.mod(step, self.display_log_interval) == 0:
                    print('step: %d time: %.6fs, abs_robust_mean_no_occlusion: %.6f, census_occlusion: %.6f' % 
                        (step, time.time() - start_time, abs_robust_mean_no_occlusion, census_occlusion))
                
                if np.mod(step, self.write_summary_interval) == 0:
                    summary_str = self.sess.run(merge_summary)
                    summary_writer.add_summary(summary_str, global_step=step)
                
                if np.mod(step, self.save_checkpoint_interval) == 0:
                    self.saver.save(self.sess, '/'.join([self.checkpoint_dir, self.model_name, 'model']), global_step=step, 
                                    write_meta_graph=False, write_state=False) 
                    
    def test(self, restore_model, save_dir):
        dataset = BasicDataset(data_list_file=self.dataset_config['data_list_file'], img_dir=self.dataset_config['img_dir'])
        save_name_list = dataset.data_list[:, 2]
        iterator = dataset.create_one_shot_iterator(dataset.data_list, num_parallel_calls=self.num_input_threads)
        batch_img1, batch_img2 = iterator.get_next()
        flow_est = pyramid_processing(batch_img1, batch_img2, train=False, trainable=False, regularizer=None, is_scale=True)  
        flow_est_color = flow_to_color(flow_est['full_res'], mask=None, max_flow=256)
        
        restore_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 
        saver = tf.train.Saver(var_list=restore_vars)
        sess = tf.Session()
        sess.run(tf.global_variables_initializer()) 
        sess.run(iterator.initializer) 
        saver.restore(sess, restore_model)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)           
        for i in range(dataset.data_num):
            np_flow_est, np_flow_est_color = sess.run([flow_est['full_res'], flow_est_color])
            misc.imsave('%s/flow_est_color_%s.png' % (save_dir, save_name_list[i]), np_flow_est_color[0])
            write_flo('%s/flow_est_%s.flo' % (save_dir, save_name_list[i]), np_flow_est[0])
            print('Finish %d/%d' % (i, dataset.data_num))    
    
    def generate_fake_flow_occlusion(self, restore_model, save_dir):
        dataset = BasicDataset(data_list_file=self.dataset_config['data_list_file'], img_dir=self.dataset_config['img_dir'])
        save_name_list = dataset.data_list[:, 2]
        iterator = dataset.create_one_shot_iterator(dataset.data_list, num_parallel_calls=self.num_input_threads)
        batch_img1, batch_img2 = iterator.get_next()        
        flow_fw, flow_bw = pyramid_processing_bidirection(batch_img1, batch_img2, 
            train=False, trainable=False, reuse=None, regularizer=None, is_scale=True)  
        occ_fw, occ_bw = occlusion(flow_fw['full_res'], flow_bw['full_res'])        
        
        flow_fw_full_res = flow_fw['full_res'] * 64. + 32768
        flow_occ_fw = tf.concat([flow_fw_full_res, occ_fw], -1)
        flow_occ_fw = tf.cast(flow_occ_fw, tf.uint16)
        flow_bw_full_res = flow_bw['full_res'] * 64. + 32768
        flow_occ_bw = tf.concat([flow_bw_full_res, occ_bw], -1)
        flow_occ_bw = tf.cast(flow_occ_bw, tf.uint16)
        
        restore_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 
        saver = tf.train.Saver(var_list=restore_vars)
        sess = tf.Session()
        sess.run(tf.global_variables_initializer()) 
        sess.run(iterator.initializer) 
        saver.restore(sess, restore_model)
        #save_dir = '/'.join([self.save_dir, 'sample', self.model_name])
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)           
        for i in range(dataset.data_num):
            np_flow_occ_fw, np_flow_occ_bw, np_occ_fw = sess.run([flow_occ_fw, flow_occ_bw, occ_fw])
            
            # opencv read and save image as bgr format, here we change rgb to bgr
            np_flow_occ_fw = rgb_bgr(np_flow_occ_fw[0])
            np_flow_occ_bw = rgb_bgr(np_flow_occ_bw[0])
            np_flow_occ_fw = np_flow_occ_fw.astype(np.uint16)
            np_flow_occ_bw = np_flow_occ_bw.astype(np.uint16)
            
            cv2.imwrite('%s/flow_occ_fw_%s.png' % (save_dir, save_name_list[i]), np_flow_occ_fw)
            cv2.imwrite('%s/flow_occ_bw_%s.png' % (save_dir, save_name_list[i]), np_flow_occ_bw)
            print('Finish %d/%d' % (i, dataset.data_num))