import math import numpy as np import tensorflow as tf from contextlib import contextmanager from tensorflow.python.framework import ops from utils import * slim = tf.contrib.slim rng = np.random.RandomState([2016, 6, 1]) ln = tf.contrib.layers.layer_norm bn = slim.batch_norm def conv_cond_concat(x, y): """Concatenate conditioning vector on feature map axis.""" x_shapes = x.get_shape() y_shapes = y.get_shape() return tf.concat([x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3) def lrelu(x, leak=0.2, name="lrelu"): with tf.variable_scope(name): f1 = 0.5 * (1 + leak) f2 = 0.5 * (1 - leak) return f1 * x + f2 * abs(x) def sin_and_cos(x, name="ignored"): return tf.concat(len(x.get_shape()) - 1, [tf.sin(x), tf.cos(x)]) def maxout(x, k = 2): shape = [int(e) for e in x.get_shape()] ax = len(shape) ch = shape[-1] assert ch % k == 0 shape[-1] = ch / k shape.append(k) x = tf.reshape(x, shape) return tf.reduce_max(x, ax) def offset_maxout(x, k = 2): shape = [int(e) for e in x.get_shape()] ax = len(shape) ch = shape[-1] assert ch % k == 0 shape[-1] = ch / k shape.append(k) x = tf.reshape(x, shape) ofs = rng.randn(1000, k).max(axis=1).mean() return tf.reduce_max(x, ax) - ofs def lrelu_sq(x): """ Concatenates lrelu and square """ dim = len(x.get_shape()) - 1 return tf.concat([lrelu(x), tf.minimum(tf.abs(x), tf.square(x))], dim) def nin(input_, output_size, name=None, mean=0., stddev=0.02, bias_start=0.0, with_w=False): s = list(map(int, input_.get_shape())) input_ = tf.reshape(input_, [np.prod(s[:-1]), s[-1]]) input_ = fc(input_, output_size, act=None, norm=None) return tf.reshape(input_, s[:-1]+[output_size]) @contextmanager def variables_on_cpu(): old_fn = tf.get_variable def new_fn(*args, **kwargs): with tf.device("/cpu:0"): return old_fn(*args, **kwargs) tf.get_variable = new_fn yield tf.get_variable = old_fn @contextmanager def variables_on_gpu0(): old_fn = tf.get_variable def new_fn(*args, **kwargs): with tf.device("/gpu:0"): return old_fn(*args, **kwargs) tf.get_variable = new_fn yield tf.get_variable = old_fn def avg_grads(tower_grads): """Calculate the average gradient for each shared variable across all towers. Note that this function provides a synchronization point across all towers. Args: tower_grads: List of lists of (gradient, variable) tuples. The outer list is over individual gradients. The inner list is over the gradient calculation for each tower. Returns: List of pairs of (gradient, variable) where the gradient has been averaged across all towers. """ average_grads = [] for grad_and_vars in zip(*tower_grads): # Note that each grad_and_vars looks like the following: # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) grads = [] for g, _ in grad_and_vars: # Add 0 dimension to the gradients to represent the tower. expanded_g = tf.expand_dims(g, 0) # Append on a 'tower' dimension which we will average over below. grads.append(expanded_g) # Average over the 'tower' dimension. grad = tf.reduce_mean(tf.concat(grads, 0), 0) # Keep in mind that the Variables are redundant because they are shared # across towers. So .. we will just return the first tower's pointer to # the Variable. v = grad_and_vars[0][1] grad_and_var = (grad, v) average_grads.append(grad_and_var) return average_grads def decayer(x, name="decayer"): with tf.variable_scope(name): scale = tf.get_variable("scale", [1], initializer=tf.constant_initializer(1.)) decay_scale = tf.get_variable("decay_scale", [1], initializer=tf.constant_initializer(1.)) relu = tf.nn.relu(x) return scale * relu / (1. + tf.abs(decay_scale) * tf.square(decay_scale)) def decayer2(x, name="decayer"): with tf.variable_scope(name): scale = tf.get_variable("scale", [int(x.get_shape()[-1])], initializer=tf.constant_initializer(1.)) decay_scale = tf.get_variable("decay_scale", [int(x.get_shape()[-1])], initializer=tf.constant_initializer(1.)) relu = tf.nn.relu(x) return scale * relu / (1. + tf.abs(decay_scale) * tf.square(decay_scale)) def masked_relu(x, name="ignored"): shape = [int(e) for e in x.get_shape()] prefix = [0] * (len(shape) - 1) most = shape[:-1] assert shape[-1] % 2 == 0 half = shape[-1] // 2 first_half = tf.slice(x, prefix + [0], most + [half]) second_half = tf.slice(x, prefix + [half], most + [half]) return tf.nn.relu(first_half) * tf.nn.sigmoid(second_half) def make_z(shape, minval=-1.0, maxval=1.0, name="z"): z = tf.random_uniform(shape, minval=minval, maxval=maxval, name=name, dtype=tf.float32) #z = tf.random_normal(shape, name=name, stddev=0.5, dtype=tf.float32) return z def get_sample_zs(model): assert model.sample_size > model.batch_size assert model.sample_size % model.batch_size == 0 if model.config.multigpu: batch_size = model.batch_size // len(model.devices) else: batch_size = model.batch_size steps = model.sample_size // batch_size assert steps > 0 sample_zs = [] for i in xrange(steps): cur_zs = model.sess.run(model.z) sample_zs.append(cur_zs) sample_zs = np.concatenate(sample_zs, axis=0) assert sample_zs.shape[0] == model.sample_size return sample_zs def batch_to_grid(images, width=8): images = tf.squeeze(images[:width**2]) images_list = tf.unstack(images, num=width**2, axis=0) conc = tf.concat(images_list, axis=1) sp = tf.split(conc, width, axis=1) grid = tf.expand_dims(tf.concat(sp, axis=0), axis=0) if len(grid.get_shape().as_list()) < 4: grid = tf.expand_dims(grid, axis=-1) return grid @tf.contrib.framework.add_arg_scope def fc(x, out_dim, is_training, act=tf.nn.relu, norm=bn, init=tf.truncated_normal_initializer(stddev=0.02)):#, is_training=False): if norm == bn: return slim.fully_connected( x, out_dim, activation_fn=act, normalizer_fn=norm, weights_initializer=init, normalizer_params={'is_training':is_training}) else: return slim.fully_connected( x, out_dim, activation_fn=act, normalizer_fn=norm, weights_initializer=init) @tf.contrib.framework.add_arg_scope def deconv2d(x, out_dim, k, s, is_training, act=tf.nn.relu, norm=bn, init=tf.truncated_normal_initializer(stddev=0.02)):#, is_training=False): if norm == bn: return slim.conv2d_transpose( x, out_dim, k, s, activation_fn=act, normalizer_fn=norm, weights_initializer=init, normalizer_params={'is_training':is_training}) else: return slim.conv2d_transpose( x, out_dim, k, s, activation_fn=act, normalizer_fn=norm, weights_initializer=init) @tf.contrib.framework.add_arg_scope def conv2d(x, out_dim, k, s, is_training, act=tf.nn.relu, norm=bn, init=tf.truncated_normal_initializer(stddev=0.02)):#, is_training=False): if norm == bn: return slim.conv2d( x, out_dim, k, s, activation_fn=act, normalizer_fn=norm, weights_initializer=init, normalizer_params={'is_training':is_training}) else: return slim.conv2d( x, out_dim, k, s, activation_fn=act, normalizer_fn=norm, weights_initializer=init) def preprocess_image(image, dataset, use_augmentation=False): image = tf.divide(image, 255., name=None) if use_augmentation: image = tf.image.random_brightness(image, max_delta=16. / 255.) image = tf.image.random_contrast(image, lower=0.8, upper=1.2) image = tf.minimum(tf.maximum(image, 0.0), 1.0) if ('mnist' not in dataset) and ('fashion' not in dataset): image = tf.subtract(image * 2., 1.) return image def conv_mean_pool(x, out_dim, k=3, act=tf.nn.relu, norm=bn, init=tf.truncated_normal_initializer(stddev=0.02)): h = conv2d(x, out_dim, k=k, s=1, act=act, norm=norm, init=init) return tf.add_n([h[:,::2,::2,:], h[:,1::2,::2,:], h[:,::2,1::2,:], h[:,1::2,1::2,:]]) / 4. def resize_conv2d(x, out_dim, k=3, scale=2, act=tf.nn.relu, norm=bn, init=tf.truncated_normal_initializer(stddev=0.02)): h = tf.concat([x, x, x, x], axis=3) h = tf.depth_to_space(h, 2) return conv2d(h, out_dim, k=k, s=1, act=act, norm=norm, init=init) def residual_block(x, resample=None, labels=None, act=tf.nn.relu, norm=bn, init=tf.truncated_normal_initializer(stddev=0.02)): c_dim = x.get_shape().as_list()[-1] if resample=='down': h = conv2d(x, c_dim, 3, 1, act=act, init=init) #h = conv2d(x, c_dim, 3, 1, act=act, init=init, norm=norm) h = conv_mean_pool(h, c_dim, 3, act=None, norm=None, init=init) h += conv_mean_pool(x, c_dim, 1, act=None, norm=None, init=init) h = act(norm(h)) elif resample=='up': h = resize_conv2d(x, c_dim, 3, act=act, init=init) #h = resize_conv2d(x, c_dim, 3, act=act, init=init, norm=norm) h = conv2d(h, c_dim, 3, 1, act=None, norm=None, init=init) h += resize_conv2d(x, c_dim, 1, act=None, norm=None, init=init) h = act(norm(h)) elif resample==None: h = conv2d(x, c_dim, 3, 1, act=act, init=init) #h = conv2d(x, c_dim, 3, 1, act=act, init=init, norm=norm) h = conv2d(h, c_dim, 3, 1, act=None, norm=None, init=init) h += x h = act(norm(h)) else: raise Exception('invalid resample value') return h def get_vars_maybe_avg(var_names, ema, **kwargs): ''' utility for retrieving polyak averaged params ''' vars = [] for vn in var_names: vars.append(get_var_maybe_avg(vn, ema, **kwargs)) return vars def concat_relu(x): axis = len(x.get_shape().as_list())-1 return tf.nn.relu(tf.concat([x, -x], axis)) def concat_elu(x): """ like concatenated ReLU (http://arxiv.org/abs/1603.05201), but then with ELU """ axis = len(x.get_shape().as_list())-1 return tf.nn.elu(tf.concat([x, -x], axis)) def fc_wn(x, num_units, nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs): ''' fully connected layer ''' name = get_name('dense', counters) with tf.variable_scope(name): if init: # data based initialization of parameters V = tf.get_variable('V', [int(x.get_shape()[1]),num_units], tf.float32, tf.random_normal_initializer(0, 0.05), trainable=True) V_norm = tf.nn.l2_normalize(V.initialized_value(), [0]) x_init = tf.matmul(x, V_norm) m_init, v_init = tf.nn.moments(x_init, [0]) scale_init = init_scale/tf.sqrt(v_init + 1e-10) g = tf.get_variable('g', dtype=tf.float32, initializer=scale_init, trainable=True) b = tf.get_variable('b', dtype=tf.float32, initializer=-m_init*scale_init, trainable=True) x_init = tf.reshape(scale_init,[1,num_units])*(x_init-tf.reshape(m_init,[1,num_units])) if nonlinearity is not None: x_init = nonlinearity(x_init) return x_init else: V,g,b = get_vars_maybe_avg(['V','g','b'], ema) tf.assert_variables_initialized([V,g,b]) # use weight normalization (Salimans & Kingma, 2016) x = tf.matmul(x, V) scaler = g/tf.sqrt(tf.reduce_sum(tf.square(V),[0])) x = tf.reshape(scaler,[1,num_units])*x + tf.reshape(b,[1,num_units]) # apply nonlinearity if nonlinearity is not None: x = nonlinearity(x) return x def conv2d_wn(x, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs): ''' convolutional layer ''' name = get_name('conv2d', counters) with tf.variable_scope(name): if init: # data based initialization of parameters V = tf.get_variable('V', filter_size+[int(x.get_shape()[-1]),num_filters], tf.float32, tf.random_normal_initializer(0, 0.05), trainable=True) V_norm = tf.nn.l2_normalize(V.initialized_value(), [0,1,2]) x_init = tf.nn.conv2d(x, V_norm, [1]+stride+[1], pad) m_init, v_init = tf.nn.moments(x_init, [0,1,2]) scale_init = init_scale/tf.sqrt(v_init + 1e-8) g = tf.get_variable('g', dtype=tf.float32, initializer=scale_init, trainable=True) b = tf.get_variable('b', dtype=tf.float32, initializer=-m_init*scale_init, trainable=True) x_init = tf.reshape(scale_init,[1,1,1,num_filters])*(x_init-tf.reshape(m_init,[1,1,1,num_filters])) if nonlinearity is not None: x_init = nonlinearity(x_init) return x_init else: V, g, b = get_vars_maybe_avg(['V', 'g', 'b'], ema) tf.assert_variables_initialized([V,g,b]) # use weight normalization (Salimans & Kingma, 2016) W = tf.reshape(g,[1,1,1,num_filters])*tf.nn.l2_normalize(V,[0,1,2]) # calculate convolutional layer output x = tf.nn.bias_add(tf.nn.conv2d(x, W, [1]+stride+[1], pad), b) # apply nonlinearity if nonlinearity is not None: x = nonlinearity(x) return x def deconv2d_wn(x, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs): ''' transposed convolutional layer ''' name = get_name('deconv2d', counters) xs = int_shape(x) if pad=='SAME': target_shape = [xs[0], xs[1]*stride[0], xs[2]*stride[1], num_filters] else: target_shape = [xs[0], xs[1]*stride[0] + filter_size[0]-1, xs[2]*stride[1] + filter_size[1]-1, num_filters] with tf.variable_scope(name): if init: # data based initialization of parameters V = tf.get_variable('V', filter_size+[num_filters,int(x.get_shape()[-1])], tf.float32, tf.random_normal_initializer(0, 0.05), trainable=True) V_norm = tf.nn.l2_normalize(V.initialized_value(), [0,1,3]) x_init = tf.nn.conv2d_transpose(x, V_norm, target_shape, [1]+stride+[1], padding=pad) m_init, v_init = tf.nn.moments(x_init, [0,1,2]) scale_init = init_scale/tf.sqrt(v_init + 1e-8) g = tf.get_variable('g', dtype=tf.float32, initializer=scale_init, trainable=True) b = tf.get_variable('b', dtype=tf.float32, initializer=-m_init*scale_init, trainable=True) x_init = tf.reshape(scale_init,[1,1,1,num_filters])*(x_init-tf.reshape(m_init,[1,1,1,num_filters])) if nonlinearity is not None: x_init = nonlinearity(x_init) return x_init else: V, g, b = get_vars_maybe_avg(['V', 'g', 'b'], ema) tf.assert_variables_initialized([V,g,b]) # use weight normalization (Salimans & Kingma, 2016) W = tf.reshape(g,[1,1,num_filters,1])*tf.nn.l2_normalize(V,[0,1,3]) # calculate convolutional layer output x = tf.nn.conv2d_transpose(x, W, target_shape, [1]+stride+[1], padding=pad) x = tf.nn.bias_add(x, b) # apply nonlinearity if nonlinearity is not None: x = nonlinearity(x) return x def gated_resnet(x, a=None, h=None, nonlinearity=concat_elu, conv=conv2d, init=False, counters={}, ema=None, dropout_p=0., **kwargs): xs = int_shape(x) num_filters = xs[-1] c1 = conv(nonlinearity(x), num_filters) if a is not None: # add short-cut connection if auxiliary input 'a' is given c1 += nin(nonlinearity(a), num_filters) c1 = nonlinearity(c1) if dropout_p > 0: c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p) c2 = conv(c1, num_filters * 2, init_scale=0.1) # add projection of h vector if included: conditional generation if h is not None: with tf.variable_scope(get_name('conditional_weights', counters)): hw = get_var_maybe_avg('hw', ema, shape=[int_shape(h)[-1], 2 * num_filters], dtype=tf.float32, initializer=tf.random_normal_initializer(0, 0.05), trainable=True) if init: hw = hw.initialized_value() c2 += tf.reshape(tf.matmul(h, hw), [xs[0], 1, 1, 2 * num_filters]) a, b = tf.split(3, 2, c2) c3 = a * tf.nn.sigmoid(b) return x + c3 def cross_entropy(y, smooth=1e-3): y = tf.minimum(1-smooth, tf.maximum(smooth, y)) return tf.reduce_mean(-tf.log(y)) ops_with_bn = [fc, conv2d, deconv2d]