#-*-coding:utf-8-*- from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import re from absl import app as absl_app from absl import flags import tensorflow as tf # pylint: disable=g-bad-import-order from resnet import resnet_model from resnet import resnet_run_loop _NUM_CHANNELS = 3 _NUM_CLASSES = 5 # The record is the image plus a one-byte label _NUM_IMAGES = { 'train': 230944, 'validation': 19448, } DATASET_NAME = 'nsfw' _IMAGE_SIZE = 64 ############################################################################### # Data processing ############################################################################### def get_filenames(is_training, data_dir): file_names = [] if is_training: pattern = 'nsfw_train_.*.tfrecord' else: pattern = 'nsfw_validation_.*.tfrecord' for top, dis, files in os.walk(data_dir): for name in files: if re.match(pattern, name): file_names.append(os.path.join(top, name)) return file_names def preprocess_image(image, is_training): """Preprocess a single image of layout [height, width, depth].""" if is_training: # Resize the image to add four extra pixels on each side. image = tf.image.resize_image_with_crop_or_pad( image, _IMAGE_SIZE + 8, _IMAGE_SIZE + 8) # Randomly crop a [_HEIGHT, _WIDTH] section of the image. image = tf.random_crop(image, [_IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS]) # Randomly flip the image horizontally. image = tf.image.random_flip_left_right(image) # Subtract off the mean and divide by the variance of the pixels. image = tf.image.per_image_standardization(image) return image def parse_record(raw_record, is_training ): image_feature_description = { 'image/height': tf.FixedLenFeature([], tf.int64), 'image/width': tf.FixedLenFeature([], tf.int64), 'image/format': tf.FixedLenFeature([], tf.string), 'image/class/label': tf.FixedLenFeature([], tf.int64), 'image/encoded': tf.FixedLenFeature([], tf.string), } parsed = tf.parse_single_example(raw_record, image_feature_description) image = parsed['image/encoded'] image = tf.image.decode_image(image, channels=3) image = tf.image.convert_image_dtype(image, dtype=tf.float32) image.set_shape([None, None, 3]) image = tf.image.resize_images(image, [_IMAGE_SIZE, _IMAGE_SIZE]) image = preprocess_image(image, is_training) label = parsed['image/class/label'] return image, label def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, parse_record_fn, num_epochs=1, num_gpus=None, examples_per_epoch=None, dtype=tf.float32): dataset = dataset.prefetch(buffer_size=batch_size) if is_training: dataset = dataset.shuffle(buffer_size=shuffle_buffer) dataset = dataset.repeat(num_epochs) if is_training and num_gpus and examples_per_epoch: total_examples = num_epochs * examples_per_epoch # Force the number of batches to be divisible by the number of devices. # This prevents some devices from receiving batches while others do not, # which can lead to a lockup. This case will soon be handled directly by # distribution strategies, at which point this .take() operation will no # longer be needed. total_batches = total_examples // batch_size // num_gpus * num_gpus dataset.take(total_batches * batch_size) # Parse the raw records into images and labels. Testing has shown that setting # num_parallel_batches > 1 produces no improvement in throughput, since # batch_size is almost always much greater than the number of CPU cores. dataset = dataset.apply( tf.contrib.data.map_and_batch( lambda value: parse_record_fn(value, is_training), batch_size=batch_size, num_parallel_batches=1, drop_remainder=False)) # Operations between the final prefetch and the get_next call to the iterator # will happen synchronously during run time. We prefetch here again to # background all of the above processing work and keep it out of the # critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE # allows DistributionStrategies to adjust how many batches to fetch based # on how many devices are present. dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) return dataset def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None, dtype=tf.float32): filenames = get_filenames(is_training, data_dir) print(filenames) dataset = tf.data.TFRecordDataset(filenames=filenames) dataset = process_record_dataset( dataset=dataset, is_training=is_training, batch_size=batch_size, shuffle_buffer=500, parse_record_fn=parse_record, num_epochs=num_epochs, num_gpus=num_gpus, examples_per_epoch=_NUM_IMAGES['train'] if is_training else None, dtype=dtype ) return dataset ############################################################################### # Running the model ############################################################################### class Model(resnet_model.Model): """Model class with appropriate defaults for CIFAR-10 data.""" def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES, resnet_version=resnet_model.DEFAULT_VERSION, dtype=resnet_model.DEFAULT_DTYPE): """These are the parameters that work for CIFAR-10 data. Args: resnet_size: The number of convolutional layers needed in the model. data_format: Either 'channels_first' or 'channels_last', specifying which data format to use when setting up the model. num_classes: The number of output classes needed from the model. This enables users to extend the same model to their own datasets. resnet_version: Integer representing which version of the ResNet network to use. See README for details. Valid values: [1, 2] dtype: The TensorFlow dtype to use for calculations. Raises: ValueError: if invalid resnet_size is chosen """ if resnet_size % 6 != 2: raise ValueError('resnet_size must be 6n + 2:', resnet_size) num_blocks = (resnet_size - 2) // 6 super(Model, self).__init__( resnet_size=resnet_size, bottleneck=False, num_classes=num_classes, num_filters=16, kernel_size=3, conv_stride=1, first_pool_size=None, first_pool_stride=None, block_sizes=[num_blocks] * 3, block_strides=[1, 2, 2], final_size=64, resnet_version=resnet_version, data_format=data_format, dtype=dtype ) def cifar10_model_fn(features, labels, mode, params): """Model function for CIFAR-10.""" features = tf.reshape(features, [-1, _IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS]) learning_rate_fn = resnet_run_loop.learning_rate_with_decay( batch_size=params['batch_size'], batch_denom=128, num_images=_NUM_IMAGES['train'], boundary_epochs=[10, 20, 30], decay_rates=[1, 0.1, 0.01, 0.001]) # We use a weight decay of 0.0002, which performs better # than the 0.0001 that was originally suggested. weight_decay = 2e-4 # Empirical testing showed that including batch_normalization variables # in the calculation of regularized loss helped validation accuracy # for the CIFAR-10 dataset, perhaps because the regularization prevents # overfitting on the small data set. We therefore include all vars when # regularizing and computing loss during training. def loss_filter_fn(_): return True return resnet_run_loop.resnet_model_fn( features=features, labels=labels, mode=mode, model_class=Model, resnet_size=params['resnet_size'], weight_decay=weight_decay, learning_rate_fn=learning_rate_fn, momentum=0.9, data_format=params['data_format'], resnet_version=params['resnet_version'], loss_scale=params['loss_scale'], loss_filter_fn=loss_filter_fn, dtype=params['dtype'], fine_tune=params['fine_tune'] ) def set_defaults(**kwargs): for key, value in kwargs.items(): flags.FLAGS.set_default(name=key, value=value) def define_flower_flags(): resnet_run_loop.define_resnet_flags() flags.adopt_module_key_flags(resnet_run_loop) set_defaults( data_dir='', model_dir='', resnet_size='32', train_epochs=50, epochs_between_evals=50, batch_size=128) def run_flower(flags_obj): """Run ResNet CIFAR-10 training and eval loop. Args: flags_obj: An object containing parsed flag values. """ input_function = input_fn resnet_run_loop.resnet_main( flags_obj, cifar10_model_fn, input_function, DATASET_NAME, shape=[_IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS]) def main(_): run_flower(flags.FLAGS) if __name__ == '__main__': tf.logging.set_verbosity(tf.logging.INFO) define_flower_flags() absl_app.run(main)