# -*- coding: utf-8 -*- import os import pickle import numpy as np import cv2 class SIFT(object): def __init__(self, root): self.path = os.path.join(root, 'sift') self.extractor = cv2.xfeatures2d.SIFT_create() def extract(self, gray, rootsift=True): # 计算图片的所有关键点和对应的描述子 kp, des = self.extractor.detectAndCompute(gray, None) if rootsift: des = self.rootsift(des) return kp, des def match(self, des_q, des_t): ratio = 0.7 # 按照Lowe的测试 flann = cv2.FlannBasedMatcher() # 对des_q中的每个描述子,在des_t中找到最好的两个匹配 two_nn = flann.knnMatch(des_q, des_t, k=2) # 找到所有显著好于次匹配的最好匹配,得到对应的索引对 matches = [(first.queryIdx, first.trainIdx) for first, second in two_nn if first.distance < ratio * second.distance] return matches def filter(self, pt_qt): if len(pt_qt) > 0: pt_q, pt_t = zip(*pt_qt) # 获取匹配坐标的变换矩阵和正常点的掩码 M, mask = cv2.findHomography(np.float32(pt_q).reshape(-1, 1, 2), np.float32(pt_t).reshape(-1, 1, 2), cv2.RANSAC, 3) return mask.ravel().tolist() else: return [] def draw(self, img_q, img_t, pt_qt): import matplotlib matplotlib.use('Agg') from matplotlib import pyplot as plt from matplotlib.patches import ConnectionPatch fig, (ax_q, ax_t) = plt.subplots(1, 2, figsize=(8, 4)) for pt_q, pt_t in pt_qt: con = ConnectionPatch(pt_t, pt_q, coordsA='data', coordsB='data', axesA=ax_t, axesB=ax_q, color='g', linewidth=0.5) ax_t.add_artist(con) ax_q.plot(pt_q[0], pt_q[1], 'rx') ax_t.plot(pt_t[0], pt_t[1], 'rx') ax_q.imshow(img_q) ax_t.imshow(img_t) ax_q.axis('off') ax_t.axis('off') plt.subplots_adjust(wspace=0, hspace=0) plt.show() @classmethod def rootsift(cls, des, eps=1e-7): if des is not None: # 对所有描述子进行L1归一化并取平方根,eps防止除数为0 des /= (des.sum(axis=1, keepdims=True) + eps) des = np.sqrt(des) return des def dump(self, kp, des, filename): tmp = [ (kp.pt, kp.size, kp.angle, kp.response, kp.octave, kp.class_id) for kp in kp ] with open(os.path.join(self.path, filename), 'wb') as sift_pkl: pickle.dump((tmp, des), sift_pkl) def load(self, filename): with open(os.path.join(self.path, filename), 'rb') as sift_pkl: tmp, des = pickle.load(sift_pkl) kp = [ cv2.KeyPoint(x=t[0][0], y=t[0][1], _size=t[1], _angle=t[2], _response=t[3], _octave=t[4], _class_id=t[5]) for t in tmp ] return kp, des