import numpy as np import tensorflow as tf from skimage import color, transform from ..base import BatchableRecordReader, RecordWriter width = 224 height = 224 depth = 3 img_shape = (width, height, depth) embedding_size = 1001 class LabImageRecordWriter(RecordWriter): img_shape = img_shape embedding_size = embedding_size def write_image(self, img_file, image, img_embedding): img = transform.resize(image, img_shape, mode="constant") lab = color.rgb2lab(img).astype(np.float32) l_channel = 2 * lab[:, :, 0] / 100 - 1 ab_channels = lab[:, :, 1:] / 127 example = tf.train.Example( features=tf.train.Features( feature={ "image_name": self._bytes_feature(img_file), "image_l": self._float32_list(l_channel.flatten()), "image_ab": self._float32_list(ab_channels.flatten()), "image_embedding": self._float32_list(img_embedding.flatten()), } ) ) self.write(example.SerializeToString()) class LabImageRecordReader(BatchableRecordReader): img_shape = img_shape embedding_size = embedding_size def _create_read_operation(self): features = tf.parse_single_example( self._tfrecord_serialized, features={ "image_name": tf.FixedLenFeature([], tf.string), "image_l": tf.FixedLenFeature([width * height], tf.float32), "image_ab": tf.FixedLenFeature([width * height * 2], tf.float32), "image_embedding": tf.FixedLenFeature([embedding_size], tf.float32), }, ) image_l = tf.reshape(features["image_l"], shape=[width, height, 1]) image_ab = tf.reshape(features["image_ab"], shape=[width, height, 2]) return { "image_name": features["image_name"], "image_l": image_l, "image_ab": image_ab, "image_embedding": features["image_embedding"], }