# coding=utf-8 # Copyright 2019 The SEED Authors # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utility functions/classes.""" import collections import threading import timeit from absl import logging import numpy as np import tensorflow as tf import tensorflow_probability as tfp from tensorflow.python.distribute import values as values_lib from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import tensor_conversion_registry # `observation` is the observation *after* a transition. When `done` is True, # `observation` will be the observation *after* the reset. EnvOutput = collections.namedtuple( 'EnvOutput', 'reward done observation abandoned episode_step') Settings = collections.namedtuple( 'Settings', 'strategy inference_devices training_strategy encode decode') def init_learner(num_training_tpus): """Performs common learner initialization.""" if tf.config.experimental.list_logical_devices('TPU'): resolver = tf.distribute.cluster_resolver.TPUClusterResolver('') topology = tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.experimental.TPUStrategy(resolver) training_da = tf.tpu.experimental.DeviceAssignment.build( topology, num_replicas=num_training_tpus) training_strategy = tf.distribute.experimental.TPUStrategy( resolver, device_assignment=training_da) inference_devices = list(set(strategy.extended.worker_devices) - set(training_strategy.extended.worker_devices)) return Settings(strategy, inference_devices, training_strategy, tpu_encode, tpu_decode) else: tf.device('/cpu').__enter__() any_gpu = tf.config.experimental.list_logical_devices('GPU') device_name = '/device:GPU:0' if any_gpu else '/device:CPU:0' strategy = tf.distribute.OneDeviceStrategy(device=device_name) enc = lambda x: x dec = lambda x, s=None: x if s is None else tf.nest.pack_sequence_as(s, x) return Settings(strategy, [device_name], strategy, enc, dec) class UnrollStore(tf.Module): """Utility module for combining individual actor steps into unrolls.""" def __init__(self, num_actors, unroll_length, timestep_specs, num_overlapping_steps=0, name='UnrollStore'): super(UnrollStore, self).__init__(name=name) with self.name_scope: self._full_length = num_overlapping_steps + unroll_length + 1 def create_unroll_variable(spec): z = tf.zeros( [num_actors, self._full_length] + spec.shape.dims, dtype=spec.dtype) return tf.Variable(z, trainable=False, name=spec.name) self._unroll_length = unroll_length self._num_overlapping_steps = num_overlapping_steps self._state = tf.nest.map_structure(create_unroll_variable, timestep_specs) # For each actor, the index into the actor dimension of the tensors in # self._state where we should add the next element. self._index = tf.Variable( tf.fill([num_actors], tf.constant(num_overlapping_steps, tf.int32)), trainable=False, name='index') @property def unroll_specs(self): return tf.nest.map_structure(lambda v: tf.TensorSpec(v.shape[1:], v.dtype), self._state) @tf.function @tf.Module.with_name_scope def append(self, actor_ids, values): """Appends values and returns completed unrolls. Args: actor_ids: 1D tensor with the list of actor IDs for which we append data. There must not be duplicates. values: Values to add for each actor. This is a structure (in the tf.nest sense) of tensors following "timestep_specs", with a batch front dimension which must be equal to the length of 'actor_ids'. Returns: A pair of: - 1D tensor of the actor IDs of the completed unrolls. - Completed unrolls. This is a structure of tensors following 'timestep_specs', with added front dimensions: [num_completed_unrolls, num_overlapping_steps + unroll_length + 1]. """ tf.debugging.assert_equal( tf.shape(actor_ids), tf.shape(tf.unique(actor_ids)[0]), message='Duplicate actor ids') tf.nest.map_structure( lambda s: tf.debugging.assert_equal( tf.shape(actor_ids)[0], tf.shape(s)[0], message='Batch dimension must be same size as number of actors.'), values) curr_indices = self._index.sparse_read(actor_ids) unroll_indices = tf.stack([actor_ids, curr_indices], axis=-1) for s, v in zip(tf.nest.flatten(self._state), tf.nest.flatten(values)): s.scatter_nd_update(unroll_indices, v) # Intentionally not protecting against out-of-bounds to make it possible to # detect completed unrolls. self._index.scatter_add(tf.IndexedSlices(1, actor_ids)) return self._complete_unrolls(actor_ids) @tf.function @tf.Module.with_name_scope def reset(self, actor_ids): """Resets state. Note, this is only intended to be called when actors need to be reset after preemptions. Not at episode boundaries. Args: actor_ids: The actors that need to have their state reset. """ self._index.scatter_update( tf.IndexedSlices(self._num_overlapping_steps, actor_ids)) # The following code is the equivalent of: # s[actor_ids, :j] = 0 j = self._num_overlapping_steps repeated_actor_ids = tf.reshape( tf.tile(tf.expand_dims(tf.cast(actor_ids, tf.int64), -1), [1, j]), [-1]) repeated_range = tf.tile(tf.range(j, dtype=tf.int64), [tf.shape(actor_ids)[0]]) indices = tf.stack([repeated_actor_ids, repeated_range], axis=-1) for s in tf.nest.flatten(self._state): z = tf.zeros(tf.concat([tf.shape(repeated_actor_ids), s.shape[2:]], axis=0), s.dtype) s.scatter_nd_update(indices, z) def _complete_unrolls(self, actor_ids): # Actor with unrolls that are now complete and should be returned. actor_indices = self._index.sparse_read(actor_ids) actor_ids = tf.gather( actor_ids, tf.where(tf.equal(actor_indices, self._full_length))[:, 0]) actor_ids = tf.cast(actor_ids, tf.int64) unrolls = tf.nest.map_structure(lambda s: s.sparse_read(actor_ids), self._state) # Store last transitions as the first in the next unroll. # The following code is the equivalent of: # s[actor_ids, :j] = s[actor_ids, -j:] j = self._num_overlapping_steps + 1 repeated_start_range = tf.tile(tf.range(j, dtype=tf.int64), [tf.shape(actor_ids)[0]]) repeated_end_range = tf.tile( tf.range(self._full_length - j, self._full_length, dtype=tf.int64), [tf.shape(actor_ids)[0]]) repeated_actor_ids = tf.reshape( tf.tile(tf.expand_dims(actor_ids, -1), [1, j]), [-1]) start_indices = tf.stack([repeated_actor_ids, repeated_start_range], -1) end_indices = tf.stack([repeated_actor_ids, repeated_end_range], -1) for s in tf.nest.flatten(self._state): s.scatter_nd_update(start_indices, s.gather_nd(end_indices)) self._index.scatter_update( tf.IndexedSlices(1 + self._num_overlapping_steps, actor_ids)) return actor_ids, unrolls class PrioritizedReplay(tf.Module): """Prioritized Replay Buffer. This buffer is not threadsafe. Make sure you call insert() and sample() from a single thread. """ def __init__(self, size, specs, importance_sampling_exponent, name='PrioritizedReplay'): super(PrioritizedReplay, self).__init__(name=name) self._priorities = tf.Variable(tf.zeros([size]), dtype=tf.float32) self._buffer = tf.nest.map_structure( lambda ts: tf.Variable(tf.zeros([size] + ts.shape, dtype=ts.dtype)), specs) self.num_inserted = tf.Variable(0, dtype=tf.int64) self._importance_sampling_exponent = importance_sampling_exponent @tf.function @tf.Module.with_name_scope def insert(self, values, priorities): """FIFO insertion/removal. Args: values: The batched values to insert. The tensors must be of the same shape and dtype as the `specs` provided in the constructor, except including a batch dimension. priorities: <float32>[batch_size] tensor with the priorities of the elements we insert. Returns: The indices of the inserted values. """ tf.nest.assert_same_structure(values, self._buffer) values = tf.nest.map_structure(tf.convert_to_tensor, values) append_size = tf.nest.flatten(values)[0].shape[0] start_index = self.num_inserted end_index = start_index + append_size # Wrap around insertion. size = self._priorities.shape[0] insert_indices = tf.range(start_index, end_index) % size tf.nest.map_structure( lambda b, v: b.batch_scatter_update( tf.IndexedSlices(v, insert_indices)), self._buffer, values) self.num_inserted.assign_add(append_size) self._priorities.batch_scatter_update( tf.IndexedSlices(priorities, insert_indices)) return insert_indices @tf.function @tf.Module.with_name_scope def sample(self, num_samples, priority_exp): r"""Samples items from the replay buffer, using priorities. Args: num_samples: int, number of replay items to sample. priority_exp: Priority exponent. Every item i in the replay buffer will be sampled with probability: priority[i] ** priority_exp / sum(priority[j] ** priority_exp, j \in [0, num_items)) Set this to 0 in order to get uniform sampling. Returns: Tuple of: - indices: An int64 tensor of shape [num_samples] with the indices in the replay buffer of the sampled items. - weights: A float32 tensor of shape [num_samples] with the normalized weights of the sampled items. - sampled_values: A nested structure following the spec passed in the contructor, where each tensor has an added front batch dimension equal to 'num_samples'. """ tf.debugging.assert_greater_equal( self.num_inserted, tf.constant(0, tf.int64), message='Cannot sample if replay buffer is empty') size = self._priorities.shape[0] limit = tf.minimum(tf.cast(size, tf.int64), self.num_inserted) if priority_exp == 0: indices = tf.random.uniform([num_samples], maxval=limit, dtype=tf.int64) weights = tf.ones_like(indices, dtype=tf.float32) else: prob = self._priorities[:limit]**priority_exp prob /= tf.reduce_sum(prob) indices = tf.random.categorical([tf.math.log(prob)], num_samples)[0] # Importance weights. weights = (((1. / tf.cast(limit, tf.float32)) / tf.gather(prob, indices)) ** self._importance_sampling_exponent) weights /= tf.reduce_max(weights) # Normalize. sampled_values = tf.nest.map_structure( lambda b: b.sparse_read(indices), self._buffer) return indices, weights, sampled_values @tf.function @tf.Module.with_name_scope def update_priorities(self, indices, priorities): """Updates the priorities of the items with the given indices. Args: indices: <int64>[batch_size] tensor with the indices of the items to update. If duplicate indices are provided, the priority that will be set among possible ones is not specified. priorities: <float32>[batch_size] tensor with the new priorities. """ self._priorities.batch_scatter_update(tf.IndexedSlices(priorities, indices)) class HindsightExperienceReplay(PrioritizedReplay): """Replay Buffer with Hindsight Experience Replay. Hindsight goals are sampled uniformly from subsequent steps in the same window (`future` strategy from https://arxiv.org/pdf/1707.01495). They are not guaranteed to come from the same episode. This buffer is not threadsafe. Make sure you call insert() and sample() from a single thread. """ def __init__(self, size, specs, importance_sampling_exponent, compute_reward_fn, unroll_length, substitution_probability, name='HindsightExperienceReplay'): super(HindsightExperienceReplay, self).__init__( size, specs, importance_sampling_exponent, name) self._compute_reward_fn = compute_reward_fn self._unroll_length = unroll_length self._substitution_probability = substitution_probability @tf.Module.with_name_scope def sample(self, num_samples, priority_exp): indices, weights, sampled_values = super( HindsightExperienceReplay, self).sample(num_samples, priority_exp) observation = sampled_values.env_outputs.observation batch_size, time_horizon = observation['achieved_goal'].shape[:2] def compute_goal_reward(): # reward[batch][time] is the reward on transition from timestep time-1 # to time. This function outputs incorrect rewards for the last transition # in each episode but we filter such cases later. goal_reward = self._compute_reward_fn( achieved_goal=observation['achieved_goal'][:, 1:], desired_goal=observation['desired_goal'][:, :-1]) return tf.concat(values=[goal_reward[:, :1] * np.nan, goal_reward], axis=1) # Substitute goals. old_goal_reward = compute_goal_reward() assert old_goal_reward.shape == observation['achieved_goal'].shape[:-1] goal_ind = tf.concat( values=[tf.random.uniform((batch_size, 1), min(t + 1, time_horizon - 1), time_horizon, dtype=tf.int32) for t in range(time_horizon)], axis=1) substituted_goal = tf.gather(observation['achieved_goal'], goal_ind, axis=1, batch_dims=1) mask = tf.cast(tfp.distributions.Bernoulli( probs=self._substitution_probability * tf.ones(goal_ind.shape)).sample(), observation['desired_goal'].dtype) # We don't substitute goals for the last states in each episodes because we # don't store the next states for them. mask *= tf.cast(~sampled_values.env_outputs.done, observation['desired_goal'].dtype) mask = mask[..., tf.newaxis] observation['desired_goal'] = ( mask * substituted_goal + (1 - mask) * observation['desired_goal']) # Substitude reward new_goal_reward = compute_goal_reward() assert new_goal_reward.shape == observation['achieved_goal'].shape[:-1] sampled_values = sampled_values._replace( env_outputs=sampled_values.env_outputs._replace( reward=sampled_values.env_outputs.reward + (new_goal_reward - old_goal_reward) * tf.cast( ~sampled_values.env_outputs.done, tf.float32) )) # Subsample unrolls of length unroll_length + 1. assert time_horizon >= self._unroll_length + 1 unroll_begin_ind = tf.random.uniform( (batch_size,), 0, time_horizon - self._unroll_length, dtype=tf.int32) unroll_inds = unroll_begin_ind[:, tf.newaxis] + tf.math.cumsum( tf.ones((batch_size, self._unroll_length + 1), tf.int32), axis=1, exclusive=True) subsampled_values = tf.nest.map_structure( lambda t: tf.gather(t, unroll_inds, axis=1, batch_dims=1), sampled_values) if hasattr(sampled_values, 'agent_state'): # do not subsample the state subsampled_values = subsampled_values._replace( agent_state=sampled_values.agent_state) return indices, weights, subsampled_values class Aggregator(tf.Module): """Utility module for keeping state and statistics for individual actors.""" def __init__(self, num_actors, specs, name='Aggregator'): """Inits an Aggregator. Args: num_actors: int, number of actors. specs: Structure (as defined by tf.nest) of tf.TensorSpecs that will be stored for each actor. name: Name of the scope for the operations. """ super(Aggregator, self).__init__(name=name) def create_variable(spec): z = tf.zeros([num_actors] + spec.shape.dims, dtype=spec.dtype) return tf.Variable(z, trainable=False, name=spec.name) self._state = tf.nest.map_structure(create_variable, specs) @tf.Module.with_name_scope def reset(self, actor_ids): """Fills the tensors for the given actors with zeros.""" with tf.name_scope('Aggregator_reset'): for s in tf.nest.flatten(self._state): s.scatter_update(tf.IndexedSlices(0, actor_ids)) @tf.Module.with_name_scope def add(self, actor_ids, values): """In-place adds values to the state associated to the given actors. Args: actor_ids: 1D tensor with the list of actor IDs we want to add values to. values: A structure of tensors following the input spec, with an added first dimension that must either have the same size as 'actor_ids', or should not exist (in which case, the value is broadcasted to all actor ids). """ tf.nest.assert_same_structure(values, self._state) for s, v in zip(tf.nest.flatten(self._state), tf.nest.flatten(values)): s.scatter_add(tf.IndexedSlices(v, actor_ids)) @tf.Module.with_name_scope def read(self, actor_ids): """Reads the values corresponding to a list of actors. Args: actor_ids: 1D tensor with the list of actor IDs we want to read. Returns: A structure of tensors with the same shapes as the input specs. A dimension is added in front of each tensor, with size equal to the number of actor_ids provided. """ return tf.nest.map_structure(lambda s: s.sparse_read(actor_ids), self._state) @tf.Module.with_name_scope def replace(self, actor_ids, values): """Replaces the state associated to the given actors. Args: actor_ids: 1D tensor with the list of actor IDs. values: A structure of tensors following the input spec, with an added first dimension that must either have the same size as 'actor_ids', or should not exist (in which case, the value is broadcasted to all actor ids). """ tf.nest.assert_same_structure(values, self._state) for s, v in zip(tf.nest.flatten(self._state), tf.nest.flatten(values)): s.scatter_update(tf.IndexedSlices(v, actor_ids)) class ProgressLogger(object): """Helper class for performing periodic logging of the training progress.""" def __init__(self, summary_writer=None, initial_period=0.01, period_factor=1.01, max_period=10.0): """Constructs ProgressLogger. Args: summary_writer: Tensorflow summary writer to use. initial_period: Initial logging period in seconds (how often logging happens). period_factor: Factor by which logging period is multiplied after each iteration (exponential back-off). max_period: Maximal logging period in seconds (the end of exponential back-off). """ self.summary_writer = summary_writer self.period = initial_period self.period_factor = period_factor self.max_period = max_period # Array of strings with names of values to be logged. self.log_keys = [] self.log_keys_set = set() self.step_cnt = tf.Variable(-1, dtype=tf.int64) self.ready_values = tf.Variable([-1.0], dtype=tf.float32, shape=tf.TensorShape(None)) self.logger_thread = None self.logging_callback = None self.terminator = None self.last_log_time = timeit.default_timer() self.last_log_step = 0 def start(self, logging_callback=None): assert self.logger_thread is None self.logging_callback = logging_callback self.terminator = threading.Event() self.logger_thread = threading.Thread(target=self._logging_loop) self.logger_thread.start() def shutdown(self): assert self.logger_thread self.terminator.set() self.logger_thread.join() self.logger_thread = None def log_session(self): return [] def log(self, session, name, value): # this is a python op so it happens only when this tf.function is compiled if name not in self.log_keys_set: self.log_keys.append(name) self.log_keys_set.add(name) # this is a TF op. session.append(value) def log_session_from_dict(self, dic): session = self.log_session() for key in dic: self.log(session, key, dic[key]) return session def step_end(self, session, strategy=None, step_increment=1): logs = [] for value in session: if strategy: value = tf.reduce_mean(tf.cast( strategy.experimental_local_results(value)[0], tf.float32)) logs.append(value) self.ready_values.assign(logs) self.step_cnt.assign_add(step_increment) def _log(self): """Perform single round of logging.""" logging_time = timeit.default_timer() step_cnt = self.step_cnt.read_value() values = self.ready_values.read_value().numpy() if values[0] == -1: return assert len(values) == len( self.log_keys ), 'Mismatch between number of keys and values to log: %r vs %r' % ( values, self.log_keys) if self.summary_writer: self.summary_writer.set_as_default() tf.summary.experimental.set_step(step_cnt.numpy()) if self.logging_callback: self.logging_callback() for key, value in zip(self.log_keys, values): tf.summary.scalar(key, value) dt = logging_time - self.last_log_time df = tf.cast(step_cnt - self.last_log_step, tf.float32) tf.summary.scalar('speed/steps_per_sec', df / dt) self.last_log_time, self.last_log_step = logging_time, step_cnt def _logging_loop(self): last_log_try = timeit.default_timer() while not self.terminator.isSet(): self._log() now = timeit.default_timer() elapsed = now - last_log_try last_log_try = now self.period = min(self.period_factor * self.period, self.max_period) self.terminator.wait(timeout=max(0, self.period - elapsed)) class StructuredFIFOQueue(tf.queue.FIFOQueue): """A tf.queue.FIFOQueue that supports nests and tf.TensorSpec.""" def __init__(self, capacity, specs, shared_name=None, name='structured_fifo_queue'): self._specs = specs self._flattened_specs = tf.nest.flatten(specs) dtypes = [ts.dtype for ts in self._flattened_specs] shapes = [ts.shape for ts in self._flattened_specs] super(StructuredFIFOQueue, self).__init__(capacity, dtypes, shapes) def dequeue(self, name=None): result = super(StructuredFIFOQueue, self).dequeue(name=name) return tf.nest.pack_sequence_as(self._specs, result) def dequeue_many(self, batch_size, name=None): result = super(StructuredFIFOQueue, self).dequeue_many( batch_size, name=name) return tf.nest.pack_sequence_as(self._specs, result) def enqueue(self, vals, name=None): tf.nest.assert_same_structure(vals, self._specs) return super(StructuredFIFOQueue, self).enqueue( tf.nest.flatten(vals), name=name) def enqueue_many(self, vals, name=None): tf.nest.assert_same_structure(vals, self._specs) return super(StructuredFIFOQueue, self).enqueue_many( tf.nest.flatten(vals), name=name) def batch_apply(fn, inputs): """Folds time into the batch dimension, runs fn() and unfolds the result. Args: fn: Function that takes as input the n tensors of the tf.nest structure, with shape [time*batch, <remaining shape>], and returns a tf.nest structure of batched tensors. inputs: tf.nest structure of n [time, batch, <remaining shape>] tensors. Returns: tf.nest structure of [time, batch, <fn output shape>]. Structure is determined by the output of fn. """ time_to_batch_fn = lambda t: tf.reshape(t, [-1] + t.shape[2:].as_list()) batched = tf.nest.map_structure(time_to_batch_fn, inputs) output = fn(*batched) prefix = [int(tf.nest.flatten(inputs)[0].shape[0]), -1] batch_to_time_fn = lambda t: tf.reshape(t, prefix + t.shape[1:].as_list()) return tf.nest.map_structure(batch_to_time_fn, output) def make_time_major(x): """Transposes the batch and time dimensions of a nest of Tensors. If an input tensor has rank < 2 it returns the original tensor. Retains as much of the static shape information as possible. Args: x: A nest of Tensors. Returns: x transposed along the first two dimensions. """ def transpose(t): t_static_shape = t.shape if t_static_shape.rank is not None and t_static_shape.rank < 2: return t t_rank = tf.rank(t) t_t = tf.transpose(t, tf.concat(([1, 0], tf.range(2, t_rank)), axis=0)) t_t.set_shape( tf.TensorShape([t_static_shape[1], t_static_shape[0]]).concatenate(t_static_shape[2:])) return t_t return tf.nest.map_structure( lambda t: tf.xla.experimental.compile(transpose, [t])[0], x) class TPUEncodedUInt8Spec(tf.TypeSpec): """Type specification for composite tensor TPUEncodedUInt8.""" def __init__(self, encoded_shape, original_shape): self._value_specs = (tf.TensorSpec(encoded_shape, tf.uint32),) self.original_shape = original_shape @property def _component_specs(self): return self._value_specs def _to_components(self, value): return (value.encoded,) def _from_components(self, components): return TPUEncodedUInt8(components[0], self.original_shape) def _serialize(self): return self._value_specs[0].shape, self.original_shape def _to_legacy_output_types(self): return self._value_specs[0].dtype def _to_legacy_output_shapes(self): return self._value_specs[0].shape @property def value_type(self): return TPUEncodedUInt8 class TPUEncodedUInt8(composite_tensor.CompositeTensor): def __init__(self, encoded, shape): self.encoded = encoded self.original_shape = shape self._spec = TPUEncodedUInt8Spec(encoded.shape, tf.TensorShape(shape)) @property def _type_spec(self): return self._spec tensor_conversion_registry.register_tensor_conversion_function( TPUEncodedUInt8, lambda value, *unused_args, **unused_kwargs: value.encoded) class TPUEncodedF32Spec(tf.TypeSpec): """Type specification for composite tensor TPUEncodedF32Spec.""" def __init__(self, encoded_shape, original_shape): self._value_specs = (tf.TensorSpec(encoded_shape, tf.float32),) self.original_shape = original_shape @property def _component_specs(self): return self._value_specs def _to_components(self, value): return (value.encoded,) def _from_components(self, components): return TPUEncodedF32(components[0], self.original_shape) def _serialize(self): return self._value_specs[0].shape, self.original_shape def _to_legacy_output_types(self): return self._value_specs[0].dtype def _to_legacy_output_shapes(self): return self._value_specs[0].shape @property def value_type(self): return TPUEncodedF32 class TPUEncodedF32(composite_tensor.CompositeTensor): def __init__(self, encoded, shape): self.encoded = encoded self.original_shape = shape self._spec = TPUEncodedF32Spec(encoded.shape, tf.TensorShape(shape)) @property def _type_spec(self): return self._spec tensor_conversion_registry.register_tensor_conversion_function( TPUEncodedF32, lambda value, *unused_args, **unused_kwargs: value.encoded) def num_divisible(v, m): return sum([1 for x in v if x % m == 0]) def tpu_encode(ts): """Encodes a nest of Tensors in a suitable way for TPUs. TPUs do not support tf.uint8, tf.uint16 and other data types. Furthermore, the speed of transfer and device reshapes depend on the shape of the data. This function tries to optimize the data encoding for a number of use cases. Should be used on CPU before sending data to TPU and in conjunction with `tpu_decode` after the data is transferred. Args: ts: A tf.nest of Tensors. Returns: A tf.nest of encoded Tensors. """ def visit(t): num_elements = t.shape.num_elements() # We need a multiple of 128 elements: encoding reduces the number of # elements by a factor 4 (packing uint8s into uint32s), and first thing # decode does is to reshape with a 32 minor-most dimension. if (t.dtype == tf.uint8 and num_elements is not None and num_elements % 128 == 0): # For details of these transformations, see b/137182262. x = tf.xla.experimental.compile( lambda x: tf.transpose(x, list(range(1, t.shape.rank)) + [0]), [t])[0] x = tf.reshape(x, [-1, 4]) x = tf.bitcast(x, tf.uint32) x = tf.reshape(x, [-1]) return TPUEncodedUInt8(x, t.shape) elif t.dtype == tf.uint8: logging.warning('Inefficient uint8 transfer with shape: %s', t.shape) return tf.cast(t, tf.bfloat16) elif t.dtype == tf.uint16: return tf.cast(t, tf.int32) elif (t.dtype == tf.float32 and t.shape.rank > 1 and not (num_divisible(t.shape.dims, 128) >= 1 and num_divisible(t.shape.dims, 8) >= 2)): x = tf.reshape(t, [-1]) return TPUEncodedF32(x, t.shape) else: return t return tf.nest.map_structure(visit, ts) def tpu_decode(ts, structure=None): """Decodes a nest of Tensors encoded with tpu_encode. Args: ts: A nest of Tensors or TPUEncodedUInt8 composite tensors. structure: If not None, a nest of Tensors or TPUEncodedUInt8 composite tensors (possibly within PerReplica's) that are only used to recreate the structure of `ts` which then should be a list without composite tensors. Returns: A nest of decoded tensors packed as `structure` if available, otherwise packed as `ts`. """ def visit(t, s): s = s.values[0] if isinstance(s, values_lib.PerReplica) else s if isinstance(s, TPUEncodedUInt8): x = t.encoded if isinstance(t, TPUEncodedUInt8) else t x = tf.reshape(x, [-1, 32, 1]) x = tf.broadcast_to(x, x.shape[:-1] + [4]) x = tf.reshape(x, [-1, 128]) x = tf.bitwise.bitwise_and(x, [0xFF, 0xFF00, 0xFF0000, 0xFF000000] * 32) x = tf.bitwise.right_shift(x, [0, 8, 16, 24] * 32) rank = s.original_shape.rank perm = [rank - 1] + list(range(rank - 1)) inverted_shape = np.array(s.original_shape)[np.argsort(perm)] x = tf.reshape(x, inverted_shape) x = tf.transpose(x, perm) return x elif isinstance(s, TPUEncodedF32): x = t.encoded if isinstance(t, TPUEncodedF32) else t x = tf.reshape(x, s.original_shape) return x else: return t return tf.nest.map_structure(visit, ts, structure or ts) def split_structure(structure, prefix_length, axis=0): """Splits in two a tf.nest structure of tensors along the first axis.""" flattened = tf.nest.flatten(structure) split = [tf.split(x, [prefix_length, tf.shape(x)[axis] - prefix_length], axis=axis) for x in flattened] flattened_prefix = [pair[0] for pair in split] flattened_suffix = [pair[1] for pair in split] return (tf.nest.pack_sequence_as(structure, flattened_prefix), tf.nest.pack_sequence_as(structure, flattened_suffix)) class nullcontext(object): def __init__(self, *args, **kwds): del args # unused del kwds # unused def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): pass