Python torch.load() Examples

The following are 30 code examples for showing how to use torch.load(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.

You may check out the related API usage on the sidebar.

You may also want to check out all available functions/classes of the module torch , or try the search function .

Example 1
Project: models   Author: kipoi   File: model_architecture.py    License: MIT License 7 votes vote down vote up
def get_model(load_weights = True):
    deepsea_cpu = nn.Sequential( # Sequential,
        nn.Conv2d(4,320,(1, 8),(1, 1)),
        nn.Threshold(0, 1e-06),
        nn.MaxPool2d((1, 4),(1, 4)),
        nn.Dropout(0.2),
        nn.Conv2d(320,480,(1, 8),(1, 1)),
        nn.Threshold(0, 1e-06),
        nn.MaxPool2d((1, 4),(1, 4)),
        nn.Dropout(0.2),
        nn.Conv2d(480,960,(1, 8),(1, 1)),
        nn.Threshold(0, 1e-06),
        nn.Dropout(0.5),
        Lambda(lambda x: x.view(x.size(0),-1)), # Reshape,
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(50880,925)), # Linear,
        nn.Threshold(0, 1e-06),
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(925,919)), # Linear,
        nn.Sigmoid(),
    )
    if load_weights:
        deepsea_cpu.load_state_dict(torch.load('model_files/deepsea_cpu.pth'))
    return nn.Sequential(ReCodeAlphabet(), deepsea_cpu) 
Example 2
def from_snapshot(self, sfile, nfile):
    print('Restoring model snapshots from {:s}'.format(sfile))
    self.net.load_state_dict(torch.load(str(sfile)))
    print('Restored.')
    # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
    # tried my best to find the random states so that it can be recovered exactly
    # However the Tensorflow state is currently not available
    with open(nfile, 'rb') as fid:
      st0 = pickle.load(fid)
      cur = pickle.load(fid)
      perm = pickle.load(fid)
      cur_val = pickle.load(fid)
      perm_val = pickle.load(fid)
      last_snapshot_iter = pickle.load(fid)

      np.random.set_state(st0)
      self.data_layer._cur = cur
      self.data_layer._perm = perm
      self.data_layer_val._cur = cur_val
      self.data_layer_val._perm = perm_val

    return last_snapshot_iter 
Example 3
Project: mmdetection   Author: open-mmlab   File: regnet2mmdet.py    License: Apache License 2.0 6 votes vote down vote up
def convert(src, dst):
    """Convert keys in pycls pretrained RegNet models to mmdet style."""
    # load caffe model
    regnet_model = torch.load(src)
    blobs = regnet_model['model_state']
    # convert to pytorch style
    state_dict = OrderedDict()
    converted_names = set()
    for key, weight in blobs.items():
        if 'stem' in key:
            convert_stem(key, weight, state_dict, converted_names)
        elif 'head' in key:
            convert_head(key, weight, state_dict, converted_names)
        elif key.startswith('s'):
            convert_reslayer(key, weight, state_dict, converted_names)

    # check if all layers are converted
    for key in blobs:
        if key not in converted_names:
            print(f'not converted: {key}')
    # save checkpoint
    checkpoint = dict()
    checkpoint['state_dict'] = state_dict
    torch.save(checkpoint, dst) 
Example 4
Project: models   Author: kipoi   File: model_architecture.py    License: MIT License 6 votes vote down vote up
def get_seqpred_model(load_weights = True):
    deepsea_cpu = nn.Sequential( # Sequential,
        nn.Conv2d(4,320,(1, 8),(1, 1)),
        nn.Threshold(0, 1e-06),
        nn.MaxPool2d((1, 4),(1, 4)),
        nn.Dropout(0.2),
        nn.Conv2d(320,480,(1, 8),(1, 1)),
        nn.Threshold(0, 1e-06),
        nn.MaxPool2d((1, 4),(1, 4)),
        nn.Dropout(0.2),
        nn.Conv2d(480,960,(1, 8),(1, 1)),
        nn.Threshold(0, 1e-06),
        nn.Dropout(0.5),
        Lambda(lambda x: x.view(x.size(0),-1)), # Reshape,
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(50880,925)), # Linear,
        nn.Threshold(0, 1e-06),
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(925,919)), # Linear,
        nn.Sigmoid(),
    )
    if load_weights:
        deepsea_cpu.load_state_dict(torch.load('model_files/deepsea_cpu.pth'))
    return nn.Sequential(ReCodeAlphabet(), ConcatenateRC(), deepsea_cpu, AverageRC()) 
Example 5
Project: pytorch_NER_BiLSTM_CNN_CRF   Author: bamtercelboo   File: test.py    License: Apache License 2.0 6 votes vote down vote up
def load_test_model(model, config):
    """
    :param model:  initial model
    :param config:  config
    :return:  loaded model
    """
    if config.t_model is None:
        test_model_dir = config.save_best_model_dir
        test_model_name = "{}.pt".format(config.model_name)
        test_model_path = os.path.join(test_model_dir, test_model_name)
        print("load default model from {}".format(test_model_path))
    else:
        test_model_path = config.t_model
        print("load user model from {}".format(test_model_path))
    model.load_state_dict(torch.load(test_model_path))
    return model 
Example 6
Project: VSE-C   Author: ExplorerFreda   File: saliency_visualization.py    License: MIT License 6 votes vote down vote up
def load_checkpoint(self, checkpoint):
        checkpoint = torch.load(checkpoint)
        opt = checkpoint['opt']
        opt.use_external_captions = False
        vocab = Vocab.from_pickle(pjoin(opt.vocab_path, '%s_vocab.pkl' % opt.data_name))
        opt.vocab_size = len(vocab)

        from model import VSE
        self.model = VSE(opt)
        self.model.load_state_dict(checkpoint['model'])
        self.projector = vocab

        self.model.img_enc.eval()
        self.model.txt_enc.eval()
        for p in self.model.img_enc.parameters():
            p.requires_grad = False
        for p in self.model.txt_enc.parameters():
            p.requires_grad = False 
Example 7
Project: cvpr2018-hnd   Author: kibok90   File: utils.py    License: MIT License 6 votes vote down vote up
def load_model(model, optimizer, scheduler, path, num_epochs, start_time=time.time()):

    epoch = num_epochs
    while epoch > 0 and not os.path.isfile('{path}_model_{epoch:d}.pth'.format(path=path, epoch=epoch)):
        epoch -= 1
    if epoch > 0:
        model_path = '{path}_model_{epoch:d}.pth'.format(path=path, epoch=epoch)
        model_state_dict = torch.load('{path}_model_{epoch:d}.pth'.format(path=path, epoch=epoch))
        model.load_state_dict(model_state_dict)
        if optimizer is not None:
            optimizer_state_dict = torch.load('{path}_optimizer_{epoch:d}.pth'.format(path=path, epoch=epoch))
            optimizer.load_state_dict(optimizer_state_dict)
        if scheduler is not None:
            scheduler_state_dict = torch.load('{path}_scheduler_{epoch:d}.pth'.format(path=path, epoch=epoch))
            scheduler.best = scheduler_state_dict['best']
            scheduler.cooldown_counter = scheduler_state_dict['cooldown_counter']
            scheduler.num_bad_epochs = scheduler_state_dict['num_bad_epochs']
            scheduler.last_epoch = scheduler_state_dict['last_epoch']
        print('{epoch:4d}/{num_epochs:4d} e; '.format(epoch=epoch, num_epochs=num_epochs), end='')
        print('load {path}; '.format(path=model_path), end='')
        print('{time:8.3f} s'.format(time=time.time()-start_time))
    return epoch 
Example 8
Project: cvpr2018-hnd   Author: kibok90   File: models.py    License: MIT License 6 votes vote down vote up
def init_truncated_normal(model, aux_str=''):
    if model is None: return None
    init_path = '{path}/{in_dim:d}_{out_dim:d}{aux_str}.pth' \
                .format(path=path, in_dim=model.in_features, out_dim=model.out_features, aux_str=aux_str)
    if os.path.isfile(init_path):
        model.load_state_dict(torch.load(init_path))
        print('load init weight: {init_path}'.format(init_path=init_path))
    else:
        if isinstance(model, nn.ModuleList):
            [truncated_normal(sub) for sub in model]
        else:
            truncated_normal(model)
        print('generate init weight: {init_path}'.format(init_path=init_path))
        torch.save(model.state_dict(), init_path)
        print('save init weight: {init_path}'.format(init_path=init_path))
    
    return model 
Example 9
Project: pruning_yolov3   Author: zbyuan   File: utils.py    License: GNU General Public License v3.0 6 votes vote down vote up
def print_mutation(hyp, results, bucket=''):
    # Print mutation results to evolve.txt (for use with train.py --evolve)
    a = '%10s' * len(hyp) % tuple(hyp.keys())  # hyperparam keys
    b = '%10.3g' * len(hyp) % tuple(hyp.values())  # hyperparam values
    c = '%10.3g' * len(results) % results  # results (P, R, mAP, F1, test_loss)
    print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))

    if bucket:
        os.system('gsutil cp gs://%s/evolve.txt .' % bucket)  # download evolve.txt

    with open('evolve.txt', 'a') as f:  # append result
        f.write(c + b + '\n')
    x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0)  # load unique rows
    np.savetxt('evolve.txt', x[np.argsort(-fitness(x))], '%10.3g')  # save sort by fitness

    if bucket:
        os.system('gsutil cp evolve.txt gs://%s' % bucket)  # upload evolve.txt 
Example 10
Project: Pytorch-Project-Template   Author: moemen95   File: dqn.py    License: MIT License 6 votes vote down vote up
def load_checkpoint(self, file_name):
        filename = self.config.checkpoint_dir + file_name
        try:
            self.logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)

            self.current_episode = checkpoint['episode']
            self.current_iteration = checkpoint['iteration']
            self.policy_model.load_state_dict(checkpoint['state_dict'])
            self.optim.load_state_dict(checkpoint['optimizer'])

            self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
                  .format(self.config.checkpoint_dir, checkpoint['episode'], checkpoint['iteration']))
        except OSError as e:
            self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir))
            self.logger.info("**First time to train**") 
Example 11
Project: Pytorch-Project-Template   Author: moemen95   File: condensenet.py    License: MIT License 6 votes vote down vote up
def load_checkpoint(self, filename):
        filename = self.config.checkpoint_dir + filename
        try:
            self.logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)

            self.current_epoch = checkpoint['epoch']
            self.current_iteration = checkpoint['iteration']
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])

            self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
                             .format(self.config.checkpoint_dir, checkpoint['epoch'], checkpoint['iteration']))
        except OSError as e:
            self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir))
            self.logger.info("**First time to train**") 
Example 12
Project: Pytorch-Project-Template   Author: moemen95   File: erfnet.py    License: MIT License 6 votes vote down vote up
def load_checkpoint(self, filename):
        filename = self.config.checkpoint_dir + filename
        try:
            self.logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)

            self.current_epoch = checkpoint['epoch']
            self.current_iteration = checkpoint['iteration']
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])

            self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
                  .format(self.config.checkpoint_dir, checkpoint['epoch'], checkpoint['iteration']))
        except OSError as e:
            self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir))
            self.logger.info("**First time to train**") 
Example 13
Project: Pytorch-Project-Template   Author: moemen95   File: dcgan.py    License: MIT License 6 votes vote down vote up
def load_checkpoint(self, file_name):
        filename = self.config.checkpoint_dir + file_name
        try:
            self.logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)

            self.current_epoch = checkpoint['epoch']
            self.current_iteration = checkpoint['iteration']
            self.netG.load_state_dict(checkpoint['G_state_dict'])
            self.optimG.load_state_dict(checkpoint['G_optimizer'])
            self.netD.load_state_dict(checkpoint['D_state_dict'])
            self.optimD.load_state_dict(checkpoint['D_optimizer'])
            self.fixed_noise = checkpoint['fixed_noise']
            self.manual_seed = checkpoint['manual_seed']

            self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
                  .format(self.config.checkpoint_dir, checkpoint['epoch'], checkpoint['iteration']))
        except OSError as e:
            self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir))
            self.logger.info("**First time to train**") 
Example 14
Project: dogTorch   Author: ehsanik   File: nyu_walkable_surface_dataset.py    License: MIT License 6 votes vote down vote up
def __getitem__(self, idx):
        fid = self.data_set_list[idx]
        if self.read_features:
            features = []
            for i in range(self.sequence_length):
                feature_path = os.path.join(
                    self.features_dir,
                    self.frames_metadata[fid + i]['cur_frame'] + '.pytar')
                features.append(torch.load(feature_path))
            input = torch.stack(features)
        else:
            image = self.load_and_resize(
                os.path.join(self.root_dir, 'images', fid))
            segment = self.load_and_resize_segmentation(
                os.path.join(self.root_dir, 'walkable', fid))

        # The two 0s are just place holders. They can be replaced by any values
        return (image, segment, 0, 0, ['images' + fid]) 
Example 15
Project: End-to-end-ASR-Pytorch   Author: Alexander-H-Liu   File: train_lm.py    License: MIT License 6 votes vote down vote up
def set_model(self):
        ''' Setup ASR model and optimizer '''

        # Model
        self.model = RNNLM(self.vocab_size, **
                           self.config['model']).to(self.device)
        self.verbose(self.model.create_msg())
        # Losses
        self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        # Optimizer
        self.optimizer = Optimizer(
            self.model.parameters(), **self.config['hparas'])
        # Enable AMP if needed
        self.enable_apex()
        # load pre-trained model
        if self.paras.load:
            self.load_ckpt()
            ckpt = torch.load(self.paras.load, map_location=self.device)
            self.model.load_state_dict(ckpt['model'])
            self.optimizer.load_opt_state_dict(ckpt['optimizer'])
            self.step = ckpt['global_step']
            self.verbose('Load ckpt from {}, restarting at step {}'.format(
                self.paras.load, self.step)) 
Example 16
Project: transferlearning   Author: jindongwang   File: main.py    License: MIT License 6 votes vote down vote up
def extract_feature(model, dataloader, save_path, load_from_disk=True, model_path=''):
    if load_from_disk:
        model = models.Network(base_net=args.model_name,
                               n_class=args.num_class)
        model.load_state_dict(torch.load(model_path))
        model = model.to(DEVICE)
    model.eval()
    correct = 0
    fea_all = torch.zeros(1,1+model.base_network.output_num()).to(DEVICE)
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            feas = model.get_features(inputs)
            labels = labels.view(labels.size(0), 1).float()
            x = torch.cat((feas, labels), dim=1)
            fea_all = torch.cat((fea_all, x), dim=0)
            outputs = model(inputs)
            preds = torch.max(outputs, 1)[1]
            correct += torch.sum(preds == labels.data.long())
        test_acc = correct.double() / len(dataloader.dataset)
    fea_numpy = fea_all.cpu().numpy()
    np.savetxt(save_path, fea_numpy[1:], fmt='%.6f', delimiter=',')
    print('Test acc: %f' % test_acc)

# You may want to classify with 1nn after getting features 
Example 17
Project: Pytorch-Networks   Author: HaiyangLiu1997   File: utils.py    License: MIT License 6 votes vote down vote up
def load_test_checkpoints(model, save_path, logger, use_best=False):
    
    #try:
    if use_best:
        print(save_path.EXPS+save_path.NAME+save_path.BESTMODEL)
        states= torch.load(save_path.EXPS+save_path.NAME+save_path.BESTMODEL) if torch.cuda.is_available() \
            else torch.load(save_path.EXPS+save_path.NAME+save_path.BESTMODEL, map_location=torch.device('cpu'))
    else:   
        states= torch.load(save_path.EXPS+save_path.NAME+save_path.MODEL) if torch.cuda.is_available() \
            else torch.load(save_path.EXPS+save_path.NAME+save_path.MODEL, map_location=torch.device('cpu'))
    #logger.debug("success")
    #try:
    model.load_state_dict(states['model_state'])
    # except:
    #     states_no_module = OrderedDict()
    #     for k, v in states['model_state'].items():
    #         name_no_module = k[7:]
    #         states_no_module[name_no_module] = v
    #     model.load_state_dict(states_no_module)
    logger.info('loading checkpoints success')
    # except:
    #     logger.error("no checkpoints") 
Example 18
def __init__(self, config):
        super(BertEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob) 
Example 19
def initialize(self):
    # Initial file lists are empty
    np_paths = []
    ss_paths = []
    # Fresh train directly from ImageNet weights
    print('Loading initial model weights from {:s}'.format(self.pretrained_model))
    self.net.load_pretrained_cnn(torch.load(self.pretrained_model))
    print('Loaded.')
    
#    pretrained_model = torch.load('/DATA3_DB7/data/jjwang/workspace/two_stage/output/vgg16/voc_2007_trainval/default/vgg16_faster_rcnn_iter_50001.pth')    
    if self.wsddn_premodel is not None: # Load the pretrained WSDDN model
      wsddn_pre = torch.load(self.wsddn_premodel)
      model_dict = self.net.state_dict()
      model_dict.update(wsddn_pre)
      self.net.load_state_dict(model_dict)
      print('Loading pretrained WSDDN model weights from {:s}'.format(self.wsddn_premodel))
      print('Loaded.')
    
    
    # Need to fix the variables before loading, so that the RGB weights are changed to BGR
    # For VGG16 it also changes the convolutional weights fc6 and fc7 to
    # fully connected weights
    last_snapshot_iter = 0
    lr = cfg.TRAIN.LEARNING_RATE
    stepsizes = list(cfg.TRAIN.STEPSIZE)

    return lr, last_snapshot_iter, stepsizes, np_paths, ss_paths 
Example 20
Project: comet-commonsense   Author: atcbosselut   File: utils.py    License: Apache License 2.0 5 votes vote down vote up
def load_existing_data_loader(data_loader, path):
    old_data_loader = torch.load(path)
    for attr in data_loader.__dict__.keys():
        if attr not in old_data_loader.__dict__.keys():
            continue
        setattr(data_loader, attr, getattr(old_data_loader, attr))


################################################################################
#
# Code Below taken from HuggingFace pytorch-openai-lm repository
#
################################################################################ 
Example 21
Project: comet-commonsense   Author: atcbosselut   File: utils.py    License: Apache License 2.0 5 votes vote down vote up
def __init__(self, encoder_path, bpe_path):
        self.nlp = spacy.load(
            'en', disable=['parser', 'tagger', 'ner', 'textcat'])
        self.encoder = json.load(open(encoder_path))
        self.decoder = {v: k for k, v in self.encoder.items()}
        merges = open(bpe_path, encoding='utf-8').read().split('\n')[1:-1]
        merges = [tuple(merge.split()) for merge in merges]
        self.bpe_ranks = dict(zip(merges, range(len(merges))))
        self.cache = {} 
Example 22
Project: comet-commonsense   Author: atcbosselut   File: data.py    License: Apache License 2.0 5 votes vote down vote up
def load_checkpoint(filename, gpu=True):
    if os.path.exists(filename):
        checkpoint = torch.load(
            filename, map_location=lambda storage, loc: storage)
    else:
        print("No model found at {}".format(filename))
    return checkpoint 
Example 23
Project: DDPAE-video-prediction   Author: jthsieh   File: base_model.py    License: MIT License 5 votes vote down vote up
def load(self, ckpt_path, epoch, load_optimizer=False):
    '''
    Load checkpoint.
    '''
    for name, net in self.nets.items():
      path = os.path.join(ckpt_path, 'net_{}_{}.pth'.format(name, epoch))
      if not os.path.exists(path):
        print('{} does not exist, ignore.'.format(path))
        continue
      ckpt = torch.load(path)
      if isinstance(net, torch.nn.DataParallel):
        module = net.module
      else:
        module = net

      try:
        module.load_state_dict(ckpt)
      except:
        print('net_{} and checkpoint have different parameter names'.format(name))
        new_ckpt = OrderedDict()
        for ckpt_key, module_key in zip(ckpt.keys(), module.state_dict().keys()):
          assert ckpt_key.split('.')[-1] == module_key.split('.')[-1]
          new_ckpt[module_key] = ckpt[ckpt_key]
        module.load_state_dict(new_ckpt)

    if load_optimizer:
      for name, optimizer in self.optimizers.items():
        path = os.path.join(ckpt_path, 'optimizer_{}_{}.pth'.format(name, epoch))
        if not os.path.exists(path):
          print('{} does not exist, ignore.'.format(path))
          continue
        ckpt = torch.load(path)
        optimizer.load_state_dict(ckpt) 
Example 24
Project: mmdetection   Author: open-mmlab   File: gather_models.py    License: Apache License 2.0 5 votes vote down vote up
def process_checkpoint(in_file, out_file):
    checkpoint = torch.load(in_file, map_location='cpu')
    # remove optimizer for smaller file size
    if 'optimizer' in checkpoint:
        del checkpoint['optimizer']
    # if it is necessary to remove some sensitive data in checkpoint['meta'],
    # add the code here.
    torch.save(checkpoint, out_file)
    sha = subprocess.check_output(['sha256sum', out_file]).decode()
    final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8])
    subprocess.Popen(['mv', out_file, final_file])
    return final_file 
Example 25
Project: mmdetection   Author: open-mmlab   File: publish_model.py    License: Apache License 2.0 5 votes vote down vote up
def process_checkpoint(in_file, out_file):
    checkpoint = torch.load(in_file, map_location='cpu')
    # remove optimizer for smaller file size
    if 'optimizer' in checkpoint:
        del checkpoint['optimizer']
    # if it is necessary to remove some sensitive data in checkpoint['meta'],
    # add the code here.
    torch.save(checkpoint, out_file)
    sha = subprocess.check_output(['sha256sum', out_file]).decode()
    if out_file.endswith('.pth'):
        out_file = out_file[:-4]
    final_file = out_file + f'-{sha[:8]}.pth'
    subprocess.Popen(['mv', out_file, final_file]) 
Example 26
Project: subword-qac   Author: clovaai   File: utils.py    License: MIT License 5 votes vote down vote up
def model_load(path, model=None, optimizer=None):
    config = LMConfig(os.path.join(path, 'config.json'))
    if model is None:
        model_to_load = LanguageModel(config)
    else:
        model_to_load = get_model(model)
        model_to_load.__init__(config)
    model_state_dict = torch.load(open(os.path.join(path, 'model.pt'), 'rb'), map_location=lambda s, l: s)
    model_to_load.load_state_dict(model_state_dict)
    if optimizer:
        optimizer_state_dict = torch.load(open(os.path.join(path, 'optimizer.pt'), 'rb'), map_location=lambda s, l: s)
        optimizer.load_state_dict(optimizer_state_dict)
    return model_to_load 
Example 27
Project: models   Author: kipoi   File: model.py    License: MIT License 5 votes vote down vote up
def load_weights( self, weights ):
        model = CNN()
        weights = torch.load(weights, map_location=None if torch.cuda.is_available() else 'cpu')
        model.load_state_dict(weights)
        return model.eval() 
Example 28
Project: models   Author: kipoi   File: pretrained_model_reloaded_th.py    License: MIT License 5 votes vote down vote up
def get_model(load_weights = True):
    # alphabet seems to be fine:
    """
    https://github.com/davek44/Basset/tree/master/src/dna_io.py#L145-L148
    seq = seq.replace('A','0')
    seq = seq.replace('C','1')
    seq = seq.replace('G','2')
    seq = seq.replace('T','3')
    """
    pretrained_model_reloaded_th = nn.Sequential( # Sequential,
        nn.Conv2d(4,300,(19, 1)),
        nn.BatchNorm2d(300),
        nn.ReLU(),
        nn.MaxPool2d((3, 1),(3, 1)),
        nn.Conv2d(300,200,(11, 1)),
        nn.BatchNorm2d(200),
        nn.ReLU(),
        nn.MaxPool2d((4, 1),(4, 1)),
        nn.Conv2d(200,200,(7, 1)),
        nn.BatchNorm2d(200),
        nn.ReLU(),
        nn.MaxPool2d((4, 1),(4, 1)),
        Lambda(lambda x: x.view(x.size(0),-1)), # Reshape,
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(2000,1000)), # Linear,
        nn.BatchNorm1d(1000,1e-05,0.1,True),#BatchNorm1d,
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(1000,1000)), # Linear,
        nn.BatchNorm1d(1000,1e-05,0.1,True),#BatchNorm1d,
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(1000,164)), # Linear,
        nn.Sigmoid(),
    )
    if load_weights:
        sd = torch.load('model_files/pretrained_model_reloaded_th.pth')
        pretrained_model_reloaded_th.load_state_dict(sd)
    return  pretrained_model_reloaded_th 
Example 29
Project: VSE-C   Author: ExplorerFreda   File: evaluation.py    License: MIT License 5 votes vote down vote up
def eval_with_extended(model_path, data_path=None, data_name=None, split='test'):
    checkpoint = torch.load(model_path)
    opt = checkpoint['opt']
    opt.use_external_captions = True
    opt.negative_number = 5
    if data_path is not None:
        opt.data_path = data_path
    if data_name is not None:
        opt.data_name = data_name

    # load vocabulary used by the model
    with open(os.path.join(opt.vocab_path,
                           '%s_vocab.pkl' % opt.data_name), 'rb') as f:
        vocab = pickle.load(f)
    opt.vocab_size = len(vocab)
    opt.use_external_captions = True

    # construct model
    model = VSE(opt)

    # load model state
    model.load_state_dict(checkpoint['model'])

    print('Loading dataset')
    data_loader = get_test_loader(split, opt.data_name, vocab, opt.crop_size,
                                  opt.batch_size, opt.workers, opt)
    print('Computing results...')
    img_embs, cap_embs = encode_data(model, data_loader)
    print('Images: %d, Captions: %d' %
          (img_embs.shape[0] // 5, cap_embs.shape[0]))

    r, rt = i2t_text_only(img_embs, cap_embs, measure=opt.measure, return_ranks=True)
    ar = (r[0] + r[1] + r[2]) / 3
    print("Average i2t Recall: %.1f" % ar)
    print("Image to text: %.1f\t%.1f\t%.1f\t%.1f\t%.1f" % r)
    torch.save({'rt': rt}, model_path[:model_path.find('model_best')] + 'ranks_extended.pth.tar') 
Example 30
Project: VSE-C   Author: ExplorerFreda   File: saliency_visualization.py    License: MIT License 5 votes vote down vote up
def main():
    encoder = ImageEncoder(args.encoder, args.load_encoder)
    dataset = Dataset(args)
    extractor = FeatureExtractor(args.load, encoder, dataset)

    def e(ind):
        a = extractor(ind)
        pic = plot_saliency(a.raw_image, a.image, a.image_embedding, a.caption_embedding)
        pic.save('/tmp/origin{}.png'.format(ind))
        print('Image saliency saved:', '/tmp/origin{}.png'.format(ind))
        print_txt_saliency(a.captions, 0, a.raw_caption, a.image_embedding, a.caption_embedding)
        return a

    from IPython import embed; embed()