import torch.nn.functional as F import torch import random import numpy as np from fastNLP import Const from fastNLP import CrossEntropyLoss from fastNLP import AccuracyMetric from fastNLP import Tester import os from fastNLP import logger def should_mask(name, t=''): if 'bias' in name: return False if 'embedding' in name: splited = name.split('.') if splited[-1]!='weight': return False if 'embedding' in splited[-2]: return False if 'c0' in name: return False if 'h0' in name: return False if 'output' in name and t not in name: return False return True def get_init_mask(model): init_masks = {} for name, param in model.named_parameters(): if should_mask(name): init_masks[name+'.mask'] = torch.ones_like(param) # logger.info(init_masks[name+'.mask'].requires_grad) return init_masks def set_seed(seed): random.seed(seed) np.random.seed(seed+100) torch.manual_seed(seed+200) torch.cuda.manual_seed_all(seed+300) def get_parameters_size(model): result = {} for name,p in model.state_dict().items(): result[name] = p.size() return result def prune_by_proportion_model(model,proportion,task): # print('this time prune to ',proportion*100,'%') for name, p in model.named_parameters(): # print(name) if not should_mask(name,task): continue tensor = p.data.cpu().numpy() index = np.nonzero(model.mask[task][name+'.mask'].data.cpu().numpy()) # print(name,'alive count',len(index[0])) alive = tensor[index] # print('p and mask size:',p.size(),print(model.mask[task][name+'.mask'].size())) percentile_value = np.percentile(abs(alive), (1 - proportion) * 100) # tensor = p # index = torch.nonzero(model.mask[task][name+'.mask']) # # print('nonzero len',index) # alive = tensor[index] # print('alive size:',alive.shape) # prune_by_proportion_model() # percentile_value = torch.topk(abs(alive), int((1-proportion)*len(index[0]))).values # print('the',(1-proportion)*len(index[0]),'th big') # print('threshold:',percentile_value) prune_by_threshold_parameter(p, model.mask[task][name+'.mask'],percentile_value) # for def prune_by_proportion_model_global(model,proportion,task): # print('this time prune to ',proportion*100,'%') alive = None for name, p in model.named_parameters(): # print(name) if not should_mask(name,task): continue tensor = p.data.cpu().numpy() index = np.nonzero(model.mask[task][name+'.mask'].data.cpu().numpy()) # print(name,'alive count',len(index[0])) if alive is None: alive = tensor[index] else: alive = np.concatenate([alive,tensor[index]],axis=0) percentile_value = np.percentile(abs(alive), (1 - proportion) * 100) for name, p in model.named_parameters(): if should_mask(name,task): prune_by_threshold_parameter(p, model.mask[task][name+'.mask'],percentile_value) def prune_by_threshold_parameter(p, mask, threshold): p_abs = torch.abs(p) new_mask = (p_abs > threshold).float() # print(mask) mask[:]*=new_mask def one_time_train_and_prune_single_task(trainer,PRUNE_PER, optimizer_init_state_dict=None, model_init_state_dict=None, is_global=None, ): from fastNLP import Trainer trainer.optimizer.load_state_dict(optimizer_init_state_dict) trainer.model.load_state_dict(model_init_state_dict) # print('metrics:',metrics.__dict__) # print('loss:',loss.__dict__) # print('trainer input:',task.train_set.get_input_name()) # trainer = Trainer(model=model, train_data=task.train_set, dev_data=task.dev_set, loss=loss, metrics=metrics, # optimizer=optimizer, n_epochs=EPOCH, batch_size=BATCH, device=device,callbacks=callbacks) trainer.train(load_best_model=True) # tester = Tester(task.train_set, model, metrics, BATCH, device=device, verbose=1,use_tqdm=False) # print('FOR DEBUG: test train_set:',tester.test()) # print('**'*20) # if task.test_set: # tester = Tester(task.test_set, model, metrics, BATCH, device=device, verbose=1) # tester.test() if is_global: prune_by_proportion_model_global(trainer.model, PRUNE_PER, trainer.model.now_task) else: prune_by_proportion_model(trainer.model, PRUNE_PER, trainer.model.now_task) # def iterative_train_and_prune_single_task(get_trainer,ITER,PRUNE,is_global=False,save_path=None): def iterative_train_and_prune_single_task(get_trainer,args,model,train_set,dev_set,test_set,device,save_path=None): ''' :param trainer: :param ITER: :param PRUNE: :param is_global: :param save_path: should be a dictionary which will be filled with mask and state dict :return: ''' from fastNLP import Trainer import torch import math import copy PRUNE = args.prune ITER = args.iter trainer = get_trainer(args,model,train_set,dev_set,test_set,device) optimizer_init_state_dict = copy.deepcopy(trainer.optimizer.state_dict()) model_init_state_dict = copy.deepcopy(trainer.model.state_dict()) if save_path is not None: if not os.path.exists(save_path): os.makedirs(save_path) # if not os.path.exists(os.path.join(save_path, 'model_init.pkl')): # f = open(os.path.join(save_path, 'model_init.pkl'), 'wb') # torch.save(trainer.model.state_dict(),f) mask_count = 0 model = trainer.model task = trainer.model.now_task for name, p in model.mask[task].items(): mask_count += torch.sum(p).item() init_mask_count = mask_count logger.info('init mask count:{}'.format(mask_count)) # logger.info('{}th traning mask count: {} / {} = {}%'.format(i, mask_count, init_mask_count, # mask_count / init_mask_count * 100)) prune_per_iter = math.pow(PRUNE, 1 / ITER) for i in range(ITER): trainer = get_trainer(args,model,train_set,dev_set,test_set,device) one_time_train_and_prune_single_task(trainer,prune_per_iter,optimizer_init_state_dict,model_init_state_dict) if save_path is not None: f = open(os.path.join(save_path,task+'_mask_'+str(i)+'.pkl'),'wb') torch.save(model.mask[task],f) mask_count = 0 for name, p in model.mask[task].items(): mask_count += torch.sum(p).item() logger.info('{}th traning mask count: {} / {} = {}%'.format(i,mask_count,init_mask_count,mask_count/init_mask_count*100)) def get_appropriate_cuda(task_scale='s'): if task_scale not in {'s','m','l'}: logger.info('task scale wrong!') exit(2) import pynvml pynvml.nvmlInit() total_cuda_num = pynvml.nvmlDeviceGetCount() for i in range(total_cuda_num): logger.info(i) handle = pynvml.nvmlDeviceGetHandleByIndex(i) # 这里的0是GPU id memInfo = pynvml.nvmlDeviceGetMemoryInfo(handle) utilizationInfo = pynvml.nvmlDeviceGetUtilizationRates(handle) logger.info(i, 'mem:', memInfo.used / memInfo.total, 'util:',utilizationInfo.gpu) if memInfo.used / memInfo.total < 0.15 and utilizationInfo.gpu <0.2: logger.info(i,memInfo.used / memInfo.total) return 'cuda:'+str(i) if task_scale=='s': max_memory=2000 elif task_scale=='m': max_memory=6000 else: max_memory = 9000 max_id = -1 for i in range(total_cuda_num): handle = pynvml.nvmlDeviceGetHandleByIndex(0) # 这里的0是GPU id memInfo = pynvml.nvmlDeviceGetMemoryInfo(handle) utilizationInfo = pynvml.nvmlDeviceGetUtilizationRates(handle) if max_memory < memInfo.free: max_memory = memInfo.free max_id = i if id == -1: logger.info('no appropriate gpu, wait!') exit(2) return 'cuda:'+str(max_id) # if memInfo.used / memInfo.total < 0.5: # return def print_mask(mask_dict): def seq_mul(*X): res = 1 for x in X: res*=x return res for name,p in mask_dict.items(): total_size = seq_mul(*p.size()) unmasked_size = len(np.nonzero(p)) print(name,':',unmasked_size,'/',total_size,'=',unmasked_size/total_size*100,'%') print() def check_words_same(dataset_1,dataset_2,field_1,field_2): if len(dataset_1[field_1]) != len(dataset_2[field_2]): logger.info('CHECK: example num not same!') return False for i, words in enumerate(dataset_1[field_1]): if len(dataset_1[field_1][i]) != len(dataset_2[field_2][i]): logger.info('CHECK {} th example length not same'.format(i)) logger.info('1:{}'.format(dataset_1[field_1][i])) logger.info('2:'.format(dataset_2[field_2][i])) return False # for j,w in enumerate(words): # if dataset_1[field_1][i][j] != dataset_2[field_2][i][j]: # print('CHECK', i, 'th example has words different!') # print('1:',dataset_1[field_1][i]) # print('2:',dataset_2[field_2][i]) # return False logger.info('CHECK: totally same!') return True def get_now_time(): import time from datetime import datetime, timezone, timedelta dt = datetime.utcnow() # print(dt) tzutc_8 = timezone(timedelta(hours=8)) local_dt = dt.astimezone(tzutc_8) result = ("_{}_{}_{}__{}_{}_{}".format(local_dt.year, local_dt.month, local_dt.day, local_dt.hour, local_dt.minute, local_dt.second)) return result def get_bigrams(words): result = [] for i,w in enumerate(words): if i!=len(words)-1: result.append(words[i]+words[i+1]) else: result.append(words[i]+'<end>') return result def print_info(*inp,islog=False,sep=' '): from fastNLP import logger if islog: print(*inp,sep=sep) else: inp = sep.join(map(str,inp)) logger.info(inp) def better_init_rnn(rnn,coupled=False): import torch.nn as nn if coupled: repeat_size = 3 else: repeat_size = 4 # print(list(rnn.named_parameters())) if hasattr(rnn,'num_layers'): for i in range(rnn.num_layers): nn.init.orthogonal(getattr(rnn,'weight_ih_l'+str(i)).data) weight_hh_data = torch.eye(rnn.hidden_size) weight_hh_data = weight_hh_data.repeat(1, repeat_size) with torch.no_grad(): getattr(rnn,'weight_hh_l'+str(i)).set_(weight_hh_data) nn.init.constant(getattr(rnn,'bias_ih_l'+str(i)).data, val=0) nn.init.constant(getattr(rnn,'bias_hh_l'+str(i)).data, val=0) if rnn.bidirectional: for i in range(rnn.num_layers): nn.init.orthogonal(getattr(rnn, 'weight_ih_l' + str(i)+'_reverse').data) weight_hh_data = torch.eye(rnn.hidden_size) weight_hh_data = weight_hh_data.repeat(1, repeat_size) with torch.no_grad(): getattr(rnn, 'weight_hh_l' + str(i)+'_reverse').set_(weight_hh_data) nn.init.constant(getattr(rnn, 'bias_ih_l' + str(i)+'_reverse').data, val=0) nn.init.constant(getattr(rnn, 'bias_hh_l' + str(i)+'_reverse').data, val=0) else: nn.init.orthogonal(rnn.weight_ih.data) weight_hh_data = torch.eye(rnn.hidden_size) weight_hh_data = weight_hh_data.repeat(repeat_size,1) with torch.no_grad(): rnn.weight_hh.set_(weight_hh_data) # The bias is just set to zero vectors. print('rnn param size:{},{}'.format(rnn.weight_hh.size(),type(rnn))) if rnn.bias: nn.init.constant(rnn.bias_ih.data, val=0) nn.init.constant(rnn.bias_hh.data, val=0) # print(list(rnn.named_parameters()))