from __future__ import print_function import argparse import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms from torch.autograd import Variable import os # import matplotlib.pyplot as plt from collections import defaultdict import numpy as np import scipy as sp # import sklearn as skl import pickle import torch.utils.data as data from PIL import Image import os import os.path import errno import torch import codecs import random import util class RandomCIFAR10(data.Dataset): def __init__(self, data_dir, train=False, transform=None, target_transform=None, download=False): self.transform = transform self.target_transform = target_transform def __getitem__(self, index): img = torch.randn(3, 32, 32) target = random.randint(0,9) # doing this so that it is consistent with all other datasets # to return a PIL Image #img = Image.fromarray(img.numpy(), mode='L') #if self.transform is not None: # img = self.transform(img) #if self.target_transform is not None: # target = self.target_transform(target) return img, target def __len__(self): return 1000 class Adv(data.Dataset): def __init__(self, transform=None, target_transform=None, filename="adv_set_e_2.p", transp = False): """ :param transform: :param target_transform: :param filename: :param transp: Set shuff= False for PGD based attacks :return: """ self.transform = transform self.target_transform = target_transform self.adv_dict=pickle.load(open(filename,"rb")) self.adv_flat=self.adv_dict["adv_input"] self.num_adv=np.shape(self.adv_flat)[0] self.shuff = transp self.sample_num = 0 def __getitem__(self, index): img=self.adv_flat[self.sample_num,:] if(self.shuff == False): # shuff is true for non-pgd attacks img = torch.from_numpy(np.reshape(img,(3,32,32))) else: img = torch.from_numpy(img).type(torch.FloatTensor) target = np.argmax(self.adv_dict["adv_labels"],axis=1)[self.sample_num] # doing this so that it is consistent with all other datasets # to return a PIL Image if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) self.sample_num = self.sample_num + 1 return img, target def __len__(self): return 14