from __future__ import print_function from __future__ import division import argparse import random import torch import torch.backends.cudnn as cudnn import torch.optim as optim import torch.utils.data from torch.autograd import Variable from torch.nn.functional import log_softmax import numpy as np # from warpctc_pytorch import CTCLoss from torch.nn import CTCLoss import os import utils import dataset_new import models.crnn as crnn import params alpha='.:ँंःअआइईउऊऋएऐऑओऔकखगघचछजझञटठडढणतथदधनपफबभमयरलळवशषसहािीुूृॅेैॉोौ्ॐड़ढ़०१२३४५६७८९\u200c\u200d()' parser = argparse.ArgumentParser() parser.add_argument('--trainroot', required=True, help='path to train dataset') parser.add_argument('--valroot', required=True, help='path to val dataset') args = parser.parse_args() '/home/azhar/crnn-pytorch/data/train_gt.txt' '/home/azhar/crnn-pytorch/data/validate_gt.txt' if not os.path.exists(params.expr_dir): os.makedirs(params.expr_dir) random.seed(params.manualSeed) np.random.seed(params.manualSeed) torch.manual_seed(params.manualSeed) cudnn.benchmark = True if torch.cuda.is_available() and not params.cuda: print("WARNING: You have a CUDA device, so you should probably set cuda in params.py to True") # ------------------------------------------------------------------------------------------------- # dealwith train and test data train_dataset = dataset_new.RecogDataset(data_file=args.trainroot) assert train_dataset if not params.random_sample: sampler = dataset_new.randomSequentialSampler(train_dataset, params.batchSize) else: sampler = None train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params.batchSize, shuffle=True, sampler=sampler, num_workers=int(params.workers)) test_dataset = dataset_new.RecogDataset(data_file=args.valroot) #train_dataset1 = dataset.lmdbDataset(root=args.trainroot, transform=dataset.resizeNormalize((params.imgW, params.imgH))) # ------------------------------------------------------------------------------------------------- # net init # custom weights initialization called on crnn def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) nclass = len(alpha) + 1 print(nclass) crnn = crnn.CRNN(params.imgH, params.nc, nclass, params.nh) crnn.apply(weights_init) if params.pretrained != '': print('loading pretrained model from %s' % params.pretrained) if params.multi_gpu: crnn = torch.nn.DataParallel(crnn) crnn.load_state_dict(torch.load(params.pretrained)) print(crnn) # ------------------------------------------------------------------------------------------------- converter = utils.strLabelConverter(alpha) criterion = CTCLoss() image = torch.FloatTensor(params.batchSize, 3, params.imgH, params.imgH) text = torch.IntTensor(params.batchSize * 5) length = torch.IntTensor(params.batchSize) if params.cuda and torch.cuda.is_available(): crnn.cuda() if params.multi_gpu: crnn = torch.nn.DataParallel(crnn, device_ids=range(params.ngpu)) image = image.cuda() criterion = criterion.cuda() image = Variable(image) text = Variable(text) length = Variable(length) # loss averager loss_avg = utils.averager() # setup optimizer if params.adam: optimizer = optim.Adam(crnn.parameters(), lr=params.lr, betas=(params.beta1, 0.999)) elif params.adadelta: optimizer = optim.Adadelta(crnn.parameters()) else: optimizer = optim.RMSprop(crnn.parameters(), lr=params.lr) def val(net, dataset, criterion, max_iter=100): #print('Start val') for p in crnn.parameters(): p.requires_grad = False net.eval() data_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=params.batchSize, num_workers=int(params.workers)) val_iter = iter(data_loader) i = 0 n_correct = 0 loss_avg = utils.averager() max_iter = min(max_iter, len(data_loader)) for i in range(max_iter): data = val_iter.next() i += 1 cpu_images, cpu_texts = data batch_size = cpu_images.size(0) utils.loadData(image, cpu_images) t, l = converter.encode(cpu_texts) utils.loadData(text, t) utils.loadData(length, l) preds = crnn(image) preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size)) cost = criterion(preds, text, preds_size, length) loss_avg.add(cost) _, preds = preds.max(2) preds = preds.transpose(1, 0).contiguous().view(-1) sim_preds = converter.decode(preds.data, preds_size.data, raw=False) #cpu_texts_decode = [] #for i in cpu_texts: # cpu_texts_decode.append(i.decode('utf-8', 'strict')) for pred, target in zip(sim_preds, cpu_texts): if pred == target: n_correct += 1 raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:params.n_test_disp] for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts): print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt)) accuracy = n_correct / float(max_iter * params.batchSize) print('loss: %f, accuray: %f' % (loss_avg.val(), accuracy)) return loss_avg.val(), accuracy def trainBatch(net, criterion, optimizer): data = train_iter.next() cpu_images, cpu_texts = data batch_size = cpu_images.size(0) utils.loadData(image, cpu_images) t, l = converter.encode(cpu_texts) utils.loadData(text, t) utils.loadData(length, l) optimizer.zero_grad() preds = crnn(image) preds = preds.log_softmax(2) preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size)) cost = criterion(preds, text, preds_size, length) # crnn.zero_grad() cost.backward() optimizer.step() return cost if __name__ == "__main__": with open(params.expr_dir+"/log.txt","w") as f: prevvalloss = -1 for epoch in range(params.nepoch): train_iter = iter(train_loader) i = 0 while i < len(train_loader): for p in crnn.parameters(): p.requires_grad = True crnn.train() cost = trainBatch(crnn, criterion, optimizer) loss_avg.add(cost) i += 1 if i % params.displayInterval == 0: print('[%d/%d][%d/%d] Loss: %f' % (epoch+1, params.nepoch, i, len(train_loader), loss_avg.val())) loss_avg.reset() print("end of epoch [%d/%d]" %(epoch+1,params.nepoch)) print("start testing on val set") valloss, valaccuracy = val(crnn, test_dataset, criterion) print("start testing on train set to check for overfitting") #trainloss, trainaccuracy = val(crnn, train_dataset, criterion) line = str(valloss)+"\t"+str(valaccuracy)+"\n"#"\t"+str(trainloss)+"\t"+str(trainaccuracy)+"\n" print(line) f.write(line) f.flush() if valloss<prevvalloss or prevvalloss==-1: # do checkpointing torch.save(crnn.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(params.expr_dir, epoch,valloss)) prevvalloss = valloss