__all__ = ['oth_ibppose'] import torch from torch import nn import torch.nn.functional as F class Residual(nn.Module): """Residual Block modified by us""" def __init__(self, ins, outs, bn=True, relu=True): super(Residual, self).__init__() self.relu_flag = relu self.convBlock = nn.Sequential( nn.Conv2d(ins, outs//2, 1, bias=False), nn.BatchNorm2d(outs//2), nn.LeakyReLU(negative_slope=0.01, inplace=True), nn.Conv2d(outs // 2, outs // 2, 3, 1, 1, bias=False), nn.BatchNorm2d(outs // 2), nn.LeakyReLU(negative_slope=0.01, inplace=True), nn.Conv2d(outs // 2, outs, 1, bias=False), nn.BatchNorm2d(outs), ) if ins != outs: self.skipConv = nn.Sequential( nn.Conv2d(ins, outs, 1, bias=False), nn.BatchNorm2d(outs) ) self.relu = nn.LeakyReLU(negative_slope=0.01, inplace=True) self.ins = ins self.outs = outs def forward(self, x): residual = x x = self.convBlock(x) if self.ins != self.outs: residual = self.skipConv(residual) x += residual # Bn layer is in the middle, so we can do in-plcae += here if self.relu_flag: x = self.relu(x) return x else: return x class Conv(nn.Module): # conv block used in hourglass def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=True, relu=True, dropout=False, dialated=1): super(Conv, self).__init__() self.inp_dim = inp_dim self.relu = None self.bn = None self.dropout = dropout if relu: self.relu = nn.LeakyReLU(negative_slope=0.01, inplace=True) # 换成 Leak Relu减缓神经元死亡现象 if bn: self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size - 1) // 2, bias=False, dilation=1) # Different form TF, momentum default in Pytorch is 0.1, which means the decay rate of old running value self.bn = nn.BatchNorm2d(out_dim) else: self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size - 1) // 2, bias=True, dilation=1) def forward(self, x): # examine the input channel equals the conve kernel channel assert x.size()[1] == self.inp_dim, "input channel {} dese not fit kernel channel {}".format(x.size()[1], self.inp_dim) if self.dropout: # comment these two lines if we do not want to use Dropout layers # p: probability of an element to be zeroed x = F.dropout(x, p=0.2, training=self.training, inplace=False) # 直接注释掉这一行,如果我们不想使用Dropout x = self.conv(x) if self.bn is not None: x = self.bn(x) if self.relu is not None: x = self.relu(x) return x class DilatedConv(nn.Module): """ Dilated convolutional layer of stride=1 only! """ def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=True, relu=True, dropout=False, dialation=3): super(DilatedConv, self).__init__() self.inp_dim = inp_dim self.relu = None self.bn = None self.dropout = dropout if relu: self.relu = nn.LeakyReLU(negative_slope=0.01, inplace=True) # 换成 Leak Relu减缓神经元死亡现象 if bn: self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=dialation, bias=False, dilation=dialation) # Different form TF, momentum default in Pytorch is 0.1, which means the decay rate of old running value self.bn = nn.BatchNorm2d(out_dim) else: self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=dialation, bias=True, dilation=dialation) def forward(self, x): # examine the input channel equals the conve kernel channel assert x.size()[1] == self.inp_dim, "input channel {} dese not fit kernel channel {}".format(x.size()[1], self.inp_dim) if self.dropout: # comment these two lines if we do not want to use Dropout layers # p: probability of an element to be zeroed x = F.dropout(x, p=0.2, training=self.training, inplace=False) # 直接注释掉这一行,如果我们不想使用Dropout x = self.conv(x) if self.bn is not None: x = self.bn(x) if self.relu is not None: x = self.relu(x) return x class Backbone(nn.Module): """ Input Tensor: a batch of images with shape (N, C, H, W) """ def __init__(self, nFeat=256, inplanes=3, resBlock=Residual, dilatedBlock=DilatedConv): super(Backbone, self).__init__() self.nFeat = nFeat self.resBlock = resBlock self.inplanes = inplanes self.conv1 = nn.Conv2d(self.inplanes, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.LeakyReLU(negative_slope=0.01, inplace=True) self.res1 = self.resBlock(64, 128) self.pool = nn.MaxPool2d(2, 2) self.res2 = self.resBlock(128, 128) self.dilation = nn.Sequential( dilatedBlock(128, 128, dialation=3), dilatedBlock(128, 128, dialation=3), dilatedBlock(128, 128, dialation=4), dilatedBlock(128, 128, dialation=4), dilatedBlock(128, 128, dialation=5), dilatedBlock(128, 128, dialation=5), ) def forward(self, x): # head x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.res1(x) x = self.pool(x) x = self.res2(x) x1 = self.dilation(x) concat_merge = torch.cat([x, x1], dim=1) # (N, C1+C2, H, W) return concat_merge class Hourglass(nn.Module): """Instantiate an n order Hourglass Network block using recursive trick.""" def __init__(self, depth, nFeat, increase=128, bn=False, resBlock=Residual, convBlock=Conv): super(Hourglass, self).__init__() self.depth = depth # oder number self.nFeat = nFeat # input and output channels self.increase = increase # increased channels while the depth grows self.bn = bn self.resBlock = resBlock self.convBlock = convBlock # will execute when instantiate the Hourglass object, prepare network's parameters self.hg = self._make_hour_glass() self.downsample = nn.MaxPool2d(2, 2) # no learning parameters, can be used any times repeatedly self.upsample = nn.Upsample(scale_factor=2, mode='nearest') # no learning parameters # FIXME: 改成反卷积? def _make_single_residual(self, depth_id): # the innermost conve layer, return as a layer item return self.resBlock(self.nFeat + self.increase * (depth_id + 1), self.nFeat + self.increase * (depth_id + 1), bn=self.bn) # ########### Index: 4 def _make_lower_residual(self, depth_id): # return as a list pack_layers = [self.resBlock(self.nFeat + self.increase * depth_id, self.nFeat + self.increase * depth_id, bn=self.bn), # ######### Index: 0 self.resBlock(self.nFeat + self.increase * depth_id, self.nFeat + self.increase * (depth_id + 1), # ######### Index: 1 bn=self.bn), self.resBlock(self.nFeat + self.increase * (depth_id + 1), self.nFeat + self.increase * depth_id, # ######### Index: 2 bn=self.bn), self.convBlock(self.nFeat + self.increase * depth_id, self.nFeat + self.increase * depth_id, # ######### Index: 3 bn=self.bn), # 添加一个Conv精细化上采样的特征图? ] return pack_layers def _make_hour_glass(self): """ pack conve layers modules of hourglass block :return: conve layers packed in n hourglass blocks """ hg = [] for i in range(self.depth): # skip path; up_residual_block; down_residual_block_path, # 0 ~ n-2 (except the outermost n-1 order) need 3 residual blocks res = self._make_lower_residual(i) # type:list if i == (self.depth - 1): # the deepest path (i.e. the longest path) need 4 residual blocks res.append(self._make_single_residual(i)) # list append an element hg.append(nn.ModuleList(res)) # pack conve layers of every oder of hourglass block return nn.ModuleList(hg) def _hour_glass_forward(self, depth_id, x, up_fms): """ built an hourglass block whose order is depth_id :param depth_id: oder number of hourglass block :param x: input tensor :return: output tensor through an hourglass block """ up1 = self.hg[depth_id][0](x) low1 = self.downsample(x) low1 = self.hg[depth_id][1](low1) if depth_id == (self.depth - 1): # except for the highest-order hourglass block low2 = self.hg[depth_id][4](low1) else: # call the lower-order hourglass block recursively low2 = self._hour_glass_forward(depth_id + 1, low1, up_fms) low3 = self.hg[depth_id][2](low2) up_fms.append(low2) # ######################## # if we don't consider 8*8 scale # if depth_id < self.depth - 1: # self.up_fms.append(low2) up2 = self.upsample(low3) deconv1 = self.hg[depth_id][3](up2) # deconv2 = self.hg[depth_id][4](deconv1) # up1 += deconv2 # out = self.hg[depth_id][5](up1) # relu after residual add return up1 + deconv1 def forward(self, x): """ :param: x a input tensor warpped wrapped as a list :return: 5 different scales of feature maps, 128*128, 64*64, 32*32, 16*16, 8*8 """ up_fms = [] # collect feature maps produced by low2 at every scale feature_map = self._hour_glass_forward(0, x, up_fms) return [feature_map] + up_fms[::-1] class SELayer(nn.Module): def __init__(self, inp_dim, reduction=16): """ Squeeze and Excitation :param inp_dim: the channel of input tensor :param reduction: channel compression ratio :return output the tensor with the same shape of input """ # assert inp_dim > reduction, f"Make sure your input channel bigger than reduction which equals to {reduction}" super(SELayer, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(inp_dim, inp_dim // reduction), nn.LeakyReLU(inplace=True), # Relu nn.Linear(inp_dim // reduction, inp_dim), nn.Sigmoid()) def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y # def forward(self, x): # 去掉Selayer # return x class Merge(nn.Module): """Change the channel dimension of the input tensor""" def __init__(self, x_dim, y_dim, bn=False): super(Merge, self).__init__() self.conv = Conv(x_dim, y_dim, 1, relu=False, bn=bn) def forward(self, x): return self.conv(x) class Features(nn.Module): """Input: feature maps produced by hourglass block Return: 5 different scales of feature maps, 128*128, 64*64, 32*32, 16*16, 8*8""" def __init__(self, inp_dim, increase=128, bn=False): super(Features, self).__init__() # Regress 5 different scales of heatmaps per stack self.before_regress = nn.ModuleList( [nn.Sequential(Conv(inp_dim + i * increase, inp_dim, 3, bn=bn, dropout=False), Conv(inp_dim, inp_dim, 3, bn=bn, dropout=False), # ##################### Channel Attention layer ##################### SELayer(inp_dim), ) for i in range(5)]) def forward(self, fms): assert len(fms) == 5, "hourglass output {} tensors,but 5 scale heatmaps are supervised".format(len(fms)) return [self.before_regress[i](fms[i]) for i in range(5)] class PoseNet(nn.Module): def __init__(self, nstack, inp_dim, oup_dim, bn=False, increase=128, init_weights=True, **kwargs): """ Pack or initialize the trainable parameters of the network :param nstack: number of stack :param inp_dim: input tensor channels fed into the hourglass block :param oup_dim: channels of regressed feature maps :param bn: use batch normalization :param increase: increased channels once down-sampling :param kwargs: """ super(PoseNet, self).__init__() # self.pre = nn.Sequential( # Conv(3, 64, 7, 2, bn=bn), # Conv(64, 128, bn=bn), # nn.MaxPool2d(2, 2), # Conv(128, 128, bn=bn), # Conv(128, inp_dim, bn=bn) # ) self.pre = Backbone(nFeat=inp_dim) # It doesn't affect the results regardless of which self.pre is used self.hourglass = nn.ModuleList([Hourglass(4, inp_dim, increase, bn=bn) for _ in range(nstack)]) self.features = nn.ModuleList([Features(inp_dim, increase=increase, bn=bn) for _ in range(nstack)]) # predict 5 different scales of heatmpas per stack, keep in mind to pack the list using ModuleList. # Notice: nn.ModuleList can only identify Module subclass! Thus, we must pack the inner layers in ModuleList. # TODO: change the outs layers, Conv(inp_dim + j * increase, oup_dim, 1, relu=False, bn=False) self.outs = nn.ModuleList( [nn.ModuleList([Conv(inp_dim, oup_dim, 1, relu=False, bn=False) for j in range(5)]) for i in range(nstack)]) # TODO: change the merge layers, Merge(inp_dim + j * increase, inp_dim + j * increase) self.merge_features = nn.ModuleList( [nn.ModuleList([Merge(inp_dim, inp_dim + j * increase, bn=bn) for j in range(5)]) for i in range(nstack - 1)]) self.merge_preds = nn.ModuleList( [nn.ModuleList([Merge(oup_dim, inp_dim + j * increase, bn=bn) for j in range(5)]) for i in range(nstack - 1)]) self.nstack = nstack if init_weights: self._initialize_weights() def forward(self, imgs): # Input Tensor: a batch of images within [0,1], shape=(N, H, W, C). Pre-processing was done in data generator # x = imgs.permute(0, 3, 1, 2) # Permute the dimensions of images to (N, C, H, W) x = imgs x = self.pre(x) pred = [] # loop over stack for i in range(self.nstack): preds_instack = [] # return 5 scales of feature maps hourglass_feature = self.hourglass[i](x) if i == 0: # cache for smaller feature maps produced by hourglass block features_cache = [torch.zeros_like(hourglass_feature[scale]) for scale in range(5)] else: # residual connection across stacks # python里面的+=, ,*=也是in-place operation,需要注意 hourglass_feature = [hourglass_feature[scale] + features_cache[scale] for scale in range(5)] # feature maps before heatmap regression features_instack = self.features[i](hourglass_feature) for j in range(5): # handle 5 scales of heatmaps preds_instack.append(self.outs[i][j](features_instack[j])) if i != self.nstack - 1: if j == 0: x = x + self.merge_preds[i][j](preds_instack[j]) + self.merge_features[i][j]( features_instack[j]) # input tensor for next stack features_cache[j] = self.merge_preds[i][j](preds_instack[j]) + self.merge_features[i][j]( features_instack[j]) else: # reset the res caches features_cache[j] = self.merge_preds[i][j](preds_instack[j]) + self.merge_features[i][j]( features_instack[j]) pred.append(preds_instack) # returned list shape: [nstack * [batch*128*128, batch*64*64, batch*32*32, batch*16*16, batch*8*8]]z # return pred return pred[-1][0] def _initialize_weights(self): for m in self.modules(): # 卷积的初始化方法 if isinstance(m, nn.Conv2d): # TODO: 使用正态分布进行初始化(0, 0.01) 网络权重看看 # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels # He kaiming 初始化, 方差为2/n. math.sqrt(2. / n) 或者直接使用现成的nn.init中的函数。在这里会梯度爆炸 m.weight.data.normal_(0, 0.001) # # math.sqrt(2. / n) # torch.nn.init.kaiming_normal_(m.weight) # bias都初始化为0 if m.bias is not None: # 当有BN层时,卷积层Con不加bias! m.bias.data.zero_() # batchnorm使用全1初始化 bias全0 elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): torch.nn.init.normal_(m.weight.data, 0, 0.01) # todo: 0.001? # m.weight.data.normal_(0, 0.01) m.bias.data.zero_() def oth_ibppose(pretrained=False, num_classes=3, in_channels=3, **kwargs): model = PoseNet(4, 256, 50, bn=True, **kwargs) return model def _calc_width(net): import numpy as np net_params = filter(lambda p: p.requires_grad, net.parameters()) weight_count = 0 for param in net_params: weight_count += np.prod(param.size()) return weight_count def load_model(net, file_path, ignore_extra=True): """ Load model state dictionary from a file. Parameters ---------- net : Module Network in which weights are loaded. file_path : str Path to the file. ignore_extra : bool, default True Whether to silently ignore parameters from the file that are not present in this Module. """ import torch if ignore_extra: pretrained_state = torch.load(file_path) model_dict = net.state_dict() pretrained_state = {k: v for k, v in pretrained_state.items() if k in model_dict} net.load_state_dict(pretrained_state) else: net.load_state_dict(torch.load(file_path)) def _test(): import torch pretrained = False models = [ oth_ibppose, ] for model in models: net = model(pretrained=pretrained) # net.train() net.eval() weight_count = _calc_width(net) print("m={}, {}".format(model.__name__, weight_count)) assert (model != oth_ibppose or weight_count == 128998760) x = torch.randn(14, 3, 256, 256) y = net(x) y.sum().backward() assert (tuple(y.size()) == (14, 50, 64, 64)) if __name__ == "__main__": _test()