#coding:utf8 from config import opt import models import os import tqdm from data.dataset import ZhihuData,ZhihuALLData import torch as t import time import fire import torchnet as tnt from torch.utils import data from torch.autograd import Variable from utils.visualize import Visualizer from utils import get_score#,get_optimizer vis = Visualizer(opt.env) ''' 训练 stack-attention 不过这些模型比较差 没用过拟合 ''' def hook():pass def val(model,dataset): dataset.train(False) model.eval() dataloader = data.DataLoader(dataset, batch_size = opt.batch_size, shuffle = False, num_workers = opt.num_workers, pin_memory = True ) predict_label_and_marked_label_list=[] for ii,((title,content),label) in tqdm.tqdm(enumerate(dataloader)): title,content,label = (Variable(title[0].cuda()),Variable(title[1].cuda())),(Variable(content[0].cuda()),Variable(content[1].cuda())),Variable(label.cuda()) score = model(title,content) # !TODO: 优化此处代码 # 1. append # 2. for循环 # 3. topk 代替sort predict = score.data.topk(5,dim=1)[1].cpu().tolist() true_target = label.data.float().topk(5,dim=1)#[1].cpu().tolist()#sort(dim=1,descending=True) true_index=true_target[1][:,:5] true_label=true_target[0][:,:5] tmp= [] for jj in range(label.size(0)): true_index_=true_index[jj] true_label_=true_label[jj] true=true_index_[true_label_>0] tmp.append((predict[jj],true.tolist())) predict_label_and_marked_label_list.extend(tmp) del score dataset.train(True) model.train() scores,prec_,recall_,_ss=get_score(predict_label_and_marked_label_list) return (scores,prec_,recall_,_ss) def main(**kwargs): #动态加全职衰减 origin_weight_decay=1e-5 opt.parse(kwargs,print_=False) if opt.debug:import ipdb;ipdb.set_trace() ################################### # opt.model_names=['MultiCNNTextBNDeep','CNNText_inception', # #'RCNN', # 'LSTMText','CNNText_inception'] # opt.model_paths = ['checkpoints/MultiCNNTextBNDeep_word_0.410330780091','checkpoints/CNNText_tmp_word_0.41096749885', # #'checkpoints/RCNN_word_0.411511574999', # 'checkpoints/LSTMText_word_0.411994005382','checkpoints/CNNText_tmp_char_0.402429167301'] ###################################### # opt.model_names=['MultiCNNTextBNDeep', # #'CNNText_inception', # #'RCNN', # 'LSTMText_boost', # #'CNNText_inception_boost' # ] # opt.model_paths = ['checkpoints/MultiCNNTextBNDeep_word_0.410330780091', # # 'checkpoints/CNNText_tmp_word_0.41096749885', # #'checkpoints/RCNN_word_0.411511574999', # 'checkpoints/LSTMText_word_0.381833388089', # #'checkpoints/CNNText_tmp_0.376364647145' # ] #####################################################3 opt.model_names=['MultiCNNTextBNDeep', #'RCNN', 'LSTMText', 'CNNText_inception', 'CNNText_inception-boost' ] opt.model_paths = ['checkpoints/MultiCNNTextBNDeep_word_0.410330780091', 'checkpoints/LSTMText_word_0.381833388089', 'checkpoints/CNNText_tmp_0.380390420742', #'checkpoints/RCNN_word_0.411511574999', 'checkpoints/CNNText_tmp_0.376364647145' ] # opt.model_path='checkpoints/BoostModel_word_0.412524727048' #************************************** #############################################3 opt.model_names=['MultiCNNTextBNDeep', 'LSTMText', 'MultiCNNTextBNDeep-boost' ] opt.model_paths = ['checkpoints/MultiCNNTextBNDeep_word_0.410330780091', 'checkpoints/LSTMText_word_0.411994005382', None ] opt.model_path='checkpoints/BoostModel2_word_0.410618920827' #********************************************* # opt.model_names=['MultiCNNTextBNDeep','LSTMText','CNNText_inception','RCNN'] # opt.model_paths = ['checkpoints/MultiCNNTextBNDeep_0.37125473788','checkpoints/LSTMText_word_0.381833388089','checkpoints/CNNText_tmp_0.376364647145','checkpoints/RCNN_char_0.3456599248'] model = getattr(models,opt.model)(opt).cuda() # if opt.model_path: # model.load(opt.model_path) print(model) opt.parse(kwargs,print_=True) vis.reinit(opt.env) pre_loss=1.0 lr,lr2=opt.lr,opt.lr2 loss_function = getattr(models,opt.loss)() if opt.all:dataset = ZhihuALLData(opt.train_data_path,opt.labels_path,type_=opt.type_) else :dataset = ZhihuData(opt.train_data_path,opt.labels_path,type_=opt.type_) dataloader = data.DataLoader(dataset, batch_size = opt.batch_size, shuffle = opt.shuffle, num_workers = opt.num_workers, pin_memory = True ) optimizer = model.get_optimizer(opt.lr,opt.lr2,0) loss_meter = tnt.meter.AverageValueMeter() score_meter=tnt.meter.AverageValueMeter() best_score = 0 # pre_score = 0 for epoch in range(opt.max_epoch): loss_meter.reset() score_meter.reset() for ii,((title,content),label) in tqdm.tqdm(enumerate(dataloader)): title,content,label = (Variable(title[0].cuda()),Variable(title[1].cuda())),(Variable(content[0].cuda()),Variable(content[1].cuda())),Variable(label.cuda()) optimizer.zero_grad() score = model(title,content) loss = loss_function(score,opt.weight*label.float()) loss_meter.add(loss.data[0]) loss.backward() optimizer.step() if ii%opt.plot_every ==opt.plot_every-1: if os.path.exists(opt.debug_file): import ipdb ipdb.set_trace() predict = score.data.topk(5,dim=1)[1].cpu().tolist()#(dim=1,descending=True)[1][:,:5].tolist() true_target = label.data.float().topk(5,dim=1)#[1].cpu().tolist()#sort(dim=1,descending=True) true_index=true_target[1][:,:5] true_label=true_target[0][:,:5] predict_label_and_marked_label_list=[] for jj in range(label.size(0)): true_index_=true_index[jj] true_label_=true_label[jj] true=true_index_[true_label_>0] predict_label_and_marked_label_list.append((predict[jj],true.tolist())) score_,prec_,recall_,_ss=get_score(predict_label_and_marked_label_list) score_meter.add(score_) vis.vis.text('prec:%s,recall:%s,score:%s,a:%s' %(prec_,recall_,score_,_ss),win='tmp') vis.plot('scores', score_meter.value()[0]) #eval() vis.plot('loss', loss_meter.value()[0]) # 随机展示一个输出的分布 k = t.randperm(label.size(0))[0] output=t.nn.functional.sigmoid(score) # vis.vis.histogram( # output.data[k].view(-1).cpu(), win=u'output_hist', opts=dict # (title='output_hist')) # print "epoch:%4d/%4d,time: %.8f,loss: %.8f " %(epoch,ii,time.time()-start,loss_meter.value()[0]) if ii%opt.decay_every == opt.decay_every-1: del loss scores,prec_,recall_ ,_ss= val(model,dataset) if scores>best_score: best_score = scores best_path = model.save(name = str(scores),new=True) vis.log({' epoch:':epoch,' lr: ':lr,'scores':scores,'prec':prec_,'recall':recall_,'ss':_ss,'scores_train':score_meter.value()[0],'loss':loss_meter.value()[0]}) if scores < best_score: model.load(best_path,change_opt=False) #lr = lr*opt.lr_decay #optimizer = model.get_optimizer(lr) lr = lr * opt.lr_decay # 第二种降低学习率的方法:不会有moment等的丢失 if lr2==0:lr2=2e-4 else : lr2 = lr2*opt.lr_decay optimizer = model.get_optimizer(lr,lr2,0) origin_weight_decay=5*origin_weight_decay # optimizer = model.get_optimizer(lr,lr2,0,weight_decay=origin_weight_decay) # origin_weight_decay=5*origin_weight_decay # for param_group in optimizer.param_groups: # param_group['lr'] *= opt.lr_decay # if param_group['lr'] ==0: # param_group['lr'] = 1e-4 pre_loss = loss_meter.value()[0] # pre_score = score_meter.value()[0] # pre_score = scores loss_meter.reset() score_meter.reset() if lr < opt.min_lr: break # model.save() if __name__=="__main__": fire.Fire()