import tensorflow as tf def conditional_instance_norm(x, scope_bn, y1=None, y2=None, alpha=1): mean, var = tf.nn.moments(x, axes=[1, 2], keep_dims=True) if y1==None: beta = tf.get_variable(name=scope_bn + 'beta', shape=[x.shape[-1]], initializer=tf.constant_initializer([0.]), trainable=True) # label_nums x C gamma = tf.get_variable(name=scope_bn + 'gamma', shape=[x.shape[-1]], initializer=tf.constant_initializer([1.]), trainable=True) # label_nums x C else: beta = tf.get_variable(name=scope_bn+'beta', shape=[y1.shape[-1], x.shape[-1]], initializer=tf.constant_initializer([0.]), trainable=True) # label_nums x C gamma = tf.get_variable(name=scope_bn+'gamma', shape=[y1.shape[-1], x.shape[-1]], initializer=tf.constant_initializer([1.]), trainable=True) # label_nums x C beta1 = tf.matmul(y1, beta) gamma1 = tf.matmul(y1, gamma) beta2 = tf.matmul(y2, beta) gamma2 = tf.matmul(y2, gamma) beta = alpha * beta1 + (1. - alpha) * beta2 gamma = alpha * gamma1 + (1. - alpha) * gamma2 x = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-10) return x def conv(name, inputs, k_size, nums_in, nums_out, strides): pad_size = k_size // 2 inputs = tf.pad(inputs, [[0, 0], [pad_size, pad_size], [pad_size, pad_size], [0, 0]], mode="REFLECT") kernel = tf.get_variable(name+"W", [k_size, k_size, nums_in, nums_out], initializer=tf.truncated_normal_initializer(stddev=0.01)) bias = tf.get_variable(name+"B", [nums_out], initializer=tf.constant_initializer(0.)) return tf.nn.conv2d(inputs, kernel, [1, strides, strides, 1], "VALID") + bias def upsampling(name, inputs, nums_in, nums_out, y1, y2, alpha): inputs = tf.image.resize_nearest_neighbor(inputs, [tf.shape(inputs)[1] * 2, tf.shape(inputs)[2] * 2]) return conditional_instance_norm(conv(name, inputs, 3, nums_in, nums_out, 1), "cin"+name, y1, y2, alpha) def relu(inputs): return tf.nn.relu(inputs) def sigmoid(inputs): return tf.nn.sigmoid(inputs) def ResBlock(name, inputs, k_size, nums_in, nums_out, y1, y2, alpha): temp = inputs * 1.0 inputs = conditional_instance_norm(conv("conv1_" + name, inputs, k_size, nums_in, nums_out, 1), "cin1"+name, y1, y2, alpha) inputs = relu(inputs) inputs = conditional_instance_norm(conv("conv2_" + name, inputs, k_size, nums_in, nums_out, 1), "cin2"+name, y1, y2, alpha) return inputs + temp def content_loss(phi_content, phi_target): return tf.nn.l2_loss(phi_content["conv2_2"] - phi_target["conv2_2"]) * 2 / tf.cast(tf.size(phi_content["conv2_2"]), dtype=tf.float32) def style_loss(phi_style, phi_target): layers = ["conv1_2", "conv2_2", "conv3_3", "conv4_3"] loss = 0 for layer in layers: s_maps = phi_style[layer] G_s = gram(s_maps) t_maps = phi_target[layer] G_t = gram(t_maps) loss += tf.nn.l2_loss(G_s - G_t) * 2 / tf.cast(tf.size(G_t), dtype=tf.float32) return loss def gram(layer): shape = tf.shape(layer) num_images = shape[0] width = shape[1] height = shape[2] num_filters = shape[3] filters = tf.reshape(layer, tf.stack([num_images, -1, num_filters])) grams = tf.matmul(filters, filters, transpose_a=True) / tf.to_float(width * height * num_filters) return grams