# -*- coding: utf-8 -*- """ Created on 2018-03-13 15:29:29 @Author: Ben @Version : 0.0.1 """ from enum import Enum from typing import List, Tuple, Union import unittest import os import random import cv2 import numpy as np import k_means import ransac import blend def show_image(image: np.ndarray) -> None: from PIL import Image Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).show() class Method(Enum): SURF = cv2.xfeatures2d.SURF_create SIFT = cv2.xfeatures2d.SIFT_create ORB = cv2.ORB_create colors = ((123, 234, 12), (23, 44, 240), (224, 120, 34), (21, 234, 190), (80, 160, 200), (243, 12, 100), (25, 90, 12), (123, 10, 140)) class Area: def __init__(self, *points): self.points = list(points) def is_inside(self, x: Union[float, Tuple[float, float]], y: float=None) -> bool: if isinstance(x, tuple): x, y = x raise NotImplementedError() class Matcher(): def __init__(self, image1: np.ndarray, image2: np.ndarray, method: Enum=Method.SIFT, threshold=800) -> None: """输入两幅图像,计算其特征值 此类用于输入两幅图像,计算其特征值,输入两幅图像分别为numpy数组格式的图像,其中的method参数要求输入SURF、SIFT或者ORB,threshold参数为特征值检测所需的阈值。 Args: image1 (np.ndarray): 图像一 image2 (np.ndarray): 图像二 method (Enum, optional): Defaults to Method.SIFT. 特征值检测方法 threshold (int, optional): Defaults to 800. 特征值阈值 """ self.image1 = image1 self.image2 = image2 self.method = method self.threshold = threshold self._keypoints1: List[cv2.KeyPoint] = None self._descriptors1: np.ndarray = None self._keypoints2: List[cv2.KeyPoint] = None self._descriptors2: np.ndarray = None if self.method == Method.ORB: # error if not set this self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True) else: # self.matcher = cv2.BFMatcher(crossCheck=True) self.matcher = cv2.FlannBasedMatcher() self.match_points = [] self.image_points1 = np.array([]) self.image_points2 = np.array([]) def compute_keypoint(self) -> None: """计算特征点 利用给出的特征值检测方法对图像进行特征值检测。 Args: image (np.ndarray): 图像 """ feature = self.method.value(self.threshold) self._keypoints1, self._descriptors1 = feature.detectAndCompute( self.image1, None) self._keypoints2, self._descriptors2 = feature.detectAndCompute( self.image2, None) def match(self, max_match_lenth=20, threshold=0.04, show_match=False): """对两幅图片计算得出的特征值进行匹配,对ORB来说使用OpenCV的BFMatcher算法,而对于其他特征检测方法则使用FlannBasedMatcher算法。 max_match_lenth (int, optional): Defaults to 20. 最大匹配点数量 threshold (float, optional): Defaults to 0.04. 默认最大匹配距离差 show_match (bool, optional): Defaults to False. 是否展示匹配结果 """ self.compute_keypoint() '''计算两张图片中的配对点,并至多取其中最优的`max_match_lenth`个''' self.match_points = sorted(self.matcher.match( self._descriptors1, self._descriptors2), key=lambda x: x.distance) match_len = min(len(self.match_points), max_match_lenth) # in case distance is 0 max_distance = max(2 * self.match_points[0].distance, 20) for i in range(match_len): if self.match_points[i].distance > max_distance: match_len = i break print('max distance: ', self.match_points[match_len].distance) print("Min distance: ", self.match_points[0].distance) print('match_len: ', match_len) assert(match_len >= 4) self.match_points = self.match_points[:match_len] if show_match: img3 = cv2.drawMatches(self.image1, self._keypoints1, self.image2, self._keypoints2, self.match_points, None, flags=0) show_image(img3) # cv2.imwrite('../resource/3-sift-match.jpg', img3) '''由最佳匹配取得匹配点对,并进行形变拼接''' image_points1, image_points2 = [], [] for i in self.match_points: image_points1.append(self._keypoints1[i.queryIdx].pt) image_points2.append(self._keypoints2[i.trainIdx].pt) self.image_points1 = np.float32(image_points1) self.image_points2 = np.float32(image_points2) # print(image_points1) def get_weighted_points(image_points: np.ndarray): average = np.average(image_points, axis=0) max_index = np.argmax(np.linalg.norm((image_points - average), axis=1)) return np.append(image_points, np.array([image_points[max_index]]), axis=0) class Stitcher: def __init__(self, image1: np.ndarray, image2: np.ndarray, method: Enum=Method.SIFT, use_kmeans=False): """输入图像和匹配,对图像进行拼接 目前采用简单矩阵匹配和平均值拼合 Args: image1 (np.ndarray): 图像一 image2 (np.ndarray): 图像二 matcher (Matcher): 匹配结果 use_kmeans (bool): 是否使用kmeans 优化点选择 """ self.image1 = image1 self.image2 = image2 self.method = method self.use_kmeans = use_kmeans self.matcher = Matcher(image1, image2, method=method) self.M = np.eye(3) self.image = None def stich(self, show_result=True, max_match_lenth=40, show_match_point=True, use_partial=False, use_new_match_method=False, use_gauss_blend=True): """对图片进行拼合 show_result (bool, optional): Defaults to True. 是否展示拼合图像 show_match_point (bool, optional): Defaults to True. 是否展示拼合点 """ self.matcher.match(max_match_lenth=max_match_lenth, show_match=show_match_point) if self.use_kmeans: self.image_points1, self.image_points2 = k_means.get_group_center( self.matcher.image_points1, self.matcher.image_points2) else: self.image_points1, self.image_points2 = ( self.matcher.image_points1, self.matcher.image_points2) if use_new_match_method: self.M = ransac.GeneticTransform(self.image_points1, self.image_points2).run() else: self.M, _ = cv2.findHomography( self.image_points1, self.image_points2, method=cv2.RANSAC) # self.M = ransac.Ransac(self.image_points1, self.image_points2).run() print("Good points and average distance: ", ransac.GeneticTransform.get_value( self.image_points1, self.image_points2, self.M)) left, right, top, bottom = self.get_transformed_size() # print(self.get_transformed_size()) width = int(max(right, self.image2.shape[1]) - min(left, 0)) height = int(max(bottom, self.image2.shape[0]) - min(top, 0)) print(width, height) # width, height = min(width, 10000), min(height, 10000) if width * height > 8000 * 5000: # raise MemoryError("Too large to get the combination") factor = width*height/(8000*5000) width = int(width/factor) height = int(height/factor) if use_partial: self.partial_transform() # 移动矩阵 self.adjustM = np.array( [[1, 0, max(-left, 0)], # 横向 [0, 1, max(-top, 0)], # 纵向 [0, 0, 1] ], dtype=np.float64) # print('adjustM: ', adjustM) self.M = np.dot(self.adjustM, self.M) transformed_1 = cv2.warpPerspective( self.image1, self.M, (width, height)) transformed_2 = cv2.warpPerspective( self.image2, self.adjustM, (width, height)) self.image = self.blend(transformed_1, transformed_2, use_gauss_blend=use_gauss_blend) if show_match_point: for point1, point2 in zip(self.image_points1, self.image_points2): point1 = self.get_transformed_position(tuple(point1)) point1 = tuple(map(int, point1)) point2 = self.get_transformed_position(tuple(point2), M=self.adjustM) point2 = tuple(map(int, point2)) cv2.line(self.image, point1, point2, random.choice(colors), 3) cv2.circle(self.image, point1, 10, (20, 20, 255), 5) cv2.circle(self.image, point2, 8, (20, 200, 20), 5) if show_result: show_image(self.image) def partial_transform(self): """Deprecated, should not be used. """ raise DeprecationWarning("Out of work, should not be used") def distance(p1, p2): return np.sqrt( (p1[0] - p2[0]) * (p1[0] - p2[0]) + (p1[1] - p2[1]) * (p1[1] - p2[1])) width = self.image1.shape[0] height = self.image1.shape[1] offset_x = np.min(self.image_points1[:, 0]) offset_y = np.min(self.image_points1[:, 1]) x_mid = int((np.max(self.image_points1[:, 0]) + offset_x) / 2) y_mid = int((np.max(self.image_points1[:, 1]) + offset_y) / 2) center = [0, 0] up = x_mid down = width - x_mid left = y_mid right = height - y_mid ne, se, sw, nw = [], [], [], [] transform_acer = [[center, [up, 0], [up, right]], [center, [down, 0], [0, right]], [center, [down, left], [0, left]], [[up, 0], [up, left], [up, left]]] transform_acer = [[center, [0, up], [right, up]], [center, [0, down], [right, 0]], [center, [left, down], [left, 0]], [[0, up], [left, up], [left, up]]] # 对点的位置进行分类 for index in range(self.image_points1.shape[0]): point = self.image_points1[index] if point[0] > y_mid: if point[1] > x_mid: se.append(index) else: ne.append(index) else: if point[1] > x_mid: sw.append(index) else: nw.append(index) # 求点最少处位置,排除零 minmum = np.argmin( list(map(lambda x: len(x) if len(x) > 0 else 65536, [ne, se, sw, nw]))) # 当足够少时 min_part = (ne, se, sw, nw)[minmum] # debug: print("minum part: ", minmum, "point len: ", len( min_part), "|", list(map(len, (ne, se, sw, nw)))) for index in min_part: point = self.image_points1[index] cv2.circle(self.image1, tuple( map(int, point)), 20, (0, 255, 255), 5) # cv2.circle(self.image1, tuple(map(int, (y_mid, x_mid))), # 25, (255, 100, 60), 7) # end debug if len(min_part) < len(self.image_points1) / 8: for index in min_part: point = self.image_points1[index].tolist() print("Point: ", point) # maybe can try other value? if distance(self.get_transformed_position(tuple(point)), self.image_points2[index]) > 10: def relevtive_point(p): return (p[0] - y_mid if p[0] > y_mid else p[0], p[1] - x_mid if p[1] > x_mid else p[1]) cv2.circle(self.image1, tuple(map(int, point)), 40, (255, 0, 0), 10) src_point = transform_acer[minmum].copy() src_point.append(relevtive_point(point)) other_point = self.get_transformed_position( tuple(self.image_points2[index]), M=np.linalg.inv(self.M)) dest_point = transform_acer[minmum].copy() dest_point.append(relevtive_point(other_point)) def a(x): return np.array(x, dtype=np.float32) print(src_point, dest_point) partial_M = cv2.getPerspectiveTransform( a(src_point), a(dest_point)) if minmum == 1 or minmum == 2: boder_0, boder_1 = x_mid, width else: boder_0, boder_1 = 0, x_mid if minmum == 2 or minmum == 3: boder_2, boder_3 = 0, y_mid else: boder_2, boder_3 = y_mid, height print("Changed:", "\nM: ", partial_M, "\npart: ", minmum, "\ndistance: ", distance(self.get_transformed_position(tuple(point)), self.image_points2[index]) ) part = self.image1[boder_0:boder_1, boder_2:boder_3] print(boder_0, boder_1, boder_2, boder_3) for point in transform_acer[minmum]: print(point) cv2.circle(part, tuple( map(int, point)), 40, (220, 200, 200), 10) for point in src_point: print(point) cv2.circle(part, tuple( map(int, point)), 22, (226, 43, 138), 8) part = cv2.warpPerspective( part, partial_M, (part.shape[1], part.shape[0])) cv2.circle(part, tuple(map(int, relevtive_point(other_point))), 40, (20, 97, 199), 6) # show_image(part) self.image1[boder_0:boder_1, boder_2:boder_3] = part return def blend(self, image1: np.ndarray, image2: np.ndarray, use_gauss_blend=True) -> np.ndarray: """对图像进行融合 Args: image1 (np.ndarray): 图像一 image2 (np.ndarray): 图像二 Returns: np.ndarray: 融合结果 """ mask = self.generate_mask(image1, image2) print("Blending") if use_gauss_blend: result = blend.gaussian_blend(image1, image2, mask, mask_blend=10) else: result = blend.direct_blend(image1, image2, mask, mask_blend=0) return result def generate_mask(self, image1: np.ndarray, image2: np.ndarray): """生成供融合使用的遮罩,由变换后图像的垂直平分线来构成分界线 Args: shape (tuple): 遮罩大小 Returns: np.ndarray: 01数组 """ print("Generating mask") # x, y center1 = self.image1.shape[1] / 2, self.image1.shape[0] / 2 center1 = self.get_transformed_position(center1) center2 = self.image2.shape[1] / 2, self.image2.shape[0] / 2 center2 = self.get_transformed_position(center2, M=self.adjustM) # 垂直平分线 y=-(x2-x1)/(y2-y1)* [x-(x1+x2)/2]+(y1+y2)/2 x1, y1 = center1 x2, y2 = center2 # note that opencv is (y, x) def function(y, x, *z): return (y2 - y1) * y < -(x2 - x1) * (x - (x1 + x2) / 2) + (y2 - y1) * (y1 + y2) / 2 mask = np.fromfunction(function, image1.shape) # mask = mask&_i2+mask&i1+i1&_i2 mask = np.logical_and(mask, np.logical_not(image2)) \ + np.logical_and(mask, image1)\ + np.logical_and(image1, np.logical_not(image2)) return mask def get_transformed_size(self) ->Tuple[int, int, int, int]: """计算形变后的边界 计算形变后的边界,从而对图片进行相应的位移,保证全部图像都出现在屏幕上。 Returns: Tuple[int, int, int, int]: 分别为左右上下边界 """ conner_0 = (0, 0) # x, y conner_1 = (self.image1.shape[1], 0) conner_2 = (self.image1.shape[1], self.image1.shape[0]) conner_3 = (0, self.image1.shape[0]) points = [conner_0, conner_1, conner_2, conner_3] # top, bottom: y, left, right: x top = min(map(lambda x: self.get_transformed_position(x)[1], points)) bottom = max( map(lambda x: self.get_transformed_position(x)[1], points)) left = min(map(lambda x: self.get_transformed_position(x)[0], points)) right = max(map(lambda x: self.get_transformed_position(x)[0], points)) return left, right, top, bottom def get_transformed_position(self, x: Union[float, Tuple[float, float]], y: float=None, M=None) -> Tuple[float, float]: """求得某点在变换矩阵(self.M)下的新坐标 Args: x (Union[float, Tuple[float, float]]): x坐标或(x,y)坐标 y (float, optional): Defaults to None. y坐标,可无 M (np.ndarray, optional): Defaults to None. 利用M进行坐标变换运算 Returns: Tuple[float, float]: 新坐标 """ if isinstance(x, tuple): x, y = x p = np.array([x, y, 1])[np.newaxis].T if M is not None: M = M else: M = self.M pa = np.dot(M, p) return pa[0, 0] / pa[2, 0], pa[1, 0] / pa[2, 0] class Test(unittest.TestCase): def _test_matcher(self): image1 = np.random.randint(100, 256, size=(400, 400, 3), dtype='uint8') # np.random.randint(256, size=(400, 400, 3), dtype='uint8') image2 = np.copy(image1) for method in Method: matcher = Matcher(image1, image2, method) matcher.match(show_match=True) def test_transform_coord(self): stitcher = Stitcher(None, None, None, None) self.assertEqual((0, 0), stitcher.get_transformed_position(0, 0)) self.assertEqual((10, 20), stitcher.get_transformed_position(10, 20)) stitcher.M[0, 2] = 20 stitcher.M[1, 2] = 10 self.assertEqual((20, 10), stitcher.get_transformed_position(0, 0)) self.assertEqual((30, 30), stitcher.get_transformed_position(10, 20)) stitcher.M = np.eye(3) stitcher.M[0, 1] = 2 stitcher.M[1, 0] = 4 self.assertEqual((0, 0), stitcher.get_transformed_position(0, 0)) self.assertEqual((50, 60), stitcher.get_transformed_position(10, 20)) def test_get_transformed_size(self): image1 = np.empty((500, 400, 3), dtype='uint8') image1[:, :] = 255, 150, 100 image1[:, 399] = 10, 20, 200 # show_image(image1) image2 = np.empty((400, 400, 3), dtype='uint8') image2[:, :] = 50, 150, 255 stitcher = Stitcher(image1, image2, None, None) stitcher.M[0, 2] = -20 stitcher.M[1, 2] = 10 stitcher.M[0, 1] = .2 stitcher.M[1, 0] = .1 left, right, top, bottom = stitcher.get_transformed_size() print(stitcher.get_transformed_size()) width = int(max(right, image2.shape[1]) - min(left, 0)) height = int(max(bottom, image2.shape[0]) - min(top, 0)) print(width, height) show_image(cv2.warpPerspective(image1, stitcher.M, (width, height))) def test_stich(self): image1 = np.empty((500, 400, 3), dtype='uint8') image1[:, :] = 255, 150, 100 image1[:, 399] = 10, 20, 200 # show_image(image1) image2 = np.empty((400, 400, 3), dtype='uint8') image2[:, :] = 50, 150, 255 points = np.float32([[0, 0], [20, 20], [12, 12], [40, 20]]) stitcher = Stitcher(image1, image2, points, points) stitcher.M[0, 2] = 20 stitcher.M[1, 2] = 10 stitcher.M[0, 1] = .2 stitcher.M[1, 0] = .1 stitcher.stich() def main(): unittest.main() if __name__ == "__main__": import time # main() os.chdir(os.path.dirname(__file__)) start_time = time.time() img1 = cv2.imread("../resource/29-left.jpg") img2 = cv2.imread("../resource/29-right.jpg") stitcher = Stitcher(img1, img2, Method.SIFT, False) stitcher.stich(max_match_lenth=40, use_partial=False, use_new_match_method=1, use_gauss_blend=0) # cv2.imwrite('../resource/19-sift-gf.jpg', stitcher.image) print("Time: ", time.time() - start_time) print("M: ", stitcher.M)