# coding=utf-8 # Copyright 2020 The Meta-Dataset Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python2, python3 """Module responsible for decoding image/feature examples.""" import gin.tf import tensorflow.compat.v1 as tf def read_single_example(example_string): """Parses the record string.""" return tf.parse_single_example( example_string, features={ 'image': tf.FixedLenFeature([], dtype=tf.string), 'label': tf.FixedLenFeature([], tf.int64) }) def read_example_and_parse_image(example_string): """Reads the string and decodes the image.""" parsed_example = read_single_example(example_string) image_decoded = tf.image.decode_image(parsed_example['image'], channels=3) image_decoded.set_shape([None, None, 3]) parsed_example['image'] = image_decoded return parsed_example @gin.configurable class ImageDecoder(object): """Image decoder.""" out_type = tf.float32 def __init__(self, image_size=None, data_augmentation=None): """Class constructor. Args: image_size: int, desired image size. The extracted image will be resized to `[image_size, image_size]`. data_augmentation: A DataAugmentation object with parameters for perturbing the images. """ self.image_size = image_size self.data_augmentation = data_augmentation def __call__(self, example_string): """Processes a single example string. Extracts and processes the image, and ignores the label. We assume that the image has three channels. Args: example_string: str, an Example protocol buffer. Returns: image_rescaled: the image, resized to `image_size x image_size` and rescaled to [-1, 1]. Note that Gaussian data augmentation may cause values to go beyond this range. """ return self.decode_with_label(example_string)[0] def decode_with_label(self, example_string): """Processes a single example string. Extracts and processes the image, and ignores the label. We assume that the image has three channels. Args: example_string: str, an Example protocol buffer. Returns: image_rescaled: the image, resized to `image_size x image_size` and rescaled to [-1, 1]. Note that Gaussian data augmentation may cause values to go beyond this range. label: tf.int """ ex_decoded = read_example_and_parse_image(example_string) image_decoded = ex_decoded['image'] image_resized = tf.image.resize_images( image_decoded, [self.image_size, self.image_size], method=tf.image.ResizeMethod.BILINEAR, align_corners=True) image_resized = tf.cast(image_resized, tf.float32) image = 2 * (image_resized / 255.0 - 0.5) # Rescale to [-1, 1]. if self.data_augmentation is not None: if self.data_augmentation.enable_gaussian_noise: image = image + tf.random_normal( tf.shape(image)) * self.data_augmentation.gaussian_noise_std if self.data_augmentation.enable_jitter: j = self.data_augmentation.jitter_amount paddings = tf.constant([[j, j], [j, j], [0, 0]]) image = tf.pad(image, paddings, 'REFLECT') image = tf.image.random_crop(image, [self.image_size, self.image_size, 3]) return image, tf.cast(ex_decoded['label'], dtype=tf.int32) @gin.configurable class FeatureDecoder(object): """Feature decoder.""" out_type = tf.float32 def __init__(self, feat_len): """Class constructor. Args: feat_len: The expected length of the feature vectors. """ self.feat_len = feat_len def __call__(self, example_string): """Processes a single example string. Extracts and processes the feature, and ignores the label. Args: example_string: str, an Example protocol buffer. Returns: feat: The feature tensor. """ feat = tf.parse_single_example( example_string, features={ 'image/embedding': tf.FixedLenFeature([self.feat_len], dtype=tf.float32), 'image/class/label': tf.FixedLenFeature([], tf.int64) })['image/embedding'] return feat @gin.configurable class StringDecoder(object): """Simple decoder that reads the image without decoding.""" out_type = tf.string def __init__(self): """Class constructor.""" def __call__(self, example_string): """Processes a single example string. Extracts the image as string, and ignores the label. Args: example_string: str, an Example protocol buffer. Returns: img_string: tf.Tensor of type tf.string. """ img_string = read_single_example(example_string)['image'] return img_string