from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as functional from inplace_abn import ABN from seamseg.utils.misc import try_index class FPNROIHead(nn.Module): """ROI head module for FPN Parameters ---------- in_channels : int Number of input channels classes : dict Dictionary with the number of classes in the dataset -- expected keys: "total", "stuff", "thing" roi_size : tuple of int `(height, width)` of the ROIs extracted from the input feature map, these will be average-pooled 2x2 before feeding to the rest of the head hidden_channels : int Number of channels in the hidden layers norm_act : callable Function to create normalization + activation modules """ def __init__(self, in_channels, classes, roi_size, hidden_channels=1024, norm_act=ABN): super(FPNROIHead, self).__init__() self.fc = nn.Sequential(OrderedDict([ ("fc1", nn.Linear(int(roi_size[0] * roi_size[1] * in_channels / 4), hidden_channels, bias=False)), ("bn1", norm_act(hidden_channels)), ("fc2", nn.Linear(hidden_channels, hidden_channels, bias=False)), ("bn2", norm_act(hidden_channels)) ])) self.roi_cls = nn.Linear(hidden_channels, classes["thing"] + 1) self.roi_bbx = nn.Linear(hidden_channels, classes["thing"] * 4) self.reset_parameters() def reset_parameters(self): gain = nn.init.calculate_gain(self.fc.bn1.activation, self.fc.bn1.activation_param) for name, mod in self.named_modules(): if isinstance(mod, nn.Linear): if "roi_cls" in name: nn.init.xavier_normal_(mod.weight, .01) elif "roi_bbx" in name: nn.init.xavier_normal_(mod.weight, .001) else: nn.init.xavier_normal_(mod.weight, gain) elif isinstance(mod, ABN): nn.init.constant_(mod.weight, 1.) if hasattr(mod, "bias") and mod.bias is not None: nn.init.constant_(mod.bias, 0.) def forward(self, x): """ROI head module for FPN Parameters ---------- x : torch.Tensor A tensor of input features with shape N x C x H x W Returns ------- cls_logits : torch.Tensor A tensor of classification logits with shape S x (num_thing + 1) bbx_logits : torch.Tensor A tensor of class-specific bounding box regression logits with shape S x num_thing x 4 """ x = functional.avg_pool2d(x, 2) # Run head x = self.fc(x.view(x.size(0), -1)) return self.roi_cls(x), self.roi_bbx(x).view(x.size(0), -1, 4) class FPNMaskHead(nn.Module): """ROI head module for FPN Parameters ---------- in_channels : int Number of input channels classes : dict Dictionary with the number of classes in the dataset -- expected keys: "total", "stuff", "thing" roi_size : tuple of int `(height, width)` of the ROIs extracted from the input feature map, these will be average-pooled 2x2 before feeding to the fully-connected branch fc_hidden_channels : int Number of channels in the hidden layers of the fully-connected branch conv_hidden_channels : int Number of channels in the hidden layers of the convolutional branch norm_act : callable Function to create normalization + activation modules """ def __init__(self, in_channels, classes, roi_size, fc_hidden_channels=1024, conv_hidden_channels=256, norm_act=ABN): super(FPNMaskHead, self).__init__() # ROI section self.fc = nn.Sequential(OrderedDict([ ("fc1", nn.Linear(int(roi_size[0] * roi_size[1] * in_channels / 4), fc_hidden_channels, bias=False)), ("bn1", norm_act(fc_hidden_channels)), ("fc2", nn.Linear(fc_hidden_channels, fc_hidden_channels, bias=False)), ("bn2", norm_act(fc_hidden_channels)) ])) self.roi_cls = nn.Linear(fc_hidden_channels, classes["thing"] + 1) self.roi_bbx = nn.Linear(fc_hidden_channels, classes["thing"] * 4) # Mask section self.conv = nn.Sequential(OrderedDict([ ("conv1", nn.Conv2d(in_channels, conv_hidden_channels, 3, padding=1, bias=False)), ("bn1", norm_act(conv_hidden_channels)), ("conv2", nn.Conv2d(conv_hidden_channels, conv_hidden_channels, 3, padding=1, bias=False)), ("bn2", norm_act(conv_hidden_channels)), ("conv3", nn.Conv2d(conv_hidden_channels, conv_hidden_channels, 3, padding=1, bias=False)), ("bn3", norm_act(conv_hidden_channels)), ("conv4", nn.Conv2d(conv_hidden_channels, conv_hidden_channels, 3, padding=1, bias=False)), ("bn4", norm_act(conv_hidden_channels)), ("conv_up", nn.ConvTranspose2d(conv_hidden_channels, conv_hidden_channels, 2, stride=2, bias=False)), ("bn_up", norm_act(conv_hidden_channels)) ])) self.roi_msk = nn.Conv2d(conv_hidden_channels, classes["thing"], 1) self.reset_parameters() def reset_parameters(self): gain = nn.init.calculate_gain(self.fc.bn1.activation, self.fc.bn1.activation_param) for name, mod in self.named_modules(): if isinstance(mod, nn.Linear) or isinstance(mod, nn.Conv2d) or isinstance(mod, nn.ConvTranspose2d): if "roi_cls" in name or "roi_msk" in name: nn.init.xavier_normal_(mod.weight, .01) elif "roi_bbx" in name: nn.init.xavier_normal_(mod.weight, .001) else: nn.init.xavier_normal_(mod.weight, gain) elif isinstance(mod, ABN): nn.init.constant_(mod.weight, 1.) if hasattr(mod, "bias") and mod.bias is not None: nn.init.constant_(mod.bias, 0.) def forward(self, x, do_cls_bbx=True, do_msk=True): """ROI head module for FPN Parameters ---------- x : torch.Tensor A tensor of input features with shape N x C x H x W do_cls_bbx : bool Whether to compute or not the class and bounding box regression predictions do_msk : bool Whether to compute or not the mask predictions Returns ------- cls_logits : torch.Tensor A tensor of classification logits with shape S x (num_thing + 1) bbx_logits : torch.Tensor A tensor of class-specific bounding box regression logits with shape S x num_thing x 4 msk_logits : torch.Tensor A tensor of class-specific mask logits with shape S x num_thing x (H_roi * 2) x (W_roi * 2) """ # Run fully-connected head if do_cls_bbx: x_fc = functional.avg_pool2d(x, 2) x_fc = self.fc(x_fc.view(x_fc.size(0), -1)) cls_logits = self.roi_cls(x_fc) bbx_logits = self.roi_bbx(x_fc).view(x_fc.size(0), -1, 4) else: cls_logits = None bbx_logits = None # Run convolutional head if do_msk: x = self.conv(x) msk_logits = self.roi_msk(x) else: msk_logits = None return cls_logits, bbx_logits, msk_logits class FPNSemanticHeadDeeplab(nn.Module): """Semantic segmentation head for FPN-style networks, extending Deeplab v3 for FPN bodies""" class _MiniDL(nn.Module): def __init__(self, in_channels, out_channels, dilation, pooling_size, norm_act): super(FPNSemanticHeadDeeplab._MiniDL, self).__init__() self.pooling_size = pooling_size self.conv1_3x3 = nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False) self.conv1_dil = nn.Conv2d(in_channels, out_channels, 3, dilation=dilation, padding=dilation, bias=False) self.conv1_glb = nn.Conv2d(in_channels, out_channels, 1, bias=False) self.bn1 = norm_act(out_channels * 3) self.conv2 = nn.Conv2d(out_channels * 3, out_channels, 1, bias=False) self.bn2 = norm_act(out_channels) def _global_pooling(self, x): pooling_size = (min(try_index(self.pooling_size, 0), x.shape[2]), min(try_index(self.pooling_size, 1), x.shape[3])) padding = ( (pooling_size[1] - 1) // 2, (pooling_size[1] - 1) // 2 if pooling_size[1] % 2 == 1 else (pooling_size[1] - 1) // 2 + 1, (pooling_size[0] - 1) // 2, (pooling_size[0] - 1) // 2 if pooling_size[0] % 2 == 1 else (pooling_size[0] - 1) // 2 + 1 ) pool = functional.avg_pool2d(x, pooling_size, stride=1) pool = functional.pad(pool, pad=padding, mode="replicate") return pool def forward(self, x): x = torch.cat([ self.conv1_3x3(x), self.conv1_dil(x), self.conv1_glb(self._global_pooling(x)), ], dim=1) x = self.bn1(x) x = self.conv2(x) x = self.bn2(x) return x def __init__(self, in_channels, min_level, levels, num_classes, hidden_channels=128, dilation=6, pooling_size=(64, 64), norm_act=ABN, interpolation="bilinear"): super(FPNSemanticHeadDeeplab, self).__init__() self.min_level = min_level self.levels = levels self.interpolation = interpolation self.output = nn.ModuleList([ self._MiniDL(in_channels, hidden_channels, dilation, pooling_size, norm_act) for _ in range(levels) ]) self.conv_sem = nn.Conv2d(hidden_channels * levels, num_classes, 1) self.reset_parameters() def reset_parameters(self): gain = nn.init.calculate_gain(self.output[0].bn1.activation, self.output[0].bn1.activation_param) for name, mod in self.named_modules(): if isinstance(mod, nn.Conv2d): if "conv_sem" not in name: nn.init.xavier_normal_(mod.weight, gain) else: nn.init.xavier_normal_(mod.weight, .1) elif isinstance(mod, ABN): nn.init.constant_(mod.weight, 1.) if hasattr(mod, "bias") and mod.bias is not None: nn.init.constant_(mod.bias, 0.) def forward(self, xs): xs = xs[self.min_level:self.min_level + self.levels] ref_size = xs[0].shape[-2:] interp_params = {"mode": self.interpolation} if self.interpolation == "bilinear": interp_params["align_corners"] = False for i, output in enumerate(self.output): xs[i] = output(xs[i]) if i > 0: xs[i] = functional.interpolate(xs[i], size=ref_size, **interp_params) xs = torch.cat(xs, dim=1) xs = self.conv_sem(xs) return xs