# -*- coding: utf-8 -*- """ This module provides many types of image augmentation. One can choose appropriate augmentation for detection, segmentation and classification. """ import cv2 import numpy import random class Augmentor(object): """ All augmentation operations are static methods of this class. """ def __init__(self): pass @staticmethod def histogram_equalisation(image): """ do histogram equlisation for grayscale image :param image: input image with single channel 8bits :return: processed image """ if image.ndim != 2: print('Input image is not grayscale!') return None if image.dtype != numpy.uint8: print('Input image is not uint8!') return None result = cv2.equalizeHist(image) return result @staticmethod def grayscale(image): """ convert BGR image to grayscale image :param image: input image with BGR channels :return: """ if image.ndim != 3: return None if image.dtype != numpy.uint8: print('Input image is not uint8!') return None result = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) return result @staticmethod def inversion(image): """ invert the image (255-) :param image: input image with BGR or grayscale :return: """ if image.dtype != numpy.uint8: print('Input image is not uint8!') return None result = 255 - image return result @staticmethod def binarization(image, block_size=5, C=10): """ convert input image to binary image cv2.adaptiveThreshold is used, for detailed information, refer to opencv docs :param image: :return: """ if image.ndim == 3: image_grayscale = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) else: image_grayscale = image binary_image = cv2.adaptiveThreshold(image_grayscale, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, block_size, C) return binary_image @staticmethod def brightness(image, min_factor=0.5, max_factor=1.5): ''' adjust the image brightness :param image: :param min_factor: :param max_factor: :return: ''' if image.dtype != numpy.uint8: print('Input image is not uint8!') return None factor = numpy.random.uniform(min_factor, max_factor) result = image * factor if factor > 1: result[result > 255] = 255 result = result.astype(numpy.uint8) return result @staticmethod def saturation(image, min_factor=0.5, max_factor=1.5): ''' adjust the image saturation :param image: :param min_factor: :param max_factor: :return: ''' if image.dtype != numpy.uint8: print('Input image is not uint8!') return None image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) factor = numpy.random.uniform(min_factor, max_factor) result = numpy.zeros(image.shape, dtype=numpy.float32) result[:, :, 0] = image[:, :, 0] * factor + image_gray * (1 - factor) result[:, :, 1] = image[:, :, 1] * factor + image_gray * (1 - factor) result[:, :, 2] = image[:, :, 2] * factor + image_gray * (1 - factor) result[result > 255] = 255 result[result < 0] = 0 result = result.astype(numpy.uint8) return result @staticmethod def contrast(image, min_factor=0.5, max_factor=1.5): ''' adjust the image contrast :param image: :param min_factor: :param max_factor: :return: ''' if image.dtype != numpy.uint8: print('Input image is not uint8!') return None image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) gray_mean = numpy.mean(image_gray) temp = numpy.ones((image.shape[0], image.shape[1]), dtype=numpy.float32) * gray_mean factor = numpy.random.uniform(min_factor, max_factor) result = numpy.zeros(image.shape, dtype=numpy.float32) result[:, :, 0] = image[:, :, 0] * factor + temp * (1 - factor) result[:, :, 1] = image[:, :, 1] * factor + temp * (1 - factor) result[:, :, 2] = image[:, :, 2] * factor + temp * (1 - factor) result[result > 255] = 255 result[result < 0] = 0 result = result.astype(numpy.uint8) return result @staticmethod def blur(image, mode='random', kernel_size=3, sigma=1): """ :param image: :param mode: options 'normalized' 'gaussian' 'median' :param kernel_size: :param sigma: used for gaussian blur :return: """ if image.dtype != numpy.uint8: print('Input image is not uint8!') return None if mode == 'random': mode = random.choice(['normalized', 'gaussian', 'median']) if mode == 'normalized': result = cv2.blur(image, (kernel_size, kernel_size)) elif mode == 'gaussian': result = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigmaX=sigma, sigmaY=sigma) elif mode == 'median': result = cv2.medianBlur(image, kernel_size) else: print('Blur mode is not supported: %s.' % mode) result = image return result @staticmethod def rotation(image, degree=10, mode='crop', scale=1): """ :param image: :param degree: :param mode: 'crop'-keep original size, 'fill'-keep full image :param scale: :return: """ if image.dtype != numpy.uint8: print('Input image is not uint8!') return None h, w = image.shape[:2] center_x, center_y = w / 2, h / 2 M = cv2.getRotationMatrix2D((center_x, center_y), degree, scale) if mode == 'crop': new_w, new_h = w, h else: cos = numpy.abs(M[0, 0]) sin = numpy.abs(M[0, 1]) new_w = int(h * sin + w * cos) new_h = int(h * cos + w * sin) M[0, 2] += (new_w / 2) - center_x M[1, 2] += (new_h / 2) - center_y result = cv2.warpAffine(image, M, (new_w, new_h)) return result @staticmethod def flip(image, orientation='h'): ''' :param image: :param orientation: :return: ''' if image.dtype != numpy.uint8: print('Input image is not uint8!') return None if orientation == 'h': return cv2.flip(image, 1) elif orientation == 'v': return cv2.flip(image, 0) else: print('Unsupported orientation: %s.' % orientation) return image @staticmethod def resize(image, size_in_pixel=None, size_in_scale=None): """ :param image: :param size_in_pixel: tuple (width, height) :param size_in_scale: tuple (width_scale, height_scale) :return: """ if image.dtype != numpy.uint8: print('Input image is not uint8!') return None if size_in_pixel is not None: return cv2.resize(image, size_in_pixel) elif size_in_scale is not None: return cv2.resize(image, (0, 0), fx=size_in_scale[0], fy=size_in_scale[1]) else: print('size_in_pixel and size_in_scale are both None.') return image @staticmethod def crop(image, x, y, width, height): """ :param image: :param x: crop area top-left x coordinate :param y: crop area top-left y coordinate :param width: crop area width :param height: crop area height :return: """ if image.dtype != numpy.uint8: print('Input image is not uint8!') return None if image.ndim == 3: return image[y:y + height, x:x + width, :] else: return image[y:y + height, x:x + width] @staticmethod def random_crop(image, width, height): """ :param image: :param width: crop area width :param height: crop area height :return: """ if image.dtype != numpy.uint8: print('Input image is not uint8!') return False, image w_interval = image.shape[1] - width h_interval = image.shape[0] - height if image.ndim == 3: result = numpy.zeros((height, width, 3), dtype=numpy.uint8) else: result = numpy.zeros((height, width), dtype=numpy.uint8) if w_interval >= 0 and h_interval >= 0: crop_x, crop_y = random.randint(0, w_interval), random.randint(0, h_interval) if image.ndim == 3: result = image[crop_y:crop_y + height, crop_x:crop_x + width, :] else: result = image[crop_y:crop_y + height, crop_x:crop_x + width] elif w_interval < 0 and h_interval >= 0: put_x = -w_interval / 2 crop_y = random.randint(0, h_interval) if image.ndim == 3: result[:, put_x:put_x + image.shape[1], :] = image[crop_y:crop_y + height, :, :] else: result[:, put_x:put_x + image.shape[1]] = image[crop_y:crop_y + height, :] elif w_interval >= 0 and h_interval < 0: crop_x = random.randint(0, w_interval) put_y = -h_interval / 2 if image.ndim == 3: result[put_y:put_y + image.shape[0], :, :] = image[:, crop_x:crop_x + width, :] else: result[put_y:put_y + image.shape[0], :] = image[:, crop_x:crop_x + width] else: put_x, put_y = -w_interval / 2, -h_interval / 2 if image.ndim == 3: result[put_y:put_y + image.shape[0], put_x:put_x + image.shape[1], :] = image[:, :, :] else: result[put_y:put_y + image.shape[0], put_x:put_x + image.shape[1]] = image[:, :] return result