import numpy as np import cv2 import matplotlib.pyplot as plt from skimage import color from sklearn.cluster import KMeans import os from scipy.ndimage.interpolation import zoom def create_temp_directory(path_template, N=1e8): print(path_template) cur_path = path_template % np.random.randint(0, N) while(os.path.exists(cur_path)): cur_path = path_template % np.random.randint(0, N) print('Creating directory: %s' % cur_path) os.mkdir(cur_path) return cur_path def lab2rgb_transpose(img_l, img_ab): ''' INPUTS img_l 1xXxX [0,100] img_ab 2xXxX [-100,100] OUTPUTS returned value is XxXx3 ''' pred_lab = np.concatenate((img_l, img_ab), axis=0).transpose((1, 2, 0)) pred_rgb = (np.clip(color.lab2rgb(pred_lab), 0, 1) * 255).astype('uint8') return pred_rgb def rgb2lab_transpose(img_rgb): ''' INPUTS img_rgb XxXx3 OUTPUTS returned value is 3xXxX ''' return color.rgb2lab(img_rgb).transpose((2, 0, 1)) class ColorizeImageBase(): def __init__(self, Xd=256, Xfullres_max=10000): self.Xd = Xd self.img_l_set = False self.net_set = False self.Xfullres_max = Xfullres_max # maximum size of maximum dimension self.img_just_set = False # this will be true whenever image is just loaded # net_forward can set this to False if they want def prep_net(self): raise Exception("Should be implemented by base class") # ***** Image prepping ***** def load_image(self, input_path): # rgb image [CxXdxXd] im = cv2.cvtColor(cv2.imread(input_path, 1), cv2.COLOR_BGR2RGB) self.img_rgb_fullres = im.copy() self._set_img_lab_fullres_() im = cv2.resize(im, (self.Xd, self.Xd)) self.img_rgb = im.copy() # self.img_rgb = sp.misc.imresize(plt.imread(input_path),(self.Xd,self.Xd)).transpose((2,0,1)) self.img_l_set = True # convert into lab space self._set_img_lab_() self._set_img_lab_mc_() def set_image(self, input_image): self.img_rgb_fullres = input_image.copy() self._set_img_lab_fullres_() self.img_l_set = True self.img_rgb = input_image # convert into lab space self._set_img_lab_() self._set_img_lab_mc_() def net_forward(self, input_ab, input_mask): # INPUTS # ab 2xXxX input color patches (non-normalized) # mask 1xXxX input mask, indicating which points have been provided # assumes self.img_l_mc has been set if(not self.img_l_set): print('I need to have an image!') return -1 if(not self.net_set): print('I need to have a net!') return -1 self.input_ab = input_ab self.input_ab_mc = (input_ab - self.ab_mean) / self.ab_norm self.input_mask = input_mask self.input_mask_mult = input_mask * self.mask_mult return 0 def get_result_PSNR(self, result=-1, return_SE_map=False): if np.array((result)).flatten()[0] == -1: cur_result = self.get_img_forward() else: cur_result = result.copy() SE_map = (1. * self.img_rgb - cur_result)**2 cur_MSE = np.mean(SE_map) cur_PSNR = 20 * np.log10(255. / np.sqrt(cur_MSE)) if return_SE_map: return(cur_PSNR, SE_map) else: return cur_PSNR def get_img_forward(self): # get image with point estimate return self.output_rgb def get_img_gray(self): # Get black and white image return lab2rgb_transpose(self.img_l, np.zeros((2, self.Xd, self.Xd))) def get_img_gray_fullres(self): # Get black and white image return lab2rgb_transpose(self.img_l_fullres, np.zeros((2, self.img_l_fullres.shape[1], self.img_l_fullres.shape[2]))) def get_img_fullres(self): # This assumes self.img_l_fullres, self.output_ab are set. # Typically, this means that set_image() and net_forward() # have been called. # bilinear upsample zoom_factor = (1, 1. * self.img_l_fullres.shape[1] / self.output_ab.shape[1], 1. * self.img_l_fullres.shape[2] / self.output_ab.shape[2]) output_ab_fullres = zoom(self.output_ab, zoom_factor, order=1) return lab2rgb_transpose(self.img_l_fullres, output_ab_fullres) def get_input_img_fullres(self): zoom_factor = (1, 1. * self.img_l_fullres.shape[1] / self.input_ab.shape[1], 1. * self.img_l_fullres.shape[2] / self.input_ab.shape[2]) input_ab_fullres = zoom(self.input_ab, zoom_factor, order=1) return lab2rgb_transpose(self.img_l_fullres, input_ab_fullres) def get_input_img(self): return lab2rgb_transpose(self.img_l, self.input_ab) def get_img_mask(self): # Get black and white image return lab2rgb_transpose(100. * (1 - self.input_mask), np.zeros((2, self.Xd, self.Xd))) def get_img_mask_fullres(self): # Get black and white image zoom_factor = (1, 1. * self.img_l_fullres.shape[1] / self.input_ab.shape[1], 1. * self.img_l_fullres.shape[2] / self.input_ab.shape[2]) input_mask_fullres = zoom(self.input_mask, zoom_factor, order=0) return lab2rgb_transpose(100. * (1 - input_mask_fullres), np.zeros((2, input_mask_fullres.shape[1], input_mask_fullres.shape[2]))) def get_sup_img(self): return lab2rgb_transpose(50 * self.input_mask, self.input_ab) def get_sup_fullres(self): zoom_factor = (1, 1. * self.img_l_fullres.shape[1] / self.output_ab.shape[1], 1. * self.img_l_fullres.shape[2] / self.output_ab.shape[2]) input_mask_fullres = zoom(self.input_mask, zoom_factor, order=0) input_ab_fullres = zoom(self.input_ab, zoom_factor, order=0) return lab2rgb_transpose(50 * input_mask_fullres, input_ab_fullres) # ***** Private functions ***** def _set_img_lab_fullres_(self): # adjust full resolution image to be within maximum dimension is within Xfullres_max Xfullres = self.img_rgb_fullres.shape[0] Yfullres = self.img_rgb_fullres.shape[1] if Xfullres > self.Xfullres_max or Yfullres > self.Xfullres_max: if Xfullres > Yfullres: zoom_factor = 1. * self.Xfullres_max / Xfullres else: zoom_factor = 1. * self.Xfullres_max / Yfullres self.img_rgb_fullres = zoom(self.img_rgb_fullres, (zoom_factor, zoom_factor, 1), order=1) self.img_lab_fullres = color.rgb2lab(self.img_rgb_fullres).transpose((2, 0, 1)) self.img_l_fullres = self.img_lab_fullres[[0], :, :] self.img_ab_fullres = self.img_lab_fullres[1:, :, :] def _set_img_lab_(self): # set self.img_lab from self.im_rgb self.img_lab = color.rgb2lab(self.img_rgb).transpose((2, 0, 1)) self.img_l = self.img_lab[[0], :, :] self.img_ab = self.img_lab[1:, :, :] def _set_img_lab_mc_(self): # set self.img_lab_mc from self.img_lab # lab image, mean centered [XxYxX] self.img_lab_mc = self.img_lab / np.array((self.l_norm, self.ab_norm, self.ab_norm))[:, np.newaxis, np.newaxis] - np.array( (self.l_mean / self.l_norm, self.ab_mean / self.ab_norm, self.ab_mean / self.ab_norm))[:, np.newaxis, np.newaxis] self._set_img_l_() def _set_img_l_(self): self.img_l_mc = self.img_lab_mc[[0], :, :] self.img_l_set = True def _set_img_ab_(self): self.img_ab_mc = self.img_lab_mc[[1, 2], :, :] def _set_out_ab_(self): self.output_lab = rgb2lab_transpose(self.output_rgb) self.output_ab = self.output_lab[1:, :, :] class ColorizeImageTorch(ColorizeImageBase): def __init__(self, Xd=256, maskcent=False): print('ColorizeImageTorch instantiated') ColorizeImageBase.__init__(self, Xd) self.l_norm = 1. self.ab_norm = 1. self.l_mean = 50. self.ab_mean = 0. self.mask_mult = 1. self.mask_cent = .5 if maskcent else 0 # Load grid properties self.pts_in_hull = np.array(np.meshgrid(np.arange(-110, 120, 10), np.arange(-110, 120, 10))).reshape((2, 529)).T # ***** Net preparation ***** def prep_net(self, gpu_id=None, path='', dist=False): import torch import models.pytorch.model as model print('path = %s' % path) print('Model set! dist mode? ', dist) self.net = model.SIGGRAPHGenerator(dist=dist) state_dict = torch.load(path) if hasattr(state_dict, '_metadata'): del state_dict._metadata # patch InstanceNorm checkpoints prior to 0.4 for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop self.__patch_instance_norm_state_dict(state_dict, self.net, key.split('.')) self.net.load_state_dict(state_dict) if gpu_id != None: self.net.cuda() self.net.eval() self.net_set = True def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): key = keys[i] if i + 1 == len(keys): # at the end, pointing to a parameter/buffer if module.__class__.__name__.startswith('InstanceNorm') and \ (key == 'running_mean' or key == 'running_var'): if getattr(module, key) is None: state_dict.pop('.'.join(keys)) if module.__class__.__name__.startswith('InstanceNorm') and \ (key == 'num_batches_tracked'): state_dict.pop('.'.join(keys)) else: self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) # ***** Call forward ***** def net_forward(self, input_ab, input_mask): # INPUTS # ab 2xXxX input color patches (non-normalized) # mask 1xXxX input mask, indicating which points have been provided # assumes self.img_l_mc has been set if ColorizeImageBase.net_forward(self, input_ab, input_mask) == -1: return -1 # net_input_prepped = np.concatenate((self.img_l_mc, self.input_ab_mc, self.input_mask_mult), axis=0) # return prediction # self.net.blobs['data_l_ab_mask'].data[...] = net_input_prepped # embed() output_ab = self.net.forward(self.img_l_mc, self.input_ab_mc, self.input_mask_mult, self.mask_cent)[0, :, :, :].cpu().data.numpy() self.output_rgb = lab2rgb_transpose(self.img_l, output_ab) # self.output_rgb = lab2rgb_transpose(self.img_l, self.net.blobs[self.pred_ab_layer].data[0, :, :, :]) self._set_out_ab_() return self.output_rgb def get_img_forward(self): # get image with point estimate return self.output_rgb def get_img_gray(self): # Get black and white image return lab2rgb_transpose(self.img_l, np.zeros((2, self.Xd, self.Xd))) class ColorizeImageTorchDist(ColorizeImageTorch): def __init__(self, Xd=256, maskcent=False): ColorizeImageTorch.__init__(self, Xd) self.dist_ab_set = False self.pts_grid = np.array(np.meshgrid(np.arange(-110, 120, 10), np.arange(-110, 120, 10))).reshape((2, 529)).T self.in_hull = np.ones(529, dtype=bool) self.AB = self.pts_grid.shape[0] # 529 self.A = int(np.sqrt(self.AB)) # 23 self.B = int(np.sqrt(self.AB)) # 23 self.dist_ab_full = np.zeros((self.AB, self.Xd, self.Xd)) self.dist_ab_grid = np.zeros((self.A, self.B, self.Xd, self.Xd)) self.dist_entropy = np.zeros((self.Xd, self.Xd)) self.mask_cent = .5 if maskcent else 0 def prep_net(self, gpu_id=None, path='', dist=True, S=.2): ColorizeImageTorch.prep_net(self, gpu_id=gpu_id, path=path, dist=dist) # set S somehow def net_forward(self, input_ab, input_mask): # INPUTS # ab 2xXxX input color patches (non-normalized) # mask 1xXxX input mask, indicating which points have been provided # assumes self.img_l_mc has been set # embed() if ColorizeImageBase.net_forward(self, input_ab, input_mask) == -1: return -1 # set distribution (function_return, self.dist_ab) = self.net.forward(self.img_l_mc, self.input_ab_mc, self.input_mask_mult, self.mask_cent) function_return = function_return[0, :, :, :].cpu().data.numpy() self.dist_ab = self.dist_ab[0, :, :, :].cpu().data.numpy() self.dist_ab_set = True # full grid, ABxXxX, AB = 529 self.dist_ab_full[self.in_hull, :, :] = self.dist_ab # gridded, AxBxXxX, A = 23 self.dist_ab_grid = self.dist_ab_full.reshape((self.A, self.B, self.Xd, self.Xd)) # return return function_return def get_ab_reccs(self, h, w, K=5, N=25000, return_conf=False): ''' Recommended colors at point (h,w) Call this after calling net_forward ''' if not self.dist_ab_set: print('Need to set prediction first') return 0 # randomly sample from pdf cmf = np.cumsum(self.dist_ab[:, h, w]) # CMF cmf = cmf / cmf[-1] cmf_bins = cmf # randomly sample N points rnd_pts = np.random.uniform(low=0, high=1.0, size=N) inds = np.digitize(rnd_pts, bins=cmf_bins) rnd_pts_ab = self.pts_in_hull[inds, :] # run k-means kmeans = KMeans(n_clusters=K).fit(rnd_pts_ab) # sort by cluster occupancy k_label_cnt = np.histogram(kmeans.labels_, np.arange(0, K + 1))[0] k_inds = np.argsort(k_label_cnt, axis=0)[::-1] cluster_per = 1. * k_label_cnt[k_inds] / N # percentage of points within cluster cluster_centers = kmeans.cluster_centers_[k_inds, :] # cluster centers # cluster_centers = np.random.uniform(low=-100,high=100,size=(N,2)) if return_conf: return cluster_centers, cluster_per else: return cluster_centers def compute_entropy(self): # compute the distribution entropy (really slow right now) self.dist_entropy = np.sum(self.dist_ab * np.log(self.dist_ab), axis=0) def plot_dist_grid(self, h, w): # Plots distribution at a given point plt.figure() plt.imshow(self.dist_ab_grid[:, :, h, w], extent=[-110, 110, 110, -110], interpolation='nearest') plt.colorbar() plt.ylabel('a') plt.xlabel('b') def plot_dist_entropy(self): # Plots distribution at a given point plt.figure() plt.imshow(-self.dist_entropy, interpolation='nearest') plt.colorbar() class ColorizeImageCaffe(ColorizeImageBase): def __init__(self, Xd=256): print('ColorizeImageCaffe instantiated') ColorizeImageBase.__init__(self, Xd) self.l_norm = 1. self.ab_norm = 1. self.l_mean = 50. self.ab_mean = 0. self.mask_mult = 110. self.pred_ab_layer = 'pred_ab' # predicted ab layer # Load grid properties self.pts_in_hull_path = './data/color_bins/pts_in_hull.npy' self.pts_in_hull = np.load(self.pts_in_hull_path) # 313x2, in-gamut # ***** Net preparation ***** def prep_net(self, gpu_id, prototxt_path='', caffemodel_path=''): import caffe print('gpu_id = %d, net_path = %s, model_path = %s' % (gpu_id, prototxt_path, caffemodel_path)) if gpu_id == -1: caffe.set_mode_cpu() else: caffe.set_device(gpu_id) caffe.set_mode_gpu() self.gpu_id = gpu_id self.net = caffe.Net(prototxt_path, caffemodel_path, caffe.TEST) self.net_set = True # automatically set cluster centers if len(self.net.params[self.pred_ab_layer][0].data[...].shape) == 4 and self.net.params[self.pred_ab_layer][0].data[...].shape[1] == 313: print('Setting ab cluster centers in layer: %s' % self.pred_ab_layer) self.net.params[self.pred_ab_layer][0].data[:, :, 0, 0] = self.pts_in_hull.T # automatically set upsampling kernel for layer in self.net._layer_names: if layer[-3:] == '_us': print('Setting upsampling layer kernel: %s' % layer) self.net.params[layer][0].data[:, 0, :, :] = np.array(((.25, .5, .25, 0), (.5, 1., .5, 0), (.25, .5, .25, 0), (0, 0, 0, 0)))[np.newaxis, :, :] # ***** Call forward ***** def net_forward(self, input_ab, input_mask): # INPUTS # ab 2xXxX input color patches (non-normalized) # mask 1xXxX input mask, indicating which points have been provided # assumes self.img_l_mc has been set if ColorizeImageBase.net_forward(self, input_ab, input_mask) == -1: return -1 net_input_prepped = np.concatenate((self.img_l_mc, self.input_ab_mc, self.input_mask_mult), axis=0) self.net.blobs['data_l_ab_mask'].data[...] = net_input_prepped self.net.forward() # return prediction self.output_rgb = lab2rgb_transpose(self.img_l, self.net.blobs[self.pred_ab_layer].data[0, :, :, :]) self._set_out_ab_() return self.output_rgb def get_img_forward(self): # get image with point estimate return self.output_rgb def get_img_gray(self): # Get black and white image return lab2rgb_transpose(self.img_l, np.zeros((2, self.Xd, self.Xd))) class ColorizeImageCaffeGlobDist(ColorizeImageCaffe): # Caffe colorization, with additional global histogram as input def __init__(self, Xd=256): ColorizeImageCaffe.__init__(self, Xd) self.glob_mask_mult = 1. self.glob_layer = 'glob_ab_313_mask' def net_forward(self, input_ab, input_mask, glob_dist=-1): # glob_dist is 313 array, or -1 if np.array(glob_dist).flatten()[0] == -1: # run without this, zero it out self.net.blobs[self.glob_layer].data[0, :-1, 0, 0] = 0. self.net.blobs[self.glob_layer].data[0, -1, 0, 0] = 0. else: # run conditioned on global histogram self.net.blobs[self.glob_layer].data[0, :-1, 0, 0] = glob_dist self.net.blobs[self.glob_layer].data[0, -1, 0, 0] = self.glob_mask_mult self.output_rgb = ColorizeImageCaffe.net_forward(self, input_ab, input_mask) self._set_out_ab_() return self.output_rgb class ColorizeImageCaffeDist(ColorizeImageCaffe): # caffe model which includes distribution prediction def __init__(self, Xd=256): ColorizeImageCaffe.__init__(self, Xd) self.dist_ab_set = False self.scale_S_layer = 'scale_S' self.dist_ab_S_layer = 'dist_ab_S' # softened distribution layer self.pts_grid = np.load('./data/color_bins/pts_grid.npy') # 529x2, all points self.in_hull = np.load('./data/color_bins/in_hull.npy') # 529 bool self.AB = self.pts_grid.shape[0] # 529 self.A = int(np.sqrt(self.AB)) # 23 self.B = int(np.sqrt(self.AB)) # 23 self.dist_ab_full = np.zeros((self.AB, self.Xd, self.Xd)) self.dist_ab_grid = np.zeros((self.A, self.B, self.Xd, self.Xd)) self.dist_entropy = np.zeros((self.Xd, self.Xd)) def prep_net(self, gpu_id, prototxt_path='', caffemodel_path='', S=.2): ColorizeImageCaffe.prep_net(self, gpu_id, prototxt_path=prototxt_path, caffemodel_path=caffemodel_path) self.S = S self.net.params[self.scale_S_layer][0].data[...] = S def net_forward(self, input_ab, input_mask): # INPUTS # ab 2xXxX input color patches (non-normalized) # mask 1xXxX input mask, indicating which points have been provided # assumes self.img_l_mc has been set function_return = ColorizeImageCaffe.net_forward(self, input_ab, input_mask) if np.array(function_return).flatten()[0] == -1: # errored out return -1 # set distribution # in-gamut, CxXxX, C = 313 self.dist_ab = self.net.blobs[self.dist_ab_S_layer].data[0, :, :, :] self.dist_ab_set = True # full grid, ABxXxX, AB = 529 self.dist_ab_full[self.in_hull, :, :] = self.dist_ab # gridded, AxBxXxX, A = 23 self.dist_ab_grid = self.dist_ab_full.reshape((self.A, self.B, self.Xd, self.Xd)) # return return function_return def get_ab_reccs(self, h, w, K=5, N=25000, return_conf=False): ''' Recommended colors at point (h,w) Call this after calling net_forward ''' if not self.dist_ab_set: print('Need to set prediction first') return 0 # randomly sample from pdf cmf = np.cumsum(self.dist_ab[:, h, w]) # CMF cmf = cmf / cmf[-1] cmf_bins = cmf # randomly sample N points rnd_pts = np.random.uniform(low=0, high=1.0, size=N) inds = np.digitize(rnd_pts, bins=cmf_bins) rnd_pts_ab = self.pts_in_hull[inds, :] # run k-means kmeans = KMeans(n_clusters=K).fit(rnd_pts_ab) # sort by cluster occupancy k_label_cnt = np.histogram(kmeans.labels_, np.arange(0, K + 1))[0] k_inds = np.argsort(k_label_cnt, axis=0)[::-1] cluster_per = 1. * k_label_cnt[k_inds] / N # percentage of points within cluster cluster_centers = kmeans.cluster_centers_[k_inds, :] # cluster centers # cluster_centers = np.random.uniform(low=-100,high=100,size=(N,2)) if return_conf: return cluster_centers, cluster_per else: return cluster_centers def compute_entropy(self): # compute the distribution entropy (really slow right now) self.dist_entropy = np.sum(self.dist_ab * np.log(self.dist_ab), axis=0) def plot_dist_grid(self, h, w): # Plots distribution at a given point plt.figure() plt.imshow(self.dist_ab_grid[:, :, h, w], extent=[-110, 110, 110, -110], interpolation='nearest') plt.colorbar() plt.ylabel('a') plt.xlabel('b') def plot_dist_entropy(self): # Plots distribution at a given point plt.figure() plt.imshow(-self.dist_entropy, interpolation='nearest') plt.colorbar()