# -*- coding: utf-8 -*- from __future__ import print_function import sys reload(sys) sys.setdefaultencoding('utf-8') import os os.environ["CUDA_VISIBLE_DEVICES"] = "0" import time import argparse import random import torch import torch.backends.cudnn as cudnn import torch.optim as optim import torch.utils.data from torch.autograd import Variable import numpy as np from warpctc_pytorch import CTCLoss import utils import dataset from keys import alphabet #Alphabet = [e.encode('utf-8') for e in alphabet] import models.efficient_densecrnn as densecrnn import distance from tensorboard_logger import configure, log_value configure("./log/densecrnn", flush_secs=5) parser = argparse.ArgumentParser() parser.add_argument('--trainroot', help='path to dataset',default='../data/newdata/train') parser.add_argument('--valroot', help='path to dataset',default='../data/newdata/val') parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) parser.add_argument('--batchSize', type=int, default=16, help='input batch size') parser.add_argument('--imgH', type=int, default=32, help='the height of the input image to network') parser.add_argument('--imgW', type=int, default=256, help='the width of the input image to network') parser.add_argument('--nh', type=int, default=256, help='size of the lstm hidden state') parser.add_argument('--niter', type=int, default=100, help='number of epochs to train for') parser.add_argument('--lr', type=float, default=0.001, help='learning rate for Critic, default=0.00005') parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.5') parser.add_argument('--cuda', action='store_true', help='enables cuda') parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') parser.add_argument('--crnn', help="path to crnn (to continue training)",default='')#default='../pretrain-models/netCRNN.pth') #parser.add_argument('--crnn', help="path to crnn (to continue training)",default='') parser.add_argument('--alphabet', default=alphabet) parser.add_argument('--experiment', help='Where to store samples and models',default='./save_model_four') parser.add_argument('--displayInterval', type=int, default=50, help='Interval to be displayed') parser.add_argument('--n_test_disp', type=int, default=1000, help='Number of samples to display when test') parser.add_argument('--valInterval', type=int, default=500, help='Interval to be displayed') parser.add_argument('--saveInterval', type=int, default=1000, help='Interval to be displayed') parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is rmsprop)') parser.add_argument('--adadelta', action='store_true', help='Whether to use adadelta (default is rmsprop)') parser.add_argument('--keep_ratio', action='store_true', help='whether to keep ratio for image resize') parser.add_argument('--random_sample', action='store_true', help='whether to sample the dataset with random sampler') opt = parser.parse_args() print(opt) ifUnicode=True if opt.experiment is None: opt.experiment = 'expr' os.system('mkdir {0}'.format(opt.experiment)) opt.manualSeed = random.randint(1, 10000) # fix seed print("Random Seed: ", opt.manualSeed) random.seed(opt.manualSeed) np.random.seed(opt.manualSeed) torch.manual_seed(opt.manualSeed) cudnn.benchmark = True if torch.cuda.is_available() and not opt.cuda: print("WARNING: You have a CUDA device, so you should probably run with --cuda") train_dataset = dataset.lmdbDataset(root=opt.trainroot) assert train_dataset if not opt.random_sample: sampler = dataset.randomSequentialSampler(train_dataset, opt.batchSize) else: sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=opt.batchSize, shuffle=True, sampler=sampler, num_workers=int(opt.workers), collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio)) test_dataset = dataset.lmdbDataset( root=opt.valroot, transform=dataset.resizeNormalize((256, 32))) test_dataset = dataset.lmdbDataset( root=opt.valroot) ngpu = int(opt.ngpu) nh = int(opt.nh) alphabet = opt.alphabet nclass = len(alphabet) + 1 nc = 1 converter = utils.strLabelConverter(alphabet) criterion = CTCLoss() # custom weights initialization called on crnn def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) #crnn = crnn.CRNN(opt.imgH, nc, nclass, nh, ngpu) crnn=densecrnn.DenseCrnnEfficient(nclass=nclass,nh=nh,growth_rate=12,block_config=(3,6,12,16), compression=0.5, num_init_features=24,bn_size=4,drop_rate=0,small=True) crnn.apply(weights_init) if opt.crnn != '': print('loading pretrained model from %s' % opt.crnn) crnn.load_state_dict(torch.load(opt.crnn)) print(crnn) image = torch.FloatTensor(opt.batchSize, 3, opt.imgH, opt.imgH) text = torch.IntTensor(opt.batchSize * 5) length = torch.IntTensor(opt.batchSize) if opt.cuda: crnn.cuda() image = image.cuda() criterion = criterion.cuda() image = Variable(image) text = Variable(text) length = Variable(length) # loss averager loss_avg = utils.averager() # setup optimizer if opt.adam: optimizer = optim.Adam(crnn.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) elif opt.adadelta: optimizer = optim.Adadelta(crnn.parameters(), lr=opt.lr) else: optimizer = optim.RMSprop(crnn.parameters(), lr=opt.lr) def val(net, test_dataset, criterion, max_iter=2): print('Start val') for p in crnn.parameters(): p.requires_grad = False net.eval() data_loader = torch.utils.data.DataLoader( test_dataset, batch_size=opt.batchSize, num_workers=int(opt.workers), sampler=dataset.randomSequentialSampler(test_dataset, opt.batchSize), collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio)) val_iter = iter(data_loader) i = 0 n_correct = 0 loss_avg = utils.averager() test_distance=0 max_iter = min(max_iter, len(data_loader)) for i in range(max_iter): data = val_iter.next() i += 1 cpu_images, cpu_texts = data batch_size = cpu_images.size(0) utils.loadData(image, cpu_images) if ifUnicode: cpu_texts = [ clean_txt(tx.decode('utf-8')) for tx in cpu_texts] t, l = converter.encode(cpu_texts) utils.loadData(text, t) utils.loadData(length, l) preds = crnn(image) preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size)) cost = criterion(preds, text, preds_size, length) / batch_size loss_avg.add(cost) _, preds = preds.max(2) # preds = preds.squeeze(2) preds = preds.transpose(1, 0).contiguous().view(-1) sim_preds = converter.decode(preds.data, preds_size.data, raw=False) for pred, target in zip(sim_preds, cpu_texts): if pred.strip() == target.strip(): n_correct += 1 # print(distance.levenshtein(pred.strip(), target.strip())) test_distance +=distance.nlevenshtein(pred.strip(), target.strip(),method=2) raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:opt.n_test_disp] for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts): print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt)) accuracy = n_correct / float(max_iter * opt.batchSize) test_distance=test_distance/float(max_iter * opt.batchSize) testLoss = loss_avg.val() #print('Test loss: %f, accuray: %f' % (testLoss, accuracy)) return testLoss,accuracy,test_distance def clean_txt(txt): """ filter char where not in alphabet with ' ' """ newTxt = u'' for t in txt: if t in alphabet: newTxt+=t else: newTxt+=u' ' return newTxt def trainBatch(net, criterion, optimizer,flage=False): n_correct = 0 train_distance=0 data = train_iter.next() cpu_images, cpu_texts = data##decode utf-8 to unicode if ifUnicode: cpu_texts = [ clean_txt(tx.decode('utf-8')) for tx in cpu_texts] batch_size = cpu_images.size(0) utils.loadData(image, cpu_images) t, l = converter.encode(cpu_texts) utils.loadData(text, t) utils.loadData(length, l) preds = crnn(image) preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size)) cost = criterion(preds, text, preds_size, length) / batch_size crnn.zero_grad() cost.backward() _, preds = preds.max(2) preds = preds.transpose(1, 0).contiguous().view(-1) sim_preds = converter.decode(preds.data, preds_size.data, raw=False) for pred, target in zip(sim_preds, cpu_texts): if pred.strip() == target.strip(): n_correct += 1 train_distance +=distance.nlevenshtein(pred.strip(),target.strip(),method=2) train_accuracy = n_correct / float(batch_size) train_distance=train_distance/float(batch_size) if flage: lr = 0.0001 optimizer = optim.Adadelta(crnn.parameters(), lr=lr) optimizer.step() return cost,train_accuracy,train_distance num =0 lasttestLoss = 10000 testLoss = 10000 import os def delete(path): """ 删除文件 """ import os import glob paths = glob.glob(path+'/*.pth') for p in paths: os.remove(p) numLoss = 0##判断训练参数是否下降 for epoch in range(opt.niter): train_iter = iter(train_loader) i = 0 while i < len(train_loader): #print('The step{} ........\n'.format(i)) for p in crnn.parameters(): p.requires_grad = True crnn.train() #if numLoss>50: # cost = trainBatch(crnn, criterion, optimizer,True) # numLoss = 0 #else: cost, train_accuracy, train_distance = trainBatch(crnn, criterion, optimizer) loss_avg.add(cost) i += 1 #if i % opt.displayInterval == 0: # print('[%d/%d][%d/%d] Loss: %f' % # (epoch, opt.niter, i, len(train_loader), loss_avg.val())) # loss_avg.reset() if i % opt.valInterval == 0: testLoss,accuracy,test_distance= val(crnn, test_dataset, criterion) localtime = time.asctime(time.localtime(time.time())) #print('Test loss: %f, accuray: %f' % (testLoss, accuracy)) print("time:{},epoch:{},step:{},test loss:{},test acc:{},train loss:{},train acc:{},test dis:{},train dis:{}".format(localtime,epoch,num,testLoss,accuracy,loss_avg.val(),train_accuracy,test_distance,train_distance)) log_value("test loss",float(testLoss),num) log_value("test accuracy",float(accuracy),num) log_value("train accuracy", float(train_accuracy), num) log_value("train loss",float(loss_avg.val()),num) log_value("test distanceloss", float(test_distance), num) log_value("train distanceloss",float(train_distance), num) loss_avg.reset() # do checkpointing num +=1 #lasttestLoss = min(lasttestLoss,testLoss) if lasttestLoss >testLoss: print("The step {},last lost:{}, current: {},save model!".format(num,lasttestLoss,testLoss)) lasttestLoss = testLoss #delete(opt.experiment)##删除历史模型 torch.save(crnn.state_dict(), '{}/netCRNN{}.pth'.format(opt.experiment,str(accuracy))) numLoss = 0 else: numLoss+=1