from IPython import embed import torch from torch.autograd import Variable from torchvision import models import cv2 import sys import numpy as np import torchvision import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import dataset from prune import * import argparse from operator import itemgetter from heapq import nsmallest import time import torch.utils.model_zoo as model_zoo import mmd import math from tools import * BATCH = 16 target_name = 'webcam' class DANNet(nn.Module): def __init__(self): super(DANNet, self).__init__() model = models.vgg16(pretrained=True) #False self.features = model.features for param in self.features.parameters(): #NOTE: prune:True // finetune:False param.requires_grad = True self.classifier = nn.Sequential( nn.Dropout(), nn.Linear(25088, 4096), nn.ReLU(inplace=True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(inplace=True), ) self.cls_fc = nn.Linear(4096, 31) def forward(self, source, target): loss = 0 source = self.features(source) source = source.view(source.size(0), -1) source = self.classifier(source) if self.training == True: target = self.features(target) target = target.view(target.size(0), -1) target = self.classifier(target) loss += mmd.mmd_rbf_noaccelerate(source, target) source = self.cls_fc(source) return source, loss class FilterPrunner: def __init__(self, model): self.model = model self.reset() def reset(self): # self.activations = [] # self.gradients = [] # self.grad_index = 0 # self.activation_to_layer = {} self.filter_ranks_1 = {} self.filter_ranks_2 = {} def forward(self, x, x_target): # NOTE: whether to add target data loss = 0 self.activations1 = [] self.activations2 = [] self.gradients = [] self.grad_index_1 = 0 self.grad_index_2 = 0 self.activation_to_layer_1 = {} self.activation_to_layer_2 = {} activation1_index = 0 for layer, (name, module) in enumerate(self.model.features._modules.items()): x = module(x) if isinstance(module, torch.nn.modules.conv.Conv2d): x.register_hook(self.compute_rank_1) self.activations1.append(x) self.activation_to_layer_1[activation1_index] = layer activation1_index += 1 activation2_index = 0 for layer, (name, module) in enumerate(self.model.features._modules.items()): x_target = module(x_target) if isinstance(module, torch.nn.modules.conv.Conv2d): x_target.register_hook(self.compute_rank_2) self.activations2.append(x_target) self.activation_to_layer_2[activation2_index] = layer activation2_index += 1 x = self.model.classifier(x.view(x.size(0), -1)) x_target = self.model.classifier(x_target.view(x_target.size(0), -1)) loss += mmd.mmd_rbf_noaccelerate(x, x_target) source_pred = self.model.cls_fc(x) return source_pred, loss def compute_rank_1(self, grad): activation1_index = len(self.activations1) - self.grad_index_1 - 1 activation1 = self.activations1[activation1_index] values1 = torch.sum((activation1 * grad), dim = 0).sum(dim=1).sum(dim=1)[:].data # Normalize the rank by the filter dimensions values1 = \ values1 / (activation1.size(0) * activation1.size(2) * activation1.size(3)) if activation1_index not in self.filter_ranks_1: self.filter_ranks_1[activation1_index] = \ torch.FloatTensor(activation1.size(1)).zero_().cuda() self.filter_ranks_1[activation1_index] = values1 self.grad_index_1 += 1 def compute_rank_2(self, grad): activation2_index = len(self.activations2) - self.grad_index_2 - 1 activation2 = self.activations2[activation2_index] values2 = torch.sum((activation2 * grad), dim = 0).sum(dim=1).sum(dim=1)[:].data values2 = \ values2 / (activation2.size(0) * activation2.size(2) * activation2.size(3)) if activation2_index not in self.filter_ranks_2: self.filter_ranks_2[activation2_index] = \ torch.FloatTensor(activation2.size(1)).zero_().cuda() self.filter_ranks_2[activation2_index] = values2 self.grad_index_2 += 1 def lowest_ranking_filters(self, num): data_1 = [] for i in sorted(self.filter_ranks_1.keys()): for j in range(self.filter_ranks_1[i].size(0)): data_1.append((self.activation_to_layer_1[i], j, self.filter_ranks_1[i][j])) data_2 = [] for i in sorted(self.filter_ranks_2.keys()): for j in range(self.filter_ranks_2[i].size(0)): data_2.append((self.activation_to_layer_2[i], j, self.filter_ranks_2[i][j])) data_3 = [] data_3.extend(data_1) data_3.extend(data_2) dic = {} c = nsmallest(num*2, data_3, itemgetter(2)) for i in range(len(c)): nm = str(c[i][0]) + '_' + str(c[i][1]) if dic.get(nm)!=None: dic[nm] = min(dic[nm], c[i][2].item()) else: dic[nm] = c[i][2].item() newc = [] for i in range(len(list(dic.items()))): lyer = int(list(dic.items())[i][0].split('_')[0]) filt = int(list(dic.items())[i][0].split('_')[1]) val = torch.tensor(list(dic.items())[i][1]) newc.append((lyer, filt, val)) return nsmallest(num, newc, itemgetter(2)) def normalize_ranks_per_layer(self): for i in self.filter_ranks_1: v = torch.abs(self.filter_ranks_1[i]) v = v / np.sqrt(torch.sum(v * v)).cuda() self.filter_ranks_1[i] = v.cpu() for i in self.filter_ranks_2: v = torch.abs(self.filter_ranks_2[i]) v = v / np.sqrt(torch.sum(v * v)).cuda() self.filter_ranks_2[i] = v.cpu() def get_prunning_plan(self, num_filters_to_prune): filters_to_prune = self.lowest_ranking_filters(num_filters_to_prune) # After each of the k filters are prunned, # the filter index of the next filters change since the model is smaller. filters_to_prune_per_layer = {} for (l, f, _) in filters_to_prune: if l not in filters_to_prune_per_layer: filters_to_prune_per_layer[l] = [] filters_to_prune_per_layer[l].append(f) for l in filters_to_prune_per_layer: filters_to_prune_per_layer[l] = sorted(filters_to_prune_per_layer[l]) for i in range(len(filters_to_prune_per_layer[l])): filters_to_prune_per_layer[l][i] = filters_to_prune_per_layer[l][i] - i filters_to_prune = [] for l in filters_to_prune_per_layer: for i in filters_to_prune_per_layer[l]: filters_to_prune.append((l, i)) return filters_to_prune class PrunningFineTuner_VGGnet: def __init__(self, train_path, test_path, model): self.source_loader = dataset.loader(train_path) self.target_train_loader = dataset.loader(test_path) self.target_test_loader = dataset.test_loader(test_path) self.model = model self.criterion = torch.nn.CrossEntropyLoss() self.prunner = FilterPrunner(self.model) self.model.train() self.len_source_loader = len(self.source_loader) self.len_target_loader = len(self.target_train_loader) self.len_source_dataset = len(self.source_loader.dataset) self.len_target_dataset = len(self.target_test_loader.dataset) self.max_correct = 0 self.littlemax_correct = 0 self.cur_model = None def test(self): self.model.eval() test_loss = 0 correct = 0 for data, target in self.target_test_loader: data, target = data.cuda(), target.cuda() data, target = Variable(data, volatile=True), Variable(target) s_output, t_output = self.model(data, data) test_loss += F.nll_loss(F.log_softmax(s_output, dim = 1), target, size_average=False).item() # sum up batch loss pred = s_output.data.max(1)[1] # get the index of the max log-probability correct += pred.eq(target.data.view_as(pred)).cpu().sum() test_loss /= self.len_target_dataset print('\n{} set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( target_name, test_loss, correct, self.len_target_dataset, 100. * correct / self.len_target_dataset)) return correct def train(self, optimizer = None, epoches = 10, save_name=None): for i in range(epoches): print("Epoch: ", i+1) self.train_epoch(optimizer, i+1, epoches+1) cur_correct = self.test() if cur_correct >= self.littlemax_correct: self.littlemax_correct = cur_correct self.cur_model = self.model print("write cur bset model") if cur_correct > self.max_correct: self.max_correct = cur_correct if save_name: torch.save(self.model, str(save_name)) print('amazon to webcam max correct: {} max accuracy{: .2f}%\n'.format( self.max_correct, 100.0 * self.max_correct / self.len_target_dataset)) print("Finished fine tuning.") def train_epoch(self, optimizer = None, epoch = 0, epoches = 0, rank_filters = False): LEARNING_RATE = 0.01 / math.pow((1 + 10 * (epoch - 1) / epoches), 0.75) # 10* optimizer = torch.optim.SGD([ {'params': self.model.features.parameters()}, {'params': self.model.classifier.parameters()}, {'params': self.model.cls_fc.parameters(), 'lr': LEARNING_RATE}, ], lr=LEARNING_RATE / 5, momentum=0.9, weight_decay=5e-4) iter_source = iter(self.source_loader) iter_target = iter(self.target_train_loader) self.model.train() for i in range(1, self.len_source_loader): data_source, label_source = iter_source.next() data_target, _ = iter_target.next() if len(data_target < BATCH): iter_target = iter(self.target_train_loader) data_target, _ = iter_target.next() data_source, label_source = data_source.cuda(), label_source.cuda() data_target = data_target.cuda() data_source, label_source = Variable(data_source), Variable(label_source) data_target = Variable(data_target) self.model.zero_grad() if rank_filters: # prune # add cls_loss and mmd_loss pred, loss_mmd = self.prunner.forward(data_source, data_target) loss_cls = F.nll_loss(F.log_softmax(pred, dim=1), label_source) gamma = 2 / (1 + math.exp(-10 * (epoch) / epoches)) - 1 loss = loss_cls + gamma * loss_mmd loss.backward() print('prune loss: {:.5f} {:.5f}'.format(loss_cls.item(), loss_mmd.item())) else: label_source_pred, loss_mmd = self.model(data_source, data_target) loss_cls = F.nll_loss(F.log_softmax(label_source_pred, dim=1), label_source) gamma = 2 / (1 + math.exp(-10 * (epoch) / epoches)) - 1 loss = loss_cls + gamma * loss_mmd loss.backward() optimizer.step() if i % 50 == 0: print('Train Epoch:{} [{}/{}({:.0f}%)]\tlr:{:.5f}\tLoss: {:.6f}\tsoft_Loss: {:.6f}\tmmd_Loss: {:.6f}'.format( epoch, i * len(data_source), self.len_source_dataset, 100. * i / self.len_source_loader, LEARNING_RATE, loss.item(), loss_cls.item(), loss_mmd.item())) def get_candidates_to_prune(self, num_filters_to_prune): self.prunner.reset() self.train_epoch(epoch = 1, epoches = 10, rank_filters = True) self.prunner.normalize_ranks_per_layer() return self.prunner.get_prunning_plan(num_filters_to_prune) def total_num_filters(self): filters = 0 for name, module in self.model.features._modules.items(): if isinstance(module, torch.nn.modules.conv.Conv2d): filters = filters + module.out_channels return filters def prune(self, perc_ind, perchan): #Get the accuracy before prunning self.test() self.model.train() #Make sure all the layers are trainable for param in self.model.features.parameters(): param.requires_grad = True number_of_filters = self.total_num_filters() # the total nums of channels in convs num_filters_to_prune_per_iteration = perchan # 20 iterations = int(float(number_of_filters) / num_filters_to_prune_per_iteration) perc = perc_ind/10 # 2.0/10 iterations = int(iterations * perc) # set the percentage of prunning 80% print("Number of prunning iterations to reduce filters", iterations) for _ in range(iterations): print("Ranking filters.. ") prune_targets = self.get_candidates_to_prune(num_filters_to_prune_per_iteration) layers_prunned = {} for layer_index, filter_index in prune_targets: if layer_index not in layers_prunned: layers_prunned[layer_index] = 0 layers_prunned[layer_index] = layers_prunned[layer_index] + 1 print("Layers that will be prunned", layers_prunned) print("Prunning filters.. ") if self.cur_model: print("load cur best") model = self.cur_model.cpu() else: model = self.model.cpu() for layer_index, filter_index in prune_targets: print(layer_index, filter_index) model = prune_vgg16_conv_layer(model, layer_index, filter_index) self.model = model.cuda() message = str(100*float(self.total_num_filters()) / number_of_filters) + "%" print("Filters prunned", str(message)) self.test() print("Fine tuning to recover from prunning iteration.") optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9) self.littlemax_correct = 0 self.train(optimizer, epoches = 5) #10 print("Finished. Going to fine tune the model a bit more") self.max_correct = 0 self.train(optimizer, epoches = 20, save_name = "model_prunned_clsmmd_{:.1f}".format(perc)) def total_num_channels(model): filters = 0 for name, module in model.features._modules.items(): if isinstance(module, torch.nn.Conv2d): print(name, module, module.out_channels) filters = filters + module.out_channels print('total nums of channels in convs: %d'%(filters)) def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--train", dest="train", action="store_true") parser.add_argument("--prune", dest="prune", action="store_true") parser.add_argument("--train_path", type = str, default = "train") parser.add_argument("--test_path", type = str, default = "test") parser.set_defaults(train=False) parser.set_defaults(prune=False) args = parser.parse_args() return args if __name__ == '__main__': args = get_args() args.train_path = '/home/xxx/datasets/office_31/amazon' args.test_path = '/home/xxx/datasets/office_31/webcam' if args.train: model = DANNet().cuda() elif args.prune: #model = torch.load("./model_prunned_clsmmd_0.6").cuda() fine_tuner = PrunningFineTuner_VGGnet(args.train_path, args.test_path, model) fine_tuner.test() total_num_channels(model) print_model_parm_flops(model) print_model_parm_nums(model) #embed() fine_tuner = PrunningFineTuner_VGGnet(args.train_path, args.test_path, model) if args.train: #model = torch.load("./model_prunned_clsmmd_0.4").cuda() fine_tuner = PrunningFineTuner_VGGnet(args.train_path, args.test_path, model) fine_tuner.test() fine_tuner.train(epoches = 10, save_name = 'model_t') elif args.prune: for perc_ind, perchan in zip([2.0], [32]): fine_tuner = PrunningFineTuner_VGGnet(args.train_path, args.test_path, model) fine_tuner.prune(perc_ind, perchan)