import torch from torch.utils.data import DataLoader from torch.autograd import Variable from torch import nn import torch.optim as optim import os import json import argparse from dataloader import CocoDataset import pretrainedmodels from pretrainedmodels import utils C, H, W = 3, 224, 224 class MILModel(nn.Module): def __init__(self, cnn_model, dim_hidden, num_classes): # python 3 # super().__init__() super(MILModel, self).__init__() self.cnn_model = cnn_model self.num_classes = num_classes self.dim_hidden = dim_hidden self.linear = nn.Linear(dim_hidden, num_classes) def forward(self, x): feature_map = self.cnn_model.features(x) feature_map = feature_map.permute(0, 2, 3, 1) b, x, y, h = feature_map.size() feature_map = feature_map.contiguous().view(b, x * y, h) logits = self.linear(feature_map) logits = 1 - logits probs = Variable(torch.ones(logits.shape[0], logits.shape[2])).cuda() for i in range(x * y): probs = probs * logits[:, i, :] probs = 1 - probs return probs def train(dataloader, model, crit, optimizer, lr_scheduler, load_image_fn, params): model.train() model = nn.DataParallel(model) images_path = json.load(open(params.coco_path)) for epoch in range(params.epochs): lr_scheduler.step() iteration = 0 for data in dataloader: iteration += 1 image_ids, image_labels = data['image_ids'], data['labels'] images = torch.zeros(image_labels.shape[0], C, H, W) for i, image_id in enumerate(image_ids): image_path = os.path.join( params.coco_dir, images_path[image_id]) images[i] = load_image_fn(image_path) logits = model(Variable(images).cuda()) loss = crit(logits, Variable(image_labels).cuda()) optimizer.zero_grad() loss.backward() optimizer.step() train_loss = loss.data[0] torch.cuda.synchronize() print("iter %d (epoch %d), train_loss = %.6f" % (iteration, epoch, train_loss)) if epoch % params.save_checkpoint_every == 0: checkpoint_path = os.path.join( params.checkpoint_path, 'cnn_model_%d.pth' % (epoch)) torch.save(model.state_dict(), checkpoint_path) print("model saved to %s" % (checkpoint_path)) def main(args): global C, H, W coco_labels = json.load(open(args.coco_labels)) num_classes = coco_labels['num_classes'] if args.model == 'inception_v3': C, H, W = 3, 299, 299 model = pretrainedmodels.inceptionv3(pretrained='imagenet') elif args.model == 'resnet152': C, H, W = 3, 224, 224 model = pretrainedmodels.resnet152(pretrained='imagenet') elif args.model == 'inception_v4': C, H, W = 3, 299, 299 model = pretrainedmodels.inceptionv4( num_classes=1000, pretrained='imagenet') else: print("doesn't support %s" % (args['model'])) load_image_fn = utils.LoadTransformImage(model) dim_feats = model.last_linear.in_features model = MILModel(model, dim_feats, num_classes) model = model.cuda() dataset = CocoDataset(coco_labels) dataloader = DataLoader( dataset, batch_size=args.batch_size, shuffle=True) optimizer = optim.Adam( model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.learning_rate_decay_every, gamma=args.learning_rate_decay_rate) crit = nn.MultiLabelSoftMarginLoss() if not os.path.isdir(args.checkpoint_path): os.mkdir(args.checkpoint_path) train(dataloader, model, crit, optimizer, exp_lr_scheduler, load_image_fn, args) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--coco_path', type=str, default='data/coco_path.json', help='') parser.add_argument('--coco_labels', type=str, default='data/coco_labels.json', help='path to processed coco caption json') parser.add_argument('--coco_dir', type=str, default='data/mscoco/train2014') parser.add_argument('--epochs', type=int, default=200, help='number of epochs') parser.add_argument('--checkpoint_path', type=str, help='path to trained model') parser.add_argument("--gpu", dest='gpu', type=str, default='0', help='Set CUDA_VISIBLE_DEVICES environment variable, optional') parser.add_argument("--model", dest="model", type=str, default='resnet152', help='the CNN model you want to use to extract_feats') parser.add_argument('--save_checkpoint_every', type=int, default=20, help='how often to save a model checkpoint (in epoch)?') parser.add_argument('--batch_size', type=int, default=512) parser.add_argument('--learning_rate', type=float, default=1e-5, help='learning rate') parser.add_argument('--learning_rate_decay_every', type=int, default=2, help='every how many epoch thereafter to drop LR?') parser.add_argument('--learning_rate_decay_rate', type=float, default=0.8) parser.add_argument('--optim_alpha', type=float, default=0.9, help='alpha for adam') parser.add_argument('--optim_beta', type=float, default=0.999, help='beta used for adam') parser.add_argument('--optim_epsilon', type=float, default=1e-8, help='epsilon that goes into denominator for smoothing') parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight_decay') args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu main(args)