from torchvision import datasets import argparse, os import numpy as np parser = argparse.ArgumentParser() parser.add_argument("--seed", "-s", default=1, type=int, help="random seed") parser.add_argument("--dataset", "-d", default="svhn", type=str, help="dataset name : [svhn, cifar10]") parser.add_argument("--nlabels", "-n", default=1000, type=int, help="the number of labeled data") args = parser.parse_args() COUNTS = { "svhn": {"train": 73257, "test": 26032, "valid": 7326, "extra": 531131}, "cifar10": {"train": 50000, "test": 10000, "valid": 5000, "extra": 0}, "imagenet_32": { "train": 1281167, "test": 50000, "valid": 50050, "extra": 0, }, } _DATA_DIR = "./data" def split_l_u(train_set, n_labels): # NOTE: this function assume that train_set is shuffled. images = train_set["images"] labels = train_set["labels"] classes = np.unique(labels) n_labels_per_cls = n_labels // len(classes) l_images = [] l_labels = [] u_images = [] u_labels = [] for c in classes: cls_mask = (labels == c) c_images = images[cls_mask] c_labels = labels[cls_mask] l_images += [c_images[:n_labels_per_cls]] l_labels += [c_labels[:n_labels_per_cls]] u_images += [c_images[n_labels_per_cls:]] u_labels += [np.zeros_like(c_labels[n_labels_per_cls:]) - 1] # dammy label l_train_set = {"images": np.concatenate(l_images, 0), "labels": np.concatenate(l_labels, 0)} u_train_set = {"images": np.concatenate(u_images, 0), "labels": np.concatenate(u_labels, 0)} return l_train_set, u_train_set def _load_svhn(): splits = {} for split in ["train", "test", "extra"]: tv_data = datasets.SVHN(_DATA_DIR, split, download=True) data = {} data["images"] = tv_data.data data["labels"] = tv_data.labels splits[split] = data return splits.values() def _load_cifar10(): splits = {} for train in [True, False]: tv_data = datasets.CIFAR10(_DATA_DIR, train, download=True) data = {} data["images"] = tv_data.data data["labels"] = np.array(tv_data.targets) splits["train" if train else "test"] = data return splits.values() def gcn(images, multiplier=55, eps=1e-10): # global contrast normalization images = images.astype(np.float) images -= images.mean(axis=(1,2,3), keepdims=True) per_image_norm = np.sqrt(np.square(images).sum((1,2,3), keepdims=True)) per_image_norm[per_image_norm < eps] = 1 return multiplier * images / per_image_norm def get_zca_normalization_param(images, scale=0.1, eps=1e-10): n_data, height, width, channels = images.shape images = images.reshape(n_data, height*width*channels) image_cov = np.cov(images, rowvar=False) U, S, _ = np.linalg.svd(image_cov + scale * np.eye(image_cov.shape[0])) zca_decomp = np.dot(U, np.dot(np.diag(1/np.sqrt(S + eps)), U.T)) mean = images.mean(axis=0) return mean, zca_decomp def zca_normalization(images, mean, decomp): n_data, height, width, channels = images.shape images = images.reshape(n_data, -1) images = np.dot((images - mean), decomp) return images.reshape(n_data, height, width, channels) rng = np.random.RandomState(args.seed) validation_count = COUNTS[args.dataset]["valid"] extra_set = None # In general, there won't be extra data. if args.dataset == "svhn": train_set, test_set, extra_set = _load_svhn() elif args.dataset == "cifar10": train_set, test_set = _load_cifar10() train_set["images"] = gcn(train_set["images"]) test_set["images"] = gcn(test_set["images"]) mean, zca_decomp = get_zca_normalization_param(train_set["images"]) train_set["images"] = zca_normalization(train_set["images"], mean, zca_decomp) test_set["images"] = zca_normalization(test_set["images"], mean, zca_decomp) # N x H x W x C -> N x C x H x W train_set["images"] = np.transpose(train_set["images"], (0,3,1,2)) test_set["images"] = np.transpose(test_set["images"], (0,3,1,2)) # permute index of training set indices = rng.permutation(len(train_set["images"])) train_set["images"] = train_set["images"][indices] train_set["labels"] = train_set["labels"][indices] if extra_set is not None: extra_indices = rng.permutation(len(extra_set["images"])) extra_set["images"] = extra_set["images"][extra_indices] extra_set["labels"] = extra_set["labels"][extra_indices] # split training set into training and validation train_images = train_set["images"][validation_count:] train_labels = train_set["labels"][validation_count:] validation_images = train_set["images"][:validation_count] validation_labels = train_set["labels"][:validation_count] validation_set = {"images": validation_images, "labels": validation_labels} train_set = {"images": train_images, "labels": train_labels} # split training set into labeled data and unlabeled data l_train_set, u_train_set = split_l_u(train_set, args.nlabels) if not os.path.exists(os.path.join(_DATA_DIR, args.dataset)): os.mkdir(os.path.join(_DATA_DIR, args.dataset)) np.save(os.path.join(_DATA_DIR, args.dataset, "l_train"), l_train_set) np.save(os.path.join(_DATA_DIR, args.dataset, "u_train"), u_train_set) np.save(os.path.join(_DATA_DIR, args.dataset, "val"), validation_set) np.save(os.path.join(_DATA_DIR, args.dataset, "test"), test_set) if extra_set is not None: np.save(os.path.join(_DATA_DIR, args.dataset, "extra"), extra_set)