# coding=utf-8 # Copyright 2019 The Interval Bound Propagation Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Helper to keep track of the different losses.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import sonnet as snt import tensorflow.compat.v1 as tf # Used to pick the least violated specification. _BIG_NUMBER = 1e25 ScalarMetrics = collections.namedtuple('ScalarMetrics', [ 'nominal_accuracy', 'verified_accuracy', 'attack_accuracy', 'attack_success']) ScalarLosses = collections.namedtuple('ScalarLosses', [ 'nominal_cross_entropy', 'attack_cross_entropy', 'verified_loss']) class Losses(snt.AbstractModule): """Helper to compute our losses.""" def __init__(self, predictor, specification=None, pgd_attack=None, interval_bounds_loss_type='xent', interval_bounds_hinge_margin=10., label_smoothing=0.): super(Losses, self).__init__(name='losses') self._predictor = predictor self._specification = specification self._attack = pgd_attack # Loss type can be any combination of: # xent: cross-entropy loss # hinge: hinge loss # softplus: softplus loss # with # all: using all specifications. # most: using only the specification that is the most violated. # least: using only the specification that is the least violated. # random_n: using a random subset of the specifications. # E.g.: "xent_max" or "hinge_random_3". tokens = interval_bounds_loss_type.split('_', 1) if len(tokens) == 1: loss_type, loss_mode = tokens[0], 'all' else: loss_type, loss_mode = tokens if loss_mode.startswith('random'): loss_mode, num_samples = loss_mode.split('_', 1) self._interval_bounds_loss_n = int(num_samples) if loss_type not in ('xent', 'hinge', 'softplus'): raise ValueError('interval_bounds_loss_type must be either "xent", ' '"hinge" or "softplus".') if loss_mode not in ('all', 'most', 'random', 'least'): raise ValueError('interval_bounds_loss_type must be followed by either ' '"all", "most", "random_N" or "least".') self._interval_bounds_loss_type = loss_type self._interval_bounds_loss_mode = loss_mode self._interval_bounds_hinge_margin = interval_bounds_hinge_margin self._label_smoothing = label_smoothing def _build(self, labels): self._build_nominal_loss(labels) self._build_verified_loss(labels) self._build_attack_loss(labels) def _build_nominal_loss(self, labels): """Build natural cross-entropy loss on clean data.""" # Cross-entropy. nominal_logits = self._predictor.logits if self._label_smoothing > 0: num_classes = nominal_logits.shape[1].value one_hot_labels = tf.one_hot(labels, num_classes) smooth_positives = 1. - self._label_smoothing smooth_negatives = self._label_smoothing / num_classes one_hot_labels = one_hot_labels * smooth_positives + smooth_negatives nominal_cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2( labels=one_hot_labels, logits=nominal_logits) self._one_hot_labels = one_hot_labels else: nominal_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=nominal_logits) self._cross_entropy = tf.reduce_mean(nominal_cross_entropy) # Accuracy. nominal_correct_examples = tf.equal(labels, tf.argmax(nominal_logits, 1)) self._nominal_accuracy = tf.reduce_mean( tf.cast(nominal_correct_examples, tf.float32)) def _get_specification_bounds(self): """Get upper bounds on specification. Used for building verified loss.""" ibp_bounds = self._specification(self._predictor.modules) # Compute verified accuracy using IBP bounds. v = tf.reduce_max(ibp_bounds, axis=1) self._interval_bounds_accuracy = tf.reduce_mean( tf.cast(v <= 0., tf.float32)) return ibp_bounds def _build_verified_loss(self, labels): """Build verified loss using an upper bound on specification.""" if not self._specification: self._verified_loss = tf.constant(0.) self._interval_bounds_accuracy = tf.constant(0.) return # Interval bounds. bounds = self._get_specification_bounds() # Select specifications. if self._interval_bounds_loss_mode == 'all': pass # Keep bounds the way it is. elif self._interval_bounds_loss_mode == 'most': bounds = tf.reduce_max(bounds, axis=1, keepdims=True) elif self._interval_bounds_loss_mode == 'random': idx = tf.random.uniform( [tf.shape(bounds)[0], self._interval_bounds_loss_n], 0, tf.shape(bounds)[1], dtype=tf.int32) bounds = tf.batch_gather(bounds, idx) else: assert self._interval_bounds_loss_mode == 'least' # This picks the least violated contraint. mask = tf.cast(bounds < 0., tf.float32) smallest_violation = tf.reduce_min( bounds + mask * _BIG_NUMBER, axis=1, keepdims=True) has_violations = tf.less( tf.reduce_sum(mask, axis=1, keepdims=True) + .5, tf.cast(tf.shape(bounds)[1], tf.float32)) largest_bounds = tf.reduce_max(bounds, axis=1, keepdims=True) bounds = tf.where(has_violations, smallest_violation, largest_bounds) if self._interval_bounds_loss_type == 'xent': v = tf.concat( [bounds, tf.zeros([tf.shape(bounds)[0], 1], dtype=bounds.dtype)], axis=1) l = tf.concat( [tf.zeros_like(bounds), tf.ones([tf.shape(bounds)[0], 1], dtype=bounds.dtype)], axis=1) self._verified_loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2( labels=tf.stop_gradient(l), logits=v)) elif self._interval_bounds_loss_type == 'softplus': self._verified_loss = tf.reduce_mean( tf.nn.softplus(bounds + self._interval_bounds_hinge_margin)) else: assert self._interval_bounds_loss_type == 'hinge' self._verified_loss = tf.reduce_mean( tf.maximum(bounds, -self._interval_bounds_hinge_margin)) def _build_attack_loss(self, labels): """Build adversarial loss using PGD attack.""" # PGD attack. if not self._attack: self._attack_accuracy = tf.constant(0.) self._attack_success = tf.constant(1.) self._attack_cross_entropy = tf.constant(0.) return if not isinstance(self._predictor.inputs, tf.Tensor): raise ValueError('Multiple inputs is not supported.') self._attack(self._predictor.inputs, labels) correct_examples = tf.equal(labels, tf.argmax(self._attack.logits, 1)) self._attack_accuracy = tf.reduce_mean( tf.cast(correct_examples, tf.float32)) self._attack_success = tf.reduce_mean( tf.cast(self._attack.success, tf.float32)) if self._label_smoothing > 0: attack_cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2( labels=self._one_hot_labels, logits=self._attack.logits) else: attack_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=self._attack.logits) self._attack_cross_entropy = tf.reduce_mean(attack_cross_entropy) @property def scalar_metrics(self): self._ensure_is_connected() return ScalarMetrics(self._nominal_accuracy, self._interval_bounds_accuracy, self._attack_accuracy, self._attack_success) @property def scalar_losses(self): self._ensure_is_connected() return ScalarLosses(self._cross_entropy, self._attack_cross_entropy, self._verified_loss)