import random import keras.backend as K import numpy as np from keras.preprocessing.image import transform_matrix_offset_center, apply_transform, random_channel_shift, flip_axis, \ load_img, img_to_array class ImageWithMaskFunction: def __init__(self, out_size, mask_dir, mask_suffix="_mask.gif", crop_size=None): super().__init__() self.out_size = out_size self.mask_dir = mask_dir self.mask_suffix = mask_suffix self.crop_size = crop_size def random_transform(self, x, mask, rotation_range=None, height_shift_range=None, width_shift_range=None, shear_range=None, zoom_range=None, channel_shift_range=None, horizontal_flip=None, vertical_flip=None, fill_mode='constant', cval=0): """Randomly augment a image tensor and mask. # Arguments x: 3D tensor, single image. # Returns A randomly transformed version of the input (same shape). """ # x is a single image, so it doesn't have image number at index 0 img_row_axis = 0 img_col_axis = 1 img_channel_axis = 2 # use composition of homographies # to generate final transform that needs to be applied if rotation_range: theta = np.pi / 180 * np.random.uniform(-rotation_range, rotation_range) else: theta = 0 if height_shift_range: uniform = np.random.uniform(-height_shift_range, height_shift_range) tx = uniform * x.shape[img_row_axis] tmx = uniform * mask.shape[img_row_axis] else: tx = 0 tmx = 0 if width_shift_range: random_uniform = np.random.uniform(-width_shift_range, width_shift_range) ty = random_uniform * x.shape[img_col_axis] tmy = random_uniform * mask.shape[img_col_axis] else: ty = 0 tmy = 0 if shear_range: shear = np.random.uniform(-shear_range, shear_range) else: shear = 0 if zoom_range[0] == 1 and zoom_range[1] == 1: zx, zy = 1, 1 else: zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2) transform_matrix = None transform_matrix_mask = None if theta != 0: rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta), 0], [0, 0, 1]]) transform_matrix = rotation_matrix transform_matrix_mask = rotation_matrix if tx != 0 or ty != 0: shift_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]]) shift_matrix_mask = np.array([[1, 0, tmx], [0, 1, tmy], [0, 0, 1]]) transform_matrix = shift_matrix if transform_matrix is None else np.dot(transform_matrix, shift_matrix) transform_matrix_mask = shift_matrix_mask if transform_matrix_mask is None else np.dot( transform_matrix_mask, shift_matrix_mask) if shear != 0: shear_matrix = np.array([[1, -np.sin(shear), 0], [0, np.cos(shear), 0], [0, 0, 1]]) transform_matrix = shear_matrix if transform_matrix is None else np.dot(transform_matrix, shear_matrix) transform_matrix_mask = shear_matrix if transform_matrix_mask is None else np.dot(transform_matrix_mask, shear_matrix) if zx != 1 or zy != 1: zoom_matrix = np.array([[zx, 0, 0], [0, zy, 0], [0, 0, 1]]) transform_matrix = zoom_matrix if transform_matrix is None else np.dot(transform_matrix, zoom_matrix) transform_matrix_mask = zoom_matrix if transform_matrix_mask is None else np.dot(transform_matrix_mask, zoom_matrix) if transform_matrix is not None: h, w = x.shape[img_row_axis], x.shape[img_col_axis] transform_matrix = transform_matrix_offset_center(transform_matrix, h, w) x = apply_transform(x, transform_matrix, img_channel_axis, fill_mode=fill_mode, cval=cval) if transform_matrix_mask is not None: h, w = mask.shape[img_row_axis], mask.shape[img_col_axis] transform_matrix_mask = transform_matrix_offset_center(transform_matrix_mask, h, w) mask[:, :, 0:1] = apply_transform(mask[:, :, 0:1], transform_matrix_mask, img_channel_axis, fill_mode='constant', cval=0.) if channel_shift_range != 0: x = random_channel_shift(x, channel_shift_range, img_channel_axis) if horizontal_flip: if np.random.random() < 0.5: x = flip_axis(x, img_col_axis) mask = flip_axis(mask, img_col_axis) if vertical_flip: if np.random.random() < 0.5: x = flip_axis(x, img_row_axis) mask = flip_axis(mask, img_row_axis) return x, mask def mask_pred(self, batch_x, filenames, index_array, aug=False): mask_pred = np.zeros((len(batch_x), self.out_size[0], self.out_size[1], 1), dtype=K.floatx()) mask_pred[:, :, :, :] = 0. for i, j in enumerate(index_array): fname = filenames[j] mask = self.mask_dir + "/" + fname.split('/')[-1].replace(".jpg", self.mask_suffix) mask_pred[i, :, :, :] = img_to_array( load_img(mask, grayscale=True, target_size=(self.out_size[0], self.out_size[1]))) / 255. if aug: batch_x[i, :, :, :], mask_pred[i, :, :, :] = self.random_transform(x=batch_x[i, :, :, :], mask=mask_pred[i, :, :, :], height_shift_range=0.0, width_shift_range=0.0, shear_range=0.0, rotation_range=0, zoom_range=[0.95, 1.05], channel_shift_range=0.1, horizontal_flip=True) if self.crop_size: height = self.crop_size[0] width = self.crop_size[1] ori_height = self.out_size[0] ori_width = self.out_size[1] if aug: h_start = random.randint(0, ori_height - height - 1) w_start = random.randint(0, ori_width - width - 1) else: # validate on center crops h_start = (ori_height - height) // 2 w_start = (ori_width - width) // 2 MASK_CROP = mask_pred[:, h_start:h_start + height, w_start:w_start + width, :] return batch_x[:, h_start:h_start + height, w_start:w_start + width, :], MASK_CROP else: return batch_x, mask_pred def mask_pred_train(self, batch_x, filenames, index_array, l): return self.mask_pred(batch_x, filenames, index_array, True) def mask_pred_val(self, batch_x, filenames, index_array, l): return self.mask_pred(batch_x, filenames, index_array, False) def random_transform_two_masks(x, mask1, mask2, rotation_range=None, height_shift_range=None, width_shift_range=None, shear_range=None, zoom_range=None, channel_shift_range=None, horizontal_flip=None, vertical_flip=None, fill_mode='constant', cval=0): """Randomly augment a image tensor and masks. # Arguments x: 3D tensor, single image. # Returns A randomly transformed version of the input (same shape). """ # x is a single image, so it doesn't have image number at index 0 img_row_axis = 0 img_col_axis = 1 img_channel_axis = 2 # use composition of homographies # to generate final transform that needs to be applied if rotation_range: theta = np.pi / 180 * np.random.uniform(-rotation_range, rotation_range) else: theta = 0 if height_shift_range: uniform = np.random.uniform(-height_shift_range, height_shift_range) tx = uniform * x.shape[img_row_axis] tmx1 = uniform * mask1.shape[img_row_axis] tmx2 = uniform * mask2.shape[img_row_axis] else: tx = 0 tmx1 = 0 tmx2 = 0 if width_shift_range: random_uniform = np.random.uniform(-width_shift_range, width_shift_range) ty = random_uniform * x.shape[img_col_axis] tmy1 = random_uniform * mask1.shape[img_col_axis] tmy2 = random_uniform * mask2.shape[img_col_axis] else: ty = 0 tmy1 = 0 tmy2 = 0 if shear_range: shear = np.random.uniform(-shear_range, shear_range) else: shear = 0 if zoom_range[0] == 1 and zoom_range[1] == 1: zx, zy = 1, 1 else: zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2) transform_matrix = None transform_matrix_mask1 = None transform_matrix_mask2 = None if theta != 0: rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta), 0], [0, 0, 1]]) transform_matrix = rotation_matrix transform_matrix_mask1 = rotation_matrix transform_matrix_mask2 = rotation_matrix if tx != 0 or ty != 0: shift_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]]) shift_matrix_mask1 = np.array([[1, 0, tmx1], [0, 1, tmy1], [0, 0, 1]]) shift_matrix_mask2 = np.array([[1, 0, tmx2], [0, 1, tmy2], [0, 0, 1]]) transform_matrix = shift_matrix if transform_matrix is None else np.dot(transform_matrix, shift_matrix) transform_matrix_mask1 = shift_matrix_mask1 if transform_matrix_mask1 is None else np.dot( transform_matrix_mask1, shift_matrix_mask1) transform_matrix_mask2 = shift_matrix_mask1 if transform_matrix_mask2 is None else np.dot( transform_matrix_mask2, shift_matrix_mask2) if shear != 0: shear_matrix = np.array([[1, -np.sin(shear), 0], [0, np.cos(shear), 0], [0, 0, 1]]) transform_matrix = shear_matrix if transform_matrix is None else np.dot(transform_matrix, shear_matrix) transform_matrix_mask1 = shear_matrix if transform_matrix_mask1 is None else np.dot(transform_matrix_mask1, shear_matrix) transform_matrix_mask2 = shear_matrix if transform_matrix_mask2 is None else np.dot(transform_matrix_mask2, shear_matrix) if zx != 1 or zy != 1: zoom_matrix = np.array([[zx, 0, 0], [0, zy, 0], [0, 0, 1]]) transform_matrix = zoom_matrix if transform_matrix is None else np.dot(transform_matrix, zoom_matrix) transform_matrix_mask1 = zoom_matrix if transform_matrix_mask1 is None else np.dot(transform_matrix_mask1, zoom_matrix) transform_matrix_mask2 = zoom_matrix if transform_matrix_mask2 is None else np.dot(transform_matrix_mask2, zoom_matrix) if transform_matrix is not None: h, w = x.shape[img_row_axis], x.shape[img_col_axis] transform_matrix = transform_matrix_offset_center(transform_matrix, h, w) x = apply_transform(x, transform_matrix, img_channel_axis, fill_mode=fill_mode, cval=cval) if transform_matrix_mask1 is not None: h, w = mask1.shape[img_row_axis], mask1.shape[img_col_axis] transform_matrix_mask1 = transform_matrix_offset_center(transform_matrix_mask1, h, w) mask1[:, :, 0:1] = apply_transform(mask1[:, :, 0:1], transform_matrix_mask1, img_channel_axis, fill_mode='constant', cval=0.) if transform_matrix_mask2 is not None: h, w = mask2.shape[img_row_axis], mask2.shape[img_col_axis] transform_matrix_mask2 = transform_matrix_offset_center(transform_matrix_mask2, h, w) mask2[:, :, 0:1] = apply_transform(mask2[:, :, 0:1], transform_matrix_mask2, img_channel_axis, fill_mode='constant', cval=0.) if channel_shift_range != 0: x = random_channel_shift(x, channel_shift_range, img_channel_axis) if horizontal_flip: if np.random.random() < 0.5: x = flip_axis(x, img_col_axis) mask1 = flip_axis(mask1, img_col_axis) mask2 = flip_axis(mask2, img_col_axis) if vertical_flip: if np.random.random() < 0.5: x = flip_axis(x, img_row_axis) mask1 = flip_axis(mask1, img_row_axis) mask2 = flip_axis(mask2, img_row_axis) return x, mask1, mask2