import random import cv2 import numpy as np import torch import torch.nn.functional as F import torchvision.transforms as tt from PIL import Image from .registry import TRANSFORMS CV2_MODE = { 'bilinear': cv2.INTER_LINEAR, 'nearest': cv2.INTER_NEAREST, 'cubic': cv2.INTER_CUBIC, 'area': cv2.INTER_AREA, } CV2_BORDER_MODE = { 'constant': cv2.BORDER_CONSTANT, 'reflect': cv2.BORDER_REFLECT, 'reflect101': cv2.BORDER_REFLECT101, 'replicate': cv2.BORDER_REPLICATE, } class Compose: def __init__(self, transforms): self.transforms = transforms def __call__(self, image, mask): for t in self.transforms: image, mask = t(image, mask) return image, mask @TRANSFORMS.register_module class FactorScale: def __init__(self, scale_factor=1.0, mode='bilinear'): self.mode = mode self.scale_factor = scale_factor def rescale(self, image, mask): h, w, c = image.shape if self.scale_factor == 1.0: return image, mask new_h = int(h * self.scale_factor) new_w = int(w * self.scale_factor) torch_image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0) torch_mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0) torch_image = F.interpolate(torch_image, size=(new_h, new_w), mode=self.mode, align_corners=True) torch_mask = F.interpolate(torch_mask, size=(new_h, new_w), mode='nearest') new_image = torch_image.squeeze().permute(1, 2, 0).numpy() new_mask = torch_mask.squeeze().numpy() return new_image, new_mask def __call__(self, image, mask): return self.rescale(image, mask) @TRANSFORMS.register_module class SizeScale(FactorScale): def __init__(self, target_size, mode='bilinear'): self.target_size = target_size super().__init__(mode=mode) def __call__(self, image, mask): h, w, _ = image.shape long_edge = max(h, w) self.scale_factor = self.target_size / long_edge return self.rescale(image, mask) @TRANSFORMS.register_module class RandomScale(FactorScale): def __init__(self, min_scale, max_scale, scale_step=0.0, mode='bilinear'): self.min_scale = min_scale self.max_scale = max_scale self.scale_step = scale_step super().__init__(mode=mode) @staticmethod def get_scale_factor(min_scale, max_scale, scale_step): if min_scale == max_scale: return min_scale if scale_step == 0: return random.uniform(min_scale, max_scale) num_steps = int((max_scale - min_scale) / scale_step + 1) scale_factors = np.linspace(min_scale, max_scale, num_steps) scale_factor = np.random.choice(scale_factors).item() return scale_factor def __call__(self, image, mask): self.scale_factor = self.get_scale_factor(self.min_scale, self.max_scale, self.scale_step) return self.rescale(image, mask) @TRANSFORMS.register_module class RandomCrop: def __init__(self, height, width, image_value, mask_value): self.height = height self.width = width self.image_value = image_value self.mask_value = mask_value self.channel = len(image_value) def __call__(self, image, mask): h, w, c = image.shape target_height = h + max(self.height - h, 0) target_width = w + max(self.width - w, 0) image_pad_value = np.reshape(np.array(self.image_value, dtype=image.dtype), [1, 1, self.channel]) mask_pad_value = np.reshape(np.array(self.mask_value, dtype=mask.dtype), [1, 1]) new_image = np.tile(image_pad_value, (target_height, target_width, 1)) new_mask = np.tile(mask_pad_value, (target_height, target_width)) new_image[:h, :w, :] = image new_mask[:h, :w] = mask assert np.count_nonzero(mask != self.mask_value) == np.count_nonzero(new_mask != self.mask_value) y1 = int(random.uniform(0, target_height - self.height + 1)) y2 = y1 + self.height x1 = int(random.uniform(0, target_width - self.width + 1)) x2 = x1 + self.width new_image = new_image[y1:y2, x1:x2, :] new_mask = new_mask[y1:y2, x1:x2] return new_image, new_mask @TRANSFORMS.register_module class PadIfNeeded: def __init__(self, height, width, image_value, mask_value): self.height = height self.width = width self.image_value = image_value self.mask_value = mask_value self.channel = len(image_value) def __call__(self, image, mask): h, w, c = image.shape assert h <= self.height and w <= self.width target_height = h + max(self.height - h, 0) target_width = w + max(self.width - w, 0) image_pad_value = np.reshape(np.array(self.image_value, dtype=image.dtype), [1, 1, self.channel]) mask_pad_value = np.reshape(np.array(self.mask_value, dtype=mask.dtype), [1, 1]) new_image = np.tile(image_pad_value, (target_height, target_width, 1)) new_mask = np.tile(mask_pad_value, (target_height, target_width)) new_image[:h, :w, :] = image new_mask[:h, :w] = mask assert np.count_nonzero(mask != self.mask_value) == np.count_nonzero(new_mask != self.mask_value) return new_image, new_mask @TRANSFORMS.register_module class HorizontalFlip: def __init__(self, p=0.5): self.p = p def __call__(self, image, mask): if random.random() > self.p: image = cv2.flip(image, 1) mask = cv2.flip(mask, 1) return image, mask @TRANSFORMS.register_module class RandomRotate: def __init__(self, p=0.5, degrees=30, mode='bilinear', border_mode='reflect101', image_value=None, mask_value=None): self.p = p self.degrees = (-degrees, degrees) if isinstance(degrees, (int, float)) else degrees self.mode = CV2_MODE[mode] self.border_mode = CV2_BORDER_MODE[border_mode] self.image_value = image_value self.mask_value = mask_value def __call__(self, image, mask): if random.random() < self.p: h, w, c = image.shape angle = random.uniform(*self.degrees) matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1.0) image = cv2.warpAffine(image, M=matrix, dsize=(w, h), flags=self.mode, borderMode=self.border_mode, borderValue=self.image_value) mask = cv2.warpAffine(mask, M=matrix, dsize=(w, h), flags=cv2.INTER_NEAREST, borderMode=self.border_mode, borderValue=self.mask_value) return image, mask @TRANSFORMS.register_module class GaussianBlur: def __init__(self, p=0.5, ksize=7): self.p = p self.ksize = (ksize, ksize) if isinstance(ksize, int) else ksize def __call__(self, image, mask): if random.random() < self.p: image = cv2.GaussianBlur(image, ksize=self.ksize, sigmaX=0) return image, mask @TRANSFORMS.register_module class Normalize: def __init__(self, mean=(123.675, 116.280, 103.530), std=(58.395, 57.120, 57.375)): self.mean = mean self.std = std self.channel = len(mean) def __call__(self, image, mask): mean = np.reshape(np.array(self.mean, dtype=image.dtype), [1, 1, self.channel]) std = np.reshape(np.array(self.std, dtype=image.dtype), [1, 1, self.channel]) denominator = np.reciprocal(std, dtype=image.dtype) new_image = (image - mean) * denominator new_mask = mask return new_image, new_mask @TRANSFORMS.register_module class ColorJitter(tt.ColorJitter): def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) def __call__(self, image=None, mask=None): new_image = Image.fromarray(image.astype(np.uint8)) new_image = super().__call__(new_image) new_image = np.array(new_image).astype(np.float32) return new_image, mask @TRANSFORMS.register_module class ToTensor: def __call__(self, image, mask): image = torch.from_numpy(image).permute(2, 0, 1) mask = torch.from_numpy(mask) return image, mask