from __future__ import division

import numbers
import os
import random

import numpy as np
import cv2

import itertools

from commons import minmax
from configs import ADNetConf
from networks import ADNetwork

class Coordinate:
    def get_imgwh(img):
        return Coordinate(x=img.shape[1], y=img.shape[0])

    def __init__(self, x, y):
        self.x = int(round(x))
        self.y = int(round(y))

    def __repr__(self):
        return 'x=%d, y=%d' % (self.x, self.y)

    def __add__(self, other):
        if isinstance(other, Coordinate):
            x = self.x + other.x
            y = self.y + other.y
        elif isinstance(other, numbers.Number):
            x = int(round(self.x + other))
            y = int(round(self.y + other))
        elif isinstance(other, tuple) or isinstance(other, list) or isinstance(other, np.ndarray):
            x = int(round(self.x + other[0]))
            y = int(round(self.y + other[1]))
        return Coordinate(x, y)

    def __sub__(self, other):
        return self.__add__(other * -1)

    def __mul__(self, other):
        if isinstance(other, numbers.Number):
            return Coordinate(self.x * other, self.y * other)
        elif isinstance(other, tuple):
            return Coordinate(self.x * other[0], self.y * other[1])

    def __floordiv__(self, other):
        return self.__truediv__(other)

    def __truediv__(self, other):
        if isinstance(other, Coordinate):
            x = self.x // other.x
            y = self.y // other.y
        elif isinstance(other, numbers.Number):
            x = self.x // other
            y = self.y // other
        return Coordinate(x, y)

    def __iter__(self):
        yield self.x
        yield self.y

    def __getitem__(self, item):
        if item == 0:
            return self.x
            return self.y

    def __eq__(self, other):
        if isinstance(other, Coordinate):
            return self.x == other.x and self.y == other.y
        elif isinstance(other, tuple) or isinstance(other, list) or isinstance(other, np.ndarray):
            return self.x == other[0] and self.y == other[1]

    def max(self, val):
        self.x = max(self.x, val)
        self.y = max(self.y, val)

class BoundingBox:
    COLOR_GT = (0, 255, 0)
    COLOR_PREDICT = (255, 0, 0)
    COLOR_NEGATIVE = (0, 0, 255)

    def read_vid_gt(path):
        if os.path.isdir(path):
            path = os.path.join(path, 'groundtruth_rect.txt')

        with open(path, 'r') as f:
            lines = f.readlines()

        boxes = []
        for line in lines:
            if not line.strip():
            x, y, w, h = [int(x) for x in line.split(',')]
            box = BoundingBox(x, y, w, h)
        return boxes

    def get_action_labels(samples, gt_box):
        # TODO : vectorize everything
        return [BoundingBox.get_action_label(sample, gt_box) for sample in samples]

    def get_action_label(sample, gt_box):
        ious = []
        for i in range(ADNetwork.NUM_ACTIONS):
            moved_box = sample.do_action(imgwh=None, action_idx=i)
            iou = gt_box.iou(moved_box)

        if ious[ADNetwork.ACTION_IDX_STOP] > ADNetConf.get()['predict']['stop_iou']:
            return ADNetwork.ACTION_IDX_STOP
        if max(ious[:-2]) * 0.99999 <= ious[ADNetwork.ACTION_IDX_STOP]:
            return np.argmax(ious)
            # return random.choice([i for i, x in enumerate(ious) if x >= max(ious)])
        return np.argmax(ious[:-2])
        # return random.choice([i for i, x in enumerate(ious[:-2]) if x >= max(ious[:-2])])

    def __init__(self, x, y, w, h):
        self.xy = Coordinate(x, y)
        self.wh = Coordinate(w, h)
        self.feat = None

    def __repr__(self):
        return 'x=%d, y=%d, w=%d, h=%d' % (self.xy.x, self.xy.y, self.wh.x, self.wh.y)

    def __eq__(self, other):
        return self.xy == other.xy and self.wh == other.wh

    def __add__(self, other):
        if isinstance(other, tuple) or isinstance(other, list) or isinstance(other, np.ndarray):
            xy = self.xy + other[:2]
            wh = self.wh + other[2:]
            return BoundingBox(xy.x, xy.y, wh.x, wh.y)
        elif isinstance(other, BoundingBox):
            xy = self.xy + other.xy
            wh = self.wh + other.wh
            return BoundingBox(xy.x, xy.y, wh.x, wh.y)

    def __mul__(self, other):
        if isinstance(other, tuple) or isinstance(other, list) or isinstance(other, np.ndarray):
            xy = self.xy * other[:2]
            wh = self.wh * other[2:]
            return BoundingBox(xy.x, xy.y, wh.x, wh.y)

    def __floordiv__(self, other):
        return self.__truediv__(other)

    def __truediv__(self, other):
        if isinstance(other, numbers.Number):
            xy = self.xy // other
            wh = self.wh // other
            return BoundingBox(xy.x, xy.y, wh.x, wh.y)

    def get_xy2(self):
        return self.xy + self.wh

    def fit_image(self, imgwh):
        self.xy.x = max(0, self.xy.x)
        self.xy.y = max(0, self.xy.y)

        self.wh.x = max(10, min(self.wh.x, imgwh.x - 10))
        self.wh.y = max(10, min(self.wh.y, imgwh.y - 10))
        self.wh.x = min(self.wh.x, imgwh.x - self.xy.x)
        self.wh.y = min(self.wh.y, imgwh.y - self.xy.y)

    def draw(self, img, color=(255, 255, 255)):
        draw bounding box on image
        cv2.rectangle(img, tuple(self.xy), tuple(self.get_xy2()), color, 1)

    def iou(self, other):
        # reference :
        # determine the (x, y)-coordinates of the intersection rectangle
        if isinstance(other, BoundingBox):
            other_x = other.xy.x
            other_y = other.xy.y
            other_w = other.wh.x
            other_h = other.wh.y
        elif isinstance(other, tuple) or isinstance(other, list) or isinstance(other, np.ndarray):
            other_x, other_y, other_w, other_h = other[:4]

        xA = max(self.xy.x, other_x)
        yA = max(self.xy.y, other_y)
        xB = min(self.xy.x + self.wh.x, other_x + other_w)
        yB = min(self.xy.y + self.wh.y, other_y + other_h)

        if xA >= xB or yA >= yB:
            return 0.0

        # compute the area of intersection rectangle
        interArea = (xB - xA) * (yB - yA)

        # compute the area of both the prediction and ground-truth
        # rectangles
        boxAArea = self.wh.x * self.wh.y
        boxBArea = other_w * other_h

        # compute the intersection over union by taking the intersection
        # area and dividing it by the sum of prediction + ground-truth
        # areas - the interesection area
        iou = interArea / float(boxAArea + boxBArea - interArea)

        # return the intersection over union value
        return iou

    def do_action(self, imgwh, action_idx):
        action_ratios = tuple([ADNetConf.get()['action_move'][x] for x in 'xywh'])

        if action_idx < 8:
            deltas_xy = self.wh * action_ratios[:2]
            actual_deltas = ADNetwork.ACTIONS[action_idx][:2] * (deltas_xy.x, deltas_xy.y)
            moved_xy = self.xy + actual_deltas
            new_box = BoundingBox(moved_xy.x, moved_xy.y, self.wh.x, self.wh.y)
        elif action_idx == 8:
            new_box = BoundingBox(self.xy.x, self.xy.y, self.wh.x, self.wh.y)
            deltas_wh = self.wh * action_ratios[2:]
            deltas_wh_scaled = ADNetwork.ACTIONS[action_idx][2:] * (deltas_wh.x, deltas_wh.y)
            moved_xy = self.xy + -1 * deltas_wh_scaled / 2
            moved_wh = self.wh + deltas_wh_scaled

            new_box = BoundingBox(moved_xy.x, moved_xy.y, moved_wh.x, moved_wh.y)

        if imgwh:
        return new_box

    def gen_noise_samples(self, imgwh, noise_type, num, **kwargs):
        center_xy = self.xy + self.wh * 0.5
        mean_wh = sum(self.wh) / 2.0

        gaussian_translation_f = kwargs.get('gaussian_translation_f', 0.1)
        uniform_translation_f = kwargs.get('uniform_translation_f', 1)
        uniform_scale_f = kwargs.get('uniform_scale_f', 10)

        samples = []
        if noise_type == 'whole':
            grid_x = range(self.wh.x // 2, imgwh.x - self.wh.x // 2, self.wh.x // 5)
            grid_y = range(self.wh.y // 2, imgwh.y - self.wh.y // 2, self.wh.y // 5)
            samples_tmp = []
            for dx, dy, ds in itertools.product(grid_x, grid_y, range(-5, 5, 1)):
                box = BoundingBox(dx, dy, self.wh.x*(1.05**ds), self.wh.y*(1.05**ds))

            for _ in range(num):
            for _ in range(num):
                if noise_type == 'gaussian':
                    dx = gaussian_translation_f * mean_wh * minmax(0.5 * random.normalvariate(0, 1), -1, 1)
                    dy = gaussian_translation_f * mean_wh * minmax(0.5 * random.normalvariate(0, 1), -1, 1)
                    dwh = 1.05 ** (3 * minmax(0.5 * random.normalvariate(0, 1), -1, 1))
                elif noise_type == 'uniform':
                    dx = uniform_translation_f * mean_wh * random.uniform(-1.0, 1.0)
                    dy = uniform_translation_f * mean_wh * random.uniform(-1.0, 1.0)
                    dwh = 1.05 ** (uniform_scale_f * random.uniform(-1.0, 1.0))
                new_cxy = center_xy + (dx, dy)
                new_wh = self.wh * dwh
                box = BoundingBox(new_cxy.x - new_wh.x / 2.0, new_cxy.y - new_wh.y / 2.0, new_wh.x, new_wh.y)

        return samples

    def get_posneg_samples(self, imgwh, pos_size, neg_size, use_whole=True, **kwargs):
        pos_thresh = kwargs.get('pos_thresh', ADNetConf.g()['initial_finetune']['pos_thresh'])
        neg_thresh = kwargs.get('neg_thresh', ADNetConf.g()['initial_finetune']['neg_thresh'])

        gaussian_samples = self.gen_noise_samples(imgwh, 'gaussian', pos_size * 2, kwargs=kwargs)
        gaussian_samples = [x for x in gaussian_samples if x.iou(self) > pos_thresh]

        uniform_samples = self.gen_noise_samples(imgwh, 'uniform', neg_size if use_whole else neg_size*2, kwargs=kwargs)
        uniform_samples = [x for x in uniform_samples if x.iou(self) < neg_thresh]

        if use_whole:
            whole_samples = self.gen_noise_samples(imgwh, 'whole', neg_size, kwargs=kwargs)
            whole_samples = [x for x in whole_samples if x.iou(self) < neg_thresh]
            whole_samples = []

        pos_samples = []
        for _ in range(pos_size):

        neg_candidates = uniform_samples + whole_samples
        neg_samples = []
        for _ in range(neg_size):
        return pos_samples, neg_samples

if __name__ == '__main__':

    # iou test
    box_a = BoundingBox(0, 0, 100, 100)
    box_b = BoundingBox(0, 0, 50, 10)
    assert box_a.iou(box_b) == 0.05

    box_a = BoundingBox(0, 0, 10, 10)
    box_b = BoundingBox(5, 7, 7, 10)
    assert 0.096 < box_a.iou(box_b) < 0.097

    # random generator test
    gt_box = BoundingBox.read_vid_gt('./data/freeman1/')[0]
    gt_box.wh.x = gt_box.wh.y = 30

    imgpath = os.path.join('./data/freeman1/', 'img', '0001.jpg')
    img = cv2.imread(imgpath)

    if False:
        for random_type in ['gaussian', 'uniform', 'whole']:
            gaussian_boxes = gt_box.gen_noise_samples(Coordinate.get_imgwh(img), random_type, 20)

            gt_box.draw(img, BoundingBox.COLOR_GT)
            for box in gaussian_boxes:
                box.draw(img, BoundingBox.COLOR_PREDICT)

            cv2.imshow(random_type, img)

    # pos-neg sample test
    pos, neg = gt_box.get_posneg_samples(Coordinate.get_imgwh(img), 1, 10)
    img = cv2.imread(imgpath)
    for box in pos:
        box.draw(img, BoundingBox.COLOR_PREDICT)
    # for box in neg:
    #     box.draw(img, BoundingBox.COLOR_NEGATIVE)
    gt_box.draw(img, BoundingBox.COLOR_GT)
    actions = BoundingBox.get_action_labels(pos, gt_box)
    cv2.imshow('posneg samples', img)