import os import torch.nn as nn from collections import namedtuple from models.DeepMask import DeepMask from utils.load_helper import load_pretrain Config = namedtuple('Config', ['iSz', 'oSz', 'gSz', 'km', 'ks']) default_config = Config(iSz=160, oSz=56, gSz=160, km=32, ks=32) class RefineModule(nn.Module): def __init__(self, l1, l2, l3): super(RefineModule, self).__init__() self.layer1 = l1 self.layer2 = l2 self.layer3 = l3 def forward(self, x): x1 = self.layer1(x[0]) x2 = self.layer2(x[1]) y = x1 + x2 y = self.layer3(y) return y class SharpMask(nn.Module): def __init__(self, config=default_config, context=True): super(SharpMask, self).__init__() self.context = context # with context self.km, self.ks = config.km, config.ks self.skpos = [6, 5, 4, 2] deepmask = DeepMask(config) deeomask_resume = os.path.join('exps', 'deepmask', 'train', 'model_best.pth.tar') assert os.path.exists(deeomask_resume), "Please train DeepMask first" deepmask = load_pretrain(deepmask, deeomask_resume) self.trunk = deepmask.trunk self.crop_trick = deepmask.crop_trick self.scoreBranch = deepmask.scoreBranch self.maskBranchDM = deepmask.maskBranch self.fSz = deepmask.fSz self.refs = self.createTopDownRefinement() # create refinement modules nph = sum(p.numel() for h in self.neths for p in h.parameters()) / 1e+06 npv = sum(p.numel() for h in self.netvs for p in h.parameters()) / 1e+06 print('| number of paramaters net h: {:.3f} M'.format(nph)) print('| number of paramaters net v: {:.3f} M'.format(npv)) print('| number of paramaters total: {:.3f} M'.format(nph + npv)) def forward(self, x): inps = list() for i, l in enumerate(self.trunk.children()): for j, ll in enumerate(l.children()): x = ll(x) if i == 0 and j == (len(l)-1) and self.context: x = self.crop_trick(x) # print(x.shape) if i == 0 and j in self.skpos: inps.append(x) # forward refinement modules currentOutput = self.refs[0](x) for k in range(len(self.refs)-2): x_f = inps[-(k+1)] currentOutput = self.refs[k+1]((x_f, currentOutput)) currentOutput = self.refs[-1](currentOutput) return currentOutput, self.scoreBranch(x) def train(self, mode=True): self.training = mode if mode: for module in self.children(): module.train(False) for module in self.refs.children(): module.train(mode) else: for module in self.children(): module.train(mode) return self def createHorizontal(self): neths = nn.ModuleList() nhu1, nhu2, crop = 0, 0, 0 for i in range(len(self.skpos)): h = [] nInps = self.ks // 2 ** i if i == 0: nhu1, nhu2, crop = 1024, 64, 0 if self.context else 0 elif i == 1: nhu1, nhu2, crop = 512, 64, -2 if self.context else 0 elif i == 2: nhu1, nhu2, crop = 256, 64, -4 if self.context else 0 elif i == 3: nhu1, nhu2, crop = 64, 64, -8 if self.context else 0 if crop != 0: h.append(nn.ZeroPad2d(crop)) h.append(nn.ReflectionPad2d(1)) h.append(nn.Conv2d(nhu1, nhu2, 3)) h.append(nn.ReLU(inplace=True)) h.append(nn.ReflectionPad2d(1)) h.append(nn.Conv2d(nhu2, nInps, 3)) h.append(nn.ReLU(inplace=True)) h.append(nn.ReflectionPad2d(1)) h.append(nn.Conv2d(nInps, nInps // 2, 3)) neths.append(nn.Sequential(*h)) return neths def createVertical(self): netvs = nn.ModuleList() netvs.append(nn.ConvTranspose2d(512, self.km, self.fSz)) for i in range(len(self.skpos)): netv = [] nInps = self.km // 2 ** i netv.append(nn.ReflectionPad2d(1)) netv.append(nn.Conv2d(nInps, nInps, 3)) netv.append(nn.ReLU(inplace=True)) netv.append(nn.ReflectionPad2d(1)) netv.append(nn.Conv2d(nInps, nInps // 2, 3)) netvs.append(nn.Sequential(*netv)) return netvs def refinement(self, neth, netv): return RefineModule(neth, netv, nn.Sequential(nn.ReLU(inplace=True), nn.UpsamplingNearest2d(scale_factor=2))) def createTopDownRefinement(self): # create horizontal nets self.neths = self.createHorizontal() # create vertical nets self.netvs = self.createVertical() refs = nn.ModuleList() refs.append(self.netvs[0]) for i in range(len(self.skpos)): refs.append(self.refinement(self.neths[i], self.netvs[i+1])) refs.append(nn.Sequential(nn.ReflectionPad2d(1), nn.Conv2d(self.km // 2 ** (len(refs)-1), 1, 3))) return refs if __name__ == '__main__': import torch Config = namedtuple('Config', ['iSz', 'oSz', 'gSz', 'km', 'ks']) config = Config(iSz=160, oSz=56, gSz=160, km=32, ks=32) model = SharpMask(config).cuda() # training mode x = torch.rand(32, 3, config.iSz+32, config.iSz+32).cuda() pred_mask = model(x, True) print("Output (training mode)", pred_mask.shape) # full image testing mode # x = torch.rand(8, 3, config.iSz + 160, config.iSz + 160).cuda() # pred_mask, pred_cls = model(x, train=False) # print("Output (testing mode)", pred_mask.shape, pred_cls.shape)