#!/usr/bin/env python # -*- coding: utf-8 -*- # References: J. Buckman, et al., "Thermometer Encoding: One hot way to resist adversarial examples," in ICLR, 2018. # Reference Implementation (TensorFlow): https://github.com/anishathalye/obfuscated-gradients/tree/master/thermometer # ************************************** # @Time : 2018/11/22 18:30 # @Author : Jiaxu Zou # @Lab : nesa.zju.edu.cn # @File : TE.py # ************************************** import os import numpy as np import torch import torch.nn.functional as F import torch.optim as optim from Defenses.DefenseMethods.defenses import Defense from RawModels.ResNet import adjust_learning_rate def one_hot_encoding(samples=None, level=None, device=None): """ the help function to encode the samples using the one-hot encoding schema :param samples: :param level: :param device: :return: """ assert level is not None and isinstance(level, int), 'level should specified as an integer' assert torch.is_tensor(samples), "input samples must be a PyTorch tensor" if len(samples.shape) >= 4 and (samples.shape[1] == 1 or samples.shape[1] == 3): samples = samples.permute(0, 2, 3, 1) # inserting the last position for samples (handle the upper bound by multiplying 0.9999) discretized_samples = torch.unsqueeze(input=(0.99999 * samples * level).long().to(device), dim=4) # make the last dim be the level number shape = discretized_samples.shape # convert to one_hot encoding one_hot_samples = torch.zeros([shape[0], shape[1], shape[2], shape[3], level]).to(device).scatter_(-1, discretized_samples, 1) one_hot_samples = one_hot_samples.float() return one_hot_samples def thermometer_encoding(samples=None, level=None, device=None): """ the help function to encode the samples using the thermometer encoding schema :param samples: :param level: :param device: :return: """ assert level is not None and isinstance(level, int), 'level should specified as an integer' assert torch.is_tensor(samples), "input samples must be a PyTorch tensor" if len(samples.shape) >= 4 and (samples.shape[1] == 1 or samples.shape[1] == 3): samples = samples.permute(0, 2, 3, 1) # convert one hot encoding to thermometer encoding one_hot_samples = one_hot_encoding(samples=samples, level=level, device=device) therm_samples = torch.cumsum(one_hot_samples, dim=-1) # the returned samples is a type of numpy data with shape [BatchSize * (Channel * Level) * Weight* Height] shape = samples.shape therm_samples_numpy = torch.reshape(therm_samples, (shape[0], shape[1], shape[2], shape[3] * level)) therm_samples_numpy = therm_samples_numpy.permute(0, 3, 1, 2).cpu().numpy() return therm_samples_numpy class TEDefense(Defense): def __init__(self, model=None, defense_name=None, dataset=None, training_parameters=None, device=None, **kwargs): """ :param model: :param defense_name: :param dataset: :param training_parameters: :param device: :param kwargs: """ super(TEDefense, self).__init__(model=model, defense_name=defense_name) self.model = model self.defense_name = defense_name self.device = device self.Dataset = dataset.upper() assert self.Dataset in ['MNIST', 'CIFAR10'], "The data set must be MNIST or CIFAR10" # make sure to parse the parameters for the defense assert self._parsing_parameters(**kwargs) # get the training_parameters, the same as the settings of RawModels self.num_epochs = training_parameters['num_epochs'] self.batch_size = training_parameters['batch_size'] # prepare the optimizers if self.Dataset == 'MNIST': self.optimizer = optim.SGD(self.model.parameters(), lr=training_parameters['learning_rate'], momentum=training_parameters['momentum'], weight_decay=training_parameters['decay'], nesterov=True) else: self.optimizer = optim.Adam(self.model.parameters(), lr=training_parameters['lr']) def _parsing_parameters(self, **kwargs): assert kwargs is not None, "the parameters should be specified" print("\nUser configurations for the {} defense".format(self.defense_name)) for key in kwargs: print('\t{} = {}'.format(key, kwargs[key])) self.level = kwargs['level'] self.steps = kwargs['steps'] self.attack_eps = kwargs['attack_eps'] self.attack_step_size = kwargs['attack_step_size'] return True def lspga_generation(self, samples=None, ys=None, noise_init=True): """ one type of white-box attacks on discretized inputs (thermometer encoding) -- Logit-Space Projected Gradient Ascent (LS-PGA) the detailed pseudo-code for LS-PGA attack is described in Algorithm 3 of the referenced paper :param samples: :param ys: :param noise_init: :return: """ # STEP 1: sub-routine for getting an \epsilon-discretized masked of an image lowest = torch.clamp(samples - self.attack_eps, 0.0, 1.0) highest = torch.clamp(samples + self.attack_eps, 0.0, 1.0) # get the masking of intervals between lowest and highest masked_intervals = 0.0 for alpha in np.linspace(0., 1., self.level): single_one_hot = one_hot_encoding(samples=alpha * lowest + (1. - alpha) * highest, level=self.level, device=self.device) masked_intervals += single_one_hot masked = (masked_intervals > 0.0).float() shape = masked.shape # STEP 2: main function for generating adversarial examples using LS-PGA # init each of logits randomly with values sampled from a standard normal distribution. if noise_init is True: us_numpy = torch.randn(shape).cpu().numpy() else: us_numpy = torch.zeros_like(masked).cpu().numpy() # generating inv_temp = 1.0 sigma_rate = 1.2 self.model.eval() for i in range(self.steps): us_logits = torch.from_numpy(us_numpy).to(self.device).float() us_logits.requires_grad = True # if not masked ( equal 0), turn it to be -inf (-99999) # then embedding the logits using softmax function with temperature to us_probs = F.softmax(inv_temp * (us_logits * masked - 999999.0 * (1. - masked)), dim=-1) # apply the cumulative sum function and reshape to get the distribution embedding thermometer_probs = torch.cumsum(us_probs, dim=-1) thermometer_probs = torch.reshape(thermometer_probs, (shape[0], shape[1], shape[2], shape[3] * self.level)) # convert the channel back to the second position thermometer_probs = thermometer_probs.permute(0, 3, 1, 2) logits = self.model(thermometer_probs) if ys is None and i == 0: ys = torch.argmax(logits, dim=1) loss = F.cross_entropy(logits, ys) gradients = torch.autograd.grad(loss, us_logits)[0] signed_gradient = torch.sign(gradients).cpu().numpy() us_numpy += self.attack_step_size * signed_gradient inv_temp *= sigma_rate # anneal the temperature via exponential decay with rate sigma us_logits = torch.from_numpy(us_numpy).to(self.device).float() logits_results = us_logits * masked - 999999.0 * (1. - masked) logits_final = torch.argmax(logits_results, dim=-1, keepdim=True) one_hot_adv_samples = torch.zeros([shape[0], shape[1], shape[2], shape[3], self.level]).to(self.device).scatter_(-1, logits_final, 1) one_hot_adv_samples = one_hot_adv_samples.float() # the returned samples is a type of numpy dataset therm_adv_samples = torch.cumsum(one_hot_adv_samples, dim=-1) final_adv_samples = torch.reshape(therm_adv_samples, (shape[0], shape[1], shape[2], shape[3] * self.level)) final_adv_samples_numpy = final_adv_samples.permute(0, 3, 1, 2).cpu().numpy() return final_adv_samples_numpy def train_one_epoch_with_adv_lspga(self, train_loader=None, epoch=None, weight_regular=None): """ :param train_loader: :param epoch: :return: """ for index, (images, labels) in enumerate(train_loader): nat_images_numpy = thermometer_encoding(samples=images.to(self.device), level=self.level, device=self.device) nat_labels = labels.to(self.device) # prepare for LSPGA perturbation self.model.eval() adv_images_numpy = self.lspga_generation(samples=images.to(self.device)) # concatenate the nature samples and adversarial examples batch_images_numpy = np.concatenate((nat_images_numpy, adv_images_numpy), axis=0) batch_images = torch.from_numpy(batch_images_numpy).to(self.device) # concatenate the true labels batch_labels = torch.cat((nat_labels, nat_labels), dim=0) # set the model in the training mode self.model.train() # forward the nn logits = self.model(batch_images) loss = F.cross_entropy(logits, batch_labels) # backward self.optimizer.zero_grad() loss.backward() self.optimizer.step() print('\rTrain Epoch{:>3}: [batch:{:>4}/{:>4}] \tLoss={:.4f} ===> '.format(epoch, index, len(train_loader), loss), end=' ') def thermometer_validation_evaluation(self, validation_loader, device): """ validation evaluation with slight modification for thermometer encoded input samples :param validation_loader: :param device: :return: """ self.model.eval() total = 0.0 correct = 0.0 with torch.no_grad(): for index, (inputs, labels) in enumerate(validation_loader): therm_inputs = thermometer_encoding(samples=inputs.to(self.device), level=self.level, device=device) therm_inputs = torch.from_numpy(therm_inputs).to(self.device) labels = labels.to(device) outputs = self.model(therm_inputs) _, predicted = torch.max(outputs.data, 1) total = total + labels.size(0) correct = correct + (predicted == labels).sum().item() ratio = correct / total print('validation set accuracy is ', ratio) return ratio def defense(self, train_loader=None, validation_loader=None): best_val_acc = None for epoch in range(self.num_epochs): # training the model with nature examples and corresponding adversarial examples self.train_one_epoch_with_adv_lspga(train_loader=train_loader, epoch=epoch, weight_regular=1e-4) val_acc = self.thermometer_validation_evaluation(validation_loader=validation_loader, device=self.device) # adjust the learning rate for cifar10 training if self.Dataset == 'CIFAR10': adjust_learning_rate(optimizer=self.optimizer, epoch=epoch) # save the retained defense-enhanced model assert os.path.exists('../DefenseEnhancedModels/{}'.format(self.defense_name)) defense_enhanced_saver = '../DefenseEnhancedModels/{}/{}_{}_enhanced.pt'.format(self.defense_name, self.Dataset, self.defense_name) if not best_val_acc or round(val_acc, 4) >= round(best_val_acc, 4): if best_val_acc is not None: os.remove(defense_enhanced_saver) best_val_acc = val_acc self.model.save(name=defense_enhanced_saver) else: print('Train Epoch{:>3}: validation dataset accuracy did not improve from {:.4f}\n'.format(epoch, best_val_acc))