"""pre-activation Residual network model class""" from collections import namedtuple import logging import tensorflow as tf from tensorflow.contrib.layers import convolution2d from tensorflow.contrib.layers import batch_norm from tensorflow.contrib.layers import variance_scaling_initializer from tensorflow.contrib.layers import l2_regularizer from tensorflow.contrib.layers import fully_connected from tensorflow.contrib.layers.python.layers import utils import model ######################################### # FLAGS ######################################### FLAGS = tf.app.flags.FLAGS class ResNN(model.Model): """Residual neural network model. classify web page only based on target html.""" def model_conf(self): # Configurations for each group # several residual units (aka. bottleneck blocks) form a group UnitsGroup = namedtuple( 'UnitsGroup', [ 'num_units', # number of residual units 'num_ker', # number of kernels for each convolution 'reduced_ker', # number of reduced kernels 'is_downsample' # (bool): downsample data using stride 2 # types of BottleneckBlock ?? # wide resnet kernel*k ?? ]) self.groups = [ # no more than three groups with downsampling # UnitsGroup(3, 64, 32, True), # UnitsGroup(2, 1024, 512, True), UnitsGroup(1, 256, 128, True), UnitsGroup(1, 512, 256, True), # UnitsGroup(3, 512, 256, True), # UnitsGroup(3, 256, 128, False), # UnitsGroup(3, 128, 64, True), # UnitsGroup(2, 258, 128, True), # UnitsGroup(2, 512, 256, True), # UnitsGroup(2, 1024, 512, True), ] # special first residual unit from P14 of (arxiv.org/abs/1603.05027) self.special_first = True # shortcut connection type: (arXiv:1512.03385) # 0: 0-padding and average pooling # 1: convolution projection only for increasing dimension # 2: projection for all shortcut self.shortcut = 1 # weight decay # self.weight_decay = 0.0001 self.weight_decay = 0.01 # self.weight_decay = 0.01 # the type of residual unit # 0: post-activation; 1: pre-activation self.unit_type = 1 # residual function: 0: bottleneck # 1: basic two conv self.residual_type = 1 # the middle conv window size of bottleneck: 3, 4, 5 self.bott_size = 5 # window size of first and third conv in bottleneck self.bott_size13 = 1 # RoR enable level 1 # requirement: every group is downsampling self.ror_l1 = False # RoR enable level 2 self.ror_l2 = False # whether enable dropout before FC layer self.dropout = True # whehter use dropout in residual function self.if_drop = False logging.info("ResNet hyper parameters:") logging.info(vars(self)) def BN_ReLU(self, net): # Batch Normalization and ReLU # 'gamma' is not used as the next layer is ReLU net = batch_norm(net, center=True, scale=False, activation_fn=tf.nn.relu, ) # net = tf.nn.relu(net) # activation summary ?? self._activation_summary(net) return net def conv1d(self, net, num_ker, ker_size, stride): # 1D-convolution net = convolution2d( net, num_outputs=num_ker, kernel_size=[ker_size, 1], stride=[stride, 1], padding='SAME', activation_fn=None, normalizer_fn=None, weights_initializer=variance_scaling_initializer(), weights_regularizer=l2_regularizer(self.weight_decay), biases_initializer=tf.zeros_initializer) return net def residual_unit(self, net, group_i, unit_i): """pre-activation Residual Units from https://arxiv.org/abs/1603.05027.""" name = 'group_%d/unit_%d' % (group_i, unit_i) group = self.groups[group_i] if group.is_downsample and unit_i == 0: stride1 = 2 else: stride1 = 1 def conv_pre(name, net, num_ker, kernel_size, stride, conv_i): """ 1D pre-activation convolution. args: num_ker (int): number of kernels (out_channels). ker_size (int): size of 1D kernel. stride (int) """ with tf.variable_scope(name): if not (self.special_first and group_i == unit_i == conv_i == 0): net = self.BN_ReLU(net) # 1D-convolution net = self.conv1d(net, num_ker, kernel_size, stride) return net def conv_post(name, net, num_ker, kernel_size, stride, conv_i): """ 1D post-activation convolution. args: num_ker (int): number of kernels (out_channels). ker_size (int): size of 1D kernel. stride (int) """ with tf.variable_scope(name): # 1D-convolution net = self.conv1d(net, num_ker, kernel_size, stride) net = self.BN_ReLU(net) return net ### residual function net_residual = net if self.unit_type == 0 and not self.special_first: unit_conv = conv_post elif self.unit_type == 1: unit_conv = conv_pre else: raise ValueError("wrong residual unit type:{}".format( self.unit_type)) if self.residual_type == 0: # 1x1 convolution responsible for reducing dimension net_residual = unit_conv(name + '/conv_reduce', net_residual, group.reduced_ker, self.bott_size13, stride1, 0) # 3x1 convolution bottleneck net_residual = unit_conv(name + '/conv_bottleneck', net_residual, group.reduced_ker, self.bott_size, 1, 1) # 1x1 convolution responsible for restoring dimension net_residual = unit_conv(name + '/conv_restore', net_residual, group.num_ker, self.bott_size13, 1, 2) elif self.residual_type == 1: net_residual = unit_conv(name + '/conv_one', net_residual, group.num_ker, self.bott_size, stride1, 0) # if self.if_drop and group_i == 2: if self.if_drop and unit_i == 0: with tf.name_scope("dropout"): net_residual = tf.nn.dropout(net_residual, self.dropout_keep_prob) net_residual = unit_conv(name + '/conv_two', net_residual, group.num_ker, self.bott_size, 1, 1) else: raise ValueError("residual_type error") ### shortcut connection num_ker_in = utils.last_dimension(net.get_shape(), min_rank=4) if self.shortcut == 0 and unit_i == 0: # average pooling for data downsampling if group.is_downsample: net = tf.nn.avg_pool(net, ksize=[1, 2, 1, 1], strides=[1, 2, 1, 1], padding='SAME') # zero-padding for increasing kernel numbers if group.num_ker / num_ker_in == 2: net = tf.pad(net, [[0, 0], [0, 0], [0, 0], [int(num_ker_in / 2), int(num_ker_in / 2)]]) elif group.num_ker != num_ker_in: raise ValueError("illigal kernel numbers at group {} unit {}" .format(group_i, unit_i)) elif self.shortcut == 1 and unit_i == 0 or self.shortcut == 2: with tf.variable_scope(name+'_sc'): # projection net = self.BN_ReLU(net) net = self.conv1d(net, group.num_ker, 1, stride1) ### element-wise addition net = net + net_residual return net def resnn(self, sequences): """Build the resnn model. Args: page_batch: Sequences returned from inputs_train() or inputs_eval. Returns: Logits. """ self.model_conf() # [batch_size, html_len, 1, we_dim] target_expanded = tf.expand_dims(sequences, 2) # First convolution with tf.variable_scope('conv_layer1'): net = self.conv1d(target_expanded, self.groups[0].num_ker, 7, 2) # if self.special_first: net = self.BN_ReLU(net) # Max pool net = tf.nn.max_pool(net, [1, 3, 1, 1], strides=[1, 2, 1, 1], padding='SAME') if self.ror_l1: net_l1 = net # stacking Residual Units for group_i, group in enumerate(self.groups): if self.ror_l2: net_l2 = net for unit_i in range(group.num_units): net = self.residual_unit(net, group_i, unit_i) if self.ror_l2: # this is necessary to prevent loss exploding net_l2 = self.BN_ReLU(net_l2) net_l2 = self.conv1d(net_l2, self.groups[group_i].num_ker, self.bott_size13, 2) net = net + net_l2 if self.ror_l1: net_l1 = self.BN_ReLU(net_l1) net_l1 = self.conv1d(net_l1, self.groups[-1].num_ker, self.bott_size13, 2 **len(self.groups)) net = net + net_l1 # an extra activation before average pooling if self.special_first: with tf.variable_scope('special_BN_ReLU'): net = self.BN_ReLU(net) # padding should be VALID for global average pooling # output: batch*1*1*channels net_shape = net.get_shape().as_list() net = tf.nn.avg_pool(net, ksize=[1, net_shape[1], net_shape[2], 1], strides=[1, 1, 1, 1], padding='VALID') net_shape = net.get_shape().as_list() softmax_len = net_shape[1] * net_shape[2] * net_shape[3] net = tf.reshape(net, [-1, softmax_len]) # add dropout if self.dropout: with tf.name_scope("dropout"): net = tf.nn.dropout(net, self.dropout_keep_prob) # 1D-fully connected nueral network with tf.variable_scope('FC-layer'): net = fully_connected( net, num_outputs=self.num_cats, activation_fn=None, normalizer_fn=None, weights_initializer=variance_scaling_initializer(), weights_regularizer=l2_regularizer(self.weight_decay), biases_initializer=tf.zeros_initializer, ) return net def inference(self, page_batch): """Build the resnn model. Args: page_batch: Sequences returned from inputs_train() or inputs_eval. Returns: Logits. """ # self.activation = tf.nn.relu # self.norm_decay = 0.99 target_batch, un_batch, un_len, la_batch, la_len = page_batch return self.resnn(target_batch)