import importlib import os import subprocess from PIL import Image import numpy as np import torch import torch.nn.functional as F import torch.utils.data from torchvision import datasets, transforms, models from advex_uar.common.loader import StridedImageFolder from advex_uar.eval.cifar10c import CIFAR10C from advex_uar.train.trainer import Metric, accuracy, correct IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] def norm_to_pil_image(img): img_new = torch.Tensor(img) for t, m, s in zip(img_new, IMAGENET_MEAN, IMAGENET_STD): t.mul_(s).add_(m) img_new.mul_(255) np_img = np.rollaxis(np.uint8(img_new.numpy()), 0, 3) return Image.fromarray(np_img, mode='RGB') class Accumulator(object): def __init__(self, name): self.name = name self.vals = [] def update(self, val): self.vals.append(val) @property def avg(self): total_sum = sum([torch.sum(v) for v in self.vals]) total_size = sum([v.size()[0] for v in self.vals]) return total_sum / total_size class BaseEvaluator(): def __init__(self, **kwargs): default_attr = dict( # eval options model=None, batch_size=32, stride=10, dataset_path=None, # val dir for imagenet, base dir for CIFAR-10-C nb_classes=None, # attack options attack=None, # Communication options fp16_allreduce=False, # Logging options logger=None) default_attr.update(kwargs) for k in default_attr: setattr(self, k, default_attr[k]) if self.dataset not in ['imagenet', 'imagenet-c', 'cifar-10', 'cifar-10-c']: raise NotImplementedError self.cuda = True if self.cuda: self.model.cuda() self.attack = self.attack() self._init_loaders() def _init_loaders(self): raise NotImplementedError def evaluate(self): self.model.eval() std_loss = Accumulator('std_loss') adv_loss = Accumulator('adv_loss') std_corr = Accumulator('std_corr') adv_corr = Accumulator('adv_corr') std_logits = Accumulator('std_logits') adv_logits = Accumulator('adv_logits') seen_classes = [] adv_images = Accumulator('adv_images') first_batch_images = Accumulator('first_batch_images') for batch_idx, (data, target) in enumerate(self.val_loader): if self.cuda: data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) with torch.no_grad(): output = self.model(data) std_logits.update(output.cpu()) loss = F.cross_entropy(output, target, reduction='none').cpu() std_loss.update(loss) corr = correct(output, target) corr = corr.view(corr.size()[0]).cpu() std_corr.update(corr) rand_target = torch.randint( 0, self.nb_classes - 1, target.size(), dtype=target.dtype, device='cuda') rand_target = torch.remainder(target + rand_target + 1, self.nb_classes) data_adv = self.attack(self.model, data, rand_target, avoid_target=False, scale_eps=False) for idx in range(target.size()[0]): if target[idx].cpu() not in seen_classes: seen_classes.append(target[idx].cpu()) orig_image = norm_to_pil_image(data[idx].detach().cpu()) adv_image = norm_to_pil_image(data_adv[idx].detach().cpu()) adv_images.update((orig_image, adv_image, target[idx].cpu())) if batch_idx == 0: for idx in range(target.size()[0]): orig_image = norm_to_pil_image(data[idx].detach().cpu()) adv_image = norm_to_pil_image(data_adv[idx].detach().cpu()) first_batch_images.update((orig_image, adv_image)) with torch.no_grad(): output_adv = self.model(data_adv) adv_logits.update(output_adv.cpu()) loss = F.cross_entropy(output_adv, target, reduction='none').cpu() adv_loss.update(loss) corr = correct(output_adv, target) corr = corr.view(corr.size()[0]).cpu() adv_corr.update(corr) run_output = {'std_loss':std_loss.avg, 'std_acc':std_corr.avg, 'adv_loss':adv_loss.avg, 'adv_acc':adv_corr.avg} print('Batch', batch_idx) print(run_output) if batch_idx % 20 == 0: self.logger.log(run_output, batch_idx) summary_dict = {'std_acc':std_corr.avg.item(), 'adv_acc':adv_corr.avg.item()} self.logger.log_summary(summary_dict) for orig_img, adv_img, target in adv_images.vals: self.logger.log_image(orig_img, 'orig_{}.png'.format(target)) self.logger.log_image(adv_img, 'adv_{}.png'.format(target)) for idx, imgs in enumerate(first_batch_images.vals): orig_img, adv_img = imgs self.logger.log_image(orig_img, 'init_orig_{}.png'.format(idx)) self.logger.log_image(adv_img, 'init_adv_{}.png'.format(idx)) self.logger.end() print(std_loss.avg, std_corr.avg, adv_loss.avg, adv_corr.avg) class CIFAR10Evaluator(BaseEvaluator): def _init_loaders(self): normalize = transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) self.val_dataset = datasets.CIFAR10( root='./', download=True, train=False, transform=transforms.Compose([ transforms.ToTensor(), normalize,])) self.val_loader = torch.utils.data.DataLoader( self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=8, pin_memory=True) class ImagenetEvaluator(BaseEvaluator): def _init_loaders(self): valdir = self.dataset_path normalize = transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) self.val_dataset = StridedImageFolder( valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize,]), stride=self.stride) self.val_sampler = torch.utils.data.SequentialSampler(self.val_dataset) self.val_loader = torch.utils.data.DataLoader( self.val_dataset, batch_size=self.batch_size, sampler=self.val_sampler, num_workers=1, pin_memory=True, shuffle=False) class ImagenetCEvaluator(BaseEvaluator): def __init__(self, corruption_type=None, corruption_name=None, corruption_level=None, **kwargs): self.corruption_type = corruption_type self.corruption_name = corruption_name self.corruption_level = corruption_level super().__init__(**kwargs) def _init_loaders(self): valdir = os.path.join(self.dataset_path, 'imagenet-c', self.corruption_type, self.corruption_name, self.corruption_level) normalize = transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) self.val_dataset = StridedImageFolder( valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize,]), stride=self.stride) self.val_sampler = torch.utils.data.SequentialSampler(self.val_dataset) self.val_loader = torch.utils.data.DataLoader( self.val_dataset, batch_size=self.batch_size, sampler=self.val_sampler, num_workers=1, pin_memory=True, shuffle=False) class CIFAR10CEvaluator(BaseEvaluator): def __init__(self, corruption_type=None, corruption_name=None, corruption_level=None, **kwargs): self.corruption_type = corruption_type self.corruption_name = corruption_name self.corruption_level = corruption_level super().__init__(**kwargs) def _init_loaders(self): valdir = os.path.join(self.dataset_path, 'CIFAR-10-C') transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) self.val_dataset = CIFAR10C(valdir, transform=transform, corruption_name=self.corruption_name, corruption_level=self.corruption_level) self.val_sampler = torch.utils.data.SequentialSampler(self.val_dataset) self.val_loader = torch.utils.data.DataLoader( self.val_dataset, batch_size=self.batch_size, sampler=self.val_sampler, num_workers=1, pin_memory=True, shuffle=False)