import torch import sys import numpy as np import torchvision import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data.sampler import SubsetRandomSampler from torchvision import transforms from model.resnet_cifar10 import BasicBlock from pruner.fp_mbnetv2 import FilterPrunerMBNetV2 from pruner.fp_resnet import FilterPrunerResNet import argparse def measure_model(model, pruner, img_size): pruner.reset() model.eval() pruner.forward(torch.zeros((1,3,img_size,img_size), device='cuda')) cur_flops = pruner.cur_flops cur_size = pruner.cur_size return cur_flops, cur_size def save_checkpoint(state, is_best, filename='checkpoint'): if is_best: torch.save(state, '{}_best.pth.tar'.format(filename)) def get_valid_flops(model, cbns, out_maps): lastConv = None residual_chain = {} chain_max_dim = 0 for m in model.modules(): if isinstance(m, BasicBlock): residual_chain[lastConv] = m.conv[3] lastConv = m.conv[3] chain_max_dim = np.maximum(chain_max_dim, lastConv.weight.size(1)) if isinstance(m, nn.Conv2d): lastConv = m chain_max_dim = lastConv.weight.size(1) # Deal with the chain first mask = np.zeros(chain_max_dim) for key in residual_chain: conv = residual_chain[key] target_idx = cbns[0].index(conv) target_bn = cbns[1][target_idx] cur_mask = target_bn.weight.data.cpu().numpy() cur_mask = np.concatenate((cur_mask, np.zeros(chain_max_dim - len(cur_mask)))) mask = np.logical_or(mask, cur_mask) flops = 0 for idx, (conv, bn) in enumerate(zip(*cbns)): if conv in residual_chain: cur_mask = mask[:bn.weight.size(0)] valid_output = np.sum(cur_mask) if idx == 0: valid_input = conv.weight.size(1) else: valid_input = (torch.abs(cbns[1][idx-1].weight) > 0).sum().item() else: valid_output = (torch.abs(bn.weight) > 0).sum().item() cur_mask = mask[:cbns[1][idx-1].weight.size(0)] valid_input = np.sum(cur_mask) flops += out_maps[idx][0] * out_maps[idx][1] * valid_output * valid_input * conv.weight.size(2) * conv.weight.size(3) / conv.groups return flops def get_cbns(model): convs = [] bns = [] for m in model.modules(): # store the information for batchnorm if isinstance(m, nn.Conv2d): convs.append(m) elif isinstance(m, nn.BatchNorm2d): bns.append(m) return convs, bns def regularizer(model, constraint='size', cbns=None, maps=None): # build kv map if cbns is None: cbns = get_cbns(model) else: G = torch.zeros([1], requires_grad=True).cuda() for idx, (conv, bn) in enumerate(zip(*cbns)): if idx < len(cbns[0])-1: gamma_prev = torch.abs(bn.weight) A = (gamma_prev > 0) gamma_now = torch.abs(cbns[1][idx+1].weight) B = (gamma_now > 0) if constraint == 'size': cost = cbns[0][idx+1].weight.size(2)*cbns[0][idx+1].weight.size(3) elif constraint == 'flops': assert maps is not None, 'Output Map is None!' cost = 2 * maps[idx+1][0] * maps[idx+1][0] * cbns[0][idx+1].weight.size(2) * cbns[0][idx+1].weight.size(3) G = G + cost * (gamma_prev.sum()*B.sum().type_as(gamma_prev) + gamma_now.sum()*A.sum().type_as(gamma_now)) return G def num_alive_filters(model): cnt = 0 for m in model.modules(): if isinstance(m, nn.BatchNorm2d): cnt = cnt + (torch.abs(m.weight) > 0).sum().item() return cnt # Truncate small beta and enforce depth-wise in-out numbers def truncate_smallbeta(model, cbns): lastConv = None residual_chain = {} chain_max_dim = 0 for m in model.modules(): if isinstance(m, BasicBlock): residual_chain[lastConv] = m.conv[3] lastConv = m.conv[3] chain_max_dim = np.maximum(chain_max_dim, lastConv.weight.size(1)) if isinstance(m, nn.Conv2d): lastConv = m chain_max_dim = lastConv.weight.size(1) # Deal with the chain first mask = np.zeros(chain_max_dim) for key in residual_chain: conv = residual_chain[key] target_idx = cbns[0].index(conv) target_bn = cbns[1][target_idx] cur_mask = target_bn.weight.data.cpu().numpy() zero_idx = np.abs(cur_mask) < 0.01 cur_mask[zero_idx] = 0 cur_mask = np.concatenate((cur_mask, np.zeros(chain_max_dim - len(cur_mask)))) mask = np.logical_or(mask, cur_mask) for idx, (conv, bn) in enumerate(zip(*cbns)): weights = bn.weight.data.cpu().numpy() bias = bn.bias.data.cpu().numpy() if conv in residual_chain: cur_mask = mask[:weights.shape[0]] weights *= cur_mask bias *= cur_mask else: idx_out = np.abs(weights) < 0.01 weights[idx_out] = 0 bias[idx_out] = 0 bn.weight.data = torch.from_numpy(weights).cuda() bn.bias.data = torch.from_numpy(bias).cuda() def test(model, loader): model.eval() total = 0 top1 = 0 total_loss = 0 criterion = torch.nn.CrossEntropyLoss() for i, (batch, label) in enumerate(loader): batch, label = batch.to('cuda'), label.to('cuda') total += batch.size(0) out = model(batch) total_loss += criterion(out, label).item() _, pred = out.max(dim=1) top1 += pred.eq(label).sum() return float(top1)/total*100, total_loss/total def train_epoch(model, optim, criterion, loader, lbda=None, cbns=None, maps=None, constraint=None): model.train() total = 0 top1 = 0 for i, (batch, label) in enumerate(loader): optim.zero_grad() batch, label = batch.to('cuda'), label.to('cuda') total += batch.size(0) out = model(batch) _, pred = out.max(dim=1) top1 += pred.eq(label).sum() if constraint: reg = lbda * regularizer(model, constraint, cbns, maps) loss = criterion(out, label) + reg else: loss = criterion(out, label) loss.backward() optim.step() if (i % 100 == 0) or (i == len(loader)-1): print('Train | Batch ({}/{}) | Top-1: {:.2f} ({}/{})'.format( i+1, len(loader), float(top1)/total*100, top1, total)) if constraint: truncate_smallbeta(model, cbns) def train(model, train_loader, val_loader, epochs=10, lr=1e-2, name=''): model = model.to('cuda') model.train() optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [int(epochs*0.3), int(epochs*0.6), int(epochs*0.8)], gamma=0.2) criterion = torch.nn.CrossEntropyLoss() for e in range(epochs): train_epoch(model, optimizer, criterion, train_loader) top1, val_loss = test(model, val_loader) print('Epoch {} | Top-1: {:.2f}'.format(e, top1)) torch.save(model, 'ckpt/{}_best.t7'.format(name)) scheduler.step() return model def train_mask(model, train_loader, val_loader, pruner, epochs=10, lr=1e-2, lbda=1.3*1e-8, cbns=None, maps=None, constraint='flops'): model = model.to('cuda') model.train() optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4, nesterov=True) criterion = torch.nn.CrossEntropyLoss() for e in range(epochs): print('Epoch {}'.format(e)) train_epoch(model, optimizer, criterion, train_loader, lbda, cbns, maps, constraint) top1, _ = test(model, val_loader) print('#Filters: {}, #FLOPs: {:.2f}M | Top-1: {:.2f}'.format(num_alive_filters(model), pruner.get_valid_flops()/1000000., top1)) return model def prune_model(model, cbns, pruner): filters_to_prune_per_layer = pruner.get_valid_filters() prune_targets = pruner.pack_pruning_target(filters_to_prune_per_layer, get_segment=True, progressive=True) layers_prunned = {} for layer_index, filter_index in prune_targets: if layer_index not in layers_prunned: layers_prunned[layer_index] = 0 layers_prunned[layer_index] = layers_prunned[layer_index] + (filter_index[1]-filter_index[0]+1) print('Layers that will be prunned: {}'.format(sorted(layers_prunned.items()))) print('Prunning filters..') for layer_index, filter_index in prune_targets: pruner.prune_conv_layer_segment(layer_index, filter_index) def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--datapath", type=str, default='/data') parser.add_argument("--dataset", type=str, default='torchvision.datasets.CIFAR10') parser.add_argument("--epoch", type=int, default=60) parser.add_argument("--name", type=str, default='ft_mbnetv2') parser.add_argument("--model", type=str, default='ft_mbnetv2') parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--lr", type=float, default=1e-2) parser.add_argument("--lbda", type=float, default=3e-9) parser.add_argument("--prune_away", type=float, default=0.5, help='The constraint level in portion to the original network, e.g. 0.5 is prune away 50%') parser.add_argument("--constraint", type=str, default='flops') parser.add_argument("--large_input", action='store_true', default=False) parser.add_argument("--no_grow", action='store_true', default=False) parser.add_argument("--pruner", type=str, default='FilterPrunnerResNet', help='Different network require differnt pruner implementation') args = parser.parse_args() return args if __name__ == '__main__': args = get_args() print(args) model = torch.load(args.model) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_set = eval(args.dataset)(args.datapath, True, transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) val_set = eval(args.dataset)(args.datapath, True, transforms.Compose([ transforms.ToTensor(), normalize, ])) num_train = len(train_set) indices = list(range(num_train)) split = int(np.floor(0.1 * num_train)) np.random.seed(98) np.random.shuffle(indices) train_idx, valid_idx = indices[split:], indices[:split] train_sampler = SubsetRandomSampler(train_idx) valid_sampler = SubsetRandomSampler(valid_idx) test_set = eval(args.dataset)(args.datapath, False, transforms.Compose([ transforms.ToTensor(), normalize, ])) train_loader = torch.utils.data.DataLoader( train_set, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True ) val_loader = torch.utils.data.DataLoader( val_set, batch_size=args.batch_size, sampler=valid_sampler, num_workers=0, pin_memory=True ) test_loader = torch.utils.data.DataLoader( test_set, batch_size=125, shuffle=False, num_workers=0, pin_memory=False ) if 'CIFAR10' in args.dataset: train_set.num_classes = 10 elif 'CIFAR100' in args.dataset: train_set.num_classes = 100 pruner = eval(args.pruner)(model, 'l2_weight', num_cls=train_set.num_classes) flops, num_params = measure_model(pruner.model, pruner, 32) maps = pruner.omap_size cbns = get_cbns(pruner.model) print('Before Pruning | FLOPs: {:.3f}M | #Params: {:.3f}M'.format(flops/1000000., num_params/1000000.)) train_mask(pruner.model, train_loader, val_loader, pruner, epochs=args.epoch, lr=1e-3, lbda=args.lbda, cbns=cbns, maps=maps, constraint=args.constraint) target = int((1.-args.prune_away)*flops) print('Target ({}): {:.3f}M'.format(args.constraint, target/1000000.)) prune_model(pruner.model, cbns, pruner) flops, num_params = measure_model(pruner.model, pruner, 32) print('After Pruning | FLOPs: {:.3f}M | #Params: {:.3f}M'.format(flops/1000000., num_params/1000000.)) if args.no_grow: train(model, train_loader, test_loader, epochs=args.epoch, lr=args.lr, name='{}_pregrow'.format(args.name)) else: if flops < target: ratio = pruner.get_uniform_ratio(target) print(ratio) pruner.uniform_grow(ratio) flops, num_params = measure_model(pruner.model, pruner, 32) print('After Growth | FLOPs: {:.3f}M | #Params: {:.3f}M'.format(flops/1000000., num_params/1000000.)) train(pruner.model, train_loader, test_loader, epochs=args.epoch, lr=args.lr, name=args.name) else: print('Over constraint ({:.3f}M > {:.3f}M), no growth'.format(flops/1000000., target/1000000.))