import math from collections import OrderedDict import torch import torch.nn.functional as F import torch.utils.model_zoo as model_zoo from torch import nn from torch.nn import Parameter from torchsummary import summary from config import device, num_classes __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', } def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.PReLU()(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) self.relu = nn.PReLU() self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class SEBlock(nn.Module): def __init__(self, channel, reduction=16): super(SEBlock, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel), nn.Sigmoid() ) def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y class IRBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): super(IRBlock, self).__init__() self.bn0 = nn.BatchNorm2d(inplanes) self.conv1 = conv3x3(inplanes, inplanes) self.bn1 = nn.BatchNorm2d(inplanes) self.prelu = nn.PReLU() self.conv2 = conv3x3(inplanes, planes, stride) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample self.stride = stride self.use_se = use_se if self.use_se: self.se = SEBlock(planes) def forward(self, x): residual = x out = self.bn0(x) out = self.conv1(out) out = self.bn1(out) out = self.prelu(out) out = self.conv2(out) out = self.bn2(out) if self.use_se: out = self.se(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.prelu(out) return out class ResNet(nn.Module): def __init__(self, block, layers, use_se=True): self.inplanes = 64 self.use_se = use_se super(ResNet, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, bias=False) self.bn1 = nn.BatchNorm2d(64) self.prelu = nn.PReLU() self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) self.bn2 = nn.BatchNorm2d(512) self.dropout = nn.Dropout() # self.fc = nn.Linear(512 * 7 * 6, 512) #修改尺寸 self.fc = nn.Linear(512 * 7 * 7, 512) self.bn3 = nn.BatchNorm1d(512) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.xavier_normal_(m.weight) elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.xavier_normal_(m.weight) nn.init.constant_(m.bias, 0) def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se)) self.inplanes = planes for i in range(1, blocks): layers.append(block(self.inplanes, planes, use_se=self.use_se)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.prelu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.bn2(x) x = self.dropout(x) x = x.view(x.size(0), -1) x = self.fc(x) x = self.bn3(x) # x = F.normalize(x) return x def resnet18(args, **kwargs): model = ResNet(IRBlock, [2, 2, 2, 2], use_se=args.use_se, **kwargs) if args.pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) return model def resnet34(args, **kwargs): model = ResNet(IRBlock, [3, 4, 6, 3], use_se=args.use_se, **kwargs) if args.pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) return model def resnet50(args, **kwargs): model = ResNet(IRBlock, [3, 4, 6, 3], use_se=args.use_se, **kwargs) if args.pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) return model def resnet101(args, **kwargs): model = ResNet(IRBlock, [3, 4, 23, 3], use_se=args.use_se, **kwargs) if args.pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) return model def resnet152(args, **kwargs): model = ResNet(IRBlock, [3, 8, 36, 3], use_se=args.use_se, **kwargs) if args.pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) return model def resnet_face18(use_se=True, **kwargs): model = ResNet(IRBlock, [2, 2, 2, 2], use_se=use_se, **kwargs) return model class MobileNet(nn.Module): def __init__(self, alpha): self.alpha = alpha super(MobileNet, self).__init__() def conv_bn(inp, oup, stride): return nn.Sequential( nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), nn.ReLU(inplace=True) ) def conv_dw(inp, oup, stride): return nn.Sequential( nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), nn.BatchNorm2d(inp), nn.ReLU(inplace=True), nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.ReLU(inplace=True), ) self.model = nn.Sequential( conv_bn(3, int(32 * self.alpha), 2), conv_dw(int(32 * self.alpha), int(64 * self.alpha), 1), conv_dw(int(64 * self.alpha), int(128 * self.alpha), 2), conv_dw(int(128 * self.alpha), int(128 * self.alpha), 1), conv_dw(int(128 * self.alpha), int(256 * self.alpha), 2), conv_dw(int(256 * self.alpha), int(256 * self.alpha), 1), conv_dw(int(256 * self.alpha), int(512 * self.alpha), 2), conv_dw(int(512 * self.alpha), int(512 * self.alpha), 1), conv_dw(int(512 * self.alpha), int(512 * self.alpha), 1), conv_dw(int(512 * self.alpha), int(512 * self.alpha), 1), conv_dw(int(512 * self.alpha), int(512 * self.alpha), 1), conv_dw(int(512 * self.alpha), int(512 * self.alpha), 1), conv_dw(int(512 * self.alpha), int(1024 * self.alpha), 2), conv_dw(int(1024 * self.alpha), int(1024 * self.alpha), 1), ) self.bn2 = nn.BatchNorm2d(1024) self.dropout = nn.Dropout() self.fc = nn.Linear(1024 * 4 * 4, 512) self.bn3 = nn.BatchNorm1d(512) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.xavier_normal_(m.weight) elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.xavier_normal_(m.weight) nn.init.constant_(m.bias, 0) def forward(self, x): x = self.model(x) x = self.bn2(x) x = self.dropout(x) x = x.view(x.size(0), -1) x = self.fc(x) x = self.bn3(x) return x class ArcMarginModel(nn.Module): def __init__(self, args): super(ArcMarginModel, self).__init__() self.weight = Parameter(torch.FloatTensor(num_classes, args.emb_size)) nn.init.xavier_uniform_(self.weight) self.easy_margin = args.easy_margin self.m = args.margin_m self.s = args.margin_s self.cos_m = math.cos(self.m) self.sin_m = math.sin(self.m) self.th = math.cos(math.pi - self.m) self.mm = math.sin(math.pi - self.m) * self.m def forward(self, input, label): x = F.normalize(input) W = F.normalize(self.weight) cosine = F.linear(x, W) sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m) if self.easy_margin: phi = torch.where(cosine > 0, phi, cosine) else: phi = torch.where(cosine > self.th, phi, cosine - self.mm) one_hot = torch.zeros(cosine.size(), device=device) one_hot.scatter_(1, label.view(-1, 1).long(), 1) output = (one_hot * phi) + ((1.0 - one_hot) * cosine) output *= self.s return output if __name__ == "__main__": # args = parse_args() # model = resnet152(args).to(device) model = MobileNet(1.0).to(device) summary(model, (3, 112, 112))