import cv2
import numpy as np
import imutils
import typing
import copy
import datetime
import tempfile
import contextlib
import os
from collections import namedtuple
from scipy.spatial.distance import euclidean

Point = namedtuple("Point", ("x", "y"))


def load_grey_from_path(pic_path: str) -> np.ndarray:
    """ load grey picture (with cv2) from path """
    assert os.path.isfile(pic_path), f"picture [{pic_path}] not existed"
    raw_img = cv2.imread(pic_path)
    return load_grey_from_cv2_object(raw_img)


def load_grey_from_cv2_object(pic_object: np.ndarray) -> np.ndarray:
    """ preparation for cv2 object (force turn it into gray) """
    pic_object = pic_object.astype(np.uint8)
    try:
        # try to turn it into grey
        grey_pic = cv2.cvtColor(pic_object, cv2.COLOR_BGR2GRAY)
    except cv2.error:
        # already grey
        return pic_object
    return grey_pic


def pre_pic(pic_path: str = None, pic_object: np.ndarray = None) -> np.ndarray:
    """ this method will turn pic path and pic object into grey pic object """
    if pic_object is not None:
        return load_grey_from_cv2_object(pic_object)
    return load_grey_from_path(pic_path)


def resize_pic_scale(pic_object: np.ndarray, target_scale: np.ndarray) -> np.ndarray:
    return imutils.resize(pic_object, width=int(pic_object.shape[1] * target_scale))


def turn_grey(old: np.ndarray) -> np.ndarray:
    try:
        return cv2.cvtColor(old, cv2.COLOR_RGB2GRAY)
    except cv2.error:
        return old


def decompress_point(old: typing.Tuple, compress_rate: float) -> typing.List:
    return [int(i / compress_rate) for i in old]


def compress_frame(
    old: np.ndarray,
    compress_rate: float = None,
    target_size: typing.Tuple[int, int] = None,
    not_grey: bool = None,
    interpolation: int = None,
) -> np.ndarray:
    """
    Compress frame

    :param old:
        origin frame

    :param compress_rate:
        before_pic * compress_rate = after_pic. default to 1 (no compression)
        eg: 0.2 means 1/5 size of before_pic

    :param target_size:
        tuple. (100, 200) means compressing before_pic to 100x200

    :param not_grey:
        convert into grey if True

    :param interpolation:
    :return:
    """

    target = turn_grey(old) if not not_grey else old
    if not interpolation:
        interpolation = cv2.INTER_AREA
    # target size first
    if target_size:
        return cv2.resize(target, target_size, interpolation=interpolation)
    # else, use compress rate
    # default rate is 1 (no compression)
    if not compress_rate:
        return target
    return cv2.resize(
        target, (0, 0), fx=compress_rate, fy=compress_rate, interpolation=interpolation
    )


def fix_location(shape: typing.Sequence, location: typing.Sequence) -> typing.Sequence:
    """ location from cv2 should be left-top location, and need to fix it and make it central """
    size_y, size_x = shape
    old_x, old_y = location
    return old_x + size_x / 2, old_y + size_y / 2


def mark_point(
    pic_object: np.ndarray, location: typing.Sequence, cover: bool = None
) -> np.ndarray:
    """ draw a mark on your picture, or your picture copy. """
    if not cover:
        pic_object = copy.deepcopy(pic_object)
    distance = 50
    target_x, target_y = map(int, location)
    start_point = (target_x - distance, target_y - distance)
    end_point = (target_x + distance, target_y + distance)
    cv2.rectangle(pic_object, start_point, end_point, -1)
    return pic_object


def get_timestamp() -> str:
    return datetime.datetime.now().strftime("%Y%m%d%H%M%S")


@contextlib.contextmanager
def cv2file(pic_object: np.ndarray) -> str:
    """ save cv object to file, and return its path """
    temp_pic_file_object = tempfile.NamedTemporaryFile(
        mode="wb+", suffix=".png", delete=False
    )
    cv2.imwrite(temp_pic_file_object.name, pic_object)
    temp_pic_file_object_path = temp_pic_file_object.name
    yield temp_pic_file_object_path
    os.remove(temp_pic_file_object_path)


def point_list_filter(
    point_list: typing.Sequence, distance: float, point_limit: int = None
) -> typing.Sequence:
    """ remove some points which are too close """
    if not point_limit:
        point_limit = 20

    point_list = sorted(list(set(point_list)), key=lambda o: o[0])
    new_point_list = [point_list[0]]
    for cur_point in point_list[1:]:
        for each_confirmed_point in new_point_list:
            cur_distance = euclidean(cur_point, each_confirmed_point)
            # existed
            if cur_distance < distance:
                break
        else:
            new_point_list.append(cur_point)
            if len(new_point_list) >= point_limit:
                break
    return new_point_list


def debug_cv_object(target_object: np.ndarray, prefix: str) -> str:
    """ save target object as a temp picture, and return its path """
    mark_pic_path = f"{prefix}_{get_timestamp()}.png"
    cv2.imwrite(mark_pic_path, target_object)
    return mark_pic_path