from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf import tensorflow.contrib.slim as slim model_params = { 'conv1': ['conv1'], 'conv12': ['conv1', 'conv2'], 'conv13': ['conv1', 'conv2', 'conv3'], 'conv14': ['conv1', 'conv2', 'conv3', 'conv4'], 'conv15': ['conv1', 'conv2', 'conv3', 'conv4', 'conv5'], 'all': ['conv1', 'conv2', 'conv3', 'conv4', 'conv5', 'fc'], 'conv45+fc': ['conv4', 'conv5', 'fc'], 'conv5+fc': ['conv5', 'fc'], 'fc': ['fc'], } batch_norm_params = { # Decay for the moving averages. 'decay': 0.995, # epsilon to prevent 0s in variance. 'epsilon': 0.001, # force in-place updates of mean and variance estimates 'updates_collections': None, # Moving averages ends up in the trainable variables collection 'variables_collections': [ tf.GraphKeys.TRAINABLE_VARIABLES ], } batch_norm_params_last = { # Decay for the moving averages. 'decay': 0.995, # epsilon to prevent 0s in variance. 'epsilon': 10e-8, # force in-place updates of mean and variance estimates 'center': False, # not use beta 'scale': False, # not use gamma 'updates_collections': None, # Moving averages ends up in the trainable variables collection 'variables_collections': [ tf.GraphKeys.TRAINABLE_VARIABLES ], } def parametric_relu(x): num_channels = x.shape[-1].value with tf.variable_scope('p_re_lu'): alpha = tf.get_variable('alpha', (1,1,num_channels), initializer=tf.constant_initializer(0.0), dtype=tf.float32) return tf.nn.relu(x) + alpha * tf.minimum(0.0, x) # activation = lambda x: tf.keras.layers.PReLU(shared_axes=[1,2]).apply(x) activation = parametric_relu def se_module(input_net, ratio=16, reuse = None, scope = None): with tf.variable_scope(scope, 'SE', [input_net], reuse=reuse): h,w,c = tuple([dim.value for dim in input_net.shape[1:4]]) assert c % ratio == 0 hidden_units = int(c / ratio) squeeze = slim.avg_pool2d(input_net, [h,w], padding='VALID') excitation = slim.flatten(squeeze) excitation = slim.fully_connected(excitation, hidden_units, scope='se_fc1', weights_regularizer=None, # weights_initializer=tf.truncated_normal_initializer(stddev=0.1), weights_initializer=slim.xavier_initializer(), activation_fn=tf.nn.relu) excitation = slim.fully_connected(excitation, c, scope='se_fc2', weights_regularizer=None, # weights_initializer=tf.truncated_normal_initializer(stddev=0.1), weights_initializer=slim.xavier_initializer(), activation_fn=tf.nn.sigmoid) excitation = tf.reshape(excitation, [-1,1,1,c]) output_net = input_net * excitation return output_net def conv_module(net, num_res_layers, num_kernels, reuse = None, scope = None): with tf.variable_scope(scope, 'conv', [net], reuse=reuse): # Every 2 conv layers constitute a residual block if scope == 'conv1': for i in range(len(num_kernels)): with tf.variable_scope('layer_%d'%i, reuse=reuse): net = slim.conv2d(net, num_kernels[i], kernel_size=3, stride=1, padding='VALID', weights_initializer=slim.xavier_initializer()) # net = activation(net) print('| ---- layer_%d' % i) net = slim.max_pool2d(net, 2, stride=2, padding='VALID') else: shortcut = net for i in range(num_res_layers): with tf.variable_scope('layer_%d'%i, reuse=reuse): net = slim.conv2d(net, num_kernels[0], kernel_size=3, stride=1, padding='SAME', weights_initializer=tf.truncated_normal_initializer(stddev=0.01), biases_initializer=None) # net = activation(net) print('| ---- layer_%d' % i) if i % 2 == 1: net = se_module(net) net = net + shortcut shortcut = net print('| shortcut') # Pooling for conv2 - conv4 if len(num_kernels) > 1: with tf.variable_scope('expand', reuse=reuse): # net = slim.batch_norm(net, **batch_norm_params) net = slim.conv2d(net, num_kernels[1], kernel_size=3, stride=1, padding='VALID', weights_initializer=slim.xavier_initializer()) # net = activation(net) net = slim.max_pool2d(net, 2, stride=2, padding='VALID') print('- expand') return net def build_scope(images, bottleneck_layer_size, shared_modules, scope_name, shared_scope_name, reuse=tf.AUTO_REUSE): get_scope = lambda x: shared_scope_name if x in shared_modules else scope_name with tf.variable_scope(get_scope('conv1'), reuse=reuse): print(tf.get_variable_scope().name) net = conv_module(images, 0, [32, 64], scope='conv1') print('module_1 shape:', [dim.value for dim in net.shape]) with tf.variable_scope(get_scope('conv2'), reuse=reuse): print(tf.get_variable_scope().name) net = conv_module(net, 2, [64, 128], scope='conv2') print('module_2 shape:', [dim.value for dim in net.shape]) with tf.variable_scope(get_scope('conv3'), reuse=reuse): print(tf.get_variable_scope().name) net = conv_module(net, 4, [128, 256], scope='conv3') print('module_3 shape:', [dim.value for dim in net.shape]) with tf.variable_scope(get_scope('conv4'), reuse=reuse): print(tf.get_variable_scope().name) net = conv_module(net, 10, [256, 512], scope='conv4') print('module_4 shape:', [dim.value for dim in net.shape]) with tf.variable_scope(get_scope('conv5'), reuse=reuse): print(tf.get_variable_scope().name) net = conv_module(net, 6, [512], scope='conv5') print('module_5 shape:', [dim.value for dim in net.shape]) with tf.variable_scope(get_scope('fc'), reuse=reuse): print(tf.get_variable_scope().name) net = slim.flatten(net) prelogits = slim.fully_connected(net, bottleneck_layer_size, scope='Bottleneck', weights_initializer=slim.xavier_initializer(), activation_fn=None) return prelogits def inference(images_A, images_B, keep_probability=1.0, phase_train=True, bottleneck_layer_size=512, weight_decay=0.0, reuse=None, model_version=None): with slim.arg_scope([slim.conv2d, slim.fully_connected], weights_regularizer=slim.l2_regularizer(weight_decay), activation_fn=activation, normalizer_fn=None, normalizer_params=None): with slim.arg_scope([slim.batch_norm, slim.dropout], is_training=phase_train): with tf.variable_scope('FaceResNet', [images_A, images_B], reuse=reuse): shared_modules = model_params[model_version] print('input shape:', [dim.value for dim in images_A.shape]) prelogits_A = build_scope(images_A, bottleneck_layer_size, shared_modules, "NetA", "SharedNet") prelogits_B = build_scope(images_B, bottleneck_layer_size, shared_modules, "NetB", "SharedNet") return prelogits_A, prelogits_B