Python torch.load() Examples
The following are 30 code examples for showing how to use torch.load(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
You may check out the related API usage on the sidebar.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example 1
Project: models Author: kipoi File: model_architecture.py License: MIT License | 7 votes |
def get_model(load_weights = True): deepsea_cpu = nn.Sequential( # Sequential, nn.Conv2d(4,320,(1, 8),(1, 1)), nn.Threshold(0, 1e-06), nn.MaxPool2d((1, 4),(1, 4)), nn.Dropout(0.2), nn.Conv2d(320,480,(1, 8),(1, 1)), nn.Threshold(0, 1e-06), nn.MaxPool2d((1, 4),(1, 4)), nn.Dropout(0.2), nn.Conv2d(480,960,(1, 8),(1, 1)), nn.Threshold(0, 1e-06), nn.Dropout(0.5), Lambda(lambda x: x.view(x.size(0),-1)), # Reshape, nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(50880,925)), # Linear, nn.Threshold(0, 1e-06), nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(925,919)), # Linear, nn.Sigmoid(), ) if load_weights: deepsea_cpu.load_state_dict(torch.load('model_files/deepsea_cpu.pth')) return nn.Sequential(ReCodeAlphabet(), deepsea_cpu)
Example 2
Project: Collaborative-Learning-for-Weakly-Supervised-Object-Detection Author: Sunarker File: train_val.py License: MIT License | 6 votes |
def from_snapshot(self, sfile, nfile): print('Restoring model snapshots from {:s}'.format(sfile)) self.net.load_state_dict(torch.load(str(sfile))) print('Restored.') # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have # tried my best to find the random states so that it can be recovered exactly # However the Tensorflow state is currently not available with open(nfile, 'rb') as fid: st0 = pickle.load(fid) cur = pickle.load(fid) perm = pickle.load(fid) cur_val = pickle.load(fid) perm_val = pickle.load(fid) last_snapshot_iter = pickle.load(fid) np.random.set_state(st0) self.data_layer._cur = cur self.data_layer._perm = perm self.data_layer_val._cur = cur_val self.data_layer_val._perm = perm_val return last_snapshot_iter
Example 3
Project: mmdetection Author: open-mmlab File: regnet2mmdet.py License: Apache License 2.0 | 6 votes |
def convert(src, dst): """Convert keys in pycls pretrained RegNet models to mmdet style.""" # load caffe model regnet_model = torch.load(src) blobs = regnet_model['model_state'] # convert to pytorch style state_dict = OrderedDict() converted_names = set() for key, weight in blobs.items(): if 'stem' in key: convert_stem(key, weight, state_dict, converted_names) elif 'head' in key: convert_head(key, weight, state_dict, converted_names) elif key.startswith('s'): convert_reslayer(key, weight, state_dict, converted_names) # check if all layers are converted for key in blobs: if key not in converted_names: print(f'not converted: {key}') # save checkpoint checkpoint = dict() checkpoint['state_dict'] = state_dict torch.save(checkpoint, dst)
Example 4
Project: models Author: kipoi File: model_architecture.py License: MIT License | 6 votes |
def get_seqpred_model(load_weights = True): deepsea_cpu = nn.Sequential( # Sequential, nn.Conv2d(4,320,(1, 8),(1, 1)), nn.Threshold(0, 1e-06), nn.MaxPool2d((1, 4),(1, 4)), nn.Dropout(0.2), nn.Conv2d(320,480,(1, 8),(1, 1)), nn.Threshold(0, 1e-06), nn.MaxPool2d((1, 4),(1, 4)), nn.Dropout(0.2), nn.Conv2d(480,960,(1, 8),(1, 1)), nn.Threshold(0, 1e-06), nn.Dropout(0.5), Lambda(lambda x: x.view(x.size(0),-1)), # Reshape, nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(50880,925)), # Linear, nn.Threshold(0, 1e-06), nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(925,919)), # Linear, nn.Sigmoid(), ) if load_weights: deepsea_cpu.load_state_dict(torch.load('model_files/deepsea_cpu.pth')) return nn.Sequential(ReCodeAlphabet(), ConcatenateRC(), deepsea_cpu, AverageRC())
Example 5
Project: pytorch_NER_BiLSTM_CNN_CRF Author: bamtercelboo File: test.py License: Apache License 2.0 | 6 votes |
def load_test_model(model, config): """ :param model: initial model :param config: config :return: loaded model """ if config.t_model is None: test_model_dir = config.save_best_model_dir test_model_name = "{}.pt".format(config.model_name) test_model_path = os.path.join(test_model_dir, test_model_name) print("load default model from {}".format(test_model_path)) else: test_model_path = config.t_model print("load user model from {}".format(test_model_path)) model.load_state_dict(torch.load(test_model_path)) return model
Example 6
Project: VSE-C Author: ExplorerFreda File: saliency_visualization.py License: MIT License | 6 votes |
def load_checkpoint(self, checkpoint): checkpoint = torch.load(checkpoint) opt = checkpoint['opt'] opt.use_external_captions = False vocab = Vocab.from_pickle(pjoin(opt.vocab_path, '%s_vocab.pkl' % opt.data_name)) opt.vocab_size = len(vocab) from model import VSE self.model = VSE(opt) self.model.load_state_dict(checkpoint['model']) self.projector = vocab self.model.img_enc.eval() self.model.txt_enc.eval() for p in self.model.img_enc.parameters(): p.requires_grad = False for p in self.model.txt_enc.parameters(): p.requires_grad = False
Example 7
Project: cvpr2018-hnd Author: kibok90 File: utils.py License: MIT License | 6 votes |
def load_model(model, optimizer, scheduler, path, num_epochs, start_time=time.time()): epoch = num_epochs while epoch > 0 and not os.path.isfile('{path}_model_{epoch:d}.pth'.format(path=path, epoch=epoch)): epoch -= 1 if epoch > 0: model_path = '{path}_model_{epoch:d}.pth'.format(path=path, epoch=epoch) model_state_dict = torch.load('{path}_model_{epoch:d}.pth'.format(path=path, epoch=epoch)) model.load_state_dict(model_state_dict) if optimizer is not None: optimizer_state_dict = torch.load('{path}_optimizer_{epoch:d}.pth'.format(path=path, epoch=epoch)) optimizer.load_state_dict(optimizer_state_dict) if scheduler is not None: scheduler_state_dict = torch.load('{path}_scheduler_{epoch:d}.pth'.format(path=path, epoch=epoch)) scheduler.best = scheduler_state_dict['best'] scheduler.cooldown_counter = scheduler_state_dict['cooldown_counter'] scheduler.num_bad_epochs = scheduler_state_dict['num_bad_epochs'] scheduler.last_epoch = scheduler_state_dict['last_epoch'] print('{epoch:4d}/{num_epochs:4d} e; '.format(epoch=epoch, num_epochs=num_epochs), end='') print('load {path}; '.format(path=model_path), end='') print('{time:8.3f} s'.format(time=time.time()-start_time)) return epoch
Example 8
Project: cvpr2018-hnd Author: kibok90 File: models.py License: MIT License | 6 votes |
def init_truncated_normal(model, aux_str=''): if model is None: return None init_path = '{path}/{in_dim:d}_{out_dim:d}{aux_str}.pth' \ .format(path=path, in_dim=model.in_features, out_dim=model.out_features, aux_str=aux_str) if os.path.isfile(init_path): model.load_state_dict(torch.load(init_path)) print('load init weight: {init_path}'.format(init_path=init_path)) else: if isinstance(model, nn.ModuleList): [truncated_normal(sub) for sub in model] else: truncated_normal(model) print('generate init weight: {init_path}'.format(init_path=init_path)) torch.save(model.state_dict(), init_path) print('save init weight: {init_path}'.format(init_path=init_path)) return model
Example 9
Project: pruning_yolov3 Author: zbyuan File: utils.py License: GNU General Public License v3.0 | 6 votes |
def print_mutation(hyp, results, bucket=''): # Print mutation results to evolve.txt (for use with train.py --evolve) a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values c = '%10.3g' * len(results) % results # results (P, R, mAP, F1, test_loss) print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c)) if bucket: os.system('gsutil cp gs://%s/evolve.txt .' % bucket) # download evolve.txt with open('evolve.txt', 'a') as f: # append result f.write(c + b + '\n') x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows np.savetxt('evolve.txt', x[np.argsort(-fitness(x))], '%10.3g') # save sort by fitness if bucket: os.system('gsutil cp evolve.txt gs://%s' % bucket) # upload evolve.txt
Example 10
Project: Pytorch-Project-Template Author: moemen95 File: dqn.py License: MIT License | 6 votes |
def load_checkpoint(self, file_name): filename = self.config.checkpoint_dir + file_name try: self.logger.info("Loading checkpoint '{}'".format(filename)) checkpoint = torch.load(filename) self.current_episode = checkpoint['episode'] self.current_iteration = checkpoint['iteration'] self.policy_model.load_state_dict(checkpoint['state_dict']) self.optim.load_state_dict(checkpoint['optimizer']) self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n" .format(self.config.checkpoint_dir, checkpoint['episode'], checkpoint['iteration'])) except OSError as e: self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir)) self.logger.info("**First time to train**")
Example 11
Project: Pytorch-Project-Template Author: moemen95 File: condensenet.py License: MIT License | 6 votes |
def load_checkpoint(self, filename): filename = self.config.checkpoint_dir + filename try: self.logger.info("Loading checkpoint '{}'".format(filename)) checkpoint = torch.load(filename) self.current_epoch = checkpoint['epoch'] self.current_iteration = checkpoint['iteration'] self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n" .format(self.config.checkpoint_dir, checkpoint['epoch'], checkpoint['iteration'])) except OSError as e: self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir)) self.logger.info("**First time to train**")
Example 12
Project: Pytorch-Project-Template Author: moemen95 File: erfnet.py License: MIT License | 6 votes |
def load_checkpoint(self, filename): filename = self.config.checkpoint_dir + filename try: self.logger.info("Loading checkpoint '{}'".format(filename)) checkpoint = torch.load(filename) self.current_epoch = checkpoint['epoch'] self.current_iteration = checkpoint['iteration'] self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n" .format(self.config.checkpoint_dir, checkpoint['epoch'], checkpoint['iteration'])) except OSError as e: self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir)) self.logger.info("**First time to train**")
Example 13
Project: Pytorch-Project-Template Author: moemen95 File: dcgan.py License: MIT License | 6 votes |
def load_checkpoint(self, file_name): filename = self.config.checkpoint_dir + file_name try: self.logger.info("Loading checkpoint '{}'".format(filename)) checkpoint = torch.load(filename) self.current_epoch = checkpoint['epoch'] self.current_iteration = checkpoint['iteration'] self.netG.load_state_dict(checkpoint['G_state_dict']) self.optimG.load_state_dict(checkpoint['G_optimizer']) self.netD.load_state_dict(checkpoint['D_state_dict']) self.optimD.load_state_dict(checkpoint['D_optimizer']) self.fixed_noise = checkpoint['fixed_noise'] self.manual_seed = checkpoint['manual_seed'] self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n" .format(self.config.checkpoint_dir, checkpoint['epoch'], checkpoint['iteration'])) except OSError as e: self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir)) self.logger.info("**First time to train**")
Example 14
Project: dogTorch Author: ehsanik File: nyu_walkable_surface_dataset.py License: MIT License | 6 votes |
def __getitem__(self, idx): fid = self.data_set_list[idx] if self.read_features: features = [] for i in range(self.sequence_length): feature_path = os.path.join( self.features_dir, self.frames_metadata[fid + i]['cur_frame'] + '.pytar') features.append(torch.load(feature_path)) input = torch.stack(features) else: image = self.load_and_resize( os.path.join(self.root_dir, 'images', fid)) segment = self.load_and_resize_segmentation( os.path.join(self.root_dir, 'walkable', fid)) # The two 0s are just place holders. They can be replaced by any values return (image, segment, 0, 0, ['images' + fid])
Example 15
Project: End-to-end-ASR-Pytorch Author: Alexander-H-Liu File: train_lm.py License: MIT License | 6 votes |
def set_model(self): ''' Setup ASR model and optimizer ''' # Model self.model = RNNLM(self.vocab_size, ** self.config['model']).to(self.device) self.verbose(self.model.create_msg()) # Losses self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0) # Optimizer self.optimizer = Optimizer( self.model.parameters(), **self.config['hparas']) # Enable AMP if needed self.enable_apex() # load pre-trained model if self.paras.load: self.load_ckpt() ckpt = torch.load(self.paras.load, map_location=self.device) self.model.load_state_dict(ckpt['model']) self.optimizer.load_opt_state_dict(ckpt['optimizer']) self.step = ckpt['global_step'] self.verbose('Load ckpt from {}, restarting at step {}'.format( self.paras.load, self.step))
Example 16
Project: transferlearning Author: jindongwang File: main.py License: MIT License | 6 votes |
def extract_feature(model, dataloader, save_path, load_from_disk=True, model_path=''): if load_from_disk: model = models.Network(base_net=args.model_name, n_class=args.num_class) model.load_state_dict(torch.load(model_path)) model = model.to(DEVICE) model.eval() correct = 0 fea_all = torch.zeros(1,1+model.base_network.output_num()).to(DEVICE) with torch.no_grad(): for inputs, labels in dataloader: inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) feas = model.get_features(inputs) labels = labels.view(labels.size(0), 1).float() x = torch.cat((feas, labels), dim=1) fea_all = torch.cat((fea_all, x), dim=0) outputs = model(inputs) preds = torch.max(outputs, 1)[1] correct += torch.sum(preds == labels.data.long()) test_acc = correct.double() / len(dataloader.dataset) fea_numpy = fea_all.cpu().numpy() np.savetxt(save_path, fea_numpy[1:], fmt='%.6f', delimiter=',') print('Test acc: %f' % test_acc) # You may want to classify with 1nn after getting features
Example 17
Project: Pytorch-Networks Author: HaiyangLiu1997 File: utils.py License: MIT License | 6 votes |
def load_test_checkpoints(model, save_path, logger, use_best=False): #try: if use_best: print(save_path.EXPS+save_path.NAME+save_path.BESTMODEL) states= torch.load(save_path.EXPS+save_path.NAME+save_path.BESTMODEL) if torch.cuda.is_available() \ else torch.load(save_path.EXPS+save_path.NAME+save_path.BESTMODEL, map_location=torch.device('cpu')) else: states= torch.load(save_path.EXPS+save_path.NAME+save_path.MODEL) if torch.cuda.is_available() \ else torch.load(save_path.EXPS+save_path.NAME+save_path.MODEL, map_location=torch.device('cpu')) #logger.debug("success") #try: model.load_state_dict(states['model_state']) # except: # states_no_module = OrderedDict() # for k, v in states['model_state'].items(): # name_no_module = k[7:] # states_no_module[name_no_module] = v # model.load_state_dict(states_no_module) logger.info('loading checkpoints success') # except: # logger.error("no checkpoints")
Example 18
Project: cmrc2019 Author: ymcui File: modeling.py License: Creative Commons Attribution Share Alike 4.0 International | 5 votes |
def __init__(self, config): super(BertEmbeddings, self).__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob)
Example 19
Project: Collaborative-Learning-for-Weakly-Supervised-Object-Detection Author: Sunarker File: train_val.py License: MIT License | 5 votes |
def initialize(self): # Initial file lists are empty np_paths = [] ss_paths = [] # Fresh train directly from ImageNet weights print('Loading initial model weights from {:s}'.format(self.pretrained_model)) self.net.load_pretrained_cnn(torch.load(self.pretrained_model)) print('Loaded.') # pretrained_model = torch.load('/DATA3_DB7/data/jjwang/workspace/two_stage/output/vgg16/voc_2007_trainval/default/vgg16_faster_rcnn_iter_50001.pth') if self.wsddn_premodel is not None: # Load the pretrained WSDDN model wsddn_pre = torch.load(self.wsddn_premodel) model_dict = self.net.state_dict() model_dict.update(wsddn_pre) self.net.load_state_dict(model_dict) print('Loading pretrained WSDDN model weights from {:s}'.format(self.wsddn_premodel)) print('Loaded.') # Need to fix the variables before loading, so that the RGB weights are changed to BGR # For VGG16 it also changes the convolutional weights fc6 and fc7 to # fully connected weights last_snapshot_iter = 0 lr = cfg.TRAIN.LEARNING_RATE stepsizes = list(cfg.TRAIN.STEPSIZE) return lr, last_snapshot_iter, stepsizes, np_paths, ss_paths
Example 20
Project: comet-commonsense Author: atcbosselut File: utils.py License: Apache License 2.0 | 5 votes |
def load_existing_data_loader(data_loader, path): old_data_loader = torch.load(path) for attr in data_loader.__dict__.keys(): if attr not in old_data_loader.__dict__.keys(): continue setattr(data_loader, attr, getattr(old_data_loader, attr)) ################################################################################ # # Code Below taken from HuggingFace pytorch-openai-lm repository # ################################################################################
Example 21
Project: comet-commonsense Author: atcbosselut File: utils.py License: Apache License 2.0 | 5 votes |
def __init__(self, encoder_path, bpe_path): self.nlp = spacy.load( 'en', disable=['parser', 'tagger', 'ner', 'textcat']) self.encoder = json.load(open(encoder_path)) self.decoder = {v: k for k, v in self.encoder.items()} merges = open(bpe_path, encoding='utf-8').read().split('\n')[1:-1] merges = [tuple(merge.split()) for merge in merges] self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {}
Example 22
Project: comet-commonsense Author: atcbosselut File: data.py License: Apache License 2.0 | 5 votes |
def load_checkpoint(filename, gpu=True): if os.path.exists(filename): checkpoint = torch.load( filename, map_location=lambda storage, loc: storage) else: print("No model found at {}".format(filename)) return checkpoint
Example 23
Project: DDPAE-video-prediction Author: jthsieh File: base_model.py License: MIT License | 5 votes |
def load(self, ckpt_path, epoch, load_optimizer=False): ''' Load checkpoint. ''' for name, net in self.nets.items(): path = os.path.join(ckpt_path, 'net_{}_{}.pth'.format(name, epoch)) if not os.path.exists(path): print('{} does not exist, ignore.'.format(path)) continue ckpt = torch.load(path) if isinstance(net, torch.nn.DataParallel): module = net.module else: module = net try: module.load_state_dict(ckpt) except: print('net_{} and checkpoint have different parameter names'.format(name)) new_ckpt = OrderedDict() for ckpt_key, module_key in zip(ckpt.keys(), module.state_dict().keys()): assert ckpt_key.split('.')[-1] == module_key.split('.')[-1] new_ckpt[module_key] = ckpt[ckpt_key] module.load_state_dict(new_ckpt) if load_optimizer: for name, optimizer in self.optimizers.items(): path = os.path.join(ckpt_path, 'optimizer_{}_{}.pth'.format(name, epoch)) if not os.path.exists(path): print('{} does not exist, ignore.'.format(path)) continue ckpt = torch.load(path) optimizer.load_state_dict(ckpt)
Example 24
Project: mmdetection Author: open-mmlab File: gather_models.py License: Apache License 2.0 | 5 votes |
def process_checkpoint(in_file, out_file): checkpoint = torch.load(in_file, map_location='cpu') # remove optimizer for smaller file size if 'optimizer' in checkpoint: del checkpoint['optimizer'] # if it is necessary to remove some sensitive data in checkpoint['meta'], # add the code here. torch.save(checkpoint, out_file) sha = subprocess.check_output(['sha256sum', out_file]).decode() final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8]) subprocess.Popen(['mv', out_file, final_file]) return final_file
Example 25
Project: mmdetection Author: open-mmlab File: publish_model.py License: Apache License 2.0 | 5 votes |
def process_checkpoint(in_file, out_file): checkpoint = torch.load(in_file, map_location='cpu') # remove optimizer for smaller file size if 'optimizer' in checkpoint: del checkpoint['optimizer'] # if it is necessary to remove some sensitive data in checkpoint['meta'], # add the code here. torch.save(checkpoint, out_file) sha = subprocess.check_output(['sha256sum', out_file]).decode() if out_file.endswith('.pth'): out_file = out_file[:-4] final_file = out_file + f'-{sha[:8]}.pth' subprocess.Popen(['mv', out_file, final_file])
Example 26
Project: subword-qac Author: clovaai File: utils.py License: MIT License | 5 votes |
def model_load(path, model=None, optimizer=None): config = LMConfig(os.path.join(path, 'config.json')) if model is None: model_to_load = LanguageModel(config) else: model_to_load = get_model(model) model_to_load.__init__(config) model_state_dict = torch.load(open(os.path.join(path, 'model.pt'), 'rb'), map_location=lambda s, l: s) model_to_load.load_state_dict(model_state_dict) if optimizer: optimizer_state_dict = torch.load(open(os.path.join(path, 'optimizer.pt'), 'rb'), map_location=lambda s, l: s) optimizer.load_state_dict(optimizer_state_dict) return model_to_load
Example 27
Project: models Author: kipoi File: model.py License: MIT License | 5 votes |
def load_weights( self, weights ): model = CNN() weights = torch.load(weights, map_location=None if torch.cuda.is_available() else 'cpu') model.load_state_dict(weights) return model.eval()
Example 28
Project: models Author: kipoi File: pretrained_model_reloaded_th.py License: MIT License | 5 votes |
def get_model(load_weights = True): # alphabet seems to be fine: """ https://github.com/davek44/Basset/tree/master/src/dna_io.py#L145-L148 seq = seq.replace('A','0') seq = seq.replace('C','1') seq = seq.replace('G','2') seq = seq.replace('T','3') """ pretrained_model_reloaded_th = nn.Sequential( # Sequential, nn.Conv2d(4,300,(19, 1)), nn.BatchNorm2d(300), nn.ReLU(), nn.MaxPool2d((3, 1),(3, 1)), nn.Conv2d(300,200,(11, 1)), nn.BatchNorm2d(200), nn.ReLU(), nn.MaxPool2d((4, 1),(4, 1)), nn.Conv2d(200,200,(7, 1)), nn.BatchNorm2d(200), nn.ReLU(), nn.MaxPool2d((4, 1),(4, 1)), Lambda(lambda x: x.view(x.size(0),-1)), # Reshape, nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(2000,1000)), # Linear, nn.BatchNorm1d(1000,1e-05,0.1,True),#BatchNorm1d, nn.ReLU(), nn.Dropout(0.3), nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(1000,1000)), # Linear, nn.BatchNorm1d(1000,1e-05,0.1,True),#BatchNorm1d, nn.ReLU(), nn.Dropout(0.3), nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(1000,164)), # Linear, nn.Sigmoid(), ) if load_weights: sd = torch.load('model_files/pretrained_model_reloaded_th.pth') pretrained_model_reloaded_th.load_state_dict(sd) return pretrained_model_reloaded_th
Example 29
Project: VSE-C Author: ExplorerFreda File: evaluation.py License: MIT License | 5 votes |
def eval_with_extended(model_path, data_path=None, data_name=None, split='test'): checkpoint = torch.load(model_path) opt = checkpoint['opt'] opt.use_external_captions = True opt.negative_number = 5 if data_path is not None: opt.data_path = data_path if data_name is not None: opt.data_name = data_name # load vocabulary used by the model with open(os.path.join(opt.vocab_path, '%s_vocab.pkl' % opt.data_name), 'rb') as f: vocab = pickle.load(f) opt.vocab_size = len(vocab) opt.use_external_captions = True # construct model model = VSE(opt) # load model state model.load_state_dict(checkpoint['model']) print('Loading dataset') data_loader = get_test_loader(split, opt.data_name, vocab, opt.crop_size, opt.batch_size, opt.workers, opt) print('Computing results...') img_embs, cap_embs = encode_data(model, data_loader) print('Images: %d, Captions: %d' % (img_embs.shape[0] // 5, cap_embs.shape[0])) r, rt = i2t_text_only(img_embs, cap_embs, measure=opt.measure, return_ranks=True) ar = (r[0] + r[1] + r[2]) / 3 print("Average i2t Recall: %.1f" % ar) print("Image to text: %.1f\t%.1f\t%.1f\t%.1f\t%.1f" % r) torch.save({'rt': rt}, model_path[:model_path.find('model_best')] + 'ranks_extended.pth.tar')
Example 30
Project: VSE-C Author: ExplorerFreda File: saliency_visualization.py License: MIT License | 5 votes |
def main(): encoder = ImageEncoder(args.encoder, args.load_encoder) dataset = Dataset(args) extractor = FeatureExtractor(args.load, encoder, dataset) def e(ind): a = extractor(ind) pic = plot_saliency(a.raw_image, a.image, a.image_embedding, a.caption_embedding) pic.save('/tmp/origin{}.png'.format(ind)) print('Image saliency saved:', '/tmp/origin{}.png'.format(ind)) print_txt_saliency(a.captions, 0, a.raw_caption, a.image_embedding, a.caption_embedding) return a from IPython import embed; embed()