import numpy as np import torch import torch.nn.functional as F import torch.utils.model_zoo as model_zoo import data import util import chamfer # network class AtlasNet(torch.nn.Module): def __init__(self,opt,eval_enc=False,eval_dec=False): super(AtlasNet,self).__init__() # define UV self.UV_sphere,self.faces_sphere = data.get_icosahedron(opt) self.UV_regular,self.faces_regular = self.get_regular_patch_grid(opt) self.faces_regular = self.duplicate_faces_original(opt) # define and load pretrained weights self.define_weights(opt) if opt.pretrained_dec is not None: self.load_pretrained_decoder(opt) for p in self.encoder.parameters(): p.requires_grad_(not eval_enc) for p in self.decoder.parameters(): p.requires_grad_(not eval_dec) (self.encoder.eval if eval_enc else self.encoder.train)() (self.decoder.eval if eval_dec else self.decoder.train)() def define_weights(self,opt): embed_size = 1024 self.encoder = resnet18(pretrained=opt.imagenet_enc,num_classes=embed_size) code_size = embed_size+3 if opt.sphere else embed_size+2 self.decoder = torch.nn.ModuleList([PointGenCon(code_size=code_size) for _ in range(opt.num_prim)]) self.encoder = self.encoder.to(opt.device) self.decoder = self.decoder.to(opt.device) def load_pretrained_decoder(self,opt): print(util.magenta("loading pretrained decoder ({})...".format(opt.pretrained_dec))) weight_dict = torch.load(opt.pretrained_dec,map_location=opt.device) # remove "decoder/" prefix in dictionary decoder_weight_dict = {k[8:]: weight_dict[k] for k in weight_dict if "decoder" in k} self.decoder.load_state_dict(decoder_weight_dict) def decoder_forward(self,opt,code,regular=False): batch_size = code.shape[0] points_list = [] for p in range(opt.num_prim): if opt.sphere: UV = self.UV_sphere.repeat(batch_size,1,1).permute(0,2,1) else: if regular: UV = self.UV_regular.repeat(batch_size,1,1).permute(0,2,1) else: UV = torch.rand(batch_size,2,opt.num_points,device=opt.device) concat = torch.cat([UV,code[...,None].repeat(1,1,UV.shape[2])],dim=1) points_prim = self.decoder[p](concat) points_list.append(points_prim) points = torch.cat(points_list,dim=-1).permute(0,2,1) return points def forward(self,opt,image,regular=False): code = self.encoder.forward(image) points = self.decoder_forward(opt,code,regular=regular) return points def get_regular_patch_grid(self,opt): N = opt.num_meshgrid # vertices (UV space) U,V = np.meshgrid(range(N+1),range(N+1)) U = (U.astype(np.float32)/N).reshape([-1]) V = (V.astype(np.float32)/N).reshape([-1]) UV = np.stack([U,V],axis=-1) UV = torch.tensor(UV,dtype=torch.float32,device=opt.device) # facess J,I = np.meshgrid(range(N),range(N)) face_upper = np.stack([I*(N+1)+J,I*(N+1)+J+1,(I+1)*(N+1)+J],axis=-1).reshape([-1,3]) face_lower = np.stack([I*(N+1)+J+1,(I+1)*(N+1)+J+1,(I+1)*(N+1)+J],axis=-1).reshape([-1,3]) faces = np.concatenate([face_upper,face_lower],axis=0) faces = torch.tensor(faces,dtype=torch.int32,device=opt.device) return UV,faces def duplicate_faces_original(self,opt): faces_list = [self.faces_regular+(opt.num_meshgrid+1)**2*p for p in range(opt.num_prim)] self.faces_regular = torch.cat(faces_list,dim=0) return self.faces_regular # ---------- AtlasNet decoder blackbox below ---------- class PointGenCon(torch.nn.Module): def __init__(self,code_size): self.bottleneck_size = code_size super(PointGenCon,self).__init__() self.conv1 = torch.nn.Conv1d(self.bottleneck_size,self.bottleneck_size,1) self.conv2 = torch.nn.Conv1d(self.bottleneck_size,self.bottleneck_size//2,1) self.conv3 = torch.nn.Conv1d(self.bottleneck_size//2,self.bottleneck_size//4,1) self.conv4 = torch.nn.Conv1d(self.bottleneck_size//4,3,1) self.th = torch.nn.Tanh() self.bn1 = torch.nn.BatchNorm1d(self.bottleneck_size) self.bn2 = torch.nn.BatchNorm1d(self.bottleneck_size//2) self.bn3 = torch.nn.BatchNorm1d(self.bottleneck_size//4) def forward(self,x): batchsize = x.size()[0] x = F.relu(self.bn1(self.conv1(x))) x = F.relu(self.bn2(self.conv2(x))) x = F.relu(self.bn3(self.conv3(x))) x = self.th(self.conv4(x)) return x # ---------- ResNet blackbox below ---------- def resnet18(pretrained=False,**kwargs): model = ResNet(BasicBlock,[2,2,2,2],**kwargs) if pretrained: print(util.magenta("loading pretrained encoder...")) weight_dict = model_zoo.load_url("https://download.pytorch.org/models/resnet18-5c106cde.pth") block_names = list(set([k.split(".")[0] for k in weight_dict.keys()])) for b in block_names: if b=="fc": continue block_weight_dict = {".".join(k.split(".")[1:]): weight_dict[k] for k in weight_dict if k[:len(b)]==b} getattr(model,b).load_state_dict(block_weight_dict) return model class BasicBlock(torch.nn.Module): expansion = 1 def __init__(self,inplanes,planes,stride=1,downsample=None): super(BasicBlock,self).__init__() self.conv1 = torch.nn.Conv2d(inplanes,planes,kernel_size=3,stride=stride,padding=1,bias=False) self.bn1 = torch.nn.BatchNorm2d(planes) self.relu = torch.nn.ReLU(inplace=True) self.conv2 = torch.nn.Conv2d(planes,planes,kernel_size=3,stride=1,padding=1,bias=False) self.bn2 = torch.nn.BatchNorm2d(planes) self.downsample = downsample self.stride = stride def forward(self,x): residual = x out = self.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class ResNet(torch.nn.Module): def __init__(self,block,layers,num_classes=1000): self.inplanes = 64 super(ResNet,self).__init__() self.conv1 = torch.nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False) self.bn1 = torch.nn.BatchNorm2d(64) self.relu = torch.nn.ReLU(inplace=True) self.maxpool = torch.nn.MaxPool2d(kernel_size=3,stride=2,padding=1) self.layer1 = self._make_layer(block,64,layers[0]) self.layer2 = self._make_layer(block,128,layers[1],stride=2) self.layer3 = self._make_layer(block,256,layers[2],stride=2) self.layer4 = self._make_layer(block,512,layers[3],stride=2) self.avgpool = torch.nn.AvgPool2d(7) self.fc = torch.nn.Linear(512*block.expansion,num_classes) for m in self.modules(): if isinstance(m,torch.nn.Conv2d): n = m.kernel_size[0]*m.kernel_size[1]*m.out_channels m.weight.data.normal_(0,np.sqrt(2./n)) elif isinstance(m,torch.nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def _make_layer(self,block,planes,blocks,stride=1): downsample = None if stride!=1 or self.inplanes!=planes*block.expansion: downsample = torch.nn.Sequential( torch.nn.Conv2d(self.inplanes,planes*block.expansion,kernel_size=1,stride=stride,bias=False), torch.nn.BatchNorm2d(planes*block.expansion), ) layers = [] layers.append(block(self.inplanes,planes,stride,downsample)) self.inplanes = planes*block.expansion for i in range(1,blocks): layers.append(block(self.inplanes,planes)) return torch.nn.Sequential(*layers) def forward(self,x): x = self.maxpool(self.relu(self.bn1(self.conv1(x)))) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = x.reshape(x.shape[0],-1) x = self.fc(x) return x # ---------- chamfer distance blackbox below ---------- class ChamferDistance(torch.autograd.Function): @staticmethod def forward(ctx,opt,p1,p2): batch_size = p1.shape[0] num_p1_points = p1.shape[1] num_p2_points = p2.shape[1] dist1 = torch.zeros(batch_size,num_p1_points,device=opt.device) dist2 = torch.zeros(batch_size,num_p2_points,device=opt.device) idx1 = torch.zeros(batch_size,num_p1_points,dtype=torch.int32,device=opt.device) idx2 = torch.zeros(batch_size,num_p2_points,dtype=torch.int32,device=opt.device) p1 = p1.contiguous() p2 = p2.contiguous() if "cuda" in opt.device: chamfer.forward(p1,p2,dist1,dist2,idx1,idx2) else: raise NotImplementedError("CPU version not implemented") ctx.opt = opt ctx.save_for_backward(p1,p2,dist1,dist2,idx1,idx2) return dist1,dist2 @staticmethod def backward(ctx,grad_dist1,grad_dist2): opt = ctx.opt p1,p2,dist1,dist2,idx1,idx2 = ctx.saved_tensors grad_p1 = torch.zeros_like(p1) grad_p2 = torch.zeros_like(p2) if "cuda" in opt.device: chamfer.backward(p1,p2,grad_p1,grad_p2,grad_dist1,grad_dist2,idx1,idx2) else: raise NotImplementedError("CPU version not implemented") return None,grad_p1,grad_p2