from os import listdir, path import re import numpy import matplotlib matplotlib.use('Agg') from skimage.io import imread matplotlib.rcParams.update({'font.size': 2}) import matplotlib.pyplot as plt from mpl_toolkits.axes_grid1 import AxesGrid import theano import my_code.occluded_args as args from my_code.predict import load_column, model_runid import pdb class OcclusionStudy(object): def __init__(self, data_stream, occlusion_patch_size, test_imagepath=None): """ """ assert(occlusion_patch_size % 2 == 1) # for simplicity self.ds = data_stream self.occlusion_patch_size = occlusion_patch_size self.test_imagepath = test_imagepath patched_img_pad = (occlusion_patch_size - 1) / 2 patched_img_dim = data_stream.image_shape[0] - patched_img_pad*2 self.num_patch_centers = (patched_img_dim)**2 # number of images in test set self.patch_starts = [divmod(n, patched_img_dim) for n in xrange(self.num_patch_centers)] def nth_patch(self, n): i_start,j_start = self.patch_starts[n] i_end = i_start + self.occlusion_patch_size j_end = j_start + self.occlusion_patch_size return(i_start,i_end,j_start,j_end) def buffer_occluded_dataset(self): assert(self.test_imagepath) img = self.ds.feed_image(image_name=self.test_imagepath, image_dir='') channel_means = img.mean(axis=(0,1)) x_cache_block = numpy.zeros(((self.ds.cache_size,) + self.ds.image_shape), dtype=theano.config.floatX) n_full_cache_blocks, n_leftovers = divmod(self.num_patch_centers, self.ds.cache_size) for ith_cache_block in xrange(n_full_cache_blocks): ith_cache_block_end = (ith_cache_block + 1) * self.ds.cache_size idxs_to_full_dataset = list(range(ith_cache_block * self.ds.cache_size, ith_cache_block_end)) for ci,n in enumerate(idxs_to_full_dataset): i_start,i_end,j_start,j_end = self.nth_patch(n) x_cache_block[ci, ...] = img x_cache_block[ci, i_start:i_end, j_start:j_end, :] = channel_means yield numpy.rollaxis(x_cache_block, 3, 1), numpy.array(idxs_to_full_dataset, dtype='int32') # sneak the leftovers out, padded by the previous full cache block if n_leftovers: for ci, n in enumerate(list(xrange(ith_cache_block_end, len(self.patch_starts)))): i_start,i_end,j_start,j_end = self.nth_patch(n) x_cache_block[ci, ...] = img x_cache_block[ci, i_start:i_end, j_start:j_end, :] = channel_means idxs_to_full_dataset[ci] = n yield numpy.rollaxis(x_cache_block, 3, 1), numpy.array(idxs_to_full_dataset, dtype='int32') def accumulate_patches_into_heatmaps(self, all_test_output, outpath_prefix=''): outpath = "plots/%s_%s.png" % (outpath_prefix, path.splitext(path.basename(self.test_imagepath))[0]) # http://matplotlib.org/examples/axes_grid/demo_axes_grid.html fig = plt.figure() grid = AxesGrid(fig, 143, # similar to subplot(143) nrows_ncols = (1, 1)) orig_img = imread(self.test_imagepath+'.png') grid[0].imshow(orig_img) grid = AxesGrid(fig, 144, # similar to subplot(144) nrows_ncols = (2, 2), axes_pad = 0.15, label_mode = "1", share_all = True, cbar_location="right", cbar_mode="each", cbar_size="7%", cbar_pad="2%", ) for klass in xrange(all_test_output.shape[1]): accumulator = numpy.zeros(self.ds.image_shape[:2]) normalizer = numpy.zeros(self.ds.image_shape[:2]) for n in xrange(self.num_patch_centers): i_start,i_end,j_start,j_end = self.nth_patch(n) accumulator[i_start:i_end, j_start:j_end] += all_test_output[n,klass] normalizer[i_start:i_end, j_start:j_end] += 1 normalized_img = accumulator / normalizer im = grid[klass].imshow(normalized_img, interpolation="nearest", vmin=0, vmax=1) grid.cbar_axes[klass].colorbar(im) grid.axes_llc.set_xticks([]) grid.axes_llc.set_yticks([]) print("Saving figure as: %s" % outpath) plt.savefig(outpath, dpi=600, bbox_inches='tight') def file_iter(test_path): """ Iterates through full paths of images. """ e = ValueError("'%s' is neither a file nor a directory of images" % test_path) if path.isdir(test_path): images = [path.splitext(f)[0] for f in listdir(test_path) if re.search('\.(jpeg|png)', f)] if not len(images): raise e for image in images: yield test_path + image elif path.isfile(test_path): yield path.splitext(test_path)[0] else: raise e def plot_occluded_activations(model_file, test_path, patch_size, **kwargs): assert(model_file) runid = model_runid(model_file) column = load_column(model_file, **kwargs) os = OcclusionStudy(column.ds, patch_size) try: for path in file_iter(test_path): os.test_imagepath = path all_test_predictions, all_test_output = column.test(override_buffer=os.buffer_occluded_dataset, override_num_examples=os.num_patch_centers) os.accumulate_patches_into_heatmaps(all_test_output, runid) except KeyboardInterrupt: print "[ERROR] User terminated Occlusion Study" print "Done" if __name__ == '__main__': _ = args.get() plot_occluded_activations(model_file=_.model_file, test_path=_.test_path, patch_size=_.patch_size, train_dataset=_.train_dataset, center=_.center, normalize=_.normalize, train_flip=_.train_flip, test_dataset=None, random_seed=_.random_seed, valid_dataset_size=_.valid_dataset_size, filter_shape=_.filter_shape, cuda_convnet=_.cuda_convnet)