from __future__ import print_function import argparse import os import random import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.optim as optim import torch.utils.data from torch.autograd import Variable import torch.nn.functional as F import time from dataloader import nomalLoader as lsn from dataloader import trainLoaderN as DA from submodels import * ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) parser = argparse.ArgumentParser(description='deepCpmpletion') parser.add_argument('--model', default='normal', help='select model') parser.add_argument('--datatype', default='png', help='datapath') parser.add_argument('--datapath', default='', help='datapath') parser.add_argument('--epochs', type=int, default=40, help='number of epochs to train') parser.add_argument('--loadmodel', default='', help='load model') parser.add_argument('--savemodel', default='my', help='save model') parser.add_argument('--no-cuda', action='store_true', default=False, help='enables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') args = parser.parse_args() datapath = args.datapath args.cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) all_left_img = [] all_normal = [] all_gts = [] if args.model == 'normal': all_left_img, all_normal, all_gts = lsn.dataloader(datapath) print(len(all_left_img)) TrainImgLoader = torch.utils.data.DataLoader( DA.myImageFloder(all_left_img,all_normal,all_gts ,True, args.model), batch_size = 12, shuffle=True, num_workers=8, drop_last=True) model = s2dN() if args.cuda: model = nn.DataParallel(model) model.cuda() para_optim = [] if args.loadmodel is not None: state_dict = torch.load(args.loadmodel)["state_dict"] model.load_state_dict(state_dict) print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999)) def nomal_loss(pred, targetN,mask1): valid_mask = (mask1 > 0.0).detach() pred_n = pred.permute(0,2,3,1) pred_n = pred_n[valid_mask] target_n = targetN[valid_mask] pred_n = pred_n.contiguous().view(-1,3) pred_n = F.normalize(pred_n) target_n = target_n.contiguous().view(-1, 3) loss_function = nn.CosineEmbeddingLoss() loss = loss_function(pred_n, target_n, Variable(torch.Tensor(pred_n.size(0)).cuda().fill_(1.0))) return loss def train(inputl,sparse,mask,mask1,gt1): model.train() inputl = Variable(torch.FloatTensor(inputl)) gt1 = Variable(torch.FloatTensor(gt1)) sparse = Variable(torch.FloatTensor(sparse)) mask = Variable(torch.FloatTensor(mask)) mask1 = Variable(torch.FloatTensor(mask1)) if args.cuda: inputl,gt1 = inputl.cuda(),gt1.cuda() sparse=sparse.cuda() mask1 = mask1.cuda() mask = mask.cuda() optimizer.zero_grad() pred = model(inputl,sparse,mask) loss = nomal_loss(pred, gt1,mask1) loss.backward() optimizer.step() return loss.data[0] def adjust_learning_rate(optimizer, epoch): if epoch <= 10: learning_rate = 0.001 if epoch>10 and epoch <= 20: learning_rate = 0.0005 if epoch>20 and epoch <=30: learning_rate = 0.00025 if epoch>30 and epoch <= 40: learning_rate = 0.000125 print(learning_rate) for param_group in optimizer.param_groups: param_group['lr'] = learning_rate def main(): start_full_time = time.time() for epoch in range(1, args.epochs+1): total_train_loss = 0 adjust_learning_rate(optimizer,epoch) ## training ## for batch_idx, (imgL_crop,sparse_n,mask,mask1,data_in1) in enumerate(TrainImgLoader): start_time = time.time() loss= train(imgL_crop,sparse_n,mask,mask1,data_in1) print('%s Iter %d / %d training loss = %.4f, time = %.2f' % (args.model, batch_idx, epoch, loss, time.time() - start_time)) total_train_loss += loss print('epoch %d total training loss = %.10f' %(epoch, total_train_loss/len(TrainImgLoader))) #SAVE if epoch % 1 == 0: savefilename = args.savemodel+'.tar' torch.save({ 'epoch': epoch, 'state_dict': model.state_dict(), 'train_loss': total_train_loss/len(TrainImgLoader), }, savefilename) print('full finetune time = %.2f HR' %((time.time() - start_full_time)/3600)) if __name__ == '__main__': main()