# coding=utf8 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ResNet Train/Eval module. """ import time import sys import os # import glob import numpy as np import dataloader import json from tqdm import tqdm from collections import Counter import densenet import resnet from PIL import Image import torchvision import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.autograd import Variable from torch.utils.data import DataLoader import torch.nn.functional as F from sklearn.metrics import roc_auc_score from tools import parse from glob import glob from skimage import measure import sys os.environ["CUDA_VISIBLE_DEVICES"] = "0" reload(sys) sys.setdefaultencoding('utf8') import traceback # from moxing.framework import file args = parse.args # anchor大小 args.anchors = [8, 12, 18, 27, 40, 60] args.stride = 8 args.image_size = [1300,64] #args.image_size = [64, 64] datadir = parse.datadir #datadir = '/home/work/user-job-dir/ocr_densenet/data' data_dir_obs = 's3://densenet-214/data/dataset' model_dir='/home/work/user-job-dir/model_15' model_dir_obs='s3://densenet-214/out/model_15' ###事先执行##### #print 'datadir is exist ? ', os.path.exists(datadir) print 'datadir path', args.data_dir print 'data_dir_obs path', data_dir_obs print 'current path', os.getcwd() print '================================================' print os.listdir(os.getcwd()) # file.copy_parallel(data_dir_obs, args.data_dir) # file.copy_parallel(model_dir_obs, model_dir) print 'batch-size:',args.batch_size print '===================data_dir=============================' # print os.listdir(args.data_dir) class DenseNet121(nn.Module): """Model modified. The architecture of our model is the same as standard DenseNet121 except the classifier layer which has an additional sigmoid function. """ def __init__(self, out_size): super(DenseNet121, self).__init__() self.inplanes = 1024 self.densenet121 = densenet.densenet121(pretrained=False, small=args.small) num_ftrs = self.densenet121.classifier.in_features self.classifier_font = nn.Sequential( # 这里可以用fc做分类 # nn.Linear(num_ftrs, out_size) # 这里可以用1×1卷积做分类 nn.Conv2d(num_ftrs, out_size, kernel_size=1, bias=False) ) self.train_params = [] self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2) # 用于构建Resnet中的4个blocks def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) # 定义数据在层之间的流动顺序 def forward(self, x, phase='train'): feats = self.densenet121(x) # (32, 1024, 2, 16) if not args.small: feats = F.max_pool2d(feats, kernel_size=2, stride=2) # (32, 1024, 1, 8) out = self.classifier_font(feats) # (32, 1824, 1, 8) out_size = out.size() # print out.size() out = out.view(out.size(0), out.size(1), -1) # (32, 1824, 8) # print out.size() if phase == 'train': out = F.adaptive_max_pool1d(out, output_size=(1)).view(out.size(0), -1) # (32, 1824) return out else: out = out.transpose(1, 2).contiguous() out = out.view(out_size[0], out_size[2], out_size[3], out_size[1]) # (32, 1, 8, 1824) return out, feats class Loss(nn.Module): def __init__(self): super(Loss, self).__init__() self.classify_loss = nn.BCELoss() self.sigmoid = nn.Sigmoid() self.regress_loss = nn.SmoothL1Loss() def forward(self, font_output, font_target, weight=None, use_hard_mining=False): font_output = self.sigmoid(font_output) font_loss = F.binary_cross_entropy(font_output, font_target, weight) # hard_mining if use_hard_mining: font_output = font_output.view(-1) font_target = font_target.view(-1) pos_index = font_target > 0.5 neg_index = font_target == 0 # pos pos_output = font_output[pos_index] pos_target = font_target[pos_index] num_hard_pos = max(len(pos_output) / 4, min(5, len(pos_output))) if len(pos_output) > 5: pos_output, pos_target = hard_mining(pos_output, pos_target, num_hard_pos, largest=False) pos_loss = self.classify_loss(pos_output, pos_target) * 0.5 # neg num_hard_neg = len(pos_output) * 2 neg_output = font_output[neg_index] neg_target = font_target[neg_index] neg_output, neg_target = hard_mining(neg_output, neg_target, num_hard_neg, largest=True) neg_loss = self.classify_loss(neg_output, neg_target) * 0.5 font_loss += pos_loss + neg_loss else: pos_loss, neg_loss = font_loss, font_loss return [font_loss, pos_loss, neg_loss] def _forward(self, font_output, font_target, weight, bbox_output=None, bbox_label=None, seg_output=None, seg_labels=None): font_output = self.sigmoid(font_output) font_loss = F.binary_cross_entropy(font_output, font_target, weight) acc = [] if bbox_output is not None: # bbox_loss = 0 bbox_output = bbox_output.view((-1, 4)) bbox_label = bbox_label.view((-1, 4)) pos_index = bbox_label[:, -1] >= 0.5 pos_index = pos_index.unsqueeze(1).expand(pos_index.size(0), 4) neg_index = bbox_label[:, -1] <= -0.5 neg_index = neg_index.unsqueeze(1).expand(neg_index.size(0), 4) # 正例 pos_label = bbox_label[pos_index].view((-1, 4)) pos_output = bbox_output[pos_index].view((-1, 4)) lx, ly, ld, lc = pos_label[:, 0], pos_label[:, 1], pos_label[:, 2], pos_label[:, 3] ox, oy, od, oc = pos_output[:, 0], pos_output[:, 1], pos_output[:, 2], pos_output[:, 3] regress_loss = [ self.regress_loss(ox, lx), self.regress_loss(oy, ly), self.regress_loss(od, ld), ] pc = self.sigmoid(oc) acc.append((pc >= 0.5).data.cpu().numpy().astype(np.float32).sum()) acc.append(len(pc)) # print pc.size(), lc.size() classify_loss = self.classify_loss(pc, lc) * 0.5 # 负例 neg_label = bbox_label[neg_index].view((-1, 4)) neg_output = bbox_output[neg_index].view((-1, 4)) lc = neg_label[:, 3] oc = neg_output[:, 3] pc = self.sigmoid(oc) acc.append((pc <= 0.5).data.cpu().numpy().astype(np.float32).sum()) acc.append(len(pc)) # print pc.size(), lc.size() classify_loss += self.classify_loss(pc, lc + 1) * 0.5 # seg_loss seg_output = seg_output.view(-1) seg_labels = seg_labels.view(-1) pos_index = seg_labels > 0.5 neg_index = seg_labels < 0.5 seg_loss = 0.5 * self.classify_loss(seg_output[pos_index], seg_labels[pos_index]) + \ 0.5 * self.classify_loss(seg_output[neg_index], seg_labels[neg_index]) seg_tpr = (seg_output[pos_index] > 0.5).data.cpu().numpy().astype(np.float32).sum() / len( seg_labels[pos_index]) seg_tnr = (seg_output[neg_index] < 0.5).data.cpu().numpy().astype(np.float32).sum() / len( seg_labels[neg_index]) # print seg_output[neg_index] # print seg_labels[neg_index] else: return font_loss if args.model == 'resnet': loss = font_loss + classify_loss + seg_loss else: loss = font_loss + classify_loss + seg_loss for reg in regress_loss: loss += reg # if args.model == 'resnet': # loss = seg_loss return [loss, font_loss, seg_loss, classify_loss] + regress_loss + acc + [seg_tpr, seg_tnr] font_num = font_target.sum(0).data.cpu().numpy() font_loss = 0 for di in range(font_num.shape[0]): if font_num[di] > 0: font_output_i = font_output[:, di] font_target_i = font_target[:, di] pos_font_index = font_target_i > 0.5 font_loss += 0.5 * self.classify_loss(font_output_i[pos_font_index], font_target_i[pos_font_index]) neg_font_index = font_target_i < 0.5 if len(font_target_i[neg_font_index]) > 0: font_loss += 0.5 * self.classify_loss(font_output_i[neg_font_index], font_target_i[neg_font_index]) font_loss = font_loss / (font_num > 0).sum() return font_loss # ''' def hard_mining(neg_output, neg_labels, num_hard, largest=True): num_hard = min(max(num_hard, 10), len(neg_output)) _, idcs = torch.topk(neg_output, min(num_hard, len(neg_output)), largest=largest) neg_output = torch.index_select(neg_output, 0, idcs) neg_labels = torch.index_select(neg_labels, 0, idcs) return neg_output, neg_labels def save_model(save_dir, phase, name, epoch, f1score, model): if not os.path.exists(save_dir): os.mkdir(save_dir) save_dir = os.path.join(save_dir, args.model) if not os.path.exists(save_dir): os.mkdir(save_dir) save_dir = os.path.join(save_dir, phase) if not os.path.exists(save_dir): os.mkdir(save_dir) state_dict = model.state_dict() for key in state_dict.keys(): state_dict[key] = state_dict[key].cpu() state_dict_all = { 'state_dict': state_dict, 'epoch': epoch, 'f1score': f1score, } saveStr = '{:s}.ckpt'.format(name) torch.save(state_dict_all, os.path.join(save_dir, saveStr)) # file.copy(os.path.join(save_dir, saveStr), os.path.join(args.save_dir_obs, saveStr)) if 'best' in name and f1score > 0.3: bestStr = '{:s}_{:s}.ckpt'.format(name, str(epoch)) torch.save(state_dict_all, os.path.join(save_dir, bestStr)) # file.copy(os.path.join(save_dir, bestStr), # os.path.join(args.save_dir_obs, bestStr)) def mkdir(path): if not os.path.exists(path): os.mkdir(path) def test(model, train_loader, phase='test'): print '\ntest {:s}_files, epoch: {:d}'.format(phase, 1) mkdir('../../../intermediate_file/recognize_result') model.eval() f1score_list = [] recall_list = [] precision_list = [] word_index_dict = json.load(open(args.word_index_json, 'r')) index_word_dict = {v: k for k, v in word_index_dict.items()} # result_file = file.File(datadir + '/data/result/{:d}_{:s}_result.csv'.format(epoch, phase), 'w') #modify here for single word # result_file = open(datadir + '/data/result/{:d}_{:s}_result_single_64.csv'.format(1, phase), 'w') #origin_csv is here result_file = open('../../../intermediate_file/recognize_result/result.csv', 'w') result_file.write('name,content\n') name_f1score_dict = dict() # 保存densenet生成的feature # feat_dir = args.data_dir.replace('dataset', 'feats') # mkdir(feat_dir) # feat_dir = os.path.join(feat_dir, phase) # print feat_dir # mkdir(feat_dir) names = [] probs_all = [] for i, data in enumerate(train_loader): if i % 50 == 0: print('step[{:d}] OK...'.format(i)) name = data[0][0].split('/')[-1].split('.seg')[0] names.append(name) images, labels = [Variable(x.cuda(async=True)) for x in data[1:3]] # print 'images.shape:',images.shape,'len(images.size()):',len(images.size()) if len(images.size()) == 5: images = images[0] probs, feats = model(images, 'test') probs_all.append(probs.data.cpu().numpy().max(2).max(1).max(0)) # pdb.set_trace() probs_temp=probs.data.cpu().numpy() #==================================== probs_temp = probs_temp.max(1) max_preds_line=np.max(probs_temp,axis=2)#max preds among 9115 words max_preds_index=np.argmax(probs_temp , axis=2) # print 'max_preds_index.shape',max_preds_index.shape # print 'max_preds_line.shape', # print 'max_preds_index:',max_preds_index.shape,'\nmax_preds_line:',max_preds_line probs_temp=probs_temp*0-float('inf') # max_preds_index[0,] for i in range(max_preds_index.shape[1]): probs_temp[0,i,max_preds_index[0,i]]=max_preds_line[0,i] preds = probs_temp > 0.0 # (-1, 8, 1824) # result_file.write(name+',') result = u'' # last_set = set() # all_set = set() # if args.feat: # # 保存所有的feat # feats = feats.data.cpu().numpy() # if i == 0: # print feats.shape # np.save(os.path.join(feat_dir, name.replace('.jpg', '.npy')), feats) # if len(feats) > 1: # feats: [-1, 1024, 1, 8] # # 多个patch # new_feats = [] # for i, feat in enumerate(feats): # if i == 0: # # 第一个patch,保存前6个 # new_feats.append(feat[:, :, :6]) # elif i == len(feats) - 1: # # 最后一个patch,保存后6个 # new_feats.append(feat[:, :, 2:]) # else: # # 保存中间4个 # new_feats.append(feat[:, :, 2:6]) # feats = np.concatenate(new_feats, 2) # 这种方法用于检测不同区域的同一个字,当同一个字同一个区域出现时,可能检测不到多次 # print 'name',name,'before preds.shape:',preds.shape,'type(preds):',type(preds) # preds = preds.max(1) # 沿着竖直方向pooling # print 'max_preds_index:',max_preds_index.shape,'\nmax_preds_line:',max_preds_line # print preds # if len(preds) > 1: # print name # print 'name',name,'after preds.shape:',preds.shape,'type(preds):',type(preds) lines_words_indexs=[] # lines_list=[] # print 'name',name for patch_i, patch_pred in enumerate(preds):#patch_pred stand for per image feature map 64*9115 # print 'patch_i:',patch_i # each_feature_map_words=set() num=0 for part_i, part_pred in enumerate(patch_pred):#part_i:0-64 part_pred:1*9115 # print 'name',name,'part_pred.shape:',part_pred.shape,'type(part_pred):',type(part_pred) # print part_i # new_set = set() # print result for idx, p in enumerate(part_pred):#idx:0-9114 if p: w='' # w = index_word_dict[idx] lines_words_indexs.append([part_i,w,idx]) # print part_i,':',w # new_set.add(w) # # if word_temp!=w # # count+=1 # if w not in all_set: # # 从没见过的字 # all_set.add(w) # result += w # # count_list.append(count) # # count=0 # # elif w!=result[-1]: # # result += w # # print 'result[-1]:',result[-1] # elif w not in last_set: # # not in last line # if patch_i == 0: #patch always 0 # # 第一个patch # 上一个部分没有这个字 # result += w # elif part_i >= preds.shape[1] / 2: # # 后续patch的后一半,不写 # 上一个部分没有这个字 # result += w # last_set = new_set words_total_list=[] word_list=[] lines_list=[] line_list=[] index_list=[] indexes_list=[] neigbour=2 for i,line_word_index in enumerate(lines_words_indexs): if i==0: line_list.append(line_word_index[0]) word_list.append(line_word_index[1]) index_list.append(line_word_index[2]) elif (line_word_index[0]-line_list[-1])<neigbour: line_list.append(line_word_index[0]) word_list.append(line_word_index[1]) index_list.append(line_word_index[2]) elif (line_word_index[0]-line_list[-1])>=neigbour: lines_list.append(line_list) words_total_list.append(word_list) indexes_list.append(index_list) line_list=[] word_list=[] index_list=[] line_list.append(line_word_index[0]) word_list.append(line_word_index[1]) index_list.append(line_word_index[2])#clustering according the line number if i==(len(lines_words_indexs)-1): lines_list.append(line_list) words_total_list.append(word_list) indexes_list.append(index_list) # for i,word in enumerate(words_total_list):#find the most common appear word in each small cluster word list # c=Counter(word) # x=c.most_common() # # print 'x',x # word=[item[0] for item in x if item[1]==x[0][1]] # words_total_list[i]=word# for i,index in enumerate(indexes_list):#find the most common appear word_index in each small cluster index list c=Counter(index) x=c.most_common() # print 'x',x index=[item[0] for item in x if item[1]==x[0][1]] indexes_list[i]=index for i,indexes in enumerate(indexes_list):#find the biggest probability word index in each small cluster index list if len(indexes)>1: # print 'lines_list[i]:',lines_list[i] # print probs_temp.shape,type(probs_temp) prob_temp=np.max(probs_temp[0,lines_list[i],:],0)#become 1*9105 result_word_index = np.argmax(prob_temp,0) indexes=[] indexes.append(result_word_index) indexes_list[i]=indexes for index in indexes_list:#map the word index to the word result result+=index_word_dict[index[0]] # probs_temp[lines_list[i]] # c=Counter(index) # x=c.most_common() # # print 'x',x # index=[item[0] for item in x if item[1]==x[0][1]] # indexes_list[i]=index # word=[] # print 'lines_list:',lines_list # print lines_list # print json.dumps(words_total_list, encoding="UTF-8", ensure_ascii=False) # print indexes_list result = result.replace(u'"', u'') if u',' in result: result = '"' + result + '"' if len(result) == 0: global_prob = probs.data.cpu().numpy().max(0).max(0).max(0) max_index = global_prob.argmax() result = index_word_dict[max_index] print 'name',name, 'result',result result_file.write(name + ',' + result + '\n') # result_file.write('\n') if phase == 'test': continue result_file.close() # import pandas as pd # re = pd.read_csv(datadir + '/data/result/{:d}_{:s}_result.csv'.format(epoch, phase)) # re.columns = ['target_file', 'text'] # submit = pd.read_csv(datadir + '/submission.csv') # submit = pd.merge(submit, re, how='left', on=['target_file']) # submit = submit.drop(['target_file'], axis=1) # submit = submit.replace(to_replace='None', value=20) # submit = submit.fillna('上') # submit.to_csv(datadir + '/predict.csv', header=True, index=None, encoding='utf-8') # file.copy(datadir + '/predict.csv', args.data_dir_obs + '/predict.csv') def get_weight(labels): labels = labels.data.cpu().numpy() weights = np.zeros_like(labels) # weight_false = 1.0 / ((labels<0.5).sum() + 10e-20) # weight_true = 1.0 / ((labels>0.5).sum() + 10e-20) weight_false = 1.0 / ((labels < 0.5).sum(0) + 10e-20) label_true = (labels > 0.5).sum(0) for i in range(labels.shape[1]): label_i = labels[:, i] weight_i = np.ones(labels.shape[0]) * weight_false[i] # weight_i = np.ones(labels.shape[0]) * weight_false if label_true[i] > 0: weight_i[label_i > 0.5] = 1.0 / label_true[i] weights[:, i] = weight_i weights *= np.ones_like(labels).sum() / (weights.sum() + 10e-20) weights[labels < -0.5] = 0 return weights def train_eval(epoch, model, train_loader, loss, optimizer, best_f1score=0, phase='train'): print '\n', epoch, phase if 'train' in phase: model.train() else: model.eval() loss_list = [] f1score_list = [] recall_list = [] precision_list = [] for i, data in enumerate(train_loader): images, labels = [Variable(x.cuda(async=True)) for x in data[1:3]] weights = torch.from_numpy(get_weight(labels)).cuda(async=True) probs = model(images) # 训练阶段 if 'train' in phase: loss_output = loss(probs, labels, weights, args.hard_mining) try: optimizer.zero_grad() loss_output[0].backward() optimizer.step() loss_list.append([x.data.cpu().numpy() for x in loss_output]) except: # pass traceback.print_exc() # 计算 f1score, recall, precision ''' x = probs.data.cpu().numpy() l = labels.data.cpu().numpy() print (get_weight(labels) * l).sum() l = 1 - l print (get_weight(labels) * l).sum() print x.max() print x.min() print x.mean() print # ''' preds = probs.data.cpu().numpy() > 0 labels = labels.data.cpu().numpy() for pred, label in zip(preds, labels): pred[label < 0] = -1 if label.sum() < 0.5: continue tp = (pred + label == 2).sum() tn = (pred + label == 0).sum() fp = (pred - label == 1).sum() fn = (pred - label == -1).sum() precision = 1.0 * tp / (tp + fp + 10e-20) recall = 1.0 * tp / (tp + fn + 10e-20) f1score = 2. * precision * recall / (precision + recall + 10e-20) precision_list.append(precision) recall_list.append(recall) f1score_list.append(f1score) if 'train' in phase and i % 50 == 0: loss_mean = np.array(loss_list).mean(0) print('step[{:d}] loss: {:3.4f} pos loss: {:3.4f} neg loss: {:3.4f}'.format(i, loss_mean[0], loss_mean[1], loss_mean[2])) # 保存中间结果到 data/middle_result,用于分析 if i == 0: images = images.data.cpu().numpy() * 128 + 128 if phase == 'pretrain': bbox_labels = bbox_labels.data.cpu().numpy() seg_labels = seg_labels.data.cpu().numpy() seg_output = seg_output.data.cpu().numpy() for ii in range(len(images)): middle_dir = os.path.join(args.save_dir, 'middle_result') if not os.path.exists(middle_dir): os.mkdir(middle_dir) middle_dir = os.path.join(middle_dir, phase) if not os.path.exists(middle_dir): os.mkdir(middle_dir) Image.fromarray(images[ii].astype(np.uint8).transpose(1, 2, 0)).save( os.path.join(middle_dir, str(ii) + '.image.png')) if phase == 'pretrain': segi = seg_labels[ii] _segi = np.array([segi, segi, segi]) * 255 segi = np.zeros([3, _segi.shape[1] * 2, _segi.shape[2] * 2]) for si in range(segi.shape[1]): for sj in range(segi.shape[2]): segi[:, si, sj] = _segi[:, si / 2, sj / 2] Image.fromarray(segi.transpose(1, 2, 0).astype(np.uint8)).save( os.path.join(middle_dir, str(ii) + '.seg.png')) segi = seg_output[ii] _segi = np.array([segi, segi, segi]) * 255 segi = np.zeros([3, _segi.shape[1] * 2, _segi.shape[2] * 2]) for si in range(segi.shape[1]): for sj in range(segi.shape[2]): segi[:, si, sj] = _segi[:, si / 2, sj / 2] Image.fromarray(segi.transpose(1, 2, 0).astype(np.uint8)).save( os.path.join(middle_dir, str(ii) + '.seg.out.png')) f1score = np.mean(f1score_list) print 'f1score', f1score print 'recall', np.mean(recall_list) print 'precision', np.mean(precision_list) if 'train' in phase: loss_mean = np.array(loss_list).mean(0) print 'loss: {:3.4f} pos loss: {:3.4f} neg loss: {:3.4f}'.format(loss_mean[0], loss_mean[1], loss_mean[2]) # 保存模型 if ('eval' in phase or 'pretrain' in phase) and best_f1score < 2: if args.small: save_dir = os.path.join(args.save_dir, 'models-small') else: save_dir = os.path.join(args.save_dir, 'models') if not os.path.exists(save_dir): os.mkdir(save_dir) if epoch % 3 == 0: save_model(save_dir, phase, str(epoch), epoch, f1score, model) if f1score > best_f1score: save_model(save_dir, phase, 'best_f1score', epoch, f1score, model) if args.model == 'resnet': tpnr = loss[11] + loss[12] # 这里用 best_f1score 也当tpnr好了,懒得改 if tpnr > best_f1score: best_f1score = tpnr save_model(save_dir, phase, 'best_tpnr', epoch, f1score, model) print 'best tpnr', best_f1score else: best_f1score = max(best_f1score, f1score) if best_f1score < 1: print '\n\t{:s}\tbest f1score {:3.4f}\n'.format(phase, best_f1score) return best_f1score def main(): word_index_dict = json.load(open(args.word_index_json, 'r')) num_classes = len(word_index_dict) image_label_dict = json.load(open(args.image_label_json, 'r')) cudnn.benchmark = True if args.model == 'densenet': # 两千多种字符,multi-label分类 model = DenseNet121(num_classes).cuda() elif args.model == 'resnet': # resnet主要用于文字区域的segmentation以及object detection操作 model = resnet.ResNet(num_classes=num_classes, args=args).cuda() else: return ##优化器 optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # model = torch.nn.DataParallel(model).cuda() loss = Loss().cuda() if args.resume: state_dict = torch.load(args.resume) model.load_state_dict(state_dict['state_dict']) best_f1score = state_dict['f1score'] start_epoch = state_dict['epoch'] + 1 else: best_f1score = 0 if args.model == 'resnet': start_epoch = 100 else: start_epoch = 1 print 'best_f1score', best_f1score # 划分数据集 # test_filelist1 = sorted(glob(os.path.join('/media/scut214/file/HLG/Chinese_Recognition', 'crop_test1', '*'))) #data set for single word # test_filelist1 = sorted(glob(os.path.join('/media/scut214/file/HLG/Chinese_Recognition', 'crop_test1', '*'))) test_filelist1 = sorted(glob(os.path.join('../../../intermediate_file', 'images_to_recognition', '*'))) # test_filelist=os.listdir('/Net/competition/OCR/ocr_densenet/data/dataset/crop') # trainval_filelist = sorted(glob(os.path.join(args.data_dir, 'train', '*'))) # 两种输入size训练 # 修改为自身数据集的尺寸。 # train_filelist1: 长宽比小于8:1的图片,经过padding后变成 64*512 的输入 # train_filelist2: 长宽比大于8:1的图片,经过padding,crop后变成 64*1024的输入 train_filelist1, train_filelist2 = [], [] print len(test_filelist1) # 黑名单,这些图片的label是有问题的 black_list = set(json.load(open(args.black_json, 'r'))['black_list']) image_hw_ratio_dict = json.load(open(args.image_hw_ratio_json, 'r')) test_filelist=[] for f in test_filelist1: image = f.split('/')[-1] if image in black_list: continue # r = image_hw_ratio_dict[image] # if r == 1 or r==2: test_filelist.append(f) # print r # else: # train_filelist2.append(f) train_val_filelist = train_filelist1 + train_filelist2 val_filelist = train_filelist1[-2048:] train_filelist1 = train_filelist1[:-2048] train_filelist2 = train_filelist2 image_size = [1300, 64] print len(test_filelist) if args.phase in ['test', 'val', 'train_val']: # 测试输出文字检测结果 test_dataset = dataloader.DataSet( test_filelist, image_label_dict, num_classes, # transform=train_transform, args=args, image_size=image_size, phase='test') test_loader = DataLoader( dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=True) # train_filelist = train_filelist1[-2048:] # train_dataset = dataloader.DataSet( # train_filelist, # image_label_dict, # num_classes, # image_size=image_size, # args=args, # phase='test') # train_loader = DataLoader( # dataset=train_dataset, # batch_size=1, # shuffle=False, # num_workers=8, # pin_memory=True) # val_dataset = dataloader.DataSet( # val_filelist, # image_label_dict, # num_classes, # image_size=image_size, # args=args, # phase='test') # val_loader = DataLoader( # dataset=val_dataset, # batch_size=1, # shuffle=False, # num_workers=8, # pin_memory=True) # train_val_dataset = dataloader.DataSet( # train_val_filelist, # image_label_dict, # num_classes, # image_size=image_size, # args=args, # phase='test') # train_val_loader = DataLoader( # dataset=train_val_dataset, # batch_size=1, # shuffle=False, # num_workers=8, # pin_memory=True) if args.phase == 'test': # test(start_epoch - 1, model, val_loader, 'val') test(model, test_loader, 'test') # test(start_epoch - 1, model, train_val_loader, 'train_val') # elif args.phase == 'val': # test(start_epoch - 1, model, train_loader, 'train')#valid set # test(start_epoch - 1, model, val_loader, 'val')#test set # elif args.phase == 'train_val': # test(start_epoch - 1, model, train_val_loader, 'train_val') return elif args.phase == 'train': train_dataset1 = dataloader.DataSet( train_filelist1, image_label_dict, num_classes, image_size=image_size, args=args, phase='train') train_loader1 = DataLoader( dataset=train_dataset1, batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=True) # train_dataset2 = dataloader.DataSet( # train_filelist2, # image_label_dict, # num_classes, # image_size=(1024,64), # args=args, # phase='train') # train_loader2 = DataLoader( # dataset=train_dataset2, # batch_size=args.batch_size / 2, # shuffle=True, # num_workers=8, # pin_memory=True) val_dataset = dataloader.DataSet( val_filelist, image_label_dict, num_classes, image_size=image_size, args=args, phase='val') val_loader = DataLoader( dataset=val_dataset, batch_size=min(8, args.batch_size), shuffle=False, num_workers=8, pin_memory=True) best_f1score = 0 # eval_mode = 'pretrain-2' eval_mode = 'eval' for a in range(start_epoch, args.epochs): args.epoch = epoch if eval_mode == 'eval': if best_f1score > 0.8: args.lr = 0.0001 if best_f1score > 0.7: args.hard_mining = 1 for param_group in optimizer.param_groups: param_group['lr'] = args.lr train_eval(epoch, model, train_loader1, loss, optimizer, 2., 'train-1') best_f1score = train_eval(epoch, model, val_loader, loss, optimizer, best_f1score, 'eval-{:d}-{:d}'.format(args.batch_size, args.hard_mining)) print 'best_f1score:',best_f1score continue if __name__ == '__main__': print 'eval-{:d}-{:d}'.format(args.batch_size, args.hard_mining) main()