from __future__ import print_function import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torch.backends.cudnn as cudnn import torchvision import torchvision.transforms as transforms import itertools import os import sys import time import argparse import datetime import pickle import numpy as np import sys import math sys.path.append('wide-resnet.pytorch') import config as cf from sklearn.metrics import confusion_matrix from torch.autograd import Variable from torch.utils.data.sampler import SubsetRandomSampler sys.path.append('../../') import dl2lib as dl2 import random from resnet import ResNet18 from vgg import VGG parser = argparse.ArgumentParser(description='PyTorch CIFAR-10 Training') parser = dl2.add_default_parser_args(parser) parser.add_argument('--seed', default=42, type=int, help='Random seed to use.') parser.add_argument('--lr', default=0.1, type=float, help='learning_rate') parser.add_argument('--net_type', default='wide-resnet', type=str, help='model') parser.add_argument('--depth', default=28, type=int, help='depth of model') parser.add_argument('--epochs', default=100, type=int, help='epochs') parser.add_argument('--growing', default=0, type=int, help='epochs') parser.add_argument('--widen_factor', default=10, type=int, help='width of model') parser.add_argument('--dropout', default=0.3, type=float, help='dropout_rate') parser.add_argument('--dataset', default='cifar100', type=str, help='dataset = [cifar10/cifar100]') parser.add_argument('--exp_name', default='', type=str, help='experiment name') parser.add_argument('--resume_from', type=str, default=None, help='resume from checkpoint') parser.add_argument('--testOnly', action='store_true', help='Test mode with the saved model') parser.add_argument('--constraint', type=str, choices=['DL2', 'none'], default='none', help='constraint system to use') parser.add_argument('--constraint-weight', '--constraint_weight', type=float, default=0.6, help='weight for constraint loss') parser.add_argument('--num_labeled', default=1000, type=int, help='Number of labeled examples (per class!).') parser.add_argument('--skip_labled', default=0, type=int, help='Number of labeled examples (per class!).') parser.add_argument('--decrease-eps-weight', default=1.0, type=float, help='Number of labeled examples (per class!).') parser.add_argument('--c-eps', default=0.05, type=float, help='Number of labeled examples (per class!).') parser.add_argument('--increase-constraint-weight', default=1.0, type=float, help='Number of labeled examples (per class!).') args = parser.parse_args() args.growing = bool(args.growing) args.skip_labled = bool(args.skip_labled) torchvision.datasets.CIFAR100(root='../../data/cifar100', train=True, download=True) meta = pickle.load(open('../../data/cifar100/cifar-100-python/meta', 'rb')) coarse = meta['coarse_label_names'] fine = meta['fine_label_names'] label_idx = {label:i for i, label in enumerate(fine)} group_idx = {label:i for i, label in enumerate(coarse)} g = {} group = [0 for i in range(100)] pairs = [] print(group_idx) with open('groups.txt') as f: for line in f: tokens = line[:-1].split('\t') large_group = tokens[0] tokens[1] = tokens[1].replace(',', '').strip() labels = tokens[1].split(' ') assert len(labels) == 5, labels for label in labels: assert label in fine, label group[label_idx[label]] = group_idx[large_group] g[group_idx[large_group]] = [label_idx[label] for label in labels] for x in labels: for y in labels: if x != y: pairs.append((label_idx[x], label_idx[y])) # Hyper Parameter settings use_cuda = torch.cuda.is_available() print(use_cuda) best_acc = 0 best_model = None start_epoch, num_epochs, batch_size, optim_type = cf.start_epoch, cf.num_epochs, cf.batch_size, cf.optim_type num_epochs = args.epochs # Data Uplaod print('\n[Phase 1] : Data Preparation') transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(cf.mean[args.dataset], cf.std[args.dataset]), ]) # meanstd transformation transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(cf.mean[args.dataset], cf.std[args.dataset]), ]) if(args.dataset == 'cifar10'): print("| Preparing CIFAR-10 dataset...") sys.stdout.write("| ") trainset = torchvision.datasets.CIFAR10(root='../../data/cifar10', train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR10(root='../../data/cifar10', train=False, download=False, transform=transform_test) num_classes = 10 elif(args.dataset == 'cifar100'): print("| Preparing CIFAR-100 dataset...") sys.stdout.write("| ") trainset = torchvision.datasets.CIFAR100(root='../../data/cifar100', train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR100(root='../../data/cifar100', train=False, download=False, transform=transform_test) num_classes = 100 num_train = len(trainset) per_class = [[] for _ in range(100)] for i in range(num_train): per_class[trainset[i][1]].append(i) train_lab_idx = [] train_unlab_idx = [] valid_idx = [] np.random.seed(args.seed) torch.manual_seed(args.seed) for i in range(100): np.random.shuffle(per_class[i]) split = int(np.floor(0.2 * len(per_class[i]))) train_lab_idx += per_class[i][split:split+args.num_labeled] train_unlab_idx += per_class[i][split+args.num_labeled:] valid_idx += per_class[i][:split] print('Total train[labeled]: ',len(train_lab_idx)) print('Total train[unlabeled]: ',len(train_unlab_idx)) print('Total valid: ',len(valid_idx)) train_labeled_sampler = SubsetRandomSampler(train_lab_idx) train_unlabeled_sampler = SubsetRandomSampler(train_unlab_idx) valid_sampler = SubsetRandomSampler(valid_idx) unlab_batch = batch_size if args.constraint != 'none' else 1 trainloader_lab = torch.utils.data.DataLoader( trainset, batch_size=batch_size, sampler=train_labeled_sampler, num_workers=2) trainloader_unlab = torch.utils.data.DataLoader( trainset, batch_size=unlab_batch, sampler=train_unlabeled_sampler, num_workers=2) validloader = torch.utils.data.DataLoader( trainset, batch_size=batch_size, sampler=valid_sampler, num_workers=2) def getNetwork(args): if args.net_type == 'resnet': net = ResNet18(100) file_name = 'resnet18' elif args.net_type == 'vgg': net = VGG('VGG16', 100) file_name = 'vgg' else: assert False file_name += '_' + str(args.seed) + '_' + args.exp_name return net, file_name # Test only option if (args.testOnly): testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) print('\n[Test Phase] : Model setup') assert os.path.isdir('checkpoint'), 'Error: No checkpoint directory found!' _, file_name = getNetwork(args) checkpoint = torch.load('./checkpoint/' + args.resume_from + '.t7') net = checkpoint['net'] if use_cuda: net.cuda() net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count())) cudnn.benchmark = True net.eval() test_loss = 0 correct = 0 constraint_correct = 0 total = 0 conf_mat = np.zeros((100, 100)) group_ok = 0 np.set_printoptions(threshold=np.inf) softmax = torch.nn.Softmax() for batch_idx, (inputs, targets) in enumerate(testloader): if use_cuda: inputs, targets = inputs.cuda(), targets.cuda() inputs, targets = Variable(inputs, volatile=True), Variable(targets) outputs = net(inputs) probs = softmax(outputs) eps = 0.05 dl2_one_group = [] for i in range(20): gsum = 0 for j in g[i]: gsum += probs[:, j] dl2_one_group.append(dl2.Or([dl2.GT(gsum, 1.0 - eps), dl2.LT(gsum, eps)])) constraint = dl2.And(dl2_one_group) constraint_correct += constraint.satisfy(args).sum() _, predicted = torch.max(outputs.data, 1) total += targets.size(0) correct += predicted.eq(targets.data).cpu().sum() conf_mat += confusion_matrix(targets.data.cpu().numpy(), predicted.cpu().numpy(), labels=np.arange(100)) n_batch = predicted.size()[0] for i in range(n_batch): if group[predicted.cpu()[i]] == group[targets.cpu().data[i]]: group_ok += 1 #rint('Confusion matrix:') #print(conf_mat) acc = 100.0*float(correct)/total c_acc = 100.0*float(constraint_correct)/total group_acc = 100.0*float(group_ok)/total print("| Test Result\tAcc@1: %.2f%%" %(acc)) print("| Test Result\tCAcc: %.2f%%" %(c_acc)) print("| Test Result\tGroupAcc: %.2f%%" %(group_acc)) sys.exit(0) # Model print('\n[Phase 2] : Model setup') if args.resume_from is not None: # Load checkpoint print('| Resuming from checkpoint...') assert os.path.isdir('checkpoint'), 'Error: No checkpoint directory found!' _, file_name = getNetwork(args) checkpoint = torch.load('./checkpoint/' + args.resume_from + '.t7') net = checkpoint['net'] best_acc = checkpoint['acc'] start_epoch = checkpoint['epoch'] else: print('| Building net type [' + args.net_type + ']...') net, file_name = getNetwork(args) # net.apply(conv_init) if use_cuda: net.cuda() net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count())) cudnn.benchmark = True criterion = nn.CrossEntropyLoss() # Training def train(epoch): net.train() train_loss = 0 correct = 0 total = 0 optimizer = optim.Adam(net.parameters(), lr=args.lr) softmax = torch.nn.Softmax() print('\n=> Training Epoch #%d, LR=%.4f' %(epoch, args.lr)) if args.skip_labled: tl = [None] * 200000 else: tl = trainloader_lab for batch_idx, (lab, ulab) in enumerate(zip(tl, trainloader_unlab)): inputs_u, targets_u = ulab inputs_u, targets_u = Variable(inputs_u), Variable(targets_u) n_u = inputs_u.size()[0] if use_cuda: inputs_u, targets_u = inputs_u.cuda(), targets_u.cuda() # GPU settings if lab is None: n = 0 all_outputs = net(inputs_u) else: inputs, targets = lab inputs, targets = Variable(inputs), Variable(targets) n = inputs.size()[0] if use_cuda: inputs, targets = inputs.cuda(), targets.cuda() # GPU settings all_outputs = net(torch.cat([inputs, inputs_u], dim=0)) optimizer.zero_grad() outputs_u = all_outputs[n:,] logits_u = F.log_softmax(outputs_u) probs_u = softmax(outputs_u) outputs = all_outputs[:n,] if args.skip_labled: ce_loss = 0 else: outputs = all_outputs[:n,] ce_loss = criterion(outputs, targets) # Loss constraint_loss = 0 if args.constraint == 'DL2': eps = args.c_eps * args.decrease_eps_weight**epoch dl2_one_group = [] for i in range(20): gsum = 0 for j in g[i]: gsum += probs_u[:,j] dl2_one_group.append(dl2.Or([dl2.EQ(gsum, 1.0), dl2.EQ(gsum, 0.0)])) dl2_one_group = dl2.And(dl2_one_group) dl2_loss = dl2_one_group.loss(args).mean() constraint_loss = dl2_loss loss = ce_loss + (args.constraint_weight * args.increase_constraint_weight**epoch) * dl2_loss else: loss = ce_loss loss.backward() # Backward Propagation optimizer.step() # Optimizer update train_loss += loss.item() if args.skip_labled: total = 1 correct = 0 else: _, predicted = torch.max(outputs.data, 1) total += targets.size(0) correct += predicted.eq(targets.data).cpu().sum() sys.stdout.write('\r') sys.stdout.write('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tCE Loss: %.4f, Constraint Loss: %.4f Acc@1: %.3f%%' %(epoch, num_epochs, batch_idx+1, (len(train_lab_idx)//batch_size)+1, loss, constraint_loss, 100.*float(correct)/total)) sys.stdout.flush() return 100.*float(correct)/total def save(acc, e, net, best=False): state = { 'net': net.module if use_cuda else net, 'acc': acc, 'epoch': epoch, } if not os.path.isdir('checkpoint'): os.mkdir('checkpoint') if best: e = int(100* math.floor(( float(epoch) / 100)) ) save_point = './checkpoint/' + file_name + '_' + str(e) + '_best' + '.t7' else: save_point = './checkpoint/' + file_name + '_' + str(e) + '_' + '.t7' torch.save(state, save_point) return net def test(epoch): global best_acc, best_model net.eval() test_loss = 0 correct = 0 total = 0 for batch_idx, (inputs, targets) in enumerate(validloader): if use_cuda: inputs, targets = inputs.cuda(), targets.cuda() inputs, targets = Variable(inputs, volatile=True), Variable(targets) outputs = net(inputs) loss = criterion(outputs, targets) test_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total += targets.size(0) correct += predicted.eq(targets.data).cpu().sum() # Save checkpoint when best model acc = 100.*float(correct)/total print("\n| Validation Epoch #%d\t\t\tLoss: %.4f Acc@1: %.2f%%" %(epoch, loss.item(), acc)) if acc > best_acc: #print('| Saving Best model...\t\t\tTop1 = %.2f%%' %(acc)) best_model = save(acc, num_epochs, net, best=True) best_acc = acc print('\n[Phase 3] : Training model') print('| Training Epochs = ' + str(num_epochs)) print('| Initial Learning Rate = ' + str(args.lr)) print('| Optimizer = ' + str(optim_type)) elapsed_time = 0 for epoch in range(start_epoch, start_epoch+num_epochs): start_time = time.time() acc = train(epoch) if epoch % 100 == 0: save(acc, epoch, net) test(epoch) epoch_time = time.time() - start_time elapsed_time += epoch_time print('| Elapsed time : %d:%02d:%02d' %(cf.get_hms(elapsed_time))) if best_model is not None: print('.') save(best_acc, 'overall', best_model) print('\n[Phase 4] : Testing model') print('* Test results : Acc@1 = %.2f%%' %(best_acc))