import torch from math import copysign nChannels = 128 nStack = 2 nModules = 2 nRegFrames = 8 nJoints = 16 scheme = 1 tempKernel = 3 mult = 0.1 def inflatePose3D(model3d, model): inflateHourglassNet(model3d.hg, model) inflateDepthRegressor(model3d.dr, model) def inflateDepthRegressor(model3d, model): for i in range(4): inflateResidual(model3d.reg[3*i], model.reg_[2*i]) inflateResidual(model3d.reg[3*i+1], model.reg_[2*i+1]) inflateMaxPool(model3d.reg[3*i+2], model.maxpool) inflateFullyConnected(model3d.fc, model.reg) def inflateFullyConnected(model3d, model): val = 4*4*nChannels for i in range(1): model3d.bias.data[nJoints*i:nJoints*(i+1)] = model.bias.data for j in range(nRegFrames): if (j == 1) : model3d.weight.data[nJoints*(i):nJoints*(i+1), val*(j):val*(j+1)] = model.weight.data #/ (1.0*nRegFrames) elif j<1 : model3d.weight.data[nJoints*(i):nJoints*(i+1), val*(j):val*(j+1)] = model.weight.data * mult#/ (1.0*nRegFrames) elif j>1 : model3d.weight.data[nJoints*(i):nJoints*(i+1), val*(j):val*(j+1)] = model.weight.data * mult#/ (1.0*nRegFrames) def inflateHourglassNet(model3d, model): inflateconv(model3d.convStart, model.conv1_) model3d.bnStart.bn = inflatebn(model3d.bnStart, model.bn1) inflaterelu(model3d.reluStart, model.relu) inflateResidual(model3d.res1, model.r1) inflateResidual(model3d.res2, model.r4) inflateResidual(model3d.res3, model.r5) inflateMaxPool(model3d.mp, model.maxpool) for i in range(nStack): inflatehourglass(model3d.hourglass[i], model.hourglass[i]) for i in range(nStack): for j in range(nModules): inflateResidual(model3d.Residual[i][j],model.Residual[nModules*i+j]) for i in range(nStack): inflateconv(model3d.lin1[i][0], model.lin_[i][0]) model3d.lin1[i][1].bn = inflatebn(model3d.lin1[i][1].bn, model.lin_[i][1]) inflaterelu(model3d.lin1[i][2], model.lin_[i][2]) for i in range(nStack): inflateconv(model3d.chantojoints[i], model.tmpOut[i]) inflateconv(model3d.lin2[i], model.ll_[i]) inflateconv(model3d.jointstochan[i], model.tmpOut_[i]) return def inflatehourglass(model3d, model): for i in range(nModules): inflateResidual(model3d.skip[i], model.up1_[i]) inflateMaxPool(model3d.mp, model.low1) for i in range(nModules): inflateResidual(model3d.afterpool[i], model.low1_[i]) if model3d.numReductions > 1: inflatehourglass(model3d.hg, model.low2) else : for i in range(nModules): inflateResidual(model3d.num1res[i], model.low2_[i]) for i in range(nModules): inflateResidual(model3d.lowres[i], model.low3_[i]) inflateupsampling(model3d.up, model.up2) return def inflateconv(conv3d, conv): tempSize = conv3d.conv.weight.data.size()[2] center = (tempSize-1)//2 if scheme==1: factor = torch.FloatTensor([copysign(mult**abs(center-i), center-i) for i in range(tempSize)]).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand_as(conv3d.conv.weight).cuda() conv3d.conv.weight.data = conv.weight.data[:,:,None,:,:].expand_as(conv3d.conv.weight).clone() * factor elif scheme==3: conv3d.conv.weight.data = conv.weight.data[:,:,None,:,:].expand_as(conv3d.conv.weight).clone() * (1./tempSize) conv3d.conv.bias.data = conv.bias.data conv3d.conv.weight.data = conv3d.conv.weight.data.contiguous() conv3d.conv.bias.data = conv3d.conv.bias.data.contiguous() return def inflatebn(bn3d, bn): """ bn3d.bn.weight.data = bn.weight.data bn3d.bn.bias.data = bn.bias.data bn3d.bn.running_mean = bn.running_mean bn3d.bn.running_var = bn.running_var bn3d.bn.weight.data = bn3d.bn.weight.data.contiguous() bn3d.bn.weight.data = bn3d.bn.weight.data.contiguous() bn3d.bn.running_mean = bn3d.bn.running_mean.contiguous() bn3d.bn.running_var = bn3d.bn.running_var.contiguous() """ bn.track_running_stats = True return bn def inflaterelu(relu3d, relu): return def inflateMaxPool(mp3d, mp): return def inflateResidual(res3d, res): res3d.cb.cbr1.bn.bn = inflatebn(res3d.cb.cbr1.bn, res.bn) inflaterelu(res3d.cb.cbr1.relu, res.relu) inflateconv(res3d.cb.cbr1.conv, res.conv1) res3d.cb.cbr2.bn.bn = inflatebn(res3d.cb.cbr2.bn, res.bn1) inflaterelu(res3d.cb.cbr2.relu, res.relu) inflateconv(res3d.cb.cbr2.conv, res.conv2) res3d.cb.cbr3.bn.bn = inflatebn(res3d.cb.cbr3.bn, res.bn2) inflaterelu(res3d.cb.cbr3.relu, res.relu) inflateconv(res3d.cb.cbr3.conv, res.conv3) if (res3d.inChannels != res3d.outChannels): inflateconv(res3d.skip.conv, res.conv4) return def inflateupsampling(up3d, up): return