# Copyright 2017 Max Planck Society # Distributed under the BSD-3 Software license, # (See accompanying file ./LICENSE.txt or copy at # https://opensource.org/licenses/BSD-3-Clause) """This class implements POT training. """ import collections import logging import os import time import tensorflow as tf import utils from utils import ProgressBar from utils import TQDM import numpy as np import ops from metrics import Metrics slim = tf.contrib.slim def vgg_16(inputs, is_training=False, dropout_keep_prob=0.5, scope='vgg_16', fc_conv_padding='VALID', reuse=None): inputs = inputs * 255.0 inputs -= tf.constant([123.68, 116.779, 103.939], dtype=tf.float32) with tf.variable_scope(scope, 'vgg_16', [inputs], reuse=reuse) as sc: end_points_collection = sc.name + '_end_points' end_points = {} # Collect outputs for conv2d, fully_connected and max_pool2d. with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], outputs_collections=end_points_collection): end_points['pool0'] = inputs net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') net = slim.max_pool2d(net, [2, 2], scope='pool1') end_points['pool1'] = net net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2') net = slim.max_pool2d(net, [2, 2], scope='pool2') end_points['pool2'] = net net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3') net = slim.max_pool2d(net, [2, 2], scope='pool3') end_points['pool3'] = net net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4') net = slim.max_pool2d(net, [2, 2], scope='pool4') end_points['pool4'] = net net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5') net = slim.max_pool2d(net, [2, 2], scope='pool5') end_points['pool5'] = net # # Use conv2d instead of fully_connected layers. # net = slim.conv2d(net, 4096, [7, 7], padding=fc_conv_padding, scope='fc6') # net = slim.dropout(net, dropout_keep_prob, is_training=is_training, # scope='dropout6') # net = slim.conv2d(net, 4096, [1, 1], scope='fc7') # net = slim.dropout(net, dropout_keep_prob, is_training=is_training, # scope='dropout7') # net = slim.conv2d(net, num_classes, [1, 1], # activation_fn=None, # normalizer_fn=None, # scope='fc8') # Convert end_points_collection into a end_point dict. # end_points = slim.utils.convert_collection_to_dict(end_points_collection) return net, end_points def compute_moments(_inputs, moments=[2, 3]): """From an image input, compute moments""" _inputs_sq = tf.square(_inputs) _inputs_cube = tf.pow(_inputs, 3) height = int(_inputs.get_shape()[1]) width = int(_inputs.get_shape()[2]) channels = int(_inputs.get_shape()[3]) def ConvFlatten(x, kernel_size): # w_sum = tf.ones([kernel_size, kernel_size, channels, 1]) / (kernel_size * kernel_size * channels) w_sum = tf.eye(num_rows=channels, num_columns=channels, batch_shape=[kernel_size * kernel_size]) w_sum = tf.reshape(w_sum, [kernel_size, kernel_size, channels, channels]) w_sum = w_sum / (kernel_size * kernel_size) sum_ = tf.nn.conv2d(x, w_sum, strides=[1, 1, 1, 1], padding='VALID') size = prod_dim(sum_) assert size == (height - kernel_size + 1) * (width - kernel_size + 1) * channels, size return tf.reshape(sum_, [-1, size]) outputs = [] for size in [3, 4, 5]: mean = ConvFlatten(_inputs, size) square = ConvFlatten(_inputs_sq, size) var = square - tf.square(mean) if 2 in moments: outputs.append(var) if 3 in moments: cube = ConvFlatten(_inputs_cube, size) skewness = cube - 3.0 * mean * var - tf.pow(mean, 3) # Unnormalized outputs.append(skewness) return tf.concat(outputs, 1) def prod_dim(tensor): return np.prod([int(d) for d in tensor.get_shape()[1:]]) def flatten(tensor): return tf.reshape(tensor, [-1, prod_dim(tensor)]) class Pot(object): """A base class for running individual POTs. """ def __init__(self, opts, data, weights): # Create a new session with session.graph = default graph self._session = tf.Session() self._trained = False self._data = data self._data_weights = np.copy(weights) # Latent noise sampled ones to apply decoder while training self._noise_for_plots = opts['pot_pz_std'] * utils.generate_noise(opts, 1000) # Placeholders self._real_points_ph = None self._noise_ph = None # Init ops self._additional_init_ops = [] self._init_feed_dict = {} # Main operations # Optimizers with self._session.as_default(), self._session.graph.as_default(): logging.error('Building the graph...') self._build_model_internal(opts) # Make sure AdamOptimizer, if used in the Graph, is defined before # calling global_variables_initializer(). init = tf.global_variables_initializer() self._session.run(init) self._session.run(self._additional_init_ops, self._init_feed_dict) def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): # Cleaning the whole default Graph logging.error('Cleaning the graph...') tf.reset_default_graph() logging.error('Closing the session...') # Finishing the session self._session.close() def train(self, opts): """Train a POT model. """ with self._session.as_default(), self._session.graph.as_default(): self._train_internal(opts) self._trained = True def sample(self, opts, num=100): """Sample points from the trained POT model. """ assert self._trained, 'Can not sample from the un-trained POT' with self._session.as_default(), self._session.graph.as_default(): return self._sample_internal(opts, num) def train_mixture_discriminator(self, opts, fake_images): """Train classifier separating true data from points in fake_images. Return: prob_real: probabilities of the points from training data being the real points according to the trained mixture classifier. Numpy vector of shape (self._data.num_points,) prob_fake: probabilities of the points from fake_images being the real points according to the trained mixture classifier. Numpy vector of shape (len(fake_images),) """ with self._session.as_default(), self._session.graph.as_default(): return self._train_mixture_discriminator_internal(opts, fake_images) def _run_batch(self, opts, operation, placeholder, feed, placeholder2=None, feed2=None): """Wrapper around session.run to process huge data. It is asumed that (a) first dimension of placeholder enumerates separate points, and (b) that operation is independently applied to every point, i.e. we can split it point-wisely and then merge the results. The second placeholder is meant either for is_train flag for batch-norm or probabilities of dropout. TODO: write util function which will be called both from this method and MNIST classification evaluation as well. """ assert len(feed.shape) > 0, 'Empry feed.' num_points = feed.shape[0] batch_size = opts['tf_run_batch_size'] batches_num = int(np.ceil((num_points + 0.) / batch_size)) result = [] for idx in xrange(batches_num): if idx == batches_num - 1: if feed2 is None: res = self._session.run( operation, feed_dict={placeholder: feed[idx * batch_size:]}) else: res = self._session.run( operation, feed_dict={placeholder: feed[idx * batch_size:], placeholder2: feed2}) else: if feed2 is None: res = self._session.run( operation, feed_dict={placeholder: feed[idx * batch_size: (idx + 1) * batch_size]}) else: res = self._session.run( operation, feed_dict={placeholder: feed[idx * batch_size: (idx + 1) * batch_size], placeholder2: feed2}) if len(res.shape) == 1: # convert (n,) vector to (n,1) array res = np.reshape(res, [-1, 1]) result.append(res) result = np.vstack(result) assert len(result) == num_points return result def _build_model_internal(self, opts): """Build a TensorFlow graph with all the necessary ops. """ assert False, 'POT base class has no build_model method defined.' def _train_internal(self, opts): assert False, 'POT base class has no train method defined.' def _sample_internal(self, opts, num): assert False, 'POT base class has no sample method defined.' def _train_mixture_discriminator_internal(self, opts, fake_images): assert False, 'POT base class has no mixture discriminator method defined.' class ImagePot(Pot): """A simple POT implementation, suitable for pictures. """ def __init__(self, opts, data, weights): # One more placeholder for batch norm self._is_training_ph = None Pot.__init__(self, opts, data, weights) def dcgan_like_arch(self, opts, noise, is_training, reuse, keep_prob): output_shape = self._data.data_shape num_units = opts['g_num_filters'] batch_size = tf.shape(noise)[0] num_layers = opts['g_num_layers'] if opts['g_arch'] == 'dcgan': height = output_shape[0] / 2**num_layers width = output_shape[1] / 2**num_layers elif opts['g_arch'] == 'dcgan_mod': height = output_shape[0] / 2**(num_layers-1) width = output_shape[1] / 2**(num_layers-1) else: assert False h0 = ops.linear( opts, noise, num_units * height * width, scope='h0_lin') h0 = tf.reshape(h0, [-1, height, width, num_units]) h0 = tf.nn.relu(h0) layer_x = h0 for i in xrange(num_layers-1): scale = 2**(i+1) if opts['g_stride1_deconv']: # Sylvain, I'm worried about this part! _out_shape = [batch_size, height * scale / 2, width * scale / 2, num_units / scale * 2] layer_x = ops.deconv2d( opts, layer_x, _out_shape, d_h=1, d_w=1, scope='h%d_deconv_1x1' % i) layer_x = tf.nn.relu(layer_x) _out_shape = [batch_size, height * scale, width * scale, num_units / scale] layer_x = ops.deconv2d(opts, layer_x, _out_shape, scope='h%d_deconv' % i) if opts['batch_norm']: layer_x = ops.batch_norm(opts, layer_x, is_training, reuse, scope='bn%d' % i) layer_x = tf.nn.relu(layer_x) if opts['dropout']: _keep_prob = tf.minimum( 1., 0.9 - (0.9 - keep_prob) * float(i + 1) / (num_layers - 1)) layer_x = tf.nn.dropout(layer_x, _keep_prob) _out_shape = [batch_size] + list(output_shape) if opts['g_arch'] == 'dcgan': last_h = ops.deconv2d( opts, layer_x, _out_shape, scope='hlast_deconv') elif opts['g_arch'] == 'dcgan_mod': last_h = ops.deconv2d( opts, layer_x, _out_shape, d_h=1, d_w=1, scope='hlast_deconv') else: assert False if opts['input_normalize_sym']: return tf.nn.tanh(last_h) else: return tf.nn.sigmoid(last_h) def began_dec(self, opts, noise, is_training, reuse, keep_prob): """ Architecture reported here: https://arxiv.org/pdf/1703.10717.pdf """ output_shape = self._data.data_shape num_units = opts['g_num_filters'] num_layers = opts['g_num_layers'] batch_size = tf.shape(noise)[0] h0 = ops.linear( opts, noise, num_units * 8 * 8, scope='h0_lin') h0 = tf.reshape(h0, [-1, 8, 8, num_units]) layer_x = h0 for i in xrange(num_layers): if i % 3 < 2: # Don't change resolution layer_x = ops.conv2d(opts, layer_x, num_units, d_h=1, d_w=1, scope='h%d_conv' % i) layer_x = tf.nn.elu(layer_x) else: if i != num_layers - 1: # Upsampling by factor of 2 with NN scale = 2 ** (i / 3 + 1) layer_x = ops.upsample_nn(layer_x, [scale * 8, scale * 8], scope='h%d_upsample' % i, reuse=reuse) # Skip connection append = ops.upsample_nn(h0, [scale * 8, scale * 8], scope='h%d_skipup' % i, reuse=reuse) layer_x = tf.concat([layer_x, append], axis=3) last_h = ops.conv2d(opts, layer_x, output_shape[-1], d_h=1, d_w=1, scope='hlast_conv') if opts['input_normalize_sym']: return tf.nn.tanh(last_h) else: return tf.nn.sigmoid(last_h) def conv_up_res(self, opts, noise, is_training, reuse, keep_prob): output_shape = self._data.data_shape num_units = opts['g_num_filters'] batch_size = tf.shape(noise)[0] num_layers = opts['g_num_layers'] data_height = output_shape[0] data_width = output_shape[1] data_channels = output_shape[2] height = data_height / 2**num_layers width = data_width / 2**num_layers h0 = ops.linear( opts, noise, num_units * height * width, scope='h0_lin') h0 = tf.reshape(h0, [-1, height, width, num_units]) h0 = tf.nn.relu(h0) layer_x = h0 for i in xrange(num_layers-1): layer_x = tf.image.resize_nearest_neighbor(layer_x, (2 * height, 2 * width)) layer_x = ops.conv2d(opts, layer_x, num_units / 2, d_h=1, d_w=1, scope='conv2d_%d' % i) height *= 2 width *= 2 num_units /= 2 if opts['g_3x3_conv'] > 0: before = layer_x for j in range(opts['g_3x3_conv']): layer_x = ops.conv2d(opts, layer_x, num_units, d_h=1, d_w=1, scope='conv2d_3x3_%d_%d' % (i, j), conv_filters_dim=3) layer_x = tf.nn.relu(layer_x) layer_x += before # Residual connection. if opts['batch_norm']: layer_x = ops.batch_norm(opts, layer_x, is_training, reuse, scope='bn%d' % i) layer_x = tf.nn.relu(layer_x) if opts['dropout']: _keep_prob = tf.minimum( 1., 0.9 - (0.9 - keep_prob) * float(i + 1) / (num_layers - 1)) layer_x = tf.nn.dropout(layer_x, _keep_prob) layer_x = tf.image.resize_nearest_neighbor(layer_x, (2 * height, 2 * width)) layer_x = ops.conv2d(opts, layer_x, data_channels, d_h=1, d_w=1, scope='last_conv2d_%d' % i) if opts['input_normalize_sym']: return tf.nn.tanh(layer_x) else: return tf.nn.sigmoid(layer_x) def ali_deconv(self, opts, noise, is_training, reuse, keep_prob): output_shape = self._data.data_shape batch_size = tf.shape(noise)[0] noise_size = int(noise.get_shape()[1]) data_height = output_shape[0] data_width = output_shape[1] data_channels = output_shape[2] noise = tf.reshape(noise, [-1, 1, 1, noise_size]) num_units = opts['g_num_filters'] layer_params = [] layer_params.append([4, 1, num_units]) layer_params.append([4, 2, num_units / 2]) layer_params.append([4, 1, num_units / 4]) layer_params.append([4, 2, num_units / 8]) layer_params.append([5, 1, num_units / 8]) # For convolution: (n - k) / stride + 1 = s # For transposed: (s - 1) * stride + k = n layer_x = noise height = 1 width = 1 for i, (kernel, stride, channels) in enumerate(layer_params): height = (height - 1) * stride + kernel width = height layer_x = ops.deconv2d( opts, layer_x, [batch_size, height, width, channels], d_h=stride, d_w=stride, scope='h%d_deconv' % i, conv_filters_dim=kernel, padding='VALID') if opts['batch_norm']: layer_x = ops.batch_norm(opts, layer_x, is_training, reuse, scope='bn%d' % i) layer_x = ops.lrelu(layer_x, 0.1) assert height == data_height assert width == data_width # Then two 1x1 convolutions. layer_x = ops.conv2d(opts, layer_x, num_units / 8, d_h=1, d_w=1, scope='conv2d_1x1', conv_filters_dim=1) if opts['batch_norm']: layer_x = ops.batch_norm(opts, layer_x, is_training, reuse, scope='bnlast') layer_x = ops.lrelu(layer_x, 0.1) layer_x = ops.conv2d(opts, layer_x, data_channels, d_h=1, d_w=1, scope='conv2d_1x1_2', conv_filters_dim=1) if opts['input_normalize_sym']: return tf.nn.tanh(layer_x) else: return tf.nn.sigmoid(layer_x) def generator(self, opts, noise, is_training=False, reuse=False, keep_prob=1.): """ Decoder actually. """ output_shape = self._data.data_shape num_units = opts['g_num_filters'] with tf.variable_scope("GENERATOR", reuse=reuse): # if not opts['convolutions']: if opts['g_arch'] == 'mlp': layer_x = noise for i in range(opts['g_num_layers']): layer_x = ops.linear(opts, layer_x, num_units, 'h%d_lin' % i) layer_x = tf.nn.relu(layer_x) if opts['batch_norm']: layer_x = ops.batch_norm( opts, layer_x, is_training, reuse, scope='bn%d' % i) out = ops.linear(opts, layer_x, np.prod(output_shape), 'h%d_lin' % (i + 1)) out = tf.reshape(out, [-1] + list(output_shape)) if opts['input_normalize_sym']: return tf.nn.tanh(out) else: return tf.nn.sigmoid(out) elif opts['g_arch'] in ['dcgan', 'dcgan_mod']: return self.dcgan_like_arch(opts, noise, is_training, reuse, keep_prob) elif opts['g_arch'] == 'conv_up_res': return self.conv_up_res(opts, noise, is_training, reuse, keep_prob) elif opts['g_arch'] == 'ali': return self.ali_deconv(opts, noise, is_training, reuse, keep_prob) elif opts['g_arch'] == 'began': return self.began_dec(opts, noise, is_training, reuse, keep_prob) else: raise ValueError('%s unknown' % opts['g_arch']) def discriminator(self, opts, input_, prefix='DISCRIMINATOR', reuse=False): """Discriminator for the GAN objective """ num_units = opts['d_num_filters'] num_layers = opts['d_num_layers'] nowozin_trick = opts['gan_p_trick'] # No convolutions as GAN happens in the latent space with tf.variable_scope(prefix, reuse=reuse): hi = input_ for i in range(num_layers): hi = ops.linear(opts, hi, num_units, scope='h%d_lin' % (i+1)) hi = tf.nn.relu(hi) hi = ops.linear(opts, hi, 1, scope='final_lin') if nowozin_trick: # We are doing GAN between our model Qz and the true Pz. # We know analytical form of the true Pz. # The optimal discriminator for D_JS(Pz, Qz) is given by: # Dopt(x) = log dPz(x) - log dQz(x) # And we know exactly dPz(x). So add log dPz(x) explicitly # to the discriminator and let it learn only the remaining # dQz(x) term. This appeared in the AVB paper. assert opts['latent_space_distr'] == 'normal' sigma2_p = float(opts['pot_pz_std']) ** 2 normsq = tf.reduce_sum(tf.square(input_), 1) hi = hi - normsq / 2. / sigma2_p \ - 0.5 * tf.log(2. * np.pi) \ - 0.5 * opts['latent_space_dim'] * np.log(sigma2_p) return hi def pz_sampler(self, opts, input_, prefix='PZ_SAMPLER', reuse=False): """Transformation to be applied to the sample from Pz We are trying to match Qz to phi(Pz), where phi is defined by this function """ dim = opts['latent_space_dim'] with tf.variable_scope(prefix, reuse=reuse): matrix = tf.get_variable( "W", [dim, dim], tf.float32, tf.constant_initializer(np.identity(dim))) bias = tf.get_variable( "b", [dim], initializer=tf.constant_initializer(0.)) return tf.matmul(input_, matrix) + bias def get_batch_size(self, opts, input_): return tf.cast(tf.shape(input_)[0], tf.float32)# opts['batch_size'] def moments_stats(self, opts, input_): """ Compute estimates of the first 4 moments of the coordinates based on the sample in input_. Compare them to the desired population values and return a corresponding loss. """ input_ = input_ / float(opts['pot_pz_std']) # If Pz = Qz then input_ should now come from # a product of pz_dim Gaussians N(0, 1) # Thus first moments should be 0 p1 = tf.reduce_mean(input_, 0) center_inp = input_ - p1 # Broadcasting # Second centered and normalized moments should be 1 p2 = tf.sqrt(1e-5 + tf.reduce_mean(tf.square(center_inp), 0)) normed_inp = center_inp / p2 # Third central moment should be 0 # p3 = tf.pow(1e-5 + tf.abs(tf.reduce_mean(tf.pow(center_inp, 3), 0)), 1.0 / 3.0) p3 = tf.abs(tf.reduce_mean(tf.pow(center_inp, 3), 0)) # 4th central moment of any uni-variate Gaussian = 3 * sigma^4 # p4 = tf.pow(1e-5 + tf.reduce_mean(tf.pow(center_inp, 4), 0) / 3.0, 1.0 / 4.0) p4 = tf.reduce_mean(tf.pow(center_inp, 4), 0) / 3. def zero_t(v): return tf.sqrt(1e-5 + tf.reduce_mean(tf.square(v))) def one_t(v): # The function below takes its minimum value 1. at v = 1. return tf.sqrt(1e-5 + tf.reduce_mean(tf.maximum(tf.square(v), 1.0 / (1e-5 + tf.square(v))))) return tf.stack([zero_t(p1), one_t(p2), zero_t(p3), one_t(p4)]) def discriminator_test(self, opts, input_): """Deterministic discriminator using simple tests.""" if opts['z_test'] == 'cramer': test_v = self.discriminator_cramer_test(opts, input_) elif opts['z_test'] == 'anderson': test_v = self.discriminator_anderson_test(opts, input_) elif opts['z_test'] == 'moments': test_v = tf.reduce_mean(self.moments_stats(opts, input_)) / 10.0 elif opts['z_test'] == 'lks': test_v = self.discriminator_lks_test(opts, input_) else: raise ValueError('%s Unknown' % opts['z_test']) return test_v def discriminator_cramer_test(self, opts, input_): """Deterministic discriminator using Cramer von Mises Test. """ add_dim = opts['z_test_proj_dim'] if add_dim > 0: dim = int(input_.get_shape()[1]) proj = np.random.rand(dim, add_dim) proj = proj - np.mean(proj, 0) norms = np.sqrt(np.sum(np.square(proj), 0) + 1e-5) proj = tf.constant(proj / norms, dtype=tf.float32) projected_x = tf.matmul(input_, proj) # Shape [batch_size, add_dim]. # Shape [batch_size, z_dim+add_dim] all_dims_x = tf.concat([input_, projected_x], 1) else: all_dims_x = input_ # top_k can only sort on the last dimension and we want to sort the # first one (batch_size). batch_size = self.get_batch_size(opts, all_dims_x) transposed = tf.transpose(all_dims_x, perm=[1, 0]) values, indices = tf.nn.top_k(transposed, k=tf.cast(batch_size, tf.int32)) values = tf.reverse(values, [1]) #values = tf.Print(values, [values], "sorted values") normal_dist = tf.contrib.distributions.Normal(0., float(opts['pot_pz_std'])) # normal_cdf = normal_dist.cdf(values) #normal_cdf = tf.Print(normal_cdf, [normal_cdf], "normal_cdf") expected = (2 * tf.range(1, batch_size+1, 1, dtype="float") - 1) / (2.0 * batch_size) #expected = tf.Print(expected, [expected], "expected") # We don't use the constant. # constant = 1.0 / (12.0 * batch_size * batch_size) # stat = constant + tf.reduce_sum(tf.square(expected - normal_cdf), 1) / batch_size stat = tf.reduce_sum(tf.square(expected - normal_cdf), 1) / batch_size stat = tf.reduce_mean(stat) #stat = tf.Print(stat, [stat], "stat") return stat def discriminator_anderson_test(self, opts, input_): """Deterministic discriminator using the Anderson Darling test. """ # A-D test says to normalize data before computing the statistic # Because true mean and variance are known, we are supposed to use # the population parameters for that, but wiki says it's better to # still use the sample estimates while normalizing means = tf.reduce_mean(input_, 0) input_ = input_ - means # Broadcasting stds = tf.sqrt(1e-5 + tf.reduce_mean(tf.square(input_), 0)) input_= input_ / stds # top_k can only sort on the last dimension and we want to sort the # first one (batch_size). batch_size = self.get_batch_size(opts, input_) transposed = tf.transpose(input_, perm=[1, 0]) values, indices = tf.nn.top_k(transposed, k=tf.cast(batch_size, tf.int32)) values = tf.reverse(values, [1]) normal_dist = tf.contrib.distributions.Normal(0., float(opts['pot_pz_std'])) normal_cdf = normal_dist.cdf(values) # ln_normal_cdf is of shape (z_dim, batch_size) ln_normal_cdf = tf.log(normal_cdf) ln_one_normal_cdf = tf.log(1.0 - normal_cdf) w1 = 2 * tf.range(1, batch_size + 1, 1, dtype="float") - 1 w2 = 2 * tf.range(batch_size - 1, -1, -1, dtype="float") + 1 stat = -batch_size - tf.reduce_sum(w1 * ln_normal_cdf + \ w2 * ln_one_normal_cdf, 1) / batch_size # stat is of shape (z_dim) stat = tf.reduce_mean(tf.square(stat)) return stat def discriminator_lks_lin_test(self, opts, input_): """Deterministic discriminator using Kernel Stein Discrepancy test refer to LKS test on page 3 of https://arxiv.org/pdf/1705.07673.pdf The statistic basically reads: \[ \frac{2}{n}\sum_{i=1}^n \left( frac{<x_{2i}, x_{2i - 1}>}{\sigma_p^4} + d/\sigma_k^2 - \|x_{2i} - x_{2i - 1}\|^2\left(\frac{1}{\sigma_p^2\sigma_k^2} + \frac{1}{\sigma_k^4}\right) \right) \exp( - \|x_{2i} - x_{2i - 1}\|^2/2/\sigma_k^2) \] """ # To check the typical sizes of the test for Pz = Qz, uncomment # input_ = opts['pot_pz_std'] * utils.generate_noise(opts, 100000) batch_size = self.get_batch_size(opts, input_) batch_size = tf.cast(batch_size, tf.int32) half_size = batch_size / 2 # s1 = tf.slice(input_, [0, 0], [half_size, -1]) # s2 = tf.slice(input_, [half_size, 0], [half_size, -1]) s1 = input_[:half_size, :] s2 = input_[half_size:, :] dotprods = tf.reduce_sum(tf.multiply(s1, s2), axis=1) distances = tf.reduce_sum(tf.square(s1 - s2), axis=1) sigma2_p = opts['pot_pz_std'] ** 2 # var = std ** 2 # Median heuristic for the sigma^2 of Gaussian kernel # sigma2_k = tf.nn.top_k(distances, half_size).values[half_size - 1] # Maximum heuristic for the sigma^2 of Gaussian kernel # sigma2_k = tf.nn.top_k(distances, 1).values[0] sigma2_k = opts['latent_space_dim'] * sigma2_p if opts['verbose'] == 2: sigma2_k = tf.Print(sigma2_k, [tf.nn.top_k(distances, 1).values[0]], 'Maximal squared pairwise distance:') sigma2_k = tf.Print(sigma2_k, [tf.reduce_mean(distances)], 'Average squared pairwise distance:') sigma2_k = tf.Print(sigma2_k, [sigma2_k], 'Kernel width:') # sigma2_k = tf.Print(sigma2_k, [sigma2_k], 'Kernel width:') res = dotprods / sigma2_p ** 2 \ - distances * (1. / sigma2_p / sigma2_k + 1. / sigma2_k ** 2) \ + opts['latent_space_dim'] / sigma2_k res = tf.multiply(res, tf.exp(- distances / 2./ sigma2_k)) stat = tf.reduce_mean(res) return stat def discriminator_lks_test(self, opts, input_): """Deterministic discriminator using Kernel Stein Discrepancy test refer to the quadratic test of https://arxiv.org/pdf/1705.07673.pdf The statistic basically reads: \[ \frac{1}{n^2 - n}\sum_{i \neq j} \left( frac{<x_i, x__j>}{\sigma_p^4} + d/\sigma_k^2 - \|x_i - x_j\|^2\left(\frac{1}{\sigma_p^2\sigma_k^2} + \frac{1}{\sigma_k^4}\right) \right) \exp( - \|x_i - x_j\|^2/2/\sigma_k^2) \] """ n = self.get_batch_size(opts, input_) n = tf.cast(n, tf.int32) half_size = (n * n - n) / 2 nf = tf.cast(n, tf.float32) norms = tf.reduce_sum(tf.square(input_), axis=1, keep_dims=True) dotprods = tf.matmul(input_, input_, transpose_b=True) distances = norms + tf.transpose(norms) - 2. * dotprods sigma2_p = opts['pot_pz_std'] ** 2 # var = std ** 2 # Median heuristic for the sigma^2 of Gaussian kernel # sigma2_k = tf.nn.top_k(tf.reshape(distances, [-1]), half_size).values[half_size - 1] # Maximal heuristic for the sigma^2 of Gaussian kernel # sigma2_k = tf.nn.top_k(tf.reshape(distances, [-1]), 1).values[0] sigma2_k = opts['latent_space_dim'] * sigma2_p if opts['verbose'] == 2: sigma2_k = tf.Print(sigma2_k, [tf.nn.top_k(tf.reshape(distances, [-1]), 1).values[0]], 'Maximal squared pairwise distance:') sigma2_k = tf.Print(sigma2_k, [tf.reduce_mean(distances)], 'Average squared pairwise distance:') sigma2_k = tf.Print(sigma2_k, [sigma2_k], 'Kernel width:') res = dotprods / sigma2_p ** 2 \ - distances * (1. / sigma2_p / sigma2_k + 1. / sigma2_k ** 2) \ + opts['latent_space_dim'] / sigma2_k res = tf.multiply(res, tf.exp(- distances / 2./ sigma2_k)) res = tf.multiply(res, 1. - tf.eye(n)) stat = tf.reduce_sum(res) / (nf * nf - nf) # stat = tf.reduce_sum(res) / (nf * nf) return stat def discriminator_mmd_test(self, opts, sample_qz, sample_pz): """U statistic for MMD(Qz, Pz) with the RBF kernel """ sigma2_p = opts['pot_pz_std'] ** 2 # var = std ** 2 kernel = 'IM' n = self.get_batch_size(opts, sample_qz) n = tf.cast(n, tf.int32) nf = tf.cast(n, tf.float32) half_size = (n * n - n) / 2 # Pz norms_pz = tf.reduce_sum(tf.square(sample_pz), axis=1, keep_dims=True) dotprods_pz = tf.matmul(sample_pz, sample_pz, transpose_b=True) distances_pz = norms_pz + tf.transpose(norms_pz) - 2. * dotprods_pz # Qz norms_qz = tf.reduce_sum(tf.square(sample_qz), axis=1, keep_dims=True) dotprods_qz = tf.matmul(sample_qz, sample_qz, transpose_b=True) distances_qz = norms_qz + tf.transpose(norms_qz) - 2. * dotprods_qz # Pz vs Qz dotprods = tf.matmul(sample_qz, sample_pz, transpose_b=True) distances = norms_qz + tf.transpose(norms_pz) - 2. * dotprods if opts['verbose'] == 2: distances = tf.Print(distances, [tf.nn.top_k(tf.reshape(distances_qz, [-1]), 1).values[0]], 'Maximal Qz squared pairwise distance:') distances = tf.Print(distances, [tf.reduce_mean(distances_qz)], 'Average Qz squared pairwise distance:') distances = tf.Print(distances, [tf.nn.top_k(tf.reshape(distances_pz, [-1]), 1).values[0]], 'Maximal Pz squared pairwise distance:') distances = tf.Print(distances, [tf.reduce_mean(distances_pz)], 'Average Pz squared pairwise distance:') distances = tf.Print(distances, [tf.nn.top_k(tf.reshape(distances, [-1]), 1).values[0]], 'Maximal squared pairwise distance:') distances = tf.Print(distances, [tf.nn.top_k(tf.reshape(distances, [-1]), n * n).values[n * n - 1]], 'Minimal squared pairwise distance:') distances = tf.Print(distances, [tf.reduce_mean(distances)], 'Average squared pairwise distance:') if kernel == 'RBF': # RBF kernel # Median heuristic for the sigma^2 of Gaussian kernel sigma2_k = tf.nn.top_k(tf.reshape(distances, [-1]), half_size).values[half_size - 1] sigma2_k += tf.nn.top_k(tf.reshape(distances_qz, [-1]), half_size).values[half_size - 1] # Maximal heuristic for the sigma^2 of Gaussian kernel # sigma2_k = tf.nn.top_k(tf.reshape(distances_qz, [-1]), 1).values[0] # sigma2_k += tf.nn.top_k(tf.reshape(distances, [-1]), 1).values[0] # sigma2_k = opts['latent_space_dim'] * sigma2_p sigma2_k = tf.Print(sigma2_k, [sigma2_k], 'Kernel width:') res1 = tf.exp( - distances_qz / 2. / sigma2_k) res1 += tf.exp( - distances_pz / 2. / sigma2_k) res1 = tf.multiply(res1, 1. - tf.eye(n)) res1 = tf.reduce_sum(res1) / (nf * nf - nf) res2 = tf.exp( - distances / 2. / sigma2_k) res2 = tf.reduce_sum(res2) * 2. / (nf * nf) stat = res1 - res2 # stat = tf.reduce_sum(res) / (nf * nf) elif kernel == 'IM': # C = tf.nn.top_k(tf.reshape(distances, [-1]), half_size).values[half_size - 1] # C += tf.nn.top_k(tf.reshape(distances_qz, [-1]), half_size).values[half_size - 1] C = 2 * opts['latent_space_dim'] * sigma2_p res1 = C / (C + distances_qz) res1 += C / (C + distances_pz) res1 = tf.multiply(res1, 1. - tf.eye(n)) res1 = tf.reduce_sum(res1) / (nf * nf - nf) res1 = tf.Print(res1, [res1], 'First two terms:') res2 = C / (C + distances) res2 = tf.reduce_sum(res2) * 2. / (nf * nf) res2 = tf.Print(res2, [res2], 'Negative term:') stat = res1 - res2 # stat = tf.reduce_sum(res) / (nf * nf) return stat def correlation_loss(self, opts, input_): """ Independence test based on Pearson's correlation. Keep in mind that this captures only linear dependancies. However, for multivariate Gaussian independence is equivalent to zero correlation. """ batch_size = self.get_batch_size(opts, input_) dim = int(input_.get_shape()[1]) transposed = tf.transpose(input_, perm=[1, 0]) mean = tf.reshape(tf.reduce_mean(transposed, axis=1), [-1, 1]) centered_transposed = transposed - mean # Broadcasting mean cov = tf.matmul(centered_transposed, centered_transposed, transpose_b=True) cov = cov / (batch_size - 1) #cov = tf.Print(cov, [cov], "cov") sigmas = tf.sqrt(tf.diag_part(cov) + 1e-5) #sigmas = tf.Print(sigmas, [sigmas], "sigmas") sigmas = tf.reshape(sigmas, [1, -1]) sigmas = tf.matmul(sigmas, sigmas, transpose_a=True) #sigmas = tf.Print(sigmas, [sigmas], "sigmas") # Pearson's correlation corr = cov / sigmas triangle = tf.matrix_set_diag(tf.matrix_band_part(corr, 0, -1), tf.zeros(dim)) #triangle = tf.Print(triangle, [triangle], "triangle") loss = tf.reduce_sum(tf.square(triangle)) / ((dim * dim - dim) / 2.0) #loss = tf.Print(loss, [loss], "Correlation loss") return loss def encoder(self, opts, input_, is_training=False, reuse=False, keep_prob=1.): if opts['e_add_noise']: def add_noise(x): shape = tf.shape(x) return x + tf.truncated_normal(shape, 0.0, 0.01) def do_nothing(x): return x input_ = tf.cond(is_training, lambda: add_noise(input_), lambda: do_nothing(input_)) num_units = opts['e_num_filters'] num_layers = opts['e_num_layers'] with tf.variable_scope("ENCODER", reuse=reuse): if not opts['convolutions']: hi = input_ for i in range(num_layers): hi = ops.linear(opts, hi, num_units, scope='h%d_lin' % i) if opts['batch_norm']: hi = ops.batch_norm(opts, hi, is_training, reuse, scope='bn%d' % i) hi = tf.nn.relu(hi) if opts['e_is_random']: latent_mean = ops.linear( opts, hi, opts['latent_space_dim'], 'h%d_lin' % (i + 1)) log_latent_sigmas = ops.linear( opts, hi, opts['latent_space_dim'], 'h%d_lin_sigma' % (i + 1)) return latent_mean, log_latent_sigmas else: return ops.linear(opts, hi, opts['latent_space_dim'], 'h%d_lin' % (i + 1)) elif opts['e_arch'] == 'dcgan': return self.dcgan_encoder(opts, input_, is_training, reuse, keep_prob) elif opts['e_arch'] == 'ali': return self.ali_encoder(opts, input_, is_training, reuse, keep_prob) elif opts['e_arch'] == 'began': return self.began_encoder(opts, input_, is_training, reuse, keep_prob) else: raise ValueError('%s Unknown' % opts['e_arch']) def dcgan_encoder(self, opts, input_, is_training=False, reuse=False, keep_prob=1.): num_units = opts['e_num_filters'] num_layers = opts['e_num_layers'] layer_x = input_ for i in xrange(num_layers): scale = 2**(num_layers-i-1) layer_x = ops.conv2d(opts, layer_x, num_units / scale, scope='h%d_conv' % i) if opts['batch_norm']: layer_x = ops.batch_norm(opts, layer_x, is_training, reuse, scope='bn%d' % i) layer_x = tf.nn.relu(layer_x) if opts['dropout']: _keep_prob = tf.minimum( 1., 0.9 - (0.9 - keep_prob) * float(i + 1) / num_layers) layer_x = tf.nn.dropout(layer_x, _keep_prob) if opts['e_3x3_conv'] > 0: before = layer_x for j in range(opts['e_3x3_conv']): layer_x = ops.conv2d(opts, layer_x, num_units / scale, d_h=1, d_w=1, scope='conv2d_3x3_%d_%d' % (i, j), conv_filters_dim=3) layer_x = tf.nn.relu(layer_x) layer_x += before # Residual connection. if opts['e_is_random']: latent_mean = ops.linear( opts, layer_x, opts['latent_space_dim'], scope='hlast_lin') log_latent_sigmas = ops.linear( opts, layer_x, opts['latent_space_dim'], scope='hlast_lin_sigma') return latent_mean, log_latent_sigmas else: return ops.linear(opts, layer_x, opts['latent_space_dim'], scope='hlast_lin') def ali_encoder(self, opts, input_, is_training=False, reuse=False, keep_prob=1.): num_units = opts['e_num_filters'] layer_params = [] layer_params.append([5, 1, num_units / 8]) layer_params.append([4, 2, num_units / 4]) layer_params.append([4, 1, num_units / 2]) layer_params.append([4, 2, num_units]) layer_params.append([4, 1, num_units * 2]) # For convolution: (n - k) / stride + 1 = s # For transposed: (s - 1) * stride + k = n layer_x = input_ height = int(layer_x.get_shape()[1]) width = int(layer_x.get_shape()[2]) assert height == width for i, (kernel, stride, channels) in enumerate(layer_params): height = (height - kernel) / stride + 1 width = height # print((height, width)) layer_x = ops.conv2d( opts, layer_x, channels, d_h=stride, d_w=stride, scope='h%d_conv' % i, conv_filters_dim=kernel, padding='VALID') if opts['batch_norm']: layer_x = ops.batch_norm(opts, layer_x, is_training, reuse, scope='bn%d' % i) layer_x = ops.lrelu(layer_x, 0.1) assert height == 1 assert width == 1 # Then two 1x1 convolutions. layer_x = ops.conv2d(opts, layer_x, num_units * 2, d_h=1, d_w=1, scope='conv2d_1x1', conv_filters_dim=1) if opts['batch_norm']: layer_x = ops.batch_norm(opts, layer_x, is_training, reuse, scope='bnlast') layer_x = ops.lrelu(layer_x, 0.1) layer_x = ops.conv2d(opts, layer_x, num_units / 2, d_h=1, d_w=1, scope='conv2d_1x1_2', conv_filters_dim=1) if opts['e_is_random']: latent_mean = ops.linear( opts, layer_x, opts['latent_space_dim'], scope='hlast_lin') log_latent_sigmas = ops.linear( opts, layer_x, opts['latent_space_dim'], scope='hlast_lin_sigma') return latent_mean, log_latent_sigmas else: return ops.linear(opts, layer_x, opts['latent_space_dim'], scope='hlast_lin') def began_encoder(self, opts, input_, is_training=False, reuse=False, keep_prob=1.): num_units = opts['e_num_filters'] assert num_units == opts['g_num_filters'], 'BEGAN requires same number of filters in encoder and decoder' num_layers = opts['e_num_layers'] layer_x = ops.conv2d(opts, input_, num_units, scope='h_first_conv') for i in xrange(num_layers): if i % 3 < 2: if i != num_layers - 2: ii = i - (i / 3) scale = (ii + 1 - ii / 2) else: ii = i - (i / 3) scale = (ii - (ii - 1) / 2) layer_x = ops.conv2d(opts, layer_x, num_units * scale, d_h=1, d_w=1, scope='h%d_conv' % i) layer_x = tf.nn.elu(layer_x) else: if i != num_layers - 1: layer_x = ops.downsample(layer_x, scope='h%d_maxpool' % i, reuse=reuse) # Tensor should be [N, 8, 8, filters] right now if opts['e_is_random']: latent_mean = ops.linear( opts, layer_x, opts['latent_space_dim'], scope='hlast_lin') log_latent_sigmas = ops.linear( opts, layer_x, opts['latent_space_dim'], scope='hlast_lin_sigma') return latent_mean, log_latent_sigmas else: return ops.linear(opts, layer_x, opts['latent_space_dim'], scope='hlast_lin') def _data_augmentation(self, opts, real_points, is_training): if not opts['data_augm']: return real_points height = int(real_points.get_shape()[1]) width = int(real_points.get_shape()[2]) depth = int(real_points.get_shape()[3]) # logging.error("real_points shape", real_points.get_shape()) def _distort_func(image): # tf.image.per_image_standardization(image), should we? # Pad with zeros. image = tf.image.resize_image_with_crop_or_pad( image, height+4, width+4) image = tf.random_crop(image, [height, width, depth]) image = tf.image.random_flip_left_right(image) image = tf.image.random_brightness(image, max_delta=0.1) image = tf.minimum(tf.maximum(image, 0.0), 1.0) image = tf.image.random_contrast(image, lower=0.8, upper=1.3) image = tf.minimum(tf.maximum(image, 0.0), 1.0) image = tf.image.random_hue(image, 0.08) image = tf.minimum(tf.maximum(image, 0.0), 1.0) image = tf.image.random_saturation(image, lower=0.8, upper=1.3) image = tf.minimum(tf.maximum(image, 0.0), 1.0) return image def _regular_func(image): # tf.image.per_image_standardization(image)? return image distorted_images = tf.cond( is_training, lambda: tf.map_fn(_distort_func, real_points, parallel_iterations=100), lambda: tf.map_fn(_regular_func, real_points, parallel_iterations=100)) return distorted_images def _recon_loss_using_disc_encoder( self, opts, reconstructed_training, encoded_training, real_points, is_training_ph, keep_prob_ph): """Build an additional loss using the encoder as discriminator.""" reconstructed_reencoded_sg = self.encoder( opts, tf.stop_gradient(reconstructed_training), is_training=is_training_ph, keep_prob=keep_prob_ph, reuse=True) if opts['e_is_random']: reconstructed_reencoded_sg = reconstructed_reencoded_sg[0] reconstructed_reencoded = self.encoder( opts, reconstructed_training, is_training=is_training_ph, keep_prob=keep_prob_ph, reuse=True) if opts['e_is_random']: reconstructed_reencoded = reconstructed_reencoded[0] # Below line enforces the forward to be reconstructed_reencoded and backwards to NOT change the encoder.... crazy_hack = reconstructed_reencoded - reconstructed_reencoded_sg +\ tf.stop_gradient(reconstructed_reencoded_sg) encoded_training_sg = self.encoder( opts, tf.stop_gradient(real_points), is_training=is_training_ph, keep_prob=keep_prob_ph, reuse=True) if opts['e_is_random']: encoded_training_sg = encoded_training_sg[0] adv_fake_layer = ops.linear(opts, reconstructed_reencoded_sg, 1, scope='adv_layer') adv_true_layer = ops.linear(opts, encoded_training_sg, 1, scope='adv_layer', reuse=True) adv_fake = tf.nn.sigmoid_cross_entropy_with_logits( logits=adv_fake_layer, labels=tf.zeros_like(adv_fake_layer)) adv_true = tf.nn.sigmoid_cross_entropy_with_logits( logits=adv_true_layer, labels=tf.ones_like(adv_true_layer)) adv_fake = tf.reduce_mean(adv_fake) adv_true = tf.reduce_mean(adv_true) adv_c_loss = adv_fake + adv_true emb_c = tf.reduce_sum(tf.square(crazy_hack - tf.stop_gradient(encoded_training)), 1) emb_c_loss = tf.reduce_mean(tf.sqrt(emb_c + 1e-5)) # Normalize the loss, so that it does not depend on how good the # discriminator is. emb_c_loss = emb_c_loss / tf.stop_gradient(emb_c_loss) return adv_c_loss, emb_c_loss def _recon_loss_using_disc_conv(self, opts, reconstructed_training, real_points, is_training, keep_prob): """Build an additional loss using a discriminator in X space.""" def _conv_flatten(x, kernel_size): height = int(x.get_shape()[1]) width = int(x.get_shape()[2]) channels = int(x.get_shape()[3]) w_sum = tf.eye(num_rows=channels, num_columns=channels, batch_shape=[kernel_size * kernel_size]) w_sum = tf.reshape(w_sum, [kernel_size, kernel_size, channels, channels]) w_sum = w_sum / (kernel_size * kernel_size) sum_ = tf.nn.conv2d(x, w_sum, strides=[1, 1, 1, 1], padding='SAME') size = prod_dim(sum_) assert size == height * width * channels, size return tf.reshape(sum_, [-1, size]) def _gram_scores(tensor, kernel_size): assert len(tensor.get_shape()) == 4, tensor ttensor = tf.transpose(tensor, [3, 1, 2, 0]) rand_indices = tf.random_shuffle(tf.range(ttensor.get_shape()[0])) shuffled = tf.gather(ttensor, rand_indices) shuffled = tf.transpose(shuffled, [3, 1, 2, 0]) cross_p = _conv_flatten(tensor * shuffled, kernel_size) # shape [batch_size, height * width * channels] diag_p = _conv_flatten(tf.square(tensor), kernel_size) # shape [batch_size, height * width * channels] return cross_p, diag_p def _architecture(inputs, reuse=None): with tf.variable_scope('DISC_X_LOSS', reuse=reuse): num_units = opts['adv_c_num_units'] num_layers = 1 filter_sizes = opts['adv_c_patches_size'] if isinstance(filter_sizes, int): filter_sizes = [filter_sizes] else: filter_sizes = [int(n) for n in filter_sizes.split(',')] embedded_outputs = [] linear_outputs = [] for filter_size in filter_sizes: layer_x = inputs for i in xrange(num_layers): # scale = 2**(num_layers-i-1) layer_x = ops.conv2d(opts, layer_x, num_units, d_h=1, d_w=1, scope='h%d_conv%d' % (i, filter_size), conv_filters_dim=filter_size, padding='SAME') # if opts['batch_norm']: # layer_x = ops.batch_norm(opts, layer_x, is_training, reuse, scope='bn%d_%d' % (i, filter_size)) layer_x = ops.lrelu(layer_x, 0.1) last = ops.conv2d( opts, layer_x, 1, d_h=1, d_w=1, scope="last_lin%d" % filter_size, conv_filters_dim=1, l2_norm=True) if opts['cross_p_w'] > 0.0 or opts['diag_p_w'] > 0.0: cross_p, diag_p = _gram_scores(layer_x, filter_size) embedded_outputs.append(cross_p * opts['cross_p_w']) embedded_outputs.append(diag_p * opts['diag_p_w']) fl = flatten(layer_x) # fl = tf.Print(fl, [fl], "fl") embedded_outputs.append(fl) size = int(last.get_shape()[1]) linear_outputs.append(tf.reshape(last, [-1, size * size])) if len(embedded_outputs) > 1: embedded_outputs = tf.concat(embedded_outputs, 1) else: embedded_outputs = embedded_outputs[0] if len(linear_outputs) > 1: linear_outputs = tf.concat(linear_outputs, 1) else: linear_outputs = linear_outputs[0] return embedded_outputs, linear_outputs reconstructed_embed_sg, adv_fake_layer = _architecture(tf.stop_gradient(reconstructed_training), reuse=None) reconstructed_embed, _ = _architecture(reconstructed_training, reuse=True) # Below line enforces the forward to be reconstructed_embed and backwards to NOT change the discriminator.... crazy_hack = reconstructed_embed-reconstructed_embed_sg+tf.stop_gradient(reconstructed_embed_sg) real_p_embed_sg, adv_true_layer = _architecture(tf.stop_gradient(real_points), reuse=True) real_p_embed, _ = _architecture(real_points, reuse=True) adv_fake = tf.nn.sigmoid_cross_entropy_with_logits( logits=adv_fake_layer, labels=tf.zeros_like(adv_fake_layer)) adv_true = tf.nn.sigmoid_cross_entropy_with_logits( logits=adv_true_layer, labels=tf.ones_like(adv_true_layer)) adv_fake = tf.reduce_mean(adv_fake) adv_true = tf.reduce_mean(adv_true) adv_c_loss = adv_fake + adv_true emb_c = tf.reduce_mean(tf.square(crazy_hack - tf.stop_gradient(real_p_embed)), 1) real_points_shuffle = tf.stop_gradient(tf.random_shuffle(real_p_embed)) emb_c_shuffle = tf.reduce_mean(tf.square(real_points_shuffle - tf.stop_gradient(reconstructed_embed)), 1) raw_emb_c_loss = tf.reduce_mean(emb_c) shuffled_emb_c_loss = tf.reduce_mean(emb_c_shuffle) emb_c_loss = raw_emb_c_loss / shuffled_emb_c_loss emb_c_loss = emb_c_loss * 40 return adv_c_loss, emb_c_loss def _recon_loss_using_disc_conv_eb(self, opts, reconstructed_training, real_points, is_training, keep_prob): """Build an additional loss using a discriminator in X space, using Energy Based approach.""" def copy3D(height, width, channels): m = np.zeros([height, width, channels, height, width, channels]) for i in xrange(height): for j in xrange(width): for c in xrange(channels): m[i, j, c, i, j, c] = 1.0 return tf.constant(np.reshape(m, [height, width, channels, -1]), dtype=tf.float32) def _architecture(inputs, reuse=None): dim = opts['adv_c_patches_size'] height = int(inputs.get_shape()[1]) width = int(inputs.get_shape()[2]) channels = int(inputs.get_shape()[3]) with tf.variable_scope('DISC_X_LOSS', reuse=reuse): num_units = opts['adv_c_num_units'] num_layers = 1 layer_x = inputs for i in xrange(num_layers): # scale = 2**(num_layers-i-1) layer_x = ops.conv2d(opts, layer_x, num_units, d_h=1, d_w=1, scope='h%d_conv' % i, conv_filters_dim=dim, padding='SAME') # if opts['batch_norm']: # layer_x = ops.batch_norm(opts, layer_x, is_training, reuse, scope='bn%d' % i) layer_x = ops.lrelu(layer_x, 0.1) #tf.nn.relu(layer_x) copy_w = copy3D(dim, dim, channels) duplicated = tf.nn.conv2d(inputs, copy_w, strides=[1, 1, 1, 1], padding='SAME') decoded = ops.conv2d( opts, layer_x, channels * dim * dim, d_h=1, d_w=1, scope="decoder", conv_filters_dim=1, padding='SAME') reconstruction = tf.reduce_mean(tf.square(tf.stop_gradient(duplicated) - decoded), [1, 2, 3]) assert len(reconstruction.get_shape()) == 1 return flatten(layer_x), reconstruction reconstructed_embed_sg, adv_fake_layer = _architecture(tf.stop_gradient(reconstructed_training), reuse=None) reconstructed_embed, _ = _architecture(reconstructed_training, reuse=True) # Below line enforces the forward to be reconstructed_embed and backwards to NOT change the discriminator.... crazy_hack = reconstructed_embed-reconstructed_embed_sg+tf.stop_gradient(reconstructed_embed_sg) real_p_embed_sg, adv_true_layer = _architecture(tf.stop_gradient(real_points), reuse=True) real_p_embed, _ = _architecture(real_points, reuse=True) adv_fake = tf.reduce_mean(adv_fake_layer) adv_true = tf.reduce_mean(adv_true_layer) adv_c_loss = tf.log(adv_true) - tf.log(adv_fake) emb_c = tf.reduce_sum(tf.square(crazy_hack - tf.stop_gradient(real_p_embed)), 1) emb_c_loss = tf.reduce_mean(emb_c) return adv_c_loss, emb_c_loss def _recon_loss_using_vgg(self, opts, reconstructed_training, real_points, is_training, keep_prob): """Build an additional loss using a pretrained VGG in X space.""" def _architecture(_inputs, reuse=None): _, end_points = vgg_16(_inputs, is_training=is_training, dropout_keep_prob=keep_prob, reuse=reuse) layer_name = opts['vgg_layer'] if layer_name == 'concat': outputs = [] for ln in ['pool1', 'pool2', 'pool3']: output = end_points[ln] output = flatten(output) outputs.append(output) output = tf.concat(outputs, 1) elif layer_name.startswith('concat_w'): weights = layer_name.split(',')[1:] assert len(weights) == 5 outputs = [] for lnum in range(5): num = lnum + 1 ln = 'pool%d' % num output = end_points[ln] output = flatten(output) # We sqrt the weight here because we use L2 after. outputs.append(np.sqrt(float(weights[lnum])) * output) output = tf.concat(outputs, 1) else: output = end_points[layer_name] output = flatten(output) if reuse is None: variables_to_restore = slim.get_variables_to_restore(include=['vgg_16']) path = os.path.join(opts['data_dir'], 'vgg_16.ckpt') # '/tmpp/models/vgg_16.ckpt' init_assign_op, init_feed_dict = slim.assign_from_checkpoint(path, variables_to_restore) self._additional_init_ops += [init_assign_op] self._init_feed_dict.update(init_feed_dict) return output reconstructed_embed_sg = _architecture(tf.stop_gradient(reconstructed_training), reuse=None) reconstructed_embed = _architecture(reconstructed_training, reuse=True) # Below line enforces the forward to be reconstructed_embed and backwards to NOT change the discriminator.... crazy_hack = reconstructed_embed-reconstructed_embed_sg+tf.stop_gradient(reconstructed_embed_sg) real_p_embed = _architecture(real_points, reuse=True) emb_c = tf.reduce_mean(tf.square(crazy_hack - tf.stop_gradient(real_p_embed)), 1) emb_c_loss = tf.reduce_mean(tf.sqrt(emb_c + 1e-5)) # emb_c_loss = tf.Print(emb_c_loss, [emb_c_loss], "emb_c_loss") # # Normalize the loss, so that it does not depend on how good the # # discriminator is. # emb_c_loss = emb_c_loss / tf.stop_gradient(emb_c_loss) return emb_c_loss def _recon_loss_using_moments(self, opts, reconstructed_training, real_points, is_training, keep_prob): """Build an additional loss using moments.""" def _architecture(_inputs): return compute_moments(_inputs, moments=[2]) # TODO reconstructed_embed = _architecture(reconstructed_training) real_p_embed = _architecture(real_points) emb_c = tf.reduce_mean(tf.square(reconstructed_embed - tf.stop_gradient(real_p_embed)), 1) # emb_c = tf.Print(emb_c, [emb_c], "emb_c") emb_c_loss = tf.reduce_mean(emb_c) return emb_c_loss * 100.0 * 100.0 # TODO: constant. def _recon_loss_using_vgg_moments(self, opts, reconstructed_training, real_points, is_training, keep_prob): """Build an additional loss using a pretrained VGG in X space.""" def _architecture(_inputs, reuse=None): _, end_points = vgg_16(_inputs, is_training=is_training, dropout_keep_prob=keep_prob, reuse=reuse) layer_name = opts['vgg_layer'] output = end_points[layer_name] # output = flatten(output) output /= 255.0 # the vgg_16 method scales everything by 255.0, so we divide back here. variances = compute_moments(output, moments=[2]) if reuse is None: variables_to_restore = slim.get_variables_to_restore(include=['vgg_16']) path = os.path.join(opts['data_dir'], 'vgg_16.ckpt') # '/tmpp/models/vgg_16.ckpt' init_assign_op, init_feed_dict = slim.assign_from_checkpoint(path, variables_to_restore) self._additional_init_ops += [init_assign_op] self._init_feed_dict.update(init_feed_dict) return variances reconstructed_embed_sg = _architecture(tf.stop_gradient(reconstructed_training), reuse=None) reconstructed_embed = _architecture(reconstructed_training, reuse=True) # Below line enforces the forward to be reconstructed_embed and backwards to NOT change the discriminator.... crazy_hack = reconstructed_embed-reconstructed_embed_sg+tf.stop_gradient(reconstructed_embed_sg) real_p_embed = _architecture(real_points, reuse=True) emb_c = tf.reduce_mean(tf.square(crazy_hack - tf.stop_gradient(real_p_embed)), 1) emb_c_loss = tf.reduce_mean(emb_c) # emb_c_loss = tf.Print(emb_c_loss, [emb_c_loss], "emb_c_loss") # # Normalize the loss, so that it does not depend on how good the # # discriminator is. # emb_c_loss = emb_c_loss / tf.stop_gradient(emb_c_loss) return emb_c_loss # TODO: constant. def add_least_gaussian2d_ops(self, opts): """ Add ops searching for the 2d plane in z_dim hidden space corresponding to the 'least Gaussian' look of the sample """ with tf.variable_scope('leastGaussian2d'): # Projection matrix which we are going to tune sample_ph = tf.placeholder( tf.float32, [None, opts['latent_space_dim']], name='sample_ph') v = tf.get_variable( "proj_v", [opts['latent_space_dim'], 1], tf.float32, tf.random_normal_initializer(stddev=1.)) u = tf.get_variable( "proj_u", [opts['latent_space_dim'], 1], tf.float32, tf.random_normal_initializer(stddev=1.)) npoints = tf.cast(tf.shape(sample_ph)[0], tf.int32) # First we need to make sure projection matrix is orthogonal v_norm = tf.nn.l2_normalize(v, 0) dotprod = tf.reduce_sum(tf.multiply(u, v_norm)) u_ort = u - dotprod * v_norm u_norm = tf.nn.l2_normalize(u_ort, 0) Mproj = tf.concat([v_norm, u_norm], 1) sample_proj = tf.matmul(sample_ph, Mproj) a = tf.eye(npoints) - tf.ones([npoints, npoints]) / tf.cast(npoints, tf.float32) b = tf.matmul(sample_proj, tf.matmul(a, a), transpose_a=True) b = tf.matmul(b, sample_proj) # Sample covariance matrix covhat = b / (tf.cast(npoints, tf.float32) - 1) # covhat = tf.Print(covhat, [covhat], 'Cov:') with tf.variable_scope('leastGaussian2d'): gcov = opts['pot_pz_std'] * opts['pot_pz_std'] * tf.eye(2) # l2 distance between sample cov and the Gaussian cov projloss = tf.reduce_sum(tf.square(covhat - gcov)) # Also account for the first moment, i.e. expected value projloss += tf.reduce_sum(tf.square(tf.reduce_mean(sample_proj, 0))) # We are maximizing projloss = -projloss optim = tf.train.AdamOptimizer(0.001, 0.9) optim = optim.minimize(projloss, var_list=[v, u]) self._proj_u = u_norm self._proj_v = v_norm self._proj_sample_ph = sample_ph self._proj_covhat = covhat self._proj_loss = projloss self._proj_optim = optim def least_gaussian_2d(self, opts, X): """ Given a sample X of shape (n_points, n_z) find 2d plain such that projection looks least gaussian. """ with self._session.as_default(), self._session.graph.as_default(): sample_ph = self._proj_sample_ph optim = self._proj_optim loss = self._proj_loss u = self._proj_u v = self._proj_v covhat = self._proj_covhat proj_mat = tf.concat([v, u], 1).eval() dot_prod = -1 best_of_runs = 10e5 # Any positive value would do updated = False for _start in xrange(3): # We will run 3 times from random inits loss_prev = 10e5 # Any positive value would do proj_vars = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope="leastGaussian2d") self._session.run(tf.variables_initializer(proj_vars)) step = 0 for _ in xrange(5000): self._session.run(optim, feed_dict={sample_ph:X}) step += 1 if step % 10 == 0: loss_cur = loss.eval(feed_dict={sample_ph: X}) rel_imp = abs(loss_cur - loss_prev) / abs(loss_prev) if rel_imp < 1e-2: break loss_prev = loss_cur loss_final = loss.eval(feed_dict={sample_ph: X}) if loss_final < best_of_runs: updated = True best_of_runs = loss_final proj_mat = tf.concat([v, u], 1).eval() dot_prod = tf.reduce_sum(tf.multiply(u, v)).eval() if not updated: logging.error('WARNING: possible bug in the worst 2d projection') return proj_mat, dot_prod def _build_model_internal(self, opts): """Build the Graph corresponding to POT implementation. """ data_shape = self._data.data_shape additional_losses = collections.OrderedDict() # Placeholders real_points_ph = tf.placeholder( tf.float32, [None] + list(data_shape), name='real_points_ph') noise_ph = tf.placeholder( tf.float32, [None] + [opts['latent_space_dim']], name='noise_ph') enc_noise_ph = tf.placeholder( tf.float32, [None] + [opts['latent_space_dim']], name='enc_noise_ph') lr_decay_ph = tf.placeholder(tf.float32) is_training_ph = tf.placeholder(tf.bool, name='is_training_ph') keep_prob_ph = tf.placeholder(tf.float32, name='keep_prob_ph') # Operations if opts['pz_transform']: assert opts['z_test'] == 'gan', 'Pz transforms are currently allowed only for POT+GAN' noise = self.pz_sampler(opts, noise_ph) else: noise = noise_ph real_points = self._data_augmentation( opts, real_points_ph, is_training_ph) if opts['e_is_random']: # If encoder is random we map the training points # to the expectation of Q(Z|X) and then add the scaled # Gaussian noise corresponding to the learned sigmas enc_train_mean, enc_log_sigmas = self.encoder( opts, real_points, is_training=is_training_ph, keep_prob=keep_prob_ph) # enc_log_sigmas = tf.Print(enc_log_sigmas, [tf.reduce_max(enc_log_sigmas), # tf.reduce_min(enc_log_sigmas), # tf.reduce_mean(enc_log_sigmas)], 'Log sigmas:') # enc_log_sigmas = tf.Print(enc_log_sigmas, [tf.slice(enc_log_sigmas, [0,0], [1,-1])], 'Log sigmas:') # stds = tf.sqrt(tf.exp(enc_log_sigmas) + 1e-05) stds = tf.sqrt(tf.nn.relu(enc_log_sigmas) + 1e-05) # stds = tf.Print(stds, [stds[0], stds[1], stds[2], stds[3]], 'Stds: ') # stds = tf.Print(stds, [enc_train_mean[0], enc_train_mean[1], enc_train_mean[2]], 'Means: ') scaled_noise = tf.multiply(stds, enc_noise_ph) encoded_training = enc_train_mean + scaled_noise else: encoded_training = self.encoder( opts, real_points, is_training=is_training_ph, keep_prob=keep_prob_ph) reconstructed_training = self.generator( opts, encoded_training, is_training=is_training_ph, keep_prob=keep_prob_ph) reconstructed_training.set_shape(real_points.get_shape()) if opts['recon_loss'] == 'l2': # c(x,y) = ||x - y||_2 loss_reconstr = tf.reduce_sum( tf.square(real_points - reconstructed_training), axis=1) # sqrt(x + delta) guarantees the direvative 1/(x + delta) is finite loss_reconstr = tf.reduce_mean(tf.sqrt(loss_reconstr + 1e-08)) elif opts['recon_loss'] == 'l2f': # c(x,y) = ||x - y||_2 loss_reconstr = tf.reduce_sum( tf.square(real_points - reconstructed_training), axis=[1, 2, 3]) loss_reconstr = tf.reduce_mean(tf.sqrt(1e-08 + loss_reconstr)) * 0.2 elif opts['recon_loss'] == 'l2sq': # c(x,y) = ||x - y||_2^2 loss_reconstr = tf.reduce_sum( tf.square(real_points - reconstructed_training), axis=[1, 2, 3]) loss_reconstr = tf.reduce_mean(loss_reconstr) * 0.05 elif opts['recon_loss'] == 'l1': # c(x,y) = ||x - y||_1 loss_reconstr = tf.reduce_mean(tf.reduce_sum( tf.abs(real_points - reconstructed_training), axis=[1, 2, 3])) * 0.02 else: assert False # Pearson independence test of coordinates in Z space loss_z_corr = self.correlation_loss(opts, encoded_training) # Perform a Qz = Pz goodness of fit test based on Stein Discrepancy if opts['z_test'] == 'gan': # Pz = Qz test based on GAN in the Z space d_logits_Pz = self.discriminator(opts, noise) d_logits_Qz = self.discriminator(opts, encoded_training, reuse=True) d_loss_Pz = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=d_logits_Pz, labels=tf.ones_like(d_logits_Pz))) d_loss_Qz = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=d_logits_Qz, labels=tf.zeros_like(d_logits_Qz))) d_loss_Qz_trick = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=d_logits_Qz, labels=tf.ones_like(d_logits_Qz))) d_loss = opts['pot_lambda'] * (d_loss_Pz + d_loss_Qz) if opts['pz_transform']: loss_match = d_loss_Qz_trick - d_loss_Pz else: loss_match = d_loss_Qz_trick elif opts['z_test'] == 'mmd': # Pz = Qz test based on MMD(Pz, Qz) loss_match = self.discriminator_mmd_test(opts, encoded_training, noise) d_loss = None d_logits_Pz = None d_logits_Qz = None elif opts['z_test'] == 'lks': # Pz = Qz test without adversarial training # based on Kernel Stein Discrepancy # Uncomment next line to check for the real Pz # loss_match = self.discriminator_test(opts, noise_ph) loss_match = self.discriminator_test(opts, encoded_training) d_loss = None d_logits_Pz = None d_logits_Qz = None else: # Pz = Qz test without adversarial training # (a) Check for multivariate Gaussianity # by checking Gaussianity of all the 1d projections # (b) Run Pearson's test of coordinate independance loss_match = self.discriminator_test(opts, encoded_training) loss_match = loss_match + opts['z_test_corr_w'] * loss_z_corr d_loss = None d_logits_Pz = None d_logits_Qz = None g_mom_stats = self.moments_stats(opts, encoded_training) loss = opts['reconstr_w'] * loss_reconstr + opts['pot_lambda'] * loss_match # Optionally, add one more cost function based on the embeddings # add a discriminator in the X space, reusing the encoder or a new model. if opts['adv_c_loss'] == 'encoder': adv_c_loss, emb_c_loss = self._recon_loss_using_disc_encoder( opts, reconstructed_training, encoded_training, real_points, is_training_ph, keep_prob_ph) loss += opts['adv_c_loss_w'] * adv_c_loss + opts['emb_c_loss_w'] * emb_c_loss additional_losses['adv_c'], additional_losses['emb_c'] = adv_c_loss, emb_c_loss elif opts['adv_c_loss'] == 'conv': adv_c_loss, emb_c_loss = self._recon_loss_using_disc_conv( opts, reconstructed_training, real_points, is_training_ph, keep_prob_ph) additional_losses['adv_c'], additional_losses['emb_c'] = adv_c_loss, emb_c_loss loss += opts['adv_c_loss_w'] * adv_c_loss + opts['emb_c_loss_w'] * emb_c_loss elif opts['adv_c_loss'] == 'conv_eb': adv_c_loss, emb_c_loss = self._recon_loss_using_disc_conv_eb( opts, reconstructed_training, real_points, is_training_ph, keep_prob_ph) additional_losses['adv_c'], additional_losses['emb_c'] = adv_c_loss, emb_c_loss loss += opts['adv_c_loss_w'] * adv_c_loss + opts['emb_c_loss_w'] * emb_c_loss elif opts['adv_c_loss'] == 'vgg': emb_c_loss = self._recon_loss_using_vgg( opts, reconstructed_training, real_points, is_training_ph, keep_prob_ph) loss += opts['emb_c_loss_w'] * emb_c_loss additional_losses['emb_c'] = emb_c_loss elif opts['adv_c_loss'] == 'moments': emb_c_loss = self._recon_loss_using_moments( opts, reconstructed_training, real_points, is_training_ph, keep_prob_ph) loss += opts['emb_c_loss_w'] * emb_c_loss additional_losses['emb_c'] = emb_c_loss elif opts['adv_c_loss'] == 'vgg_moments': emb_c_loss = self._recon_loss_using_vgg_moments( opts, reconstructed_training, real_points, is_training_ph, keep_prob_ph) loss += opts['emb_c_loss_w'] * emb_c_loss additional_losses['emb_c'] = emb_c_loss else: assert opts['adv_c_loss'] == 'none' # Add ops to pretrain the Qz match mean and covariance of Pz loss_pretrain = None if opts['e_pretrain']: # Next two vectors are zdim-dimensional mean_pz = tf.reduce_mean(noise, axis=0, keep_dims=True) mean_qz = tf.reduce_mean(encoded_training, axis=0, keep_dims=True) mean_loss = tf.reduce_mean(tf.square(mean_pz - mean_qz)) cov_pz = tf.matmul(noise - mean_pz, noise - mean_pz, transpose_a=True) cov_pz /= opts['e_pretrain_bsize'] - 1. cov_qz = tf.matmul(encoded_training - mean_qz, encoded_training - mean_qz, transpose_a=True) cov_qz /= opts['e_pretrain_bsize'] - 1. cov_loss = tf.reduce_mean(tf.square(cov_pz - cov_qz)) loss_pretrain = mean_loss + cov_loss # Also add ops to find the least Gaussian 2d projection # this is handy when visually inspection Qz = Pz self.add_least_gaussian2d_ops(opts) # Optimizer ops t_vars = tf.trainable_variables() # Updates for discriminator d_vars = [var for var in t_vars if 'DISCRIMINATOR/' in var.name] # Updates for everything but adversary (encoder, decoder and possibly pz-transform) all_vars = [var for var in t_vars if 'DISCRIMINATOR/' not in var.name] # Updates for everything but adversary (encoder, decoder and possibly pz-transform) eg_vars = [var for var in t_vars if 'GENERATOR/' in var.name or 'ENCODER/' in var.name] # Encoder variables separately if we want to pretrain e_vars = [var for var in t_vars if 'ENCODER/' in var.name] logging.error('Param num in G and E: %d' % \ np.sum([np.prod([int(d) for d in v.get_shape()]) for v in eg_vars])) for v in eg_vars: print v.name, [int(d) for d in v.get_shape()] if len(d_vars) > 0: d_optim = ops.optimizer(opts, net='d', decay=lr_decay_ph).minimize(loss=d_loss, var_list=d_vars) else: d_optim = None optim = ops.optimizer(opts, net='g', decay=lr_decay_ph).minimize(loss=loss, var_list=all_vars) pretrain_optim = None if opts['e_pretrain']: pretrain_optim = ops.optimizer(opts, net='g').minimize(loss=loss_pretrain, var_list=e_vars) generated_images = self.generator( opts, noise, is_training=is_training_ph, reuse=True, keep_prob=keep_prob_ph) self._real_points_ph = real_points_ph self._real_points = real_points self._noise_ph = noise_ph self._noise = noise self._enc_noise_ph = enc_noise_ph self._lr_decay_ph = lr_decay_ph self._is_training_ph = is_training_ph self._keep_prob_ph = keep_prob_ph self._optim = optim self._d_optim = d_optim self._pretrain_optim = pretrain_optim self._loss = loss self._loss_reconstruct = loss_reconstr self._loss_match = loss_match self._loss_z_corr = loss_z_corr self._loss_pretrain = loss_pretrain self._additional_losses = additional_losses self._g_mom_stats = g_mom_stats self._d_loss = d_loss self._generated = generated_images self._Qz = encoded_training self._reconstruct_x = reconstructed_training saver = tf.train.Saver(max_to_keep=10) tf.add_to_collection('real_points_ph', self._real_points_ph) tf.add_to_collection('noise_ph', self._noise_ph) tf.add_to_collection('enc_noise_ph', self._enc_noise_ph) if opts['pz_transform']: tf.add_to_collection('noise', self._noise) tf.add_to_collection('is_training_ph', self._is_training_ph) tf.add_to_collection('keep_prob_ph', self._keep_prob_ph) tf.add_to_collection('encoder', self._Qz) tf.add_to_collection('decoder', self._generated) if d_logits_Pz is not None: tf.add_to_collection('disc_logits_Pz', d_logits_Pz) if d_logits_Qz is not None: tf.add_to_collection('disc_logits_Qz', d_logits_Qz) self._saver = saver logging.error("Building Graph Done.") def pretrain(self, opts): steps_max = 200 batch_size = opts['e_pretrain_bsize'] for step in xrange(steps_max): train_size = self._data.num_points data_ids = np.random.choice(train_size, min(train_size, batch_size), replace=False) batch_images = self._data.data[data_ids].astype(np.float) batch_noise = opts['pot_pz_std'] *\ utils.generate_noise(opts, batch_size) # Noise for the random encoder (if present) batch_enc_noise = utils.generate_noise(opts, batch_size) # Update encoder [_, loss_pretrain] = self._session.run( [self._pretrain_optim, self._loss_pretrain], feed_dict={self._real_points_ph: batch_images, self._noise_ph: batch_noise, self._enc_noise_ph: batch_enc_noise, self._is_training_ph: True, self._keep_prob_ph: opts['dropout_keep_prob']}) if opts['verbose'] == 2: logging.error('Step %d/%d, loss=%f' % (step, steps_max, loss_pretrain)) if loss_pretrain < 0.1: break def _train_internal(self, opts): """Train a POT model. """ logging.error(opts) batches_num = self._data.num_points / opts['batch_size'] train_size = self._data.num_points num_plot = 320 sample_prev = np.zeros([num_plot] + list(self._data.data_shape)) l2s = [] losses = [] losses_rec = [] losses_match = [] wait = 0 start_time = time.time() counter = 0 decay = 1. logging.error('Training POT') # Optionally we first pretrain the Qz to match mean and # covariance of Pz if opts['e_pretrain']: logging.error('Pretraining the encoder') self.pretrain(opts) logging.error('Pretraining the encoder done') for _epoch in xrange(opts["gan_epoch_num"]): if opts['decay_schedule'] == "manual": if _epoch == 30: decay = decay / 2. if _epoch == 50: decay = decay / 5. if _epoch == 100: decay = decay / 10. elif opts['decay_schedule'] != "plateau": assert type(1.0 * opts['decay_schedule']) == float decay = 1.0 * 10**(-_epoch / float(opts['decay_schedule'])) if _epoch > 0 and _epoch % opts['save_every_epoch'] == 0: os.path.join(opts['work_dir'], opts['ckpt_dir']) self._saver.save(self._session, os.path.join(opts['work_dir'], opts['ckpt_dir'], 'trained-pot'), global_step=counter) for _idx in xrange(batches_num): data_ids = np.random.choice(train_size, opts['batch_size'], replace=False, p=self._data_weights) batch_images = self._data.data[data_ids].astype(np.float) # Noise for the Pz=Qz GAN batch_noise = opts['pot_pz_std'] *\ utils.generate_noise(opts, opts['batch_size']) # Noise for the random encoder (if present) batch_enc_noise = utils.generate_noise(opts, opts['batch_size']) # Update generator (decoder) and encoder [_, loss, loss_rec, loss_match] = self._session.run( [self._optim, self._loss, self._loss_reconstruct, self._loss_match], feed_dict={self._real_points_ph: batch_images, self._noise_ph: batch_noise, self._enc_noise_ph: batch_enc_noise, self._lr_decay_ph: decay, self._is_training_ph: True, self._keep_prob_ph: opts['dropout_keep_prob']}) if opts['decay_schedule'] == "plateau": # First 30 epochs do nothing if _epoch >= 30: # If no significant progress was made in last 10 epochs # then decrease the learning rate. if loss < min(losses[-20 * batches_num:]): wait = 0 else: wait += 1 if wait > 10 * batches_num: decay = max(decay / 1.4, 1e-6) logging.error('Reduction in learning rate: %f' % decay) wait = 0 losses.append(loss) losses_rec.append(loss_rec) losses_match.append(loss_match) if opts['verbose'] >= 2: # logging.error('loss after %d steps : %f' % (counter, losses[-1])) logging.error('loss match after %d steps : %f' % (counter, losses_match[-1])) # Update discriminator in Z space (if any). if self._d_optim is not None: for _st in range(opts['d_steps']): if opts['d_new_minibatch']: d_data_ids = np.random.choice( train_size, opts['batch_size'], replace=False, p=self._data_weights) d_batch_images = self._data.data[data_ids].astype(np.float) d_batch_enc_noise = utils.generate_noise(opts, opts['batch_size']) else: d_batch_images = batch_images d_batch_enc_noise = batch_enc_noise _ = self._session.run( [self._d_optim, self._d_loss], feed_dict={self._real_points_ph: d_batch_images, self._noise_ph: batch_noise, self._enc_noise_ph: d_batch_enc_noise, self._lr_decay_ph: decay, self._is_training_ph: True, self._keep_prob_ph: opts['dropout_keep_prob']}) counter += 1 now = time.time() rec_test = None if opts['verbose'] and counter % 500 == 0: # Printing (training and test) loss values test = self._data.test_data[:200] [loss_rec_test, rec_test, g_mom_stats, loss_z_corr, additional_losses] = self._session.run( [self._loss_reconstruct, self._reconstruct_x, self._g_mom_stats, self._loss_z_corr, self._additional_losses], feed_dict={self._real_points_ph: test, self._enc_noise_ph: utils.generate_noise(opts, len(test)), self._is_training_ph: False, self._noise_ph: batch_noise, self._keep_prob_ph: 1e5}) debug_str = 'Epoch: %d/%d, batch:%d/%d, batch/sec:%.2f' % ( _epoch+1, opts['gan_epoch_num'], _idx+1, batches_num, float(counter) / (now - start_time)) debug_str += ' [L=%.5f, Recon=%.5f, GanL=%.5f, Recon_test=%.5f' % ( loss, loss_rec, loss_match, loss_rec_test) debug_str += ',' + ', '.join( ['%s=%.2g' % (k, v) for (k, v) in additional_losses.items()]) logging.error(debug_str) if opts['verbose'] >= 2: logging.error(g_mom_stats) logging.error(loss_z_corr) if counter % opts['plot_every'] == 0: # plotting the test images. metrics = Metrics() merged = np.vstack([rec_test[:8 * 10], test[:8 * 10]]) r_ptr = 0 w_ptr = 0 for _ in range(8 * 10): merged[w_ptr] = test[r_ptr] merged[w_ptr + 1] = rec_test[r_ptr] r_ptr += 1 w_ptr += 2 metrics.make_plots( opts, counter, None, merged, prefix='test_reconstr_e%04d_mb%05d_' % (_epoch, _idx)) if opts['verbose'] and counter % opts['plot_every'] == 0: # Plotting intermediate results metrics = Metrics() # --Random samples from the model points_to_plot, sample_pz = self._session.run( [self._generated, self._noise], feed_dict={ self._noise_ph: self._noise_for_plots[0:num_plot], self._is_training_ph: False, self._keep_prob_ph: 1e5}) Qz_num = 320 sample_Qz = self._session.run( self._Qz, feed_dict={ self._real_points_ph: self._data.data[:Qz_num], self._enc_noise_ph: utils.generate_noise(opts, Qz_num), self._is_training_ph: False, self._keep_prob_ph: 1e5}) # Searching least Gaussian 2d projection proj_mat, check = self.least_gaussian_2d(opts, sample_Qz) # Projecting samples from Qz and Pz on this 2d plain metrics.Qz = np.dot(sample_Qz, proj_mat) # metrics.Pz = np.dot(self._noise_for_plots, proj_mat) metrics.Pz = np.dot(sample_pz, proj_mat) if self._data.labels != None: metrics.Qz_labels = self._data.labels[:Qz_num] else: metrics.Qz_labels = None metrics.l2s = losses[:] metrics.losses_match = [opts['pot_lambda'] * el for el in losses_match] metrics.losses_rec = [opts['reconstr_w'] * el for el in losses_rec] to_plot = [points_to_plot, 0 * batch_images[:16], batch_images] if rec_test is not None: to_plot += [0 * batch_images[:16], rec_test[:64]] metrics.make_plots( opts, counter, None, np.vstack(to_plot), prefix='sample_e%04d_mb%05d_' % (_epoch, _idx) if rec_test is None \ else 'sample_with_test_e%04d_mb%05d_' % (_epoch, _idx)) # --Reconstructions for the train and test points num_real_p = 8 * 10 reconstructed, real_p = self._session.run( [self._reconstruct_x, self._real_points], feed_dict={ self._real_points_ph: self._data.data[:num_real_p], self._enc_noise_ph: utils.generate_noise(opts, num_real_p), self._is_training_ph: True, self._keep_prob_ph: 1e5}) points = real_p merged = np.vstack([reconstructed, points]) r_ptr = 0 w_ptr = 0 for _ in range(8 * 10): merged[w_ptr] = points[r_ptr] merged[w_ptr + 1] = reconstructed[r_ptr] r_ptr += 1 w_ptr += 2 metrics.make_plots( opts, counter, None, merged, prefix='reconstr_e%04d_mb%05d_' % (_epoch, _idx)) sample_prev = points_to_plot[:] if _epoch > 0: os.path.join(opts['work_dir'], opts['ckpt_dir']) self._saver.save(self._session, os.path.join(opts['work_dir'], opts['ckpt_dir'], 'trained-pot-final'), global_step=counter) def _sample_internal(self, opts, num): """Sample from the trained GAN model. """ # noise = opts['pot_pz_std'] * utils.generate_noise(opts, num) # sample = self._run_batch( # opts, self._generated, self._noise_ph, noise, self._is_training_ph, False) sample = None return sample