""" This module implements the co-registration transformers. Credits: Copyright (c) 2017-2019 Matej Aleksandrov, Matej Batič, Andrej Burja, Eva Erzin (Sinergise) Copyright (c) 2017-2019 Grega Milčinski, Matic Lubej, Devis Peresutti, Jernej Puc, Tomislav Slijepčević (Sinergise) Copyright (c) 2017-2019 Blaž Sovdat, Nejc Vesel, Jovan Višnjić, Anže Zupanc, Lojze Žust (Sinergise) This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree. """ import logging import copy from abc import ABC, abstractmethod from enum import Enum import registration import cv2 import numpy as np from eolearn.core import EOTask, FeatureType from .coregistration_utilities import ransac, EstimateEulerTransformModel LOGGER = logging.getLogger(__name__) MAX_TRANSLATION = 20 MAX_ROTATION = np.pi / 9 class InterpolationType(Enum): """ Types of interpolation, available are NEAREST, LINEAR and CUBIC """ NEAREST = 0 LINEAR = 1 CUBIC = 3 class RegistrationTask(EOTask, ABC): """ Abstract class for multi-temporal image co-registration The task uses a temporal stack of images of the same location (i.e. a temporal-spatial feature in `EOPatch`). Starting from the latest frame and proceeding backwards it calculates a transformation between two temporally adjacent images. The transformation is used to correct the earlier image to best fit the later. The reason for such reversed order is that the latest frames are supposed to be less affected by orthorectificational inaccuracies. Each transformation is calculated using only a single channel of the images. If feature which contains masks of valid pixels is specified it is used during the calculation. At the end the transformations are applied to each of the specified features. Any additional registration parameters can be passed on to registration method class. Parameters: :param registration_feature: A feature which will be used for co-registration, e.g. feature=(FeatureType.DATA, 'bands'). By default this feature is of type FeatureType.DATA therefore also only feature name can be given e.g. feature='bands' :type registration_feature: (FeatureType, str) or str :param channel: Index of `feature`'s channel to be used in co-registration :type channel: int :param valid_mask_feature: Feature containing a mask of valid pixels for `registration_feature`. By default no mask is set. It can be set to e.g. valid_mask_feature=(FeatureType.MASK, 'IS_DATA') or valid_mask_feature='IS_DATA' if the feature is of type FeatureType.MASK :type valid_mask_feature: str or (FeatureType, str) or None :param apply_to_features: A collection of features to which co-registration will be applied to. By default this is only `registration_feature` and `valid_mask_feature` if given. Note that each feature must have same temporal dimension as `registration_feature`. :type apply_to_features: dict(FeatureType: set(str) or dict(str: str)) :param interpolation_type: Type of interpolation used. Default is `InterpolationType.CUBIC` :type interpolation_type: InterpolationType :param params: Any other registration setting which will be passed to registration method :type params: object """ def __init__(self, registration_feature, channel=0, valid_mask_feature=None, apply_to_features=..., interpolation_type=InterpolationType.CUBIC, **params): self.registration_feature = self._parse_features(registration_feature, default_feature_type=FeatureType.DATA) self.channel = channel self.valid_mask_feature = None if valid_mask_feature is None else \ self._parse_features(valid_mask_feature, default_feature_type=FeatureType.MASK) if apply_to_features is ...: apply_to_features = [next(self.registration_feature())] if valid_mask_feature: apply_to_features.append(next(self.valid_mask_feature())) self.apply_to_features = self._parse_features(apply_to_features) self.interpolation_type = interpolation_type self.params = params @abstractmethod def register(self, src, trg, trg_mask=None, src_mask=None): """ Method for registration :param src: src :param trg: trg :param trg_mask: trg_mask :param src_mask: src_mask """ raise NotImplementedError @abstractmethod def check_params(self): """ Method to validate registration parameters """ raise NotImplementedError @abstractmethod def get_params(self): """ Method to print out registration parameters used """ raise NotImplementedError @staticmethod def _get_interpolation_flag(interpolation_type): try: return { InterpolationType.CUBIC: cv2.INTER_CUBIC, InterpolationType.NEAREST: cv2.INTER_NEAREST, InterpolationType.LINEAR: cv2.INTER_LINEAR }[interpolation_type] except KeyError: raise ValueError("Unsupported interpolation method specified") def execute(self, eopatch): """ Method that estimates registrations and warps EOPatch objects """ self.check_params() self.get_params() new_eopatch = copy.deepcopy(eopatch) f_type, f_name = next(self.registration_feature(eopatch)) sliced_data = copy.deepcopy(eopatch[f_type][f_name][..., self.channel]) time_frames = sliced_data.shape[0] iflag = self._get_interpolation_flag(self.interpolation_type) for idx in range(time_frames - 1, 0, -1): # Pair-wise registration starting from the most recent frame src_mask, trg_mask = None, None if self.valid_mask_feature is not None: f_type, f_name = next(self.valid_mask_feature(eopatch)) src_mask = new_eopatch[f_type][f_name][idx - 1] trg_mask = new_eopatch[f_type][f_name][idx] # Estimate transformation warp_matrix = self.register(sliced_data[idx - 1], sliced_data[idx], src_mask=src_mask, trg_mask=trg_mask) # Check amount of deformation rflag = self.is_registration_suspicious(warp_matrix) # Flag suspicious registrations and set them to the identity if rflag: LOGGER.warning("{:s} warning in pair-wise reg {:d} to {:d}".format(self.__class__.__name__, idx - 1, idx)) warp_matrix = np.eye(2, 3) # Transform and update sliced_data sliced_data[idx - 1] = self.warp(warp_matrix, sliced_data[idx - 1], iflag) # Apply tranformation to every given feature for feature_type, feature_name in self.apply_to_features(eopatch): new_eopatch[feature_type][feature_name][idx - 1] = \ self.warp(warp_matrix, new_eopatch[feature_type][feature_name][idx - 1], iflag) return new_eopatch def warp(self, warp_matrix, img, iflag=cv2.INTER_NEAREST): """ Function to warp input image given an estimated 2D linear transformation :param warp_matrix: Linear 2x3 matrix to use to linearly warp the input images :type warp_matrix: ndarray :param img: Image to be warped with estimated transformation :type img: ndarray :param iflag: Interpolation flag, specified interpolation using during resampling of warped image :type iflag: cv2.INTER_* :return: Warped image using the linear matrix """ height, width = img.shape[:2] warped_img = np.zeros_like(img, dtype=img.dtype) # Check if image to warp is 2D or 3D. If 3D need to loop over channels if (self.interpolation_type == InterpolationType.LINEAR) or img.ndim == 2: warped_img = cv2.warpAffine(img.astype(np.float32), warp_matrix, (width, height), flags=iflag).astype(img.dtype) elif img.ndim == 3: for idx in range(img.shape[-1]): warped_img[..., idx] = cv2.warpAffine(img[..., idx].astype(np.float32), warp_matrix, (width, height), flags=iflag).astype(img.dtype) else: raise ValueError('Image has incorrect number of dimensions: {}'.format(img.ndim)) return warped_img @staticmethod def is_registration_suspicious(warp_matrix): """ Static method that check if estimated linear transformation could be unplausible This function checks whether the norm of the estimated translation or the rotation angle exceed predefined values. For the translation, a maximum translation radius of 20 pixels is flagged, while larger rotations than 20 degrees are flagged. :param warp_matrix: Input linear transformation matrix :type warp_matrix: ndarray :return: 0 if registration doesn't exceed threshold, 1 otherwise """ if warp_matrix is None: return 1 cos_theta = np.trace(warp_matrix[:2, :2]) / 2 rot_angle = np.arccos(cos_theta) transl_norm = np.linalg.norm(warp_matrix[:, 2]) return 1 if int((rot_angle > MAX_ROTATION) or (transl_norm > MAX_TRANSLATION)) else 0 class ThunderRegistration(RegistrationTask): """ Registration task implementing a translational registration using the thunder-registration package """ def register(self, src, trg, trg_mask=None, src_mask=None): """ Implementation of pair-wise registration using thunder-registration For more information on the model estimation, refer to https://github.com/thunder-project/thunder-registration This function takes two 2D single channel images and estimates a 2D translation that best aligns the pair. The estimation is done by maximising the correlation of the Fourier transforms of the images. Once, the translation is estimated, it is applied to the (multi-channel) image to warp and, possibly, ot hte ground-truth. Different interpolations schemes could be more suitable for images and ground-truth values (or masks). :param src: 2D single channel source moving image :param trg: 2D single channel target reference image :param src_mask: Mask of source image. Not used in this method. :param trg_mask: Mask of target image. Not used in this method. :return: Estimated 2D transformation matrix of shape 2x3 """ # Initialise instance of CrossCorr object ccreg = registration.CrossCorr() # padding_value = 0 # Compute translation between pair of images model = ccreg.fit(src, reference=trg) # Get translation as an array translation = [-x for x in model.toarray().tolist()[0]] # Fill in transformation matrix warp_matrix = np.eye(2, 3) warp_matrix[0, 2] = translation[1] warp_matrix[1, 2] = translation[0] # Return transformation matrix return warp_matrix def get_params(self): LOGGER.info("{:s}:This registration does not require parameters".format(self.__class__.__name__)) def check_params(self): pass class ECCRegistration(RegistrationTask): """ Registration task implementing an intensity-based method from OpenCV """ def get_params(self): LOGGER.info("{:s}:Params for this registration are:".format(self.__class__.__name__)) LOGGER.info("\t\t\t\tMaxIters: {:d}".format(self.params['MaxIters'])) LOGGER.info("\t\t\t\tgaussFiltSize: {:d}".format(self.params['gaussFiltSize'])) def check_params(self): if not isinstance(self.params.get('MaxIters'), int): LOGGER.info("{:s}:MaxIters set to 200".format(self.__class__.__name__)) self.params['MaxIters'] = 200 if not isinstance(self.params.get('gaussFilterSize'), int): LOGGER.info("{:s}:gaussFilterSize set to 1".format(self.__class__.__name__)) self.params['gaussFiltSize'] = 1 def register(self, src, trg, trg_mask=None, src_mask=None): """ Implementation of pair-wise registration and warping using Enhanced Correlation Coefficient This function estimates an Euclidean transformation (x,y translation + rotation) using the intensities of the pair of images to be registered. The similarity metric is a modification of the cross-correlation metric, which is invariant to distortions in contrast and brightness. :param src: 2D single channel source moving image :param trg: 2D single channel target reference image :param trg_mask: Mask of target image. Not used in this method. :param src_mask: Mask of source image. Not used in this method. :return: Estimated 2D transformation matrix of shape 2x3 """ # Parameters of registration warp_mode = cv2.MOTION_EUCLIDEAN # Specify the threshold of the increment # in the correlation coefficient between two iterations termination_eps = 1e-10 # Define termination criteria criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, self.params['MaxIters'], termination_eps) # Initialise warp matrix warp_matrix = np.eye(2, 3, dtype=np.float32) # Run the ECC algorithm. The results are stored in warp_matrix. _, warp_matrix = cv2.findTransformECC(src.astype(np.float32), trg.astype(np.float32), warp_matrix, warp_mode, criteria, None, self.params['gaussFiltSize']) return warp_matrix class PointBasedRegistration(RegistrationTask): """ Registration class implementing a point-based registration from OpenCV contrib package """ def get_params(self): LOGGER.info("{:s}:Params for this registration are:".format(self.__class__.__name__)) LOGGER.info("\t\t\t\tModel: {:s}".format(self.params['Model'])) LOGGER.info("\t\t\t\tDescriptor: {:s}".format(self.params['Descriptor'])) LOGGER.info("\t\t\t\tMaxIters: {:d}".format(self.params['MaxIters'])) LOGGER.info("\t\t\t\tRANSACThreshold: {:.2f}".format(self.params['RANSACThreshold'])) def check_params(self): if not (self.params.get('Model') in ['Euler', 'PartialAffine', 'Homography']): LOGGER.info("{:s}:Model set to Euler".format(self.__class__.__name__)) self.params['Model'] = 'Euler' if not (self.params.get('Descriptor') in ['SIFT', 'SURF']): LOGGER.info("{:s}:Descriptor set to SIFT".format(self.__class__.__name__)) self.params['Descriptor'] = 'SIFT' if not isinstance(self.params.get('MaxIters'), int): LOGGER.info("{:s}:RANSAC MaxIters set to 1000".format(self.__class__.__name__)) self.params['MaxIters'] = 1000 if not isinstance(self.params.get('RANSACThreshold'), float): LOGGER.info("{:s}:RANSAC threshold set to 7.0".format(self.__class__.__name__)) self.params['RANSACThreshold'] = 7.0 def register(self, src, trg, trg_mask=None, src_mask=None): """ Implementation of pair-wise registration and warping using point-based matching This function estimates a number of transforms (Euler, PartialAffine and Homography) using point-based matching. Features descriptor are first extracted from the pair of images using either SIFT or SURF descriptors. A brute-force point-matching algorithm estimates matching points and a transformation is computed. All transformations use RANSAC to robustly fit a tranform to the matching points. However, the feature extraction and point matching estimation can be very poor and unstable. In those cases, an identity transform is used to warp the images instead. :param src: 2D single channel source moving image :param trg: 2D single channel target reference image :param trg_mask: Mask of target image. Not used in this method. :param src_mask: Mask of source image. Not used in this method. :return: Estimated 2D transformation matrix of shape 2x3 """ # Initialise matrix and failed registrations flag warp_matrix = None # Initiate point detector ptdt = cv2.xfeatures2d.SIFT_create() if self.params['Descriptor'] == 'SIFT' else cv2.xfeatures2d.SURF_create() # create BFMatcher object bf_matcher = cv2.BFMatcher(cv2.NORM_L1, crossCheck=True) # find the keypoints and descriptors with SIFT kp1, des1 = ptdt.detectAndCompute(self.rescale_image(src), None) kp2, des2 = ptdt.detectAndCompute(self.rescale_image(trg), None) # Match descriptors if any are found if des1 is not None and des2 is not None: matches = bf_matcher.match(des1, des2) # Sort them in the order of their distance. matches = sorted(matches, key=lambda x: x.distance) src_pts = np.asarray([kp1[m.queryIdx].pt for m in matches], dtype=np.float32).reshape(-1, 2) trg_pts = np.asarray([kp2[m.trainIdx].pt for m in matches], dtype=np.float32).reshape(-1, 2) # Parse model and estimate matrix if self.params['Model'] == 'PartialAffine': warp_matrix = cv2.estimateRigidTransform(src_pts, trg_pts, fullAffine=False) elif self.params['Model'] == 'Euler': model = EstimateEulerTransformModel(src_pts, trg_pts) warp_matrix = ransac(src_pts.shape[0], model, 3, self.params['MaxIters'], 1, 5) elif self.params['Model'] == 'Homography': warp_matrix, _ = cv2.findHomography(src_pts, trg_pts, cv2.RANSAC, ransacReprojThreshold=self.params['RANSACThreshold'], maxIters=self.params['MaxIters']) if warp_matrix is not None: warp_matrix = warp_matrix[:2, :] return warp_matrix @staticmethod def rescale_image(image): """ Normalise and scale image in 0-255 range """ s2_min_value, s2_max_value = 0, 1 out_min_value, out_max_value = 0, 255 # Clamp values in 0-1 range image[image > s2_max_value] = s2_max_value image[image < s2_min_value] = s2_min_value # Rescale to uint8 range out_image = out_max_value + (image-s2_min_value)*(out_max_value-out_min_value)/(s2_max_value-s2_min_value) return out_image.astype(np.uint8)