# Copyright 2016 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Domain Adaptation Loss Functions. The following domain adaptation loss functions are defined: - Maximum Mean Discrepancy (MMD). Relevant paper: Gretton, Arthur, et al., "A kernel two-sample test." The Journal of Machine Learning Research, 2012 - Correlation Loss on a batch. """ from functools import partial import tensorflow as tf import grl_op_grads # pylint: disable=unused-import import grl_op_shapes # pylint: disable=unused-import import grl_ops import utils slim = tf.contrib.slim ################################################################################ # SIMILARITY LOSS ################################################################################ def maximum_mean_discrepancy(x, y, kernel=utils.gaussian_kernel_matrix): r"""Computes the Maximum Mean Discrepancy (MMD) of two samples: x and y. Maximum Mean Discrepancy (MMD) is a distance-measure between the samples of the distributions of x and y. Here we use the kernel two sample estimate using the empirical mean of the two distributions. MMD^2(P, Q) = || \E{\phi(x)} - \E{\phi(y)} ||^2 = \E{ K(x, x) } + \E{ K(y, y) } - 2 \E{ K(x, y) }, where K = <\phi(x), \phi(y)>, is the desired kernel function, in this case a radial basis kernel. Args: x: a tensor of shape [num_samples, num_features] y: a tensor of shape [num_samples, num_features] kernel: a function which computes the kernel in MMD. Defaults to the GaussianKernelMatrix. Returns: a scalar denoting the squared maximum mean discrepancy loss. """ with tf.name_scope('MaximumMeanDiscrepancy'): # \E{ K(x, x) } + \E{ K(y, y) } - 2 \E{ K(x, y) } cost = tf.reduce_mean(kernel(x, x)) cost += tf.reduce_mean(kernel(y, y)) cost -= 2 * tf.reduce_mean(kernel(x, y)) # We do not allow the loss to become negative. cost = tf.where(cost > 0, cost, 0, name='value') return cost def mmd_loss(source_samples, target_samples, weight, scope=None): """Adds a similarity loss term, the MMD between two representations. This Maximum Mean Discrepancy (MMD) loss is calculated with a number of different Gaussian kernels. Args: source_samples: a tensor of shape [num_samples, num_features]. target_samples: a tensor of shape [num_samples, num_features]. weight: the weight of the MMD loss. scope: optional name scope for summary tags. Returns: a scalar tensor representing the MMD loss value. """ sigmas = [ 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100, 1e3, 1e4, 1e5, 1e6 ] gaussian_kernel = partial( utils.gaussian_kernel_matrix, sigmas=tf.constant(sigmas)) loss_value = maximum_mean_discrepancy( source_samples, target_samples, kernel=gaussian_kernel) loss_value = tf.maximum(1e-4, loss_value) * weight assert_op = tf.Assert(tf.is_finite(loss_value), [loss_value]) with tf.control_dependencies([assert_op]): tag = 'MMD Loss' if scope: tag = scope + tag tf.summary.scalar(tag, loss_value) tf.losses.add_loss(loss_value) return loss_value def correlation_loss(source_samples, target_samples, weight, scope=None): """Adds a similarity loss term, the correlation between two representations. Args: source_samples: a tensor of shape [num_samples, num_features] target_samples: a tensor of shape [num_samples, num_features] weight: a scalar weight for the loss. scope: optional name scope for summary tags. Returns: a scalar tensor representing the correlation loss value. """ with tf.name_scope('corr_loss'): source_samples -= tf.reduce_mean(source_samples, 0) target_samples -= tf.reduce_mean(target_samples, 0) source_samples = tf.nn.l2_normalize(source_samples, 1) target_samples = tf.nn.l2_normalize(target_samples, 1) source_cov = tf.matmul(tf.transpose(source_samples), source_samples) target_cov = tf.matmul(tf.transpose(target_samples), target_samples) corr_loss = tf.reduce_mean(tf.square(source_cov - target_cov)) * weight assert_op = tf.Assert(tf.is_finite(corr_loss), [corr_loss]) with tf.control_dependencies([assert_op]): tag = 'Correlation Loss' if scope: tag = scope + tag tf.summary.scalar(tag, corr_loss) tf.losses.add_loss(corr_loss) return corr_loss def dann_loss(source_samples, target_samples, weight, scope=None): """Adds the domain adversarial (DANN) loss. Args: source_samples: a tensor of shape [num_samples, num_features]. target_samples: a tensor of shape [num_samples, num_features]. weight: the weight of the loss. scope: optional name scope for summary tags. Returns: a scalar tensor representing the correlation loss value. """ with tf.variable_scope('dann'): batch_size = tf.shape(source_samples)[0] samples = tf.concat(axis=0, values=[source_samples, target_samples]) samples = slim.flatten(samples) domain_selection_mask = tf.concat( axis=0, values=[tf.zeros((batch_size, 1)), tf.ones((batch_size, 1))]) # Perform the gradient reversal and be careful with the shape. grl = grl_ops.gradient_reversal(samples) grl = tf.reshape(grl, (-1, samples.get_shape().as_list()[1])) grl = slim.fully_connected(grl, 100, scope='fc1') logits = slim.fully_connected(grl, 1, activation_fn=None, scope='fc2') domain_predictions = tf.sigmoid(logits) domain_loss = tf.losses.log_loss( domain_selection_mask, domain_predictions, weights=weight) domain_accuracy = utils.accuracy( tf.round(domain_predictions), domain_selection_mask) assert_op = tf.Assert(tf.is_finite(domain_loss), [domain_loss]) with tf.control_dependencies([assert_op]): tag_loss = 'losses/domain_loss' tag_accuracy = 'losses/domain_accuracy' if scope: tag_loss = scope + tag_loss tag_accuracy = scope + tag_accuracy tf.summary.scalar(tag_loss, domain_loss) tf.summary.scalar(tag_accuracy, domain_accuracy) return domain_loss ################################################################################ # DIFFERENCE LOSS ################################################################################ def difference_loss(private_samples, shared_samples, weight=1.0, name=''): """Adds the difference loss between the private and shared representations. Args: private_samples: a tensor of shape [num_samples, num_features]. shared_samples: a tensor of shape [num_samples, num_features]. weight: the weight of the incoherence loss. name: the name of the tf summary. """ private_samples -= tf.reduce_mean(private_samples, 0) shared_samples -= tf.reduce_mean(shared_samples, 0) private_samples = tf.nn.l2_normalize(private_samples, 1) shared_samples = tf.nn.l2_normalize(shared_samples, 1) correlation_matrix = tf.matmul( private_samples, shared_samples, transpose_a=True) cost = tf.reduce_mean(tf.square(correlation_matrix)) * weight cost = tf.where(cost > 0, cost, 0, name='value') tf.summary.scalar('losses/Difference Loss {}'.format(name), cost) assert_op = tf.Assert(tf.is_finite(cost), [cost]) with tf.control_dependencies([assert_op]): tf.losses.add_loss(cost) ################################################################################ # TASK LOSS ################################################################################ def log_quaternion_loss_batch(predictions, labels, params): """A helper function to compute the error between quaternions. Args: predictions: A Tensor of size [batch_size, 4]. labels: A Tensor of size [batch_size, 4]. params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'. Returns: A Tensor of size [batch_size], denoting the error between the quaternions. """ use_logging = params['use_logging'] assertions = [] if use_logging: assertions.append( tf.Assert( tf.reduce_all( tf.less( tf.abs(tf.reduce_sum(tf.square(predictions), [1]) - 1), 1e-4)), ['The l2 norm of each prediction quaternion vector should be 1.'])) assertions.append( tf.Assert( tf.reduce_all( tf.less( tf.abs(tf.reduce_sum(tf.square(labels), [1]) - 1), 1e-4)), ['The l2 norm of each label quaternion vector should be 1.'])) with tf.control_dependencies(assertions): product = tf.multiply(predictions, labels) internal_dot_products = tf.reduce_sum(product, [1]) if use_logging: internal_dot_products = tf.Print( internal_dot_products, [internal_dot_products, tf.shape(internal_dot_products)], 'internal_dot_products:') logcost = tf.log(1e-4 + 1 - tf.abs(internal_dot_products)) return logcost def log_quaternion_loss(predictions, labels, params): """A helper function to compute the mean error between batches of quaternions. The caller is expected to add the loss to the graph. Args: predictions: A Tensor of size [batch_size, 4]. labels: A Tensor of size [batch_size, 4]. params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'. Returns: A Tensor of size 1, denoting the mean error between batches of quaternions. """ use_logging = params['use_logging'] logcost = log_quaternion_loss_batch(predictions, labels, params) logcost = tf.reduce_sum(logcost, [0]) batch_size = params['batch_size'] logcost = tf.multiply(logcost, 1.0 / batch_size, name='log_quaternion_loss') if use_logging: logcost = tf.Print( logcost, [logcost], '[logcost]', name='log_quaternion_loss_print') return logcost