# SPDX-License-Identifier: MIT # Copyright © 2020 Patrick Levin """Utility functions for working with TensorFlow Graphs""" from __future__ import absolute_import from collections import namedtuple from typing import List, Union import numpy as np import tensorflow as tf from tensorflow.core.framework.graph_pb2 import GraphDef from tensorflow.core.framework.node_def_pb2 import NodeDef import tfjs_graph_converter.common as c _DTYPE_MAP: List[type] = [ None, np.float32, np.float64, np.int32, np.uint8, np.int16, np.int8, None, np.complex64, np.int64, np.bool ] NodeInfo = namedtuple('NodeInfo', 'name shape dtype tensor') def _is_op_node(node: NodeInfo) -> bool: return node.op not in (c.TFJS_NODE_CONST_KEY, c.TFJS_NODE_PLACEHOLDER_KEY) def _op_nodes(graph_def: GraphDef) -> List[NodeDef]: return [node for node in graph_def.node if _is_op_node(node)] def _map_type(type_id: int) -> type: if type_id < 0 or type_id > len(_DTYPE_MAP): raise ValueError(f'Unsupported data type: {type_id}') np_type = _DTYPE_MAP[type_id] return np_type def _get_shape(node: NodeDef) -> List[int]: def shape(attr): return attr.shape.dim def size(dim): return dim.size if dim.size > 0 else None return [size(dim) for dim in shape(node.attr[c.TFJS_ATTR_SHAPE_KEY])] def _node_info(node: NodeDef) -> NodeInfo: def dtype(n): return _map_type(n.attr[c.TFJS_ATTR_DTYPE_KEY].type) return NodeInfo(name=node.name, shape=_get_shape(node), dtype=dtype(node), tensor=node.name + ':0') def get_input_nodes(graph: Union[tf.Graph, GraphDef]) -> List[NodeInfo]: """ Return information about a graph's inputs. Args: graph: Graph or GraphDef object Returns: List of NodeInfo tuples holding name, shape, and type of the input """ if isinstance(graph, tf.Graph): graph_def = graph.as_graph_def() else: graph_def = graph def is_input(node): return node.op == c.TFJS_NODE_PLACEHOLDER_KEY return [_node_info(node) for node in graph_def.node if is_input(node)] def get_output_nodes(graph: Union[tf.Graph, GraphDef]) -> List[NodeInfo]: """ Return information about a graph's outputs. Args: graph: Graph or GraphDef object Returns: List of NodeInfo tuples holding name, shape, and type of the input; shape will be left empty """ # normalise input if isinstance(graph, tf.Graph): graph_def = graph.as_graph_def() else: graph_def = graph # visit graph nodes and test for references # assumption: all referenced nodes are declared *before* use ops = _op_nodes(graph_def) outputs = [] for i, node in enumerate(ops): has_ref = False for test in ops[i+1:]: if node.name in test.input: has_ref = True break if not has_ref: outputs.append(node) return [_node_info(node) for node in outputs] def get_input_tensors(graph: Union[tf.Graph, GraphDef]) -> List[str]: """ Return the names of the graph's input tensors. Args: graph: Graph or GraphDef object Returns: List of tensor names """ return [node.tensor for node in get_input_nodes(graph)] def get_output_tensors(graph: Union[tf.Graph, GraphDef]) -> List[str]: """ Return the names of the graph's output tensors. Args: graph: Graph or GraphDef object Returns: List of tensor names """ return [node.tensor for node in get_output_nodes(graph)]