""" Template matching. """ from typing import Callable, Union, Tuple from pathlib import Path import cv2 as cv import numpy as np import logging from matplotlib import pyplot as plt logger = logging.getLogger('tm') # the template matching method TM_METHOD = cv.TM_CCOEFF_NORMED class TM: def __init__(self, feed: Callable, threshold: float = 0.85): """ :param feed: the screencap feed function :param threshold: the default threshold of matching. """ self.feed = feed self.threshold = threshold # template image set self.images = {} self.load_images() # the screencap image. Needs to be updated before matching. self.screen = None def load_image(self, im: Path, name=''): """ Load an image (in png format). May override default images. :param im: path to the image. :param name: specify the name of the image in the dict. If not given, use the filename as default. """ assert im.is_file() and im.name.endswith('.png') name = name or im.name[:-4] self.images[name] = cv.imread(str(im), cv.IMREAD_COLOR) # self.images[name] = cv.cvtColor(self.images[name], cv.COLOR_BGR2RGB) # plt.figure(name) # plt.imshow(self.images[name]) # plt.show() logger.debug('Loaded image {}'.format(name)) def load_images(self): """ Load template images from directory. """ im_dir = Path(__file__).absolute().parent / 'images' for im in im_dir.glob('*.png'): self.load_image(im) logger.info('Images loaded successfully.') def getsize(self, im: str) -> Tuple[int, int]: """ Return the size of given image. :param im: the name of image :return: the size in (width, height) """ h, w, _ = self.images[im].shape return w, h def update_screen(self): """ Update the screencap image from feed. """ self.screen = self.feed() logger.debug('Screen updated.') def probability(self, im: str) -> float: """ Return the probability of the existence of given image. :param im: the name of the image. :return: the probability (confidence). """ assert self.screen is not None try: template = self.images[im] except KeyError: logger.error('Unexpected image name {}'.format(im)) return 0.0 res = cv.matchTemplate(self.screen, template, TM_METHOD) _, max_val, _, max_loc = cv.minMaxLoc(res) logger.debug('max_val = {}, max_loc = {}'.format(max_val, max_loc)) return max_val def find(self, im: str, threshold: float = None) -> Tuple[int, int]: """ Find the template image on screen and return its top-left coords. Return None if the matching value is less than `threshold`. :param im: the name of the image :param threshold: the threshold of matching. If not given, will be set to the default threshold. :return: the top-left coords of the result. Return (-1, -1) if not found. """ threshold = threshold or self.threshold assert self.screen is not None try: template = self.images[im] except KeyError: logger.error('Unexpected image name {}'.format(im)) return -1, -1 res = cv.matchTemplate(self.screen, template, TM_METHOD) _, max_val, _, max_loc = cv.minMaxLoc(res) logger.debug('max_val = {}, max_loc = {}'.format(max_val, max_loc)) return max_loc if max_val >= threshold else (-1, -1) def exists(self, im: str, threshold: float = None) -> bool: """ Check if a given image exists on screen. :param im: the name of the image :param threshold: the threshold of matching. If not given, will be set to the default threshold. """ threshold = threshold or self.threshold return self.probability(im) >= threshold