import os import argparse import numpy as np import torch import torch.nn as nn from torch.autograd import Variable from torchvision import datasets, transforms import models # Prune settings parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune') parser.add_argument('--dataset', type=str, default='cifar100', help='training dataset (default: cifar100)') parser.add_argument('--test-batch-size', type=int, default=100, metavar='N', help='input batch size for testing (default: 100)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--percent', type=float, default=0.5, help='scale sparse rate (default: 0.5)') parser.add_argument('--model', default='', type=str, metavar='PATH', help='path to raw trained model (default: none)') parser.add_argument('--save', default='.', type=str, metavar='PATH', help='path to save prune model (default: none)') parser.add_argument('--depth', default=19, type=int, help='depth of resnet and densenet') parser.add_argument('--arch', default='vgg', type=str, help='architecture to use') args = parser.parse_args() args.cuda = not args.no_cuda and torch.cuda.is_available() if not os.path.exists(args.save): os.makedirs(args.save) model = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth) if args.cuda: model.cuda() if args.model: if os.path.isfile(args.model): print("=> loading checkpoint '{}'".format(args.model)) checkpoint = torch.load(args.model) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}" .format(args.model, checkpoint['epoch'], best_prec1)) print(model) total = 0 for m in model.modules(): if isinstance(m, nn.BatchNorm2d): total += m.weight.data.shape[0] bn = torch.zeros(total) index = 0 for m in model.modules(): if isinstance(m, nn.BatchNorm2d): size = m.weight.data.shape[0] bn[index:(index+size)] = m.weight.data.abs().clone() index += size y, i = torch.sort(bn) thre_index = int(total * args.percent) thre = y[thre_index] pruned = 0 cfg = [] cfg_mask = [] for k, m in enumerate(model.modules()): if isinstance(m, nn.BatchNorm2d): weight_copy = m.weight.data.abs().clone() mask = weight_copy.gt(thre).float().cuda() pruned = pruned + mask.shape[0] - torch.sum(mask) m.weight.data.mul_(mask) m.bias.data.mul_(mask) cfg.append(int(torch.sum(mask))) cfg_mask.append(mask.clone()) print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'. format(k, mask.shape[0], int(torch.sum(mask)))) elif isinstance(m, nn.MaxPool2d): cfg.append('M') torch.save({'cfg': cfg, 'state_dict': model.state_dict()}, os.path.join(args.save, 'pruned.pth.tar')) pruned_ratio = pruned/total print('Pre-processing Successful!') # simple test model after Pre-processing prune (simple set BN scales to zeros) def test(): kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} if args.dataset == 'cifar10': test_loader = torch.utils.data.DataLoader( datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), batch_size=args.test_batch_size, shuffle=True, **kwargs) elif args.dataset == 'cifar100': test_loader = torch.utils.data.DataLoader( datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])), batch_size=args.test_batch_size, shuffle=True, **kwargs) else: raise ValueError("No valid dataset is given.") model.eval() correct = 0 for data, target in test_loader: if args.cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data, volatile=True), Variable(target) output = model(data) pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability correct += pred.eq(target.data.view_as(pred)).cpu().sum() print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format( correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) return correct / float(len(test_loader.dataset)) acc = test() print(cfg) savepath = os.path.join(args.save, "prune.txt") with open(savepath, "w") as fp: fp.write("Configuration: \n") fp.write(str(cfg)+"\n") fp.write("Test accuracy: \n") fp.write(str(acc))