# default modules import numpy as np import tensorflow as tf import warnings from GeneralTools.misc_fun import FLAGS ######################################################################## def kron_by_reshape(mat1, mat2, mat_shape=None): """ This function does kronecker product through reshape and perm :param mat1: 2-D tensor :param mat2: 2-D tensor :param mat_shape: shape of mat1 and mat2 :return mat3: mat3 = kronecker(mat1, mat2) """ if mat_shape is None: a, b = mat1.shape c, d = mat2.shape else: # in case of tensorflow, mat_shape must be provided a, b, c, d = mat_shape if isinstance(mat1, np.ndarray) and isinstance(mat2, np.ndarray): mat3 = np.matmul(np.reshape(mat1, [-1, 1]), np.reshape(mat2, [1, -1])) # (axb)-by-(cxd) mat3 = np.reshape(mat3, [a, b, c, d]) # a-by-b-by-c-by-d mat3 = np.transpose(mat3, axes=[0, 2, 1, 3]) # a-by-c-by-b-by-d mat3 = np.reshape(mat3, [a * c, b * d]) # (axc)-by-(bxd) elif isinstance(mat1, tf.Tensor) and isinstance(mat2, tf.Tensor): mat3 = tf.matmul(tf.reshape(mat1, [-1, 1]), tf.reshape(mat2, [1, -1])) # (axb)-by-(cxd) mat3 = tf.reshape(mat3, [a, b, c, d]) # a-by-b-by-c-by-d mat3 = tf.transpose(mat3, perm=[0, 2, 1, 3]) # a-by-c-by-b-by-d mat3 = tf.reshape(mat3, [a * c, b * d]) # (axc)-by-(bxd) else: raise AttributeError('Input should be numpy array or tensor') return mat3 ######################################################################## def scale_range(x, scale_min=-1.0, scale_max=1.0, axis=1): """ This function scales numpy matrix to range [scale_min, scale_max] """ x_min = np.amin(x, axis=axis, keepdims=True) x_range = np.amax(x, axis=axis, keepdims=True) - x_min x_range[x_range == 0.0] = 1.0 # scale to [0,1] x = (x - x_min) / x_range # scale to [scale_min, scale_max] x = x * (scale_max - scale_min) + scale_min return x ######################################################################## def mean_cov_np(x): """ This function calculates mean and covariance for 2d array x. This function is faster than separately running np.mean and np.cov :param x: 2D array, columns of x represents variables. :return: """ mu = np.mean(x, axis=0) x_centred = x - mu cov = np.matmul(x_centred.transpose(), x_centred) / (x.shape[0] - 1.0) return mu, cov ######################################################################## def mean_cov_tf(x): """ This function calculates mean and covariance for 2d array x. :param x: 2D array, columns of x represents variables. :return: """ mu = tf.reduce_mean(x, axis=0, keepdims=True) # 1-D x_centred = x - mu cov = tf.matmul(x_centred, x_centred, transpose_a=True) / (x.get_shape().as_list()[0] - 1.0) return mu, cov ######################################################################## def scale_image_range(image, scale_min=-1.0, scale_max=1.0, image_format='channels_last'): """ This function scales images per channel to [-1,1]. The max and min are calculated over all samples. Note that, in batch normalization, they also calculate the mean and std for each feature map. :param image: 4-D numpy array, either in channels_first format or channels_last format :param scale_min: :param scale_max: :param image_format :return: """ if len(image.shape) != 4: raise AttributeError('Input must be 4-D tensor.') if image_format == 'channels_last': num_instance, height, width, num_channel = image.shape pixel_channel = image.reshape((-1, num_channel)) # [pixels, channel] pixel_channel = scale_range(pixel_channel, scale_min=scale_min, scale_max=scale_max, axis=0) image = pixel_channel.reshape((num_instance, height, width, num_channel)) elif image_format == 'channels_first': # scale_range works faster when axis=1, work on this image = np.transpose(image, axes=(1, 0, 2, 3)) num_channel, num_instance, height, width = image.shape pixel_channel = image.reshape((num_channel, -1)) # [channel, pixels] pixel_channel = scale_range(pixel_channel, scale_min=scale_min, scale_max=scale_max, axis=1) image = pixel_channel.reshape((num_channel, num_instance, height, width)) image = np.transpose(image, axes=(1, 0, 2, 3)) # convert back to channels_first return image ######################################################################## def pairwise_dist(mat1, mat2=None): """ This function calculates the pairwise distance matrix dist. If mat2 is not provided, dist is defined among row vectors of mat1. The distance is formed as sqrt(mat1*mat1' - 2*mat1*mat2' + mat2*mat2') :param mat1: :param mat2: :return: """ # tf.reduce_sum() will produce result of shape (N,), which, when transposed, is still (N,) # Thus, to force mm1 and mm2 (or mm1') to have different shape, tf.expand_dims() is used mm1 = tf.expand_dims(tf.reduce_sum(tf.multiply(mat1, mat1), axis=1), axis=1) if mat2 is None: mmt = tf.multiply(tf.matmul(mat1, mat1, transpose_b=True), -2) dist = tf.sqrt(tf.add(tf.add(tf.add(mm1, tf.transpose(mm1)), mmt), FLAGS.EPSI)) else: mm2 = tf.expand_dims(tf.reduce_sum(tf.multiply(mat2, mat2), axis=1), axis=0) mrt = tf.multiply(tf.matmul(mat1, mat2, transpose_b=True), -2) dist = tf.sqrt(tf.add(tf.add(tf.add(mm1, mm2), mrt), FLAGS.EPSI)) # dist = tf.sqrt(tf.add(tf.add(mm1, mm2), mrt)) return dist ######################################################################## def slerp(p0, p1, t): """ This function calculates the spherical linear interpolation of p0 and p1 :param p0: a vector of shape (d, ) :param p1: a vector of shape (d, ) :param t: a scalar, or a vector of shape (n, ) :return: Numeric instability may occur when theta is close to zero or pi. In these cases, sin(t * theta) >> sin(theta). These cases are common, e.g. p0 = -p1. """ from numpy.linalg import norm theta = np.arccos(np.dot(p0 / norm(p0), p1 / norm(p1)), dtype=np.float32) st = np.sin(theta) # there is no dtype para for np.sin # in case t is a vector, output is a row matrix if not np.isscalar(t): p0 = np.expand_dims(p0, axis=0) p1 = np.expand_dims(p1, axis=0) t = np.expand_dims(t, axis=1) if st > 0.1: p2 = np.sin((1.0 - t) * theta) / st * p0 + np.sin(t * theta) / st * p1 else: p2 = (1.0 - t) * p0 + t * p1 return p2 def spatial_shape_after_conv(input_spatial_shape, kernel_size, strides, dilation, padding): """ This function calculates the spatial shape after conv layer. The formula is obtained from: https://www.tensorflow.org/api_docs/python/tf/nn/convolution It should be note that current function assumes PS is done before conv :param input_spatial_shape: :param kernel_size: :param strides: :param dilation: :param padding: :return: """ if isinstance(input_spatial_shape, (list, tuple)): return [spatial_shape_after_conv( one_shape, kernel_size, strides, dilation, padding) for one_shape in input_spatial_shape] else: if padding in ['same', 'SAME']: return np.int(np.ceil(input_spatial_shape / strides)) else: return np.int(np.ceil((input_spatial_shape - (kernel_size - 1) * dilation) / strides)) def spatial_shape_after_transpose_conv(input_spatial_shape, kernel_size, strides, dilation, padding): """ This function calculates the spatial shape after conv layer. Since transpose conv is often used in upsampling, scale_factor is not used here. This function has not been fully tested, and may be wrong in some cases. :param input_spatial_shape: :param kernel_size: :param strides: :param dilation: :param padding: :return: """ if isinstance(input_spatial_shape, (list, tuple)): return [spatial_shape_after_transpose_conv( one_shape, kernel_size, strides, dilation, padding) for one_shape in input_spatial_shape] else: if padding in ['same', 'SAME']: return np.int(input_spatial_shape * strides) else: return np.int(input_spatial_shape * strides + (kernel_size - 1) * dilation) ######################################################################## class MeshCode(object): def __init__(self, code_length, mesh_num=None): """ This function creates meshed code for generative models :param code_length: :param mesh_num: :return: """ self.D = code_length if mesh_num is None: self.mesh_num = (10, 10) else: self.mesh_num = mesh_num def get_batch(self, mesh_mode, name=None): if name is None: name = 'Z' if mesh_mode == 0 or mesh_mode == 'random': z_batch = self.by_random(name) elif mesh_mode == 1 or mesh_mode == 'sine': z_batch = self.by_sine(name) elif mesh_mode == 2 or mesh_mode == 'feature': z_batch = self.by_feature(name) else: raise AttributeError('mesh_mode is not supported.') return z_batch def by_random(self, name=None): """ This function generates mesh code randomly :param name: :return: """ return tf.random_normal( [self.mesh_num[0] * self.mesh_num[1], self.D], mean=0.0, stddev=1.0, name=name) def by_sine(self, z_support=None, name=None): """ This function creates mesh code by interpolating between four supporting codes :param z_support: :param name: list or tuple of two elements :return: """ if z_support is None: z_support = tf.random_normal( [4, self.D], mean=0.0, stddev=1.0) elif isinstance(z_support, np.ndarray): z_support = tf.constant(z_support, dtype=tf.float32) z0 = tf.expand_dims(z_support[0], axis=0) # create 1-by-D vector z1 = tf.expand_dims(z_support[1], axis=0) z2 = tf.expand_dims(z_support[2], axis=0) z3 = tf.expand_dims(z_support[3], axis=0) # generate phi and psi from 0 to 90 degrees mesh_phi = np.float32( # mesh_num[0]-by-1 vector np.expand_dims(np.pi / 4.0 * np.linspace(0.0, 1.0, self.mesh_num[0]), axis=1)) mesh_psi = np.float32( np.expand_dims(np.pi / 4.0 * np.linspace(0.0, 1.0, self.mesh_num[1]), axis=1)) # sample instances on the manifold z_batch = tf.identity( # mesh_num[0]*mesh_num[1]-by-1 vector kron_by_reshape( # do kronecker product tf.matmul(tf.cos(mesh_psi), z0) + tf.matmul(tf.sin(mesh_psi), z1), tf.cos(mesh_phi), mat_shape=[self.mesh_num[1], self.D, self.mesh_num[0], 1]) + kron_by_reshape( tf.matmul(tf.cos(mesh_psi), z2) + tf.matmul(tf.sin(mesh_psi), z3), tf.sin(mesh_phi), mat_shape=[self.mesh_num[1], self.D, self.mesh_num[0], 1]), name=name) return z_batch def by_feature(self, grid=2.0, name=None): """ This function creates mesh code by varying a single feature. In this case, mesh_num[0] refers to the number of features to mesh, mesh[1] refers to the number of variations in one feature :param grid: :param name: string :return: """ mesh = np.float32( # mesh_num[0]-by-1 vector np.expand_dims(np.linspace(-grid, grid, self.mesh_num[1]), axis=1)) # sample instances on the manifold z_batch = kron_by_reshape( # mesh_num[0]*mesh_num[1]-by-1 vector tf.eye(num_rows=self.mesh_num[0], num_columns=self.D), tf.constant(mesh), mat_shape=[self.mesh_num[0], self.D, self.mesh_num[1], 1]) # shuffle the columns of z_batch z_batch = tf.identity( tf.transpose(tf.random_shuffle(tf.transpose(z_batch, perm=[1, 0])), perm=[1, 0]), name=name) return z_batch def simple_grid(self, grid=None): """ This function creates simple grid meshes Note: this function returns np.ndarray :param grid: :return: """ if self.D != 2: raise AttributeError('Code length has to be two') if grid is None: grid = np.array([[-1.0, 1.0], [-1.0, 1.0]], dtype=np.float32) x = np.linspace(grid[0][0], grid[0][1], self.mesh_num[0]) y = np.linspace(grid[1][0], grid[1][1], self.mesh_num[1]) z0 = np.reshape(np.transpose(np.tile(x, (self.mesh_num[1], 1))), [-1, 1]) z1 = np.reshape(np.tile(y, (1, self.mesh_num[0])), [-1, 1]) z = np.concatenate((z0, z1), axis=1) return z, x, y def j_diagram(self, name=None): """ This function creates a j diagram using slerp This function is not finished as there is a problem with the slerp idea. :param name: :return: """ raise NotImplementedError('This function has not been implemented.') # z_support = np.random.randn(4, self.D) # z0 = tf.expand_dims(z_support[0], axis=0) # create 1-by-D vector # z1 = tf.expand_dims(z_support[1], axis=0) # z2 = tf.expand_dims(z_support[2], axis=0) # pass ######################################################################## def mat_slice(mat, row_index, col_index=None, name='slice'): """ This function gets mat[index, index] where index is either bool or int32. Note that: if index is bool, output size is typically smaller than mat unless each element in index is True if index is int32, output can be any size. :param mat: :param row_index: :param col_index: :param name; :return: """ if col_index is None: col_index = row_index with tf.name_scope(name): if row_index.dtype != col_index.dtype: raise AttributeError('dtype of row-index and col-index do not match.') if row_index.dtype == tf.int32: return tf.gather(tf.gather(mat, row_index, axis=0), col_index, axis=1) elif row_index.dtype == tf.bool: return tf.boolean_mask(tf.boolean_mask(mat, row_index, axis=0), col_index, axis=1) else: raise AttributeError('Type of index is: {}; expected either tf.int32 or tf.bool'.format(row_index.dtype)) ######################################################################## def l2normalization(w): """ This function applies l2 normalization to the input vector. If w is a matrix / tensor, the Frobenius norm is used for normalization. :param w: :return: """ # tf.norm is slightly faster than tf.sqrt(tf.reduce_sum(tf.square())) # it is important that axis=None; in this case, norm(w) = norm(vec(w)) return w / (tf.norm(w, ord='euclidean', axis=None) + FLAGS.EPSI) class SpectralNorm(object): def __init__(self, sn_def, name_scope='SN', scope_prefix='', num_iter=1): """ This class contains functions to calculate the spectral normalization of the weight matrix using power iteration. The application of spectral normal to NN is proposed in following papers: Yoshida, Y., & Miyato, T. (2017). Spectral Norm Regularization for Improving the Generalizability of Deep Learning. Miyato, T., Kataoka, T., Koyama, M., & Yoshida, Y. (2017). Spectral Normalization for Generative Adversarial Networks, Here spectral normalization is generalized for any linear ops or combination of linear ops Example of usage: Example 1. w = tf.constant(np.random.randn(3, 3, 128, 64).astype(np.float32)) sn_def = {'op': 'tc', 'input_shape': [10, 64, 64, 64], 'output_shape': [10, 128, 64, 64], 'strides': 1, 'dilation': 1, 'padding': 'SAME', 'data_format': 'NCHW'} sigma = SpectralNorm(sn_def, name_scope='SN1', num_iter=20).apply(w) Example 2. w = tf.constant(np.random.randn(3, 3, 128, 64).astype(np.float32)) w2 = tf.constant(np.random.randn(3, 3, 128, 64).astype(np.float32)) sn_def = {'op': 'tc', 'input_shape': [10, 64, 64, 64], 'output_shape': [10, 128, 64, 64], 'strides': 1, 'dilation': 1, 'padding': 'SAME', 'data_format': 'NCHW'} SN = SpectralNorm(sn_def, num_iter=20) sigma1 = SN.apply(w) sigma2 = SN.apply(w2, name_scope='SN2', num_iter=30) :param sn_def: a dictionary with keys depending on the type of kernel: type keys value options dense: 'op' 'd' - common dense layer; 'cd' - conditional dense layers; 'dcd' - dense + conditional dense; 'dck' - dense * conditional scale 'project' - same to cd, except num_out is 1 conv: 'op' 'c' - convolution; 'tc' - transpose convolution; 'cck' - convolution * conditional scale; 'tcck' - t-conv * conditional scale 'strides' integer 'dilation' integer 'padding' 'SAME' or 'VALID' 'data_format' 'NCHW' or 'NHWC' 'input_shape' list of integers in format NCHW or NHWC 'output_shape' for 'tc', output shape must be provided :param name_scope: :param scope_prefix: :param num_iter: number of power iterations per run """ self.sn_def = sn_def.copy() self.name_scope = name_scope self.scope_prefix = scope_prefix self.name_in_err = self.scope_prefix + self.name_scope self.num_iter = num_iter # initialize self.w = None self.x = None self.use_u = None self.is_initialized = False self.forward = None self.backward = None # format stride if self.sn_def['op'] in {'c', 'tc', 'cck', 'tcck'}: if self.sn_def['data_format'] in ['NCHW', 'channels_first']: self.sn_def['strides'] = (1, 1, self.sn_def['strides'], self.sn_def['strides']) else: self.sn_def['strides'] = (1, self.sn_def['strides'], self.sn_def['strides'], 1) assert 'output_shape' in self.sn_def, \ '{}: for conv, output_shape must be provided.'.format(self.name_in_err) def _init_routine(self): """ This function decides the routine to minimize memory usage :return: """ if self.is_initialized is False: # decide the routine if self.sn_def['op'] in {'d', 'project'}: # for d kernel_shape [num_in, num_out]; for project, kernel shape [num_class, num_in] assert len(self.kernel_shape) == 2, \ '{}: kernel shape {} does not have length 2'.format(self.name_in_err, self.kernel_shape) num_in, num_out = self.kernel_shape # self.use_u = True self.use_u = True if num_in <= num_out else False x_shape = [1, num_in] if self.use_u else [1, num_out] self.forward = self._dense_ if self.use_u else self._dense_t_ self.backward = self._dense_t_ if self.use_u else self._dense_ elif self.sn_def['op'] in {'cd'}: # kernel_shape [num_class, num_in, num_out] assert len(self.kernel_shape) == 3, \ '{}: kernel shape {} does not have length 3'.format(self.name_in_err, self.kernel_shape) num_class, num_in, num_out = self.kernel_shape self.use_u = True if num_in <= num_out else False x_shape = [num_class, 1, num_in] if self.use_u else [num_class, 1, num_out] self.forward = self._dense_ if self.use_u else self._dense_t_ self.backward = self._dense_t_ if self.use_u else self._dense_ elif self.sn_def['op'] in {'dck'}: # convolution * conditional scale assert isinstance(self.kernel_shape, (list, tuple)) and len(self.kernel_shape) == 2, \ '{}: kernel shape must be a list of length 2. Got {}'.format(self.name_in_err, self.kernel_shape) assert len(self.kernel_shape[0]) == 2 and len(self.kernel_shape[1]) == 2, \ '{}: kernel shape {} does not have length 2'.format(self.name_in_err, self.kernel_shape) num_in, num_out = self.kernel_shape[0] num_class = self.kernel_shape[1][0] self.use_u = True if num_in <= num_out else False x_shape = [num_class, num_in] if self.use_u else [num_class, num_out] self.forward = (lambda x: self._scalar_(self._dense_(x, index=0), index=1, offset=1.0)) \ if self.use_u else (lambda y: self._dense_t_(self._scalar_(y, index=1, offset=1.0), index=0)) self.backward = (lambda y: self._dense_t_(self._scalar_(y, index=1, offset=1.0), index=0)) \ if self.use_u else (lambda x: self._scalar_(self._dense_(x, index=0), index=1, offset=1.0)) elif self.sn_def['op'] in {'c', 'tc'}: assert len(self.kernel_shape) == 4, \ '{}: kernel shape {} does not have length 4'.format(self.name_in_err, self.kernel_shape) # self.use_u = True self.use_u = True \ if np.prod(self.sn_def['input_shape'][1:]) <= np.prod(self.sn_def['output_shape'][1:]) \ else False if self.sn_def['op'] in {'c'}: # input / output shape NCHW or NHWC x_shape = self.sn_def['input_shape'].copy() if self.use_u else self.sn_def['output_shape'].copy() x_shape[0] = 1 y_shape = self.sn_def['input_shape'].copy() y_shape[0] = 1 elif self.sn_def['op'] in {'tc'}: # tc x_shape = self.sn_def['output_shape'].copy() if self.use_u else self.sn_def['input_shape'].copy() x_shape[0] = 1 y_shape = self.sn_def['output_shape'].copy() y_shape[0] = 1 else: raise NotImplementedError('{}: {} not implemented.'.format(self.name_in_err, self.sn_def['op'])) self.forward = self._conv_ if self.use_u else (lambda y: self._conv_t_(y, x_shape=y_shape)) self.backward = (lambda y: self._conv_t_(y, x_shape=y_shape)) if self.use_u else self._conv_ elif self.sn_def['op'] in {'cck', 'tcck'}: # convolution * conditional scale assert isinstance(self.kernel_shape, (list, tuple)) and len(self.kernel_shape) == 2, \ '{}: kernel shape must be a list of length 2. Got {}'.format(self.name_in_err, self.kernel_shape) assert len(self.kernel_shape[0]) == 4 and len(self.kernel_shape[1]) == 4, \ '{}: kernel shape {} does not have length 4'.format(self.name_in_err, self.kernel_shape) self.use_u = True \ if np.prod(self.sn_def['input_shape'][1:]) <= np.prod(self.sn_def['output_shape'][1:]) \ else False num_class = self.kernel_shape[1][0] if self.sn_def['op'] in {'cck'}: # input / output shape NCHW or NHWC x_shape = self.sn_def['input_shape'].copy() if self.use_u else self.sn_def['output_shape'].copy() x_shape[0] = num_class y_shape = self.sn_def['input_shape'].copy() y_shape[0] = num_class self.forward = (lambda x: self._scalar_(self._conv_(x, index=0), index=1, offset=1.0)) \ if self.use_u \ else (lambda y: self._conv_t_(self._scalar_(y, index=1, offset=1.0), x_shape=y_shape, index=0)) self.backward = (lambda y: self._conv_t_( self._scalar_(y, index=1, offset=1.0), x_shape=y_shape, index=0)) \ if self.use_u else (lambda x: self._scalar_(self._conv_(x, index=0), index=1, offset=1.0)) elif self.sn_def['op'] in {'tcck'}: # tcck x_shape = self.sn_def['output_shape'].copy() if self.use_u else self.sn_def['input_shape'].copy() x_shape[0] = num_class y_shape = self.sn_def['output_shape'].copy() y_shape[0] = num_class self.forward = (lambda x: self._conv_(self._scalar_(x, index=1, offset=1.0), index=0)) \ if self.use_u \ else (lambda y: self._scalar_(self._conv_t_(y, x_shape=y_shape, index=0), index=1, offset=1.0)) self.backward = (lambda y: self._scalar_( self._conv_t_(y, x_shape=y_shape, index=0), index=1, offset=1.0)) \ if self.use_u else (lambda x: self._conv_(self._scalar_(x, index=1, offset=1.0), index=0)) else: raise NotImplementedError('{}: {} not implemented.'.format(self.name_in_err, self.sn_def['op'])) else: raise NotImplementedError('{}: {} is not implemented.'.format(self.name_in_err, self.sn_def['op'])) self.x = tf.get_variable( 'in_rand', shape=x_shape, dtype=tf.float32, initializer=tf.truncated_normal_initializer(), trainable=False) self.is_initialized = True def _scalar_(self, x, index=None, offset=0.0): """ This function defines a elementwise multiplication op: y = x * w, where x shape [N, C, ...] or [N, ..., C], w shape [N, C, 1,..,1] or [N, 1,...,1, C], y shape [N, C, ...] or [N, ..., C] :param x: :param index: if index is provided, self.w is a list or tuple :param offset: add a constant offset :return: """ w = self.w if index is None else self.w[index] return tf.multiply(x, w, name='scalar') if offset == 0.0 else tf.multiply(x, w + offset, name='scalar') def _dense_(self, x, index=None): """ This function defines a dense op: y = x * w, where x shape [..., a, b], w shape [..., b, c], y shape [..., a, c] :param x: :param index: if index is provided, self.w is a list or tuple :return: """ w = self.w if index is None else self.w[index] return tf.matmul(x, w, name='dense') def _dense_t_(self, y, index=None): """ Transpose version of self._dense_ :param y: :param index: if index is provided, self.w is a list or tuple :return: """ w = self.w if index is None else self.w[index] return tf.matmul(y, w, transpose_b=True, name='dense_t') def _conv_(self, x, index=None): """ This function defines a conv op: y = x \otimes w, where x shape NCHW or NHWC, w shape kkhw, y shape NCHW or NHWC :param x: :param index: if index is provided, self.w is a list or tuple :return: """ w = self.w if index is None else self.w[index] if self.sn_def['dilation'] > 1: return tf.nn.atrous_conv2d( x, w, rate=self.sn_def['dilation'], padding=self.sn_def['padding'], name='conv') else: return tf.nn.conv2d( x, w, strides=self.sn_def['strides'], padding=self.sn_def['padding'], data_format=self.sn_def['data_format'], name='conv') def _conv_t_(self, y, x_shape, index=None): """ Transpose version of self._conv_ :param y: :param x_shape: :param index: :return: """ w = self.w if index is None else self.w[index] if self.sn_def['dilation'] > 1: return tf.nn.atrous_conv2d_transpose( y, w, output_shape=x_shape, rate=self.sn_def['dilation'], padding=self.sn_def['padding'], name='conv_t') else: return tf.nn.conv2d_transpose( y, w, output_shape=x_shape, strides=self.sn_def['strides'], padding=self.sn_def['padding'], data_format=self.sn_def['data_format'], name='conv_t') def _l2_norm(self, x): if self.sn_def['op'] in {'cd'}: # x shape [num_class, 1, num_in or num_out] return tf.norm(x, ord='euclidean', axis=2, keepdims=True) # return [num_class, 1, 1] elif self.sn_def['op'] in {'dck'}: # x shape [num_class, num_in or num_out] return tf.norm(x, ord='euclidean', axis=1, keepdims=True) # return [num_class, 1] elif self.sn_def['op'] in {'cck', 'tcck'}: # x shape [num_class, num_in or num_out, H, W] or [num_class, H, W, num_in or num_out] # here i did not use tf.norm because axis cannot be (1, 2, 3) return tf.sqrt( tf.reduce_sum(tf.square(x), axis=(1, 2, 3), keepdims=True), name='norm') # return [num_class, 1, 1, 1] elif self.sn_def['op'] in {'d', 'c', 'tc', 'project'}: # x shape [1, num_in or num_out], or [1, num_in or num_out, H, W] or [1, H, W, num_in or num_out] return tf.norm(x, ord='euclidean', axis=None) # return scalar def _l2_normalize_(self, w): """ :param w: :return: """ return w / (self._l2_norm(w) + FLAGS.EPSI) def _power_iter_(self, x, step): """ This function does power iteration for one step :param x: :param step: :return: """ y = self._l2_normalize_(self.forward(x)) x_update = self._l2_normalize_(self.backward(y)) sigma = self._l2_norm(self.forward(x)) return sigma, x_update, step + 1 def __call__(self, kernel, **kwargs): """ This function calculates spectral normalization for kernel :param kernel: :param kwargs: :return: """ # check inputs if 'name_scope' in kwargs and kwargs['name_scope'] != self.name_scope: # different name_scope will initialize another SN process self.name_scope = kwargs['name_scope'] self.name_in_err = self.scope_prefix + self.name_scope if self.is_initialized: warnings.warn( '{}: a new SN process caused lost of links to the previous one.'.format(self.name_in_err)) self.is_initialized = False self.use_u = None if 'num_iter' in kwargs: self.num_iter = kwargs['num_iter'] if isinstance(kernel, (list, tuple)): # for dcd, cck, the kernel is a list of two kernels kernel_shape = [k.get_shape().as_list() for k in kernel] else: kernel_shape = kernel.get_shape().as_list() with tf.variable_scope(self.name_scope, reuse=tf.AUTO_REUSE): # In some cases, the spectral norm can be easily calculated. sigma = None if self.sn_def['op'] in {'d', 'project'} and 1 in kernel_shape: # for project op. kernel_shape = [num_class, num_in] sigma = tf.norm(kernel, ord='euclidean') elif self.sn_def['op'] in {'cd'}: if len(kernel_shape) == 2: # equivalent to [num_class, num_in, 1] sigma = tf.norm(kernel, ord='euclidean', axis=1, keepdims=True) elif kernel_shape[1] == 1 or kernel_shape[2] == 1: sigma = tf.norm(kernel, ord='euclidean', axis=(1, 2), keepdims=True) elif self.sn_def['op'] in {'dcd'}: # dense + conditional dense # kernel_cd [num_class, num_in, num_out] kernel_cd = tf.expand_dims(kernel[1], axis=2) if len(kernel_shape[1]) == 2 else kernel[1] kernel = tf.expand_dims(kernel[0], axis=0) + kernel_cd # [num_class, num_in, num_out] if 1 in kernel_shape[0]: # kernel_d shape [1, num_out] or [num_in, 1] sigma = tf.norm(kernel, ord='euclidean', axis=(1, 2), keepdims=True) # [num_class, 1, 1] else: # convert dcd to cd kernel_shape = kernel.get_shape().as_list() self.sn_def['op'] = 'cd' elif self.sn_def['op'] in {'dck'}: # dense * conditional scales if kernel_shape[0][1] == 1: sigma = tf.norm(kernel[0], ord='euclidean') * tf.abs(kernel[1]) # [num_class, 1] # initialize a random input and calculate spectral norm if sigma is None: # decide the routine self.w = kernel self.kernel_shape = kernel_shape self._init_routine() # initialize sigma if self.sn_def['op'] in {'dck'}: sigma_init = tf.zeros((self.kernel_shape[1][0], 1), dtype=tf.float32) elif self.sn_def['op'] in {'cd'}: # for cd, the sigma is a [num_class, 1, 1] sigma_init = tf.zeros((self.kernel_shape[0], 1, 1), dtype=tf.float32) elif self.sn_def['op'] in {'cck', 'tcck'}: sigma_init = tf.zeros((self.kernel_shape[1][0], 1, 1, 1), dtype=tf.float32) else: sigma_init = tf.constant(0.0, dtype=tf.float32) # do power iterations sigma, x_update, _ = tf.while_loop( cond=lambda _1, _2, i: i < self.num_iter, body=lambda _1, x, i: self._power_iter_(x, step=i), loop_vars=(sigma_init, self.x, tf.constant(0, dtype=tf.int32))) # update the random input tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, tf.assign(self.x, x_update)) return sigma def apply(self, kernel, **kwargs): return self.__call__(kernel, **kwargs) ######################################################################## def batch_norm(tensor, axis=None, keepdims=False, name='norm'): """ This function calculates the l2 norm for each instance in a batch :param tensor: shape [batch_size, ...] :param axis: the axis to calculate norm, could be integer or list/tuple of integers :param keepdims: whether to keep dimensions :param name: :return: """ with tf.name_scope(name): return tf.sqrt(tf.reduce_sum(tf.square(tensor), axis=axis, keepdims=keepdims)) ######################################################################## def get_squared_dist( x, y=None, scale=None, z_score=False, mode='xxxyyy', name='squared_dist', do_summary=False, scope_prefix=''): """ This function calculates the pairwise distance between x and x, x and y, y and y Warning: when x, y has mean far away from zero, the distance calculation is not accurate; use get_dist_ref instead :param x: batch_size-by-d matrix :param y: batch_size-by-d matrix :param scale: 1-by-d vector, the precision vector. dxy = x*scale*y :param z_score: :param mode: 'xxxyyy', 'xx', 'xy', 'xxxy' :param name: :param do_summary: :param scope_prefix: summary scope prefix :return: """ with tf.name_scope(name): # check inputs if len(x.get_shape().as_list()) > 2: raise AttributeError('get_dist: Input must be a matrix.') if y is None: mode = 'xx' if z_score: if y is None: mu = tf.reduce_mean(x, axis=0, keepdims=True) x = x - mu else: mu = tf.reduce_mean(tf.concat((x, y), axis=0), axis=0, keepdims=True) x = x - mu y = y - mu if mode in ['xx', 'xxxy', 'xxxyyy']: if scale is None: xxt = tf.matmul(x, x, transpose_b=True) # [xi_xi, xi_xj; xj_xi, xj_xj], batch_size-by-batch_size else: xxt = tf.matmul(x * scale, x, transpose_b=True) dx = tf.diag_part(xxt) # [xxt], [batch_size] dist_xx = tf.maximum(tf.expand_dims(dx, axis=1) - 2.0 * xxt + tf.expand_dims(dx, axis=0), 0.0) if do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.histogram(scope_prefix + name + '/dxx', dist_xx) if mode == 'xx': return dist_xx elif mode == 'xxxy': # estimate dy without yyt if scale is None: xyt = tf.matmul(x, y, transpose_b=True) dy = tf.reduce_sum(tf.multiply(y, y), axis=1) else: xyt = tf.matmul(x * scale, y, transpose_b=True) dy = tf.reduce_sum(tf.multiply(y * scale, y), axis=1) dist_xy = tf.maximum(tf.expand_dims(dx, axis=1) - 2.0 * xyt + tf.expand_dims(dy, axis=0), 0.0) if do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.histogram(scope_prefix + name + '/dxy', dist_xy) return dist_xx, dist_xy elif mode == 'xxxyyy': if scale is None: xyt = tf.matmul(x, y, transpose_b=True) yyt = tf.matmul(y, y, transpose_b=True) else: xyt = tf.matmul(x * scale, y, transpose_b=True) yyt = tf.matmul(y * scale, y, transpose_b=True) dy = tf.diag_part(yyt) dist_xy = tf.maximum(tf.expand_dims(dx, axis=1) - 2.0 * xyt + tf.expand_dims(dy, axis=0), 0.0) dist_yy = tf.maximum(tf.expand_dims(dy, axis=1) - 2.0 * yyt + tf.expand_dims(dy, axis=0), 0.0) if do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.histogram(scope_prefix + name + '/dxy', dist_xy) tf.summary.histogram(scope_prefix + name + '/dyy', dist_yy) return dist_xx, dist_xy, dist_yy elif mode == 'xy': if scale is None: dx = tf.reduce_sum(tf.multiply(x, x), axis=1) dy = tf.reduce_sum(tf.multiply(y, y), axis=1) xyt = tf.matmul(x, y, transpose_b=True) else: dx = tf.reduce_sum(tf.multiply(x * scale, x), axis=1) dy = tf.reduce_sum(tf.multiply(y * scale, y), axis=1) xyt = tf.matmul(x * scale, y, transpose_b=True) dist_xy = tf.maximum(tf.expand_dims(dx, axis=1) - 2.0 * xyt + tf.expand_dims(dy, axis=0), 0.0) if do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.histogram(scope_prefix + name + '/dxy', dist_xy) return dist_xy else: raise AttributeError('Mode {} not supported'.format(mode)) def get_squared_dist_ref(x, y): """ This function calculates the pairwise distance between x and x, x and y, y and y. It is more accurate than get_dist at the cost of higher memory and complexity. :param x: :param y: :return: """ with tf.name_scope('squared_dist_ref'): if len(x.get_shape().as_list()) > 2: raise AttributeError('get_dist: Input must be a matrix.') x_expand = tf.expand_dims(x, axis=2) # m-by-d-by-1 x_permute = tf.transpose(x_expand, perm=(2, 1, 0)) # 1-by-d-by-m dxx = x_expand - x_permute # m-by-d-by-m, the first page is ai - a1 dist_xx = tf.reduce_sum(tf.multiply(dxx, dxx), axis=1) # m-by-m, the first column is (ai-a1)^2 if y is None: return dist_xx else: y_expand = tf.expand_dims(y, axis=2) # m-by-d-by-1 y_permute = tf.transpose(y_expand, perm=(2, 1, 0)) dxy = x_expand - y_permute # m-by-d-by-m, the first page is ai - b1 dist_xy = tf.reduce_sum(tf.multiply(dxy, dxy), axis=1) # m-by-m, the first column is (ai-b1)^2 dyy = y_expand - y_permute # m-by-d-by-m, the first page is ai - b1 dist_yy = tf.reduce_sum(tf.multiply(dyy, dyy), axis=1) # m-by-m, the first column is (ai-b1)^2 return dist_xx, dist_xy, dist_yy ######################################################################## def squared_dist_triplet(x, y, z, name='squared_dist', do_summary=False, scope_prefix=''): """ This function calculates the pairwise distance between x and x, x and y, y and y, y and z, z and z in 'seq' mode, or any two pairs in 'all' mode :param x: :param y: :param z: :param name: :param do_summary: :param scope_prefix: :return: """ with tf.name_scope(name): x_x = tf.matmul(x, x, transpose_b=True) y_y = tf.matmul(y, y, transpose_b=True) z_z = tf.matmul(z, z, transpose_b=True) x_y = tf.matmul(x, y, transpose_b=True) y_z = tf.matmul(y, z, transpose_b=True) x_z = tf.matmul(x, z, transpose_b=True) d_x = tf.diag_part(x_x) d_y = tf.diag_part(y_y) d_z = tf.diag_part(z_z) d_x_x = tf.maximum(tf.expand_dims(d_x, axis=1) - 2.0 * x_x + tf.expand_dims(d_x, axis=0), 0.0) d_y_y = tf.maximum(tf.expand_dims(d_y, axis=1) - 2.0 * y_y + tf.expand_dims(d_y, axis=0), 0.0) d_z_z = tf.maximum(tf.expand_dims(d_z, axis=1) - 2.0 * z_z + tf.expand_dims(d_z, axis=0), 0.0) d_x_y = tf.maximum(tf.expand_dims(d_x, axis=1) - 2.0 * x_y + tf.expand_dims(d_y, axis=0), 0.0) d_y_z = tf.maximum(tf.expand_dims(d_y, axis=1) - 2.0 * y_z + tf.expand_dims(d_z, axis=0), 0.0) d_x_z = tf.maximum(tf.expand_dims(d_x, axis=1) - 2.0 * x_z + tf.expand_dims(d_z, axis=0), 0.0) if do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.histogram(scope_prefix + name + '/dxx', d_x_x) tf.summary.histogram(scope_prefix + name + '/dyy', d_y_y) tf.summary.histogram(scope_prefix + name + '/dzz', d_z_z) tf.summary.histogram(scope_prefix + name + '/dxy', d_x_y) tf.summary.histogram(scope_prefix + name + '/dyz', d_y_z) tf.summary.histogram(scope_prefix + name + '/dxz', d_x_z) return d_x_x, d_y_y, d_z_z, d_x_y, d_x_z, d_y_z ######################################################################## def get_dist_np(x, y): """ This function calculates the pairwise distance between x and y using numpy :param x: m-by-d array :param y: n-by-d array :return: """ x = np.array(x, dtype=np.float32) y = np.array(y, dtype=np.float32) x_expand = np.expand_dims(x, axis=2) # m-by-d-by-1 y_expand = np.expand_dims(y, axis=2) # n-by-d-by-1 y_permute = np.transpose(y_expand, axes=(2, 1, 0)) # 1-by-d-by-n dxy = x_expand - y_permute # m-by-d-by-n, the first page is ai - b1 dist_xy = np.sqrt(np.sum(np.multiply(dxy, dxy), axis=1, dtype=np.float32)) # m-by-n, the first column is (ai-b1)^2 return dist_xy ######################################################################### def get_batch_squared_dist(x_batch, y_batch=None, axis=1, mode='xx', name='squared_dist'): """ This function calculates squared pairwise distance for vectors under xi or between xi and yi where i refers to the samples in the batch :param x_batch: batch_size-a-b tensor :param y_batch: batch_size-c-d tensor :param axis: the axis to be considered as features; if axis==1, a=c; if axis=2, b=d :param mode: 'xxxyyy', 'xx', 'xy', 'xxxy' :param name: :return: dist tensor(s) """ # check inputs assert axis in [1, 2], 'axis has to be 1 or 2.' batch, a, b = x_batch.get_shape().as_list() if y_batch is not None: batch_y, c, d = y_batch.get_shape().as_list() assert batch == batch_y, 'Batch sizes do not match.' if axis == 1: assert a == c, 'Feature sizes do not match.' elif axis == 2: assert b == d, 'Feature sizes do not match.' if mode == 'xx': mode = 'xy' with tf.name_scope(name): if mode in {'xx', 'xxxyyy', 'xxxy'}: # xxt is batch-a-a if axis is 2 else batch-b-b xxt = tf.matmul(x_batch, tf.transpose(x_batch, [0, 2, 1])) \ if axis == 2 else tf.matmul(tf.transpose(x_batch, [0, 2, 1]), x_batch) # dx is batch-a if axis is 2 else batch-b dx = tf.matrix_diag_part(xxt) dist_xx = tf.maximum(tf.expand_dims(dx, axis=2) - 2.0 * xxt + tf.expand_dims(dx, axis=1), 0.0) if mode == 'xx': return dist_xx elif mode == 'xxxy': # xyt is batch-a-c if axis is 2 else batch-b-d xyt = tf.matmul(x_batch, tf.transpose(y_batch, [0, 2, 1])) \ if axis == 2 else tf.matmul(tf.transpose(x_batch, [0, 2, 1]), y_batch) # dy is batch-c if axis is 2 else batch-d dy = tf.reduce_sum(tf.multiply(y_batch, y_batch), axis=axis) dist_xy = tf.maximum(tf.expand_dims(dx, axis=2) - 2.0 * xyt + tf.expand_dims(dy, axis=1), 0.0) return dist_xx, dist_xy elif mode == 'xxxyyy': # xyt is batch-a-c if axis is 2 else batch-b-d xyt = tf.matmul(x_batch, tf.transpose(y_batch, [0, 2, 1])) \ if axis == 2 else tf.matmul(tf.transpose(x_batch, [0, 2, 1]), y_batch) # yyt is batch-c-c if axis is 2 else batch-d-d yyt = tf.matmul(y_batch, tf.transpose(y_batch, [0, 2, 1])) \ if axis == 2 else tf.matmul(tf.transpose(y_batch, [0, 2, 1]), y_batch) # dy is batch-c if axis is 2 else batch-d dy = tf.reduce_sum(tf.multiply(y_batch, y_batch), axis=axis) dist_xy = tf.maximum(tf.expand_dims(dx, axis=2) - 2.0 * xyt + tf.expand_dims(dy, axis=1), 0.0) dist_yy = tf.maximum(tf.expand_dims(dy, axis=2) - 2.0 * yyt + tf.expand_dims(dy, axis=1), 0.0) return dist_xx, dist_xy, dist_yy elif mode == 'xy': # dx is batch-a if axis is 2 else batch-b dx = tf.reduce_sum(tf.multiply(x_batch, x_batch), axis=axis) # dy is batch-c if axis is 2 else batch-d dy = tf.reduce_sum(tf.multiply(y_batch, y_batch), axis=axis) # xyt is batch-a-c if axis is 2 else batch-b-d xyt = tf.matmul(x_batch, tf.transpose(y_batch, [0, 2, 1])) \ if axis == 2 else tf.matmul(tf.transpose(x_batch, [0, 2, 1]), y_batch) dist_xy = tf.maximum(tf.expand_dims(dx, axis=2) - 2.0 * xyt + tf.expand_dims(dy, axis=1), 0.0) return dist_xy else: raise AttributeError('Mode {} not supported'.format(mode)) ####################################################################### def newton_root(x, f, df, step=None): """ This function does one iteration update on x to find the root f(x)=0. It is primarily used as the body of tf.while_loop. :param x: :param f: a function that receives x as input and outputs f(x) and other info for gradient calculation :param df: a function that receives info as inputs and outputs the gradient of f at x :param step: :return: """ fx, info2grad = f(x) gx = df(info2grad) x = x - fx / (gx + FLAGS.EPSI) if step is None: return x else: return x, step + 1 ####################################################################### def matrix_mean_wo_diagonal(matrix, num_row, num_col=None, name='mu_wo_diag'): """ This function calculates the mean of the matrix elements not in the diagonal 2018.4.9 - replace tf.diag_part with tf.matrix_diag_part tf.matrix_diag_part can be used for rectangle matrix while tf.diag_part can only be used for square matrix :param matrix: :param num_row: :type num_row: float :param num_col: :type num_col: float :param name: :return: """ with tf.name_scope(name): if num_col is None: mu = (tf.reduce_sum(matrix) - tf.reduce_sum(tf.matrix_diag_part(matrix))) / (num_row * (num_row - 1.0)) else: mu = (tf.reduce_sum(matrix) - tf.reduce_sum(tf.matrix_diag_part(matrix))) \ / (num_row * num_col - tf.minimum(num_col, num_row)) return mu ######################################################################## def row_mean_wo_diagonal(matrix, num_col, name='mu_wo_diag'): """ This function calculates the mean of each row of the matrix elements excluding the diagonal :param matrix: :param num_col: :type num_col: float :param name: :return: """ with tf.name_scope(name): return (tf.reduce_sum(matrix, axis=1) - tf.matrix_diag_part(matrix)) / (num_col - 1.0) ######################################################################### def mmd_t( dist_xx, dist_xy, dist_yy, batch_size, alpha=1.0, beta=2.0, var_target=None, name='mmd', do_summary=False, scope_prefix=''): """This function calculates the maximum mean discrepancy with t-distribution kernel The code is inspired by the Github page of following paper: Binkowski M., Sutherland D., Arbel M., Gretton A. (2018) Demystifying MMD GANs. :param dist_xx: batch_size-by-batch_size matrix :param dist_xy: :param dist_yy: :param batch_size: :param alpha: :param beta: :param var_target: if alpha is trainable, var_target contain the target for sigma :param name: :param do_summary: :param scope_prefix: :return: """ with tf.name_scope(name): log_k_xx = tf.log(dist_xx / (beta * alpha) + 1.0) # use log for better condition log_k_xy = tf.log(dist_xy / (beta * alpha) + 1.0) log_k_yy = tf.log(dist_yy / (beta * alpha) + 1.0) k_xx = tf.exp(-alpha * log_k_xx) # [1.0, k(xi, xj); k(xi, xj), 1.0] k_xy = tf.exp(-alpha * log_k_xy) k_yy = tf.exp(-alpha * log_k_yy) m = tf.constant(batch_size, tf.float32) e_kxx = matrix_mean_wo_diagonal(k_xx, m) e_kxy = matrix_mean_wo_diagonal(k_xy, m) e_kyy = matrix_mean_wo_diagonal(k_yy, m) mmd = e_kxx + e_kyy - 2.0 * e_kxy if do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar(scope_prefix + name + '/kxx', e_kxx) tf.summary.scalar(scope_prefix + name + '/kyy', e_kyy) tf.summary.scalar(scope_prefix + name + '/kxy', e_kxy) # return e_kxx, e_kxy, e_kyy if var_target is None: return mmd else: var = e_kxx + e_kyy + 2.0 * e_kxy loss_sigma = tf.square(var - var_target) if do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar(scope_prefix + name + '/loss_sigma', loss_sigma) return mmd, loss_sigma ######################################################################### def mixture_mmd_t( dist_xx, dist_xy, dist_yy, batch_size, alpha=None, beta=2.0, var_targets=None, name='mmd', do_summary=False, scope_prefix=''): """ This function calculates the maximum mean discrepancy with a list of t-distribution kernels :param dist_xx: :param dist_xy: :param dist_yy: :param batch_size: :param alpha: [0.2, 0.5, 1, 2, 25] :type alpha: list :param beta: :param var_targets: if alpha is trainable, var_targets contain the target for each alpha :type var_targets: list :param name: :param do_summary: :param scope_prefix: :return: """ num_alpha = len(alpha) if isinstance(alpha, list) else len(var_targets) with tf.name_scope(name): mmd = 0.0 if var_targets is None: for i in range(num_alpha): mmd_i = mmd_t( dist_xx, dist_xy, dist_yy, batch_size, alpha=alpha[i], beta=beta, name='d{}'.format(i), do_summary=do_summary, scope_prefix=scope_prefix + name + '/') mmd = mmd + mmd_i return mmd else: loss_alpha = 0.0 for i in range(num_alpha): mmd_i, loss_i = mmd_t( dist_xx, dist_xy, dist_yy, batch_size, alpha=alpha[i], beta=beta, var_target=var_targets[i], name='d{}'.format(i), do_summary=do_summary, scope_prefix=scope_prefix + name + '/') mmd = mmd + mmd_i loss_alpha = loss_alpha + loss_i return mmd, loss_alpha ######################################################################### def witness_t(dist_zx, dist_zy, alpha=1.0, beta=2.0, name='witness', do_summary=False, scope_prefix=''): """ This function calculates the witness function f(z) = Ek(x, z) - Ek(y, z) based on t-distribution kernel :param dist_zx: :param dist_zy: :param alpha: :param beta: :param name: :param do_summary: :param scope_prefix: :return: """ with tf.name_scope(name): # get dist between (x, z) and (y, z) # dist_zx = get_squared_dist(z, x, mode='xy', name='dist_zx', do_summary=do_summary) # dist_zy = get_squared_dist(z, y, mode='xy', name='dist_zy', do_summary=do_summary) log_k_zx = tf.log(dist_zx / (beta * alpha) + 1.0) log_k_zy = tf.log(dist_zy / (beta * alpha) + 1.0) k_zx = tf.exp(-alpha * log_k_zx) k_zy = tf.exp(-alpha * log_k_zy) e_kx = tf.reduce_mean(k_zx, axis=1) e_ky = tf.reduce_mean(k_zy, axis=1) witness = e_kx - e_ky if do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.histogram(scope_prefix + name + '/kzx', e_kx) tf.summary.histogram(scope_prefix + name + '/kzy', e_ky) return witness ######################################################################### def witness_mix_t(dist_zx, dist_zy, alpha=None, beta=2.0, name='witness', do_summary=False): """ This function calculates the witness function f(z) = Ek(x, z) - Ek(y, z) based on a list of t-distribution kernels. :param dist_zx: :param dist_zy: :param alpha: :param beta: :param name: :param do_summary: :return: """ num_alpha = len(alpha) with tf.name_scope(name): witness = 0.0 for i in range(num_alpha): wit_i = witness_t( dist_zx, dist_zy, alpha=alpha[i], beta=beta, name='d{}'.format(i), do_summary=do_summary) witness = witness + wit_i return witness ######################################################################### def cramer(dist_xx, dist_xy, dist_yy, batch_size, name='mmd', epsi=1e-16, do_summary=False, scope_prefix=''): """ This function calculates the energy distance without the need of independent samples. The energy distance is taken originall from following paper: Bellemare1, M.G., Danihelka1, I., Dabney, W., Mohamed S., Lakshminarayanan B., Hoyer S., Munos R. (2017). The Cramer Distance as a Solution to Biased Wasserstein Gradients However, the original method requires two batches to calculate the kernel. :param dist_xx: :param dist_xy: :param dist_yy: :param batch_size: :param name: :param epsi: :param do_summary: :param scope_prefix: :return: """ with tf.name_scope(name): k_xx = -tf.sqrt(dist_xx + epsi) k_xy = -tf.sqrt(dist_xy + epsi) k_yy = -tf.sqrt(dist_yy + epsi) m = tf.constant(batch_size, tf.float32) e_kxx = matrix_mean_wo_diagonal(k_xx, m) e_kxy = matrix_mean_wo_diagonal(k_xy, m) e_kyy = matrix_mean_wo_diagonal(k_yy, m) if do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar(scope_prefix + name + '/kxx', e_kxx) tf.summary.scalar(scope_prefix + name + '/kyy', e_kyy) tf.summary.scalar(scope_prefix + name + '/kxy', e_kxy) # return e_kxx, e_kxy, e_kyy return e_kxx + e_kyy - 2.0 * e_kxy ######################################################################### def mmd_g( dist_xx, dist_xy, dist_yy, batch_size, sigma=1.0, var_target=None, upper_bound=None, lower_bound=None, name='mmd', do_summary=False, scope_prefix='', custom_weights=None): """This function calculates the maximum mean discrepancy with Gaussian distribution kernel The kernel is taken from following paper: Li, C.-L., Chang, W.-C., Cheng, Y., Yang, Y., & Póczos, B. (2017). MMD GAN: Towards Deeper Understanding of Moment Matching Network. :param dist_xx: :param dist_xy: :param dist_yy: :param batch_size: :param sigma: :param var_target: if sigma is trainable, var_target contain the target for sigma :param upper_bound: bounds for pairwise distance in mmd-g. :param lower_bound: :param name: :param do_summary: :param scope_prefix: :param custom_weights: weights for loss in mmd, default is [2.0, 1.0], custom[0] - custom[1] = 1.0 :type custom_weights: list :return: """ with tf.name_scope(name): if lower_bound is None: k_xx = tf.exp(-dist_xx / (2.0 * sigma**2), name='k_xx') k_yy = tf.exp(-dist_yy / (2.0 * sigma ** 2), name='k_yy') else: k_xx = tf.exp(-tf.maximum(dist_xx, lower_bound) / (2.0 * sigma ** 2), name='k_xx_lb') k_yy = tf.exp(-tf.maximum(dist_yy, lower_bound) / (2.0 * sigma ** 2), name='k_yy_lb') if upper_bound is None: k_xy = tf.exp(-dist_xy / (2.0 * sigma**2), name='k_xy') else: k_xy = tf.exp(-tf.minimum(dist_xy, upper_bound) / (2.0 * sigma ** 2), name='k_xy_ub') m = tf.constant(batch_size, tf.float32) e_kxx = matrix_mean_wo_diagonal(k_xx, m) e_kxy = matrix_mean_wo_diagonal(k_xy, m) e_kyy = matrix_mean_wo_diagonal(k_yy, m) if do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar(scope_prefix + name + '/kxx', e_kxx) tf.summary.scalar(scope_prefix + name + '/kyy', e_kyy) tf.summary.scalar(scope_prefix + name + '/kxy', e_kxy) if var_target is None: if custom_weights is None: mmd = e_kxx + e_kyy - 2.0 * e_kxy return mmd else: # note that here kyy is for the real data! assert custom_weights[0] - custom_weights[1] == 1.0, 'w[0]-w[1] must be 1' mmd1 = e_kxx + e_kyy - 2.0 * e_kxy mmd2 = custom_weights[0] * e_kxy - e_kxx - custom_weights[1] * e_kyy return mmd1, mmd2 else: mmd = e_kxx + e_kyy - 2.0 * e_kxy var = e_kxx + e_kyy + 2.0 * e_kxy loss_sigma = tf.square(var - var_target) if do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar(scope_prefix + name + '/loss_sigma', loss_sigma) return mmd, loss_sigma ######################################################################### def mmd_g_bounded( dist_xx, dist_xy, dist_yy, batch_size, sigma=1.0, var_target=None, upper_bound=None, lower_bound=None, name='mmd', do_summary=False, scope_prefix='', custom_weights=None): """This function calculates the maximum mean discrepancy with Gaussian distribution kernel The kernel is taken from following paper: Li, C.-L., Chang, W.-C., Cheng, Y., Yang, Y., & Póczos, B. (2017). MMD GAN: Towards Deeper Understanding of Moment Matching Network. :param dist_xx: :param dist_xy: :param dist_yy: :param batch_size: :param sigma: :param var_target: if sigma is trainable, var_target contain the target for sigma :param upper_bound: :param lower_bound: :param name: :param do_summary: :param scope_prefix: :param custom_weights: weights for loss in mmd, default is [2.0, 1.0], custom[0] - custom[1] = 1.0 :type custom_weights: list :return: """ with tf.name_scope(name): k_xx = tf.exp(-dist_xx / (2.0 * sigma ** 2), name='k_xx') k_yy = tf.exp(-dist_yy / (2.0 * sigma ** 2), name='k_yy') k_xy = tf.exp(-dist_xy / (2.0 * sigma ** 2), name='k_xy') # in rep loss, custom_weights[0] - custom_weights[1] = 1 k_xx_b = tf.exp(-tf.maximum(dist_xx, lower_bound) / (2.0 * sigma ** 2), name='k_xx_lb') if custom_weights[0] > 0: k_xy_b = tf.exp(-tf.minimum(dist_xy, upper_bound) / (2.0 * sigma ** 2), name='k_xy_ub') else: k_xy_b = k_xy # no lower bound should be enforced as k_xy may be zero at equilibrium if custom_weights[1] > 0: # the original mmd-g k_yy_b = tf.exp(-tf.maximum(dist_yy, lower_bound) / (2.0 * sigma ** 2), name='k_yy_ub') else: # the repulsive mmd-g k_yy_b = tf.exp(-tf.minimum(dist_yy, upper_bound) / (2.0 * sigma ** 2), name='k_yy_ub') m = tf.constant(batch_size, tf.float32) e_kxx = matrix_mean_wo_diagonal(k_xx, m) e_kxy = matrix_mean_wo_diagonal(k_xy, m) e_kyy = matrix_mean_wo_diagonal(k_yy, m) e_kxx_b = matrix_mean_wo_diagonal(k_xx_b, m) e_kyy_b = matrix_mean_wo_diagonal(k_yy_b, m) e_kxy_b = matrix_mean_wo_diagonal(k_xy_b, m) if custom_weights[0] < 0 else e_kxy if do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar(scope_prefix + name + '/kxx', e_kxx) tf.summary.scalar(scope_prefix + name + '/kyy', e_kyy) tf.summary.scalar(scope_prefix + name + '/kxy', e_kxy) tf.summary.scalar(scope_prefix + name + '/kxx_b', e_kxx_b) tf.summary.scalar(scope_prefix + name + '/kyy_b', e_kyy_b) if custom_weights[0] > 0: tf.summary.scalar(scope_prefix + name + '/kxy_b', e_kxy_b) if var_target is None: if custom_weights is None: mmd = e_kxx + e_kyy - 2.0 * e_kxy return mmd else: assert custom_weights[0] - custom_weights[1] == 1.0, 'w[0]-w[1] must be 1' mmd1 = e_kxx + e_kyy - 2.0 * e_kxy mmd2 = custom_weights[0] * e_kxy_b - e_kxx_b - custom_weights[1] * e_kyy_b return mmd1, mmd2 else: mmd = e_kxx + e_kyy - 2.0 * e_kxy var = e_kxx + e_kyy + 2.0 * e_kxy loss_sigma = tf.square(var - var_target) if do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar(scope_prefix + name + '/loss_sigma', loss_sigma) return mmd, loss_sigma ######################################################################### def mixture_mmd_g( dist_xx, dist_xy, dist_yy, batch_size, sigma=None, var_targets=None, name='mmd_g', do_summary=False, scope_prefix=''): """ This function calculates the maximum mean discrepancy with a list of Gaussian distribution kernel :param dist_xx: :param dist_xy: :param dist_yy: :param batch_size: :param sigma: :type sigma: list :param var_targets: if sigma is trainable, var_targets contain the target for each sigma :type var_targets: list :param name: :param do_summary: :param scope_prefix: :return: """ num_sigma = len(sigma) if isinstance(sigma, list) else len(var_targets) with tf.name_scope(name): mmd = 0.0 if var_targets is None: for i in range(num_sigma): mmd_i = mmd_g( dist_xx, dist_xy, dist_yy, batch_size, sigma=sigma[i], name='d{}'.format(i), do_summary=do_summary, scope_prefix=scope_prefix + name + '/') mmd = mmd + mmd_i return mmd else: loss_sigma = 0.0 for i in range(num_sigma): mmd_i, loss_i = mmd_g( dist_xx, dist_xy, dist_yy, batch_size, sigma=sigma[i], var_target=var_targets[i], name='d{}'.format(i), do_summary=do_summary, scope_prefix=scope_prefix + name + '/') mmd = mmd + mmd_i loss_sigma = loss_sigma + loss_i return mmd, loss_sigma ######################################################################### def witness_g(dist_zx, dist_zy, sigma=2.0, name='witness', do_summary=False, scope_prefix=''): """ This function calculates the witness function f(z) = Ek(x, z) - Ek(y, z) based on Gaussian kernel :param dist_zx: :param dist_zy: :param sigma: :param name: :param do_summary: :param scope_prefix: :return: """ with tf.name_scope(name): # get dist between (x, z) and (y, z) # dist_zx = get_squared_dist(z, x, mode='xy', name='dist_zx', do_summary=do_summary) # dist_zy = get_squared_dist(z, y, mode='xy', name='dist_zy', do_summary=do_summary) k_zx = tf.exp(-dist_zx / (2.0 * sigma), name='k_zx') k_zy = tf.exp(-dist_zy / (2.0 * sigma), name='k_zy') e_kx = tf.reduce_mean(k_zx, axis=1) e_ky = tf.reduce_mean(k_zy, axis=1) witness = e_kx - e_ky if do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.histogram(scope_prefix + name + '/kzx', e_kx) tf.summary.histogram(scope_prefix + name + '/kzy', e_ky) return witness ######################################################################### def witness_mix_g(dist_zx, dist_zy, sigma=None, name='witness', do_summary=False): """ This function calculates the witness function f(z) = Ek(x, z) - Ek(y, z) based on a list of t-distribution kernels. :param dist_zx: :param dist_zy: :param sigma: :param name: :param do_summary: :return: """ num_sigma = len(sigma) with tf.name_scope(name): witness = 0.0 for i in range(num_sigma): wit_i = witness_g( dist_zx, dist_zy, sigma=sigma[i], name='d{}'.format(i), do_summary=do_summary) witness = witness + wit_i return witness def mmd_g_xn( batch_size, d, sigma, x, dist_xx=None, y_mu=0.0, y_var=1.0, name='mmd', do_summary=False, scope_prefix=''): """ This function calculates the mmd between two samples x and y. y is sampled from normal distribution with zero mean and specified variance. :param x: :param y_var: :param batch_size: :param d: :param sigma: :param y_mu: :param dist_xx: :param name: :param do_summary: :param scope_prefix: :return: """ with tf.name_scope(name): # get dist_xx if dist_xx is None: xxt = tf.matmul(x, x, transpose_b=True) dx = tf.diag_part(xxt) dist_xx = tf.maximum(tf.expand_dims(dx, axis=1) - 2.0 * xxt + tf.expand_dims(dx, axis=0), 0.0) # get dist(x, Ey) dist_xy = tf.reduce_sum(tf.multiply(x - y_mu, x - y_mu), axis=1) k_xx = tf.exp(-dist_xx / (2.0 * sigma), name='k_xx') k_xy = tf.multiply( tf.exp(-dist_xy / (2.0 * (sigma + y_var))), tf.pow(sigma / (sigma + y_var), d / 2.0), name='k_xy') m = tf.constant(batch_size, tf.float32) e_kxx = matrix_mean_wo_diagonal(k_xx, m) e_kxy = tf.reduce_mean(k_xy) e_kyy = tf.pow(sigma / (sigma + 2.0 * y_var), d / 2.0) if do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar(scope_prefix + name + '/kxx', e_kxx) tf.summary.scalar(scope_prefix + name + '/kyy', e_kyy) tf.summary.scalar(scope_prefix + name + '/kxy', e_kxy) return e_kxx + e_kyy - 2.0 * e_kxy def mixture_g_xn(batch_size, d, sigma, x, dist_xx=None, y_mu=0.0, y_var=1.0, name='mmd', do_summary=False): """ This function calculates the mmd between two samples x and y. y is sampled from normal distribution with zero mean and specified variance. A mixture of sigma is used. :param batch_size: :param d: :param sigma: :param x: :param dist_xx: :param y_mu: :param y_var: :param name: :param do_summary: :return: """ num_sigma = len(sigma) with tf.name_scope(name): mmd = 0.0 for i in range(num_sigma): mmd_i = mmd_g_xn( batch_size, d, sigma[i], x=x, dist_xx=dist_xx, y_mu=y_mu, y_var=y_var, name='d{}'.format(i), do_summary=do_summary) mmd = mmd + mmd_i return mmd ######################################################################### def rand_mmd_g(dist_all, batch_size, omega=0.5, max_iter=0, name='mmd', do_summary=False, scope_prefix=''): """ This function uses a global sigma to make e_k match the given omega which is sampled uniformly. The sigma is initialized with geometric mean of pairwise distances and updated with Newton's method. :param dist_all: :param batch_size: :param omega: :param max_iter: :param name: :param do_summary: :param scope_prefix: :return: """ with tf.name_scope(name): m = tf.constant(batch_size, tf.float32) def kernel(b): return tf.exp(-dist_all * b) def f(b): k = kernel(b) e_k = matrix_mean_wo_diagonal(k, 2 * m) return e_k - omega, k def df(k): kd = -k * dist_all # gradient of exp(-d*w) e_kd = matrix_mean_wo_diagonal(kd, 2 * m) return e_kd # initialize sigma as the geometric mean of all pairwise distances dist_mean = matrix_mean_wo_diagonal(dist_all, 2 * m) beta = -tf.log(omega) / (dist_mean + FLAGS.EPSI) # beta = 1/2/sigma # if max_iter is larger than one, do newton's update if max_iter > 0: beta, _ = tf.while_loop( cond=lambda _1, i: i < max_iter, body=lambda b, i: newton_root(b, f, df, step=i), loop_vars=(beta, tf.constant(0, dtype=tf.int32))) k_all = kernel(beta) k_xx = k_all[0:batch_size, 0:batch_size] k_xy_0 = k_all[0:batch_size, batch_size:] k_xy_1 = k_all[batch_size:, 0:batch_size] k_yy = k_all[batch_size:, batch_size:] e_kxx = matrix_mean_wo_diagonal(k_xx, m) e_kxy_0 = matrix_mean_wo_diagonal(k_xy_0, m) e_kxy_1 = matrix_mean_wo_diagonal(k_xy_1, m) e_kyy = matrix_mean_wo_diagonal(k_yy, m) if do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar(scope_prefix + name + '/kxx', e_kxx) tf.summary.scalar(scope_prefix + name + '/kyy', e_kyy) tf.summary.scalar(scope_prefix + name + '/kxy_0', e_kxy_0) tf.summary.scalar(scope_prefix + name + '/kxy_1', e_kxy_1) # tf.summary.scalar(scope_prefix + name + 'omega', omega) return e_kxx + e_kyy - e_kxy_0 - e_kxy_1 def rand_mmd_g_xy( dist_xx, dist_xy, dist_yy, batch_size=None, dist_yx=None, omega=0.5, max_iter=3, name='mmd', do_summary=False, scope_prefix=''): """ This function calculates the mmd between two samples x and y. It uses a global sigma to make e_k match the given omega which is sampled uniformly. The sigma is initialized with geometric mean of dist_xy and updated with Newton's method. :param dist_xx: :param dist_xy: :param dist_yy: :param dist_yx: optional, if dist_xy and dist_yx are not the same :param batch_size: do not provide batch_size when the diagonal part of k** also need to be considered. :param omega: :param max_iter: :param name: :param do_summary: :param scope_prefix: :return: """ with tf.name_scope(name): def kernel(dist, b): return tf.exp(-dist * b) def f(b): k = kernel(dist_xy, b) e_k = tf.reduce_mean(k) return e_k - omega, k def df(k): kd = -k * dist_xy # gradient of exp(-d*w) e_kd = tf.reduce_mean(kd) return e_kd def f_plus(b): k0 = kernel(dist_xy, b) e_k0 = tf.reduce_mean(k0) k1 = kernel(dist_yx, b) e_k1 = tf.reduce_mean(k1) return e_k0 + e_k1 - 2.0 * omega, (k0, k1) def df_plus(k): kd0 = -k[0] * dist_xy # gradient of exp(-d*w) kd1 = -k[1] * dist_yx # gradient of exp(-d*w) e_kd = tf.reduce_mean(kd0) + tf.reduce_mean(kd1) return e_kd if dist_yx is None: # initialize sigma as the geometric mean of dist_xy beta = -tf.log(omega) / tf.reduce_mean(dist_xy + FLAGS.EPSI) # beta = 1/2/sigma # if max_iter is larger than one, do newton's update if max_iter > 0: beta, _ = tf.while_loop( cond=lambda _1, i: i < max_iter, body=lambda b, i: newton_root(b, f, df, step=i), loop_vars=(beta, tf.constant(0, dtype=tf.int32))) else: # initialize sigma as the geometric mean of dist_xy and dist_yx # beta = 1/2/sigma beta = -2.0 * tf.log(omega) / (tf.reduce_mean(dist_xy) + tf.reduce_mean(dist_yx) + FLAGS.EPSI) # if max_iter is larger than one, do newton's update if max_iter > 0: beta, _ = tf.while_loop( cond=lambda _1, i: i < max_iter, body=lambda b, i: newton_root(b, f_plus, df_plus, step=i), loop_vars=(beta, tf.constant(0, dtype=tf.int32))) k_xx = kernel(dist_xx, beta) k_xy = kernel(dist_xy, beta) k_yy = kernel(dist_yy, beta) if batch_size is None: # include diagonal elements in k** e_kxx = tf.reduce_mean(k_xx) e_kxy = tf.reduce_mean(k_xy) e_kyy = tf.reduce_mean(k_yy) else: # exclude diagonal elements in k** m = tf.constant(batch_size, tf.float32) e_kxx = matrix_mean_wo_diagonal(k_xx, m) e_kxy = matrix_mean_wo_diagonal(k_xy, m) e_kyy = matrix_mean_wo_diagonal(k_yy, m) if do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar(scope_prefix + name + '/kxx', e_kxx) tf.summary.scalar(scope_prefix + name + '/kyy', e_kyy) tf.summary.scalar(scope_prefix + name + '/kxy', e_kxy) # tf.summary.scalar(scope_prefix + name + 'omega', omega) # tf.summary.histogram(scope_prefix + name + 'dxx', dist_xx) # tf.summary.histogram(scope_prefix + name + 'dxy', dist_xy) # tf.summary.histogram(scope_prefix + name + 'dyy', dist_yy) if dist_yx is None: return e_kxx + e_kyy - 2.0 * e_kxy else: k_yx = kernel(dist_yx, beta) if batch_size is None: e_kyx = tf.reduce_mean(k_yx) else: m = tf.constant(batch_size, tf.float32) e_kyx = matrix_mean_wo_diagonal(k_yx, m) if do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar(scope_prefix + name + 'kyx', e_kyx) return e_kxx + e_kyy - e_kxy - e_kyx def rand_mmd_g_xy_bounded( dist_xx, dist_xy, dist_yy, batch_size=None, dist_yx=None, omega=0.5, max_iter=3, name='mmd', beta_lb=0.125, beta_ub=2.0, do_summary=False, scope_prefix=''): """ This function calculates the mmd between two samples x and y. It uses a global sigma to make e_k match the given omega which is sampled uniformly. The sigma is initialized with geometric mean of dist_xy and updated with Newton's method. :param dist_xx: :param dist_xy: :param dist_yy: :param dist_yx: optional, if dist_xy and dist_yx are not the same :param batch_size: do not provide batch_size when the diagonal part of k** also need to be considered. :param omega: :param max_iter: :param name: :param beta_lb: lower bound for beta (upper bound for sigma) :param beta_ub: upper bound for beta (lower bound for sigma) :param do_summary: :param scope_prefix: :return: """ with tf.name_scope(name): def kernel(dist, b): return tf.exp(-dist * b) def f(b): k = kernel(dist_xy, b) e_k = tf.reduce_mean(k) return e_k - omega, k def df(k): kd = -k * dist_xy # gradient of exp(-d*w) e_kd = tf.reduce_mean(kd) return e_kd def f_plus(b): k0 = kernel(dist_xy, b) e_k0 = tf.reduce_mean(k0) k1 = kernel(dist_yx, b) e_k1 = tf.reduce_mean(k1) return e_k0 + e_k1 - 2.0 * omega, (k0, k1) def df_plus(k): kd0 = -k[0] * dist_xy # gradient of exp(-d*w) kd1 = -k[1] * dist_yx # gradient of exp(-d*w) e_kd = tf.reduce_mean(kd0) + tf.reduce_mean(kd1) return e_kd if dist_yx is None: # initialize sigma as the geometric mean of dist_xy beta = -tf.log(omega) / tf.reduce_mean(dist_xy + FLAGS.EPSI) # beta = 1/2/sigma # if max_iter is larger than one, do newton's update if max_iter > 0: beta, _ = tf.while_loop( cond=lambda _1, i: i < max_iter, body=lambda b, i: newton_root(b, f, df, step=i), loop_vars=(beta, tf.constant(0, dtype=tf.int32))) else: # initialize sigma as the geometric mean of dist_xy and dist_yx # beta = 1/2/sigma beta = -2.0 * tf.log(omega) / (tf.reduce_mean(dist_xy) + tf.reduce_mean(dist_yx) + FLAGS.EPSI) # if max_iter is larger than one, do newton's update if max_iter > 0: beta, _ = tf.while_loop( cond=lambda _1, i: i < max_iter, body=lambda b, i: newton_root(b, f_plus, df_plus, step=i), loop_vars=(beta, tf.constant(0, dtype=tf.int32))) beta = tf.clip_by_value(beta, beta_lb, beta_ub) k_xx = kernel(dist_xx, beta) k_xy = kernel(dist_xy, beta) k_yy = kernel(dist_yy, beta) k_xx_b = kernel(tf.maximum(dist_xx, 0.125/beta), beta) k_xy_b = kernel(tf.minimum(dist_xy, 2.0/beta), beta) k_yy_b = kernel(tf.maximum(dist_yy, 0.125/beta), beta) if batch_size is None: # include diagonal elements in k** e_kxx = tf.reduce_mean(k_xx) e_kxy = tf.reduce_mean(k_xy) e_kyy = tf.reduce_mean(k_yy) e_kxx_b = tf.reduce_mean(k_xx_b) e_kxy_b = tf.reduce_mean(k_xy_b) e_kyy_b = tf.reduce_mean(k_yy_b) else: # exclude diagonal elements in k** m = tf.constant(batch_size, tf.float32) e_kxx = matrix_mean_wo_diagonal(k_xx, m) e_kxy = matrix_mean_wo_diagonal(k_xy, m) e_kyy = matrix_mean_wo_diagonal(k_yy, m) e_kxx_b = matrix_mean_wo_diagonal(k_xx_b, m) e_kxy_b = matrix_mean_wo_diagonal(k_xy_b, m) e_kyy_b = matrix_mean_wo_diagonal(k_yy_b, m) if do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar(scope_prefix + name + '/kxx', e_kxx) tf.summary.scalar(scope_prefix + name + '/kyy', e_kyy) tf.summary.scalar(scope_prefix + name + '/kxy', e_kxy) tf.summary.scalar(scope_prefix + name + '/beta', beta) tf.summary.scalar(scope_prefix + name + '/kxx_b', e_kxx_b) tf.summary.scalar(scope_prefix + name + '/kyy_b', e_kyy_b) tf.summary.scalar(scope_prefix + name + '/kxy_b', e_kxy_b) # tf.summary.scalar(scope_prefix + name + '/kxy_b', e_kxy_b) # tf.summary.scalar(scope_prefix + name + 'omega', omega) # tf.summary.histogram(scope_prefix + name + 'dxx', dist_xx) # tf.summary.histogram(scope_prefix + name + 'dxy', dist_xy) # tf.summary.histogram(scope_prefix + name + 'dyy', dist_yy) if dist_yx is None: return e_kxx + e_kyy - 2.0 * e_kxy, e_kxx_b - 2.0 * e_kyy_b + e_kxy_b else: k_yx = kernel(dist_yx, beta) # k_yx_b = kernel(tf.minimum(dist_yx, upper_bound), beta) if batch_size is None: e_kyx = tf.reduce_mean(k_yx) # e_kyx_b = tf.reduce_mean(k_yx_b) else: m = tf.constant(batch_size, tf.float32) e_kyx = matrix_mean_wo_diagonal(k_yx, m) # e_kyx_b = matrix_mean_wo_diagonal(k_yx_b, m) if do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar(scope_prefix + name + 'kyx', e_kyx) # tf.summary.scalar(scope_prefix + name + 'kyx_b', e_kyx_b) return e_kxx + e_kyy - e_kxy - e_kyx def rand_mmd_g_xn( x, y_rho, batch_size, d, y_mu=0.0, dist_xx=None, omega=0.5, max_iter=0, name='mmd', do_summary=False, scope_prefix=''): """ This function calculates the mmd between two samples x and y. y is sampled from normal distribution with zero mean and specified STD. This function uses a global sigma to make e_k match the given omega which is sampled uniformly. The sigma is initialized with geometric mean of dist_xy and updated with Newton's method. :param x: :param y_rho: y_std = sqrt(y_rho / 2.0 / d) :param batch_size: :param d: number of features in x :param y_mu: :param dist_xx: :param omega: :param max_iter: :param name: :param do_summary: :param scope_prefix: :return: """ with tf.name_scope(name): # get dist_xx if dist_xx is None: xxt = tf.matmul(x, x, transpose_b=True) dx = tf.diag_part(xxt) dist_xx = tf.maximum(tf.expand_dims(dx, axis=1) - 2.0 * xxt + tf.expand_dims(dx, axis=0), 0.0) # get dist(x, Ey) dist_xy = tf.reduce_sum(tf.multiply(x - y_mu, x - y_mu), axis=1) def kernel(dist, b): return tf.exp(-dist * b) def f(b): const_f = d / (d + b * y_rho) k = tf.pow(const_f, d / 2.0) * tf.exp(-b * const_f * dist_xy) e_k = tf.reduce_mean(k) return e_k - omega, (const_f, k, e_k) def df(k): kd = -y_rho * k[0] / 2.0 * k[2] - tf.reduce_mean(tf.pow(k[0], 2) * dist_xy * k[1]) # gradient of exp(-d*w) e_kd = tf.reduce_mean(kd) return e_kd # initialize sigma as the geometric mean of dist_xy beta = -tf.log(omega) / (tf.reduce_mean(dist_xy) + y_rho / 2.0) # beta = 1/2/sigma # if max_iter is larger than one, do newton's update if max_iter > 0: beta, _ = tf.while_loop( cond=lambda _1, i: i < max_iter, body=lambda b, i: newton_root(b, f, df, step=i), loop_vars=(beta, tf.constant(0, dtype=tf.int32))) const_0 = d / (d + beta * y_rho) k_xx = kernel(dist_xx, beta) k_xy = tf.pow(const_0, d / 2.0) * tf.exp(-beta * const_0 * dist_xy) e_kxx = matrix_mean_wo_diagonal(k_xx, tf.constant(batch_size, tf.float32)) e_kxy = tf.reduce_mean(k_xy) e_kyy = tf.pow(d / (d + 2.0 * beta * y_rho), d / 2.0) if do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar(scope_prefix + name + '/kxx', e_kxx) tf.summary.scalar(scope_prefix + name + '/kyy', e_kyy) tf.summary.scalar(scope_prefix + name + '/kxy', e_kxy) return e_kxx + e_kyy - 2.0 * e_kxy def get_tensor_name(tensor): """ This function return tensor name without scope :param tensor: :return: """ import re # split 'scope/name:0' into [scope, name, 0] return re.split('[/:]', tensor.name)[-2] def moving_average_update(name, shape, tensor_update, rho=0.01, initializer=None, clip_values=None, dtype=tf.float32): """ This function creates a tensor that will be updated by tensor_update using moving average :param tensor_update: update at each iteration :param name: name for the tensor :param shape: shape of tensor :param rho: :param initializer: :param clip_values: :param dtype: :return: """ if initializer is None: initializer = tf.zeros_initializer tensor = tf.get_variable( name, shape=shape, dtype=dtype, initializer=initializer, trainable=False) if clip_values is None: tf.add_to_collection( tf.GraphKeys.UPDATE_OPS, tf.assign(tensor, tensor + rho * tensor_update)) else: tf.add_to_collection( tf.GraphKeys.UPDATE_OPS, tf.assign( tensor, tf.clip_by_value( tensor + rho * tensor_update, clip_value_min=clip_values[0], clip_value_max=clip_values[1]))) return tensor def moving_average_copy(tensor, name=None, rho=0.01, initializer=None, dtype=tf.float32): """ This function creates a moving average copy of tensor :param tensor: :param name: name for the moving average :param rho: :param initializer: :param dtype: :return: """ if initializer is None: initializer = tf.zeros_initializer if name is None: name = get_tensor_name(tensor) + '_copy' tensor_copy = tf.get_variable( name, shape=tensor.get_shape().as_list(), dtype=dtype, initializer=initializer, trainable=False) tf.add_to_collection( tf.GraphKeys.UPDATE_OPS, tf.assign(tensor_copy, (1.0 - rho) * tensor_copy + rho * tensor)) return tensor_copy def slice_pairwise_distance(pair_dist, batch_size=None, indices=None): """ This function slice pair-dist into smaller pairwise distance matrices :param pair_dist: 2batch_size-by-2batch_size pairwise distance matrix :param batch_size: :param indices: :return: """ with tf.name_scope('slice_dist'): if indices is None: dist_g1 = pair_dist[0:batch_size, 0:batch_size] dist_g2 = pair_dist[batch_size:, batch_size:] dist_g1g2 = pair_dist[0:batch_size, batch_size:] else: mix_group_1 = tf.concat((indices, tf.logical_not(indices)), axis=0) mix_group_2 = tf.concat((tf.logical_not(indices), indices), axis=0) dist_g1 = mat_slice(pair_dist, mix_group_1) dist_g2 = mat_slice(pair_dist, mix_group_2) dist_g1g2 = mat_slice(pair_dist, mix_group_1, mix_group_2) return dist_g1, dist_g1g2, dist_g2 def get_mix_coin( loss, loss_threshold, batch_size=None, loss_average_update=0.01, mix_prob_update=0.01, loss_average_name='loss_ave'): """ This function generate a mix_indices to mix data from two classes :param loss: :param loss_threshold: :param batch_size: :param loss_average_update: :param mix_prob_update: :param loss_average_name: :return: """ with tf.variable_scope('coin', reuse=tf.AUTO_REUSE): # calculate moving average of loss loss_average = moving_average_copy(loss, loss_average_name, rho=loss_average_update) # update mixing probability mix_prob = moving_average_update( 'prob', [], loss_average - loss_threshold, rho=mix_prob_update, clip_values=[0.0, 0.5]) # sample mix_indices uni = tf.random_uniform([batch_size], 0.0, 1.0, dtype=tf.float32, name='uni') mix_indices = tf.greater(uni, mix_prob, name='mix_indices') # mix_indices for using original data # loss_average and mix_prob is returned so that summary can be added outside of coin variable scope return mix_indices, loss_average, mix_prob class GANLoss(object): def __init__(self, do_summary=False): """ This class defines all kinds of loss functions for generative adversarial nets Current losses include: """ # IO self.do_summary = do_summary self.score_gen = None self.score_data = None self.batch_size = None self.num_scores = None # loss self.loss_gen = None self.loss_dis = None self.dis_penalty = None self.dis_scale = None self.debug_register = None # output used for debugging # hyperparameters self.sigma = [1.0, np.sqrt(2.0), 2.0, np.sqrt(8.0), 4.0] # self.sigma = [1.0, 2.0, 4.0, 8.0, 16.0] # mmd-g, kernel scales used in original paper self.alpha = [0.2, 0.5, 1, 2, 5.0] # mmd-t, kernel scales used in original paper self.beta = 2.0 # mmd-t, kernel scales used in original paper self.omega_range = [0.05, 0.85] # rand_g parameter self.ref_normal = 1.0 # rand_g parameter # weights[0] - weights[1] = 1.0 self.repulsive_weights = [0.0, -1.0] # weights for e_kxy and -e_kyy; note that kyy is for the real data! # self.repulsive_weights = [-1.0, -2.0] # weights for e_kxy and -e_kyy def _add_summary_(self): """ This function adds summaries :return: """ if self.do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar('GANLoss/gen', self.loss_gen) tf.summary.scalar('GANLoss/dis', self.loss_dis) def _logistic_(self): """ non-saturate logistic loss :return: """ with tf.name_scope('logistic_loss'): self.loss_dis = tf.reduce_mean(tf.nn.softplus(self.score_gen) + tf.nn.softplus(-self.score_data)) self.loss_gen = tf.reduce_mean(tf.nn.softplus(-self.score_gen)) def _hinge_(self): """ hinge loss :return: """ with tf.name_scope('hinge_loss'): self.loss_dis = tf.reduce_mean( tf.nn.relu(1.0 + self.score_gen)) + tf.reduce_mean(tf.nn.relu(1.0 - self.score_data)) self.loss_gen = tf.reduce_mean(-self.score_gen) def _wasserstein_(self): """ wasserstein distance :return: """ assert self.dis_penalty is not None, 'Discriminator penalty must be provided for wasserstein GAN' with tf.name_scope('wasserstein'): self.loss_gen = tf.reduce_mean(self.score_data) - tf.reduce_mean(self.score_gen) self.loss_dis = - self.loss_gen + self.dis_penalty if self.do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar('GANLoss/dis_penalty', self.dis_penalty) return self.loss_dis, self.loss_gen def _mmd_g_(self): """ maximum mean discrepancy with gaussian kernel """ # calculate pairwise distance dist_gg, dist_gd, dist_dd = get_squared_dist( self.score_gen, self.score_data, z_score=False, do_summary=self.do_summary) # mmd self.loss_gen = mixture_mmd_g( dist_gg, dist_gd, dist_dd, self.batch_size, sigma=self.sigma, name='mmd_g', do_summary=self.do_summary) self.loss_dis = -self.loss_gen if self.dis_penalty is not None: self.loss_dis = self.loss_dis + self.dis_penalty def _mmd_g_bound_(self): """ maximum mean discrepancy with gaussian kernel and bounds on dxy :return: """ # calculate pairwise distance dist_gg, dist_gd, dist_dd = get_squared_dist( self.score_gen, self.score_data, z_score=False, do_summary=self.do_summary) # mmd self.loss_gen = mmd_g( dist_gg, dist_gd, dist_dd, self.batch_size, sigma=1.0, name='mmd_g', do_summary=self.do_summary, scope_prefix='') mmd_b = mmd_g( dist_gg, dist_gd, dist_dd, self.batch_size, sigma=1.0, upper_bound=4, lower_bound=0.25, name='mmd_g_b', do_summary=self.do_summary, scope_prefix='') self.loss_dis = -mmd_b if self.dis_penalty is not None: self.loss_dis = self.loss_dis + self.dis_penalty def _mmd_g_mix_(self, mix_threshold=1.0): """ maximum mean discrepancy with gaussian kernel and mixing score_gen and score_data if discriminator is too strong :param mix_threshold: :return: """ # calculate pairwise distance pair_dist = get_squared_dist(tf.concat((self.score_gen, self.score_data), axis=0)) dist_gg, dist_gd, dist_dd = slice_pairwise_distance(pair_dist, batch_size=self.batch_size) # mmd with tf.variable_scope('mmd_g_mix', reuse=tf.AUTO_REUSE): self.loss_gen = mixture_mmd_g( dist_gg, dist_gd, dist_dd, self.batch_size, sigma=self.sigma, name='mmd', do_summary=self.do_summary, scope_prefix='mmd_g_mix/') # mix data if self.loss_gen surpass loss_gen_threshold mix_indices, loss_average, mix_prob = get_mix_coin( self.loss_gen, mix_threshold, batch_size=self.batch_size, loss_average_name='gen_average') dist_gg_mix, dist_gd_mix, dist_dd_mix = slice_pairwise_distance(pair_dist, indices=mix_indices) # mmd for mixed data loss_mix = mixture_mmd_g( dist_gg_mix, dist_gd_mix, dist_dd_mix, self.batch_size, sigma=self.sigma, name='mmd_mix', do_summary=self.do_summary, scope_prefix='mmd_g_mix/') self.loss_dis = -loss_mix if self.do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar('GANLoss/gen_average', loss_average) tf.summary.scalar('GANLoss/mix_prob', mix_prob) tf.summary.histogram('squared_dist/dxx', dist_gg) tf.summary.histogram('squared_dist/dyy', dist_dd) tf.summary.histogram('squared_dist/dxy', dist_gd) def _single_mmd_g_mix_(self, mix_threshold=0.2): """ maximum mean discrepancy with gaussian kernel and mixing score_gen and score_data if discriminator is too strong :param mix_threshold: :return: """ # calculate pairwise distance pair_dist = get_squared_dist(tf.concat((self.score_gen, self.score_data), axis=0)) dist_gg, dist_gd, dist_dd = slice_pairwise_distance(pair_dist, batch_size=self.batch_size) # mmd with tf.variable_scope('mmd_g_mix', reuse=tf.AUTO_REUSE): self.loss_gen = mmd_g( dist_gg, dist_gd, dist_dd, self.batch_size, sigma=1.0, name='mmd', do_summary=self.do_summary, scope_prefix='mmd_g_mix/') # mix data if self.loss_gen surpass loss_gen_threshold mix_indices, loss_average, mix_prob = get_mix_coin( self.loss_gen, mix_threshold, batch_size=self.batch_size, loss_average_name='gen_average') dist_gg_mix, dist_gd_mix, dist_dd_mix = slice_pairwise_distance(pair_dist, indices=mix_indices) # mmd for mixed data loss_mix = mmd_g( dist_gg_mix, dist_gd_mix, dist_dd_mix, self.batch_size, sigma=1.0, name='mmd_mix', do_summary=self.do_summary, scope_prefix='mmd_g_mix/') self.loss_dis = -loss_mix if self.do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar('GANLoss/gen_average', loss_average) tf.summary.scalar('GANLoss/mix_prob', mix_prob) tf.summary.histogram('squared_dist/dxx', dist_gg) tf.summary.histogram('squared_dist/dyy', dist_dd) tf.summary.histogram('squared_dist/dxy', dist_gd) def _mmd_t_(self): """ maximum mean discrepancy with t-distribution kernel """ # calculate pairwise distance dist_gg, dist_gd, dist_dd = get_squared_dist( self.score_gen, self.score_data, z_score=False, do_summary=self.do_summary) # mmd self.loss_gen = mixture_mmd_t( dist_gg, dist_gd, dist_dd, self.batch_size, alpha=self.alpha, beta=self.beta, name='mmd_t', do_summary=self.do_summary) self.loss_dis = -self.loss_gen if self.dis_penalty is not None: self.loss_dis = self.loss_dis + self.dis_penalty def _rand_g_(self): """ maximum mean discrepancy with gaussian kernel and random kernel scale """ # calculate pairwise distance dist_gg, dist_gd, dist_dd = get_squared_dist( self.score_gen, self.score_data, z_score=False, do_summary=self.do_summary) # mmd with tf.name_scope('rand_g'): omega = tf.random_uniform([], self.omega_range[0], self.omega_range[1], dtype=tf.float32) \ if isinstance(self.omega_range, (list, tuple)) else self.omega_range loss_gr = rand_mmd_g_xy( dist_gg, dist_gd, dist_dd, self.batch_size, omega=omega, max_iter=3, name='mmd_gr', do_summary=self.do_summary, scope_prefix='rand_g/') loss_gn = rand_mmd_g_xn( self.score_gen, self.ref_normal, self.batch_size, self.num_scores, dist_xx=dist_gg, omega=omega, max_iter=3, name='mmd_gn', do_summary=self.do_summary, scope_prefix='rand_g/') loss_rn = rand_mmd_g_xn( self.score_data, self.ref_normal, self.batch_size, self.num_scores, dist_xx=dist_dd, omega=omega, max_iter=3, name='mmd_rn', do_summary=self.do_summary, scope_prefix='rand_g/') # final loss self.loss_gen = loss_gr self.loss_dis = loss_rn - loss_gr # self.debug_register = [omega, loss_gr, loss_gn, loss_rn] if self.do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar('rand_g/omega', omega) tf.summary.scalar('GANLoss/gr', loss_gr) tf.summary.scalar('GANLoss/gn', loss_gn) tf.summary.scalar('GANLoss/rn', loss_rn) def _rand_g_bounded_(self): """ maximum mean discrepancy with gaussian kernel and random kernel scale, and upper bounds on dxy :return: """ # calculate pairwise distance dist_gg, dist_gd, dist_dd = get_squared_dist( self.score_gen, self.score_data, z_score=False, do_summary=self.do_summary) with tf.name_scope('rand_g'): omega = tf.random_uniform([], self.omega_range[0], self.omega_range[1], dtype=tf.float32) \ if isinstance(self.omega_range, (list, tuple)) else self.omega_range loss_gr, loss_gr_b = rand_mmd_g_xy_bounded( dist_gg, dist_gd, dist_dd, self.batch_size, omega=omega, max_iter=3, name='mmd', do_summary=self.do_summary, scope_prefix='rand_g/') # loss_gn = rand_mmd_g_xn( # self.score_gen, self.ref_normal, self.batch_size, self.num_scores, dist_xx=dist_gg, omega=omega, # max_iter=3, name='mmd_gn', do_summary=self.do_summary, scope_prefix='rand_g/') # loss_rn = rand_mmd_g_xn( # self.score_data, self.ref_normal, self.batch_size, self.num_scores, dist_xx=dist_dd, omega=omega, # max_iter=3, name='mmd_rn', do_summary=self.do_summary, scope_prefix='rand_g/') # final loss self.loss_gen = loss_gr self.loss_dis = - loss_gr_b if self.do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar('rand_g/omega', omega) tf.summary.scalar('GANLoss/gr', loss_gr) # tf.summary.scalar('GANLoss/gn', loss_gn) # tf.summary.scalar('GANLoss/rn', loss_rn) def _rand_g_mix_(self, mix_threshold=0.2): """ maximum mean discrepancy with gaussian kernel and random kernel scale and mixing score_gen and score_data if discriminator is too strong """ # calculate pairwise distance pair_dist = get_squared_dist(tf.concat((self.score_gen, self.score_data), axis=0)) dist_gg, dist_gd, dist_dd = slice_pairwise_distance(pair_dist, batch_size=self.batch_size) # mmd with tf.variable_scope('rand_g_mix', reuse=tf.AUTO_REUSE): omega = tf.random_uniform([], self.omega_range[0], self.omega_range[1], dtype=tf.float32) \ if isinstance(self.omega_range, (list, tuple)) else self.omega_range loss_gr = rand_mmd_g_xy( dist_gg, dist_gd, dist_dd, self.batch_size, omega=omega, max_iter=3, name='mmd_gr', do_summary=self.do_summary, scope_prefix='rand_g_mix/') loss_gn = rand_mmd_g_xn( self.score_gen, self.ref_normal, self.batch_size, self.num_scores, dist_xx=dist_gg, omega=omega, max_iter=3, name='mmd_gn', do_summary=self.do_summary, scope_prefix='rand_g_mix/') loss_rn = rand_mmd_g_xn( self.score_data, self.ref_normal, self.batch_size, self.num_scores, dist_xx=dist_dd, omega=omega, max_iter=3, name='mmd_rn', do_summary=self.do_summary, scope_prefix='rand_g_mix/') # mix data if self.loss_gen surpass loss_gen_threshold mix_indices, loss_average, mix_prob = get_mix_coin( loss_gr, mix_threshold, batch_size=self.batch_size, loss_average_name='gr_average') dist_gg_mix, dist_gd_mix, dist_dd_mix = slice_pairwise_distance(pair_dist, indices=mix_indices) # mmd for mixed data loss_gr_mix = rand_mmd_g_xy( dist_gg_mix, dist_gd_mix, dist_dd_mix, self.batch_size, omega=omega, max_iter=3, name='mmd_gr_mix', do_summary=self.do_summary, scope_prefix='rand_g_mix/') # final loss self.loss_gen = loss_gr self.loss_dis = loss_rn - loss_gr_mix # self.debug_register = loss_rn if self.do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar('rand_g_mix/omega', omega) tf.summary.scalar('GANLoss/gr_average', loss_average) tf.summary.scalar('GANLoss/mix_prob', mix_prob) tf.summary.histogram('squared_dist/dxx', dist_gg) tf.summary.histogram('squared_dist/dyy', dist_dd) tf.summary.histogram('squared_dist/dxy', dist_gd) tf.summary.scalar('GANLoss/gr', loss_gr) tf.summary.scalar('GANLoss/gn', loss_gn) tf.summary.scalar('GANLoss/rn', loss_rn) tf.summary.scalar('GANLoss/gr_mix', loss_gr_mix) def _sym_rg_mix_(self, mix_threshold=0.2): """ symmetric version of rand_g_mix :param mix_threshold: :return: """ # calculate pairwise distance pair_dist = get_squared_dist(tf.concat((self.score_gen, self.score_data), axis=0)) dist_gg, dist_gd, dist_dd = slice_pairwise_distance(pair_dist, batch_size=self.batch_size) # mmd with tf.variable_scope('sym_rg_mix', reuse=tf.AUTO_REUSE): omega = tf.random_uniform([], self.omega_range[0], self.omega_range[1], dtype=tf.float32) \ if isinstance(self.omega_range, (list, tuple)) else self.omega_range loss_gr = rand_mmd_g_xy( dist_gg, dist_gd, dist_dd, self.batch_size, omega=omega, max_iter=3, name='mmd_gr', do_summary=self.do_summary, scope_prefix='sym_rg_mix/') loss_gn = rand_mmd_g_xn( self.score_gen, self.ref_normal, self.batch_size, self.num_scores, dist_xx=dist_gg, omega=omega, max_iter=3, name='mmd_gn', do_summary=self.do_summary, scope_prefix='sym_rg_mix/') loss_rn = rand_mmd_g_xn( self.score_data, self.ref_normal, self.batch_size, self.num_scores, dist_xx=dist_dd, omega=omega, max_iter=3, name='mmd_rn', do_summary=self.do_summary, scope_prefix='sym_rg_mix/') # mix data if self.loss_gen surpass loss_gen_threshold mix_indices, loss_average, mix_prob = get_mix_coin( loss_gr, mix_threshold, batch_size=self.batch_size, loss_average_name='gr_average') dist_gg_mix, dist_gd_mix, dist_dd_mix = slice_pairwise_distance(pair_dist, indices=mix_indices) # mmd for mixed data loss_gr_mix = rand_mmd_g_xy( dist_gg_mix, dist_gd_mix, dist_dd_mix, self.batch_size, omega=omega, max_iter=3, name='mmd_gr_mix', do_summary=self.do_summary, scope_prefix='sym_rg_mix/') # final loss self.loss_gen = loss_gr + loss_gn self.loss_dis = loss_rn - loss_gr_mix - loss_gn if self.do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar('rand_g_mix/omega', omega) tf.summary.scalar('GANLoss/gr_average', loss_average) tf.summary.scalar('GANLoss/mix_prob', mix_prob) tf.summary.histogram('squared_dist/dxx', dist_gg) tf.summary.histogram('squared_dist/dyy', dist_dd) tf.summary.histogram('squared_dist/dxy', dist_gd) tf.summary.scalar('GANLoss/gr', loss_gr) tf.summary.scalar('GANLoss/gn', loss_gn) tf.summary.scalar('GANLoss/rn', loss_rn) tf.summary.scalar('GANLoss/gr_mix', loss_gr_mix) def _sym_rand_g_(self): """ Version 2 of symmetric rand_g. This function does not use label smoothing This function does not work. :return: """ # calculate pairwise distance pair_dist = get_squared_dist(tf.concat((self.score_gen, self.score_data), axis=0)) dist_gg, dist_gd, dist_dd = slice_pairwise_distance(pair_dist, batch_size=self.batch_size) # mmd with tf.variable_scope('sym_rg_mix', reuse=tf.AUTO_REUSE): omega = tf.random_uniform([], self.omega_range[0], self.omega_range[1], dtype=tf.float32) \ if isinstance(self.omega_range, (list, tuple)) else self.omega_range loss_gr = rand_mmd_g_xy( dist_gg, dist_gd, dist_dd, self.batch_size, omega=omega, max_iter=3, name='mmd_gr', do_summary=self.do_summary, scope_prefix='sym_rg_mix/') loss_gn = rand_mmd_g_xn( self.score_gen, self.ref_normal, self.batch_size, self.num_scores, y_mu=-0.5, dist_xx=dist_gg, omega=omega, max_iter=3, name='mmd_gn', do_summary=self.do_summary, scope_prefix='sym_rg_mix/') loss_rn = rand_mmd_g_xn( self.score_data, self.ref_normal, self.batch_size, self.num_scores, y_mu=0.5, dist_xx=dist_dd, omega=omega, max_iter=3, name='mmd_rn', do_summary=self.do_summary, scope_prefix='sym_rg_mix/') self.loss_gen = loss_gr self.loss_dis = 0.5*(loss_rn + loss_gn) - loss_gr if self.do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar('sym_rg_mix/omega', omega) tf.summary.histogram('squared_dist/dxx', dist_gg) tf.summary.histogram('squared_dist/dyy', dist_dd) tf.summary.histogram('squared_dist/dxy', dist_gd) tf.summary.scalar('GANLoss/gr', loss_gr) tf.summary.scalar('GANLoss/gn', loss_gn) tf.summary.scalar('GANLoss/rn', loss_rn) def _rand_g_instance_noise_(self, mix_threshold=0.2): """ This function tests instance noise :param mix_threshold: :return: """ with tf.variable_scope('ins_noise'): sigma = tf.get_variable( 'sigma', shape=[], dtype=tf.float32, initializer=tf.zeros_initializer, trainable=False) stddev = tf.log(sigma + 1.0) # to slow down sigma increase noise_gen = tf.random_normal( self.score_gen.get_shape().as_list(), mean=0.0, stddev=stddev, name='noise_gen', dtype=tf.float32) noise_x = tf.random_normal( self.score_data.get_shape().as_list(), mean=0.0, stddev=stddev, name='noise_x', dtype=tf.float32) self.score_gen = self.score_gen + noise_gen self.score_data = self.score_data + noise_x # use rand_g loss self._rand_g_() # update sigma loss_average = moving_average_copy(self.loss_gen, 'mmd_mean') tf.add_to_collection( tf.GraphKeys.UPDATE_OPS, tf.assign( sigma, tf.clip_by_value( sigma + 0.001 * (loss_average - mix_threshold), clip_value_min=0.0, clip_value_max=1.7183))) if self.do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar('GANLoss/gr_average', loss_average) tf.summary.scalar('GANLoss/sigma', sigma) def _repulsive_mmd_g_(self): """ repulsive loss :return: """ # calculate pairwise distance dist_gg, dist_gd, dist_dd = get_squared_dist( self.score_gen, self.score_data, z_score=False, do_summary=self.do_summary) # self.loss_gen, self.loss_dis = mmd_g( # dist_gg, dist_gd, dist_dd, self.batch_size, sigma=1.6, # name='mmd_g', do_summary=self.do_summary, scope_prefix='', custom_weights=self.repulsive_weights) self.loss_gen, self.loss_dis = mmd_g( dist_gg, dist_gd, dist_dd, self.batch_size, sigma=1.0, name='mmd_g', do_summary=self.do_summary, scope_prefix='', custom_weights=self.repulsive_weights) if self.dis_penalty is not None: self.loss_dis = self.loss_dis + self.dis_penalty if self.do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar('GANLoss/dis_penalty', self.dis_penalty) if self.dis_scale is not None: self.loss_dis = (self.loss_dis - 1.0) * self.dis_scale if self.do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar('GANLoss/dis_scale', self.dis_scale) def _repulsive_mmd_g_bounded_(self): """ rmb loss :return: """ # calculate pairwise distance dist_gg, dist_gd, dist_dd = get_squared_dist( self.score_gen, self.score_data, z_score=False, do_summary=self.do_summary) self.loss_gen, self.loss_dis = mmd_g_bounded( dist_gg, dist_gd, dist_dd, self.batch_size, sigma=1.0, lower_bound=0.25, upper_bound=4.0, name='mmd_g', do_summary=self.do_summary, scope_prefix='', custom_weights=self.repulsive_weights) if self.dis_penalty is not None: self.loss_dis = self.loss_dis + self.dis_penalty if self.do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar('GANLoss/dis_penalty', self.dis_penalty) if self.dis_scale is not None: self.loss_dis = self.loss_dis * self.dis_scale if self.do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.scalar('GANLoss/dis_scale', self.dis_scale) def _test_(self): self.loss_dis = 0.0 self.loss_gen = 0.0 def __call__(self, score_gen, score_data, loss_type='logistic', **kwargs): """ This function calls one of the loss functions. :param score_gen: :param score_data: :param loss_type: :param kwargs: :return: """ # IO and hyperparameters self.score_gen = score_gen self.score_data = score_data if 'batch_size' in kwargs: self.batch_size = kwargs['batch_size'] if 'd' in kwargs: self.num_scores = kwargs['d'] if 'dis_penalty' in kwargs: self.dis_penalty = kwargs['dis_penalty'] if 'dis_scale' in kwargs: self.dis_scale = kwargs['dis_scale'] if 'sigma' in kwargs: self.sigma = kwargs['sigma'] if 'alpha' in kwargs: self.alpha = kwargs['alpha'] if 'beta' in kwargs: self.beta = kwargs['beta'] if 'omega' in kwargs: self.omega_range = kwargs['omega'] if 'ref_normal' in kwargs: self.ref_normal = kwargs['ref_normal'] if 'rep_weights' in kwargs: self.repulsive_weights = kwargs['rep_weights'] # check inputs if loss_type in {'fixed_g', 'mmd_g', 'fixed_t', 'mmd_t', 'mmd_g_mix', 'fixed_g_mix', 'rand_g', 'rand_g_mix', 'sym_rg_mix', 'instance_noise', 'ins_noise', 'sym_rg', 'rgb', 'rep', 'rep_gp', 'rmb', 'rmb_gp'}: assert self.batch_size is not None, 'GANLoss: batch_size must be provided' if loss_type in {'rand_g', 'rand_g_mix', 'sym_rg_mix', 'sym_rg'}: assert self.num_scores is not None, 'GANLoss: d must be provided' if loss_type in {'rep_gp', 'rmb_gp', 'wasserstein'}: assert self.dis_penalty is not None, 'Discriminator penalty must be provided.' if loss_type in {'rep_ds', 'rmb_ds'}: assert self.dis_scale is not None, 'Discriminator loss scale must be provided.' # loss if loss_type in {'logistic', ''}: self._logistic_() elif loss_type == 'hinge': self._hinge_() elif loss_type == 'wasserstein': self._wasserstein_() elif loss_type in {'fixed_g', 'mmd_g'}: self._mmd_g_() elif loss_type in {'mgb'}: self._mmd_g_bound_() elif loss_type in {'fixed_t', 'mmd_t'}: self._mmd_t_() elif loss_type in {'mmd_g_mix', 'fixed_g_mix'}: if 'mix_threshold' in kwargs: self._mmd_g_mix_(kwargs['mix_threshold']) else: self._mmd_g_mix_() elif loss_type in {'sgm'}: # single mmd-g mix if 'mix_threshold' in kwargs: self._single_mmd_g_mix_(kwargs['mix_threshold']) else: self._single_mmd_g_mix_() elif loss_type == 'rand_g': self._rand_g_() elif loss_type == 'rgb': self._rand_g_bounded_() elif loss_type == 'rand_g_mix': if 'mix_threshold' in kwargs: self._rand_g_mix_(kwargs['mix_threshold']) else: self._rand_g_mix_() elif loss_type == 'sym_rg_mix': if 'mix_threshold' in kwargs: self._sym_rg_mix_(kwargs['mix_threshold']) else: self._sym_rg_mix_() elif loss_type in {'sym_rg', 'sym_rand_g'}: self._sym_rand_g_() elif loss_type in {'instance_noise', 'ins_noise'}: if 'mix_threshold' in kwargs: self._rand_g_instance_noise_(kwargs['mix_threshold']) else: self._rand_g_instance_noise_() elif loss_type in {'rep', 'rep_mmd_g', 'rep_gp', 'rep_ds'}: self._repulsive_mmd_g_() elif loss_type in {'rmb', 'rep_b', 'rep_mmd_b', 'rmb_gp', 'rmb_ds'}: self._repulsive_mmd_g_bounded_() elif loss_type == 'test': self._test_() else: raise NotImplementedError('Not implemented.') self._add_summary_() return self.loss_gen, self.loss_dis def apply(self, score_gen, score_data, loss_type='logistic', **kwargs): return self.__call__(score_gen, score_data, loss_type=loss_type, **kwargs) def get_register(self): """ This function returns the registered tensor :return: """ # loss object always forgets self.debug_register after its value returned registered_info = self.debug_register self.debug_register = None return registered_info def sqrt_sym_mat_np(mat, eps=None): """ This function calculates the square root of symmetric matrix :param mat: :param eps: :return: """ if eps is None: eps = FLAGS.EPSI u, s, vh = np.linalg.svd(mat) si = np.where(s < eps, 0.0, np.sqrt(s)) return np.matmul(np.matmul(u, np.diag(si)), vh) def trace_sqrt_product_np(cov1, cov2): """ This function calculates trace(sqrt(cov1 * cov2)) This code is inspired from: https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py :param cov1: :param cov2: :return: """ sqrt_cov1 = sqrt_sym_mat_np(cov1) cov_121 = np.matmul(np.matmul(sqrt_cov1, cov2), sqrt_cov1) return np.trace(sqrt_sym_mat_np(cov_121)) def sqrt_sym_mat_tf(mat, eps=None): """ This function calculates the square root of symmetric matrix :param mat: :param eps: :return: """ if eps is None: eps = FLAGS.EPSI s, u, v = tf.svd(mat) si = tf.where(tf.less(s, eps), s, tf.sqrt(s)) return tf.matmul(tf.matmul(u, tf.diag(si)), v, transpose_b=True) def trace_sqrt_product_tf(cov1, cov2): """ This function calculates trace(sqrt(cov1 * cov2)) This code is inspired from: https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py :param cov1: :param cov2: :return: """ sqrt_cov1 = sqrt_sym_mat_tf(cov1) cov_121 = tf.matmul(tf.matmul(sqrt_cov1, cov2), sqrt_cov1) return tf.trace(sqrt_sym_mat_tf(cov_121)) def jacobian(y, x, name='jacobian'): """ This function calculates the jacobian matrix: dy/dx and returns a list :param y: batch_size-by-d matrix :param x: batch_size-by-s tensor :param name: :return: """ with tf.name_scope(name): batch_size, d = y.get_shape().as_list() if d == 1: return tf.reshape(tf.gradients(y, x)[0], [batch_size, -1]) # b-by-s else: return tf.transpose( tf.stack( [tf.reshape(tf.gradients(y[:, i], x)[0], [batch_size, -1]) for i in range(d)], axis=0), # d-b-s perm=(1, 0, 2)) # b-d-s tensor def jacobian_squared_frobenius_norm(y, x, name='J_fnorm', do_summary=False): """ This function calculates the squared frobenious norm, e.g. sum of square of all elements in Jacobian matrix :param y: batch_size-by-d matrix :param x: batch_size-by-s tensor :param name: :param do_summary: :return: """ with tf.name_scope(name): batch_size, d = y.get_shape().as_list() # sfn - squared frobenious norm if d == 1: jaco_sfn = tf.reduce_sum(tf.square(tf.reshape(tf.gradients(y, x)[0], [batch_size, -1])), axis=1) else: jaco_sfn = tf.reduce_sum( tf.stack( [tf.reduce_sum( tf.square(tf.reshape(tf.gradients(y[:, i], x)[0], [batch_size, -1])), # b-vector axis=1) for i in range(d)], axis=0), # d-by-b axis=0) # b-vector if do_summary: with tf.name_scope(None): # return to root scope to avoid scope overlap tf.summary.histogram('Jaco_sfn', jaco_sfn) return jaco_sfn