import torch import torch.nn as nn import torch.optim as optim import time import torch.nn.functional as F #import matplotlib.pyplot as plt import nibabel as nib import numpy as np import pydicom from torch.utils.checkpoint import checkpoint fold = 1#int(sys.argv[1]) list_train = torch.Tensor([1,4,5,6,9,10]).long() if(fold==2): list_train = torch.Tensor([2,3,5,6,7,8,9]).long() if(fold==3): list_train = torch.Tensor([1,2,3,4,7,8,10]).long() B = 10#len(list_train) H = 233; W = 168; D = 286; imgs = torch.zeros(B,1,H,W,D) segs = torch.zeros(B,H,W,D).long() label_select = torch.Tensor([0,1,2,3,4,5,6,7,0,0,8,9]).long() for i in range(B): #case_train = int(list_train[i]) imgs[i,0,:,:,:] = torch.from_numpy(nib.load('/share/data_zoe2/heinrich/DatenPMBV/img'+str(i+1)+'v2.nii.gz').get_data())/500.0#.unsqueeze(0).unsqueeze(0) segs[i,:,:,:] = label_select[torch.from_numpy(nib.load('/share/data_zoe2/heinrich/DatenPMBV/seg'+str(i+1)+'v2.nii.gz').get_data()).long()] #img00 = torch.from_numpy(nib.load('/share/data_zoe2/heinrich/DatenPMBV/img10v2.nii.gz').get_data()).unsqueeze(0).unsqueeze(0) #img50 = torch.from_numpy(nib.load('/share/data_zoe2/heinrich/DatenPMBV/img5v2.nii.gz').get_data()).unsqueeze(0).unsqueeze(0) #seg00 = torch.from_numpy(nib.load('/share/data_zoe2/heinrich/DatenPMBV/seg10v2.nii.gz').get_data()).long().unsqueeze(0) #seg50 = torch.from_numpy(nib.load('/share/data_zoe2/heinrich/DatenPMBV/seg5v2.nii.gz').get_data()).long().unsqueeze(0) #seg00 = label_select[seg00] #seg50 = label_select[seg50] def dice_coeff(outputs, labels, max_label): dice = torch.FloatTensor(max_label-1).fill_(0) for label_num in range(1, max_label): iflat = (outputs==label_num).view(-1).float() tflat = (labels==label_num).view(-1).float() intersection = torch.mean(iflat * tflat) dice[label_num-1] = (2. * intersection) / (1e-8 + torch.mean(iflat) + torch.mean(tflat)) return dice print(np.unique(segs.view(-1).numpy())) # mask_train[i,:,:,:] = torch.from_numpy(nib.load('/share/data_zoe2/heinrich/DatenPMBV/mask'+str(case_train)+'v2.nii.gz').get_data())#.long() d0 = dice_coeff(segs[9,:,:,:], segs[4,:,:,:], 10) print(d0.mean(),d0) o_m = H//3 o_n = W//3 o_o = D//3 print('numel_o',o_m*o_n*o_o) ogrid_xyz = F.affine_grid(torch.eye(3,4).unsqueeze(0),(1,1,o_m,o_n,o_o)).view(1,1,-1,1,3).cuda() def init_weights(m): if isinstance(m, nn.Linear) or isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): nn.init.xavier_normal(m.weight) if m.bias is not None: nn.init.constant(m.bias, 0.0) def countParameters(model): model_parameters = filter(lambda p: p.requires_grad, model.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) return params class OBELISK(nn.Module): def __init__(self): super(OBELISK, self).__init__() channels = 16#16 self.offsets = nn.Parameter(torch.randn(2,channels*2,3)*0.05) self.layer0 = nn.Conv3d(1, 4, 5, stride=2, bias=False, padding=2) self.batch0 = nn.BatchNorm3d(4) self.layer1 = nn.Conv3d(channels*8, channels*4, 1, bias=False, groups=1) self.batch1 = nn.BatchNorm3d(channels*4) self.layer2 = nn.Conv3d(channels*4, channels*4, 3, bias=False, padding=1) self.batch2 = nn.BatchNorm3d(channels*4) self.layer3 = nn.Conv3d(channels*4, channels*1, 1) def forward(self, input_img): img_in = F.avg_pool3d(input_img,3,padding=1,stride=2) img_in = F.relu(self.batch0(self.layer0(img_in))) sampled = F.grid_sample(img_in,ogrid_xyz + self.offsets[0,:,:].view(1,-1,1,1,3)).view(1,-1,o_m,o_n,o_o) sampled -= F.grid_sample(img_in,ogrid_xyz + self.offsets[1,:,:].view(1,-1,1,1,3)).view(1,-1,o_m,o_n,o_o) x = F.relu(self.batch1(self.layer1(sampled))) x = F.relu(self.batch2(self.layer2(x))) features = self.layer3(x) return features disp_range = 0.4#0.25 displacement_width = 15#11#17 shift_xyz = F.affine_grid(disp_range*torch.eye(3,4).unsqueeze(0),(1,1,displacement_width,displacement_width,displacement_width)).view(1,1,-1,1,3).cuda() #_,_,H,W,D = img00.size() grid_size = 32#25#30 grid_xyz = F.affine_grid(torch.eye(3,4).unsqueeze(0),(1,1,grid_size,grid_size,grid_size)).view(1,-1,1,1,3).cuda() # print('moving_unfold',torch.numel(moving_unfold)*4e-6,'MBytes') #print('deeds_cost',torch.numel(deeds_cost)*4e-6,'MBytes') #minconv_seq = nn.Sequential(pad1,max1,avg1,avg1) #cost = checkpoint(deeds_cost,minconv_seq) #print('cost',torch.numel(cost)*4e-6,'MBytes') # # avg_seq = nn.Sequential(pad2,avg1,avg1) #>>> input_var = checkpoint_sequential(model, chunks, input_var) #cost_avg = checkpoint(avg_seq,cost_permute)# # print('cost_avg',torch.numel(cost_avg)*4e-6,'MBytes') def augmentAffine(img_in, seg_in, strength=0.05): """ 3D affine augmentation on image and segmentation mini-batch on GPU. (affine transf. is centered: trilinear interpolation and zero-padding used for sampling) :input: img_in batch (torch.cuda.FloatTensor), seg_in batch (torch.cuda.LongTensor) :return: augmented BxCxTxHxW image batch (torch.cuda.FloatTensor), augmented BxTxHxW seg batch (torch.cuda.LongTensor) """ B,C,D,H,W = img_in.size() affine_matrix = (torch.eye(3,4).unsqueeze(0) + torch.randn(B, 3, 4) * strength).to(img_in.device) meshgrid = F.affine_grid(affine_matrix,torch.Size((B,1,D,H,W))) img_out = F.grid_sample(img_in, meshgrid,padding_mode='border') seg_out = F.grid_sample(seg_in.float().unsqueeze(1), meshgrid, mode='nearest').long().squeeze(1) return img_out, seg_out class deeds(nn.Module): def __init__(self): super(deeds, self).__init__() self.alpha = nn.Parameter(torch.Tensor([1,.1,1,1,.1,1]))#.cuda() self.pad1 = nn.ReplicationPad3d(3)#.cuda() self.avg1 = nn.AvgPool3d(3,stride=1)#.cuda() self.max1 = nn.MaxPool3d(3,stride=1)#.cuda() self.pad2 = nn.ReplicationPad3d(2)#.cuda()## def forward(self, feat00,feat50): #deeds correlation layer (slightly unrolled) deeds_cost = torch.zeros(1,grid_size**3,displacement_width,displacement_width,displacement_width).cuda() xyz8 = grid_size**2 for i in range(grid_size): moving_unfold = F.grid_sample(feat50,grid_xyz[:,i*xyz8:(i+1)*xyz8,:,:,:] + shift_xyz,padding_mode='border') fixed_grid = F.grid_sample(feat00,grid_xyz[:,i*xyz8:(i+1)*xyz8,:,:,:]) deeds_cost[:,i*xyz8:(i+1)*xyz8,:,:,:] = self.alpha[1]+self.alpha[0]*torch.sum(torch.pow(fixed_grid-moving_unfold,2),1).view(1,-1,displacement_width,displacement_width,displacement_width) # remove mean (not really necessary) #deeds_cost = deeds_cost.view(-1,displacement_width**3) - deeds_cost.view(-1,displacement_width**3).mean(1,keepdim=True)[0] deeds_cost = deeds_cost.view(1,-1,displacement_width,displacement_width,displacement_width) # approximate min convolution / displacement compatibility cost = self.avg1(self.avg1(-self.max1(-self.pad1(deeds_cost)))) # grid-based mean field inference (one iteration) cost_permute = cost.permute(2,3,4,0,1).view(1,displacement_width**3,grid_size,grid_size,grid_size) cost_avg = self.avg1(self.avg1(self.pad2(cost_permute))).permute(0,2,3,4,1).view(1,-1,displacement_width,displacement_width,displacement_width) # second path cost = self.alpha[4]+self.alpha[2]*deeds_cost+self.alpha[3]*cost_avg cost = self.avg1(self.avg1(-self.max1(-self.pad1(cost)))) # grid-based mean field inference (one iteration) cost_permute = cost.permute(2,3,4,0,1).view(1,displacement_width**3,grid_size,grid_size,grid_size) cost_avg = self.avg1(self.avg1(self.pad2(cost_permute))).permute(0,2,3,4,1).view(grid_size**3,displacement_width**3) #cost = alpha[4]+alpha[2]*deeds_cost+alpha[3]*cost.view(1,-1,displacement_width,displacement_width,displacement_width) #cost = avg1(avg1(-max1(-pad1(cost)))) #probabilistic and continuous output cost_soft = F.softmax(-self.alpha[5]*cost_avg,1) # pred_xyz = torch.sum(F.softmax(-5self.alpha[2]*cost_avg,1).unsqueeze(2)*shift_xyz.view(1,-1,3),1) pred_xyz = torch.sum(cost_soft.unsqueeze(2)*shift_xyz.view(1,-1,3),1) return cost_soft,pred_xyz net = OBELISK() net.apply(init_weights) net.cuda() net.train() class_weight = torch.sqrt(1.0/(torch.bincount(segs.view(-1)).float())) class_weight = class_weight/class_weight.mean() class_weight[0] = 0.15 class_weight = class_weight.cuda() print('inv sqrt class_weight',class_weight) criterion = nn.CrossEntropyLoss(class_weight) t0 = time.time() reg = deeds() reg.cuda() print('alpha_before',reg.alpha) list_train = torch.Tensor([1,4,5,6,9,10]).long()-1 #img00.requires_grad = True #img50.requires_grad = True iterations = 1000 lambda_weight = 2#2.5#1.5 run_labelloss = torch.zeros(iterations)#/0 run_diffloss = torch.zeros(iterations)#/0 optimizer = optim.Adam(list(net.parameters())+list(reg.parameters()),lr=0.005) for i in range(iterations): idx = list_train[torch.randperm(6)].view(2,3)[:,0] #print(idx) optimizer.zero_grad() label_moving = torch.zeros(size=(1,10,H,W,D)).cuda() label_moving = label_moving.scatter_(1, segs[idx[1]:idx[1]+1,:,:,:].unsqueeze(1).cuda(), 1).detach() img00_in = imgs[idx[0]:idx[0]+1,:,:,:,:].cuda() img50 = imgs[idx[1]:idx[1]+1,:,:,:,:].cuda() img00, seg50 = augmentAffine(img00_in,segs[idx[0]:idx[0]+1,:,:,:].cuda(),0.0375) img00.requires_grad = True img50.requires_grad = True label_fixed = torch.zeros(size=(1,10,H,W,D)).cuda() label_fixed = label_fixed.scatter_(1, seg50.unsqueeze(1), 1).detach() # get features (regular grid) feat00 = checkpoint(net,img00)#net(img00)# #00 is fixed feat50 = checkpoint(net,img50)#net(img50)# #50 is moving # run differentiable deeds (regular grid) cost_soft,pred_xyz = checkpoint(reg,feat00,feat50)#reg(feat00,feat50)# pred_xyz = pred_xyz.view(1,grid_size,grid_size,grid_size,3) # evaluate diffusion regularisation loss diffloss = lambda_weight*((pred_xyz[0,:,1:,:,:]-pred_xyz[0,:,:-1,:,:])**2).mean()+\ lambda_weight*((pred_xyz[0,1:,:,:,:]-pred_xyz[0,:-1,:,:,:])**2).mean()+\ lambda_weight*((pred_xyz[0,:,:,1:,:]-pred_xyz[0,:,:,:-1,:])**2).mean() run_diffloss[i] = diffloss.item() # evaluate non-local loss nonlocal_label = (F.grid_sample(label_moving,grid_xyz+shift_xyz,padding_mode='border')\ *cost_soft.view(1,-1,grid_size**3,displacement_width**3,1)).sum(3,keepdim=True) fixed_label = F.grid_sample(label_fixed,grid_xyz,padding_mode='border').detach()#.long().squeeze(1) labelloss = ((nonlocal_label-fixed_label)**2)*class_weight.view(1,-1,1,1,1) labelloss = labelloss.mean() #labelloss = criterion(nonlocal_label,fixed_label) run_labelloss[i] = labelloss.item() (labelloss+diffloss).backward() optimizer.step() if(i%50==49): print('epoch',i,'time',time.time()-t0) #print('grad',reg.layer1.weight.grad.norm().item()) loss_avg = F.avg_pool1d(run_labelloss.view(1,1,-1),5,stride=1).squeeze().numpy()[:i] print('run_labelloss',loss_avg[-1]) loss_avg = F.avg_pool1d(run_diffloss.view(1,1,-1),5,stride=1).squeeze().numpy()[:i] print('run_diffloss',loss_avg[-1]) #plt.plot(F.avg_pool1d(run_labelloss.view(1,1,-1),5,stride=1).squeeze().numpy()[:i]) #plt.plot(F.avg_pool1d(run_diffloss.view(1,1,-1),5,stride=1).squeeze().numpy()[:i]) #plt.show() #plt.imshow(pred_xyz[0,:,12,:,0].cpu().data.numpy()) #plt.colorbar() #plt.show() torch.save(net.cpu().state_dict(),'/data_supergrover2/heinrich/dense_reg3_feat_epoch'+str(i)+'.pth') torch.save(reg.cpu().state_dict(),'/data_supergrover2/heinrich/dense_reg3_deeds_epoch'+str(i)+'.pth') net.cuda() reg.cuda() # torch.cuda.synchronize() print('time',time.time()-t0) print('grad_alpha',reg.alpha.grad.norm()) print('grad_obelisk',net.layer1.weight.grad.norm()) print('alpha_after',reg.alpha)