# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ResNet model. Related papers: https://arxiv.org/pdf/1603.05027v2.pdf https://arxiv.org/pdf/1512.03385v1.pdf https://arxiv.org/pdf/1605.07146v1.pdf """ from collections import namedtuple import numpy as np import tensorflow as tf from tensorflow.python.training import moving_averages HParams = namedtuple('HParams', 'batch_size, num_classes, min_lrn_rate, lrn_rate, ' 'num_residual_units, use_bottleneck, weight_decay_rate, ' 'relu_leakiness, optimizer') class ResNet(object): """ResNet model.""" def __init__(self, hps, mode, image_size=32, use_wide_resnet=False, nr_gpu=1): self.hps = hps self.batch_size = self.hps.batch_size self.input_image = [tf.placeholder(tf.float32, shape=(self.batch_size,image_size,image_size,3)) for _ in range(nr_gpu)] self.input_label = [tf.placeholder(tf.int32, shape=(self.batch_size,1)) for _ in range(nr_gpu)] self.mode = mode self.needImgAug = tf.placeholder(tf.bool, shape=()) self.image_size = image_size self.nr_gpu = nr_gpu self._extra_train_ops = [] self.lrn_rate = tf.placeholder(tf.float32, shape=()) self.use_wide_resnet = use_wide_resnet def build_graph(self): """Build a whole graph for the model.""" with tf.variable_scope('I2L'): self.global_step = tf.contrib.framework.get_or_create_global_step() self._build_model() self.trainable_variables = [v for v in tf.trainable_variables() if v.name.startswith('I2L/')] self.all_variables = [v for v in tf.global_variables() if v.name.startswith('I2L/')] #if self.mode == 'train': # self._build_train_op() def _stride_arr(self, stride): """Map a stride scalar to the stride array for tf.nn.conv2d.""" return [1, stride, stride, 1] def _PreprocessImages(self): def _aug_one_img(img): img = tf.image.resize_image_with_crop_or_pad(img, self.image_size+4, self.image_size+4) img = tf.random_crop(img, [self.image_size, self.image_size, 3]) img = tf.image.random_flip_left_right(img) return img def _deal_one_img(img): img = tf.cond(self.needImgAug, lambda: _aug_one_img(img), lambda: img) img = tf.image.per_image_standardization(img) return img #images = tf.map_fn(lambda img: _deal_one_img(img), self.input_image) #self.image = images self.image = [tf.map_fn(lambda img: _deal_one_img(img), X) for X in self.input_image] def _make_1hot_labels(self): self.labels = [] for L in self.input_label: labels = tf.reshape(L, [self.batch_size, 1]) indices = tf.reshape(tf.range(0, self.batch_size, 1), [self.batch_size, 1]) labels = tf.sparse_to_dense( tf.concat([indices, labels],1), [self.batch_size, self.hps.num_classes], 1.0, 0.0) self.labels.append(labels) def _build_basic_structure(self, x, y): with tf.variable_scope('init'): x = self._conv('init_conv', x, 3, 3, 16, self._stride_arr(1)) strides = [1, 2, 2] activate_before_residual = [True, False, False] if self.hps.use_bottleneck: res_func = self._bottleneck_residual filters = [16, 64, 128, 256] else: res_func = self._residual if self.use_wide_resnet: filters = [16, 160, 320, 640] else: filters = [16, 16, 32, 64] # Uncomment the following codes to use w28-10 wide residual network. # It is more memory efficient than very deep residual network and has # comparably good performance. # https://arxiv.org/pdf/1605.07146v1.pdf # filters = [16, 160, 320, 640] # Update hps.num_residual_units to 9 with tf.variable_scope('unit_1_0'): x = res_func(x, filters[0], filters[1], self._stride_arr(strides[0]), activate_before_residual[0]) for i in range(1, self.hps.num_residual_units): with tf.variable_scope('unit_1_%d' % i): x = res_func(x, filters[1], filters[1], self._stride_arr(1), False) with tf.variable_scope('unit_2_0'): x = res_func(x, filters[1], filters[2], self._stride_arr(strides[1]), activate_before_residual[1]) for i in range(1, self.hps.num_residual_units): with tf.variable_scope('unit_2_%d' % i): x = res_func(x, filters[2], filters[2], self._stride_arr(1), False) with tf.variable_scope('unit_3_0'): x = res_func(x, filters[2], filters[3], self._stride_arr(strides[2]), activate_before_residual[2]) for i in range(1, self.hps.num_residual_units): with tf.variable_scope('unit_3_%d' % i): x = res_func(x, filters[3], filters[3], self._stride_arr(1), False) with tf.variable_scope('unit_last'): x = self._batch_norm('final_bn', x) x = self._relu(x, self.hps.relu_leakiness) x = self._global_avg_pool(x) with tf.variable_scope('logit'): logits = self._fully_connected(x, self.hps.num_classes) predictions_ = tf.nn.softmax(logits) with tf.variable_scope('costs'): xent = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits) nlls_ = xent cost_ = tf.reduce_mean(xent, name='xent') cost_ += self._decay() return nlls_, cost_, predictions_ #tf.scalar_summary('cost', self.cost) def _build_model(self): """Build the core model within the graph.""" # Preprocess self._PreprocessImages() self._make_1hot_labels() self.nlls = [None for _ in range(self.nr_gpu)] self.cost = [None for _ in range(self.nr_gpu)] self.predictions = [None for _ in range(self.nr_gpu)] for i in range(self.nr_gpu): with tf.variable_scope('I2L', reuse=True if i >= 1 else None): with tf.device('/gpu:%d' % i): nll_, cost_, predicted_ = self._build_basic_structure(self.image[i], self.labels[i]) self.nlls[i] = nll_ self.cost[i] = cost_ self.predictions[i] = predicted_ ''' def _build_model(self): """Build the core model within the graph.""" # Preprocess self._PreprocessImages() self._make_1hot_labels() with tf.variable_scope('init'): x = self.image x = self._conv('init_conv', x, 3, 3, 16, self._stride_arr(1)) strides = [1, 2, 2] activate_before_residual = [True, False, False] if self.hps.use_bottleneck: res_func = self._bottleneck_residual filters = [16, 64, 128, 256] else: res_func = self._residual if self.use_wide_resnet: filters = [16, 160, 320, 640] else: filters = [16, 16, 32, 64] # Uncomment the following codes to use w28-10 wide residual network. # It is more memory efficient than very deep residual network and has # comparably good performance. # https://arxiv.org/pdf/1605.07146v1.pdf # filters = [16, 160, 320, 640] # Update hps.num_residual_units to 9 with tf.variable_scope('unit_1_0'): x = res_func(x, filters[0], filters[1], self._stride_arr(strides[0]), activate_before_residual[0]) for i in range(1, self.hps.num_residual_units): with tf.variable_scope('unit_1_%d' % i): x = res_func(x, filters[1], filters[1], self._stride_arr(1), False) with tf.variable_scope('unit_2_0'): x = res_func(x, filters[1], filters[2], self._stride_arr(strides[1]), activate_before_residual[1]) for i in range(1, self.hps.num_residual_units): with tf.variable_scope('unit_2_%d' % i): x = res_func(x, filters[2], filters[2], self._stride_arr(1), False) with tf.variable_scope('unit_3_0'): x = res_func(x, filters[2], filters[3], self._stride_arr(strides[2]), activate_before_residual[2]) for i in range(1, self.hps.num_residual_units): with tf.variable_scope('unit_3_%d' % i): x = res_func(x, filters[3], filters[3], self._stride_arr(1), False) with tf.variable_scope('unit_last'): x = self._batch_norm('final_bn', x) x = self._relu(x, self.hps.relu_leakiness) x = self._global_avg_pool(x) with tf.variable_scope('logit'): logits = self._fully_connected(x, self.hps.num_classes) self.predictions = tf.nn.softmax(logits) with tf.variable_scope('costs'): xent = tf.nn.softmax_cross_entropy_with_logits(labels=self.labels, logits=logits) self.nlls = xent self.cost = tf.reduce_mean(xent, name='xent') self.cost += self._decay() #tf.scalar_summary('cost', self.cost) ''' def _build_train_op(self): """Build training specific ops for the graph.""" self.lrn_rate = tf.constant(self.hps.lrn_rate, tf.float32) #tf.scalar_summary('learning rate', self.lrn_rate) trainable_variables = tf.trainable_variables() #self.trainable_variables = [v for v in tf.trainable_variables() if v.name.startswith('LM/')] grads = tf.gradients(self.cost, trainable_variables) if self.hps.optimizer == 'sgd': optimizer = tf.train.GradientDescentOptimizer(self.lrn_rate) elif self.hps.optimizer == 'mom': optimizer = tf.train.MomentumOptimizer(self.lrn_rate, 0.9) apply_op = optimizer.apply_gradients( zip(grads, trainable_variables), global_step=self.global_step, name='train_step') train_ops = [apply_op] + self._extra_train_ops self.train_op = tf.group(*train_ops) def Update(self, grads): """Build training specific ops for the graph.""" if self.hps.optimizer == 'sgd': optimizer = tf.train.GradientDescentOptimizer(self.lrn_rate) elif self.hps.optimizer == 'mom': optimizer = tf.train.MomentumOptimizer(self.lrn_rate, 0.9) apply_op = optimizer.apply_gradients( zip(grads, self.trainable_variables), global_step=self.global_step, name='train_step') train_ops = [apply_op] + self._extra_train_ops self.update_ops = tf.group(*train_ops) # TODO(xpan): Consider batch_norm in contrib/layers/python/layers/layers.py def _batch_norm(self, name, x): """Batch normalization.""" with tf.variable_scope(name): params_shape = [x.get_shape()[-1]] beta = tf.get_variable( 'beta', params_shape, tf.float32, initializer=tf.constant_initializer(0.0, tf.float32)) gamma = tf.get_variable( 'gamma', params_shape, tf.float32, initializer=tf.constant_initializer(1.0, tf.float32)) if self.mode == 'train': mean, variance = tf.nn.moments(x, [0, 1, 2], name='moments') moving_mean = tf.get_variable( 'moving_mean', params_shape, tf.float32, initializer=tf.constant_initializer(0.0, tf.float32), trainable=False) moving_variance = tf.get_variable( 'moving_variance', params_shape, tf.float32, initializer=tf.constant_initializer(1.0, tf.float32), trainable=False) self._extra_train_ops.append(moving_averages.assign_moving_average( moving_mean, mean, 0.9)) self._extra_train_ops.append(moving_averages.assign_moving_average( moving_variance, variance, 0.9)) else: mean = tf.get_variable( 'moving_mean', params_shape, tf.float32, initializer=tf.constant_initializer(0.0, tf.float32), trainable=False) variance = tf.get_variable( 'moving_variance', params_shape, tf.float32, initializer=tf.constant_initializer(1.0, tf.float32), trainable=False) #tf.histogram_summary(mean.op.name, mean) #tf.histogram_summary(variance.op.name, variance) # elipson used to be 1e-5. Maybe 0.001 solves NaN problem in deeper net. y = tf.nn.batch_normalization( x, mean, variance, beta, gamma, 0.001) y.set_shape(x.get_shape()) return y def _residual(self, x, in_filter, out_filter, stride, activate_before_residual=False): """Residual unit with 2 sub layers.""" if activate_before_residual: with tf.variable_scope('shared_activation'): x = self._batch_norm('init_bn', x) x = self._relu(x, self.hps.relu_leakiness) orig_x = x else: with tf.variable_scope('residual_only_activation'): orig_x = x x = self._batch_norm('init_bn', x) x = self._relu(x, self.hps.relu_leakiness) with tf.variable_scope('sub1'): x = self._conv('conv1', x, 3, in_filter, out_filter, stride) with tf.variable_scope('sub2'): x = self._batch_norm('bn2', x) x = self._relu(x, self.hps.relu_leakiness) x = self._conv('conv2', x, 3, out_filter, out_filter, [1, 1, 1, 1]) with tf.variable_scope('sub_add'): if in_filter != out_filter: orig_x = tf.nn.avg_pool(orig_x, stride, stride, 'VALID') orig_x = tf.pad( orig_x, [[0, 0], [0, 0], [0, 0], [(out_filter-in_filter)//2, (out_filter-in_filter)//2]]) x += orig_x tf.logging.info('image after unit %s', x.get_shape()) return x def _bottleneck_residual(self, x, in_filter, out_filter, stride, activate_before_residual=False): """Bottleneck resisual unit with 3 sub layers.""" if activate_before_residual: with tf.variable_scope('common_bn_relu'): x = self._batch_norm('init_bn', x) x = self._relu(x, self.hps.relu_leakiness) orig_x = x else: with tf.variable_scope('residual_bn_relu'): orig_x = x x = self._batch_norm('init_bn', x) x = self._relu(x, self.hps.relu_leakiness) with tf.variable_scope('sub1'): x = self._conv('conv1', x, 1, in_filter, out_filter/4, stride) with tf.variable_scope('sub2'): x = self._batch_norm('bn2', x) x = self._relu(x, self.hps.relu_leakiness) x = self._conv('conv2', x, 3, out_filter/4, out_filter/4, [1, 1, 1, 1]) with tf.variable_scope('sub3'): x = self._batch_norm('bn3', x) x = self._relu(x, self.hps.relu_leakiness) x = self._conv('conv3', x, 1, out_filter/4, out_filter, [1, 1, 1, 1]) with tf.variable_scope('sub_add'): if in_filter != out_filter: orig_x = self._conv('project', orig_x, 1, in_filter, out_filter, stride) x += orig_x tf.logging.info('image after unit %s', x.get_shape()) return x def _decay(self): """L2 weight decay loss.""" costs = [] for var in tf.trainable_variables(): if var.op.name.find(r'DW') > 0: costs.append(tf.nn.l2_loss(var)) # tf.histogram_summary(var.op.name, var) return tf.multiply(self.hps.weight_decay_rate, tf.add_n(costs)) def GetWeightDecay(self): """L2 weight decay loss.""" costs = [] for var in self.trainable_variables: if var.op.name.find(r'DW') > 0: costs.append(tf.nn.l2_loss(var)) # tf.histogram_summary(var.op.name, var) return tf.mul(self.hps.weight_decay_rate, tf.add_n(costs)) def _conv(self, name, x, filter_size, in_filters, out_filters, strides): """Convolution.""" with tf.variable_scope(name): n = filter_size * filter_size * out_filters kernel = tf.get_variable( 'DW', [filter_size, filter_size, in_filters, out_filters], tf.float32, initializer=tf.random_normal_initializer( stddev=np.sqrt(2.0/n))) return tf.nn.conv2d(x, kernel, strides, padding='SAME') def _relu(self, x, leakiness=0.0): """Relu, with optional leaky support.""" return tf.where(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu') def _fully_connected(self, x, out_dim): """FullyConnected layer for final output.""" x = tf.reshape(x, [self.batch_size, -1]) w = tf.get_variable( 'DW', [x.get_shape()[1], out_dim], initializer=tf.uniform_unit_scaling_initializer(factor=1.0)) b = tf.get_variable('biases', [out_dim], initializer=tf.constant_initializer()) return tf.nn.xw_plus_b(x, w, b) def _global_avg_pool(self, x): assert x.get_shape().ndims == 4 return tf.reduce_mean(x, [1, 2])