import os
import pickle

import cv2
import numpy as np

from ikalog.inputs.filters import Filter, WarpFilterModel
from ikalog.utils import *

class WarpCalibrationException(Exception):

class WarpCalibrationNotFound(WarpCalibrationException):

class WarpCalibrationUnacceptableSize(WarpCalibrationException):

    def __init__(self, shape):
        self.shape = shape

class WarpFilter(Filter):

    def filter_matches(self, kp1, kp2, matches, ratio=0.75):
        mkp1, mkp2 = [], []
        for m in matches:
            if len(m) == 2 and m[0].distance < m[1].distance * ratio:
                m = m[0]
        p1 = np.float32([kp.pt for kp in mkp1])
        p2 = np.float32([kp.pt for kp in mkp2])
        kp_pairs = zip(mkp1, mkp2)
        return p1, p2, kp_pairs

    def set_bbox(self, x, y, w, h):
        corners = np.float32(
            [[x, y], [x + w, y], [w + x, y + h], [x, y + h]]

        self.pts1 = np.float32(corners)

        IkaUtils.dprint('pts1: %s' % [self.pts1])
        IkaUtils.dprint('pts2: %s' % [self.pts2])

        self.M = cv2.getPerspectiveTransform(self.pts1, self.pts2)
        return True

    def calibrateWarp(self, capture_image, validation_func=None):
        capture_image_gray = cv2.cvtColor(capture_image, cv2.COLOR_BGR2GRAY)

        capture_image_keypoints, capture_image_descriptors = self.detector.detectAndCompute(
            capture_image_gray, None)
        print('caputure_image - %d features' % (len(capture_image_keypoints)))


        raw_matches = self.matcher.knnMatch(
        p1, p2, kp_pairs = self.filter_matches(

        if len(p1) >= 4:
            H, status = cv2.findHomography(p1, p2, cv2.RANSAC, 5.0)
            print('%d / %d  inliers/matched' % (np.sum(status), len(status)))
            H, status = None, None
            print('%d matches found, not enough for homography estimation' % len(p1))
            self.calibration_requested = False
            raise WarpCalibrationNotFound()

        if H is None:
            # Should never reach there...
            self.calibration_requested = False
            raise WarpCalibrationNotFound()

        if len(status) < 1000:
            raise WarpCalibrationNotFound()

        calibration_image_height, calibration_image_width = self.calibration_image_size

        corners = np.float32(
            [[0, 0],
             [calibration_image_width, 0],
             [calibration_image_width, calibration_image_height],
             [0, calibration_image_height]]

        pts1 = np.float32(cv2.perspectiveTransform(
            corners.reshape(1, -1, 2), H).reshape(-1, 2) + (0, 0))

        IkaUtils.dprint('pts1: %s' % [pts1])
        IkaUtils.dprint('pts2: %s' % [self.pts2])

        if validation_func is not None:
            if not validation_func(pts1):
                w = int(pts1[1][0] - pts1[0][0])
                h = int(pts1[2][1] - pts1[1][1])
                raise WarpCalibrationUnacceptableSize((w, h))

        self.M = cv2.getPerspectiveTransform(pts1, self.pts2)
        return True

    def tuples2keyPoints(self, tuples):
        new_l = []
        for point in tuples:
            pt, size, angle, response, octave, class_id = point
                pt[0], pt[1], size, angle, response, octave, class_id))
        return new_l

    def keyPoints2tuples(self, points):
        new_l = []
        for point in points:
            new_l.append((point.pt, point.size, point.angle, point.response, point.octave,
        return new_l

    def loadModelFromFile(self, file):
        f = open(file, 'rb')
        l = pickle.load(f)
        self.calibration_image_size = l[0]
        self.calibration_image_keypoints = self.tuples2keyPoints(l[1])
        self.calibration_image_descriptors = l[2]

    def saveModelToFile(self, file):
        f = open(file, 'wb')
        ], f)

    def initializeCalibration(self):
        model_object = WarpFilterModel()

        if not model_object.trained:
            raise Exception('Could not intialize WarpFilterModel')

        self.detector = model_object.detector
        self.norm = model_object.norm
        self.matcher = model_object.matcher

        self.calibration_image_size = model_object.calibration_image_size
        self.calibration_image_keypoints = model_object.calibration_image_keypoints
        self.calibration_image_descriptors = model_object.calibration_image_descriptors


    def reset(self):
        # input source
        w = 1280
        h = 720

        self.pts2 = np.float32([[0, 0], [w, 0], [w, h], [0, h]])
        self.M = cv2.getPerspectiveTransform(self.pts2, self.pts2)

    def pre_execute(self, frame):
        return True

    def execute(self, frame):
        if not (self.enabled and self.pre_execute(frame)):
            return frame

        return cv2.warpPerspective(frame, self.M, (1280, 720))

    def __init__(self, parent, debug=False):
        super().__init__(parent, debug=debug)