Tensorflow-DatasetAPI

Simple Tensorflow DatasetAPI Tutorial for reading image

Usage

1. glob images

ex) trainA_dataset = glob('./dataset/{}/*.*'.format(dataset_name + '/trainA'))

trainA_dataset = ['./dataset/cat/trainA/a.jpg', 
                  './dataset/cat/trainA/b.png', 
                  './dataset/cat/trainA/c.jpeg', 
                  ...]

2. Use from_tensor_slices

trainA = tf.data.Dataset.from_tensor_slices(trainA_dataset)

3. Use map for preprocessing


    def image_processing(filename):
        x = tf.read_file(filename) # file read 
        x_decode = tf.image.decode_jpeg(x, channels=3) # for RGB

        # DO NOT USE decode_image
        # will be error

        img = tf.image.resize_images(x_decode, [256, 256])
        img = tf.cast(img, tf.float32) / 127.5 - 1

        return img

trainA = trainA.map(image_processing, num_parallel_calls=8)

class ImageData:

def __init__(self, batch_size, load_size, channels, augment_flag):
    self.batch_size = batch_size
    self.load_size = load_size
    self.channels = channels
    self.augment_flag = augment_flag
    self.augment_size = load_size + (30 if load_size == 256 else 15)

def image_processing(self, filename):
    x = tf.read_file(filename)
    x_decode = tf.image.decode_jpeg(x, channels=self.channels)

    # DO NOT USE decode_image
    # will be error

    img = tf.image.resize_images(x_decode, [self.load_size, self.load_size])
    img = tf.cast(img, tf.float32) / 127.5 - 1

    if self.augment_flag :
        p = random.random()
        if p > 0.5:
            img = self.augmentation(img)

    return img

def augmentation(self, image):
    seed = random.randint(0, 2 ** 31 - 1)

    ori_image_shape = tf.shape(image)
    image = tf.image.random_flip_left_right(image, seed=seed)
    image = tf.image.resize_images(image, [self.augment_size, self.augment_size])
    image = tf.random_crop(image, ori_image_shape, seed=seed)

    return image

Image_Data_Class = ImageData(batch_size, img_size, img_ch, augment_flag) trainA = trainA.map(Image_Data_Class.image_processing, num_parallel_calls=8)


* Personally recommend `num_parallel_calls` = `8 or 16`

***

### 4. Set `prefetch` & `batch_size`
```python

trainA = trainA.shuffle(buffer_size=10000).prefetch(buffer_size=batch_size).batch(batch_size).repeat()
trainA = trainA.shuffle(10000).prefetch(batch_size).apply(batch_and_drop_remainder(batch_size)).repeat()

trainA = trainA.apply(shuffle_and_repeat(dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, batch_size))


***

### 5. Set `Iterator`
```python

trainA_iterator = trainA.make_one_shot_iterator()

data_A = trainA_iterator.get_next()
logit = network(data_A)
...

6. Run Model


def train() :
    for epoch ...
        for iteration ...

7. See Code

Author

Junho Kim