import functools import sonnet as snt import tensorflow as tf import tensorflow.contrib.slim as slim from tensorflow.contrib.slim.nets import resnet_v2, resnet_v1, vgg from luminoth.models.base import truncated_vgg from luminoth.utils.checkpoint_downloader import get_checkpoint_file # Default RGB means used commonly. _R_MEAN = 123.68 _G_MEAN = 116.78 _B_MEAN = 103.94 VALID_ARCHITECTURES = set([ 'resnet_v1_50', 'resnet_v1_101', 'resnet_v1_152', 'resnet_v2_50', 'resnet_v2_101', 'resnet_v2_152', 'vgg_16', 'truncated_vgg_16', ]) class BaseNetwork(snt.AbstractModule): """ Convolutional Neural Network used for image classification, whose architecture can be any of the `VALID_ARCHITECTURES`. This class wraps the `tf.slim` implementations of these models, with some helpful additions. """ def __init__(self, config, name='base_network'): super(BaseNetwork, self).__init__(name=name) if config.get('architecture') not in VALID_ARCHITECTURES: raise ValueError('Invalid architecture: "{}"'.format( config.get('architecture') )) self._architecture = config.get('architecture') self._config = config self.pretrained_weights_scope = None @property def arg_scope(self): arg_scope_kwargs = self._config.get('arg_scope', {}) if self.vgg_type: return vgg.vgg_arg_scope(**arg_scope_kwargs) if self.truncated_vgg_type: return truncated_vgg.vgg_arg_scope(**arg_scope_kwargs) if self.resnet_type: # It's the same arg_scope for v1 or v2. return resnet_v2.resnet_utils.resnet_arg_scope(**arg_scope_kwargs) raise ValueError('Invalid architecture: "{}"'.format( self._config.get('architecture') )) def network(self, is_training=False): if self.vgg_type: return functools.partial( getattr(vgg, self._architecture), is_training=is_training, spatial_squeeze=self._config.get('spatial_squeeze', False), ) elif self.truncated_vgg_type: return functools.partial( getattr(truncated_vgg, self._architecture), is_training=is_training, ) elif self.resnet_v1_type: output_stride = self._config.get('output_stride') train_batch_norm = ( is_training and self._config.get('train_batch_norm') ) return functools.partial( getattr(resnet_v1, self._architecture), is_training=train_batch_norm, num_classes=None, global_pool=False, output_stride=output_stride ) elif self.resnet_v2_type: output_stride = self._config.get('output_stride') return functools.partial( getattr(resnet_v2, self._architecture), is_training=is_training, num_classes=self._config.get('num_classes'), output_stride=output_stride, ) @property def vgg_type(self): return self._architecture.startswith('vgg') @property def vgg_16_type(self): return self._architecture.startswith('vgg_16') @property def truncated_vgg_type(self): return self._architecture.startswith('truncated_vgg') @property def truncated_vgg_16_type(self): return self._architecture.startswith('truncated_vgg_16') @property def resnet_type(self): return self._architecture.startswith('resnet') @property def resnet_v1_type(self): return self._architecture.startswith('resnet_v1') @property def resnet_v2_type(self): return self._architecture.startswith('resnet_v2') @property def default_image_size(self): # Usually 224, but depends on the architecture. if self.vgg_16_type: return vgg.vgg_16.default_image_size if self.truncated_vgg_16_type: return vgg.truncated_vgg_16.default_image_size if self.resnet_v1_type: return resnet_v1.resnet_v1.default_image_size if self.resnet_v2_type: return resnet_v2.resnet_v2.default_image_size def _build(self, inputs, is_training=False): inputs = self.preprocess(inputs) with slim.arg_scope(self.arg_scope): net, end_points = self.network(is_training=is_training)(inputs) return { 'net': net, 'end_points': end_points, } def preprocess(self, inputs): if self.vgg_type or self.resnet_type: inputs = self._subtract_channels(inputs) return inputs def _subtract_channels(self, inputs, means=[_R_MEAN, _G_MEAN, _B_MEAN]): """Subtract channels from images. It is common for CNNs to subtract the mean of all images from each channel. In the case of RGB images we first calculate the mean from each of the channels (Red, Green, Blue) and subtract those values for training and for inference. Args: inputs: A Tensor of images we want to normalize. Its shape is (1, height, width, num_channels). means: A Tensor of shape (num_channels,) with the means to be subtracted from each channels on the inputs. Returns: outputs: A Tensor of images normalized with the means. Its shape is the same as the input. """ return inputs - [means] def _normalize(self, inputs): """Normalize between -1.0 to 1.0. Args: inputs: A Tensor of images we want to normalize. Its shape is (1, height, width, num_channels). Returns: outputs: A Tensor of images normalized between -1 and 1. Its shape is the same as the input. """ inputs = inputs / 255. inputs = (inputs - 0.5) * 2. return inputs def get_checkpoint_file(self): return get_checkpoint_file(self._architecture) def _get_base_network_vars(self): """Returns a list of all the base network's variables.""" if self.pretrained_weights_scope: # We may have defined the base network in a particular scope module_variables = tf.get_collection( tf.GraphKeys.MODEL_VARIABLES, scope=self.pretrained_weights_scope ) else: module_variables = snt.get_variables_in_module( self, tf.GraphKeys.MODEL_VARIABLES ) assert len(module_variables) > 0 return module_variables def get_trainable_vars(self): """ Returns a list of the variables that are trainable. If a value for `fine_tune_from` is specified in the config, only the variables starting from the first that contains this string in its name will be trainable. For example, specifying `vgg_16/fc6` for a VGG16 will set only the variables in the fully connected layers to be trainable. If `fine_tune_from` is None, then all the variables will be trainable. Returns: trainable_variables: a tuple of `tf.Variable`. """ all_variables = snt.get_variables_in_module(self) fine_tune_from = self._config.get('fine_tune_from') if fine_tune_from is None: return all_variables # Get the index of the first trainable variable var_iter = enumerate(v.name for v in all_variables) try: index = next(i for i, name in var_iter if fine_tune_from in name) except StopIteration: raise ValueError( '"{}" is an invalid value of fine_tune_from for this ' 'architecture.'.format(fine_tune_from) ) return all_variables[index:] def get_base_network_checkpoint_vars(self): """Returns the vars which the base network checkpoint will load into. We return a dict which maps a variable name to a variable object. This is needed because the base network may be created inside a particular scope, which the checkpoint may not contain. Therefore we must map each variable to its unscoped name in order to be able to find them in the checkpoint file. """ variable_scope_len = len(self.variable_scope.name) + 1 var_list = self._get_base_network_vars() var_map = {} for var in var_list: var_name = var.op.name checkpoint_var_name = var_name[variable_scope_len:] var_map[checkpoint_var_name] = var return var_map