# Copyright 2018 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Functions to create the implementation graph.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import copy import hashlib # GOOGLE-INITIALIZATION import tensorflow as tf from tensorflow_transform import analyzer_nodes from tensorflow_transform import graph_tools from tensorflow_transform import impl_helper from tensorflow_transform import nodes from tensorflow_transform import tf_utils from tensorflow_transform.beam import analyzer_cache from tensorflow_transform.beam import beam_nodes from tensorflow_transform.beam import combiner_packing_util # Used for debugging only. This will point to the most recent graph built. _ANALYSIS_GRAPH = None def _serialize_op_attr(op_attr): """Deterministicly serializes tf.Operation attrs since it is a map.""" sorted_attributes = sorted(op_attr.items(), key=lambda kv: kv[0]) if 'f' in op_attr: # This is a tf.Function node, and it includes attributes that are # inconsistent across runs such as _gradient_op_type, config_proto, so we # only keep input and output types since other information will arrive from # the FuncGraph attributes. sorted_attributes = [ kv for kv in sorted_attributes if kv[0] in ('Tin', 'Tout') ] result = [] for key, attr_value in sorted_attributes: result.append(key) attr_value = copy.deepcopy(attr_value) if attr_value.list.func: raise ValueError( 'Unable to serialize op attributes that contain a `list.func` field') if attr_value.HasField('func'): # There should be a separate call for the FuncGraph attributes. attr_value.ClearField('func') result.append(attr_value.SerializeToString()) return result def _describe_path_as_analyzer_cache_hash(x, parents=None): """Constructs a hash to describe a unique TF graph path. Note: We do not rely on names for hashing since it can be fragile. Args: x: One of (None, tf.Operation, tf.Tensor, str), the current TF graph node. parents: (Optional) a list of bytes, results of previous calls to this function, where x was an ancestor to the current node x. Returns: A bytes hash of the path from x to its sources. None if x is None. """ # This may happen in cases where tensors are outputs of previous analyzers, # we don't need to describe a path for those. if x is None: assert parents is None return None parents = parents or [] if any(p is None for p in parents): return None if isinstance(x, tf.Operation): values = _serialize_op_attr(x.node_def.attr) elif isinstance(x, tf.Tensor): # No need to add x.op to the hash since that should be included in parents. values = [tf.compat.as_str_any(x.value_index)] else: assert isinstance(x, (str, bytes)) values = [x] h = hashlib.sha1() for value in values: encoded = tf.compat.as_bytes(value) h.update(encoded) for p in parents: h.update(p) return h.digest() def _tensor_name(tensor): """Get a name of a tensor without trailing ":0" when relevant.""" # tensor.name is unicode in Python 3 and bytes in Python 2 so convert to # bytes here. name = str(tensor.name) return name[:-2] if name.endswith(':0') else name class _ReadyVisitor(nodes.Visitor): """Visitor to determine if a node is ready to run.""" def __init__(self, graph_analyzer): self._graph_analyzer = graph_analyzer def visit(self, operation_def, input_values): if isinstance(operation_def, analyzer_nodes.TensorSource): is_ready = all(self._graph_analyzer.ready_to_run(tensor) for tensor in operation_def.tensors) else: is_ready = all(input_values) return (is_ready,) * operation_def.num_outputs def validate_value(self, value): assert isinstance(value, bool) class _TranslateVisitor(nodes.Visitor): """Visitor that translates the operation graph. The original graph is defined by the user in the preprocessing_fn. The translated graph represents a Beam pipeline. """ def __init__(self): self.phase = None self.extracted_values_dict = None self.intermediate_output_signature = None def visit(self, operation_def, input_values): if isinstance(operation_def, analyzer_nodes.TensorSource): tensors = operation_def.tensors label = operation_def.label # Add tensor to signature so it gets produced by the SavedModel. for tensor in tensors: self.intermediate_output_signature[_tensor_name(tensor)] = tensor keys = tuple(map(_tensor_name, tensors)) output = nodes.apply_operation( beam_nodes.ExtractFromDict, self.extracted_values_dict, keys=keys, label=label) return (output,) else: return nodes.OperationNode(operation_def, input_values).outputs def validate_value(self, value): assert isinstance(value, nodes.ValueNode) class _OptimizationView( collections.namedtuple('_OptimizationView', [ 'prefer_fine_grained_view', 'flattened_view', 'fine_grained_view', 'hashed_path' ])): """A container for operation outputs during _OptimizeVisitor traversal. This is used in order to maintain both a flattened view, and a fine grained view that can be used for caching. `prefer_fine_grained_view` is a hint that means that if True, the `fine_grained_view` should be used. It should be set to true if the upstream view has cacheing operations that haven't been flattened yet. """ def __init__(self, prefer_fine_grained_view, flattened_view, fine_grained_view, hashed_path): if prefer_fine_grained_view and not fine_grained_view: raise ValueError( 'Cannot prefer fine_grained_view when one is not provided') del hashed_path self._validate_flattened_view(flattened_view) self._validate_fine_grained_view(fine_grained_view) super(_OptimizationView, self).__init__() def _validate_flattened_view(self, view): assert view is self.flattened_view assert view is not None assert isinstance(view, nodes.ValueNode), view def _validate_fine_grained_view(self, view): assert view is self.fine_grained_view if view is None: return assert isinstance(view, collections.OrderedDict), view for value in view.values(): assert isinstance(value, nodes.ValueNode), value class _OptimizeVisitor(nodes.Visitor): """Visitor optimizes the operation graph (see nodes.py). This operates on the translated graph which is emitted by the `_TranslateVisitor`, and performs optimizations. Namely, when enabled, this enables reading and writing from/to analyzer accumulator cache to avoid recomputing them over already seen datasets. This type of optimization requires also creating a partitioned view of the input data, according to the `is_partitionable` annotation. """ def __init__(self, dataset_keys, cache_dict, tensor_keys_to_paths, cache_output_nodes): """Init method for _OptimizeVisitor. Args: dataset_keys: An iterable of strings which are keys for a partitioned dataset. cache_dict: A dictionary of input cache that can be used in place of a cacheable accumulate operation. A dictionary from dataset_keys to dictionaries of cache keys to PCollections. This can be None if there is no cache. tensor_keys_to_paths: A dictionary from a tensor key to a unique TF graph path hash. cache_output_nodes: A dictionary from (dataset_key, cache_key) to encoded cache ValueNode. This is the output cache for this graph. """ self._sorted_dataset_keys = sorted(dataset_keys) self._cache_dict = cache_dict self._tensor_keys_to_paths = tensor_keys_to_paths self.cache_output_nodes = cache_output_nodes def _validate_operation_def(self, operation_def): if operation_def.cache_coder is not None: if not operation_def.is_partitionable: raise ValueError('Non partitionable OperationDefs cannot be cacheable') if operation_def.is_partitionable or operation_def.cache_coder is not None: if operation_def.num_outputs != 1: raise ValueError('Cacheable OperationDefs must have exactly 1 output') def _make_next_hashed_path(self, parent_hashed_paths, operation_def): # Making a copy of parent_hashed_paths. paths_to_hash = list(parent_hashed_paths) paths_to_hash.append(tf.compat.as_bytes(operation_def.__class__.__name__)) if isinstance(operation_def, beam_nodes.ExtractFromDict): for key in operation_def.keys: path = self._tensor_keys_to_paths[key] paths_to_hash.append(path) else: for attr in sorted( [x for x in dir(operation_def) if x not in operation_def._fields]): if attr.startswith('_') or callable(getattr(operation_def, attr)): continue paths_to_hash.append( tf.compat.as_bytes(str((attr, getattr(operation_def, attr))))) for field in operation_def._fields: paths_to_hash.append( tf.compat.as_bytes( str((field, operation_def.get_field_str(field))))) hash_container = hashlib.sha1() for path in paths_to_hash: if path is None: return None hash_container.update(path) return hash_container.digest() def visit(self, operation_def, input_values): self._validate_operation_def(operation_def) if (isinstance(operation_def, beam_nodes.ApplySavedModel) and operation_def.phase == 0): return self._visit_apply_savedmodel_operation(operation_def, input_values) # When self._cache_dict is None this means that we shouldn't do any cacheing # for this pipeline, and so there's no need to create any fine grained # views. if self._cache_dict is not None and operation_def.is_partitionable: return self._visit_partitionable_operation(operation_def, input_values) if input_values and any(v.fine_grained_view and v.prefer_fine_grained_view for v in input_values): # We can 'flatten' the cached outputs of the parent operation since this # operation doesn't support partitioning. disaggregated_input_values = [] for view in input_values: disaggregated_input_values.extend(view.fine_grained_view.values()) # Checking that all cache has the same size. assert len({len(value) for value in disaggregated_input_values}) == 1 next_inputs = nodes.apply_multi_output_operation( beam_nodes.Flatten, *disaggregated_input_values, label='FlattenCache[{}]'.format(operation_def.label)) else: # Parent operation output is not cacheable, therefore we can just use # a flattened view. next_inputs = tuple(v.flattened_view for v in input_values) flattened_view = nodes.OperationNode(operation_def, next_inputs).outputs return tuple( _OptimizationView( # pylint: disable=g-complex-comprehension prefer_fine_grained_view=False, flattened_view=flat, fine_grained_view=None, hashed_path=None) for flat in flattened_view) def _visit_partitionable_operation(self, operation_def, upstream_views): # This is a hint for whether or not the `fine_grained_view` should be used # downstream. It should be set to true if either the upstream view has # cacheing operations that haven't been flattened yet, or the current # operation is cacheable. all_fine_grained_views_available = all( v.fine_grained_view for v in upstream_views) prefer_fine_grained_view = ( any(v.prefer_fine_grained_view for v in upstream_views) or all_fine_grained_views_available and operation_def.cache_coder is not None) next_hashed_path = self._make_next_hashed_path( [v.hashed_path for v in upstream_views], operation_def) if all_fine_grained_views_available: fine_grained_views = (self._apply_operation_on_fine_grained_view( operation_def, tuple(v.fine_grained_view for v in upstream_views), next_hashed_path),) else: fine_grained_views = (None,) * operation_def.num_outputs flattened_views = nodes.OperationNode( operation_def, tuple(v.flattened_view for v in upstream_views)).outputs assert len(fine_grained_views) == len(flattened_views) return tuple( _OptimizationView( # pylint: disable=g-complex-comprehension prefer_fine_grained_view=prefer_fine_grained_view, flattened_view=flat, fine_grained_view=fine, hashed_path=next_hashed_path) for flat, fine in zip(flattened_views, fine_grained_views)) def _apply_operation_on_fine_grained_view(self, operation_def, fine_grained_views, next_hashed_path): """Applies a shardable operation on a fine grained view. This also updates `cache_output_nodes` when necessary. Args: operation_def: A shardable `OperationDef`. fine_grained_views: A tuple of `_OptimizationView.fine_grained_view`s. next_hashed_path: The hashed path for the currently processed operation_def. Returns: The resulting list of `_OptimizationView.fine_grained_view`s. """ result_fine_grained_view = collections.OrderedDict() cache_entry_key = analyzer_cache.make_cache_entry_key( tf.compat.as_bytes(operation_def.label) + b'-' + next_hashed_path) for (dataset_idx, dataset_key) in enumerate(self._sorted_dataset_keys): # We use an index for the label in order to make beam labels more stable. infix = 'AnalysisIndex{}'.format(dataset_idx) if (operation_def.cache_coder and self._cache_dict.get( dataset_key, {}).get(cache_entry_key) is not None): decode_cache = analyzer_nodes.DecodeCache( dataset_key, cache_entry_key, coder=operation_def.cache_coder, label='DecodeCache[{}][{}]'.format(operation_def.label, infix)) (op_output,) = nodes.OperationNode(decode_cache, tuple()).outputs else: value_nodes = tuple(v[dataset_key] for v in fine_grained_views) (op_output,) = nodes.OperationNode( operation_def._replace( label='{}[{}]'.format(operation_def.label, infix)), value_nodes).outputs if operation_def.cache_coder: encode_cache = nodes.apply_operation( analyzer_nodes.EncodeCache, op_output, coder=operation_def.cache_coder, label='EncodeCache[{}][{}]'.format(operation_def.label, infix)) self.cache_output_nodes[(dataset_key, cache_entry_key)] = encode_cache result_fine_grained_view[dataset_key] = op_output return result_fine_grained_view def _visit_apply_savedmodel_operation(self, operation_def, upstream_views): if any(v.fine_grained_view for v in upstream_views): raise ValueError( 'Was not expecting a fine_grained_view input for ApplySavedModel') (saved_model_path_upstream_view, input_upstream_view) = upstream_views fine_grained_view = collections.OrderedDict() for (dataset_idx, dataset_key) in enumerate(self._sorted_dataset_keys): infix = 'AnalysisIndex{}'.format(dataset_idx) input_node = nodes.apply_operation( beam_nodes.ExtractInputForSavedModel, dataset_key=dataset_key, label='ExtractInputForSavedModel[{}]'.format(infix)) # We use an index for the label in order to make beam labels more stable. (fine_grained_view[dataset_key],) = ( nodes.OperationNode( operation_def._replace( label='{}[{}]'.format(operation_def.label, infix)), (saved_model_path_upstream_view.flattened_view, input_node)).outputs) (flattened_view,) = nodes.OperationNode( operation_def, (saved_model_path_upstream_view.flattened_view, input_upstream_view.flattened_view)).outputs return (_OptimizationView( prefer_fine_grained_view=False, flattened_view=flattened_view, fine_grained_view=fine_grained_view, hashed_path=b'APPLY_SAVEDMODEL'),) def validate_value(self, value): assert isinstance(value, _OptimizationView), value if value.fine_grained_view: assert set(value.fine_grained_view.keys()) == set( self._sorted_dataset_keys), ('{} != {}'.format( value.fine_grained_view.keys(), self._sorted_dataset_keys)) def _perform_cache_optimization(saved_model_future, dataset_keys, tensor_keys_to_paths, cache_dict): """Performs cache optimization on the given graph.""" cache_output_nodes = {} optimize_visitor = _OptimizeVisitor(dataset_keys or {}, cache_dict, tensor_keys_to_paths, cache_output_nodes) optimize_traverser = nodes.Traverser(optimize_visitor) optimized = optimize_traverser.visit_value_node( saved_model_future).flattened_view if cache_dict is None: assert not cache_output_nodes cache_output_nodes = None return optimized, cache_output_nodes class _InspectVisitor(nodes.Visitor): """A visitor that inspects the graph and looks for dataset keys in use.""" def __init__(self, required_dataset_keys_output): self._required_dataset_keys = required_dataset_keys_output def visit(self, operation_def, input_values): if isinstance(operation_def, beam_nodes.ExtractInputForSavedModel): self._required_dataset_keys.add(operation_def.dataset_key) return nodes.OperationNode(operation_def, input_values).outputs def validate_value(self, value): assert isinstance(value, nodes.ValueNode) def _build_analysis_graph_for_inspection( preprocessing_fn, specs, dataset_keys, input_cache): """Builds the analysis graph for inspection.""" with tf.compat.v1.Graph().as_default() as graph: with tf.compat.v1.name_scope('inputs'): input_signature = impl_helper.batched_placeholders_from_specs(specs) # TODO(b/34288791): This needs to be exactly the same as in impl.py copied_inputs = impl_helper.copy_tensors(input_signature) output_signature = preprocessing_fn(copied_inputs) transform_fn_future, cache_dict = build( graph, input_signature, output_signature, dataset_keys=dataset_keys, cache_dict=input_cache) return transform_fn_future, cache_dict def get_analysis_dataset_keys( preprocessing_fn, specs, dataset_keys, input_cache): """Computes the dataset keys that are required in order to perform analysis. Args: preprocessing_fn: A tf.transform preprocessing_fn. specs: A dict of feature name to feature specification or tf.TypeSpecs. dataset_keys: A set of strings which are dataset keys, they uniquely identify these datasets across analysis runs. input_cache: A cache dictionary. Returns: A set of dataset keys that are required for analysis. """ transform_fn_future, _ = _build_analysis_graph_for_inspection( preprocessing_fn, specs, dataset_keys, input_cache) result = set() inspect_visitor = _InspectVisitor(result) inspect_traverser = nodes.Traverser(inspect_visitor) _ = inspect_traverser.visit_value_node(transform_fn_future) # If None is present this means that a flattened version of the entire dataset # is required, therefore this will be returning all of the given dataset_keys. if any(k.is_flattened_dataset_key() for k in result): result = dataset_keys return result def get_analysis_cache_entry_keys(preprocessing_fn, feature_spec, dataset_keys): """Computes the cache entry keys that would be useful for analysis. Args: preprocessing_fn: A tf.transform preprocessing_fn. feature_spec: A dict of feature name to feature specification. dataset_keys: A set of strings which are dataset keys, they uniquely identify these datasets across analysis runs. Returns: A set of cache entry keys which would be useful for analysis. """ _, cache_dict = _build_analysis_graph_for_inspection( preprocessing_fn, feature_spec, dataset_keys, {}) return set([cache_key for _, cache_key in cache_dict.keys()]) def build(graph, input_signature, output_signature, dataset_keys=None, cache_dict=None): """Returns a list of `Phase`s describing how to execute the pipeline. The default graph is assumed to contain some `Analyzer`s which must be executed by doing a full pass over the dataset, and passing the inputs for that analyzer into some implementation, then taking the results and replacing the `Analyzer`s outputs with constants in the graph containing these results. The execution plan is described by a list of `Phase`s. Each phase contains a list of `Analyzer`s, which are the `Analyzer`s which are ready to run in that phase, together with a list of ops, which are the table initializers that are ready to run in that phase. An `Analyzer` or op is ready to run when all its dependencies in the graph have been computed. Thus if the graph is constructed by def preprocessing_fn(input) x = inputs['x'] scaled_0 = x - tft.min(x) scaled_0_1 = scaled_0 / tft.max(scaled_0) Then the first phase will contain the analyzer corresponding to the call to `min`, because `x` is an input and so is ready to compute in the first phase, while the second phase will contain the analyzer corresponding to the call to `max` since `scaled_1` depends on the result of the call to `tft.min` which is computed in the first phase. More generally, we define a level for each op and each `Analyzer` by walking the graph, assigning to each operation the max level of its inputs, to each `Tensor` the level of its operation, unless it's the output of an `Analyzer` in which case we assign the level of its `Analyzer` plus one. Args: graph: A `tf.Graph`. input_signature: A dict whose keys are strings and values are `Tensor`s or `SparseTensor`s. output_signature: A dict whose keys are strings and values are `Tensor`s or `SparseTensor`s. dataset_keys: (Optional) A set of strings which are dataset keys, they uniquely identify these datasets across analysis runs. cache_dict: (Optional): A cache dictionary. Returns: A pair of: * list of `Phase`s * A dictionary of output cache `ValueNode`s. Raises: ValueError: if the graph cannot be analyzed. """ tensor_sinks = graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS) graph.clear_collection(analyzer_nodes.TENSOR_REPLACEMENTS) phase = 0 tensor_bindings = [] sink_tensors_ready = { tf_utils.hashable_tensor_or_op(tensor_sink.tensor): False for tensor_sink in tensor_sinks } translate_visitor = _TranslateVisitor() translate_traverser = nodes.Traverser(translate_visitor) analyzers_input_signature = {} graph_analyzer = None extracted_input_node = nodes.apply_operation( beam_nodes.ExtractInputForSavedModel, dataset_key=analyzer_cache._make_flattened_dataset_key(), # pylint: disable=protected-access label='ExtractInputForSavedModel[FlattenedDataset]') while not all(sink_tensors_ready.values()): infix = 'Phase{}'.format(phase) # Determine which table init ops are ready to run in this phase # Determine which keys of pending_tensor_replacements are ready to run # in this phase, based in whether their dependencies are ready. graph_analyzer = graph_tools.InitializableGraphAnalyzer( graph, input_signature, list(sink_tensors_ready.items()), _describe_path_as_analyzer_cache_hash) ready_traverser = nodes.Traverser(_ReadyVisitor(graph_analyzer)) # Now create and apply a SavedModel with all tensors in tensor_bindings # bound, which outputs all the tensors in the required tensor tuples. intermediate_output_signature = collections.OrderedDict() saved_model_future = nodes.apply_operation( beam_nodes.CreateSavedModel, *tensor_bindings, table_initializers=tuple(graph_analyzer.ready_table_initializers), output_signature=intermediate_output_signature, label='CreateSavedModelForAnalyzerInputs[{}]'.format(infix)) extracted_values_dict = nodes.apply_operation( beam_nodes.ApplySavedModel, saved_model_future, extracted_input_node, phase=phase, label='ApplySavedModel[{}]'.format(infix)) translate_visitor.phase = phase translate_visitor.intermediate_output_signature = ( intermediate_output_signature) translate_visitor.extracted_values_dict = extracted_values_dict for tensor, value_node, is_asset_filepath in tensor_sinks: hashable_tensor = tf_utils.hashable_tensor_or_op(tensor) # Don't compute a binding/sink/replacement that's already been computed if sink_tensors_ready[hashable_tensor]: continue if not ready_traverser.visit_value_node(value_node): continue translated_value_node = translate_traverser.visit_value_node(value_node) name = _tensor_name(tensor) tensor_bindings.append( nodes.apply_operation( beam_nodes.CreateTensorBinding, translated_value_node, tensor=str(tensor.name), is_asset_filepath=is_asset_filepath, label='CreateTensorBinding[{}]'.format(name))) sink_tensors_ready[hashable_tensor] = True analyzers_input_signature.update(intermediate_output_signature) phase += 1 # We need to make sure that the representation of this output_signature is # deterministic. output_signature = collections.OrderedDict( sorted(output_signature.items(), key=lambda t: t[0])) # TODO(KesterTong): check all table initializers are ready, check all output # tensors are ready. saved_model_future = nodes.apply_operation( beam_nodes.CreateSavedModel, *tensor_bindings, table_initializers=tuple( graph.get_collection(tf.compat.v1.GraphKeys.TABLE_INITIALIZERS)), output_signature=output_signature, label='CreateSavedModel') tensor_keys_to_paths = { tensor_key: graph_analyzer.get_unique_path(analyzers_input_signature[tensor_key]) for tensor_key in analyzers_input_signature } (optimized_saved_model_future, output_cache_value_nodes) = _perform_cache_optimization( saved_model_future, dataset_keys, tensor_keys_to_paths, cache_dict) (optimized_saved_model_future, output_cache_value_nodes) = ( combiner_packing_util.perform_combiner_packing_optimization( optimized_saved_model_future, output_cache_value_nodes, phase)) global _ANALYSIS_GRAPH _ANALYSIS_GRAPH = optimized_saved_model_future return optimized_saved_model_future, output_cache_value_nodes