# coding=utf-8 # Copyright 2020 The Mesh TensorFlow Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """SIMD Mesh implementation (for TPU/XLA).""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import math import os from mesh_tensorflow import ops_with_redefined_builtins as mtf from mesh_tensorflow import tpu_variables from mesh_tensorflow import utils from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow.compat.v1 as tf from tensorflow.python.tpu.ops import tpu_ops # pylint: disable=g-direct-tensorflow-import class SimdMeshImpl(mtf.MeshImpl): """Mesh implementation for TPU using SIMD and MPI operations.""" def __init__(self, shape, layout, devices=None, device_assignment=None, logical_to_physical=None, allreduce_in_bfloat16_max_group_size=32, ): """Create a SimdMeshImpl. Args: shape: an input to mtf.convert_to_shape() layout: an input to mtf.convert_to_layout_rules() devices: deprecated device_assignment: a tf.tpu.experimental.DeviceAssignment - devices must be asssigned in lexicographic order logical_to_physical: an optional permutation representing the mapping from logical cores to "physical" cores, where the physical cores are listed in lexicographic order in the physical mesh, and the logical cores are listed in lexicographic order in the logical mesh. Default is lexicographic order. allreduce_in_bfloat16_max_group_size: an integer. Allreduces of bfloat16 tensors are done in float32 if the group size exceeds this value. """ super(SimdMeshImpl, self).__init__(shape, layout) if devices is not None: tf.logging.warning("SimdMeshImpl ignoring devices %s" % devices) self._device_assignment = device_assignment tf.logging.info("SimdMeshImpl init: {0} {1}".format(shape, layout)) tf.logging.info("Device Assignment: {0}".format(device_assignment)) if logical_to_physical is None: # TODO(noam): maybe use auto_logical_to_physical_tpu() here logical_to_physical = list(range(self.size)) if sorted(logical_to_physical) != list(range(self.size)): raise ValueError( "logical_to_physical must be a permutation on range(shape.size)" " shape=%s logical_to_physical=%s" % (shape, logical_to_physical)) self._logical_to_physical = logical_to_physical self._physical_to_logical = [None] * self.size for logical, physical in enumerate(self._logical_to_physical): self._physical_to_logical[physical] = logical self._pnum_tensor = None self.graph_device_function_stacks = [] self.copy_master_to_slice_ops = [] self._allreduce_in_bfloat16_max_group_size = ( allreduce_in_bfloat16_max_group_size) @property def pnum_tensor(self): if self._pnum_tensor is not None: return self._pnum_tensor with utils.outside_all_rewrites(): tf.logging.info("Create pnum_tensor") self._pnum_tensor = tpu_ops.tpu_replicated_input( self._physical_to_logical, name="pnum_constants") return self._pnum_tensor def l2p(self, logical_pnum): return self._logical_to_physical[logical_pnum] def p2l(self, physical_pnum): return self._physical_to_logical[physical_pnum] class LaidOutTensor(object): """One Slice.""" def __init__(self, tensor_list): assert isinstance(tensor_list, list) self._tensor_list = tensor_list def __repr__(self): return "[" + ",".join([str(t) for t in self._tensor_list]) + "]" @property def tensor_list(self): return self._tensor_list @property def one_slice(self): return self._tensor_list[0] @classmethod def from_tensor_list(cls, tensor_list): return cls(tensor_list) @property def all_slices(self): return self._tensor_list @property def slice_shape(self): return self.one_slice.shape.as_list() def to_laid_out_tensor(self): return self class LaidOutVariable(object): """Maintains slice-variables and copy operations.""" def __init__(self, variable, mesh_impl): """Create a LaidOutVariable. Args: variable: a Variable (Operation) mesh_impl: a MeshImpl """ self._variable = variable self._mesh_impl = mesh_impl shape = variable.outputs[0].shape slice_shape = mesh_impl.slice_shape(shape) base_name = variable.name slices = [] slices_with_master_dtype = [] with tf.device(variable.master_device), utils.outside_all_rewrites(): zero_tensor = tf.zeros(slice_shape, dtype=variable.slice_dtype) # pylint: disable=protected-access init_device_stack = tf.get_default_graph()._device_function_stack if not mesh_impl.graph_device_function_stacks: for pnum in xrange(mesh_impl.size): tpu_device = mesh_impl.device_assignment.tpu_device(replica=pnum) with tf.device(tpu_device): mesh_impl.graph_device_function_stacks.append( tf.get_default_graph()._device_function_stack.copy()) for physical_pnum in xrange(mesh_impl.size): slice_var_name = base_name + "_slice_%d" % physical_pnum # Use tf.Variable instead of tf.get_variable since latter adds lots of # useless operations to the TF graph. Use tf.get_variable only if # in a AUTO_REUSE scope. # Note: Repeatedly 'with tf.device():' slows down the graph # construction. Therefore we directly use the cached device_stack here. tf.get_default_graph()._device_function_stack = ( mesh_impl.graph_device_function_stacks[physical_pnum]) if tf.get_variable_scope().reuse == tf.AUTO_REUSE: slice_var = tf.get_variable( initializer=zero_tensor, trainable=self._variable.trainable, collections=["TPU_VAR"], dtype=variable.slice_dtype, name=slice_var_name) else: slice_var = tf.Variable( initial_value=zero_tensor, trainable=self._variable.trainable, collections=["TPU_VAR"], dtype=variable.slice_dtype, name=slice_var_name, expected_shape=slice_shape) slices.append(slice_var) # Restore the initial stack tf.get_default_graph()._device_function_stack = init_device_stack # pylint: enable=protected-access self._laid_out_tensor = mesh_impl.LaidOutTensor( [tpu_variables.ReplicatedVariable(base_name, slices)]) with tf.device(variable.master_device), utils.outside_all_rewrites(): if os.environ.get("MTF_SEQUENCE_MODE", "") == "1": if mesh_impl.copy_master_to_slice_ops: with tf.control_dependencies( [mesh_impl.copy_master_to_slice_ops[-1]]): self._copy_master_to_slices = self._gen_copy_master_to_slices_op( variable.get_master(), shape, slices, slice_shape) else: self._copy_master_to_slices = self._gen_copy_master_to_slices_op( variable.get_master(), shape, slices, slice_shape) mesh_impl.copy_master_to_slice_ops.append(self._copy_master_to_slices) else: self._copy_master_to_slices = self._gen_copy_master_to_slices_op( variable.get_master(), shape, slices, slice_shape) slices_with_master_dtype = [ tf.cast(s, variable.master_dtype) for s in slices] slices_with_master_dtype = [ slices_with_master_dtype[mesh_impl.l2p(logical_pnum)] for logical_pnum in range(mesh_impl.size)] self._copy_slices_to_master = variable.assign_to_master( mesh_impl.combine_slices(slices_with_master_dtype, shape, device=variable.master_device)) def _gen_copy_master_to_slices_op(self, master_variable, master_shape, slices, slice_shape): """Generate ops which slices master and assign to slices. Args: master_variable: The master variable. master_shape: The shape of master variable. slices: The list of slice-variables in physical order. slice_shape: The shape of the slice variable. Returns: A grouped tf.assign ops. """ mesh_impl = self._mesh_impl master_layout = mesh_impl.tensor_layout(master_shape) # For handling case: master is float32 and slices are bfloat16. if master_variable.dtype != slices[0].dtype: master_variable = tf.cast(master_variable, slices[0].dtype) assign_ops = [] if master_layout.is_fully_replicated: assign_ops = [tf.assign(t, master_variable) for t in slices] else: slice_dict = {} for logical_pnum in xrange(len(slices)): slice_begin = mesh_impl.slice_begin(master_shape, logical_pnum) slice_begin_tuple = tuple(slice_begin) # Reuse the same slice if slice_begin doesn't change. if slice_begin_tuple not in slice_dict: slice_dict[slice_begin_tuple] = tf.slice(master_variable, slice_begin, slice_shape) physical_pnum = mesh_impl.l2p(logical_pnum) assign_ops.append( tf.assign(slices[physical_pnum], slice_dict[slice_begin_tuple])) return tf.group(assign_ops) def assign_to_slices(self, assign_fn, values, assign_to_tensor_list=None): """Assign to the slice variables. Args: assign_fn: a function from (mtf.Variable, tf.Variable, tf.Tensor) -> tf.Operation values: a list of tf.Tensor assign_to_tensor_list: an optional list of tf.Variable Returns: a tf.operation """ if assign_to_tensor_list is None: assign_to_tensor_list = self._laid_out_tensor.all_slices # Handle both N -> 1 and N -> N cases. num_slices = min(len(assign_to_tensor_list), len(values)) devices = [""] * num_slices return tf.group( mtf.parallel(devices, assign_fn, [self._variable] * len(devices), assign_to_tensor_list[:num_slices], values[:num_slices])) @property def laid_out_tensor(self): return self._laid_out_tensor @property def copy_master_to_slices(self): return self._copy_master_to_slices @property def copy_slices_to_master(self): return self._copy_slices_to_master def laid_out_pnum(self): """Returns a LaidOutTensor containing the logical processor number. Returns: a LaidOutTensor where each slice is an integer scalar """ return self.LaidOutTensor([self.pnum_tensor]) def _create_group_assignment(self, mesh_axes): """Create group assignment for XLA cross replica ops (physical pnums).""" partitioning = {} for logical_pnum in xrange(self.size): group = mtf.pnum_to_group(self.shape, mesh_axes, logical_pnum) if group not in partitioning: partitioning[group] = [] partitioning[group].append(self.l2p(logical_pnum)) group_assignment = [] for group, physical_pnums in partitioning.items(): group_assignment.append(physical_pnums) return group_assignment def allreduce(self, x, mesh_axes, reduction_fn_string): """Grouped allreduce, (summed across the given dimensions). Args: x: a LaidOutTensor mesh_axes: a list of integers reduction_fn_string: "SUM" Returns: a LaidOutTensor Raises: ValueError: if the reduction is not yet implemented. """ if not mesh_axes: return x x = x.to_laid_out_tensor() if reduction_fn_string == "SUM": group_assignment = self._create_group_assignment(mesh_axes) group_size = len(group_assignment[0]) tf_in = x.one_slice dtype = tf_in.dtype if dtype == tf.float32: cast_to_float32 = False elif dtype == tf.bfloat16: cast_to_float32 = ( group_size > self._allreduce_in_bfloat16_max_group_size) else: tf.logging.info("Casting %s to float32 for allreduce" % tf_in.dtype) cast_to_float32 = True if cast_to_float32: tf_in = tf.cast(tf_in, tf.float32) tf_out = tpu_ops.cross_replica_sum(tf_in, group_assignment) if cast_to_float32: tf_out = tf.cast(tf_out, dtype) return self.LaidOutTensor([tf_out]) else: for axis in mesh_axes: x = self.allconcat(x, axis, 0, stack=True) x = self.LaidOutTensor( [mtf.reduction_fn(reduction_fn_string)(x.one_slice, 0)]) return x def allconcat(self, x, mesh_axis, concat_axis, stack=False): """Grouped allconcat (like MPI allgather followed by concat). TODO(noam): inefficient - replace with a XLA allconcat when available Args: x: a LaidOutTensor mesh_axis: an integer - the mesh axis along which to group concat_axis: an integer (the Tensor axis along which to concatenate) stack: a boolean - whether to stack instead of concat Returns: a LaidOutTensor """ x = x.to_laid_out_tensor() coord = self.laid_out_pcoord(mesh_axis) t = x.one_slice old_shape = t.shape.as_list() num_parts = self.shape[mesh_axis].size t = tf.expand_dims(t, concat_axis) t *= tf.reshape( tf.one_hot(coord.one_slice, num_parts, dtype=t.dtype), [num_parts if i == concat_axis else 1 for i in xrange(len(old_shape) + 1)]) if not stack: new_shape = old_shape[:] new_shape[concat_axis] *= num_parts t = tf.reshape(t, new_shape) return self.allreduce(self.LaidOutTensor([t]), [mesh_axis], "SUM") def alltoall(self, x, mesh_axis, split_axis, concat_axis): """Grouped alltoall (like MPI alltoall with splitting and concatenation). Args: x: a LaidOutTensor mesh_axis: an integer the mesh axis along which to group split_axis: an integer (the Tensor axis along which to split) concat_axis: an integer (the Tensor axis along which to concatenate) Returns: a LaidOutTensor """ x = x.to_laid_out_tensor() t = x.one_slice group_assignment = self._create_group_assignment([mesh_axis]) dtype = t.dtype if dtype == tf.float32: # There seems to be a bug with float32 alltoall. # Do it in bfloat16 until the bug is fixed. # TODO(noam): file a bug t = tf.to_bfloat16(t) t = tpu_ops.all_to_all( t, concat_dimension=concat_axis, split_dimension=split_axis, split_count=len(group_assignment[0]), group_assignment=group_assignment) t = tf.cast(t, dtype) x = self.LaidOutTensor([t]) return x def receive(self, x, mesh_axis, source_pcoord): """Collective receive in groups. Each group contains the processors that differ only in mesh_axis. ```python group_size = self.shape[mesh_axis].size ``` Args: x: a LaidOutTensor mesh_axis: an integer source_pcoord: a list of optional integers. Each element is either None or an integer in [0, group_size). If source_pcoord[k] is None, then the output for the k-th processor in each group is a zero tensor. If source_pcoord[k] is not None, then the output for the k-th processor in each group is equal to the input for the source_pcoord[k]-th processor in that group. Returns: a LaidOutTensor """ x = x.to_laid_out_tensor() t = x.one_slice source_target_pairs = [] for pnum in xrange(self.size): coord = mtf.pnum_to_processor_coordinates(self.shape, pnum) k = coord[mesh_axis] if source_pcoord[k] is not None: coord[mesh_axis] = source_pcoord[k] source_pnum = mtf.processor_coordinates_to_pnum(self.shape, coord) source_target_pairs.append( [self.l2p(source_pnum), self.l2p(pnum)]) if not source_target_pairs: ret = tf.zeros_like(t, t.dtype) elif t.dtype in [tf.float32, tf.bfloat16, tf.int32]: ret = tpu_ops.collective_permute(t, source_target_pairs) else: # If t is not one of the allowed types, cast and cast back. ret = tf.cast(tpu_ops.collective_permute( tf.cast(t, tf.float32), source_target_pairs), t.dtype) return self.LaidOutTensor([ret]) def slice(self, tf_tensor, tensor_shape): """"Slice out the corresponding part of tensor given the pnum variable.""" tensor_layout = self.tensor_layout(tensor_shape) if tensor_layout.is_fully_replicated: return self.LaidOutTensor([tf_tensor]) else: slice_shape = self.slice_shape(tensor_shape) slice_begins = [ self.slice_begin(tensor_shape, pnum) for pnum in xrange(self.size) ] slice_begins_tensor = tf.stack(slice_begins) # slice on source device selected_slice_begin = tf.gather(slice_begins_tensor, self.pnum_tensor) return self.LaidOutTensor( [tf.slice(tf_tensor, selected_slice_begin, slice_shape)]) def slicewise(self, fn, *inputs): """Execute a function in parallel on all slices. Args: fn: a function from tf.Tensors to tf.Tensor or a tuple of tf.Tensors. *inputs: a list of inputs. Each input is either a LaidOutTensor or is convertible to a tf.Tensor. Returns: a LaidOutTensor, or a tuple of LaidOutTensors if fn returns a tuple. """ # convert all inputs to LaidOutTensor where possible inputs = mtf.convert_args_to_laid_out_tensors(inputs) ret = fn(*[ x.one_slice if isinstance(x, self.LaidOutTensor) else x for x in inputs]) if isinstance(ret, tuple): return tuple([self.LaidOutTensor([t]) for t in ret]) else: return self.LaidOutTensor([ret]) @property def device_assignment(self): return self._device_assignment @property def devices(self): return self._devices def random(self, shape, tf_fn, kwargs): """Call a random tf operation (e.g. random_uniform). Args: shape: a Shape tf_fn: a function such as tf.random.uniform kwargs: kwargs to pass to tf_fn, except for seed Returns: a LaidOutTensor """ # TODO(noam): can we make things better with stateless_random? slice_shape = self.slice_shape(shape) x = tf_fn(slice_shape, **kwargs) # TPU does not have seeds enabled. Sync up the # random choices by zeroing out all but the first core per group of # identical slices, then allreducing by group. layout = self.tensor_layout(shape) # we need to sync across these axes. mesh_axes = [i for i in xrange(self.ndims) if i not in layout.tensor_axis_to_mesh_axis] multiplier = 1.0 for axis in mesh_axes: multiplier *= tf.cast( tf.equal(self.laid_out_pcoord(axis).one_slice, 0), x.dtype) x *= multiplier x = self.LaidOutTensor([x]) x = self.allreduce(x, mesh_axes, "SUM") return x def export_to_tf_tensor(self, x, laid_out_x): """Turn a Tensor into a tf.Tensor. Args: x: a Tensor laid_out_x: a LaidOutTensor Returns: a tf.Tensor """ tensor_layout = self.tensor_layout(x.shape) if not tensor_layout.is_fully_replicated: raise NotImplementedError( "SimdMeshImpl only supports export_to_tf_tensor of fully-replicated " "Tensors. Try reshaping to new dimension names. " " x.shape = %s tensor_layout=%s" % (x.shape, tensor_layout)) return laid_out_x.one_slice def import_tf_tensor(self, x, tf_x): """Import a tf.Tensor, producing a LaidOutTensor. Args: x: a Tensor tf_x: a tf.Tensor Returns: a LaidOutTensor """ return self.slice(tf_x, x.shape) @property def supports_control_dependencies(self): return False def einsum(self, equation, *slices): """Override this for custom einsum implementation. Args: equation: a string *slices: a list of tf.Tensor Returns: a tf.Tensor """ return tf.einsum(equation, *slices) def _ring_2d(m, n): """Ring-order of a mxn mesh. If m and n are both even, then we generate a ring like this: 0 -- 1 -- 2 -- 3 | | | | 15-- 6 -- 5 -- 4 | | | | 14-- 7 -- 8 -- 9 | | | | 13-- 12-- 11-- 10 Args: m: an integer n: an integer Returns: a list of mxn pairs """ if m == 1: return [(0, i) for i in range(n)] if n == 1: return [(i, 0) for i in range(m)] if m % 2 != 0: tf.logging.warning("Odd dimension") return [(i % m, i // m) for i in range(n * m)] ret = [(0, 0)] for i in range(m // 2): for j in range(1, n): ret.append((2 * i, j)) for j in range(n-1, 0, -1): ret.append((2 * i + 1, j)) for i in range(m-1, 0, -1): ret.append((i, 0)) return ret def _logical_1d_to_physical_subspace_auto(sizes_and_strides, physical_shape): """Maps logical 1d mesh to subspace of physical nd mesh. We are mapping a 1d logical mesh to a subspace (a strided slice containing the origin) of a n-dimensional physical mesh. output[i] contains the coordinate-tuple in the physical mesh for the i-th logical processor. sizes_and_strides is a list of (size, stride) pairs specifying the dimensions of the strided slice. For example, sizes_and_strides=[(2, 16), (4, 1)] would represent the slice containing [(0, 0), (0, 1), (0, 2), (0, 3), (16, 0), (16, 1), (16, 2), (16, 3)] This function heuristically picks an order, with the goal of optimizing allreduce performance. Args: sizes_and_strides: a list of n (size, stride) pairs physical_shape: ignored Returns: a list of coordinate-lists """ del physical_shape ndims = len(sizes_and_strides) sizes = [p[0] for p in sizes_and_strides] strides = [p[1] for p in sizes_and_strides] n = mtf.list_product(sizes) if ndims >= 2 and sizes[0] > 1 and sizes[1] > 1: ring = _ring_2d(sizes[0], sizes[1]) ret = [] sizes_combined = [sizes[0] * sizes[1]] + sizes[2:] for logical_pnum in range(n): logical_coord = mtf.pnum_to_processor_coordinates( sizes_combined, logical_pnum) ret.append(list(ring[logical_coord[0]]) + logical_coord[1:]) else: ret = [mtf.pnum_to_processor_coordinates(sizes, logical_pnum) for logical_pnum in range(n)] # multiply by strides ret = [[x * stride for x, stride in zip(pcoord, strides)] for pcoord in ret] return ret def _logical_to_physical_v1( sizes_and_strides, physical_shape, fn_1d=_logical_1d_to_physical_subspace_auto): """Maps logical m-dimensional mesh to physical n-dimensional mesh. Also see comments to _logical_1d_to_physical_subspace_auto. We are mapping a m-dimensonal logical mesh to a n-dimensional physical mesh. output[i] contains the coordinate-tuple in the physical mesh for the i-th logical processor (if the logical processors are ordered lexicographically). sizes_and_strides is a list of m lists of n (size, stride) pairs. sizes_and_strides[i] specifies the subspace (strided slice containing the origin) of the physical mesh covered by axis i of the logical mesh. See comments to _logical_1d_to_physical_subspace_auto for more detail. For example, say we have a physical mesh with shape [4, 4, 2] and a logical mesh with shape [4, 8]. We want to divide the physical mesh into 4 tiles, each with shape [2, 2, 2]. The first logical dimension corresponds to which tile, and the second logical dimension corresponds to position within a tile. This would correspond to: physical_shape=[4, 4, 2] sizes_and_strides=[[(2, 2), (2, 2), (1, 2)], [(2, 1), (2, 1), (2, 1)]] physical_shape can be inferred from sizes_and_strides, but is passed in for error checking. Args: sizes_and_strides: a list of m list of n (size, stride) pairs physical_shape: a list of integers fn_1d: a function like _logical_1d_to_physical_subspace_auto Returns: a list of coordinate-lists """ pndims = len(physical_shape) logical_shape = [ mtf.list_product([p[0] for p in l]) for l in sizes_and_strides] n = mtf.list_product(physical_shape) if n != mtf.list_product(logical_shape): raise ValueError( "logical size and physical size must match " "- got sizes_and_strides=%s physical_shape=%s" % (sizes_and_strides, physical_shape)) dimension_layouts = [fn_1d(l, physical_shape) for l in sizes_and_strides] tf.logging.info("physical_shape: %s" % physical_shape) tf.logging.info("sizes_and_strides: %s" % sizes_and_strides) for i, l in enumerate(dimension_layouts): tf.logging.info("dimension_layout %s: %s" % (i, l)) ret = [] for logical_pnum in range(n): logical_coordinates = mtf.pnum_to_processor_coordinates( logical_shape, logical_pnum) physical_coordinates = [0] * pndims for logical_axis, logical_coord in enumerate(logical_coordinates): for physical_axis in range(pndims): physical_coordinates[physical_axis] += ( dimension_layouts[logical_axis][logical_coord][physical_axis]) ret.append(physical_coordinates) # verify that we have indeed covered all the processors l2p = [mtf.processor_coordinates_to_pnum(physical_shape, c) for c in ret] if sorted(l2p) != list(range(n)): raise ValueError( "logical_to_physical produced something that was not a permutation." " sizes_and_strides=%s physical_shape=%s ret=%s" % (sizes_and_strides, physical_shape, ret)) return ret class HierarchicalTiling(object): """One kind of mapping of a logical mesh to a physical mesh.""" def __init__(self, spec, physical_shape): """Constructs a HierarchicalTiling. spec is a list corresponding to the logical dimensions. spec[i] corresponds to the i-th logical dimension and consists of a name and a list of integers, the list being the shape of logical axis i when it is physically projected to the physical mesh and then compacted. Striding information is omitted. By convention, the earlier dimensions get more strided. so the axis corresponding to the last dimension always gets projected to the tile specified by its shape. Args: spec: a list of (string, list-of-integers) pairs physical_shape: a list of integers """ self._names = [p[0] for p in spec] logical_ndims = len(spec) physical_ndims = len(physical_shape) projected_shapes = [p[1] for p in spec] if logical_ndims > 0 and projected_shapes[0] is None: # fill in missing value projected_shapes[0] = list(physical_shape) for s in projected_shapes[1:]: for i, x in enumerate(s): projected_shapes[0][i] //= x # compute strides, and verify that the spec is valid. products = [1] * physical_ndims sizes_and_strides = [] for s in reversed(projected_shapes): sizes_and_strides.append( [(size, stride) for size, stride in zip(s, products)]) for i, x in enumerate(s): products[i] *= x if products != physical_shape: raise ValueError("mesh spec multiplies to the wrong size" "spec=%s physical_shape=%s products=%s" % (spec, physical_shape, products)) sizes_and_strides.reverse() self._physical_coordinates = _logical_to_physical_v1( sizes_and_strides, physical_shape) self._logical_to_physical = [ mtf.processor_coordinates_to_pnum(physical_shape, c) for c in self._physical_coordinates] self._mesh_shape = mtf.Shape( [mtf.Dimension(name, mtf.list_product(s)) for name, s in zip(self._names, projected_shapes)]) @property def logical_to_physical(self): """List of physical processor numbers.""" return list(self._logical_to_physical) @property def mesh_shape(self): return self._mesh_shape @classmethod def spec_to_mesh_shape(cls, spec, num_processors): """Compute mesh shape even without knowing the physical shape. This is useful in cases where the mesh shape must be computed before you know the physical_shape. Args: spec: a list of (string, list-of-integers) pairs num_processors: an integer Returns: a mtf.Shape """ logical_ndims = len(spec) names = [p[0] for p in spec] sizes = [p[1] for p in spec] sizes = [None if s is None else mtf.list_product(s) for s in sizes] if logical_ndims > 0 and sizes[0] is None: sizes[0] = num_processors // mtf.list_product(sizes[1:]) if mtf.list_product(sizes) != num_processors: raise ValueError("product of spec must be num_processors" " spec=%s num_processors=%s" % (spec, num_processors)) return mtf.Shape( [mtf.Dimension(name, s) for name, s in zip(names, sizes)]) def physical_shape_3d_from_topology_proto_4d(mesh_shape): """Convert a 4d shape that we get from TPU estimator to a 3d shape. Args: mesh_shape: a list of length 4 Returns: a list of length 3 """ if len(mesh_shape) != 4 or mesh_shape[2] != 1: raise ValueError("Expected a 4d shape [x, y, 1, core]") return [mesh_shape[1], mesh_shape[0], mesh_shape[3]] def auto_logical_to_physical_tpu(logical_shape, physical_shape, return_coordinates=False): """Set up a mapping from logical to physical cores for TPU. We will try to set up a mapping so that allreduce operations are relatively fast, prioritizing the later dimensions in the mesh_shape. Example: auto_logical_to_physical_tpu( logical_shape=[16, 8], physical_shape=[8, 8, 1, 2]) Heuristics in this function subject to change. Args: logical_shape: a list of integers physical_shape: a list of integers - typically [X, Y, 1, cores] return_coordinates: a boolean - return a list of integer lists (coordinates) instead of a list of processor indices Returns: logical_to_physical: a permutation of range(product(physical_shape))) """ tf.logging.info("auto_logical_to_physical_tpu " "logical_shape=%s physical_shape=%s" % (logical_shape, physical_shape)) if mtf.list_product(logical_shape) != mtf.list_product(physical_shape): raise ValueError( "physical and logical shapes must have the same product " "physical_shape=%s logical_shape=%s" % (physical_shape, logical_shape)) # drop logical dimensions of size 1 logical_shape = [i for i in logical_shape if i != 1] num_cores = mtf.list_product(logical_shape) # For physical shapes different from what we are used to [2^a, 2^b, 2], # return a simple default value (a lexicographic ordering) def _default_value(): default = list(range(num_cores)) if return_coordinates: default = [mtf.pnum_to_processor_coordinates(i) for i in default] return default if len(physical_shape) == 4 and physical_shape[2] == 1: physical_shape = physical_shape_3d_from_topology_proto_4d(physical_shape) elif len(physical_shape) != 3: tf.logging.warning("Unrecognized format for tpu physical shape") return _default_value() # physical_shape is a triple of rows, cols, cores p0, p1, p2 = physical_shape if p2 != 2: return _default_value for dimsize in [p0, p1]: # if dimsize not a power of 2, give up if dimsize & (dimsize - 1): return _default_value() # At this point, the physical shape has at least 1x1x2=2 cores, so there # must be at least one logical dimension. assert logical_shape if len(logical_shape) == 1: # ring of p0 x p1 chips ring = _ring_2d(p0, p1) logical_to_physical = [] for logical_pnum in range(num_cores): core_on_chip = logical_pnum % 2 chip_num = logical_pnum // 2 i, j = ring[chip_num] logical_to_physical.append((i, j, core_on_chip)) else: # We have a p0 x p1 rectangle of chips, which we will tile with rectangular # tiles. The first logical dimension correspond to the number of tiles, # and the other logical dimensions will correspond to position within a # tile. num_tiles = logical_shape[0] tile_chips = num_cores // num_tiles // p2 # If we can, we make each tile occupy exactly one row or column of chips. # Otherwise, we make each tile approximately square. if len(logical_shape) == 2 and tile_chips == p0: t0, t1 = [tile_chips, 1] elif len(logical_shape) == 2 and tile_chips == p1: t0, t1 = [1, tile_chips] else: # try to make the tile approximately square lg_tile_chips = int(math.log(tile_chips, 2)) t0 = 2 ** (lg_tile_chips // 2) # make sure that the tile fits in the mesh - i.e. # t0 <= p0 # t1 == tile_chips // t0 <= p1 t0 = min(t0, p0) t0 = max(t0, tile_chips // p1) t1 = tile_chips // t0 # recursive call to find mapping for one tile tile_logical_to_physical = auto_logical_to_physical_tpu( logical_shape[1:], [t0, t1, p2], return_coordinates=True) tiles_ring = _ring_2d(p0 // t0, p1 // t1) logical_to_physical = [] for logical_pnum in range(num_cores): logical_tile_num = logical_pnum // (t0 * t1 * p2) logical_pos_in_tile = logical_pnum % (t0 * t1 * p2) logical_to_physical.append(( tiles_ring[logical_tile_num][0] * t0 + tile_logical_to_physical[logical_pos_in_tile][0], tiles_ring[logical_tile_num][1] * t1 + tile_logical_to_physical[logical_pos_in_tile][1], tile_logical_to_physical[logical_pos_in_tile][2])) tf.logging.info("auto_logical_to_physical_tpu logical_to_physical = %s" % logical_to_physical) if return_coordinates: return logical_to_physical else: return [mtf.processor_coordinates_to_pnum(physical_shape, coord) for coord in logical_to_physical]