import argparse import torch from utils.DataPrepare_ex import load_data from resnet50 import resnet50,resnet101,resnet34 import os import scipy.io as sio import numpy as np os.environ["CUDA_VISIBLE_DEVICES"] = "0" def trainFeature(model, train_loader, device): model.eval() feats = torch.empty(len(train_loader.dataset), 2048) labels = torch.empty(len(train_loader.dataset), 1) with torch.no_grad(): for batch_idx, sample in enumerate(train_loader): data = sample[0].to(device) label = torch.from_numpy(np.array(sample[1])).to(device) cnn_feat = model(data)[0] feats[batch_idx*64:(batch_idx+1)*64, :] = cnn_feat labels[batch_idx*64:(batch_idx+1)*64, 0] = label sio_content = {"features":feats.numpy(), "label":labels.numpy()} sio.savemat("/home/xd133/ZJL_Fusai/Feature_1029/train.mat", sio_content) def testFeature(model, test_loader, device): model.eval() feats = torch.empty(len(test_loader.dataset), 2048) with torch.no_grad(): for batch_idx, sample in enumerate(test_loader): data = sample[0].to(device) cnn_feat = model(data)[0] feats[batch_idx*64:(batch_idx+1)*64, :] = cnn_feat print("the batchidx is %d" % batch_idx) sio_content = {"features":feats.numpy()} sio.savemat("/home/xd133/ZJL_Fusai/Feature_1029/test.mat", sio_content) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Pytorch baseline') parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--test', type=bool, default=False, help='only test?') parser.add_argument('--weights', type=str, default='/home/xd133/ZJL_Fusai/output_1027_3_3/cls99.pt', help='pretrained model path') args = parser.parse_args() device = torch.device('cuda') train_loader = load_data(batch_size=args.batch_size, alldata=True)['train'] val_loader = load_data(batch_size=args.batch_size, alldata=True)['val'] test_loader = load_data(batch_size=args.batch_size, alldata=True)['test'] model = resnet101(num_classes=365).to(device) model.avgpool = torch.nn.AdaptiveAvgPool2d(output_size=1) model.fc = torch.nn.Linear(model.fc.in_features, 365) model.to(device) model.load_state_dict(torch.load(args.weights)) if args.test: #是否是训练集或者验证集的特征 testFeature(model, test_loader, device) else: trainFeature(model, train_loader, device)