"""
Copyright (C) 2017, 申瑞珉 (Ruimin Shen)

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.
"""

import sys
import argparse
import configparser
import logging
import logging.config
import collections
import multiprocessing
import os
import shutil
import io
import hashlib
import subprocess
import pickle
import traceback
import yaml

import numpy as np
import torch.autograd
import torch.cuda
import torch.nn as nn
import torch.optim
import torch.utils.data
import torch.nn.functional as F
import torchvision.transforms
import tqdm
import humanize
import pybenchmark
import filelock
from tensorboardX import SummaryWriter

import model
import transform.augmentation
import utils.data
import utils.postprocess
import utils.train
import utils.visualize
import eval as _eval


def norm_data(data, height, width, rows, cols, keys='yx_min, yx_max'):
    _data = {key: data[key] for key in data}
    scale = utils.ensure_device(torch.from_numpy(np.reshape(np.array([rows / height, cols / width], dtype=np.float32), [1, 1, 2])))
    for key in keys.split(', '):
        _data[key] = _data[key] * scale
    return _data


def ensure_model(model):
    if torch.cuda.is_available():
        model.cuda()
        if torch.cuda.device_count() > 1:
            logging.info('%d GPUs are used' % torch.cuda.device_count())
            model = nn.DataParallel(model).cuda()
    return model


class SummaryWorker(multiprocessing.Process):
    def __init__(self, env):
        super(SummaryWorker, self).__init__()
        self.env = env
        self.config = env.config
        self.queue = multiprocessing.Queue()
        try:
            self.timer_scalar = utils.train.Timer(env.config.getfloat('summary', 'scalar'))
        except configparser.NoOptionError:
            self.timer_scalar = lambda: False
        try:
            self.timer_image = utils.train.Timer(env.config.getfloat('summary', 'image'))
        except configparser.NoOptionError:
            self.timer_image = lambda: False
        try:
            self.timer_histogram = utils.train.Timer(env.config.getfloat('summary', 'histogram'))
        except configparser.NoOptionError:
            self.timer_histogram = lambda: False
        with open(os.path.expanduser(os.path.expandvars(env.config.get('summary_histogram', 'parameters'))), 'r') as f:
            self.histogram_parameters = utils.RegexList([line.rstrip() for line in f])
        self.draw_bbox = utils.visualize.DrawBBox(env.category)
        self.draw_feature = utils.visualize.DrawFeature()

    def __call__(self, name, **kwargs):
        if getattr(self, 'timer_' + name)():
            kwargs = getattr(self, 'copy_' + name)(**kwargs)
            self.queue.put((name, kwargs))

    def stop(self):
        self.queue.put((None, {}))

    def run(self):
        self.writer = SummaryWriter(os.path.join(self.env.model_dir, self.env.args.run))
        try:
            height, width = tuple(map(int, self.config.get('image', 'size').split()))
            tensor = torch.randn(1, 3, height, width)
            step, epoch, dnn = self.env.load()
            self.writer.add_graph(dnn, (torch.autograd.Variable(tensor),))
        except:
            traceback.print_exc()
        while True:
            name, kwargs = self.queue.get()
            if name is None:
                break
            func = getattr(self, 'summary_' + name)
            try:
                func(**kwargs)
            except:
                traceback.print_exc()

    def copy_scalar(self, **kwargs):
        step, loss_total, loss, loss_hparam = (kwargs[key] for key in 'step, loss_total, loss, loss_hparam'.split(', '))
        loss_total = loss_total.data.clone().cpu().numpy()
        loss = {key: l.data.clone().cpu().numpy() for key, l in loss.items()}
        loss_hparam = {key: l.data.clone().cpu().numpy() for key, l in loss_hparam.items()}
        return dict(
            step=step,
            loss_total=loss_total,
            loss=loss, loss_hparam=loss_hparam,
        )

    def summary_scalar(self, **kwargs):
        step, loss_total, loss, loss_hparam = (kwargs[key] for key in 'step, loss_total, loss, loss_hparam'.split(', '))
        for key, l in loss.items():
            self.writer.add_scalar('loss/' + key, l[0], step)
        if self.config.getboolean('summary_scalar', 'loss_hparam'):
            self.writer.add_scalars('loss_hparam', {key: l[0] for key, l in loss_hparam.items()}, step)
        self.writer.add_scalar('loss_total', loss_total[0], step)

    def copy_image(self, **kwargs):
        step, height, width, rows, cols, data, pred, debug = (kwargs[key] for key in 'step, height, width, rows, cols, data, pred, debug'.split(', '))
        data = {key: data[key].clone().cpu().numpy() for key in 'image, yx_min, yx_max, cls'.split(', ')}
        pred = {key: pred[key].data.clone().cpu().numpy() for key in 'yx_min, yx_max, iou, logits'.split(', ') if key in pred}
        matching = (debug['positive'].float() - debug['negative'].float() + 1) / 2
        matching = matching.data.clone().cpu().numpy()
        return dict(
            step=step, height=height, width=width, rows=rows, cols=cols,
            data=data, pred=pred,
            matching=matching,
        )

    def summary_image(self, **kwargs):
        step, height, width, rows, cols, data, pred, matching = (kwargs[key] for key in 'step, height, width, rows, cols, data, pred, matching'.split(', '))
        image = data['image']
        limit = min(self.config.getint('summary_image', 'limit'), image.shape[0])
        image = image[:limit, :, :, :]
        yx_min, yx_max, iou = (pred[key] for key in 'yx_min, yx_max, iou'.split(', '))
        scale = [height / rows, width / cols]
        yx_min, yx_max = (a * scale for a in (yx_min, yx_max))
        if 'logits' in pred:
            cls = np.argmax(F.softmax(torch.autograd.Variable(torch.from_numpy(pred['logits'])), -1).data.cpu().numpy(), -1)
        else:
            cls = np.zeros(iou.shape, np.int)
        if self.config.getboolean('summary_image', 'bbox'):
            # data
            canvas = np.copy(image)
            canvas = pybenchmark.profile('bbox/data')(self.draw_bbox_data)(canvas, *(data[key] for key in 'yx_min, yx_max, cls'.split(', ')))
            self.writer.add_image('bbox/data', torchvision.utils.make_grid(torch.from_numpy(np.stack(canvas)).permute(0, 3, 1, 2).float(), normalize=True, scale_each=True), step)
            # pred
            canvas = np.copy(image)
            canvas = pybenchmark.profile('bbox/pred')(self.draw_bbox_pred)(canvas, yx_min, yx_max, cls, iou, nms=True)
            self.writer.add_image('bbox/pred', torchvision.utils.make_grid(torch.from_numpy(np.stack(canvas)).permute(0, 3, 1, 2).float(), normalize=True, scale_each=True), step)
        if self.config.getboolean('summary_image', 'iou'):
            # bbox
            canvas = np.copy(image)
            canvas_data = self.draw_bbox_data(canvas, *(data[key] for key in 'yx_min, yx_max, cls'.split(', ')), colors=['g'])
            # data
            for i, canvas in enumerate(pybenchmark.profile('iou/data')(self.draw_bbox_iou)(list(map(np.copy, canvas_data)), yx_min, yx_max, cls, matching, rows, cols, colors=['w'])):
                canvas = np.stack(canvas)
                canvas = torch.from_numpy(canvas).permute(0, 3, 1, 2)
                canvas = torchvision.utils.make_grid(canvas.float(), normalize=True, scale_each=True)
                self.writer.add_image('iou/data%d' % i, canvas, step)
            # pred
            for i, canvas in enumerate(pybenchmark.profile('iou/pred')(self.draw_bbox_iou)(list(map(np.copy, canvas_data)), yx_min, yx_max, cls, iou, rows, cols, colors=['w'])):
                canvas = np.stack(canvas)
                canvas = torch.from_numpy(canvas).permute(0, 3, 1, 2)
                canvas = torchvision.utils.make_grid(canvas.float(), normalize=True, scale_each=True)
                self.writer.add_image('iou/pred%d' % i, canvas, step)

    def draw_bbox_data(self, canvas, yx_min, yx_max, cls, colors=None):
        batch_size = len(canvas)
        if len(cls.shape) == len(yx_min.shape):
            cls = np.argmax(cls, -1)
        yx_min, yx_max, cls = ([a[b] for b in range(batch_size)] for a in (yx_min, yx_max, cls))
        return [self.draw_bbox(canvas, yx_min.astype(np.int), yx_max.astype(np.int), cls, colors=colors) for canvas, yx_min, yx_max, cls in zip(canvas, yx_min, yx_max, cls)]

    def draw_bbox_pred(self, canvas, yx_min, yx_max, cls, iou, colors=None, nms=False):
        batch_size = len(canvas)
        mask = iou > self.config.getfloat('detect', 'threshold')
        yx_min, yx_max = (np.reshape(a, [a.shape[0], -1, 2]) for a in (yx_min, yx_max))
        cls, iou, mask = (np.reshape(a, [a.shape[0], -1]) for a in (cls, iou, mask))
        yx_min, yx_max, cls, iou, mask = ([a[b] for b in range(batch_size)] for a in (yx_min, yx_max, cls, iou, mask))
        yx_min, yx_max, cls, iou = ([a[m] for a, m in zip(l, mask)] for l in (yx_min, yx_max, cls, iou))
        if nms:
            overlap = self.config.getfloat('detect', 'overlap')
            keep = [pybenchmark.profile('nms')(utils.postprocess.nms)(torch.Tensor(iou), torch.Tensor(yx_min), torch.Tensor(yx_max), overlap) if iou.shape[0] > 0 else [] for yx_min, yx_max, iou in zip(yx_min, yx_max, iou)]
            keep = [np.array(k, np.int) for k in keep]
            yx_min, yx_max, cls = ([a[k] for a, k in zip(l, keep)] for l in (yx_min, yx_max, cls))
        return [self.draw_bbox(canvas, yx_min.astype(np.int), yx_max.astype(np.int), cls, colors=colors) for canvas, yx_min, yx_max, cls in zip(canvas, yx_min, yx_max, cls)]

    def draw_bbox_iou(self, canvas_share, yx_min, yx_max, cls, iou, rows, cols, colors=None):
        batch_size = len(canvas_share)
        yx_min, yx_max = ([np.squeeze(a, -2) for a in np.split(a, a.shape[-2], -2)] for a in (yx_min, yx_max))
        cls, iou = ([np.squeeze(a, -1) for a in np.split(a, a.shape[-1], -1)] for a in (cls, iou))
        results = []
        for i, (yx_min, yx_max, cls, iou) in enumerate(zip(yx_min, yx_max, cls, iou)):
            mask = iou > self.config.getfloat('detect', 'threshold')
            yx_min, yx_max = (np.reshape(a, [a.shape[0], -1, 2]) for a in (yx_min, yx_max))
            cls, iou, mask = (np.reshape(a, [a.shape[0], -1]) for a in (cls, iou, mask))
            yx_min, yx_max, cls, iou, mask = ([a[b] for b in range(batch_size)] for a in (yx_min, yx_max, cls, iou, mask))
            yx_min, yx_max, cls = ([a[m] for a, m in zip(l, mask)] for l in (yx_min, yx_max, cls))
            canvas = [self.draw_bbox(canvas, yx_min.astype(np.int), yx_max.astype(np.int), cls, colors=colors) for canvas, yx_min, yx_max, cls in zip(np.copy(canvas_share), yx_min, yx_max, cls)]
            iou = [np.reshape(a, [rows, cols]) for a in iou]
            canvas = [self.draw_feature(_canvas, iou) for _canvas, iou in zip(canvas, iou)]
            results.append(canvas)
        return results

    def copy_histogram(self, **kwargs):
        return {
            'step': kwargs['step'],
            'state_dict': self.env.dnn.state_dict(),
        }

    def summary_histogram(self, **kwargs):
        step, state_dict = (kwargs[key] for key in 'step, state_dict'.split(', '))
        for name, var in state_dict.items():
            if self.histogram_parameters(name):
                self.writer.add_histogram(name, var, step)


class Train(object):
    def __init__(self, args, config):
        self.args = args
        self.config = config
        self.model_dir = utils.get_model_dir(config)
        self.cache_dir = utils.get_cache_dir(config)
        self.category = utils.get_category(config, self.cache_dir)
        self.anchors = torch.from_numpy(utils.get_anchors(config)).contiguous()
        logging.info('use cache directory ' + self.cache_dir)
        logging.info('tensorboard --logdir ' + self.model_dir)
        if args.delete:
            logging.warning('delete model directory: ' + self.model_dir)
            shutil.rmtree(self.model_dir, ignore_errors=True)
        os.makedirs(self.model_dir, exist_ok=True)
        with open(self.model_dir + '.ini', 'w') as f:
            config.write(f)

        self.step, self.epoch, self.dnn = self.load()
        self.inference = model.Inference(self.config, self.dnn, self.anchors)
        logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in self.inference.state_dict().values())))
        if self.args.finetune:
            path = os.path.expanduser(os.path.expandvars(self.args.finetune))
            logging.info('finetune from ' + path)
            self.finetune(self.dnn, path)
        self.inference = ensure_model(self.inference)
        self.inference.train()
        self.optimizer = eval(self.config.get('train', 'optimizer'))(filter(lambda p: p.requires_grad, self.inference.parameters()), self.args.learning_rate)

        self.saver = utils.train.Saver(self.model_dir, config.getint('save', 'keep'))
        self.timer_save = utils.train.Timer(config.getfloat('save', 'secs'), False)
        try:
            self.timer_eval = utils.train.Timer(eval(config.get('eval', 'secs')), config.getboolean('eval', 'first'))
        except configparser.NoOptionError:
            self.timer_eval = lambda: False
        self.summary_worker = SummaryWorker(self)
        self.summary_worker.start()

    def stop(self):
        self.summary_worker.stop()
        self.summary_worker.join()

    def get_loader(self):
        paths = [os.path.join(self.cache_dir, phase + '.pkl') for phase in self.config.get('train', 'phase').split()]
        dataset = utils.data.Dataset(
            utils.data.load_pickles(paths),
            transform=transform.augmentation.get_transform(self.config, self.config.get('transform', 'augmentation').split()),
            one_hot=None if self.config.getboolean('train', 'cross_entropy') else len(self.category),
            shuffle=self.config.getboolean('data', 'shuffle'),
            dir=os.path.join(self.model_dir, 'exception'),
        )
        logging.info('num_examples=%d' % len(dataset))
        try:
            workers = self.config.getint('data', 'workers')
            if torch.cuda.is_available():
                workers = workers * torch.cuda.device_count()
        except configparser.NoOptionError:
            workers = multiprocessing.cpu_count()
        collate_fn = utils.data.Collate(
            transform.parse_transform(self.config, self.config.get('transform', 'resize_train')),
            utils.train.load_sizes(self.config),
            maintain=self.config.getint('data', 'maintain'),
            transform_image=transform.get_transform(self.config, self.config.get('transform', 'image_train').split()),
            transform_tensor=transform.get_transform(self.config, self.config.get('transform', 'tensor').split()),
            dir=os.path.join(self.model_dir, 'exception'),
        )
        return torch.utils.data.DataLoader(dataset, batch_size=self.args.batch_size * torch.cuda.device_count() if torch.cuda.is_available() else self.args.batch_size, shuffle=True, num_workers=workers, collate_fn=collate_fn, pin_memory=torch.cuda.is_available())

    def load(self):
        try:
            path, step, epoch = utils.train.load_model(self.model_dir)
            state_dict = torch.load(path, map_location=lambda storage, loc: storage)
            config_channels = model.ConfigChannels(self.config, state_dict)
        except (FileNotFoundError, ValueError):
            step, epoch = 0, 0
            config_channels = model.ConfigChannels(self.config)
        dnn = utils.parse_attr(self.config.get('model', 'dnn'))(config_channels, self.anchors, len(self.category))
        if config_channels.state_dict is not None:
            dnn.load_state_dict(config_channels.state_dict)
        return step, epoch, dnn

    def finetune(self, model, path):
        if os.path.isdir(path):
            path, _step, _epoch = utils.train.load_model(path)
        _state_dict = torch.load(path, map_location=lambda storage, loc: storage)
        state_dict = model.state_dict()
        ignore = utils.RegexList(self.args.ignore)
        for key, value in state_dict.items():
            try:
                if not ignore(key):
                    state_dict[key] = _state_dict[key]
            except KeyError:
                logging.warning('%s not in finetune file %s' % (key, path))
        model.load_state_dict(state_dict)

    def iterate(self, data):
        for key in data:
            t = data[key]
            if torch.is_tensor(t):
                data[key] = utils.ensure_device(t)
        tensor = torch.autograd.Variable(data['tensor'])
        pred = pybenchmark.profile('inference')(model._inference)(self.inference, tensor)
        height, width = data['image'].size()[1:3]
        rows, cols = pred['feature'].size()[-2:]
        loss, debug = pybenchmark.profile('loss')(model.loss)(self.anchors, norm_data(data, height, width, rows, cols), pred, self.config.getfloat('model', 'threshold'))
        loss_hparam = {key: loss[key] * self.config.getfloat('hparam', key) for key in loss}
        loss_total = sum(loss_hparam.values())
        self.optimizer.zero_grad()
        loss_total.backward()
        try:
            clip = self.config.getfloat('train', 'clip')
            nn.utils.clip_grad_norm(self.inference.parameters(), clip)
        except configparser.NoOptionError:
            pass
        self.optimizer.step()
        return dict(
            height=height, width=width, rows=rows, cols=cols,
            data=data, pred=pred, debug=debug,
            loss_total=loss_total, loss=loss, loss_hparam=loss_hparam,
        )

    def __call__(self):
        with filelock.FileLock(os.path.join(self.model_dir, 'lock'), 0):
            try:
                try:
                    scheduler = eval(self.config.get('train', 'scheduler'))(self.optimizer)
                except configparser.NoOptionError:
                    scheduler = None
                loader = self.get_loader()
                logging.info('num_workers=%d' % loader.num_workers)
                step = self.step
                for epoch in range(0 if self.epoch is None else self.epoch, self.args.epoch):
                    if scheduler is not None:
                        scheduler.step(epoch)
                        logging.info('epoch=%d, lr=%s' % (epoch, str(scheduler.get_lr())))
                    for data in loader if self.args.quiet else tqdm.tqdm(loader, desc='epoch=%d/%d' % (epoch, self.args.epoch)):
                        kwargs = self.iterate(data)
                        step += 1
                        kwargs = {**kwargs, **dict(step=step, epoch=epoch)}
                        self.summary_worker('scalar', **kwargs)
                        self.summary_worker('image', **kwargs)
                        self.summary_worker('histogram', **kwargs)
                        if self.timer_save():
                            self.save(**kwargs)
                        if self.timer_eval():
                            self.eval(**kwargs)
                self.save(**kwargs)
                logging.info('finished')
            except KeyboardInterrupt:
                logging.warning('interrupted')
                self.save(**kwargs)
            except:
                traceback.print_exc()
                try:
                    with open(os.path.join(self.model_dir, 'data.pkl'), 'wb') as f:
                        pickle.dump(data, f)
                except UnboundLocalError:
                    pass
                raise
            finally:
                self.stop()

    def check_nan(self, **kwargs):
        step, loss_total, loss, data = (kwargs[key] for key in 'step, loss_total, loss, data'.split(', '))
        if np.isnan(loss_total.data.cpu()[0]):
            dump_dir = os.path.join(self.model_dir, str(step))
            os.makedirs(dump_dir, exist_ok=True)
            torch.save(collections.OrderedDict([(key, var.cpu()) for key, var in self.dnn.state_dict().items()]), os.path.join(dump_dir, 'model.pth'))
            torch.save(data, os.path.join(dump_dir, 'data.pth'))
            for key, l in loss.items():
                logging.warning('%s=%f' % (key, l.data.cpu()[0]))
            raise OverflowError('NaN loss detected, dump runtime information into ' + dump_dir)

    def save(self, **kwargs):
        step, epoch = (kwargs[key] for key in 'step, epoch'.split(', '))
        self.check_nan(**kwargs)
        self.saver(collections.OrderedDict([(key, var.cpu()) for key, var in self.dnn.state_dict().items()]), step, epoch)

    def eval(self, **kwargs):
        logging.info('evaluating')
        if torch.cuda.is_available():
            self.inference.cpu()
        try:
            e = _eval.Eval(self.args, self.config)
            cls_ap = e()
            self.backup_best(cls_ap, e.path)
        except:
            traceback.print_exc()
        if torch.cuda.is_available():
            self.inference.cuda()

    def backup_best(self, cls_ap, path):
        try:
            with open(self.model_dir + '.pkl', 'rb') as f:
                best = np.mean(list(pickle.load(f).values()))
        except:
            best = np.finfo(np.float32).min
        metric = np.mean(list(cls_ap.values()))
        if metric > best:
            with open(self.model_dir + '.pkl', 'wb') as f:
                pickle.dump(cls_ap, f)
            shutil.copy(path, self.model_dir + '.pth')
            logging.info('best model (%f) saved into %s.*' % (metric, self.model_dir))
        else:
            logging.info('best metric %f >= %f' % (best, metric))


def main():
    args = make_args()
    config = configparser.ConfigParser()
    utils.load_config(config, args.config)
    for cmd in args.modify:
        utils.modify_config(config, cmd)
    with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
        logging.config.dictConfig(yaml.load(f))
    if args.run is None:
        buffer = io.StringIO()
        config.write(buffer)
        args.run = hashlib.md5(buffer.getvalue().encode()).hexdigest()
    logging.info('cd ' + os.getcwd() + ' && ' + subprocess.list2cmdline([sys.executable] + sys.argv))
    train = Train(args, config)
    train()
    logging.info(pybenchmark.stats)


def make_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
    parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
    parser.add_argument('-b', '--batch_size', default=16, type=int, help='batch size')
    parser.add_argument('-f', '--finetune')
    parser.add_argument('-i', '--ignore', nargs='+', default=[], help='regex to ignore weights while fintuning')
    parser.add_argument('-lr', '--learning_rate', default=1e-3, type=float, help='learning rate')
    parser.add_argument('-e', '--epoch', type=int, default=np.iinfo(np.int).max)
    parser.add_argument('-d', '--delete', action='store_true', help='delete model')
    parser.add_argument('-q', '--quiet', action='store_true', help='quiet mode')
    parser.add_argument('-r', '--run', help='the run name in TensorBoard')
    parser.add_argument('--logging', default='logging.yml', help='logging config')
    return parser.parse_args()


if __name__ == '__main__':
    main()