# Network architectures. from __future__ import division from utils.ops import relu, conv2d, lrelu, instance_norm, deconv2d, tanh import tensorflow as tf import tensorflow.contrib.slim as slim def discriminator(params, x_init, reuse=False): """ Discriminator. Parameters ---------- params: dict. x_init: input tensor. reuse: bool, reuse the net if True. Returns ------- x_gan: tensor, outputs for adversarial training. x_reg: tensor, outputs for gaze estimation. """ layers = 5 channel = 64 image_size = params.image_size with tf.variable_scope('discriminator', reuse=reuse): # 64 3 -> 32 64 -> 16 128 -> 8 256 -> 4 512 -> 2 1024 x = conv2d(x_init, channel, conv_filters_dim=4, d_h=2, d_w=2, scope='conv_0', pad=1, use_bias=True) x = lrelu(x) for i in range(1, layers): x = conv2d(x, channel * 2, conv_filters_dim=4, d_h=2, d_w=2, scope='conv_%d' % i, pad=1, use_bias=True) x = lrelu(x) channel = channel * 2 filter_size = int(image_size / 2 ** layers) x_gan = conv2d(x, 1, conv_filters_dim=filter_size, d_h=1, d_w=1, pad=1, scope='conv_logit_gan', use_bias=False) x_reg = conv2d(x, 2, conv_filters_dim=filter_size, d_h=1, d_w=1, pad=0, scope='conv_logit_reg', use_bias=False) x_reg = tf.reshape(x_reg, [-1, 2]) return x_gan, x_reg def generator(input_, angles, reuse=False): """ Generator. Parameters ---------- input_: tensor, input images. angles: tensor, target gaze direction. reuse: bool, reuse the net if True. Returns ------- x: tensor, generated image. """ channel = 64 style_dim = angles.get_shape().as_list()[-1] angles_reshaped = tf.reshape(angles, [-1, 1, 1, style_dim]) angles_tiled = tf.tile(angles_reshaped, [1, tf.shape(input_)[1], tf.shape(input_)[2], 1]) x = tf.concat([input_, angles_tiled], axis=3) with tf.variable_scope('generator', reuse=reuse): # input layer x = conv2d(x, channel, d_h=1, d_w=1, scope='conv2d_input', use_bias=False, pad=3, conv_filters_dim=7) x = instance_norm(x, scope='in_input') x = relu(x) # encoder for i in range(2): x = conv2d(x, 2 * channel, d_h=2, d_w=2, scope='conv2d_%d' % i, use_bias=False, pad=1, conv_filters_dim=4) x = instance_norm(x, scope='in_conv_%d' % i) x = relu(x) channel = 2 * channel # bottleneck for i in range(6): x_a = conv2d(x, channel, conv_filters_dim=3, d_h=1, d_w=1, pad=1, use_bias=False, scope='conv_res_a_%d' % i) x_a = instance_norm(x_a, 'in_res_a_%d' % i) x_a = relu(x_a) x_b = conv2d(x_a, channel, conv_filters_dim=3, d_h=1, d_w=1, pad=1, use_bias=False, scope='conv_res_b_%d' % i) x_b = instance_norm(x_b, 'in_res_b_%d' % i) x = x + x_b # decoder for i in range(2): x = deconv2d(x, int(channel / 2), conv_filters_dim=4, d_h=2, d_w=2, use_bias=False, scope='deconv_%d' % i) x = instance_norm(x, scope='in_decon_%d' % i) x = relu(x) channel = int(channel / 2) x = conv2d(x, 3, conv_filters_dim=7, d_h=1, d_w=1, pad=3, use_bias=False, scope='output') x = tanh(x) return x def vgg_16(inputs, scope='vgg_16', reuse=False): """ VGG-16. Parameters ---------- inputs: tensor. scope: name of scope. reuse: reuse the net if True. Returns ------- net: tensor, output tensor. end_points: dict, collection of layers. """ with tf.variable_scope(scope, 'vgg_16', [inputs], reuse=reuse) as sc: end_points_collection = sc.original_name_scope + '_end_points' # Collect outputs for conv2d, fully_connected and max_pool2d. with slim.arg_scope( [slim.conv2d, slim.fully_connected, slim.max_pool2d], outputs_collections=end_points_collection): net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') net = slim.max_pool2d(net, [2, 2], scope='pool1') net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2') net = slim.max_pool2d(net, [2, 2], scope='pool2') net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3') net = slim.max_pool2d(net, [2, 2], scope='pool3') net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4') net = slim.max_pool2d(net, [2, 2], scope='pool4') net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5') net = slim.max_pool2d(net, [2, 2], scope='pool5') # Convert end_points_collection into a end_point dict. end_points = slim.utils.convert_collection_to_dict( end_points_collection) return net, end_points