# coding=utf-8 # Copyright 2018 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Autoencoders.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensor2tensor.layers import common_attention from tensor2tensor.layers import common_hparams from tensor2tensor.layers import common_layers from tensor2tensor.layers import discretization from tensor2tensor.utils import registry from tensor2tensor.utils import t2t_model import tensorflow as tf def lrelu(input_, leak=0.2, name="lrelu"): return tf.maximum(input_, leak * input_, name=name) def reverse_gradient(x): return -x + tf.stop_gradient(2 * x) @registry.register_model class AutoencoderBasic(t2t_model.T2TModel): """A basic autoencoder, try with image_mnist_rev or image_cifar10_rev.""" def __init__(self, *args, **kwargs): super(AutoencoderBasic, self).__init__(*args, **kwargs) self._cur_bottleneck_tensor = None self.is1d = None def bottleneck(self, x): with tf.variable_scope("bottleneck"): hparams = self.hparams x = tf.layers.dense(x, hparams.bottleneck_bits, name="bottleneck") if hparams.mode == tf.estimator.ModeKeys.TRAIN: noise = 2.0 * tf.random_uniform(common_layers.shape_list(x)) - 1.0 return tf.tanh(x) + noise * hparams.bottleneck_noise, 0.0 return tf.tanh(x), 0.0 def discriminator(self, x, is_training): """Discriminator architecture based on InfoGAN. Args: x: input images, shape [bs, h, w, channels] is_training: boolean, are we in train or eval model. Returns: out_logit: the output logits (before sigmoid). """ hparams = self.hparams with tf.variable_scope( "discriminator", initializer=tf.random_normal_initializer(stddev=0.02)): batch_size, height, width = common_layers.shape_list(x)[:3] # Mapping x from [bs, h, w, c] to [bs, 1] net = tf.layers.conv2d( x, 64, (4, 4), strides=(2, 2), padding="SAME", name="d_conv1") # [bs, h/2, w/2, 64] net = lrelu(net) net = tf.layers.conv2d( net, 128, (4, 4), strides=(2, 2), padding="SAME", name="d_conv2") # [bs, h/4, w/4, 128] if hparams.discriminator_batchnorm: net = tf.layers.batch_normalization( net, training=is_training, momentum=0.999, name="d_bn2") net = lrelu(net) size = height * width net = tf.reshape(net, [batch_size, size * 8]) # [bs, h * w * 8] net = tf.layers.dense(net, 1024, name="d_fc3") # [bs, 1024] if hparams.discriminator_batchnorm: net = tf.layers.batch_normalization( net, training=is_training, momentum=0.999, name="d_bn3") net = lrelu(net) return net def unbottleneck(self, x, res_size, reuse=None): with tf.variable_scope("unbottleneck", reuse=reuse): x = tf.layers.dense(x, res_size, name="dense") return x def make_even_size(self, x): if not self.is1d: return common_layers.make_even_size(x) shape1 = x.get_shape().as_list()[1] if shape1 is not None and shape1 % 2 == 0: return x x, _ = common_layers.pad_to_same_length( x, x, final_length_divisible_by=2, axis=1) return x def encoder(self, x): with tf.variable_scope("encoder"): hparams = self.hparams kernel, strides = self._get_kernel_and_strides() # Down-convolutions. for i in range(hparams.num_hidden_layers): x = self.make_even_size(x) x = tf.layers.conv2d( x, hparams.hidden_size * 2**(i + 1), kernel, strides=strides, padding="SAME", activation=common_layers.belu, name="conv_%d" % i) x = common_layers.layer_norm(x) return x def decoder(self, x): with tf.variable_scope("decoder"): hparams = self.hparams kernel, strides = self._get_kernel_and_strides() # Up-convolutions. for i in range(hparams.num_hidden_layers): j = hparams.num_hidden_layers - i - 1 x = tf.layers.conv2d_transpose( x, hparams.hidden_size * 2**j, kernel, strides=strides, padding="SAME", activation=common_layers.belu, name="deconv_%d" % j) x = common_layers.layer_norm(x) return x def body(self, features): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN if hparams.mode != tf.estimator.ModeKeys.PREDICT: x = features["targets"] shape = common_layers.shape_list(x) is1d = shape[2] == 1 self.is1d = is1d # Run encoder. x = self.encoder(x) # Bottleneck (mix during early training, not too important but stable). b, b_loss = self.bottleneck(x) self._cur_bottleneck_tensor = b b = self.unbottleneck(b, common_layers.shape_list(x)[-1]) b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training) if hparams.gan_loss_factor != 0.0: # Add a purely sampled batch on which we'll compute the GAN loss. g = self.unbottleneck( self.sample(), common_layers.shape_list(x)[-1], reuse=True) b = tf.concat([g, b], axis=0) # With probability bottleneck_max_prob use the bottleneck, otherwise x. if hparams.bottleneck_max_prob < -1.0: x = tf.where( tf.less(tf.random_uniform([]), hparams.bottleneck_max_prob), b, x) else: x = b else: if self._cur_bottleneck_tensor is None: b = self.sample() else: b = self._cur_bottleneck_tensor res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x) if hparams.mode == tf.estimator.ModeKeys.PREDICT: return x, {"bottleneck_loss": 0.0} # Cut to the right size and mix before returning. res = x[:, :shape[1], :shape[2], :] # Add GAN loss if requested. gan_loss = 0.0 if hparams.gan_loss_factor != 0.0: # Split back if we added a purely sampled batch. res_gan, res = tf.split(res, 2, axis=0) num_channels = self.hparams.problem.num_channels res_rgb = common_layers.convert_real_to_rgb( tf.nn.sigmoid(tf.layers.dense(res_gan, num_channels, name="gan_rgb"))) tf.summary.image( "gan", common_layers.tpu_safe_image_summary(res_rgb), max_outputs=1) orig_rgb = tf.to_float(features["targets_raw"]) def discriminate(x): return self.discriminator(x, is_training=is_training) gan_loss = common_layers.sliced_gan_loss(orig_rgb, reverse_gradient(res_rgb), discriminate, self.hparams.num_sliced_vecs) gan_loss *= hparams.gan_loss_factor # Mix the final result and return. res = common_layers.mix(res, features["targets"], hparams.bottleneck_warmup_steps // 2, is_training) return res, {"bottleneck_loss": b_loss, "gan_loss": -gan_loss} def sample(self, features=None, shape=None): del features, shape hp = self.hparams div_x = 2**hp.num_hidden_layers div_y = 1 if self.is1d else 2**hp.num_hidden_layers size = [ hp.batch_size, hp.sample_height // div_x, hp.sample_width // div_y, hp.bottleneck_bits ] # Sample in [-1, 1] as the bottleneck is under tanh. return 2.0 * tf.random_uniform(size) - 1.0 def encode(self, x): """Auto-encode x and return the bottleneck.""" features = {"targets": x} self(features) # pylint: disable=not-callable res = tf.maximum(0.0, self._cur_bottleneck_tensor) # Be 0/1 and not -1/1. self._cur_bottleneck_tensor = None return res def infer(self, features, *args, **kwargs): # pylint: disable=arguments-differ """Produce predictions from the model by sampling.""" del args, kwargs # Inputs and features preparation needed to handle edge cases. if not features: features = {} inputs_old = None if "inputs" in features and len(features["inputs"].shape) < 4: inputs_old = features["inputs"] features["inputs"] = tf.expand_dims(features["inputs"], 2) # Sample and decode. # TODO(lukaszkaiser): is this a universal enough way to get channels? try: num_channels = self.hparams.problem.num_channels except AttributeError: num_channels = 1 if "targets" not in features: features["targets"] = tf.zeros( [self.hparams.batch_size, 1, 1, num_channels], dtype=tf.int32) logits, _ = self(features) # pylint: disable=not-callable samples = tf.argmax(logits, axis=-1) # Restore inputs to not confuse Estimator in edge cases. if inputs_old is not None: features["inputs"] = inputs_old # Return samples. return samples def decode(self, bottleneck): """Auto-decode from the bottleneck and return the result.""" # Get the shape from bottleneck and num channels. shape = common_layers.shape_list(bottleneck) try: num_channels = self.hparams.problem.num_channels except AttributeError: num_channels = 1 dummy_targets = tf.zeros(shape[:-1] + [num_channels]) # Set the bottleneck to decode. if len(shape) > 4: bottleneck = tf.squeeze(bottleneck, axis=[1]) bottleneck = 2 * bottleneck - 1 # Be -1/1 instead of 0/1. self._cur_bottleneck_tensor = bottleneck # Run decoding. res = self.infer({"targets": dummy_targets}) self._cur_bottleneck_tensor = None return res def _get_kernel_and_strides(self): hparams = self.hparams kernel = (hparams.kernel_height, hparams.kernel_width) kernel = (hparams.kernel_height, 1) if self.is1d else kernel strides = (2, 1) if self.is1d else (2, 2) return (kernel, strides) @registry.register_model class AutoencoderAutoregressive(AutoencoderBasic): """Autoencoder with an autoregressive part.""" def body(self, features): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN # Run the basic autoencoder part first. basic_result, losses = super(AutoencoderAutoregressive, self).body(features) if hparams.autoregressive_mode == "none": assert not hparams.autoregressive_forget_base return basic_result, losses shape = common_layers.shape_list(basic_result) basic1d = tf.reshape(basic_result, [shape[0], -1, shape[3]]) # During autoregressive inference, don't resample. if hparams.mode == tf.estimator.ModeKeys.PREDICT: if hasattr(hparams, "sampled_basic1d_tensor"): basic1d = hparams.sampled_basic1d_tensor else: hparams.sampled_basic1d_tensor = basic1d # Prepare inputs for autoregressive modes. if common_layers.shape_list(features["targets"])[1] == 1: # This happens on the first step of predicitions. assert hparams.mode == tf.estimator.ModeKeys.PREDICT features["targets"] = tf.zeros_like(basic_result) targets_dropout = common_layers.mix( features["targets"], tf.zeros_like(basic_result), hparams.bottleneck_warmup_steps, is_training, max_prob=1.0 - hparams.autoregressive_dropout, broadcast_last=True) # Sometimes it's useful to look at non-autoregressive evals. if (hparams.mode == tf.estimator.ModeKeys.EVAL and hparams.autoregressive_eval_pure_autoencoder): targets_dropout = tf.zeros_like(basic_result) # Now combine the basic reconstruction with shifted targets. targets1d = tf.reshape(targets_dropout, [shape[0], -1, shape[3]]) targets_shifted = common_layers.shift_right_3d(targets1d) concat1d = tf.concat([basic1d, targets_shifted], axis=-1) # The forget_base hparam sets purely-autoregressive mode, no autoencoder. if hparams.autoregressive_forget_base: concat1d = tf.reshape(features["targets"], [shape[0], -1, shape[3]]) concat1d = common_layers.shift_right_3d(concat1d) # The autoregressive part depends on the mode. if hparams.autoregressive_mode == "conv3": res = common_layers.conv1d( concat1d, shape[3], 3, padding="LEFT", activation=common_layers.belu, name="autoregressive_conv3") return tf.reshape(res, shape), losses if hparams.autoregressive_mode == "conv5": res = common_layers.conv1d( concat1d, shape[3], 5, padding="LEFT", activation=common_layers.belu, name="autoregressive_conv5") return tf.reshape(res, shape), losses if hparams.autoregressive_mode == "sru": res = common_layers.conv1d( concat1d, shape[3], 3, padding="LEFT", activation=common_layers.belu, name="autoregressive_sru_conv3") res = common_layers.sru(res) return tf.reshape(res, shape), losses raise ValueError( "Unsupported autoregressive mode: %s" % hparams.autoregressive_mode) def infer(self, features, *args, **kwargs): """Produce predictions from the model by sampling.""" # Inputs and features preparation needed to handle edge cases. if not features: features = {} inputs_old = None if "inputs" in features and len(features["inputs"].shape) < 4: inputs_old = features["inputs"] features["inputs"] = tf.expand_dims(features["inputs"], 2) # Sample first. try: num_channels = self.hparams.problem.num_channels except AttributeError: num_channels = 1 if "targets" not in features: features["targets"] = tf.zeros( [self.hparams.batch_size, 1, 1, num_channels], dtype=tf.int32) logits, _ = self(features) # pylint: disable=not-callable samples = common_layers.sample_with_temperature(logits, 0.0) shape = common_layers.shape_list(samples) # Sample again if requested for the autoregressive part. extra_samples = self.hparams.autoregressive_decode_steps self.hparams.autoregressive_dropout = 0.2 for i in range(extra_samples): if i == extra_samples - 2: self.hparams.autoregressive_dropout -= 0.1 self.hparams.sampling_temp /= 2 if i == extra_samples - 1: self.hparams.autoregressive_dropout -= 0.1 self.hparams.sampling_temp = 0.0 features["targets"] = samples old_samples1d = tf.reshape(samples, [shape[0], -1, shape[3]]) with tf.variable_scope(tf.get_variable_scope(), reuse=True): logits, _ = self(features) # pylint: disable=not-callable samples = common_layers.sample_with_temperature( logits, self.hparams.sampling_temp) samples1d = tf.reshape(samples, [shape[0], -1, shape[3]]) samples1d = tf.concat( [old_samples1d[:, :i, :], samples1d[:, i:, :]], axis=1) samples = tf.reshape(samples1d, shape) # Restore inputs to not confuse Estimator in edge cases. if inputs_old is not None: features["inputs"] = inputs_old # Return samples. return samples @registry.register_model class AutoencoderResidual(AutoencoderAutoregressive): """Residual autoencoder.""" def dropout(self, x): if self.hparams.dropout <= 0.0: return x # For simple dropout just do this: # return tf.nn.dropout(x, 1.0 - self.hparams.dropout) is_training = self.hparams.mode == tf.estimator.ModeKeys.TRAIN return common_layers.mix( tf.zeros_like(x), x, self.hparams.bottleneck_warmup_steps, is_training, max_prob=self.hparams.dropout, broadcast_last=True) def encoder(self, x): with tf.variable_scope("encoder"): hparams = self.hparams kernel, strides = self._get_kernel_and_strides() residual_kernel = (hparams.residual_kernel_height, hparams.residual_kernel_width) residual_kernel1d = (hparams.residual_kernel_height, 1) residual_kernel = residual_kernel1d if self.is1d else residual_kernel residual_conv = tf.layers.conv2d if hparams.residual_use_separable_conv: residual_conv = tf.layers.separable_conv2d # Input embedding with a non-zero bias for uniform inputs. x = tf.layers.dense( x, hparams.hidden_size, name="embed", activation=common_layers.belu, bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Down-convolutions. for i in range(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % i): x = self.make_even_size(x) x = self.dropout(x) filters = hparams.hidden_size * 2**(i + 1) filters = min(filters, hparams.max_hidden_size) x = tf.layers.conv2d( x, filters, kernel, strides=strides, padding="SAME", activation=common_layers.belu, name="strided") y = x for r in range(hparams.num_residual_layers): residual_filters = filters if r < hparams.num_residual_layers - 1: residual_filters = int( filters * hparams.residual_filter_multiplier) y = residual_conv( y, residual_filters, residual_kernel, padding="SAME", activation=common_layers.belu, name="residual_%d" % r) x += tf.nn.dropout(y, 1.0 - hparams.residual_dropout) x = common_layers.layer_norm(x) return x def decoder(self, x): with tf.variable_scope("decoder"): hparams = self.hparams kernel, strides = self._get_kernel_and_strides() residual_kernel = (hparams.residual_kernel_height, hparams.residual_kernel_width) residual_kernel1d = (hparams.residual_kernel_height, 1) residual_kernel = residual_kernel1d if self.is1d else residual_kernel residual_conv = tf.layers.conv2d if hparams.residual_use_separable_conv: residual_conv = tf.layers.separable_conv2d # Up-convolutions. for i in range(hparams.num_hidden_layers): j = hparams.num_hidden_layers - i - 1 filters = hparams.hidden_size * 2**j filters = min(filters, hparams.max_hidden_size) with tf.variable_scope("layer_%d" % i): j = hparams.num_hidden_layers - i - 1 filters = hparams.hidden_size * 2**j x = tf.layers.conv2d_transpose( x, filters, kernel, strides=strides, padding="SAME", activation=common_layers.belu, name="strided") y = x for r in range(hparams.num_residual_layers): residual_filters = filters if r < hparams.num_residual_layers - 1: residual_filters = int( filters * hparams.residual_filter_multiplier) y = residual_conv( y, residual_filters, residual_kernel, padding="SAME", activation=common_layers.belu, name="residual_%d" % r) x += tf.nn.dropout(y, 1.0 - hparams.residual_dropout) x = common_layers.layer_norm(x) x = common_attention.add_timing_signal_nd(x) return x @registry.register_model class AutoencoderBasicDiscrete(AutoencoderAutoregressive): """Discrete autoencoder.""" def bottleneck(self, x): hparams = self.hparams x = tf.tanh(tf.layers.dense(x, hparams.bottleneck_bits, name="bottleneck")) d = x + tf.stop_gradient(2.0 * tf.to_float(tf.less(0.0, x)) - 1.0 - x) if hparams.mode == tf.estimator.ModeKeys.TRAIN: noise = tf.random_uniform(common_layers.shape_list(x)) noise = 2.0 * tf.to_float(tf.less(hparams.bottleneck_noise, noise)) - 1.0 d *= noise x = common_layers.mix(d, x, hparams.discretize_warmup_steps, hparams.mode == tf.estimator.ModeKeys.TRAIN) return x, 0.0 def sample(self, features=None): del features hp = self.hparams div_x = 2**hp.num_hidden_layers div_y = 1 if self.is1d else 2**hp.num_hidden_layers size = [ hp.batch_size, hp.sample_height // div_x, hp.sample_width // div_y, hp.bottleneck_bits ] rand = tf.random_uniform(size) return 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0 @registry.register_model class AutoencoderResidualDiscrete(AutoencoderResidual): """Discrete residual autoencoder.""" def variance_loss(self, b): part = tf.random_uniform(common_layers.shape_list(b)) selection = tf.to_float(tf.less(part, tf.random_uniform([]))) selection_size = tf.reduce_sum(selection) part_avg = tf.abs(tf.reduce_sum(b * selection)) / (selection_size + 1) return part_avg def bottleneck(self, x, bottleneck_bits=None): # pylint: disable=arguments-differ if bottleneck_bits is not None: old_bottleneck_bits = self.hparams.bottleneck_bits self.hparams.bottleneck_bits = bottleneck_bits res, loss = discretization.parametrized_bottleneck(x, self.hparams) if bottleneck_bits is not None: self.hparams.bottleneck_bits = old_bottleneck_bits return res, loss def unbottleneck(self, x, res_size, reuse=None): with tf.variable_scope("unbottleneck", reuse=reuse): return discretization.parametrized_unbottleneck(x, res_size, self.hparams) def sample(self, features=None): del features hp = self.hparams div_x = 2**hp.num_hidden_layers div_y = 1 if self.is1d else 2**hp.num_hidden_layers size = [ hp.batch_size, hp.sample_height // div_x, hp.sample_width // div_y, hp.bottleneck_bits ] rand = tf.random_uniform(size) res = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0 # If you want to set some first bits to a fixed value, do this: # fixed = tf.zeros_like(rand) - 1.0 # nbits = 3 # res = tf.concat([fixed[:, :, :, :nbits], res[:, :, :, nbits:]], axis=-1) return res @registry.register_model class AutoencoderOrderedDiscrete(AutoencoderResidualDiscrete): """Ordered discrete autoencoder.""" def bottleneck(self, x): # pylint: disable=arguments-differ hparams = self.hparams if hparams.unordered: return super(AutoencoderOrderedDiscrete, self).bottleneck(x) noise = hparams.bottleneck_noise hparams.bottleneck_noise = 0.0 # We'll add noise below. x, loss = discretization.parametrized_bottleneck(x, hparams) hparams.bottleneck_noise = noise if hparams.mode == tf.estimator.ModeKeys.TRAIN: # We want a number p such that p^bottleneck_bits = 1 - noise. # So log(p) * bottleneck_bits = log(noise) log_p = tf.log(1 - float(noise) / 2) / float(hparams.bottleneck_bits) # Probabilities of flipping are p, p^2, p^3, ..., p^bottleneck_bits. noise_mask = 1.0 - tf.exp(tf.cumsum(tf.zeros_like(x) + log_p, axis=-1)) # Having the no-noise mask, we can make noise just uniformly at random. ordered_noise = tf.random_uniform(tf.shape(x)) # We want our noise to be 1s at the start and random {-1, 1} bits later. ordered_noise = tf.to_float(tf.less(noise_mask, ordered_noise)) # Now we flip the bits of x on the noisy positions (ordered and normal). x *= 2.0 * ordered_noise - 1 return x, loss @registry.register_model class AutoencoderStacked(AutoencoderResidualDiscrete): """A stacked autoencoder.""" def stack(self, b, size, bottleneck_bits, name): with tf.variable_scope(name + "_stack"): unb = self.unbottleneck(b, size) enc = self.encoder(unb) b, _ = self.bottleneck(enc, bottleneck_bits=bottleneck_bits) return b def unstack(self, b, size, bottleneck_bits, name): with tf.variable_scope(name + "_unstack"): unb = self.unbottleneck(b, size) dec = self.decoder(unb) pred = tf.layers.dense(dec, bottleneck_bits, name="pred") pred_shape = common_layers.shape_list(pred) pred1 = tf.reshape(pred, pred_shape[:-1] + [-1, 2]) x, y = tf.split(pred1, 2, axis=-1) x = tf.squeeze(x, axis=[-1]) y = tf.squeeze(y, axis=[-1]) gt = 2.0 * tf.to_float(tf.less(x, y)) - 1.0 gtc = tf.tanh(y - x) gt += gtc - tf.stop_gradient(gtc) return gt, pred1 def stack_loss(self, b, b_pred, name): with tf.variable_scope(name): labels_discrete = tf.to_int32((b + 1.0) * 0.5) loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels_discrete, logits=b_pred) return tf.reduce_mean(loss) def full_stack(self, b, x_size, bottleneck_bits, losses, is_training, i): stack1_b = self.stack(b, x_size, bottleneck_bits, "step%d" % i) if i > 1: stack1_b = self.full_stack(stack1_b, 2 * x_size, 2 * bottleneck_bits, losses, is_training, i - 1) b1, b_pred = self.unstack(stack1_b, x_size, bottleneck_bits, "step%d" % i) losses["stack%d_loss" % i] = self.stack_loss(b, b_pred, "step%d" % i) b_shape = common_layers.shape_list(b) if is_training: condition = tf.less(tf.random_uniform([]), 0.5) condition = tf.reshape(condition, [1] * len(b.shape)) condition = tf.tile(condition, b.shape) b1 = tf.where(condition, b, b1) return tf.reshape(b1, b_shape) def body(self, features): hparams = self.hparams num_stacks = hparams.num_hidden_layers hparams.num_hidden_layers = 1 is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN if hparams.mode != tf.estimator.ModeKeys.PREDICT: x = features["targets"] shape = common_layers.shape_list(x) is1d = shape[2] == 1 self.is1d = is1d x, _ = common_layers.pad_to_same_length( x, x, final_length_divisible_by=2**num_stacks, axis=1) if not is1d: x, _ = common_layers.pad_to_same_length( x, x, final_length_divisible_by=2**num_stacks, axis=2) # Run encoder. x = self.encoder(x) x_size = common_layers.shape_list(x)[-1] # Bottleneck (mix during early training, not too important but stable). b, b_loss = self.bottleneck(x) losses = {"bottleneck0_loss": b_loss} b = self.full_stack(b, 2 * x_size, 2 * hparams.bottleneck_bits, losses, is_training, num_stacks - 1) b = self.unbottleneck(b, x_size) b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training) # With probability bottleneck_max_prob use the bottleneck, otherwise x. if hparams.bottleneck_max_prob < 1.0: x = tf.where( tf.less(tf.random_uniform([]), hparams.bottleneck_max_prob), b, x) else: x = b else: b = self.sample() res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x) if hparams.mode == tf.estimator.ModeKeys.PREDICT: return x # Cut to the right size and mix before returning. res = x[:, :shape[1], :shape[2], :] res = common_layers.mix(res, features["targets"], hparams.bottleneck_warmup_steps // 2, is_training) hparams.num_hidden_layers = num_stacks return res, losses @registry.register_hparams def autoencoder_basic(): """Basic autoencoder model.""" hparams = common_hparams.basic_params1() hparams.optimizer = "Adam" hparams.learning_rate_constant = 0.0002 hparams.learning_rate_warmup_steps = 500 hparams.learning_rate_schedule = "constant * linear_warmup" hparams.label_smoothing = 0.0 hparams.batch_size = 128 hparams.hidden_size = 64 hparams.num_hidden_layers = 5 hparams.initializer = "uniform_unit_scaling" hparams.initializer_gain = 1.0 hparams.weight_decay = 0.0 hparams.kernel_height = 4 hparams.kernel_width = 4 hparams.dropout = 0.1 hparams.add_hparam("max_hidden_size", 1024) hparams.add_hparam("bottleneck_bits", 128) hparams.add_hparam("bottleneck_noise", 0.1) hparams.add_hparam("bottleneck_warmup_steps", 3000) hparams.add_hparam("bottleneck_max_prob", 1.0) hparams.add_hparam("sample_height", 32) hparams.add_hparam("sample_width", 32) hparams.add_hparam("discriminator_batchnorm", True) hparams.add_hparam("num_sliced_vecs", 4096) hparams.add_hparam("gan_loss_factor", 0.0) return hparams @registry.register_hparams def autoencoder_autoregressive(): """Autoregressive autoencoder model.""" hparams = autoencoder_basic() hparams.add_hparam("autoregressive_forget_base", False) hparams.add_hparam("autoregressive_mode", "none") hparams.add_hparam("autoregressive_dropout", 0.4) hparams.add_hparam("autoregressive_decode_steps", 0) hparams.add_hparam("autoregressive_eval_pure_autoencoder", False) return hparams @registry.register_hparams def autoencoder_residual(): """Residual autoencoder model.""" hparams = autoencoder_autoregressive() hparams.optimizer = "Adafactor" hparams.clip_grad_norm = 1.0 hparams.learning_rate_constant = 0.5 hparams.learning_rate_warmup_steps = 500 hparams.learning_rate_schedule = "constant * linear_warmup * rsqrt_decay" hparams.dropout = 0.05 hparams.num_hidden_layers = 5 hparams.hidden_size = 64 hparams.max_hidden_size = 1024 hparams.add_hparam("num_residual_layers", 2) hparams.add_hparam("residual_kernel_height", 3) hparams.add_hparam("residual_kernel_width", 3) hparams.add_hparam("residual_filter_multiplier", 2.0) hparams.add_hparam("residual_dropout", 0.2) hparams.add_hparam("residual_use_separable_conv", int(True)) return hparams @registry.register_hparams def autoencoder_basic_discrete(): """Basic autoencoder model.""" hparams = autoencoder_autoregressive() hparams.num_hidden_layers = 5 hparams.hidden_size = 64 hparams.bottleneck_bits = 4096 hparams.bottleneck_noise = 0.1 hparams.bottleneck_warmup_steps = 3000 hparams.add_hparam("discretize_warmup_steps", 5000) return hparams @registry.register_hparams def autoencoder_residual_discrete(): """Residual discrete autoencoder model.""" hparams = autoencoder_residual() hparams.bottleneck_bits = 4096 hparams.bottleneck_noise = 0.1 hparams.bottleneck_warmup_steps = 3000 hparams.add_hparam("discretize_warmup_steps", 5000) hparams.add_hparam("bottleneck_kind", "tanh_discrete") hparams.add_hparam("isemhash_noise_dev", 0.5) hparams.add_hparam("isemhash_mix_prob", 0.5) hparams.add_hparam("isemhash_filter_size_multiplier", 2.0) hparams.add_hparam("vq_beta", 0.25) hparams.add_hparam("vq_decay", 0.999) hparams.add_hparam("vq_epsilon", 1e-5) return hparams @registry.register_hparams def autoencoder_residual_discrete_big(): """Residual discrete autoencoder model, big version.""" hparams = autoencoder_residual_discrete() hparams.hidden_size = 128 hparams.max_hidden_size = 4096 hparams.bottleneck_noise = 0.1 hparams.dropout = 0.1 hparams.residual_dropout = 0.4 return hparams @registry.register_hparams def autoencoder_ordered_discrete(): """Ordered discrete autoencoder model.""" hparams = autoencoder_residual_discrete() hparams.bottleneck_noise = 1.0 hparams.gan_loss_factor = 0.0 hparams.dropout = 0.1 hparams.residual_dropout = 0.3 hparams.add_hparam("unordered", False) return hparams @registry.register_hparams def autoencoder_ordered_text(): """Ordered discrete autoencoder model for text.""" hparams = autoencoder_ordered_discrete() hparams.learning_rate_constant = 2.0 hparams.learning_rate_warmup_steps = 2000 hparams.bottleneck_bits = 1024 hparams.batch_size = 2048 hparams.autoregressive_mode = "sru" hparams.hidden_size = 256 hparams.max_hidden_size = 4096 hparams.bottleneck_warmup_steps = 10000 hparams.discretize_warmup_steps = 15000 return hparams @registry.register_hparams def autoencoder_ordered_discrete_vq(): """Ordered discrete autoencoder model with VQ bottleneck.""" hparams = autoencoder_ordered_discrete() hparams.bottleneck_kind = "vq" hparams.bottleneck_bits = 16 return hparams @registry.register_hparams def autoencoder_discrete_pong(): """Discrete autoencoder model for compressing pong frames.""" hparams = autoencoder_ordered_discrete() hparams.num_hidden_layers = 2 hparams.bottleneck_bits = 24 hparams.dropout = 0.1 hparams.batch_size = 2 hparams.bottleneck_noise = 0.2 hparams.max_hidden_size = 1024 hparams.unordered = True return hparams @registry.register_hparams def autoencoder_discrete_cifar(): """Discrete autoencoder model for compressing cifar.""" hparams = autoencoder_ordered_discrete() hparams.bottleneck_noise = 0.0 hparams.bottleneck_bits = 90 hparams.unordered = True hparams.num_hidden_layers = 2 hparams.hidden_size = 256 hparams.num_residual_layers = 4 hparams.batch_size = 32 hparams.learning_rate_constant = 1.0 hparams.dropout = 0.1 return hparams @registry.register_ranged_hparams def autoencoder_discrete_pong_range(rhp): """Narrow tuning grid.""" rhp.set_float("dropout", 0.0, 0.2) rhp.set_discrete("max_hidden_size", [1024, 2048]) @registry.register_hparams def autoencoder_stacked(): """Stacked autoencoder model.""" hparams = autoencoder_residual_discrete() hparams.bottleneck_bits = 128 return hparams