"""Script to spot cat faces in videos and draw bounding boxes around them.
Expects file 'model.best.tar' to exist (generated by train.py).
Writes outputs to outputs/videos/ by default."""
from __future__ import print_function, division
import argparse
import numpy as np
import os
from collections import defaultdict
from scipy import misc
from model import Model2
from common import to_aspect_ratio_add, draw_heatmap, imresize_sidelen
from bbs import RectangleOnImage
import cv2
import torch
from torch.autograd import Variable
from skimage import morphology
#from sklearn.cluster import DBSCAN
import imgaug as ia
import time

torch.backends.cudnn.benchmark = True

WRITE_TO_BASEDIR = "outputs/videos/"
GPU = 0

def main():
    """Find bounding boxes in a video."""

    parser = argparse.ArgumentParser(description="Process a video")
    parser.add_argument("--video", help="Filepath to the video", required=True)
    parser.add_argument("--out_dir", help="Directory name in which to save results")
    parser.add_argument("--start_frame", default=1, type=int, help="Frame number to start at (1 to N).")
    parser.add_argument("--end_frame", help="Frame number to end at (1 to N).")
    parser.add_argument("--conf", default=0.5, type=float, help="Confidence threshold for BBs")
    parser.add_argument("--size", default=400, type=int, help="Input image size when feeding into the model")
    args = parser.parse_args()

    # load trained model
    checkpoint = torch.load("model.best.tar")
    model = Model2()
    model.load_state_dict(checkpoint["state_dict"])
    if GPU >= 0:
        model.cuda(GPU)
    model.eval()
    del checkpoint

    # check if video file exists
    video_fp = args.video
    assert os.path.isfile(video_fp)

    # convert video filename to output directory path
    video_fn = os.path.basename(video_fp)
    if args.out_dir is not None:
        write_to_dir = os.path.join(WRITE_TO_BASEDIR, args.out_dir)
    else:
        write_to_dir = os.path.join(WRITE_TO_BASEDIR, os.path.splitext(video_fn)[0])

    # create output directory if necessary
    if not os.path.exists(write_to_dir):
        os.makedirs(write_to_dir)

    # start reading video
    vidcap = cv2.VideoCapture(video_fp)
    success, img = vidcap.read() # img => float 0-255 BGR
    frame_idx = 0

    # forward video, if --start_frame=<int> was used
    if int(args.start_frame) > 1:
        while frame_idx+1 < int(args.start_frame):
            success, img = vidcap.read() # img => float 0-255 BGR
            frame_idx += 1

    # handle frames of video
    while success:
        # end if args.end_frame=<int> was used and that frame was reached
        if args.end_frame is not None and frame_idx >= args.end_frame:
            break

        # find BBs in frame
        time_start = time.time()
        img_rgb = img[:, :, ::-1]
        time_model = process_frame(frame_idx, img_rgb, model, write_to_dir, args.conf, input_size=args.size)
        #debug_frame(frame_idx, img_rgb, model, args.conf, input_size=args.size)
        time_req = time.time() - time_start

        # output message and forward to next frame
        print("Frame %05d in %03dms (model: %03dms)." % (frame_idx, time_req*1000, time_model*1000))
        success, img = vidcap.read()
        frame_idx += 1

def process_frame(frame_idx, img, model, write_to_dir, conf_threshold, input_size=224):
    """Finds bounding boxes in a video frame, draws these bounding boxes
    and saves the result to HDD.
    """
    # find BBs in frame
    bbs, time_model = find_bbs(img, model, conf_threshold, input_size=input_size)

    # draw BBs
    img_out = np.copy(img)
    for (bb, score) in bbs:
        if score > conf_threshold and bb.width > 2 and bb.height > 2:
            img_out = bb.draw_on_image(img_out, color=[0, 255, 0], thickness=3)

    # save to output directory
    save_to_fp = os.path.join(write_to_dir, "%05d.jpg" % (frame_idx,))
    misc.imsave(save_to_fp, img_out)

    return time_model

def find_bbs(img, model, conf_threshold, input_size):
    """Find bounding boxes in an image."""
    # pad image so that its square
    img_pad, (pad_top, pad_right, pad_bottom, pad_left) = to_aspect_ratio_add(img, 1.0, return_paddings=True)

    # resize padded image to desired input size
    # "linear" interpolation seems to be enough here for 400x400 or larger images
    # change to "area" or "cubic" for marginally better quality
    img_rs = ia.imresize_single_image(img_pad, (input_size, input_size), interpolation="linear")

    # convert to torch-ready input variable
    inputs_np = (np.array([img_rs])/255.0).astype(np.float32).transpose(0, 3, 1, 2)
    inputs = torch.from_numpy(inputs_np)
    inputs = Variable(inputs, volatile=True)
    if GPU >= 0:
        inputs = inputs.cuda(GPU)

    # apply model and measure the model's time
    time_start = time.time()
    outputs_pred = model(inputs)
    time_req = time.time() - time_start

    # process the model's output (i.e. convert heatmaps to BBs)
    result = ModelResult(
        outputs_pred,
        inputs_np,
        img,
        (pad_top, pad_right, pad_bottom, pad_left)
    )
    bbs = result.get_bbs()

    return bbs, time_req

class ModelResult(object):
    """Class the handles the transformation from heatmaps (model output) to
    bounding boxes."""

    def __init__(self, outputs, inputs_np, img, paddings):
        self.inputs = inputs_np
        self.outputs = outputs.cpu().data.numpy()
        assert self.inputs.ndim == 4
        assert self.outputs.ndim == 4
        assert self.inputs.shape[0] == 1
        assert self.outputs.shape[0] == 1
        self.img = img
        self.paddings = paddings
        self.shrink_depth = 1
        self.shrink_threshold = 0.9
        self.heatmap_activation_threshold = 0.25

    def get_bbs(self):
        """Convert model outputs to bounding boxes."""
        outputs_pred = self.outputs

        # generate shape of model input image
        # (=> original image + padding to square it + resize)
        # note: self.inputs has form NCHW
        img_pad_rs_shape = (
            self.inputs.shape[2],
            self.inputs.shape[3],
            3
        )

        # generate shape of original image after padding (no resize)
        img_pad_shape = (
            self.img.shape[0] + self.paddings[0] + self.paddings[2],
            self.img.shape[1] + self.paddings[1] + self.paddings[3],
            3
        )

        # convert heatmaps to rectangles
        # (old code that uses all heatmaps)
        """
        hm_idx_to_rects = []
        for i in range(outputs_pred.shape[1]):
            hms = self._heatmap_to_rects(
                outputs_pred[0, i, ...],
                img_pad_rs_shape
            )

            hms_rev = self._rects_reverse_projection(
                hms, self.img.shape, img_pad_shape,
                self.paddings[0], self.paddings[1],
                self.paddings[2], self.paddings[3]
            )
            hm_idx_to_rects.append(hms_rev)
        bbs = merge_rects_to_bbs(hm_idx_to_rects, conf_threshold)
        """

        # convert only the first heatmap outputs to BBs
        # the other heatmaps (top left corner, top center, ...)
        # are currently ignored
        hm_idx_to_rects = []
        for i in [0]:
            hms = self._heatmap_to_rects(
                outputs_pred[0, i, ...],
                img_pad_rs_shape
            )

            hms_rev = self._rects_reverse_projection(
                hms, self.img.shape, img_pad_shape,
                self.paddings[0], self.paddings[1],
                self.paddings[2], self.paddings[3]
            )
            hm_idx_to_rects.append(hms_rev)
        bbs = hm_idx_to_rects[0]

        return bbs

    def _heatmap_to_rects(self, grid_pred, bb_img):
        """Convert a heatmap to rectangles / bounding box candidates."""
        grid_pred = np.squeeze(grid_pred) # (1, H, W) => (H, W)

        # remove low activations
        grid_thresh = grid_pred >= self.heatmap_activation_threshold

        # find connected components
        grid_labeled, num_labels = morphology.label(
            grid_thresh, background=0, connectivity=1, return_num=True
        )

        # for each connected components,
        # - draw a bounding box around it,
        # - shrink the bounding box to optimal size
        # - estimate a score/confidence value
        bbs = []
        for label in range(1, num_labels+1):
            (yy, xx) = np.nonzero(grid_labeled == label)
            min_y, max_y = np.min(yy), np.max(yy)
            min_x, max_x = np.min(xx), np.max(xx)
            rect = RectangleOnImage(x1=min_x, x2=max_x+1, y1=min_y, y2=max_y+1, shape=grid_labeled)
            activation = self._rect_to_score(rect, grid_pred)
            rect_shrunk, activation_shrunk = self._shrink(grid_pred, rect)
            rect_rs_shrunk = rect_shrunk.on(bb_img)
            bbs.append((rect_rs_shrunk, activation_shrunk))

        return bbs

    def _shrink(self, heatmap, rect):
        """Shrink a rectangle to get rid of some low activations.

        The model often generates areas of high activations, with a few
        pixels of medium activations on the side. When drawing a bounding box
        around these activations, the medium ones can force the bounding box
        to become significantly larger than it should be. This function tries
        to shrink those bounding boxes, while retaining most of the activation.

        This function is implemented in a (slow) recursive way. Using dynamic
        programming would probably be faster.
        """
        assert rect.width >= 1 and rect.height >= 1
        #print("shrink...", rect)
        score_orig = self._rect_to_score(rect, heatmap)
        candidates = self._shrink_candidates(rect, depth=self.shrink_depth)
        candidates_scored = []
        #print("score..")
        for candidate in candidates:
            score = self._rect_to_score(candidate, heatmap)
            score_rel = score / score_orig
            if score_rel >= self.shrink_threshold:
                candidates_scored.append((candidate, score, candidate.area))
        #print("sort ", len(candidates_scored))
        candidates_scored = sorted(candidates_scored, key=lambda t: t[2])
        return (candidates_scored[0][0], candidates_scored[0][1])

    def _shrink_candidates(self, rect, depth):
        """Recursive function called by _shrink() to generate bounding box
        candidates that are smaller than the input bounding box."""
        result = [rect]

        if depth > 0:
            if rect.width > 1:
                rect_left = rect.copy(x1=rect.x1+1)
                rect_right = rect.copy(x2=rect.x2-1)
                result.extend(self._shrink_candidates(rect_left, depth=depth-1))
                result.extend(self._shrink_candidates(rect_right, depth=depth-1))

            if rect.height > 1:
                rect_top = rect.copy(y1=rect.y1+1)
                rect_bottom = rect.copy(y2=rect.y2-1)
                result.extend(self._shrink_candidates(rect_top, depth=depth-1))
                result.extend(self._shrink_candidates(rect_bottom, depth=depth-1))

        return result

    def _rects_reverse_projection(self, rects, img_shape, img_pad_shape, pad_top, pad_right, pad_bottom, pad_left):
        """Input images into the model are padded to make them squared. They
        are also resized to a smaller size. This function is supposed to
        remove both effects, i.e. to project the found bounding boxes from
        the padded and resized image to the unpadded und unresized (original)
        input image.
        """
        result = []
        for (rect, score) in rects:
            # project from resized padded (squared) image to unresized one
            rect_large = rect.on(img_pad_shape)
            # move rectangles to remove paddings
            rect_large_unpadded = rect_large.shift(top=-pad_top, left=-pad_left)
            # positions of corners are now correct, so switch underlying shape
            rect_large_unpadded = rect_large_unpadded.copy(shape=img_shape)
            result.append((rect_large_unpadded, score))
        return result

    def _rect_to_score(self, rect, heatmap):
        """Compute a score for a given rectangle (i.e. the confidence value).
        Currently this is done via an average of the corresponding activations
        in the heatmap.
        """
        subheatmap = rect.extract_from_image(heatmap)
        if subheatmap.ndim == 2 and subheatmap.shape[0] > 0 and subheatmap.shape[1] > 0:
            return np.average(subheatmap)
        else:
            print("[WARN] Broken heatmap extracted for rectangle", rect)
            return 0

    # The following stuff is some old code to make use of all generated
    # heatmaps. Didn't work well in tests.
    """
    def _merge_rects_to_bbs(self, hm_idx_to_rects, conf_threshold, img_shape):
        rects_full_size = self._make_rects_full_size(hm_idx_to_rects, self.img.shape)
        groups = self._group_rects(rects_full_size)

        final_bbs = []
        #for label, rects in cluster.iteritems():
        for label in groups:
            rects = groups[label]
            score_avg = sum([score for (rect, score) in rects]) / (1+9)
            if score_avg > conf_threshold:
                x1 = np.average([rect.x1 for (rect, score) in rects])
                x2 = np.average([rect.x2 for (rect, score) in rects])
                y1 = np.average([rect.y1 for (rect, score) in rects])
                y2 = np.average([rect.y2 for (rect, score) in rects])
                final_bbs.append((RectangleOnImage(x1=x1, y1=y1, x2=x2, y2=y2, shape=rect.shape), score_avg))

        return final_bbs

    def _make_rects_full_size(self, hm_idx_to_rects, img_orig_shape, keep_grouping=False):
        rects_full_size = []

        if keep_grouping:
            group = []
            rects_full_size.append(group)
        else:
            group = rects_full_size

        for (rect, score) in hm_idx_to_rects[0]:
            group.append((rect, score))

        nb_cells_y = 3
        nb_cells_x = 3
        grid_idx = 1
        for row_idx in range(nb_cells_y):
            for col_idx in range(nb_cells_x):
                if keep_grouping:
                    group = []
                    rects_full_size.append(group)
                else:
                    group = rects_full_size

                left = col_idx
                right = nb_cells_x - col_idx - 1
                above = row_idx
                below = nb_cells_y - row_idx - 1

                for (rect, score) in hm_idx_to_rects[grid_idx]:
                    x1 = rect.x1 - (left * rect.width)
                    x2 = rect.x2 + (right * rect.width)
                    y1 = rect.y1 - (above * rect.height)
                    y2 = rect.y2 + (below * rect.height)
                    rect_full_size = RectangleOnImage(x1=x1, x2=x2, y1=y1, y2=y2, shape=img_orig_shape)
                    group.append((rect_full_size, score))

                grid_idx += 1
        return rects_full_size

    def _group_rects(self, rects_full_size):
        if len(rects_full_size) == 0:
            return dict()
        elif len(rects_full_size) == 1:
            return dict([(0, [rects_full_size])])
        else:
            distances = np.zeros((len(rects_full_size), len(rects_full_size)), dtype=np.float32)
            for i in range(len(rects_full_size)):
                rect1 = rects_full_size[i][0]
                for j in range(i+1, len(rects_full_size)):
                    rect2 = rects_full_size[j][0]
                    sim = rect1.iou(rect2)
                    distances[i, j] = (1 - sim)
                    distances[j, i] = (1 - sim)

            clusterer = DBSCAN(metric="precomputed")
            labels = clusterer.fit_predict(distances)

            clusters = defaultdict(list)
            for label, (rect, score) in zip(labels, rects_full_size):
                clusters[label].append((rect, score))
            return clusters
    """

def debug_frame(frame_idx, img, model, conf_threshold, input_size=224):
    """Corresponding function to process_frame() that effectively does the same,
    but shows some debug information.

    Probably doesn't work currently as some functions were moved into a class.
    """
    img_orig_shape = img.shape
    img_pad, (pad_top, pad_right, pad_bottom, pad_left) = to_aspect_ratio_add(img, 1.0, return_paddings=True)
    img_rs = misc.imresize(img_pad, (input_size, input_size))
    inputs = (np.array([img_rs])/255.0).astype(np.float32).transpose(0, 3, 1, 2)
    inputs = torch.from_numpy(inputs)
    inputs = Variable(inputs)
    if GPU >= 0:
        inputs = inputs.cuda(GPU)
    outputs_pred = model(inputs)
    outputs_pred = outputs_pred.data.cpu().numpy()
    print("outputs_pred", np.min(outputs_pred), np.average(outputs_pred), np.max(outputs_pred))
    hm_idx_to_rects_pad = []
    for i in range(outputs_pred.shape[1]):
        hms = heatmap_to_rects(outputs_pred[0, i, ...], img_rs)
        hm_idx_to_rects_pad.append(hms)
    hm_idx_to_rects = []
    for i in range(outputs_pred.shape[1]):
        hms = heatmap_to_rects(outputs_pred[0, i, ...], img_rs)
        hms_rev = rects_reverse_projection(hms, img.shape, img_pad.shape, pad_top, pad_right, pad_bottom, pad_left)
        hm_idx_to_rects.append(hms_rev)

    rects_full_size_vis = make_rects_full_size(hm_idx_to_rects, img_orig_shape, keep_grouping=True)
    rects_full_size = make_rects_full_size(hm_idx_to_rects, img_orig_shape, keep_grouping=False)
    groups = group_rects(rects_full_size)

    final_bbs = []
    #for label, rects in cluster.iteritems():
    for label in groups:
        rects = groups[label]
        score_avg = sum([score for (rect, score) in rects]) / (1+9)
        if score_avg > conf_threshold:
            x1 = np.average([rect.x1 for (rect, score) in rects])
            x2 = np.average([rect.x2 for (rect, score) in rects])
            y1 = np.average([rect.y1 for (rect, score) in rects])
            y2 = np.average([rect.y2 for (rect, score) in rects])
            final_bbs.append((RectangleOnImage(x1=x1, y1=y1, x2=x2, y2=y2, shape=img_orig_shape), score_avg))

    img_rs_nopad = imresize_sidelen(img, 200, pick_func=max)
    rows = []

    # heatmaps
    row = [misc.imresize(img_rs, (img_rs_nopad.shape[0], img_rs_nopad.shape[1]))]
    for i in range(outputs_pred.shape[1]):
        hm = draw_heatmap(img_rs, outputs_pred[0, i])
        row.append(misc.imresize(hm, (img_rs_nopad.shape[0], img_rs_nopad.shape[1])))
    rows.append(np.hstack(row))

    # heatmaps => rects (padded image)
    #print("pad", pad_top, pad_right, pad_bottom, pad_left)
    #print("hm_idx_to_rects_pad", hm_idx_to_rects_pad)
    row = [misc.imresize(img_rs, (img_rs_nopad.shape[0], img_rs_nopad.shape[1]))]
    for rects in hm_idx_to_rects_pad:
        img_cp = np.copy(row[0])
        for (rect, score) in rects:
            img_cp = rect.draw_on_image(img_cp, color=[0, 255, 0])
        row.append(img_cp)
    rows.append(np.hstack(row))

    # heatmaps => rects (unpadded/original image)
    #print("hm_idx_to_rects", [[(r.on(img_rs_nopad), s) for (r, s) in rects] for rects in hm_idx_to_rects])
    row = [img_rs_nopad]
    for rects in hm_idx_to_rects:
        img_cp = np.copy(img_rs_nopad)
        for (rect, score) in rects:
            img_cp = rect.draw_on_image(img_cp, color=[0, 255, 0])
        row.append(img_cp)
    rows.append(np.hstack(row))

    # heatmaps => rects full size
    row = [img_rs_nopad]
    for rects in rects_full_size_vis:
        img_cp = np.copy(img_rs_nopad)
        for (rect, score) in rects:
            img_cp = rect.draw_on_image(img_cp, color=[0, 255, 0])
        row.append(img_cp)
    rows.append(np.hstack(row))

    # clustered rects
    img_cp = np.copy(img_rs_nopad)
    for label in groups:
        col = np.random.randint(0, 255, size=(3,))
        rects = groups[label]
        for (rect, score) in rects:
            img_cp = rect.draw_on_image(img_cp, color=col)
    row = np.hstack([img_rs_nopad, img_cp])
    diff = img_rs_nopad.shape[1] * (1+1+9) - row.shape[1]
    row = np.pad(row, ((0, 0), (0, diff), (0, 0)), mode="constant", constant_values=0)
    rows.append(row)

    # final rects
    img_cp = np.copy(img_rs_nopad)
    for (rect, score) in final_bbs:
        col = np.random.randint(0, 255, size=(3,))
        img_cp = rect.draw_on_image(img_cp, color=col)
    row = np.hstack([img_rs_nopad, img_cp])
    diff = img_rs_nopad.shape[1] * (1+1+9) - row.shape[1]
    row = np.pad(row, ((0, 0), (0, diff), (0, 0)), mode="constant", constant_values=0)
    rows.append(row)

    #print([r.shape for r in rows])
    misc.imshow(np.vstack(rows))

if __name__ == "__main__":
    main()