# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tensorflow Example proto decoder for object detection. A decoder to decode string tensors containing serialized tensorflow.Example protos for object detection. """ import tensorflow as tf from object_detection.core import data_decoder from object_detection.core import standard_fields as fields slim_example_decoder = tf.contrib.slim.tfexample_decoder class TfExampleDecoder(data_decoder.DataDecoder): """Tensorflow Example proto decoder.""" def __init__(self): """Constructor sets keys_to_features and items_to_handlers.""" self.keys_to_features = { 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'), 'image/filename': tf.FixedLenFeature((), tf.string, default_value=''), 'image/key/sha256': tf.FixedLenFeature((), tf.string, default_value=''), 'image/source_id': tf.FixedLenFeature((), tf.string, default_value=''), 'image/height': tf.FixedLenFeature((), tf.int64, 1), 'image/width': tf.FixedLenFeature((), tf.int64, 1), # Object boxes and classes. 'image/object/bbox/xmin': tf.VarLenFeature(tf.float32), 'image/object/bbox/xmax': tf.VarLenFeature(tf.float32), 'image/object/bbox/ymin': tf.VarLenFeature(tf.float32), 'image/object/bbox/ymax': tf.VarLenFeature(tf.float32), 'image/object/class/label': tf.VarLenFeature(tf.int64), 'image/object/area': tf.VarLenFeature(tf.float32), 'image/object/is_crowd': tf.VarLenFeature(tf.int64), 'image/object/difficult': tf.VarLenFeature(tf.int64), # Instance masks and classes. 'image/segmentation/object': tf.VarLenFeature(tf.int64), 'image/segmentation/object/class': tf.VarLenFeature(tf.int64) } self.items_to_handlers = { fields.InputDataFields.image: slim_example_decoder.Image( image_key='image/encoded', format_key='image/format', channels=3), fields.InputDataFields.source_id: ( slim_example_decoder.Tensor('image/source_id')), fields.InputDataFields.key: ( slim_example_decoder.Tensor('image/key/sha256')), fields.InputDataFields.filename: ( slim_example_decoder.Tensor('image/filename')), # Object boxes and classes. fields.InputDataFields.groundtruth_boxes: ( slim_example_decoder.BoundingBox( ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/')), fields.InputDataFields.groundtruth_classes: ( slim_example_decoder.Tensor('image/object/class/label')), fields.InputDataFields.groundtruth_area: slim_example_decoder.Tensor( 'image/object/area'), fields.InputDataFields.groundtruth_is_crowd: ( slim_example_decoder.Tensor('image/object/is_crowd')), fields.InputDataFields.groundtruth_difficult: ( slim_example_decoder.Tensor('image/object/difficult')), # Instance masks and classes. fields.InputDataFields.groundtruth_instance_masks: ( slim_example_decoder.ItemHandlerCallback( ['image/segmentation/object', 'image/height', 'image/width'], self._reshape_instance_masks)), fields.InputDataFields.groundtruth_instance_classes: ( slim_example_decoder.Tensor('image/segmentation/object/class')), } def Decode(self, tf_example_string_tensor): """Decodes serialized tensorflow example and returns a tensor dictionary. Args: tf_example_string_tensor: a string tensor holding a serialized tensorflow example proto. Returns: A dictionary of the following tensors. fields.InputDataFields.image - 3D uint8 tensor of shape [None, None, 3] containing image. fields.InputDataFields.source_id - string tensor containing original image id. fields.InputDataFields.key - string tensor with unique sha256 hash key. fields.InputDataFields.filename - string tensor with original dataset filename. fields.InputDataFields.groundtruth_boxes - 2D float32 tensor of shape [None, 4] containing box corners. fields.InputDataFields.groundtruth_classes - 1D int64 tensor of shape [None] containing classes for the boxes. fields.InputDataFields.groundtruth_area - 1D float32 tensor of shape [None] containing containing object mask area in pixel squared. fields.InputDataFields.groundtruth_is_crowd - 1D bool tensor of shape [None] indicating if the boxes enclose a crowd. fields.InputDataFields.groundtruth_difficult - 1D bool tensor of shape [None] indicating if the boxes represent `difficult` instances. fields.InputDataFields.groundtruth_instance_masks - 3D int64 tensor of shape [None, None, None] containing instance masks. fields.InputDataFields.groundtruth_instance_classes - 1D int64 tensor of shape [None] containing classes for the instance masks. """ serialized_example = tf.reshape(tf_example_string_tensor, shape=[]) decoder = slim_example_decoder.TFExampleDecoder(self.keys_to_features, self.items_to_handlers) keys = decoder.list_items() tensors = decoder.decode(serialized_example, items=keys) tensor_dict = dict(zip(keys, tensors)) is_crowd = fields.InputDataFields.groundtruth_is_crowd tensor_dict[is_crowd] = tf.cast(tensor_dict[is_crowd], dtype=tf.bool) tensor_dict[fields.InputDataFields.image].set_shape([None, None, 3]) return tensor_dict def _reshape_instance_masks(self, keys_to_tensors): """Reshape instance segmentation masks. The instance segmentation masks are reshaped to [num_instances, height, width] and cast to boolean type to save memory. Args: keys_to_tensors: a dictionary from keys to tensors. Returns: A 3-D boolean tensor of shape [num_instances, height, width]. """ masks = keys_to_tensors['image/segmentation/object'] if isinstance(masks, tf.SparseTensor): masks = tf.sparse_tensor_to_dense(masks) height = keys_to_tensors['image/height'] width = keys_to_tensors['image/width'] to_shape = tf.cast(tf.stack([-1, height, width]), tf.int32) return tf.cast(tf.reshape(masks, to_shape), tf.bool)