""" Extensions called during training to generate samples and diagnostic plots and printouts. """ import matplotlib matplotlib.use('agg') import matplotlib.pyplot as plt import numpy as np import os import theano.tensor as T import theano from blocks.extensions import SimpleExtension import viz import sampler class PlotSamples(SimpleExtension): def __init__(self, model, algorithm, X, path, n_samples=49, **kwargs): """ Generate samples from the model. The do() function is called as an extension during training. Generates 3 types of samples: - Sample from generative model - Sample from image denoising posterior distribution (default signal to noise of 1) - Sample from image inpainting posterior distribution (inpaint left half of image) """ super(PlotSamples, self).__init__(**kwargs) self.model = model self.path = path n_samples = np.min([n_samples, X.shape[0]]) self.X = X[:n_samples].reshape( (n_samples, model.n_colors, model.spatial_width, model.spatial_width)) self.n_samples = n_samples X_noisy = T.tensor4('X noisy samp', dtype=theano.config.floatX) t = T.matrix('t samp', dtype=theano.config.floatX) self.get_mu_sigma = theano.function([X_noisy, t], model.get_mu_sigma(X_noisy, t), allow_input_downcast=True) def do(self, callback_name, *args): import sys sys.setrecursionlimit(10000000) print "generating samples" base_fname_part1 = self.path + '/samples-' base_fname_part2 = '_batch%06d'%self.main_loop.status['iterations_done'] sampler.generate_samples(self.model, self.get_mu_sigma, n_samples=self.n_samples, inpaint=False, denoise_sigma=None, X_true=None, base_fname_part1=base_fname_part1, base_fname_part2=base_fname_part2) sampler.generate_samples(self.model, self.get_mu_sigma, n_samples=self.n_samples, inpaint=True, denoise_sigma=None, X_true=self.X, base_fname_part1=base_fname_part1, base_fname_part2=base_fname_part2) sampler.generate_samples(self.model, self.get_mu_sigma, n_samples=self.n_samples, inpaint=False, denoise_sigma=1, X_true=self.X, base_fname_part1=base_fname_part1, base_fname_part2=base_fname_part2) class PlotParameters(SimpleExtension): def __init__(self, model, blocks_model, path, **kwargs): super(PlotParameters, self).__init__(**kwargs) self.path = path self.model = model self.blocks_model = blocks_model def do(self, callback_name, *args): import sys sys.setrecursionlimit(10000000) print "plotting parameters" for param in self.blocks_model.parameters: param_name = param.name filename_safe_name = '-'.join(param_name.split('/')[2:]).replace(' ', '_') base_fname_part1 = self.path + '/params-' + filename_safe_name base_fname_part2 = '_batch%06d'%self.main_loop.status['iterations_done'] viz.plot_parameter(param.get_value(), base_fname_part1, base_fname_part2, title=param_name, n_colors=self.model.n_colors) class PlotGradients(SimpleExtension): def __init__(self, model, blocks_model, algorithm, X, path, **kwargs): super(PlotGradients, self).__init__(**kwargs) self.path = path self.X = X self.model = model self.blocks_model = blocks_model gradients = [] for param_name in sorted(self.blocks_model.parameters.keys()): gradients.append(algorithm.gradients[self.blocks_model.parameters[param_name]]) self.grad_f = theano.function(algorithm.inputs, gradients, allow_input_downcast=True) def do(self, callback_name, *args): print "plotting gradients" grad_vals = self.grad_f(self.X) keynames = sorted(self.blocks_model.parameters.keys()) for ii in xrange(len(keynames)): param_name = keynames[ii] val = grad_vals[ii] filename_safe_name = '-'.join(param_name.split('/')[2:]).replace(' ', '_') base_fname_part1 = self.path + '/grads-' + filename_safe_name base_fname_part2 = '_batch%06d'%self.main_loop.status['iterations_done'] viz.plot_parameter(val, base_fname_part1, base_fname_part2, title="grad " + param_name, n_colors=self.model.n_colors) class PlotInternalState(SimpleExtension): def __init__(self, model, blocks_model, state, features, X, path, **kwargs): super(PlotInternalState, self).__init__(**kwargs) self.path = path self.X = X self.model = model self.blocks_model = blocks_model self.internal_state_f = theano.function([features], state, allow_input_downcast=True) self.internal_state_names = [] for var in state: self.internal_state_names.append(var.name) def do(self, callback_name, *args): print "plotting internal state of network" state = self.internal_state_f(self.X) for ii in xrange(len(state)): param_name = self.internal_state_names[ii] val = state[ii] filename_safe_name = param_name.replace(' ', '_').replace('/', '-') base_fname_part1 = self.path + '/state-' + filename_safe_name base_fname_part2 = '_batch%06d'%self.main_loop.status['iterations_done'] viz.plot_parameter(val, base_fname_part1, base_fname_part2, title="state " + param_name, n_colors=self.model.n_colors) class PlotMonitors(SimpleExtension): def __init__(self, path, burn_in_iters=0, **kwargs): super(PlotMonitors, self).__init__(**kwargs) self.path = path self.burn_in_iters = burn_in_iters def do(self, callback_name, *args): print "plotting monitors" try: df = self.main_loop.log.to_dataframe() except AttributeError: # This starting breaking after a Blocks update. print "Failed to generate monitoring plots due to Blocks interface change." return iter_number = df.tail(1).index # Throw out the first burn_in values # as the objective is often much larger # in that period. if iter_number > self.burn_in_iters: df = df.loc[self.burn_in_iters:] cols = [col for col in df.columns if col.startswith(('cost', 'train', 'test'))] df = df[cols].interpolate(method='linear') # If we don't have any non-nan dataframes, don't plot if len(df) == 0: return try: axs = df.interpolate(method='linear').plot( subplots=True, legend=False, figsize=(5, len(cols)*2)) except TypeError: # This starting breaking after a different Blocks update. print "Failed to generate monitoring plots due to Blocks interface change." return for ax, cname in zip(axs, cols): ax.set_title(cname) fn = os.path.join(self.path, 'monitors_subplots_batch%06d.png' % self.main_loop.status['iterations_done']) plt.savefig(fn, bbox_inches='tight') plt.clf() df.plot(subplots=False, figsize=(15,10)) plt.gcf().tight_layout() fn = os.path.join(self.path, 'monitors_batch%06d.png' % self.main_loop.status['iterations_done']) plt.savefig(fn, bbox_inches='tight') plt.close('all') class LogLikelihood(SimpleExtension): def __init__(self, model, test_stream, rescale, num_eval_batches=10000, **kwargs): """ Compute and print log likelihood lower bound on test dataset. The do() function is called as an extension during training. """ super(LogLikelihood, self).__init__(**kwargs) self.model = model self.test_stream = test_stream self.rescale = rescale self.num_eval_batches = num_eval_batches features = T.matrix('features', dtype=theano.config.floatX) cost = self.model.cost(features) self.L_gap_func = theano.function([features,], cost, allow_input_downcast=True) def print_stats(self, L_gap): larr = np.array(L_gap) mn = np.mean(larr) sd = np.std(larr, ddof=1) stderr = sd / np.sqrt(len(L_gap)) # The log likelihood lower bound, K, is reported for the data after Z-scoring it. # Z-score rescale is the multiplicative factor by which the data was rescaled, to # give it standard deviation 1. print "eval batch=%05d (K-L_null)=%g bits/pix standard error=%g bits/pix Z-score rescale %g"%( len(L_gap), mn, stderr, self.rescale) def do(self, callback_name, *args): L_gap = [] n_colors = self.model.n_colors Xiter = None for kk in xrange(self.num_eval_batches): try: X = next(Xiter)[0] except: Xiter = self.test_stream.get_epoch_iterator() X = next(Xiter)[0] lg = -self.L_gap_func(X) L_gap.append(lg) if np.mod(kk, 1000) == 999: self.print_stats(L_gap) self.print_stats(L_gap) def decay_learning_rate(iteration, old_value): # TODO the numbers in this function should not be hard coded # this is called every epoch # reduce the learning rate by 10 every 1000 epochs min_value = 1e-4 decay_rate = np.exp(np.log(0.1)/1000.) new_value = decay_rate*old_value if new_value < min_value: new_value = min_value print "learning rate %g"%new_value return np.float32(new_value)