##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: chenyuru
## This source code is licensed under the MIT-style license
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import os
import sys
import time
import datetime
import numpy as np
import scipy.io
import shutil
from tensorboardX import SummaryWriter
from trainers import Trainer, DataPrefetcher
from utils import predict_multi_scale, predict_whole_img, compute_errors, \
                    display_figure, colored_depthmap, merge_images, measure_list
import torch
from torch.nn import DataParallel
import matplotlib.pyplot as plt
from tqdm import tqdm
from copy import deepcopy
import json

class DepthEstimationTrainer(Trainer):
    def __init__(self, params, net, datasets, criterion, optimizer, scheduler, 
                 sets=['train', 'val', 'test'], verbose=100, stat=False, 
                 eval_func=compute_errors, 
                 disp_func=display_figure):
        self.time = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
        self.params = params
        self.verbose = verbose
        self.eval_func = eval_func
        self.disp_func = disp_func
        # Init dir
        if params.workdir is not None:
            workdir = os.path.expanduser(params.workdir)
        if params.logdir is None:
            logdir = os.path.join(workdir, 'log_{}_{}'.format(params.encoder+params.decoder, params.dataset))
        else:
            logdir = os.path.join(workdir, params.logdir)
        resdir = None
        if self.params.mode == 'test':
            if params.resdir is None:
                resdir = os.path.join(logdir, 'res')
            else:
                resdir = os.path.join(logdir, params.resdir)
        # Call the constructor of the parent class (Trainer)
        super().__init__(net, datasets, optimizer, scheduler, criterion,
                         batch_size=params.batch, batch_size_val=params.batch_val,
                         max_epochs=params.epochs, threads=params.threads, eval_freq=params.eval_freq,
                         use_gpu=params.gpu, resume=params.resume, mode=params.mode,
                         sets=sets, workdir=workdir, logdir=logdir, resdir=resdir)
        self.params.logdir = self.logdir
        self.params.resdir = self.resdir
        # params json
        if self.params.mode == 'train':
            with open(os.path.join(self.logdir, 'params_{}.json'.format(self.time)), 'w') as f:
                json.dump(vars(self.params), f)
        # uncomment to display the model complexity
        if stat:
            from torchstat import stat
            import tensorwatch as tw
            #net_copy = deepcopy(self.net)
            stat(self.net, (3, *self.datasets[sets[0]].input_size))
            exit()
            tw.draw_model(self.net, (1, 3, *self.datasets[sets[0]].input_size))
            #del net_copy
        self.print('###### Experiment Parameters ######')
        for k, v in vars(self.params).items():
            self.print('{0:<22s} : {1:}'.format(k, v))
        
    def train(self):
        torch.backends.cudnn.benchmark = True
        if self.logdir:
            self.writer = SummaryWriter(self.logdir)
        else:
            raise Exception("Log dir doesn't exist!")
        # Calculate total step
        self.n_train = len(self.trainset)
        self.steps_per_epoch = np.ceil(self.n_train / self.batch_size).astype(np.int32)
        self.verbose = min(self.verbose, self.steps_per_epoch)
        self.n_steps = self.max_epochs * self.steps_per_epoch
        self.print("{0:<22s} : {1:} ".format('trainset sample', self.n_train))
        # calculate model parameters memory
        para = sum([np.prod(list(p.size())) for p in self.net.parameters()])
        memory = para * 4 / (1024**2)
        self.print('Model {} : params: {:,}, Memory {:.3f}MB'.format(self.net._get_name(), para, memory))
        # GO!!!!!!!!!
        start_time = time.time()
        self.train_total_time = 0
        self.time_sofar = 0
        for epoch in range(self.start_epoch, self.max_epochs + 1):
            # Train one epoch
            total_loss = self.train_epoch(epoch)
            torch.cuda.empty_cache()
            # Decay Learning Rate
            if self.params.scheduler in ['step', 'plateau']:
                self.scheduler.step()
            # Evaluate the model
            if self.eval_freq and epoch % self.eval_freq == 0:
                measures = self.eval(epoch)
                torch.cuda.empty_cache()
                for k in sorted(list(measures.keys())):
                    self.writer.add_scalar(k, measures[k], epoch)
        self.print("Finished training! Best epoch {} best acc {:.4f}".format(self.best_epoch, self.best_acc))
        self.print("Spend time: {:.2f}h".format((time.time() - start_time) / 3600))
        net_type = type(self.net).__name__
        best_pkl = os.path.join(self.logdir, '{}_{:03d}.pkl'.format(net_type, self.best_epoch))
        modify = os.path.join(self.logdir, 'best.pkl')
        shutil.copyfile(best_pkl, modify)
        return

    def train_epoch(self, epoch):
        self.net.train()
        device = torch.device('cuda:0' if self.use_gpu else 'cpu')
        self.net.to(device)
        self.criterion.to(device)
        # Iterate over data.
        prefetcher = DataPrefetcher(self.trainloader)
        data = prefetcher.next()
        step = 0
        while data is not None:
            images, labels = data[0].to(device), data[1].to(device)
            before_op_time = time.time()
            self.optimizer.zero_grad()
            output = self.net(images)
            loss1, loss2, loss3, total_loss = self.criterion(output, labels, epoch)
            total_loss.backward()
            self.optimizer.step()
            fps = images.shape[0] / (time.time() - before_op_time)
            time_sofar = self.train_total_time / 3600
            time_left = (self.n_steps / self.global_step - 1.0) * time_sofar
            lr = self.optimizer.param_groups[0]['lr']
            if self.verbose > 0 and (step + 1) % (self.steps_per_epoch // self.verbose) == 0:
                print_str = 'Epoch[{:>2}/{:>2}] | Step[{:>4}/{:>4}] | fps {:4.2f} | Loss1 {:7.3f} | Loss2 {:7.3f} | Loss3 {:7.3f} | elapsed {:.2f}h | left {:.2f}h | lr {:.3e}'. \
                        format(epoch, self.max_epochs, step + 1, self.steps_per_epoch, fps, loss1, loss2, loss3, time_sofar, time_left, lr)
                if self.params.classifier == 'OHEM':
                    ratio = self.criterion.AppearanceLoss.ohem_ratio
                    print_str += ' | OHEM {:.4f}'.format(ratio)
                    self.writer.add_scalar('OHEM', ratio)
                self.print(print_str)
            self.writer.add_scalar('loss1', loss1, self.global_step)
            self.writer.add_scalar('loss2', loss2, self.global_step)
            self.writer.add_scalar('loss3', loss3, self.global_step)
            self.writer.add_scalar('total_loss', total_loss, self.global_step)
            self.writer.add_scalar('lr', lr, epoch)
            # Decay Learning Rate
            if self.params.scheduler == 'poly':
                self.scheduler.step()
            self.global_step += 1
            self.train_total_time += time.time() - before_op_time
            data = prefetcher.next()
            step += 1
        return total_loss

    def eval(self, epoch):
        torch.backends.cudnn.benchmark = True
        self.n_val = len(self.valset)
        self.print("{0:<22s} : {1:} ".format('valset sample', self.n_val))
        self.print("<-------------Evaluate the model-------------->")
        # Evaluate one epoch
        measures, fps = self.eval_epoch(epoch)
        acc = measures['a1']
        self.print('The {}th epoch, fps {:4.2f} | {}'.format(epoch, fps, measures))
        # Save the checkpoint
        self.save(epoch, acc)
        return measures

    def eval_epoch(self, epoch):
        device = torch.device('cuda:0' if self.use_gpu else 'cpu')
        self.net.to(device)
        self.criterion.to(device)
        self.net.eval()
        val_total_time = 0
        #measure_list = ['a1', 'a2', 'a3', 'rmse', 'rmse_log', 'log10', 'abs_rel', 'sq_rel']
        measures = {key: 0 for key in measure_list}
        with torch.no_grad():
            sys.stdout.flush()
            tbar = tqdm(self.valloader)
            rand = np.random.randint(len(self.valloader))
            for step, data in enumerate(tbar):
                images, labels = data[0].to(device), data[1].to(device)
                # forward
                before_op_time = time.time()
                y = self.net(images)
                depths = self.net.inference(y)
                duration = time.time() - before_op_time
                val_total_time += duration
                # accuracy
                new = self.eval_func(labels, depths)
                for k, v in new.items():
                    measures[k] += v.item()
                # display images
                if step == rand and self.disp_func is not None:
                    visuals = {'inputs': images, 'sim_map': y['sim_map'], 'labels': labels, 'depths': depths}
                    self.disp_func(self.writer, visuals, epoch)
                print_str = 'Test step [{}/{}].'.format(step + 1, len(self.valloader))
                tbar.set_description(print_str)
        fps = self.n_val / val_total_time
        measures = {key: round(value/self.n_val, 5) for key, value in measures.items()}
        return measures, fps

    def test(self):
        n_test = len(self.testset)
        device = torch.device('cuda:0' if self.use_gpu else 'cpu')
        self.net.to(device)
        self.net.eval()
        self.print("<-------------Test the model-------------->")
        colormaps = {'nyu': plt.cm.jet, 'kitti': plt.cm.plasma}
        cm = colormaps[self.params.dataset]
        #measure_list = ['a1', 'a2', 'a3', 'rmse', 'rmse_log', 'log10', 'abs_rel', 'sq_rel']
        measures = {key: 0 for key in measure_list}
        test_total_time = 0
        with torch.no_grad():
            #sys.stdout.flush()
            #tbar = tqdm(self.testloader)
            for step, data in enumerate(self.testloader):
                images, labels = data[0].to(device), data[1].to(device)
                before_op_time = time.time()
                scales = [1]
                if self.params.use_ms: 
                    scales = [1, 1.25]
                depths = predict_multi_scale(self.net, images, scales, self.params.classes, self.params.use_flip)
                duration = time.time() - before_op_time
                test_total_time += duration
                # accuracy
                new = self.eval_func(labels, depths)
                print_str = "Test step [{}/{}], a1: {:.5f}, rmse: {:.5f}.".format(step + 1, n_test, new['a1'], new['rmse'])
                self.print(print_str)
                #tbar.set_description(print_str)
                #sys.stdout.flush()
                images = images.cpu().numpy().squeeze().transpose(1, 2, 0)
                labels = labels.cpu().numpy().squeeze()
                depths = depths.cpu().numpy().squeeze()
                labels = colored_depthmap(labels, cmap=cm).squeeze()
                depths = colored_depthmap(depths, cmap=cm).squeeze()
                #fuse = merge_images(images, labels, depths, self.params.min_depth, self.params.max_depth)
                plt.imsave(os.path.join(self.resdir, '{:04}_rgb.png'.format(step)), images)
                plt.imsave(os.path.join(self.resdir, '{:04}_gt.png'.format(step)), labels)
                plt.imsave(os.path.join(self.resdir, '{:04}_depth.png'.format(step)), depths)
                for k, v in new.items():
                    measures[k] += v.item()
        fps = n_test / test_total_time
        measures = {key: round(value / n_test, 5) for key, value in measures.items()}
        self.print('Testing done, fps {:4.2f} | {}'.format(fps, measures))
        return


if __name__ == '__main__':
    pass