# coding=utf-8 # Copyright 2019 The Interval Bound Propagation Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Naive bound calculation for common neural network layers.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from interval_bound_propagation.src import bounds as basic_bounds from interval_bound_propagation.src import relative_bounds import sonnet as snt import tensorflow.compat.v1 as tf class SimplexBounds(basic_bounds.AbstractBounds): """Specifies a bounding simplex within an embedding space.""" def __init__(self, vertices, nominal, r): """Initialises the simplex bounds. Args: vertices: Tensor of shape (num_vertices, *input_shape) or of shape (batch_size, num_vertices, *input_shape) containing the vertices in embedding space. nominal: Tensor of shape (batch_size, *input_shape) specifying the unperturbed inputs in embedding space, where `*input_shape` denotes either (embedding_size,) for flat input (e.g. bag-of-words) or (input_length, embedding_channels) for sequence input. r: Scalar specifying the dilation factor of the simplex. The dilated simplex will have vertices `nominal + r * (vertices-nominal)`. """ super(SimplexBounds, self).__init__() self._vertices = vertices self._nominal = nominal self._r = r @property def vertices(self): return self._vertices @property def nominal(self): return self._nominal @property def r(self): return self._r @property def shape(self): return self.nominal.shape.as_list() @classmethod def convert(cls, bounds): if not isinstance(bounds, cls): raise ValueError('Cannot convert "{}" to "{}"'.format(bounds, cls.__name__)) return bounds def apply_batch_reshape(self, wrapper, shape): reshape = snt.BatchReshape(shape) if self.vertices.shape.ndims == self.nominal.shape.ndims: reshape_vertices = reshape else: reshape_vertices = snt.BatchReshape(shape, preserve_dims=2) return SimplexBounds(reshape_vertices(self.vertices), reshape(self.nominal), self.r) def apply_linear(self, wrapper, w, b): mapped_centres = tf.matmul(self.nominal, w) mapped_vertices = tf.tensordot(self.vertices, w, axes=1) lb, ub = _simplex_bounds(mapped_vertices, mapped_centres, self.r, -2) nominal_out = tf.matmul(self.nominal, w) if b is not None: nominal_out += b return relative_bounds.RelativeIntervalBounds(lb, ub, nominal_out) def apply_conv1d(self, wrapper, w, b, padding, stride): mapped_centres = tf.nn.conv1d(self.nominal, w, padding=padding, stride=stride) if self.vertices.shape.ndims == 3: # `self.vertices` has no batch dimension; its shape is # (num_vertices, input_length, embedding_channels). mapped_vertices = tf.nn.conv1d(self.vertices, w, padding=padding, stride=stride) elif self.vertices.shape.ndims == 4: # `self.vertices` has shape # (batch_size, num_vertices, input_length, embedding_channels). # Vertices are different for each example in the batch, # e.g. for word perturbations. mapped_vertices = snt.BatchApply( lambda x: tf.nn.conv1d(x, w, padding=padding, stride=stride))( self.vertices) else: raise ValueError('"vertices" must have either 3 or 4 dimensions.') lb, ub = _simplex_bounds(mapped_vertices, mapped_centres, self.r, -3) nominal_out = tf.nn.conv1d(self.nominal, w, padding=padding, stride=stride) if b is not None: nominal_out += b return relative_bounds.RelativeIntervalBounds(lb, ub, nominal_out) def apply_conv2d(self, wrapper, w, b, padding, strides): mapped_centres = tf.nn.convolution(self.nominal, w, padding=padding, strides=strides) if self.vertices.shape.ndims == 4: # `self.vertices` has no batch dimension; its shape is # (num_vertices, input_height, input_width, input_channels). mapped_vertices = tf.nn.convolution(self.vertices, w, padding=padding, strides=strides) elif self.vertices.shape.ndims == 5: # `self.vertices` has shape # (batch_size, num_vertices, input_height, input_width, input_channels). # Vertices are different for each example in the batch. mapped_vertices = snt.BatchApply( lambda x: tf.nn.convolution(x, w, padding=padding, strides=strides))( self.vertices) else: raise ValueError('"vertices" must have either 4 or 5 dimensions.') lb, ub = _simplex_bounds(mapped_vertices, mapped_centres, self.r, -4) nominal_out = tf.nn.convolution(self.nominal, w, padding=padding, strides=strides) if b is not None: nominal_out += b return relative_bounds.RelativeIntervalBounds(lb, ub, nominal_out) def apply_increasing_monotonic_fn(self, wrapper, fn, *args, **parameters): if fn.__name__ in ('add', 'reduce_mean', 'reduce_sum', 'avg_pool'): if self.vertices.shape.ndims == self.nominal.shape.ndims: vertices_fn = fn else: vertices_fn = snt.BatchApply(fn, n_dims=2) return SimplexBounds( vertices_fn(self.vertices, *[bounds.vertices for bounds in args]), fn(self.nominal, *[bounds.nominal for bounds in args]), self.r) elif fn.__name__ == 'quotient': return SimplexBounds( self.vertices / tf.expand_dims(parameters['denom'], axis=1), fn(self.nominal), self.r) else: return super(SimplexBounds, self).apply_increasing_monotonic_fn( wrapper, fn, *args, **parameters) def _simplex_bounds(mapped_vertices, mapped_centres, r, axis): """Calculates naive bounds on the given layer-mapped vertices. Args: mapped_vertices: Tensor of shape (num_vertices, *output_shape) or of shape (batch_size, num_vertices, *output_shape) containing the vertices in the layer's output space. mapped_centres: Tensor of shape (batch_size, *output_shape) containing the layer's nominal outputs. r: Scalar in [0, 1) specifying the radius (in vocab space) of the simplex. axis: Index of the `num_vertices` dimension of `mapped_vertices`. Returns: lb_out: Tensor of shape (batch_size, *output_shape) with lower bounds on the outputs of the affine layer. ub_out: Tensor of shape (batch_size, *output_shape) with upper bounds on the outputs of the affine layer. """ # Use the negative of r, instead of the complement of r, as # we're shifting the input domain to be centred at the origin. lb_out = -r * mapped_centres + r * tf.reduce_min(mapped_vertices, axis=axis) ub_out = -r * mapped_centres + r * tf.reduce_max(mapped_vertices, axis=axis) return lb_out, ub_out