"""
Copyright (C) 2017, 申瑞珉 (Ruimin Shen)

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.
"""

import argparse
import configparser
import logging
import logging.config
import os
import time
import yaml

import numpy as np
import torch.autograd
import torch.cuda
import torch.optim
import torch.utils.data
import torch.nn.functional as F
import humanize
import pybenchmark
import cv2

import transform
import model
import utils.postprocess
import utils.train
import utils.visualize


def get_logits(pred):
    if 'logits' in pred:
        return pred['logits'].contiguous()
    else:
        size = pred['iou'].size()
        return torch.autograd.Variable(utils.ensure_device(torch.ones(*size, 1)))


def filter_visible(config, iou, yx_min, yx_max, prob):
    prob_cls, cls = torch.max(prob, -1)
    if config.getboolean('detect', 'fix'):
        mask = (iou * prob_cls) > config.getfloat('detect', 'threshold_cls')
    else:
        mask = iou > config.getfloat('detect', 'threshold')
    iou, prob_cls, cls = (t[mask].view(-1) for t in (iou, prob_cls, cls))
    _mask = torch.unsqueeze(mask, -1).repeat(1, 2)  # PyTorch's bug
    yx_min, yx_max = (t[_mask].view(-1, 2) for t in (yx_min, yx_max))
    num = prob.size(-1)
    _mask = torch.unsqueeze(mask, -1).repeat(1, num)  # PyTorch's bug
    prob = prob[_mask].view(-1, num)
    return iou, yx_min, yx_max, prob, prob_cls, cls


def postprocess(config, iou, yx_min, yx_max, prob):
    iou, yx_min, yx_max, prob, prob_cls, cls = filter_visible(config, iou, yx_min, yx_max, prob)
    keep = pybenchmark.profile('nms')(utils.postprocess.nms)(iou, yx_min, yx_max, config.getfloat('detect', 'overlap'))
    if keep:
        keep = utils.ensure_device(torch.LongTensor(keep))
        iou, yx_min, yx_max, prob, prob_cls, cls = (t[keep] for t in (iou, yx_min, yx_max, prob, prob_cls, cls))
        if config.getboolean('detect', 'fix'):
            score = torch.unsqueeze(iou, -1) * prob
            mask = score > config.getfloat('detect', 'threshold_cls')
            indices, cls = torch.unbind(mask.nonzero(), -1)
            yx_min, yx_max = (t[indices] for t in (yx_min, yx_max))
            score = score[mask]
        else:
            score = iou
        return iou, yx_min, yx_max, cls, score


class Detect(object):
    def __init__(self, args, config):
        self.args = args
        self.config = config
        self.cache_dir = utils.get_cache_dir(config)
        self.model_dir = utils.get_model_dir(config)
        self.category = utils.get_category(config, self.cache_dir if os.path.exists(self.cache_dir) else None)
        self.draw_bbox = utils.visualize.DrawBBox(self.category, colors=args.colors, thickness=args.thickness)
        self.anchors = torch.from_numpy(utils.get_anchors(config)).contiguous()
        self.height, self.width = tuple(map(int, config.get('image', 'size').split()))
        self.path, self.step, self.epoch = utils.train.load_model(self.model_dir)
        state_dict = torch.load(self.path, map_location=lambda storage, loc: storage)
        self.dnn = utils.parse_attr(config.get('model', 'dnn'))(model.ConfigChannels(config, state_dict), self.anchors, len(self.category))
        self.dnn.load_state_dict(state_dict)
        self.inference = model.Inference(config, self.dnn, self.anchors)
        self.inference.eval()
        if torch.cuda.is_available():
            self.inference.cuda()
        logging.info(humanize.naturalsize(sum(var.cpu().numpy().nbytes for var in self.inference.state_dict().values())))
        self.cap = self.create_cap()
        self.keys = set(args.keys)
        self.resize = transform.parse_transform(config, config.get('transform', 'resize_test'))
        self.transform_image = transform.get_transform(config, config.get('transform', 'image_test').split())
        self.transform_tensor = transform.get_transform(config, config.get('transform', 'tensor').split())

    def __del__(self):
        cv2.destroyAllWindows()
        try:
            self.writer.release()
        except AttributeError:
            pass
        self.cap.release()

    def create_cap(self):
        try:
            cap = int(self.args.input)
        except ValueError:
            cap = os.path.expanduser(os.path.expandvars(self.args.input))
            assert os.path.exists(cap)
        return cv2.VideoCapture(cap)

    def create_writer(self, height, width):
        fps = self.cap.get(cv2.CAP_PROP_FPS)
        logging.info('cap fps=%f' % fps)
        path = os.path.expanduser(os.path.expandvars(self.args.output))
        if self.args.fourcc:
            fourcc = cv2.VideoWriter_fourcc(*self.args.fourcc.upper())
        else:
            fourcc = int(self.cap.get(cv2.CAP_PROP_FOURCC))
        os.makedirs(os.path.dirname(path), exist_ok=True)
        return cv2.VideoWriter(path, fourcc, fps, (width, height))

    def get_image(self):
        ret, image_bgr = self.cap.read()
        if self.args.crop:
            image_bgr = image_bgr[self.crop_ymin:self.crop_ymax, self.crop_xmin:self.crop_xmax]
        return image_bgr

    def __call__(self):
        image_bgr = self.get_image()
        image_resized = self.resize(image_bgr, self.height, self.width)
        image = self.transform_image(image_resized)
        tensor = self.transform_tensor(image)
        tensor = utils.ensure_device(tensor.unsqueeze(0))
        pred = pybenchmark.profile('inference')(model._inference)(self.inference, torch.autograd.Variable(tensor, volatile=True))
        rows, cols = pred['feature'].size()[-2:]
        iou = pred['iou'].data.contiguous().view(-1)
        yx_min, yx_max = (pred[key].data.view(-1, 2) for key in 'yx_min, yx_max'.split(', '))
        logits = get_logits(pred)
        prob = F.softmax(logits, -1).data.view(-1, logits.size(-1))
        ret = postprocess(self.config, iou, yx_min, yx_max, prob)
        image_result = image_bgr.copy()
        if ret is not None:
            iou, yx_min, yx_max, cls, score = ret
            try:
                scale = self.scale
            except AttributeError:
                scale = utils.ensure_device(torch.from_numpy(np.array(image_result.shape[:2], np.float32) / np.array([rows, cols], np.float32)))
                self.scale = scale
            yx_min, yx_max = ((t * scale).cpu().numpy().astype(np.int) for t in (yx_min, yx_max))
            image_result = self.draw_bbox(image_result, yx_min, yx_max, cls)
        if self.args.output:
            if not hasattr(self, 'writer'):
                self.writer = self.create_writer(*image_result.shape[:2])
            self.writer.write(image_result)
        else:
            cv2.imshow('detection', image_result)
        if cv2.waitKey(0 if self.args.pause else 1) in self.keys:
            root = os.path.join(self.model_dir, 'snapshot')
            os.makedirs(root, exist_ok=True)
            path = os.path.join(root, time.strftime(self.args.format))
            cv2.imwrite(path, image_bgr)
            logging.warning('image dumped into ' + path)


def main():
    args = make_args()
    config = configparser.ConfigParser()
    utils.load_config(config, args.config)
    for cmd in args.modify:
        utils.modify_config(config, cmd)
    with open(os.path.expanduser(os.path.expandvars(args.logging)), 'r') as f:
        logging.config.dictConfig(yaml.load(f))
    detect = Detect(args, config)
    try:
        while detect.cap.isOpened():
            detect()
    except KeyboardInterrupt:
        logging.warning('interrupted')
    finally:
        logging.info(pybenchmark.stats)


def make_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', nargs='+', default=['config.ini'], help='config file')
    parser.add_argument('-m', '--modify', nargs='+', default=[], help='modify config')
    parser.add_argument('-i', '--input', default=-1)
    parser.add_argument('-k', '--keys', nargs='+', type=int, default=[ord(' ')], help='keys to dump images')
    parser.add_argument('-o', '--output', help='output video file')
    parser.add_argument('-f', '--format', default='%Y-%m-%d_%H-%M-%S.jpg', help='dump file name format')
    parser.add_argument('--crop', nargs='+', type=float, default=[], help='ymin ymax xmin xmax')
    parser.add_argument('--pause', action='store_true')
    parser.add_argument('--fourcc', default='XVID', help='4-character code of codec used to compress the frames, such as XVID, MJPG')
    parser.add_argument('--thickness', default=3, type=int)
    parser.add_argument('--colors', nargs='+', default=[])
    parser.add_argument('--logging', default='logging.yml', help='logging config')
    return parser.parse_args()


if __name__ == '__main__':
    main()