from models import * from utils.utils import * import numpy as np from copy import deepcopy from test import test from terminaltables import AsciiTable import time from utils.prune_utils import * import argparse if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path') parser.add_argument('--data', type=str, default='data/coco.data', help='*.data file path') parser.add_argument('--weights', type=str, default='weights/last.pt', help='sparse model weights') parser.add_argument('--percent', type=float, default=0.8, help='channel prune percent') parser.add_argument('--img_size', type=int, default=416, help='inference size (pixels)') opt = parser.parse_args() print(opt) img_size = opt.img_size device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Darknet(opt.cfg, (img_size, img_size)).to(device) if opt.weights.endswith('.pt'): model.load_state_dict(torch.load(opt.weights)['model']) else: load_darknet_weights(model, opt.weights) print('\nloaded weights from ',opt.weights) eval_model = lambda model:test(opt.cfg, opt.data, weights=opt.weights, batch_size=16, img_size=img_size, iou_thres=0.5, conf_thres=0.001, nms_thres=0.5, save_json=False, model=model) obtain_num_parameters = lambda model:sum([param.nelement() for param in model.parameters()]) print("\nlet's test the original model first:") with torch.no_grad(): origin_model_metric = eval_model(model) origin_nparameters = obtain_num_parameters(model) CBL_idx, Conv_idx, prune_idx= parse_module_defs(model.module_defs) bn_weights = gather_bn_weights(model.module_list, prune_idx) sorted_bn = torch.sort(bn_weights)[0] # 避免剪掉所有channel的最高阈值(每个BN层的gamma的最大值的最小值即为阈值上限) highest_thre = [] for idx in prune_idx: highest_thre.append(model.module_list[idx][1].weight.data.abs().max().item()) highest_thre = min(highest_thre) # 找到highest_thre对应的下标对应的百分比 percent_limit = (sorted_bn==highest_thre).nonzero().item()/len(bn_weights) print(f'Suggested Gamma threshold should be less than {highest_thre:.4f}.') print(f'The corresponding prune ratio is {percent_limit:.3f}, but you can set higher.') #%% def prune_and_eval(model, sorted_bn, percent=.0): model_copy = deepcopy(model) thre_index = int(len(sorted_bn) * percent) thre = sorted_bn[thre_index] print(f'Gamma value that less than {thre:.4f} are set to zero!') remain_num = 0 for idx in prune_idx: bn_module = model_copy.module_list[idx][1] mask = obtain_bn_mask(bn_module, thre) remain_num += int(mask.sum()) bn_module.weight.data.mul_(mask) print("let's test the current model!") with torch.no_grad(): mAP = eval_model(model_copy)[0][2] print(f'Number of channels has been reduced from {len(sorted_bn)} to {remain_num}') print(f'Prune ratio: {1-remain_num/len(sorted_bn):.3f}') print(f"mAP of the 'pruned' model is {mAP:.4f}") return thre percent = opt.percent print('the required prune percent is', percent) threshold = prune_and_eval(model, sorted_bn, percent) #%% def obtain_filters_mask(model, thre, CBL_idx, prune_idx): pruned = 0 total = 0 num_filters = [] filters_mask = [] for idx in CBL_idx: bn_module = model.module_list[idx][1] if idx in prune_idx: mask = obtain_bn_mask(bn_module, thre).cpu().numpy() remain = int(mask.sum()) pruned = pruned + mask.shape[0] - remain if remain == 0: # print("Channels would be all pruned!") # raise Exception max_value = bn_module.weight.data.abs().max() mask = obtain_bn_mask(bn_module, max_value).cpu().numpy() remain = int(mask.sum()) pruned = pruned + mask.shape[0] - remain print(f'layer index: {idx:>3d} \t total channel: {mask.shape[0]:>4d} \t ' f'remaining channel: {remain:>4d}') else: mask = np.ones(bn_module.weight.data.shape) remain = mask.shape[0] total += mask.shape[0] num_filters.append(remain) filters_mask.append(mask.copy()) prune_ratio = pruned / total print(f'Prune channels: {pruned}\tPrune ratio: {prune_ratio:.3f}') return num_filters, filters_mask num_filters, filters_mask = obtain_filters_mask(model, threshold, CBL_idx, prune_idx) #%% CBLidx2mask = {idx: mask.astype('float32') for idx, mask in zip(CBL_idx, filters_mask)} pruned_model = prune_model_keep_size2(model, CBL_idx, CBL_idx, CBLidx2mask) print("\nnow prune the model but keep size,(actually add offset of BN beta to next layer), let's see how the mAP goes") with torch.no_grad(): eval_model(pruned_model) #%% compact_module_defs = deepcopy(model.module_defs) for idx, num in zip(CBL_idx, num_filters): assert compact_module_defs[idx]['type'] == 'convolutional' compact_module_defs[idx]['filters'] = str(num) #%% compact_model = Darknet([model.hyperparams.copy()] + compact_module_defs, (img_size, img_size)).to(device) compact_nparameters = obtain_num_parameters(compact_model) init_weights_from_loose_model(compact_model, pruned_model, CBL_idx, Conv_idx, CBLidx2mask) #%% random_input = torch.rand((1, 3, img_size, img_size)).to(device) def obtain_avg_forward_time(input, model, repeat=200): model.eval() start = time.time() with torch.no_grad(): for i in range(repeat): output = model(input)[0] avg_infer_time = (time.time() - start) / repeat return avg_infer_time, output print('\ntesting avg forward time...') pruned_forward_time, pruned_output = obtain_avg_forward_time(random_input, pruned_model) compact_forward_time, compact_output = obtain_avg_forward_time(random_input, compact_model) diff = (pruned_output-compact_output).abs().gt(0.001).sum().item() if diff > 0: print('Something wrong with the pruned model!') #%% # 在测试集上测试剪枝后的模型, 并统计模型的参数数量 print('testing the mAP of final pruned model') with torch.no_grad(): compact_model_metric = eval_model(compact_model) #%% # 比较剪枝前后参数数量的变化、指标性能的变化 metric_table = [ ["Metric", "Before", "After"], ["mAP", f'{origin_model_metric[0][2]:.6f}', f'{compact_model_metric[0][2]:.6f}'], ["Parameters", f"{origin_nparameters}", f"{compact_nparameters}"], ["Inference", f'{pruned_forward_time:.4f}', f'{compact_forward_time:.4f}'] ] print(AsciiTable(metric_table).table) #%% # 生成剪枝后的cfg文件并保存模型 pruned_cfg_name = opt.cfg.replace('/', f'/prune_{percent}_') pruned_cfg_file = write_cfg(pruned_cfg_name, [model.hyperparams.copy()] + compact_module_defs) print(f'Config file has been saved: {pruned_cfg_file}') compact_model_name = opt.weights.replace('/', f'/prune_{percent}_') if compact_model_name.endswith('.pt'): compact_model_name = compact_model_name.replace('.pt', '.weights') save_weights(compact_model, compact_model_name) print(f'Compact model has been saved: {compact_model_name}')