#coding=utf-8 import cv2 import random import os import numpy as np from tqdm import tqdm img_w = 64 img_h = 64 basePath="C:\\Users\Administrator\Desktop\Project\\"; image_sets = ['1.png'] #高斯噪声 def gamma_transform(img, gamma): gamma_table = [np.power(x / 255.0, gamma) * 255.0 for x in range(256)] gamma_table = np.round(np.array(gamma_table)).astype(np.uint8) return cv2.LUT(img, gamma_table) #椒盐噪声 def random_gamma_transform(img, gamma_vari): log_gamma_vari = np.log(gamma_vari) alpha = np.random.uniform(-log_gamma_vari, log_gamma_vari) gamma = np.exp(alpha) return gamma_transform(img, gamma) #旋转处理 def rotate(xb,yb,angle): M_rotate = cv2.getRotationMatrix2D((img_w/2, img_h/2), angle, 1) xb = cv2.warpAffine(xb, M_rotate, (img_w, img_h)) yb = cv2.warpAffine(yb, M_rotate, (img_w, img_h)) return xb,yb #模糊处理 def blur(img): img = cv2.blur(img, (3, 3)); return img #添加噪点信息 def add_noise(img): for i in range(200): #添加点噪声 temp_x = np.random.randint(0,img.shape[0]) temp_y = np.random.randint(0,img.shape[1]) img[temp_x][temp_y] = 255 return img #添加数据 def data_augment(xb,yb): if np.random.random() < 0.25: xb,yb = rotate(xb,yb,90) if np.random.random() < 0.25: xb,yb = rotate(xb,yb,180) if np.random.random() < 0.25: xb,yb = rotate(xb,yb,270) if np.random.random() < 0.25: xb = cv2.flip(xb, 1) # flipcode > 0:沿y轴翻转 yb = cv2.flip(yb, 1) #对原图像做模糊处理 if np.random.random() < 0.25: xb = random_gamma_transform(xb,1.0) if np.random.random() < 0.25: xb = blur(xb) if np.random.random() < 0.2: xb = add_noise(xb) return xb,yb #创建数据 def creat_dataset(image_num = 2000, mode = 'original'): print('creating dataset...') image_each = image_num / len(image_sets) g_count = 0 for i in tqdm(range(len(image_sets))): count = 0 imgPath=basePath+'train\\' + image_sets[i]; src_img = cv2.imread(imgPath) # 3 channels #print("\n图片"+imgPath) labelPath=basePath+'label\\' + image_sets[i] #print("图片"+labelPath) label_img = cv2.imread(labelPath,cv2.IMREAD_GRAYSCALE) # 1 channel #print(label_img) X_height,X_width,_ = src_img.shape print("\n") while count < image_each: random_width = random.randint(0, X_width - img_w - 1) random_height = random.randint(0, X_height - img_h - 1) src_roi = src_img[random_height: random_height + img_h, random_width: random_width + img_w,:] label_roi = label_img[random_height: random_height + img_h, random_width: random_width + img_w] if mode == 'augment': src_roi,label_roi = data_augment(src_roi,label_roi) visualize = np.zeros((64,64)).astype(np.uint8) visualize = label_roi *50 cv2.imwrite((basePath+'src//visualize//%d.png' % g_count),visualize) cv2.imwrite((basePath+'src//train//%d.png' % g_count),src_roi) cv2.imwrite((basePath+'src//label//%d.png' % g_count),label_roi) count += 1 g_count += 1 if count%100==0: print("已经生成"+ str(count) +"张图片") if __name__=='__main__': creat_dataset(mode='augment')