# -*- coding: utf-8 -*- """ Feature models of interest """ from torch import nn from torchvision import models # feature model based on resnet architecture class resnet_model(nn.Module): def __init__(self, model_type='resnet50', layer_type='layer4'): super().__init__() # get model if model_type == 'resnet50': original_model = models.resnet50(pretrained=True) elif model_type == 'resnet101': original_model = models.resnet101(pretrained=True) else: raise NameError('Unknown model_type passed') # get requisite layer if layer_type == 'layer2': num_layers = 6 pool_size = 28 elif layer_type == 'layer3': num_layers = 7 pool_size = 14 elif layer_type == 'layer4': num_layers = 8 pool_size = 7 else: raise NameError('Uknown layer_type passed') self.features = nn.Sequential(*list(original_model.children())[:num_layers]) self.avgpool = nn.AvgPool2d(pool_size, stride=1) def forward(self, x): x = self.features(x) x = self.avgpool(x) x = x.view(x.size(0), -1) return x class vgg_model(nn.Module): def __init__(self, model_type='vgg13', layer_type='fc6'): super().__init__() # get model if model_type == 'vgg13': self.original_model = models.vgg13_bn(pretrained=True) elif model_type == 'vgg16': self.original_model = models.vgg16_bn(pretrained=True) else: raise NameError('Unknown model_type passed') self.features = self.original_model.features if layer_type == 'fc6': self.classifier = nn.Sequential(*list(self.original_model.classifier.children())[:2]) elif layer_type == 'fc7': self.classifier = nn.Sequential(*list(self.original_model.classifier.children())[:-2]) else: raise NameError('Uknown layer_type passed') def forward(self, x): x = self.features(x) x = x.view(x.size(0), -1) x = self.classifier(x) return x