# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Zhipeng Zhang and Houwen Peng
# Email: houwen.peng@microsoft.com
# Details: siamfc dataset generator
# ------------------------------------------------------------------------------


from __future__ import division

import cv2
import json
import torch
import random
import logging
import numpy as np
import torchvision.transforms as transforms
from scipy.ndimage.filters import gaussian_filter
from os.path import join
from easydict import EasyDict as edict
from torch.utils.data import Dataset

import sys
sys.path.append('../')
from utils.utils import *
from core.config import config

sample_random = random.Random()
# sample_random.seed(123456)


class SiamFCDataset(Dataset):
    def __init__(self, cfg):
        super(SiamFCDataset, self).__init__()
        # pair information
        self.template_size = cfg.SIAMFC.TRAIN.TEMPLATE_SIZE
        self.search_size = cfg.SIAMFC.TRAIN.SEARCH_SIZE
        self.size = (self.search_size - self.template_size) // cfg.SIAMFC.TRAIN.STRIDE + 1   # from cross-correlation

        # aug information
        self.color = cfg.SIAMFC.DATASET.COLOR
        self.flip = cfg.SIAMFC.DATASET.FLIP
        self.rotation = cfg.SIAMFC.DATASET.ROTATION
        self.blur = cfg.SIAMFC.DATASET.BLUR
        self.shift = cfg.SIAMFC.DATASET.SHIFT
        self.scale = cfg.SIAMFC.DATASET.SCALE

        self.transform_extra = transforms.Compose(
            [transforms.ToPILImage(), ] +
            ([transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), ] if self.color > random.random() else [])
            + ([transforms.RandomHorizontalFlip(), ] if self.flip > random.random() else [])
            + ([transforms.RandomRotation(degrees=10), ] if self.rotation > random.random() else [])
        )

        # train data information
        if cfg.SIAMFC.TRAIN.WHICH_USE == 'VID':
            self.anno = cfg.SIAMFC.DATASET.VID.ANNOTATION
            self.num_use = cfg.SIAMFC.TRAIN.PAIRS
            self.root = cfg.SIAMFC.DATASET.VID.PATH
        elif cfg.SIAMFC.TRAIN.WHICH_USE == 'GOT10K':
            self.anno = cfg.SIAMFC.DATASET.GOT10K.ANNOTATION
            self.num_use = cfg.SIAMFC.TRAIN.PAIRS
            self.root = cfg.SIAMFC.DATASET.GOT10K.PATH
        else:
            raise ValueError('not supported training dataset')

        self.labels = json.load(open(self.anno, 'r'))
        self.videos = list(self.labels.keys())
        self.num = len(self.videos)   # video number
        self.frame_range = 100
        self.pick = self._shuffle()

    def __len__(self):
        return self.num_use

    def __getitem__(self, index):
        """
        pick a vodeo/frame --> pairs --> data aug --> label
        """
        index = self.pick[index]
        template, search = self._get_pairs(index)

        template_image = cv2.imread(template[0])
        search_image = cv2.imread(search[0])

        template_box = self._toBBox(template_image, template[1])
        search_box = self._toBBox(search_image, search[1])

        template, _, _ = self._augmentation(template_image, template_box, self.template_size)
        search, bbox, dag_param = self._augmentation(search_image, search_box, self.search_size)

        # from PIL image to numpy
        template = np.array(template)
        search = np.array(search)

        out_label = self._dynamic_label([self.size, self.size], dag_param.shift)

        template, search = map(lambda x: np.transpose(x, (2, 0, 1)).astype(np.float32), [template, search])

        return template, search, out_label, np.array(bbox, np.float32)  # self.label 15*15/17*17

    # ------------------------------------
    # function groups for selecting pairs
    # ------------------------------------
    def _shuffle(self):
        """
        shuffel to get random pairs index
        """
        lists = list(range(0, self.num))
        m = 0
        pick = []
        while m < self.num_use:
            sample_random.shuffle(lists)
            pick += lists
            m += self.num

        self.pick = pick[:self.num_use]
        return self.pick

    def _get_image_anno(self, video, track, frame):
        """
        get image and annotation
        """
        frame = "{:06d}".format(frame)
        image_path = join(self.root, video, "{}.{}.x.jpg".format(frame, track))
        image_anno = self.labels[video][track][frame]

        return image_path, image_anno

    def _get_pairs(self, index):
        """
        get training pairs
        """
        video_name = self.videos[index]
        video = self.labels[video_name]
        track = random.choice(list(video.keys()))
        track_info = video[track]
        try:
            frames = track_info['frames']
        except:
            frames = list(track_info.keys())

        template_frame = random.randint(0, len(frames)-1)

        left = max(template_frame - self.frame_range, 0)
        right = min(template_frame + self.frame_range, len(frames)-1) + 1
        search_range = frames[left:right]
        template_frame = int(frames[template_frame])
        search_frame = int(random.choice(search_range))

        return self._get_image_anno(video_name, track, template_frame), \
               self._get_image_anno(video_name, track, search_frame)

    def _posNegRandom(self):
        """
        random number from [-1, 1]
        """
        return random.random() * 2 - 1.0

    def _toBBox(self, image, shape):
        imh, imw = image.shape[:2]
        if len(shape) == 4:
            w, h = shape[2] - shape[0], shape[3] - shape[1]
        else:
            w, h = shape
        context_amount = 0.5
        exemplar_size = self.template_size
        wc_z = w + context_amount * (w + h)
        hc_z = h + context_amount * (w + h)
        s_z = np.sqrt(wc_z * hc_z)
        scale_z = exemplar_size / s_z
        w = w * scale_z
        h = h * scale_z
        cx, cy = imw // 2, imh // 2
        bbox = center2corner(Center(cx, cy, w, h))
        return bbox

    def _crop_hwc(self, image, bbox, out_sz, padding=(0, 0, 0)):
        """
        crop image
        """
        bbox = [float(x) for x in bbox]
        a = (out_sz - 1) / (bbox[2] - bbox[0])
        b = (out_sz - 1) / (bbox[3] - bbox[1])
        c = -a * bbox[0]
        d = -b * bbox[1]
        mapping = np.array([[a, 0, c],
                            [0, b, d]]).astype(np.float)
        crop = cv2.warpAffine(image, mapping, (out_sz, out_sz), borderMode=cv2.BORDER_CONSTANT, borderValue=padding)
        return crop

    def _draw(self, image, box, name):
        """
        draw image for debugging
        """
        draw_image = image.copy()
        x1, y1, x2, y2 = map(lambda x:int(round(x)), box)
        cv2.rectangle(draw_image, (x1, y1), (x2, y2), (0,255,0))
        cv2.circle(draw_image, (int(round(x1 + x2)/2), int(round(y1 + y2) /2)), 3, (0, 0, 255))
        cv2.putText(draw_image, '[x: {}, y: {}]'.format(int(round(x1 + x2)/2), int(round(y1 + y2) /2)), (int(round(x1 + x2)/2) - 3, int(round(y1 + y2) /2) -3), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1)
        cv2.imwrite(name, draw_image)

    # ------------------------------------
    # function for data augmentation
    # ------------------------------------
    def _augmentation(self, image, bbox, size):
        """
        data augmentation for input pairs
        """
        shape = image.shape
        crop_bbox = center2corner((shape[0] // 2, shape[1] // 2, size, size))
        param = edict()

        param.shift = (self._posNegRandom() * self.shift, self._posNegRandom() * self.shift)   # shift
        param.scale = ((1.0 + self._posNegRandom() * self.scale), (1.0 + self._posNegRandom() * self.scale))  # scale change

        crop_bbox, _ = aug_apply(Corner(*crop_bbox), param, shape)

        x1, y1 = crop_bbox.x1, crop_bbox.y1
        bbox = BBox(bbox.x1 - x1, bbox.y1 - y1, bbox.x2 - x1, bbox.y2 - y1)

        scale_x, scale_y = param.scale
        bbox = Corner(bbox.x1 / scale_x, bbox.y1 / scale_y, bbox.x2 / scale_x, bbox.y2 / scale_y)

        image = self._crop_hwc(image, crop_bbox, size)   # shift and scale

        if self.blur > random.random():
            image = gaussian_filter(image, sigma=(1, 1, 0))

        image = self.transform_extra(image)        # other data augmentation
        return image, bbox, param

    # ------------------------------------
    # function for creating training label
    # ------------------------------------
    def _dynamic_label(self, fixedLabelSize, c_shift, rPos=2, rNeg=0):
        if isinstance(fixedLabelSize, int):
            fixedLabelSize = [fixedLabelSize, fixedLabelSize]

        assert (fixedLabelSize[0] % 2 == 1)

        d_label = self._create_dynamic_logisticloss_label(fixedLabelSize, c_shift, rPos, rNeg)

        return d_label

    def _create_dynamic_logisticloss_label(self, label_size, c_shift, rPos=2, rNeg=0):
        if isinstance(label_size, int):
            sz = label_size
        else:
            sz = label_size[0]

        # the real shift is -param['shifts']
        sz_x = sz // 2 + round(-c_shift[0]) // 8  # 8 is strides
        sz_y = sz // 2 + round(-c_shift[1]) // 8

        x, y = np.meshgrid(np.arange(0, sz) - np.floor(float(sz_x)),
                           np.arange(0, sz) - np.floor(float(sz_y)))

        dist_to_center = np.abs(x) + np.abs(y)  # Block metric
        label = np.where(dist_to_center <= rPos,
                         np.ones_like(y),
                         np.where(dist_to_center < rNeg,
                                  0.5 * np.ones_like(y),
                                  np.zeros_like(y)))
        return label