# Copyright 2016 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Functions to create a DSN model and add the different losses to it. Specifically, in this file we define the: - Shared Encoding Similarity Loss Module, with: - The MMD Similarity method - The Correlation Similarity method - The Gradient Reversal (Domain-Adversarial) method - Difference Loss Module - Reconstruction Loss Module - Task Loss Module """ from functools import partial import tensorflow as tf import losses import models import utils slim = tf.contrib.slim ################################################################################ # HELPER FUNCTIONS ################################################################################ def dsn_loss_coefficient(params): """The global_step-dependent weight that specifies when to kick in DSN losses. Args: params: A dictionary of parameters. Expecting 'domain_separation_startpoint' Returns: A weight to that effectively enables or disables the DSN-related losses, i.e. similarity, difference, and reconstruction losses. """ return tf.where( tf.less(slim.get_or_create_global_step(), params['domain_separation_startpoint']), 1e-10, 1.0) ################################################################################ # MODEL CREATION ################################################################################ def create_model(source_images, source_labels, domain_selection_mask, target_images, target_labels, similarity_loss, params, basic_tower_name): """Creates a DSN model. Args: source_images: images from the source domain, a tensor of size [batch_size, height, width, channels] source_labels: a dictionary with the name, tensor pairs. 'classes' is one- hot for the number of classes. domain_selection_mask: a boolean tensor of size [batch_size, ] which denotes the labeled images that belong to the source domain. target_images: images from the target domain, a tensor of size [batch_size, height width, channels]. target_labels: a dictionary with the name, tensor pairs. similarity_loss: The type of method to use for encouraging the codes from the shared encoder to be similar. params: A dictionary of parameters. Expecting 'weight_decay', 'layers_to_regularize', 'use_separation', 'domain_separation_startpoint', 'alpha_weight', 'beta_weight', 'gamma_weight', 'recon_loss_name', 'decoder_name', 'encoder_name' basic_tower_name: the name of the tower to use for the shared encoder. Raises: ValueError: if the arch is not one of the available architectures. """ network = getattr(models, basic_tower_name) num_classes = source_labels['classes'].get_shape().as_list()[1] # Make sure we are using the appropriate number of classes. network = partial(network, num_classes=num_classes) # Add the classification/pose estimation loss to the source domain. source_endpoints = add_task_loss(source_images, source_labels, network, params) if similarity_loss == 'none': # No domain adaptation, we can stop here. return with tf.variable_scope('towers', reuse=True): target_logits, target_endpoints = network( target_images, weight_decay=params['weight_decay'], prefix='target') # Plot target accuracy of the train set. target_accuracy = utils.accuracy( tf.argmax(target_logits, 1), tf.argmax(target_labels['classes'], 1)) if 'quaternions' in target_labels: target_quaternion_loss = losses.log_quaternion_loss( target_labels['quaternions'], target_endpoints['quaternion_pred'], params) tf.summary.scalar('eval/Target quaternions', target_quaternion_loss) tf.summary.scalar('eval/Target accuracy', target_accuracy) source_shared = source_endpoints[params['layers_to_regularize']] target_shared = target_endpoints[params['layers_to_regularize']] # When using the semisupervised model we include labeled target data in the # source classifier. We do not want to include these target domain when # we use the similarity loss. indices = tf.range(0, source_shared.get_shape().as_list()[0]) indices = tf.boolean_mask(indices, domain_selection_mask) add_similarity_loss(similarity_loss, tf.gather(source_shared, indices), tf.gather(target_shared, indices), params) if params['use_separation']: add_autoencoders( source_images, source_shared, target_images, target_shared, params=params,) def add_similarity_loss(method_name, source_samples, target_samples, params, scope=None): """Adds a loss encouraging the shared encoding from each domain to be similar. Args: method_name: the name of the encoding similarity method to use. Valid options include `dann_loss', `mmd_loss' or `correlation_loss'. source_samples: a tensor of shape [num_samples, num_features]. target_samples: a tensor of shape [num_samples, num_features]. params: a dictionary of parameters. Expecting 'gamma_weight'. scope: optional name scope for summary tags. Raises: ValueError: if `method_name` is not recognized. """ weight = dsn_loss_coefficient(params) * params['gamma_weight'] method = getattr(losses, method_name) method(source_samples, target_samples, weight, scope) def add_reconstruction_loss(recon_loss_name, images, recons, weight, domain): """Adds a reconstruction loss. Args: recon_loss_name: The name of the reconstruction loss. images: A `Tensor` of size [batch_size, height, width, 3]. recons: A `Tensor` whose size matches `images`. weight: A scalar coefficient for the loss. domain: The name of the domain being reconstructed. Raises: ValueError: If `recon_loss_name` is not recognized. """ if recon_loss_name == 'sum_of_pairwise_squares': loss_fn = tf.contrib.losses.mean_pairwise_squared_error elif recon_loss_name == 'sum_of_squares': loss_fn = tf.contrib.losses.mean_squared_error else: raise ValueError('recon_loss_name value [%s] not recognized.' % recon_loss_name) loss = loss_fn(recons, images, weight) assert_op = tf.Assert(tf.is_finite(loss), [loss]) with tf.control_dependencies([assert_op]): tf.summary.scalar('losses/%s Recon Loss' % domain, loss) def add_autoencoders(source_data, source_shared, target_data, target_shared, params): """Adds the encoders/decoders for our domain separation model w/ incoherence. Args: source_data: images from the source domain, a tensor of size [batch_size, height, width, channels] source_shared: a tensor with first dimension batch_size target_data: images from the target domain, a tensor of size [batch_size, height, width, channels] target_shared: a tensor with first dimension batch_size params: A dictionary of parameters. Expecting 'layers_to_regularize', 'beta_weight', 'alpha_weight', 'recon_loss_name', 'decoder_name', 'encoder_name', 'weight_decay' """ def normalize_images(images): images -= tf.reduce_min(images) return images / tf.reduce_max(images) def concat_operation(shared_repr, private_repr): return shared_repr + private_repr mu = dsn_loss_coefficient(params) # The layer to concatenate the networks at. concat_layer = params['layers_to_regularize'] # The coefficient for modulating the private/shared difference loss. difference_loss_weight = params['beta_weight'] * mu # The reconstruction weight. recon_loss_weight = params['alpha_weight'] * mu # The reconstruction loss to use. recon_loss_name = params['recon_loss_name'] # The decoder/encoder to use. decoder_name = params['decoder_name'] encoder_name = params['encoder_name'] _, height, width, _ = source_data.get_shape().as_list() code_size = source_shared.get_shape().as_list()[-1] weight_decay = params['weight_decay'] encoder_fn = getattr(models, encoder_name) # Target Auto-encoding. with tf.variable_scope('source_encoder'): source_endpoints = encoder_fn( source_data, code_size, weight_decay=weight_decay) with tf.variable_scope('target_encoder'): target_endpoints = encoder_fn( target_data, code_size, weight_decay=weight_decay) decoder_fn = getattr(models, decoder_name) decoder = partial( decoder_fn, height=height, width=width, channels=source_data.get_shape().as_list()[-1], weight_decay=weight_decay) # Source Auto-encoding. source_private = source_endpoints[concat_layer] target_private = target_endpoints[concat_layer] with tf.variable_scope('decoder'): source_recons = decoder(concat_operation(source_shared, source_private)) with tf.variable_scope('decoder', reuse=True): source_private_recons = decoder( concat_operation(tf.zeros_like(source_private), source_private)) source_shared_recons = decoder( concat_operation(source_shared, tf.zeros_like(source_shared))) with tf.variable_scope('decoder', reuse=True): target_recons = decoder(concat_operation(target_shared, target_private)) target_shared_recons = decoder( concat_operation(target_shared, tf.zeros_like(target_shared))) target_private_recons = decoder( concat_operation(tf.zeros_like(target_private), target_private)) losses.difference_loss( source_private, source_shared, weight=difference_loss_weight, name='Source') losses.difference_loss( target_private, target_shared, weight=difference_loss_weight, name='Target') add_reconstruction_loss(recon_loss_name, source_data, source_recons, recon_loss_weight, 'source') add_reconstruction_loss(recon_loss_name, target_data, target_recons, recon_loss_weight, 'target') # Add summaries source_reconstructions = tf.concat( axis=2, values=map(normalize_images, [ source_data, source_recons, source_shared_recons, source_private_recons ])) target_reconstructions = tf.concat( axis=2, values=map(normalize_images, [ target_data, target_recons, target_shared_recons, target_private_recons ])) tf.summary.image( 'Source Images:Recons:RGB', source_reconstructions[:, :, :, :3], max_outputs=10) tf.summary.image( 'Target Images:Recons:RGB', target_reconstructions[:, :, :, :3], max_outputs=10) if source_reconstructions.get_shape().as_list()[3] == 4: tf.summary.image( 'Source Images:Recons:Depth', source_reconstructions[:, :, :, 3:4], max_outputs=10) tf.summary.image( 'Target Images:Recons:Depth', target_reconstructions[:, :, :, 3:4], max_outputs=10) def add_task_loss(source_images, source_labels, basic_tower, params): """Adds a classification and/or pose estimation loss to the model. Args: source_images: images from the source domain, a tensor of size [batch_size, height, width, channels] source_labels: labels from the source domain, a tensor of size [batch_size]. or a tuple of (quaternions, class_labels) basic_tower: a function that creates the single tower of the model. params: A dictionary of parameters. Expecting 'weight_decay', 'pose_weight'. Returns: The source endpoints. Raises: RuntimeError: if basic tower does not support pose estimation. """ with tf.variable_scope('towers'): source_logits, source_endpoints = basic_tower( source_images, weight_decay=params['weight_decay'], prefix='Source') if 'quaternions' in source_labels: # We have pose estimation as well if 'quaternion_pred' not in source_endpoints: raise RuntimeError('Please use a model for estimation e.g. pose_mini') loss = losses.log_quaternion_loss(source_labels['quaternions'], source_endpoints['quaternion_pred'], params) assert_op = tf.Assert(tf.is_finite(loss), [loss]) with tf.control_dependencies([assert_op]): quaternion_loss = loss tf.summary.histogram('log_quaternion_loss_hist', quaternion_loss) slim.losses.add_loss(quaternion_loss * params['pose_weight']) tf.summary.scalar('losses/quaternion_loss', quaternion_loss) classification_loss = tf.losses.softmax_cross_entropy( source_labels['classes'], source_logits) tf.summary.scalar('losses/classification_loss', classification_loss) return source_endpoints