# Copyright 2017 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for creating Sparsely-Gated Mixture-of-Experts Layers. See the most recent draft of our ICLR paper: https://openreview.net/pdf?id=B1ckMDqlg """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import math # Dependency imports import six from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import zip # pylint: disable=redefined-builtin import tensorflow as tf from tensorflow.python.framework import function def NoisyTopKGatingParams(): """Hyperparams defining NoisyTopK Gating Network. Returns: a tf.contrib.training.HParams object """ return tf.contrib.training.HParams( gating_class=NoisyTopKGating, num_experts=16, # The number of experts k=2, # 'The number of experts to use per example input_size=None, # size of input to MoE. Set by MoE class dtype=tf.float32, # floating point data type initializer=tf.zeros_initializer(), # initializer for weight matrices noisy_gating=True, # Add tunable noise (necessary for load-balancing) noise_epsilon=1e-2, # Added to noise stddev for numerical stability ) def FeedForwardExpertParams(): """Hyperparameters defining feed-forward expert networks. Returns: a tf.contrib.training.HParams object """ return tf.contrib.training.HParams( # The class that implements the expert network expert_class=FeedForwardExpert, input_size=None, # Size of input to MoE. Set by MoE class. # List of hidden layer sizes, or None for no hidden layers. # The length of this list determines the number of hidden layers hidden_layer_sizes=None, output_size=None, # Size of output from MoE. Set by MoE class. dtype=tf.float32, # Floating point data type) # Activation function applied at each hidden layer) hidden_activation=tf.nn.relu, initializer=None, # Optional initializer for weight matrices.) # If autoscale=True, At each hidden/output layer, multiply by # rsqrt(prev_layer_size / input_size). This scaling happens # before application of hidden_activation) autoscale=True,) def _SetInputOutputSizes(hp, input_size, output_size): """Fill in the input_size and output_size hyperparameters. This is used by LocalMixtureOfExperts and DistributedMixtureOfExperts to fill in the input_size and output_size on the gating parameters and expert parameters so that the user does not have to set them in multiple places. Args: hp: a hyperparameters input_size: an integer output_size: an integer """ if hp.input_size is None: hp.input_size = input_size else: assert hp.input_size == input_size if output_size is not None: if hp.output_size is None: hp.output_size = output_size else: assert hp.output_size == output_size class FeedForwardExpert(object): """An object representing a feed forward network (used as an expert). """ def __init__(self, hp, name): """Creates a FeedForwardExpert. Args: hp: hyperparameters. Call FeedForwardExpertParams() to create these. name: a string. """ self._hp = hp hidden_layer_sizes = hp.hidden_layer_sizes or [] num_layers = 1 + len(hidden_layer_sizes) layer_sizes = [hp.input_size] + hidden_layer_sizes + [hp.output_size] self._layer_sizes = layer_sizes self._w = [] for layer in range(num_layers): shape = layer_sizes[layer:layer + 2] self._w.append( tf.get_variable('%s_layer_%d' % (name, layer), shape, hp.dtype, hp.initializer)) def Eval(self, x): """Evaluate the FeedForwardExpert on the given input. Args: x: a `Tensor` of shape `[batch_size, hp.input_size]` Returns: a `Tensor` of shape `[batch_size, hp.output_size]` """ hp = self._hp num_layers = len(self._w) for i in xrange(num_layers): x = tf.matmul(x, self._w[i]) if hp.autoscale and self._layer_sizes[i] != hp.input_size: x *= (self._layer_sizes[i] / hp.input_size)**-0.5 if i + 1 < num_layers and hp.hidden_activation: x = hp.hidden_activation(x) return x @property def vars(self): return self._w @function.Defun( python_grad_func=lambda x, dy: tf.convert_to_tensor(dy), shape_func=lambda op: [op.inputs[0].get_shape()]) def ConvertGradientToTensor(x): """Identity operation whose gradient is converted to a `Tensor`. Currently, the gradient to `tf.concat` is particularly expensive to compute if dy is an `IndexedSlices` (a lack of GPU implementation forces the gradient operation onto CPU). This situation occurs when the output of the `tf.concat` is eventually passed to `tf.gather`. It is sometimes faster to convert the gradient to a `Tensor`, so as to get the cheaper gradient for `tf.concat`. To do this, replace `tf.concat(x)` with `ConvertGradientToTensor(tf.concat(x))`. Args: x: A `Tensor`. Returns: The input `Tensor`. """ return x class Parallelism(object): """Helper class for creating sets of parallel function calls. The purpose of this class is to replace this code: e = [] f = [] for i in xrange(len(devices)): with tf.device(devices[i]): e_, f_ = func(a[i], b[i], c) e.append(e_) f.append(f_) with this code: e, f = expert_utils.Parallelism(devices)(func, a, b, c) """ def __init__(self, device_names_or_functions, reuse=None, caching_devices=None, daisy_chain_variables=False): """Create a Parallelism. Args: device_names_or_functions: A list of of length n, containing device names or device functions (see `tf.device`) reuse: True or None. Whether to reuse variables created in the first replica in the subsequent replicas. caching_devices: Either `None`, or a list of length n containing device names. daisy_chain_variables: a boolean - if true, then copies variables in a daisy chain between devices. Returns: a Parallelism. """ assert device_names_or_functions self._devices = device_names_or_functions self._n = len(device_names_or_functions) self._reuse = reuse self._caching_devices = self._MaybeRepeat(caching_devices) self._daisy_chain_variables = daisy_chain_variables def __call__(self, fn, *args, **kwargs): """A parallel set of function calls (using the specified devices). Args: fn: a function or a list of n functions. *args: additional args. Each arg should either be not a list, or a list of length n. **kwargs: additional keyword args. Each arg should either be not a list, or a list of length n. Returns: either a single list of length n (if fn does not return a tuple), or a tuple of lists of length n (if fn returns a tuple). """ # Construct lists or args and kwargs for each function. if args: my_args = TransposeListOfLists([self._MaybeRepeat(arg) for arg in args]) else: my_args = [[] for _ in xrange(self.n)] my_kwargs = [{} for _ in xrange(self.n)] for k, v in six.iteritems(kwargs): vals = self._MaybeRepeat(v) for i in xrange(self.n): my_kwargs[i][k] = vals[i] # Construct lists of functions. fns = self._MaybeRepeat(fn) # Now make the parallel call. outputs = [] cache = {} for i in xrange(self.n): def DaisyChainGetter(getter, name, *args, **kwargs): """Get a variable and cache in a daisy chain.""" device_var_key = (self._devices[i], name) if device_var_key in cache: # if we have the variable on the correct device, return it. return cache[device_var_key] if name in cache: # if we have it on a different device, copy it from the last device v = tf.identity(cache[name]) else: var = getter(name, *args, **kwargs) v = tf.identity(var._ref()) # pylint: disable=protected-access # update the cache cache[name] = v cache[device_var_key] = v return v # Variable scope will not reset caching_device on reused variables, # so we make a custom getter that uses identity to cache the variable. # pylint: disable=cell-var-from-loop def CachingGetter(getter, name, *args, **kwargs): v = getter(name, *args, **kwargs) key = (self._caching_devices[i], name) if key in cache: return cache[key] with tf.device(self._caching_devices[i]): ret = tf.identity(v._ref()) # pylint: disable=protected-access cache[key] = ret return ret if self._daisy_chain_variables: custom_getter = DaisyChainGetter elif self._caching_devices: custom_getter = CachingGetter else: custom_getter = None # pylint: enable=cell-var-from-loop with tf.name_scope('parallel_%d' % i): with tf.variable_scope( tf.get_variable_scope(), reuse=True if i > 0 and self._reuse else None, caching_device=self._caching_devices[i], custom_getter=custom_getter): with tf.device(self._devices[i]): outputs.append(fns[i](*my_args[i], **my_kwargs[i])) if isinstance(outputs[0], tuple): outputs = list(zip(*outputs)) outputs = tuple([list(o) for o in outputs]) return outputs @property def n(self): return self._n @property def devices(self): return self._devices def _MaybeRepeat(self, x): """Utility function for processing arguments that are singletons or lists. Args: x: either a list of self.n elements, or not a list. Returns: a list of self.n elements. """ if isinstance(x, list): assert len(x) == self.n return x else: return [x] * self.n def Parallel(device_names_or_functions, fn, *args): """Deprecated interface. Use `Parallelism(device_names_or_functions)(fn, *args)` instead. Args: device_names_or_functions: A list of length n. fn: a function or a list of n functions. *args: additional args. Each arg should either be not a list, or a list of length n. Returns: either a single list of length n (if fn does not return a tuple), or a tuple of lists of length n (if fn returns a tuple). """ return Parallelism(device_names_or_functions)(fn, *args) def _RowwiseUnsortedSegmentSum(values, indices, n): """UnsortedSegmentSum on each row. Args: values: a `Tensor` with shape `[batch_size, k]`. indices: an integer `Tensor` with shape `[batch_size, k]`. n: an integer. Returns: A `Tensor` with the same type as `values` and shape `[batch_size, n]`. """ batch, k = tf.unstack(tf.shape(indices), num=2) indices_flat = tf.reshape(indices, [-1]) + tf.div(tf.range(batch * k), k) * n ret_flat = tf.unsorted_segment_sum( tf.reshape(values, [-1]), indices_flat, batch * n) return tf.reshape(ret_flat, [batch, n]) def _NormalDistributionCDF(x, stddev): """Evaluates the CDF of the normal distribution. Normal distribution with mean 0 and standard deviation stddev, evaluated at x=x. input and output `Tensor`s have matching shapes. Args: x: a `Tensor` stddev: a `Tensor` with the same shape as `x`. Returns: a `Tensor` with the same shape as `x`. """ return 0.5 * (1.0 + tf.erf(x / (math.sqrt(2) * stddev + 1e-20))) def _ProbInTopK(clean_values, noisy_values, noise_stddev, noisy_top_values, k): """Helper function to NoisyTopKGating. Computes the probability that value is in top k, given different random noise. This gives us a way of backpropagating from a loss that balances the number of times each expert is in the top k experts per example. In the case of no noise, pass in None for noise_stddev, and the result will not be differentiable. Args: clean_values: a `Tensor` of shape [batch, n]. noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus normally distributed noise with standard deviation noise_stddev. noise_stddev: a `Tensor` of shape [batch, n], or None noisy_top_values: a `Tensor` of shape [batch, m]. 'values' Output of tf.top_k(noisy_top_values, m). m >= k+1 k: an integer. Returns: a `Tensor` of shape [batch, n]. """ batch = tf.shape(clean_values)[0] m = tf.shape(noisy_top_values)[1] top_values_flat = tf.reshape(noisy_top_values, [-1]) # we want to compute the threshold that a particular value would have to # exceed in order to make the top k. This computation differs depending # on whether the value is already in the top k. threshold_positions_if_in = tf.range(batch) * m + k threshold_if_in = tf.expand_dims( tf.gather(top_values_flat, threshold_positions_if_in), 1) is_in = tf.greater(noisy_values, threshold_if_in) if noise_stddev is None: return tf.to_float(is_in) threshold_positions_if_out = threshold_positions_if_in - 1 threshold_if_out = tf.expand_dims( tf.gather(top_values_flat, threshold_positions_if_out), 1) # is each value currently in the top k. prob_if_in = _NormalDistributionCDF(clean_values - threshold_if_in, noise_stddev) prob_if_out = _NormalDistributionCDF(clean_values - threshold_if_out, noise_stddev) prob = tf.where(is_in, prob_if_in, prob_if_out) return prob def CVSquared(x): """The squared coefficient of variation of a sample. Useful as a loss to encourage a positive distribution to be more uniform. Epsilons added for numerical stability. Returns 0 for an empty Tensor. Args: x: a `Tensor`. Returns: a `Scalar`. """ epsilon = 1e-10 float_size = tf.to_float(tf.size(x)) + epsilon mean = tf.reduce_sum(x) / float_size variance = tf.reduce_sum(tf.square(x - mean)) / float_size return variance / (tf.square(mean) + epsilon) def MaxOverload(load): """The load of the hardest-hit device relative to average. This is useful for monitoring the performance of MoEs. The load of an expert is the number of examples assigned to that expert. The load of a device is the sum of the loads of all experts on that device. The input to this function is generally the 'load' output of DistributedMixtureOfExperts.Eval(), which is either a 1d or 2d `Tensor` of per-expert loads. In either case, the fist dimension corresponds to devices. This function sums over all dimensions other than dimension zero, then computes the ratio of the maxmium value to the mean value. Args: load: a 1d or 2d `Tensor`. Returns: a `Scalar`. """ per_device_load = tf.reduce_sum(tf.reshape(load, [tf.shape(load)[0], -1]), 1) return (tf.reduce_max(per_device_load) / (tf.reduce_mean(per_device_load) + 1e-10)) def _GatesToLoad(gates): """Compute the true load per expert, given the gates. The load is the number of examples for which the corresponding gate is >0. Args: gates: a `Tensor` of shape [batch_size, n] Returns: a float32 `Tensor` of shape [n] """ return tf.reduce_sum(tf.to_float(gates > 0), 0) def _MyTopK(x, k): """GPU-compatible version of top-k that works for very small constant k. Calls argmax repeatedly. Args: x: a 2d Tensor. k: a small integer. Returns: values: a Tensor of shape [batch_size, k] indices: a int32 Tensor of shape [batch_size, k] """ if k > 10: return tf.nn.top_k(x, k) values = [] indices = [] depth = tf.shape(x)[1] for i in xrange(k): values.append(tf.reduce_max(x, 1)) argmax = tf.argmax(x, 1) indices.append(argmax) if i + 1 < k: x += tf.one_hot(argmax, depth, -1e9) return tf.stack(values, axis=1), tf.to_int32(tf.stack(indices, axis=1)) class NoisyTopKGating(object): """Noisy top-k gating network. See paper: https://arxiv.org/abs/1701.06538. """ def __init__(self, hp, name): """Create a NoisyTopKGating network. Args: hp: a hyperparameters created by NoisyTopKGatingParams() name: a string """ self._vars = [] self._hp = hp self._w_gate = tf.get_variable('%s_gate' % name, [hp.input_size, hp.num_experts], hp.dtype, hp.initializer) self._vars.append(self._w_gate) if hp.noisy_gating: self._w_noise = tf.get_variable('%s_noise' % name, [hp.input_size, hp.num_experts], hp.dtype, hp.initializer) self._vars.append(self._w_noise) def Eval(self, x, train=True, summaries=False): """Compute noisy top-k gating. Args: x: a `Tensor` of shape `[batch_size, input_size]`. train: a boolean `Scalar`. Setting this to false turns off noise. summaries: a boolean. Whether to add summaries. Returns: gates: a `Tensor` of shape `[batch_size, n]` load: a `Tensor` of shape `[n]`. If we are using noise, this is a smooth approximation of the load, and you can define a loss in terms of it to help with load-balancing. """ with tf.variable_scope('NoisyTopKGating'): hp = self._hp clean_logits = tf.matmul(x, self._w_gate) if hp.noisy_gating: raw_noise_stddev = tf.matmul(x, self._w_noise) noise_stddev = ((tf.nn.softplus(raw_noise_stddev) + hp.noise_epsilon) * (tf.to_float(train))) noisy_logits = clean_logits + ( tf.random_normal(tf.shape(clean_logits)) * noise_stddev) logits = noisy_logits if summaries: tf.summary.histogram('noisy_logits', noisy_logits) tf.summary.histogram('noise_stddev', noise_stddev) else: logits = clean_logits top_logits, top_indices = _MyTopK(logits, min(hp.k + 1, hp.num_experts)) top_k_logits = tf.slice(top_logits, [0, 0], [-1, hp.k]) top_k_indices = tf.slice(top_indices, [0, 0], [-1, hp.k]) top_k_gates = tf.nn.softmax(top_k_logits) # This will be a `Tensor` of shape `[batch_size, n]`, with zeros in the # positions corresponding to all but the top k experts per example. gates = _RowwiseUnsortedSegmentSum(top_k_gates, top_k_indices, hp.num_experts) if hp.noisy_gating and hp.k < hp.num_experts: load = tf.reduce_sum( _ProbInTopK(clean_logits, noisy_logits, noise_stddev, top_logits, hp.k), 0) else: load = _GatesToLoad(gates) if summaries: tf.summary.histogram('importance', tf.reduce_sum(gates, 0)) tf.summary.histogram('load', load) return gates, load @property def vars(self): return self._vars class LocalMixtureOfExperts(object): """A MoE on a single device. """ def __init__(self, gating_hp, expert_hp, input_size, output_size, name): """Create a LocalMixtureOfExperts. Args: gating_hp: hyperparameters for the gating network. e.g. NoisyTopKGatingParams() expert_hp: hyperparameters for the expert networks. e.g. FeedForwardExpertParams() input_size: an integer. output_size: an integer. name: a string. """ self._name = name _SetInputOutputSizes(gating_hp, input_size, None) _SetInputOutputSizes(expert_hp, input_size, output_size) self._gating_hp = gating_hp self._gating = gating_hp.gating_class(gating_hp, name + '_gating') self._expert_hp = expert_hp self._experts = [ expert_hp.expert_class(expert_hp, name + '_%d' % i) for i in xrange(gating_hp.num_experts) ] def Eval(self, x, train=True, per_example_multiplier=None, summaries=False, identifiers=None): """Evaluate mixture of experts. We provide a convenient debugging tool for determining the set of examples that we passed to each expert. The caller may provide a `Tensor` of "identifiers", of any type whose first dimension matches the number of input examples. The function will then return a list "expert_to_identifiers", with one `Tensor` for each expert containing the identifiers for all examples assigned to that expert. A parallel list of `Tensor`s, "expert_to_gates", is also returned, containing the corresponding gate values. Args: x: a `Tensor` of shape `[batch_size, input_size]` train: a boolean Scalar. Are we in training mode? per_example_multiplier: an optional `Tensor` of shape `[batch_size]` which gets multiplied into the gate values. If this LocalMixtureOfExperts represents one secondary MoE in a hierarchical MoE, then we pass in in the gate values from the primary gating function here. This causes the computed values (`y`, `importance` and `expert_to_gates`) to also reflect the primary gate values. summaries: an boolean. Enable summaries. identifiers: an optional `Tensor` whose first dimension is equal to batch_size. Returns: y: a `Tensor` of shape `[batch_size, output_size]`. Output of the MoE. importance: a `Tensor` of shape `[n]`. Batchwise sum of gates. load: a `Tensor` of shape `[n]`. Smooth estimator of the number of examples passed to each expert. This is useful for load-balancing, as any gradient on this `Tensor` will back-propagate to the gating network. expert_to_identifiers: if `identifiers` was passed in, a list of length `num_experts`. Each element is a `Tensor` whose shape matches that of `identifiers` in all but the first dimension. Contains the slices of `identifiers` corresponding to the batch elements that were dispatched to that expert. expert_to_gates: A list of length `num_experts`. Each element contains a 1-dimensional tensor """ gating_hp = self._gating_hp gates, load = self._gating.Eval(x, train, summaries) if per_example_multiplier is not None: gates *= tf.expand_dims(per_example_multiplier, 1) dispatcher = SparseDispatcher(gating_hp.num_experts, gates) expert_input = dispatcher.Dispatch(x) expert_output = [ self._experts[i].Eval(expert_input[i]) for i in xrange(gating_hp.num_experts) ] y = dispatcher.Combine(expert_output) if identifiers is not None: expert_to_identifiers = dispatcher.Dispatch(identifiers) else: expert_to_identifiers = None return (y, tf.reduce_sum(gates, 0), load, expert_to_identifiers, dispatcher.ExpertToGates()) @property def vars(self): ret = [] for x in self._experts: ret.extend(x.vars) ret.extend(self._gating.vars) return ret class DistributedMixtureOfExperts(object): """Distributed (optionally Hierarchical) Mixture of Experts. This class implements the scheme described in our paper. See link at the top of this file. The model is trained synchronously using one large TF graph using multiple devices. The conventional (non-MoE) layers use data-parallelism, with each device processing a subset of the training batch. We call these datashards. The MoE layer (this object) uses model parallelism. Each expert is assigned to a particular device, which hosts the expert parameters and performs the expert computation for all examples assigned to that expert. In the case of a hierarchical MoE, each second-level MoE is assigned to a device. """ def __init__(self, primary_gating_hp, secondary_gating_hp, expert_hp, input_size, output_size, expert_devices, name): """Create a DistributedMixtureOfExperts. If `secondary_gating_hp` is `None`, then this is a flat MoE with `primary_gating_hp.num_experts` experts. Otherwise, this is a hierarchical MoE with `primary_gating_hp.num_experts` groups of `secondary_gating_hp.num_experts` experts. The assignemnt of experts (or groups of experts) to devices is by round-robin. So to make equal use of all the devices, one should set `primary_gating_hp.num_experts` to the number of devices or a multiple thereof. Args: primary_gating_hp: hyperparameters for the primary gating network. e.g. NoisyTopKGatingParams(). secondary_gating_hp: hyperparameters for the secondary gating network. e.g. NoisyTopKGatingParams(). None indicates a flat MoE. expert_hp: hyperparameters for the expert networks. e.g. FeedForwardExpertParams() input_size: an integer. output_size: an integer. expert_devices: a list of device strings. The devices to be used for the experts. name: a string. """ self._name = name # fill in the missing values in the hyperparameters _SetInputOutputSizes(primary_gating_hp, input_size, None) _SetInputOutputSizes(expert_hp, input_size, output_size) self._is_hierarchical = secondary_gating_hp is not None self._primary_gating_hp = primary_gating_hp self._primary_gating = primary_gating_hp.gating_class( primary_gating_hp, name + '_primary_gating') n1 = self._primary_gating_hp.num_experts # round robin assignment of experts to devices. expert_devices = [ expert_devices[i % len(expert_devices)] for i in xrange(n1) ] self._expert_devices = expert_devices self._all_vars = [] self._all_vars.extend(self._primary_gating.vars) if self._is_hierarchical: # hierarchical MoE self._secondary_moe = [] for i in xrange(n1): with tf.device(expert_devices[i]): secondary_moe = LocalMixtureOfExperts(secondary_gating_hp, expert_hp, input_size, output_size, '%s_secondary_%d' % (name, i)) self._secondary_moe.append(secondary_moe) self._all_vars.extend(secondary_moe.vars) else: # flat MoE self._experts = [] for i in xrange(n1): with tf.device(expert_devices[i]): expert = expert_hp.expert_class(expert_hp, name + '_%d' % i) self._experts.append(expert) self._all_vars.extend(expert.vars) def Eval(self, datashard_devices, xs, train=True, summaries=False, identifiers=None, shadow_xs=None): """Evaluate MoE on given inputs. This class is designed for the case where the rest of the model is using data parallelism. We receive an array of input `Tensor`s, one per datashard, and we produce a list of output Tensors, one per datashard. We provide a convenient debugging tool for determining the set of examples that we passed to each expert. The caller may provide a `Tensor` of "identifiers", of any type whose first dimension matches the number of input examples. The function will then return a list "expert_to_identifiers", with one `Tensor` for each expert containing the identifiers for all examples assigned to that expert. A parallel list of `Tensor`s, "expert_to_gates", is also returned, containing the corresponding gate values. Args: datashard_devices: a `list` of device strings of length `num_datashards`. Which devices to use for the output tensors. xs: A `list` of `Tensor`s of length `num_datashards`. Each has shape `[batch_size[d], input_size]. train: a boolean `Scalar`. When train=`True`, noise is added to the gating function. summaries: a boolean. Whether to write summaries. identifiers: an optional list of tensors. Each tensor has shape [<batch_size[datashard]>, extra_dims] shadow_xs: Optional `list` of `Tensor`s of length `num_datashards`. Each has shape `[batch_size[d], input_size]. Shadow_xs is useful if you want to dispatch a transformed version of xs to the experts, but you want untransformed xs for the gating network. Returns: ys: the output (a list of one tensor per datashard). Each has shape `[batch_size[d], output_size]. importance: a `Tensor` of shape `[n]` for a flat MoE or `[n1, n2]` for a hierarchical MoE. Batchwise sum of gates. load: a `Tensor` of shape `[n]` for a flat MoE or `[n1, n2]` for a hierarchical MoE. Smooth estimator of the number of examples passed to each expert. This is useful for load-balancing, as any gradient on this `Tensor` will back-propagate to the gating network. expert_to_identifiers: if `identifiers` was passed in, a list of length `num_experts`. Each element is a `Tensor` whose shape matches that of `identifiers` in all but the first dimension. Contains the slices of `identifiers` corresponding to the batch elements that were dispatched to that expert. expert_to_gates: a list of one tensor per expert. Each tensor has shape [<num_examples[expert]>] """ n1 = self._primary_gating_hp.num_experts epsilon = 1e-10 assert len(datashard_devices) == len(xs) num_datashards = len(xs) expert_devices = self._expert_devices has_identifiers = identifiers is not None # pylint: disable=unbalanced-tuple-unpacking primary_gates, primary_smooth_load = Parallel( datashard_devices, self._primary_gating.Eval, xs, train, [summaries] + [False] * (num_datashards - 1)) primary_importance = tf.add_n( Parallel(datashard_devices, tf.reduce_sum, primary_gates, 0)) primary_smooth_load = tf.add_n(primary_smooth_load) primary_true_load = tf.add_n( Parallel(datashard_devices, _GatesToLoad, primary_gates)) primary_dispatcher = DistributedSparseDispatcher( datashard_devices, expert_devices, primary_gates) if shadow_xs is None: secondary_input = primary_dispatcher.Dispatch(xs) else: secondary_input = primary_dispatcher.Dispatch(shadow_xs) primary_expert_to_identifiers = (primary_dispatcher.Dispatch(identifiers) if has_identifiers else None) primary_expert_to_gates = primary_dispatcher.ExpertToGates() if not self._is_hierarchical: # one-level distributed mixture of experts secondary_output = Parallel(expert_devices, lambda a, b: a.Eval(b), self._experts, secondary_input) ys = primary_dispatcher.Combine(secondary_output) return (ys, primary_importance, primary_smooth_load, primary_expert_to_identifiers, primary_expert_to_gates) # two-level hierarchical MoE (secondary_output, secondary_importance, secondary_load, secondary_expert_to_identifiers, secondary_expert_to_gates) = (Parallel( expert_devices, [m.Eval for m in self._secondary_moe], secondary_input, train, primary_expert_to_gates, [summaries] + [False] * (n1 - 1), primary_expert_to_identifiers)) # pylint: enable=unbalanced-tuple-unpacking ys = primary_dispatcher.Combine(secondary_output, multiply_by_gates=False) importance = tf.stack(secondary_importance) load = tf.stack(secondary_load) * tf.expand_dims(primary_smooth_load / ( primary_true_load + epsilon), 1) expert_to_identifiers = [] if identifiers is not None: for el in secondary_expert_to_identifiers: expert_to_identifiers.extend(el) expert_to_gates = [] for el in secondary_expert_to_gates: expert_to_gates.extend(el) return (ys, importance, load, expert_to_identifiers, expert_to_gates) @property def vars(self): return self._all_vars class SparseDispatcher(object): """Helper for implementing a mixture of experts. Example use: gates: a float32 `Tensor` with shape `[batch_size, num_experts]` inputs: a float32 `Tensor` with shape `[batch_size, input_size]` experts: a list of length `num_experts` containing sub-networks. dispatcher = SparseDispatcher(num_experts, gates) expert_inputs = dispatcher.Dispatch(inputs) expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)] outputs = dispatcher.Combine(expert_outputs) The preceding code sets the output for a particular example b to: output[b] = Sum_i(gates[b, i] * experts[i](inputs[b])) This class takes advantage of sparsity in the gate matrix by including in the `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`. """ def __init__(self, num_experts, gates): """Create a SparseDispatcher. Args: num_experts: an integer. gates: a `Tensor` of shape `[batch_size, num_experts]`. Returns: a SparseDispatcher """ self._gates = gates self._num_experts = num_experts where = tf.to_int32(tf.where(tf.transpose(gates) > 0)) self._expert_index, self._batch_index = tf.unstack(where, num=2, axis=1) self._part_sizes_tensor = tf.reduce_sum(tf.to_int32(gates > 0), [0]) self._nonzero_gates = tf.gather( tf.reshape(self._gates, [-1]), self._batch_index * num_experts + self._expert_index) def Dispatch(self, inp): """Create one input Tensor for each expert. The `Tensor` for a expert `i` contains the slices of `inp` corresponding to the batch elements `b` where `gates[b, i] > 0`. Args: inp: a `Tensor` of shape '[batch_size, <extra_input_dims>]` Returns: a list of `num_experts` `Tensor`s with shapes `[expert_batch_size_i, <extra_input_dims>]`. """ inp = tf.gather(inp, self._batch_index) return tf.split(inp, self._part_sizes_tensor, 0) def Combine(self, expert_out, multiply_by_gates=True): """Sum together the expert output, weighted by the gates. The slice corresponding to a particular batch element `b` is computed as the sum over all experts `i` of the expert output, weighted by the corresponding gate values. If `multiply_by_gates` is set to False, the gate values are ignored. Args: expert_out: a list of `num_experts` `Tensor`s, each with shape `[expert_batch_size_i, <extra_output_dims>]`. multiply_by_gates: a boolean Returns: a `Tensor` with shape `[batch_size, <extra_output_dims>]`. """ # see comments on ConvertGradientToTensor stitched = ConvertGradientToTensor(tf.concat(expert_out, 0)) if multiply_by_gates: stitched *= tf.expand_dims(self._nonzero_gates, 1) combined = tf.unsorted_segment_sum(stitched, self._batch_index, tf.shape(self._gates)[0]) return combined def ExpertToGates(self): """Gate values corresponding to the examples in the per-expert `Tensor`s. Returns: a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32` and shapes `[expert_batch_size_i]` """ return tf.split(self._nonzero_gates, self._part_sizes_tensor, 0) @property def part_sizes(self): return self._part_sizes_tensor class DistributedSparseDispatcher(object): """A distributed version of SparseDispatcher. Instead of one batch of input examples, we simultaneously process num_datashards batches of input examples. The per-expert `Tensor`s contain a combination of examples from the different datashards. Each datashard is associated with a particular device and each expert is associated with a particular device. All per-datashard and per-expert `Tensor`s are created on those devices. There is no single-device bottleneck. """ def __init__(self, datashard_devices, expert_devices, gates): """Create a DistributedSparseDispatcher. Args: datashard_devices: a list of num_datashards device strings. expert_devices: a list of num_experts device strings. gates: a list of num_datashards `Tensor`s of shapes `[batch_size[d], num_experts]`. Returns: a DistributedSparseDispatcher """ self._gates = gates self._num_experts = len(expert_devices) assert len(gates) == len(datashard_devices) self._num_datashards = len(gates) self._datashard_devices = datashard_devices self._expert_devices = expert_devices self._dispatchers = Parallel(self._datashard_devices, SparseDispatcher, self._num_experts, gates) def Dispatch(self, inp): """Create one input Tensor for each expert. Args: inp: a list of length num_datashards `Tensor`s with shapes `[batch_size[d], <extra_input_dims>]`. Returns: a list of `num_experts` `Tensor`s with shapes `[num_examples[i], <extra_input_dims>]`. """ dispatched = Parallel(self._datashard_devices, lambda a, b: a.Dispatch(b), self._dispatchers, inp) ret = Parallel(self._expert_devices, tf.concat, TransposeListOfLists(dispatched), 0) if ret[0].dtype == tf.float32: # see comments on ConvertGradientToTensor ret = Parallel(self._expert_devices, ConvertGradientToTensor, ret) return ret def Combine(self, expert_out, multiply_by_gates=True): """Sum together the expert output, multiplied by the corresponding gates. Args: expert_out: a list of `num_experts` `Tensor`s, each with shape `[expert_batch_size_i, <extra_output_dims>]`. multiply_by_gates: a boolean. Returns: a list of num_datashards `Tensor`s with shapes `[batch_size[d], <extra_output_dims>]`. """ expert_part_sizes = tf.unstack( tf.stack([ self._dispatchers[d].part_sizes for d in xrange(self._num_datashards) ]), num=self._num_experts, axis=1) # list of lists of shape [num_experts][num_datashards] expert_output_parts = Parallel(self._expert_devices, tf.split, expert_out, expert_part_sizes) expert_output_parts_t = TransposeListOfLists(expert_output_parts) ret = [] for d in xrange(self._num_datashards): with tf.device(self._datashard_devices[d]): ret.append(self._dispatchers[d].Combine( # see comments on ConvertGradientToTensor ConvertGradientToTensor(tf.concat(expert_output_parts_t[d], 0)), multiply_by_gates=multiply_by_gates)) return ret def ExpertToGates(self): """Gate values corresponding to the examples in the per-expert `Tensor`s. Returns: a list of `num_experts` one-dimensional `Tensor`s of type `tf.float32`. """ return Parallel(self._expert_devices, tf.concat, TransposeListOfLists( Parallel(self._datashard_devices, [ self._dispatchers[d].ExpertToGates for d in xrange(self._num_datashards) ])), 0) def TransposeListOfLists(lol): """Transpose a list of equally-sized python lists. Args: lol: a list of lists Returns: a list of lists """ assert lol, 'cannot pass the empty list' return [list(x) for x in zip(*lol)] class DistributedSingleDispatcher(object): """Dispatches to experts according to gates. Each example goes to one expert. Unlike SparseDispatcher, the gates are one-dimensional `Tensor`s of integer expert ids. There are no weights. """ def __init__(self, data_parallelism, model_parallelism, gates): """Constructs a Dispatcher. Args: data_parallelism: a Parallelism object. model_parallelism: a Parallelism object. gates: a list of 1d integer `Tensor`s, one per datashard. Says which expert to use for each batch element. Returns: a DistributedSingleDispatcher """ gates = data_parallelism(tf.to_int32, gates) self._gates = gates self._data_parallelism = data_parallelism self._model_parallelism = model_parallelism # Compute the sizes number of examples going from each datashard to each # expert. def _PartSizes(gates): return tf.unsorted_segment_sum( tf.ones_like(gates), gates, model_parallelism.n) part_sizes_by_datashard = data_parallelism(_PartSizes, gates) self._part_sizes_by_expert = tf.unstack( tf.stack(part_sizes_by_datashard), num=model_parallelism.n, axis=1) # These indices will be used to combine the output on the datashards. def _StitchIndices(gates): return tf.dynamic_partition( tf.range(tf.size(gates)), gates, model_parallelism.n) self._stitch_indices = data_parallelism(_StitchIndices, gates) def Dispatch(self, d_tensors): """Reshuffles input `Tensor`s to produce output `Tensor`s. The dimensions of all input and output `Tensor`s match, except for dimension 0. In dimension 0, the input `Tensor`s match the corresponding `gates` `Tensor`s which were passed to the constructor. Args: d_tensors: a list of `Tensor`s, one per datashard. Returns: a list of `Tensor`s, one per expert. """ parts = self._data_parallelism(tf.dynamic_partition, d_tensors, self._gates, self._model_parallelism.n) parts_by_expert = TransposeListOfLists(parts) x_tensors = self._model_parallelism(tf.concat, parts_by_expert, 0) return x_tensors def Combine(self, x_tensors): """Reshuffles per-expert `Tensor`s to produce per-datashard `Tensor`s. Dispatch must have been called at least once first. The dimensions of all input and output `Tensor`s match, except for dimension 0. In dimension 0, the input `Tensor`s match the corresponding outputs of `Dispatch`, and the output `Tensor`s match the corresponding `gates` `Tensor`s which were passed to the constructor. Args: x_tensors: a list of `Tensor`s, one per expert. Returns: a list of `Tensor`s, one per datashard. """ parts = self._model_parallelism(tf.split, x_tensors, self._part_sizes_by_expert) d_tensors = self._data_parallelism(tf.dynamic_stitch, self._stitch_indices, TransposeListOfLists(parts)) return d_tensors def ParallelEmbeddingLookup(params, ids, data_parallelism): """Mod-sharded embedding lookup with multiple datashards. TODO(noam): does this work when vocab_size is not a multiple of `num_shards`? Args: params: A list of `num_shards` `Tensors`, each with shapes `[vocab_size / num_params, depth]`. ids: A list of `num_datashards` one-dimensional ineger `Tensors`, with shapes `[batch_size[i]]` data_parallelism: A Parallelism object. Returns: a list of `num_datashards` `Tensors`, each with shape `[batch_size[i], depth]`. """ param_devices = [x.device for x in params] model_parallelism = Parallelism(param_devices) num_shards = len(param_devices) # pylint: disable=unbalanced-tuple-unpacking ids, unique_idx = data_parallelism(tf.unique, ids) # pylint: enable=unbalanced-tuple-unpacking gates = data_parallelism(tf.mod, ids, num_shards) ids_div = data_parallelism(tf.div, ids, num_shards) dispatcher = DistributedSingleDispatcher(data_parallelism, model_parallelism, gates) x_ids_div = dispatcher.Dispatch(ids_div) params = model_parallelism(ConvertGradientToTensor, params) x_emb = model_parallelism(tf.gather, params, x_ids_div) r_emb = dispatcher.Combine(x_emb) r_emb = data_parallelism(tf.gather, r_emb, unique_idx) return r_emb def SampledSoftmaxLoss(features, sampler, num_classes, target_classes, target_params, sampled_classes, sampled_params): """Loss for training softmax classifiers on large label vocabulary. This function assumes that we have already chosen the sampled classes and fetched the parameters for the target classes and the sampled classes. Args: features: a Tensor with shape [batch_size, hidden_size] sampler: a candidate sampler object (see learning/brain/google/python/ops/candidate_sampling.py) num_classes: an integer target_classes: an integer Tensor with shape [batch_size] target_params: a Tensor with shape [batch_size, hidden_size] The parameters corresponding to the target classes. sampled_classes: an integer tensor with shape [num_sampled_classes] sampled_params: a Tensor with shape [num_sampled_classes, hidden_size] The parameters corresponding to the sampled classes. Returns: a Tensor with shape [batch_size] """ sampled_logits = (tf.matmul(features, sampled_params, transpose_b=True) - sampler.log_expected_count(sampled_classes)) target_logits = (tf.reduce_sum(target_params * features, 1) - sampler.log_expected_count(target_classes)) sampled_log_denominator = tf.reduce_logsumexp( sampled_logits, [1], name='SampledLogDenominator') sampled_classes_mask = tf.unsorted_segment_sum( tf.fill(tf.shape(sampled_classes), float('-inf')), sampled_classes, num_classes) target_log_denominator = ( target_logits + tf.gather(sampled_classes_mask, target_classes)) combined_log_denominator = tf.reduce_logsumexp( tf.stack([sampled_log_denominator, target_log_denominator]), [0]) loss = combined_log_denominator - target_logits return loss def ParallelSampledSoftmaxLoss(params, features, target_classes, sampler, num_classes, data_parallelism, target_weights=None): """Computes sampled softmax loss across many datashards. This is used during training to efficiently train a softmax classifier layer. Args: params: A list of num_param_shards Tensors, each with shape [num_classes / num_param_shards, num_features]. The parameters are assumed to be mod-sharded by class. features: a list of num_datashards Tensors, each with shape [batch_size_i, num_features] target_classes: A list of num_datashards integer Tensors each with shape [batch_size_i] sampler: a candidate sampler object (see learning/brain/google/python/ops/candidate_sampling.py) num_classes: an Integer data_parallelism: a Parallelism object target_weights: an optional list of num_datashards Tensors each with shape [batch_size_i] Returns: a Scalar. """ sampled_classes = data_parallelism(sampler.sample) sampled_params = ParallelEmbeddingLookup(params, sampled_classes, data_parallelism) target_params = ParallelEmbeddingLookup(params, target_classes, data_parallelism) ret = data_parallelism(SampledSoftmaxLoss, features, sampler, num_classes, target_classes, target_params, sampled_classes, sampled_params) if target_weights is not None: ret = data_parallelism(tf.multiply, ret, target_weights) ret = data_parallelism(tf.reduce_sum, ret) ret = tf.add_n(ret) return ret