import numpy as np import tensorflow as tf # pylint: ignore-module #import builtins import functools import copy import os import collections # ================================================================ # Import all names into common namespace # ================================================================ clip = tf.clip_by_value # Make consistent with numpy # ---------------------------------------- def sum(x, axis=None, keepdims=False): return tf.reduce_sum(x, reduction_indices=None if axis is None else [axis], keep_dims = keepdims) def mean(x, axis=None, keepdims=False): return tf.reduce_mean(x, reduction_indices=None if axis is None else [axis], keep_dims = keepdims) def var(x, axis=None, keepdims=False): meanx = mean(x, axis=axis, keepdims=keepdims) return mean(tf.square(x - meanx), axis=axis, keepdims=keepdims) def std(x, axis=None, keepdims=False): return tf.sqrt(var(x, axis=axis, keepdims=keepdims)) def max(x, axis=None, keepdims=False): return tf.reduce_max(x, reduction_indices=None if axis is None else [axis], keep_dims = keepdims) def min(x, axis=None, keepdims=False): return tf.reduce_min(x, reduction_indices=None if axis is None else [axis], keep_dims = keepdims) def concatenate(arrs, axis=0): return tf.concat(axis, arrs) def argmax(x, axis=None): return tf.argmax(x, dimension=axis) def switch(condition, then_expression, else_expression): '''Switches between two operations depending on a scalar value (int or bool). Note that both `then_expression` and `else_expression` should be symbolic tensors of the *same shape*. # Arguments condition: scalar tensor. then_expression: TensorFlow operation. else_expression: TensorFlow operation. ''' x_shape = copy.copy(then_expression.get_shape()) x = tf.cond(tf.cast(condition, 'bool'), lambda: then_expression, lambda: else_expression) x.set_shape(x_shape) return x # Extras # ---------------------------------------- def l2loss(params): if len(params) == 0: return tf.constant(0.0) else: return tf.add_n([sum(tf.square(p)) for p in params]) def lrelu(x, leak=0.2): f1 = 0.5 * (1 + leak) f2 = 0.5 * (1 - leak) return f1 * x + f2 * abs(x) def categorical_sample_logits(X): # https://github.com/tensorflow/tensorflow/issues/456 U = tf.random_uniform(tf.shape(X)) return argmax(X - tf.log(-tf.log(U)), axis=1) # ================================================================ # Global session # ================================================================ def get_session(): return tf.get_default_session() def single_threaded_session(): tf_config = tf.ConfigProto( inter_op_parallelism_threads=1, intra_op_parallelism_threads=1) return tf.Session(config=tf_config) def make_session(num_cpu): tf_config = tf.ConfigProto( inter_op_parallelism_threads=num_cpu, intra_op_parallelism_threads=num_cpu) return tf.Session(config=tf_config) ALREADY_INITIALIZED = set() def initialize(): new_variables = set(tf.all_variables()) - ALREADY_INITIALIZED get_session().run(tf.initialize_variables(new_variables)) ALREADY_INITIALIZED.update(new_variables) def eval(expr, feed_dict=None): if feed_dict is None: feed_dict = {} return get_session().run(expr, feed_dict=feed_dict) def set_value(v, val): get_session().run(v.assign(val)) def load_state(fname): saver = tf.train.Saver() saver.restore(get_session(), fname) def save_state(fname): os.makedirs(os.path.dirname(fname), exist_ok=True) saver = tf.train.Saver() saver.save(get_session(), fname) # ================================================================ # Model components # ================================================================ def normc_initializer(std=1.0): def _initializer(shape, dtype=None, partition_info=None): #pylint: disable=W0613 out = np.random.randn(*shape).astype(np.float32) out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) return tf.constant(out) return _initializer def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME", dtype=tf.float32, collections=None, summary_tag=None): with tf.variable_scope(name): stride_shape = [1, stride[0], stride[1], 1] filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]), num_filters] # there are "num input feature maps * filter height * filter width" # inputs to each hidden unit fan_in = intprod(filter_shape[:3]) # each unit in the lower layer receives a gradient from: # "num output feature maps * filter height * filter width" / # pooling size fan_out = intprod(filter_shape[:2]) * num_filters # initialize weights with random weights w_bound = np.sqrt(6. / (fan_in + fan_out)) w = tf.get_variable("W", filter_shape, dtype, tf.random_uniform_initializer(-w_bound, w_bound), collections=collections) b = tf.get_variable("b", [1, 1, 1, num_filters], initializer=tf.zeros_initializer, collections=collections) if summary_tag is not None: tf.image_summary(summary_tag, tf.transpose(tf.reshape(w, [filter_size[0], filter_size[1], -1, 1]), [2, 0, 1, 3]), max_images=10) return tf.nn.conv2d(x, w, stride_shape, pad) + b def dense(x, size, name, weight_init=None, bias=True): w = tf.get_variable(name + "/w", [x.get_shape()[1], size], initializer=weight_init) ret = tf.matmul(x, w) if bias: b = tf.get_variable(name + "/b", [size], initializer=tf.zeros_initializer) return ret + b else: return ret def wndense(x, size, name, init_scale=1.0): v = tf.get_variable(name + "/V", [int(x.get_shape()[1]), size], initializer=tf.random_normal_initializer(0, 0.05)) g = tf.get_variable(name + "/g", [size], initializer=tf.constant_initializer(init_scale)) b = tf.get_variable(name + "/b", [size], initializer=tf.constant_initializer(0.0)) # use weight normalization (Salimans & Kingma, 2016) x = tf.matmul(x, v) scaler = g / tf.sqrt(sum(tf.square(v), axis=0, keepdims=True)) return tf.reshape(scaler, [1, size]) * x + tf.reshape(b, [1, size]) def densenobias(x, size, name, weight_init=None): return dense(x, size, name, weight_init=weight_init, bias=False) def dropout(x, pkeep, phase=None, mask=None): mask = tf.floor(pkeep + tf.random_uniform(tf.shape(x))) if mask is None else mask if phase is None: return mask * x else: return switch(phase, mask*x, pkeep*x) def batchnorm(x, name, phase, updates, gamma=0.96): k = x.get_shape()[1] runningmean = tf.get_variable(name+"/mean", shape=[1, k], initializer=tf.constant_initializer(0.0), trainable=False) runningvar = tf.get_variable(name+"/var", shape=[1, k], initializer=tf.constant_initializer(1e-4), trainable=False) testy = (x - runningmean) / tf.sqrt(runningvar) mean_ = mean(x, axis=0, keepdims=True) var_ = mean(tf.square(x), axis=0, keepdims=True) std = tf.sqrt(var_) trainy = (x - mean_) / std updates.extend([ tf.assign(runningmean, runningmean * gamma + mean_ * (1 - gamma)), tf.assign(runningvar, runningvar * gamma + var_ * (1 - gamma)) ]) y = switch(phase, trainy, testy) out = y * tf.get_variable(name+"/scaling", shape=[1, k], initializer=tf.constant_initializer(1.0), trainable=True)\ + tf.get_variable(name+"/translation", shape=[1,k], initializer=tf.constant_initializer(0.0), trainable=True) return out # ================================================================ # Basic Stuff # ================================================================ def function(inputs, outputs, updates=None, givens=None): if isinstance(outputs, list): return _Function(inputs, outputs, updates, givens=givens) elif isinstance(outputs, (dict, collections.OrderedDict)): f = _Function(inputs, outputs.values(), updates, givens=givens) return lambda *inputs : type(outputs)(zip(outputs.keys(), f(*inputs))) else: f = _Function(inputs, [outputs], updates, givens=givens) return lambda *inputs : f(*inputs)[0] class _Function(object): def __init__(self, inputs, outputs, updates, givens, check_nan=False): assert all(len(i.op.inputs)==0 for i in inputs), "inputs should all be placeholders" self.inputs = inputs updates = updates or [] self.update_group = tf.group(*updates) self.outputs_update = list(outputs) + [self.update_group] self.givens = {} if givens is None else givens self.check_nan = check_nan def __call__(self, *inputvals): assert len(inputvals) == len(self.inputs) feed_dict = dict(zip(self.inputs, inputvals)) feed_dict.update(self.givens) results = get_session().run(self.outputs_update, feed_dict=feed_dict)[:-1] if self.check_nan: if any(np.isnan(r).any() for r in results): raise RuntimeError("Nan detected") return results def mem_friendly_function(nondata_inputs, data_inputs, outputs, batch_size): if isinstance(outputs, list): return _MemFriendlyFunction(nondata_inputs, data_inputs, outputs, batch_size) else: f = _MemFriendlyFunction(nondata_inputs, data_inputs, [outputs], batch_size) return lambda *inputs : f(*inputs)[0] class _MemFriendlyFunction(object): def __init__(self, nondata_inputs, data_inputs, outputs, batch_size): self.nondata_inputs = nondata_inputs self.data_inputs = data_inputs self.outputs = list(outputs) self.batch_size = batch_size def __call__(self, *inputvals): assert len(inputvals) == len(self.nondata_inputs) + len(self.data_inputs) nondata_vals = inputvals[0:len(self.nondata_inputs)] data_vals = inputvals[len(self.nondata_inputs):] feed_dict = dict(zip(self.nondata_inputs, nondata_vals)) n = data_vals[0].shape[0] for v in data_vals[1:]: assert v.shape[0] == n for i_start in range(0, n, self.batch_size): slice_vals = [v[i_start:min(i_start+self.batch_size, n)] for v in data_vals] for (var,val) in zip(self.data_inputs, slice_vals): feed_dict[var]=val results = tf.get_default_session().run(self.outputs, feed_dict=feed_dict) if i_start==0: sum_results = results else: for i in range(len(results)): sum_results[i] = sum_results[i] + results[i] for i in range(len(results)): sum_results[i] = sum_results[i] / n return sum_results # ================================================================ # Modules # ================================================================ class Module(object): def __init__(self, name): self.name = name self.first_time = True self.scope = None self.cache = {} def __call__(self, *args): if args in self.cache: print("(%s) retrieving value from cache"%self.name) return self.cache[args] with tf.variable_scope(self.name, reuse=not self.first_time): scope = tf.get_variable_scope().name if self.first_time: self.scope = scope print("(%s) running function for the first time"%self.name) else: assert self.scope == scope, "Tried calling function with a different scope" print("(%s) running function on new inputs"%self.name) self.first_time = False out = self._call(*args) self.cache[args] = out return out def _call(self, *args): raise NotImplementedError @property def trainable_variables(self): assert self.scope is not None, "need to call module once before getting variables" return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope) @property def variables(self): assert self.scope is not None, "need to call module once before getting variables" return tf.get_collection(tf.GraphKeys.VARIABLES, self.scope) def module(name): @functools.wraps def wrapper(f): class WrapperModule(Module): def _call(self, *args): return f(*args) return WrapperModule(name) return wrapper # ================================================================ # Graph traversal # ================================================================ VARIABLES = {} def get_parents(node): return node.op.inputs def topsorted(outputs): """ Topological sort via non-recursive depth-first search """ assert isinstance(outputs, (list,tuple)) marks = {} out = [] stack = [] #pylint: disable=W0621 # i: node # jidx = number of children visited so far from that node # marks: state of each node, which is one of # 0: haven't visited # 1: have visited, but not done visiting children # 2: done visiting children for x in outputs: stack.append((x,0)) while stack: (i,jidx) = stack.pop() if jidx == 0: m = marks.get(i,0) if m == 0: marks[i] = 1 elif m == 1: raise ValueError("not a dag") else: continue ps = get_parents(i) if jidx == len(ps): marks[i] = 2 out.append(i) else: stack.append((i,jidx+1)) j = ps[jidx] stack.append((j,0)) return out # ================================================================ # Flat vectors # ================================================================ def var_shape(x): out = [k.value for k in x.get_shape()] assert all(isinstance(a, int) for a in out), \ "shape function assumes that shape is fully known" return out def numel(x): return intprod(var_shape(x)) def intprod(x): return int(np.prod(x)) def flatgrad(loss, var_list): grads = tf.gradients(loss, var_list) return tf.concat(0, [tf.reshape(grad, [numel(v)]) for (v, grad) in zip(var_list, grads)]) class SetFromFlat(object): def __init__(self, var_list, dtype=tf.float32): assigns = [] shapes = list(map(var_shape, var_list)) total_size = np.sum([intprod(shape) for shape in shapes]) self.theta = theta = tf.placeholder(dtype,[total_size]) start=0 assigns = [] for (shape,v) in zip(shapes,var_list): size = intprod(shape) assigns.append(tf.assign(v, tf.reshape(theta[start:start+size],shape))) start+=size self.op = tf.group(*assigns) def __call__(self, theta): get_session().run(self.op, feed_dict={self.theta:theta}) class GetFlat(object): def __init__(self, var_list): self.op = tf.concat(0, [tf.reshape(v, [numel(v)]) for v in var_list]) def __call__(self): return get_session().run(self.op) # ================================================================ # Misc # ================================================================ def fancy_slice_2d(X, inds0, inds1): """ like numpy X[inds0, inds1] XXX this implementation is bad """ inds0 = tf.cast(inds0, tf.int64) inds1 = tf.cast(inds1, tf.int64) shape = tf.cast(tf.shape(X), tf.int64) ncols = shape[1] Xflat = tf.reshape(X, [-1]) return tf.gather(Xflat, inds0 * ncols + inds1) def scope_vars(scope, trainable_only): """ Get variables inside a scope The scope can be specified as a string """ return tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES if trainable_only else tf.GraphKeys.VARIABLES, scope=scope if isinstance(scope, str) else scope.name ) def lengths_to_mask(lengths_b, max_length): """ Turns a vector of lengths into a boolean mask Args: lengths_b: an integer vector of lengths max_length: maximum length to fill the mask Returns: a boolean array of shape (batch_size, max_length) row[i] consists of True repeated lengths_b[i] times, followed by False """ lengths_b = tf.convert_to_tensor(lengths_b) assert lengths_b.get_shape().ndims == 1 mask_bt = tf.expand_dims(tf.range(max_length), 0) < tf.expand_dims(lengths_b, 1) return mask_bt def in_session(f): @functools.wraps(f) def newfunc(*args, **kwargs): with tf.Session(): f(*args, **kwargs) return newfunc _PLACEHOLDER_CACHE = {} # name -> (placeholder, dtype, shape) def get_placeholder(name, dtype, shape): print("calling get_placeholder", name) if name in _PLACEHOLDER_CACHE: out, dtype1, shape1 = _PLACEHOLDER_CACHE[name] assert dtype1==dtype and shape1==shape return out else: out = tf.placeholder(dtype=dtype, shape=shape, name=name) _PLACEHOLDER_CACHE[name] = (out,dtype,shape) return out def get_placeholder_cached(name): return _PLACEHOLDER_CACHE[name][0] def flattenallbut0(x): return tf.reshape(x, [-1, intprod(x.get_shape().as_list()[1:])]) def reset(): global _PLACEHOLDER_CACHE global VARIABLES _PLACEHOLDER_CACHE = {} VARIABLES = {} tf.reset_default_graph()