import numpy as np import os import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.utils.data import DataLoader from utils.plot import plotHistogram,plotCummulative,plotSeries from utils import train_op, torch_op from utils.torch_op import v,npy from utils.log import AverageMeter from utils import log import config from tensorboardX import SummaryWriter import cv2 import util import time import re import glob from opts import opts from utils.dotdict import * from quaternion import * from utils.factory import trainer from model.mymodel import Resnet18_8s, SCNet from utils.callbacks import PeriodicCallback, OnceCallback, ScheduledCallback,CallbackLoc import copy from sklearn.decomposition import PCA #**--*--**--*--**--*--**--*--**--*--**--*--**--*--**--*--**--*--**--*-- #**--*--**--*--**--*--**--*--**--*--**--*--**--*--**--*--**--*--**--*-- # here is the place for customized functions def visNorm(vis): for v in range(len(vis)): if (vis[v].max().item() - vis[v].min().item())!=0: vis[v] = (vis[v]-vis[v].min())/(vis[v].max()-vis[v].min()) return vis def visNormV1(vis,min_,max_): for v in range(len(vis)): if (max_ - min_)!=0: vis[v] = ((vis[v]-min_)/(max_-min_)).clamp(0,None) return vis def class_to_color(classIdx, dataList): if 'scannet' in dataList: colors = config.scannet_color_palette[classIdx,:] elif 'matterport' in dataList: colors = config.matterport_color_palette[classIdx,:] elif 'suncg' in dataList: colors = config.suncg_color_palette[classIdx,:] return colors def apply_mask(x,maskMethod,*arg): # input: [n,c,h,w] h=x.shape[2] w=x.shape[3] tp = np.zeros([x.shape[0],1,x.shape[2],x.shape[3]]) geow=np.zeros([x.shape[0],1,x.shape[2],x.shape[3]]) if maskMethod == 'second': tp[:,:,:h,h:2*h]=1 ys,xs=np.meshgrid(range(h),range(w),indexing='ij') dist=np.stack((np.abs(xs-h),np.abs(xs-(2*h)),np.abs(xs-w-h),np.abs(xs-w-(2*h))),0) dist=dist.min(0)/h sigmaGeom=0.7 dist=np.exp(-dist/(2*sigmaGeom**2)) dist[:,h:2*h]=0 geow = torch_op.v(np.tile(np.reshape(dist,[1,1,dist.shape[0],dist.shape[1]]),[geow.shape[0],1,1,1])) elif maskMethod == 'kinect': assert(w==640 and h==160) dw = int(89.67//2) dh = int(67.25//2) tp[:,:,80-dh:80+dh,160+80-dw:160+80+dw]=1 geow = tp.copy()*20 geow[tp==0]=1 geow = torch_op.v(geow) tp=torch_op.v(tp) x=x*tp return x,tp,geow #**--*--**--*--**--*--**--*--**--*--**--*--**--*--**--*--**--*--**--*-- #**--*--**--*--**--*--**--*--**--*--**--*--**--*--**--*--**--*--**--*-- def buildDataset(args): def worker_init_fn(worker_id): np.random.seed(np.random.get_state()[1][0] + worker_id) if 'suncg' in args.dataList: from datasets.SUNCG import SUNCG as Dataset elif 'scannet' in args.dataList: from datasets.ScanNet import ScanNet as Dataset elif 'matterport' in args.dataList: from datasets.Matterport3D import Matterport3D as Dataset else: raise Exception("unknown dataset!") train_dataset = Dataset('train', config.nViews,AuthenticdepthMap=False,meta=False,rotate=False,rgbd=True,hmap=False,segm=True,normal=True\ ,list_=f"./data/dataList/{args.dataList}.npy",singleView=args.single_view,denseCorres=args.featurelearning,reproj=True,\ representation=args.representation,dynamicWeighting=args.dynamicWeighting,snumclass=args.snumclass) val_dataset = Dataset('test', nViews=config.nViews,meta=False,rotate=False,rgbd=True,hmap=False,segm=True,normal=True,\ list_=f"./data/dataList/{args.dataList}.npy",singleView=args.single_view,denseCorres=args.featurelearning,reproj=True,\ representation=args.representation,dynamicWeighting=args.dynamicWeighting,snumclass=args.snumclass) if args.num_workers == 1: train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,drop_last=True, collate_fn=util.collate_fn_cat, worker_init_fn=worker_init_fn) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True,drop_last=True, collate_fn=util.collate_fn_cat, worker_init_fn=worker_init_fn) else: train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,drop_last=True,collate_fn=util.collate_fn_cat, worker_init_fn=worker_init_fn) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers,drop_last=True,collate_fn=util.collate_fn_cat, worker_init_fn=worker_init_fn) return train_loader,val_loader class learnerParam(object): def __init__(self,train_step_vis=600,val_step_vis=50,\ train_step_log=100,val_step_log=10): self.train_step_vis = train_step_vis self.val_step_vis = val_step_vis self.train_step_log = train_step_log self.val_step_log = val_step_log class learner(object): def __init__(self,args,learnerParam): self.learnerParam=learnerParam self.args=args self.epochStart = 0 self.userConfig() # build network if self.args.representation == 'skybox': self.netG=SCNet(args).cuda() Fargs = copy.copy(args) Fargs.num_input = 7 self.netF=Resnet18_8s(Fargs).cuda() if 'suncg' in self.args.dataList: checkpoint = torch.load('./data/pretrained_model/suncg.feat.pth.tar') elif 'matterport' in self.args.dataList: checkpoint = torch.load('./data/pretrained_model/matterport.feat.pth.tar') elif 'scannet' in self.args.dataList: checkpoint = torch.load('./data/pretrained_model/scannet.feat.pth.tar') state_dict = checkpoint['state_dict'] model_dict = self.netF.state_dict() # 1. filter out unnecessary keys state_dict = {k: v for k, v in state_dict.items() if k in model_dict} # 2. overwrite entries in the existing state dict model_dict.update(state_dict) # 3. load the new state dict self.netF.load_state_dict(model_dict) print('resume F network weights successfully') else: raise Exception("unknown representation") if self.args.parallel: if torch.cuda.device_count()>1: self.netG = torch.nn.DataParallel(self.netG, device_ids=[0,1]).cuda() train_op.parameters_count(self.netG, 'netG') # setup optimizer self.optimizerG = torch.optim.Adam(self.netG.parameters(), lr=0.0002, betas=(0.5, 0.999)) # resume if specified if self.args.resume: self.load_checkpoint() useScheduler=False if useScheduler: self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, [1000,2000], 0.1 ) def userConfig(self): """ include the task specific setup here """ if self.args.featurelearning: assert('f' in self.args.outputType) pointer = 0 if 'rgb' in self.args.outputType: self.args.idx_rgb_start = pointer self.args.idx_rgb_end = pointer + 3 pointer += 3 if 'n' in self.args.outputType: self.args.idx_n_start = pointer self.args.idx_n_end = pointer + 3 pointer += 3 if 'd' in self.args.outputType: self.args.idx_d_start = pointer self.args.idx_d_end = pointer + 1 pointer += 1 if 'k' in self.args.outputType: self.args.idx_k_start = pointer self.args.idx_k_end = pointer + 1 pointer += 1 if 's' in self.args.outputType: self.args.idx_s_start = pointer self.args.idx_s_end = pointer + self.args.snumclass # 21 class pointer += self.args.snumclass if 'f' in self.args.outputType: self.args.idx_f_start = pointer self.args.idx_f_end = pointer + self.args.featureDim pointer += self.args.featureDim self.args.num_output = pointer self.args.num_input = 8*2 self.args.ngpu = int(1) self.args.nz = int(100) self.args.ngf = int(64) self.args.ndf = int(64) self.args.nef = int(64) self.args.nBottleneck = int(4000) self.args.wt_recon = float(0.998) self.args.wtlD = float(0.002) self.args.overlapL2Weight = 10 # setup logger self.tensorboardX = SummaryWriter(log_dir=os.path.join(self.args.EXP_DIR, 'tensorboard')) self.logger = log.logging(self.args.EXP_DIR_LOG) self.logger_errG = AverageMeter() self.logger_errG_recon = AverageMeter() self.logger_errG_rgb = AverageMeter() self.logger_errG_d = AverageMeter() self.logger_errG_n = AverageMeter() self.logger_errG_s = AverageMeter() self.logger_errG_k = AverageMeter() self.logger_errD_fake = AverageMeter() self.logger_errD_real = AverageMeter() self.logger_errG_fl = AverageMeter() self.logger_errG_fl_pos = AverageMeter() self.logger_errG_fl_neg = AverageMeter() self.logger_errG_fl_f = AverageMeter() self.logger_errG_fc = AverageMeter() self.logger_errG_pn = AverageMeter() self.logger_errG_freq = AverageMeter() self.global_step=0 self.speed_benchmark=True if self.speed_benchmark: self.time_per_step=AverageMeter() self.sift = cv2.xfeatures2d.SIFT_create() self.evalFeatRatioDL_obs,self.evalFeatRatioDL_unobs=[],[] self.evalFeatRatioDLc_obs,self.evalFeatRatioDLc_unobs=[],[] self.evalFeatRatioSift=[] self.evalErrN=[] self.evalErrD=[] self.evalSemantic = [] self.evalSemantic_gt = [] self.sancheck={} # semantic encoding if 'scannet' in self.args.dataList: self.colors = config.scannet_color_palette elif 'matterport' in self.args.dataList: self.colors = config.matterport_color_palette elif 'suncg' in self.args.dataList: self.colors = config.suncg_color_palette self.class_balance_weights = torch_op.v(np.ones([self.args.snumclass])) def set_mode(self,mode='train'): if mode == 'train': self.netG.train() else: return #!!! self.netG.eval() return def update_lr(self): self.lr_scheduler.step() def save_checkpoint_helper(self,net,optimizer,filename,clean=True,epoch=None): # find previous saved model and only retain the 5 most recent model state = { 'epoch':epoch, 'state_dict': net.state_dict(), 'optimizer' : optimizer.state_dict() } torch.save(state, filename) if clean: NumRetain=3 dirname=os.path.dirname(filename) checkpointName=filename.split('/')[-1] num=re.findall(r'\d+', checkpointName)[0] checkpointName=checkpointName.replace(num,'*') checkpoints=glob.glob(f"{dirname}/{checkpointName}") checkpoints.sort() for i in range(len(checkpoints)-NumRetain): cmd=f"rm {checkpoints[i]}" os.system(cmd) def save_checkpoint(self, context): epoch = context['epoch'] self.logger('save model: {0}'.format(epoch)) self.save_checkpoint_helper(self.netG,self.optimizerG,\ os.path.join(self.args.EXP_DIR_PARAMS, 'checkpoint_G_{0:04d}.pth.tar'.format(epoch)),clean=True,epoch=epoch) def load_checkpoint(self): try: if self.args.model is not None: net_path = self.args.model else: net_path = train_op.get_latest_model(self.args.EXP_DIR_PARAMS, 'checkpoint') checkpoint = torch.load(net_path) state_dict = checkpoint['state_dict'] self.epochStart = checkpoint['epoch']+1 self.netG.load_state_dict(state_dict) print('resume network weights from {0} successfully'.format(net_path)) self.optimizerG.load_state_dict(checkpoint['optimizer']) print('resume optimizer weights from {0} successfully'.format(net_path)) except Exception as e: print(e) print("resume fail, start training from scratch!") def evalPlot(self,context): # normal angle error visEvalErrNc=plotCummulative(np.array(self.evalErrN),'angular error','percentage','ours') visEvalErrNh=plotHistogram(np.array(self.evalErrN),'angular error','probability','ours') # plane distance visEvalErrDc=plotCummulative(np.array(self.evalErrD),'l1 error','percentage','ours') visEvalErrDh=plotHistogram(np.array(self.evalErrD),'l1 error','probability','ours') # semantic class distribution if self.args.objectFreqLoss: self.evalSemantic = np.concatenate(self.evalSemantic,0).mean(0) self.evalSemantic_gt = np.concatenate(self.evalSemantic_gt,0).mean(0) visEvalSemantic=plotSeries([range(len(self.colors)),range(len(self.colors))],\ [self.evalSemantic,self.evalSemantic_gt],'class','pixel percentage',['ours','gt']) # descriptive power of learned feature visEvalFeat=plotCummulative([np.array(self.evalFeatRatioDLc_obs),np.array(self.evalFeatRatioDLc_unobs),np.array(self.evalFeatRatioDL_obs),np.array(self.evalFeatRatioDL_unobs),np.array(self.evalFeatRatioSift)],'ratio','percentage',['dl_complete_obs','dl_complete_unobs','dl_partial_obs','dl_partial_unobs','sift']) visEval = np.concatenate((visEvalErrNc,visEvalErrNh,visEvalErrDc,visEvalErrDh,visEvalFeat,visEvalSemantic),1) cv2.imwrite(os.path.join(self.args.EXP_DIR,f"evalMetric_epoch_{context['epoch']}.png"),visEval) self.evalErrN=[] self.evalErrD=[] self.evalSemantic=[] self.evalSemantic_gt=[] def evalSiftDescriptor(self,rgb,denseCorres): ratios=[] n=rgb.shape[0] Kn = denseCorres['idxSrc'].shape[1] for jj in range(n): if denseCorres['valid'][jj].item() == 0: continue idx=np.random.choice(range(Kn),100) rs=(torch_op.npy(rgb[jj,0,:,:,:]).transpose(1,2,0)*255).astype('uint8') grays= cv2.cvtColor(rs,cv2.COLOR_BGR2GRAY) rt=(torch_op.npy(rgb[jj,1,:,:,:]).transpose(1,2,0)*255).astype('uint8') grayt= cv2.cvtColor(rt,cv2.COLOR_BGR2GRAY) step_size = 5 tp=torch_op.npy(denseCorres['idxSrc'][jj,idx,:]) kp = [cv2.KeyPoint(coord[0], coord[1], step_size) for coord in tp] _,sifts = self.sift.compute(grays, kp) tp=torch_op.npy(denseCorres['idxTgt'][jj,idx,:]) kp = [cv2.KeyPoint(coord[0], coord[1], step_size) for coord in tp] _,siftt = self.sift.compute(grayt, kp) dist=np.power(sifts-siftt,2).sum(1) kp = [cv2.KeyPoint(x, y, step_size) for y in range(0, rgb.shape[3], step_size) for x in range(0, rgb.shape[4], step_size)] _,dense_feat = self.sift.compute(grayt, kp) distRest=np.power(np.expand_dims(sifts,1)-np.expand_dims(dense_feat,0),2).sum(2) ratio=(distRest<dist[:,np.newaxis]).sum(1)/distRest.shape[1] ratios.append(ratio.mean()) return ratios def evalDLDescriptor(self,featMaps,featMapt,denseCorres,rs,rt,mask): Kn = denseCorres['idxSrc'].shape[1] C = featMaps.shape[1] n = featMaps.shape[0] ratiosObs,ratiosUnobs=[],[] rsnpy,rtnpy,masknpy=torch_op.npy(rs),torch_op.npy(rt),torch_op.npy(mask) # dim the image to illustrate mask area rsnpy = rsnpy * masknpy + 0.5*rsnpy * (1-masknpy) rtnpy = rtnpy * masknpy + 0.5*rtnpy * (1-masknpy) plot_all=[] try: for jj in range(n): if denseCorres['valid'][jj].item() == 0: continue idx=np.random.choice(range(Kn),100) typeCP = torch_op.npy(denseCorres['observe'][jj,idx]) featSrc = featMaps[jj,:,denseCorres['idxSrc'][jj,idx,1].long(),denseCorres['idxSrc'][jj,idx,0].long()] featTgt = featMapt[jj,:,denseCorres['idxTgt'][jj,idx,1].long(),denseCorres['idxTgt'][jj,idx,0].long()] dist = (featSrc-featTgt).pow(2).sum(0) distRest= (featSrc.unsqueeze(2) - featMapt[jj].view(C,1,-1)).pow(2).sum(0) ratio = (distRest<dist.unsqueeze(1)).sum(1).float()/distRest.shape[1] ratio = torch_op.npy(ratio) if ((typeCP==2).sum()>0): ratiosObs.append(ratio[typeCP==2].mean()) if ((typeCP<2).sum()>0): ratiosUnobs.append(ratio[typeCP<2].mean()) except: import ipdb;ipdb.set_trace() return ratiosObs,ratiosUnobs,plot_all def sancheck_total_traversed(self, imgsPath): #print(imgsPath) for kk in range(len(imgsPath[0])): if imgsPath[0][kk] not in self.sancheck: self.sancheck[imgsPath[0][kk]]=1 else: self.sancheck[imgsPath[0][kk]]+=1 for kk in range(len(imgsPath[1])): if imgsPath[1][kk] not in self.sancheck: self.sancheck[imgsPath[1][kk]]=1 else: self.sancheck[imgsPath[1][kk]]+=1 def contrast_loss(self,featMaps,featMapt,denseCorres): validCorres=torch.nonzero(denseCorres['valid']==1).view(-1).long() n = featMaps.shape[0] if not len(validCorres): loss_fl_pos=torch_op.v(np.array([0]))[0] loss_fl_neg=torch_op.v(np.array([0]))[0] loss_fl=torch_op.v(np.array([0]))[0][0] loss_fc=torch_op.v(np.array([0]))[0] else: # consistency of keypoint proposal across different view idxInst=torch.arange(n)[validCorres].view(-1,1).repeat(1,denseCorres['idxSrc'].shape[1]).view(-1).long() featS=featMaps[idxInst,:,denseCorres['idxSrc'][validCorres,:,1].view(-1).long(),denseCorres['idxSrc'][validCorres,:,0].view(-1).long()] featT=featMapt[idxInst,:,denseCorres['idxTgt'][validCorres,:,1].view(-1).long(),denseCorres['idxTgt'][validCorres,:,0].view(-1).long()] # positive example, loss_fl_pos=(featS-featT).pow(2).sum(1).mean() # negative example, make sure does not contain positive Kn = denseCorres['idxSrc'].shape[1] C = featMaps.shape[1] negIdy=torch.from_numpy(np.random.choice(range(featMaps.shape[2]),Kn*100*len(validCorres))) negIdx=torch.from_numpy(np.random.choice(range(featMaps.shape[3]),Kn*100*len(validCorres))) idx=torch.arange(n)[validCorres].view(-1,1).repeat(1,Kn*100).view(-1).long() loss_fl_neg=F.relu(self.args.D-(featS.unsqueeze(1).repeat(1,100,1).view(-1,C)-featMapt[idx,:,negIdy,negIdx]).pow(2).sum(1)).mean() loss_fl=loss_fl_pos+loss_fl_neg return loss_fl, loss_fl_pos, loss_fl_neg def step(self,data,mode='train'): torch.cuda.empty_cache() if self.speed_benchmark: step_start=time.time() with torch.set_grad_enabled(mode == 'train'): np.random.seed() self.optimizerG.zero_grad() MSEcriterion = torch.nn.MSELoss() BCEcriterion = torch.nn.BCELoss() CEcriterion = nn.CrossEntropyLoss(weight=self.class_balance_weights,reduce=False) rgb,norm,depth,dataMask,Q = v(data['rgb']),v(data['norm']),v(data['depth']),v(data['dataMask']),v(data['Q']) proj_rgb_p,proj_n_p,proj_d_p,proj_mask_p = v(data['proj_rgb_p']),v(data['proj_n_p']),v(data['proj_d_p']),v(data['proj_mask_p']) proj_flow = v(data['proj_flow']) if 's' in self.args.outputType: segm = v(data['segm']) if self.args.dynamicWeighting: dynamicW = v(data['proj_box_p']) dynamicW[dynamicW==0] = 0.2 dynamicW = torch.cat((dynamicW[:,0,:,:,:],dynamicW[:,1,:,:,:])) else: dynamicW = 1 errG_rgb,errG_d,errG_n,errG_k,errG_s = torch.FloatTensor([0]),torch.FloatTensor([0]),torch.FloatTensor([0]),torch.FloatTensor([0]),torch.FloatTensor([0]) n = Q.shape[0] complete_s=torch.cat((rgb[:,0,:,:,:],norm[:,0,:,:,:],depth[:,0:1,:,:]),1) complete_t=torch.cat((rgb[:,1,:,:,:],norm[:,1,:,:,:],depth[:,1:2,:,:]),1) view_s,mask_s,geow_s = apply_mask(complete_s.clone(),self.args.maskMethod,self.args.ObserveRatio) view_s = torch.cat((view_s,mask_s),1) view_t,mask_t,geow_t = apply_mask(complete_t.clone(),self.args.maskMethod,self.args.ObserveRatio) view_t = torch.cat((view_t,mask_t),1) view_t2s=torch.cat((proj_rgb_p[:,0,:,:,:],proj_n_p[:,0,:,:,:],proj_d_p[:,0,:,:,:],proj_mask_p[:,0,:,:,:]),1) view_s2t=torch.cat((proj_rgb_p[:,1,:,:,:],proj_n_p[:,1,:,:,:],proj_d_p[:,1,:,:,:],proj_mask_p[:,1,:,:,:]),1) # netG need to tolerate three type of input: # 0.correct s + blank t # 1.correct s + wrong t # 2.correct s + correct t view_s_type0 = torch.cat((view_s,torch.zeros(view_s.shape).cuda()),1) view_s_type1 = torch.cat((view_s,view_t2s),1) view_t_type0 = torch.cat((view_t,torch.zeros(view_t.shape).cuda()),1) view_t_type1 = torch.cat((view_t,view_s2t),1) if 's' in self.args.outputType: segm = torch.cat((segm[:,0,:,:,:],segm[:,1,:,:,:])).repeat(2,1,1,1) # mask the pano view=torch.cat((view_s_type0,view_t_type0,view_s_type1,view_t_type1)) mask=torch.cat((mask_s,mask_t)).repeat(2,1,1,1) geow=torch.cat((geow_s,geow_t)).repeat(2,1,1,1) complete =torch.cat((complete_s,complete_t)).repeat(2,1,1,1) dataMask = torch.cat((dataMask[:,0,:,:,:],dataMask[:,1,:,:,:])).repeat(2,1,1,1) fake = self.netG(view) with torch.set_grad_enabled(False): fakec = self.netF(complete) if 'f' in self.args.outputType: featMapsc = fakec[:n] featMaptc = fakec[n:n*2] if np.random.rand()>0.5: featMaps = fake[:n,self.args.idx_f_start:self.args.idx_f_end,:,:] featMapt = fake[n:n*2,self.args.idx_f_start:self.args.idx_f_end,:,:] else: featMaps = fake[n*2:n*3,self.args.idx_f_start:self.args.idx_f_end,:,:] featMapt = fake[n*3:n*4,self.args.idx_f_start:self.args.idx_f_end,:,:] if self.args.featurelearning: denseCorres = data['denseCorres'] validCorres=torch.nonzero(denseCorres['valid']==1).view(-1).long() loss_fl, loss_fl_pos, loss_fl_neg = self.contrast_loss(featMaps,featMapt,data['denseCorres']) # categorize each correspondence by whether it contain unobserved point allCorres = torch.cat((denseCorres['idxSrc'],denseCorres['idxTgt'])) corresShape = allCorres.shape allCorres = allCorres.view(-1,2).long() typeIdx = torch.arange(corresShape[0]).view(-1,1).repeat(1,corresShape[1]).view(-1).long() typeIcorresP = mask[typeIdx,0,allCorres[:,1],allCorres[:,0]] typeIcorresP=typeIcorresP.view(2,-1,corresShape[1]).sum(0) denseCorres['observe'] = typeIcorresP loss_fc=torch.pow((fake[:,self.args.idx_f_start:self.args.idx_f_end,:,:]-fakec.detach())*dataMask*geow,2).sum(1).mean() errG_recon = 0 if self.args.GeometricWeight: total_weight = geow[:,0:1,:,:]*dynamicW*dataMask else: total_weight = dynamicW*dataMask if 'rgb' in self.args.outputType: errG_rgb = ((fake[:,self.args.idx_rgb_start:self.args.idx_rgb_end,:,:]-complete[:,0:3,:,:])*total_weight).abs().mean() errG_recon += errG_rgb if 'n' in self.args.outputType: errG_n = ((fake[:,self.args.idx_n_start:self.args.idx_n_end,:,:]-complete[:,3:6,:,:])*total_weight).abs().mean() errG_recon += errG_n if 'd' in self.args.outputType: errG_d = ((fake[:,self.args.idx_d_start:self.args.idx_d_end,:,:]-complete[:,6:7,:,:])*total_weight).abs().mean() errG_recon += errG_d if 'k' in self.args.outputType: errG_k = ((fake[:,self.args.idx_k_start:self.args.idx_k_end,:,:]-complete[:,7:8,:,:])*total_weight).abs().mean() errG_recon += errG_k if 's' in self.args.outputType: errG_s = (CEcriterion(fake[:,self.args.idx_s_start:self.args.idx_s_end,:,:],segm.squeeze(1).long())*total_weight).mean() * 0.1 errG_recon += errG_s errG = errG_recon if self.args.pnloss: loss_pn = util.pnlayer(torch.cat((depth[:,0:1,:,:],depth[:,1:2,:,:])),fake[:,3:6,:,:],fake[:,6:7,:,:]*4,self.args.dataList,self.args.representation)*1e-1 #loss_pn = util.pnlayer(torch.cat((depth[:,0:1,:,:],depth[:,1:2,:,:])),complete[:,3:6,:,:],complete[:,6:7,:,:]*4,self.args.dataList,self.args.representation)*1e-1 errG += loss_pn if self.args.featurelearning: errG += loss_fl+loss_fc #if errG.item()>100: # import ipdb;ipdb.set_trace() if mode == 'train': errG.backward() self.optimizerG.step() self.logger_errG.update(errG.data, Q.size(0)) self.logger_errG_rgb.update(errG_rgb.data, Q.size(0)) self.logger_errG_n.update(errG_n.data, Q.size(0)) self.logger_errG_d.update(errG_d.data, Q.size(0)) self.logger_errG_s.update(errG_s.data, Q.size(0)) self.logger_errG_k.update(errG_k.data, Q.size(0)) if self.args.pnloss: self.logger_errG_pn.update(loss_pn.data, Q.size(0)) if self.args.featurelearning: self.logger_errG_fl.update(loss_fl.data, Q.size(0)) self.logger_errG_fl_pos.update(loss_fl_pos.data, Q.size(0)) self.logger_errG_fl_neg.update(loss_fl_neg.data, Q.size(0)) self.logger_errG_fc.update(loss_fc.data, Q.size(0)) if self.args.objectFreqLoss: self.logger_errG_freq.update(loss_freq.data, Q.size(0)) suffix = f"| errG {self.logger_errG.avg:.6f}| | errG_fl {self.logger_errG_fl.avg:.6f}\ | errG_fl_pos {self.logger_errG_fl_pos.avg:.6f} | errG_fl_neg {self.logger_errG_fl_neg.avg:.6f} | errG_fc {self.logger_errG_fc.avg:.6f} | errG_pn {self.logger_errG_pn.avg:.6f} | errG_freq {self.logger_errG_freq.avg:.6f}" if self.global_step % getattr(self.learnerParam,f"{mode}_step_vis") == 0: print(f"total image trasversed:{len(self.sancheck)}\n") # do logging and visualizing if 'n' in self.args.outputType: # normalized normal faken = fake[:,self.args.idx_n_start:self.args.idx_n_end,:,:] faken = faken/torch.norm(faken,dim=1,keepdim=True) vis = [] if 'rgb' in self.args.outputType: # draw rgb visrgb = complete[:,0:3,:,:] visrgbm = view[:,0:3,:,:] visrgbm2 = view[:,8+0:8+3,:,:] visrgbf = fake[:,self.args.idx_rgb_start:self.args.idx_rgb_end,:,:] visrgbf = visNorm(visrgbf) visrgbc = (fake[:,self.args.idx_rgb_start:self.args.idx_rgb_end,:,:]*(1-mask)+visrgb*mask) visrgbc = visNorm(visrgbc) visrgb = torch.cat((visrgbm,visrgbm2,visrgbf,visrgbc,visrgb),2) visrgb = visNorm(visrgb) vis.append(visrgb) if 'n' in self.args.outputType: # draw normal visn = complete[:,3:6,:,:] visnm = view[:,3:6,:,:] visnm2 = view[:,8+3:8+6,:,:] visnf = faken visnc = (faken*(1-mask)+visn*mask) visn = torch.cat((visnm,visnm2,visnf,visnc,visn),2) visn = visNorm(visn) vis.append(visn) if 'd' in self.args.outputType: # draw depth visd = complete[:,6:7,:,:] visdm = view[:,6:7,:,:] visdm2 = view[:,8+6:8+7,:,:] visdf = fake[:,self.args.idx_d_start:self.args.idx_d_end,:,:] visdc = (fake[:,self.args.idx_d_start:self.args.idx_d_end,:,:]*(1-mask)+visd*mask) visd = torch.cat((visdm,visdm2,visdf,visdc,visd),2) visd = visNorm(visd) visd = visd.repeat(1,3,1,1) vis.append(visd) if 'k' in self.args.outputType: # draw keypoint visk = complete[:,7:8,:,:] viskm = view[:,7:8,:,:] viskm2 = view[:,8+7:8+8,:,:] viskf = fake[:,self.args.idx_k_start:self.args.idx_k_end:,:] viskc = fake[:,self.args.idx_k_start:self.args.idx_k_end:,:].clone() viskc = viskc*(1-mask)+(viskc.view(viskc.shape[0],-1).min(1)[0].view(-1,1,1,1))*mask viskc = visNorm(viskc) viskc = util.extractKeypoint(viskc) viskc = (viskc*(1-mask)+visk*mask) visk = torch.cat((viskm,viskf,viskc,visk),2) visk = visk.repeat(1,3,1,1) vis.append(visk) if 's' in self.args.outputType: # draw semantic viss = segm vissm = viss*mask[:,0:1,:,:] vissf = fake[:,self.args.idx_s_start:self.args.idx_s_end,:,:] vissf = torch.argmax(vissf,1,keepdim=True).float() vissc = (vissf*(1-mask)+viss*mask) viss = torch.cat((vissm,vissf,vissc,viss),2) visstp= torch_op.npy(viss) visstp= np.expand_dims(np.squeeze(visstp,1),3) visstp= self.colors[visstp.flatten().astype('int'),:].reshape(visstp.shape[0],visstp.shape[1],visstp.shape[2],3) viss = torch_op.v(visstp.transpose(0,3,1,2))/255. vis.append(viss) if self.args.dynamicWeighting: visdw = dynamicW.repeat(1,3,1,1) vis.append(visdw) if 'f' in self.args.outputType: # draw feature error map visf = fake[:,self.args.idx_f_start:self.args.idx_f_end,:,:] visf = (visf - fakec).pow(2).sum(1,keepdim=True) visf = visNorm(visf) visf = visf.repeat(1,3,1,1) vis.append(visf) visw = total_weight.repeat(1,3,1,1) vis.append(visw) # concate all vis vis = torch.cat(vis, 2)[::2] permute = [2, 1, 0] # bgr to rgb vis = vis[:,permute,:,:] if mode != 'train': with torch.set_grad_enabled(False): if 'n' and 'd' in self.args.outputType: # evaluate strcuture prediction ## 1. normal angle mask_n=(1-mask[:,0:1,:,:]).cpu() mask_n = mask_n * dataMask.cpu() evalErrN=(torch.acos(((faken.cpu()*complete[:,3:6,:,:].cpu()).sum(1,keepdim=True)[mask_n!=0]).clamp(-1,1))/np.pi*180) self.evalErrN.extend(npy(evalErrN)) ## 2. plane distance evalErrD=((fake[:,6:7,:,:].cpu()-complete[:,6:7,:,:].cpu())[mask_n!=0]).abs() self.evalErrD.extend(npy(evalErrD)) # evaluate the learned feature ## 1. descriptive power of learned feature if self.args.featurelearning: if len(validCorres): self.evalFeatRatioSift.extend(self.evalSiftDescriptor(rgb,denseCorres)) obs,unobs,_=self.evalDLDescriptor(featMapsc,featMaptc,denseCorres,complete_s[:,0:3,:,:],complete_t[:,0:3,:,:],mask[0:1,0:1,:,:]) self.evalFeatRatioDLc_obs.extend(obs) self.evalFeatRatioDLc_unobs.extend(unobs) obs,unobs,_=self.evalDLDescriptor(featMaps,featMapt,denseCorres,complete_s[:,0:3,:,:],complete_t[:,0:3,:,:],mask[0:1,0:1,:,:]) self.evalFeatRatioDL_obs.extend(obs) self.evalFeatRatioDL_unobs.extend(unobs) if self.args.objectFreqLoss: freq_pred = freq_pred/freq_pred.sum(1,keepdim=True) freq_gt = freq_gt/freq_gt.sum(1,keepdim=True) self.evalSemantic.append(torch_op.npy(freq_pred)) self.evalSemantic_gt.append(torch_op.npy(freq_gt)) train_op.tboard_add_img(self.tensorboardX,vis,f"{mode}/loss",self.global_step) if self.global_step % getattr(self.learnerParam,f"{mode}_step_log") == 0: self.tensorboardX.add_scalars('data/errG_recon', {f"{mode}":errG_recon}, self.global_step) self.tensorboardX.add_scalars('data/errG_rgb', {f"{mode}":errG_rgb}, self.global_step) self.tensorboardX.add_scalars('data/errG_n', {f"{mode}":errG_n}, self.global_step) self.tensorboardX.add_scalars('data/errG_d', {f"{mode}":errG_d}, self.global_step) self.tensorboardX.add_scalars('data/errG_s', {f"{mode}":errG_s}, self.global_step) self.tensorboardX.add_scalars('data/errG_k', {f"{mode}":errG_k}, self.global_step) if self.args.pnloss: self.tensorboardX.add_scalars('data/errG_pnloss', {f"{mode}":loss_pn}, self.global_step) if self.args.featurelearning: self.tensorboardX.add_scalars('data/errG_fl', {f"{mode}_complete":loss_fl}, self.global_step) self.tensorboardX.add_scalars('data/errG_fl_pos', {f"{mode}_complete":loss_fl_pos}, self.global_step) self.tensorboardX.add_scalars('data/errG_fl_neg', {f"{mode}_complete":loss_fl_neg}, self.global_step) self.tensorboardX.add_scalars('data/errG_fc', {f"{mode}":loss_fc}, self.global_step) if self.args.objectFreqLoss: self.tensorboardX.add_scalars('data/errG_freq', {f"{mode}":loss_freq}, self.global_step) summary = {'suffix':suffix} self.global_step+=1 if self.speed_benchmark: self.time_per_step.update(time.time()-step_start,1) print(f"time elapse per step: {self.time_per_step.avg}") return dotdict(summary) def main(): # parse arguments, build exp dir opt = opts() args = opt.parse() train_op.initialize_experiment_directories(args) train_op.platform_specific_initialization(args) # build data loader train_loader,val_loader=buildDataset(args) # build learner lp = learnerParam() model=learner(args,lp) # build trainer and launch training mytrainer=trainer( model, train_loader, val_loader, max_epoch=200, ) mytrainer.add_callbacks([PeriodicCallback(cb_loc=CallbackLoc.epoch_end,pstep=5,func=model.save_checkpoint)]) mytrainer.add_callbacks([PeriodicCallback(cb_loc=CallbackLoc.epoch_end,pstep=5,func=model.evalPlot)]) mytrainer.run() if __name__ == '__main__': main()