Python torch.load() Examples

The following are 30 code examples of torch.load(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: model_architecture.py    From models with MIT License 8 votes vote down vote up
def get_model(load_weights = True):
    deepsea_cpu = nn.Sequential( # Sequential,
        nn.Conv2d(4,320,(1, 8),(1, 1)),
        nn.Threshold(0, 1e-06),
        nn.MaxPool2d((1, 4),(1, 4)),
        nn.Dropout(0.2),
        nn.Conv2d(320,480,(1, 8),(1, 1)),
        nn.Threshold(0, 1e-06),
        nn.MaxPool2d((1, 4),(1, 4)),
        nn.Dropout(0.2),
        nn.Conv2d(480,960,(1, 8),(1, 1)),
        nn.Threshold(0, 1e-06),
        nn.Dropout(0.5),
        Lambda(lambda x: x.view(x.size(0),-1)), # Reshape,
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(50880,925)), # Linear,
        nn.Threshold(0, 1e-06),
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(925,919)), # Linear,
        nn.Sigmoid(),
    )
    if load_weights:
        deepsea_cpu.load_state_dict(torch.load('model_files/deepsea_cpu.pth'))
    return nn.Sequential(ReCodeAlphabet(), deepsea_cpu) 
Example #2
Source File: model_architecture.py    From models with MIT License 6 votes vote down vote up
def get_seqpred_model(load_weights = True):
    deepsea_cpu = nn.Sequential( # Sequential,
        nn.Conv2d(4,320,(1, 8),(1, 1)),
        nn.Threshold(0, 1e-06),
        nn.MaxPool2d((1, 4),(1, 4)),
        nn.Dropout(0.2),
        nn.Conv2d(320,480,(1, 8),(1, 1)),
        nn.Threshold(0, 1e-06),
        nn.MaxPool2d((1, 4),(1, 4)),
        nn.Dropout(0.2),
        nn.Conv2d(480,960,(1, 8),(1, 1)),
        nn.Threshold(0, 1e-06),
        nn.Dropout(0.5),
        Lambda(lambda x: x.view(x.size(0),-1)), # Reshape,
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(50880,925)), # Linear,
        nn.Threshold(0, 1e-06),
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(925,919)), # Linear,
        nn.Sigmoid(),
    )
    if load_weights:
        deepsea_cpu.load_state_dict(torch.load('model_files/deepsea_cpu.pth'))
    return nn.Sequential(ReCodeAlphabet(), ConcatenateRC(), deepsea_cpu, AverageRC()) 
Example #3
Source File: utils.py    From Pytorch-Networks with MIT License 6 votes vote down vote up
def load_test_checkpoints(model, save_path, logger, use_best=False):
    
    #try:
    if use_best:
        print(save_path.EXPS+save_path.NAME+save_path.BESTMODEL)
        states= torch.load(save_path.EXPS+save_path.NAME+save_path.BESTMODEL) if torch.cuda.is_available() \
            else torch.load(save_path.EXPS+save_path.NAME+save_path.BESTMODEL, map_location=torch.device('cpu'))
    else:   
        states= torch.load(save_path.EXPS+save_path.NAME+save_path.MODEL) if torch.cuda.is_available() \
            else torch.load(save_path.EXPS+save_path.NAME+save_path.MODEL, map_location=torch.device('cpu'))
    #logger.debug("success")
    #try:
    model.load_state_dict(states['model_state'])
    # except:
    #     states_no_module = OrderedDict()
    #     for k, v in states['model_state'].items():
    #         name_no_module = k[7:]
    #         states_no_module[name_no_module] = v
    #     model.load_state_dict(states_no_module)
    logger.info('loading checkpoints success')
    # except:
    #     logger.error("no checkpoints") 
Example #4
Source File: regnet2mmdet.py    From mmdetection with Apache License 2.0 6 votes vote down vote up
def convert(src, dst):
    """Convert keys in pycls pretrained RegNet models to mmdet style."""
    # load caffe model
    regnet_model = torch.load(src)
    blobs = regnet_model['model_state']
    # convert to pytorch style
    state_dict = OrderedDict()
    converted_names = set()
    for key, weight in blobs.items():
        if 'stem' in key:
            convert_stem(key, weight, state_dict, converted_names)
        elif 'head' in key:
            convert_head(key, weight, state_dict, converted_names)
        elif key.startswith('s'):
            convert_reslayer(key, weight, state_dict, converted_names)

    # check if all layers are converted
    for key in blobs:
        if key not in converted_names:
            print(f'not converted: {key}')
    # save checkpoint
    checkpoint = dict()
    checkpoint['state_dict'] = state_dict
    torch.save(checkpoint, dst) 
Example #5
Source File: main.py    From transferlearning with MIT License 6 votes vote down vote up
def extract_feature(model, dataloader, save_path, load_from_disk=True, model_path=''):
    if load_from_disk:
        model = models.Network(base_net=args.model_name,
                               n_class=args.num_class)
        model.load_state_dict(torch.load(model_path))
        model = model.to(DEVICE)
    model.eval()
    correct = 0
    fea_all = torch.zeros(1,1+model.base_network.output_num()).to(DEVICE)
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            feas = model.get_features(inputs)
            labels = labels.view(labels.size(0), 1).float()
            x = torch.cat((feas, labels), dim=1)
            fea_all = torch.cat((fea_all, x), dim=0)
            outputs = model(inputs)
            preds = torch.max(outputs, 1)[1]
            correct += torch.sum(preds == labels.data.long())
        test_acc = correct.double() / len(dataloader.dataset)
    fea_numpy = fea_all.cpu().numpy()
    np.savetxt(save_path, fea_numpy[1:], fmt='%.6f', delimiter=',')
    print('Test acc: %f' % test_acc)

# You may want to classify with 1nn after getting features 
Example #6
Source File: train_val.py    From Collaborative-Learning-for-Weakly-Supervised-Object-Detection with MIT License 6 votes vote down vote up
def from_snapshot(self, sfile, nfile):
    print('Restoring model snapshots from {:s}'.format(sfile))
    self.net.load_state_dict(torch.load(str(sfile)))
    print('Restored.')
    # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
    # tried my best to find the random states so that it can be recovered exactly
    # However the Tensorflow state is currently not available
    with open(nfile, 'rb') as fid:
      st0 = pickle.load(fid)
      cur = pickle.load(fid)
      perm = pickle.load(fid)
      cur_val = pickle.load(fid)
      perm_val = pickle.load(fid)
      last_snapshot_iter = pickle.load(fid)

      np.random.set_state(st0)
      self.data_layer._cur = cur
      self.data_layer._perm = perm
      self.data_layer_val._cur = cur_val
      self.data_layer_val._perm = perm_val

    return last_snapshot_iter 
Example #7
Source File: train_lm.py    From End-to-end-ASR-Pytorch with MIT License 6 votes vote down vote up
def set_model(self):
        ''' Setup ASR model and optimizer '''

        # Model
        self.model = RNNLM(self.vocab_size, **
                           self.config['model']).to(self.device)
        self.verbose(self.model.create_msg())
        # Losses
        self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
        # Optimizer
        self.optimizer = Optimizer(
            self.model.parameters(), **self.config['hparas'])
        # Enable AMP if needed
        self.enable_apex()
        # load pre-trained model
        if self.paras.load:
            self.load_ckpt()
            ckpt = torch.load(self.paras.load, map_location=self.device)
            self.model.load_state_dict(ckpt['model'])
            self.optimizer.load_opt_state_dict(ckpt['optimizer'])
            self.step = ckpt['global_step']
            self.verbose('Load ckpt from {}, restarting at step {}'.format(
                self.paras.load, self.step)) 
Example #8
Source File: test.py    From pytorch_NER_BiLSTM_CNN_CRF with Apache License 2.0 6 votes vote down vote up
def load_test_model(model, config):
    """
    :param model:  initial model
    :param config:  config
    :return:  loaded model
    """
    if config.t_model is None:
        test_model_dir = config.save_best_model_dir
        test_model_name = "{}.pt".format(config.model_name)
        test_model_path = os.path.join(test_model_dir, test_model_name)
        print("load default model from {}".format(test_model_path))
    else:
        test_model_path = config.t_model
        print("load user model from {}".format(test_model_path))
    model.load_state_dict(torch.load(test_model_path))
    return model 
Example #9
Source File: saliency_visualization.py    From VSE-C with MIT License 6 votes vote down vote up
def load_checkpoint(self, checkpoint):
        checkpoint = torch.load(checkpoint)
        opt = checkpoint['opt']
        opt.use_external_captions = False
        vocab = Vocab.from_pickle(pjoin(opt.vocab_path, '%s_vocab.pkl' % opt.data_name))
        opt.vocab_size = len(vocab)

        from model import VSE
        self.model = VSE(opt)
        self.model.load_state_dict(checkpoint['model'])
        self.projector = vocab

        self.model.img_enc.eval()
        self.model.txt_enc.eval()
        for p in self.model.img_enc.parameters():
            p.requires_grad = False
        for p in self.model.txt_enc.parameters():
            p.requires_grad = False 
Example #10
Source File: nyu_walkable_surface_dataset.py    From dogTorch with MIT License 6 votes vote down vote up
def __getitem__(self, idx):
        fid = self.data_set_list[idx]
        if self.read_features:
            features = []
            for i in range(self.sequence_length):
                feature_path = os.path.join(
                    self.features_dir,
                    self.frames_metadata[fid + i]['cur_frame'] + '.pytar')
                features.append(torch.load(feature_path))
            input = torch.stack(features)
        else:
            image = self.load_and_resize(
                os.path.join(self.root_dir, 'images', fid))
            segment = self.load_and_resize_segmentation(
                os.path.join(self.root_dir, 'walkable', fid))

        # The two 0s are just place holders. They can be replaced by any values
        return (image, segment, 0, 0, ['images' + fid]) 
Example #11
Source File: dcgan.py    From Pytorch-Project-Template with MIT License 6 votes vote down vote up
def load_checkpoint(self, file_name):
        filename = self.config.checkpoint_dir + file_name
        try:
            self.logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)

            self.current_epoch = checkpoint['epoch']
            self.current_iteration = checkpoint['iteration']
            self.netG.load_state_dict(checkpoint['G_state_dict'])
            self.optimG.load_state_dict(checkpoint['G_optimizer'])
            self.netD.load_state_dict(checkpoint['D_state_dict'])
            self.optimD.load_state_dict(checkpoint['D_optimizer'])
            self.fixed_noise = checkpoint['fixed_noise']
            self.manual_seed = checkpoint['manual_seed']

            self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
                  .format(self.config.checkpoint_dir, checkpoint['epoch'], checkpoint['iteration']))
        except OSError as e:
            self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir))
            self.logger.info("**First time to train**") 
Example #12
Source File: erfnet.py    From Pytorch-Project-Template with MIT License 6 votes vote down vote up
def load_checkpoint(self, filename):
        filename = self.config.checkpoint_dir + filename
        try:
            self.logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)

            self.current_epoch = checkpoint['epoch']
            self.current_iteration = checkpoint['iteration']
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])

            self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
                  .format(self.config.checkpoint_dir, checkpoint['epoch'], checkpoint['iteration']))
        except OSError as e:
            self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir))
            self.logger.info("**First time to train**") 
Example #13
Source File: condensenet.py    From Pytorch-Project-Template with MIT License 6 votes vote down vote up
def load_checkpoint(self, filename):
        filename = self.config.checkpoint_dir + filename
        try:
            self.logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)

            self.current_epoch = checkpoint['epoch']
            self.current_iteration = checkpoint['iteration']
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])

            self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
                             .format(self.config.checkpoint_dir, checkpoint['epoch'], checkpoint['iteration']))
        except OSError as e:
            self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir))
            self.logger.info("**First time to train**") 
Example #14
Source File: dqn.py    From Pytorch-Project-Template with MIT License 6 votes vote down vote up
def load_checkpoint(self, file_name):
        filename = self.config.checkpoint_dir + file_name
        try:
            self.logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)

            self.current_episode = checkpoint['episode']
            self.current_iteration = checkpoint['iteration']
            self.policy_model.load_state_dict(checkpoint['state_dict'])
            self.optim.load_state_dict(checkpoint['optimizer'])

            self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
                  .format(self.config.checkpoint_dir, checkpoint['episode'], checkpoint['iteration']))
        except OSError as e:
            self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir))
            self.logger.info("**First time to train**") 
Example #15
Source File: utils.py    From pruning_yolov3 with GNU General Public License v3.0 6 votes vote down vote up
def print_mutation(hyp, results, bucket=''):
    # Print mutation results to evolve.txt (for use with train.py --evolve)
    a = '%10s' * len(hyp) % tuple(hyp.keys())  # hyperparam keys
    b = '%10.3g' * len(hyp) % tuple(hyp.values())  # hyperparam values
    c = '%10.3g' * len(results) % results  # results (P, R, mAP, F1, test_loss)
    print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))

    if bucket:
        os.system('gsutil cp gs://%s/evolve.txt .' % bucket)  # download evolve.txt

    with open('evolve.txt', 'a') as f:  # append result
        f.write(c + b + '\n')
    x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0)  # load unique rows
    np.savetxt('evolve.txt', x[np.argsort(-fitness(x))], '%10.3g')  # save sort by fitness

    if bucket:
        os.system('gsutil cp evolve.txt gs://%s' % bucket)  # upload evolve.txt 
Example #16
Source File: models.py    From cvpr2018-hnd with MIT License 6 votes vote down vote up
def init_truncated_normal(model, aux_str=''):
    if model is None: return None
    init_path = '{path}/{in_dim:d}_{out_dim:d}{aux_str}.pth' \
                .format(path=path, in_dim=model.in_features, out_dim=model.out_features, aux_str=aux_str)
    if os.path.isfile(init_path):
        model.load_state_dict(torch.load(init_path))
        print('load init weight: {init_path}'.format(init_path=init_path))
    else:
        if isinstance(model, nn.ModuleList):
            [truncated_normal(sub) for sub in model]
        else:
            truncated_normal(model)
        print('generate init weight: {init_path}'.format(init_path=init_path))
        torch.save(model.state_dict(), init_path)
        print('save init weight: {init_path}'.format(init_path=init_path))
    
    return model 
Example #17
Source File: utils.py    From cvpr2018-hnd with MIT License 6 votes vote down vote up
def load_model(model, optimizer, scheduler, path, num_epochs, start_time=time.time()):

    epoch = num_epochs
    while epoch > 0 and not os.path.isfile('{path}_model_{epoch:d}.pth'.format(path=path, epoch=epoch)):
        epoch -= 1
    if epoch > 0:
        model_path = '{path}_model_{epoch:d}.pth'.format(path=path, epoch=epoch)
        model_state_dict = torch.load('{path}_model_{epoch:d}.pth'.format(path=path, epoch=epoch))
        model.load_state_dict(model_state_dict)
        if optimizer is not None:
            optimizer_state_dict = torch.load('{path}_optimizer_{epoch:d}.pth'.format(path=path, epoch=epoch))
            optimizer.load_state_dict(optimizer_state_dict)
        if scheduler is not None:
            scheduler_state_dict = torch.load('{path}_scheduler_{epoch:d}.pth'.format(path=path, epoch=epoch))
            scheduler.best = scheduler_state_dict['best']
            scheduler.cooldown_counter = scheduler_state_dict['cooldown_counter']
            scheduler.num_bad_epochs = scheduler_state_dict['num_bad_epochs']
            scheduler.last_epoch = scheduler_state_dict['last_epoch']
        print('{epoch:4d}/{num_epochs:4d} e; '.format(epoch=epoch, num_epochs=num_epochs), end='')
        print('load {path}; '.format(path=model_path), end='')
        print('{time:8.3f} s'.format(time=time.time()-start_time))
    return epoch 
Example #18
Source File: network_factory.py    From Pytorch-Networks with MIT License 5 votes vote down vote up
def get_network(net_name, logger=None, cfg=None):
    try:
        net_class = NET_LUT.get(net_name)
    except:
        logger.error("network tpye error, {} not exist".format(net_name))
    net_instance = net_class(cfg=cfg, logger=logger)
    if cfg.PRETRAIN is not None:
        load_func = LOAD_LUT.get(net_name)
        load_func(net_instance,cfg.PRETRAIN_PATH,cfg.PRETRAIN)
        logger.info("load {} pretrain weight success".format(net_name))
    return net_instance 
Example #19
Source File: solver.py    From End-to-end-ASR-Pytorch with MIT License 5 votes vote down vote up
def load_ckpt(self):
        ''' Load ckpt if --load option is specified '''
        if self.paras.load:
            # Load weights
            ckpt = torch.load(
                self.paras.load, map_location=self.device if self.mode == 'train' else 'cpu')
            self.model.load_state_dict(ckpt['model'])
            if self.emb_decoder is not None:
                self.emb_decoder.load_state_dict(ckpt['emb_decoder'])
            # if self.amp:
            #    amp.load_state_dict(ckpt['amp'])
            # Load task-dependent items
            metric = "None"
            score = 0.0
            for k, v in ckpt.items():
                if type(v) is float:
                    metric, score = k, v
            if self.mode == 'train':
                self.step = ckpt['global_step']
                self.optimizer.load_opt_state_dict(ckpt['optimizer'])
                self.verbose('Load ckpt from {}, restarting at step {} (recorded {} = {:.2f} %)'.format(
                              self.paras.load, self.step, metric, score))
            else:
                self.model.eval()
                if self.emb_decoder is not None:
                    self.emb_decoder.eval()
                self.verbose('Evaluation target = {} (recorded {} = {:.2f} %)'.format(self.paras.load, metric, score)) 
Example #20
Source File: decode.py    From End-to-end-ASR-Pytorch with MIT License 5 votes vote down vote up
def __init__(self, asr, emb_decoder, beam_size, min_len_ratio, max_len_ratio,
                 lm_path='', lm_config='', lm_weight=0.0, ctc_weight=0.0):
        super().__init__()
        # Setup
        self.beam_size = beam_size
        self.min_len_ratio = min_len_ratio
        self.max_len_ratio = max_len_ratio
        self.asr = asr

        # ToDo : implement pure ctc decode
        assert self.asr.enable_att

        # Additional decoding modules
        self.apply_ctc = ctc_weight > 0
        if self.apply_ctc:
            assert self.asr.ctc_weight > 0, 'ASR was not trained with CTC decoder'
            self.ctc_w = ctc_weight
            self.ctc_beam_size = int(CTC_BEAM_RATIO * self.beam_size)

        self.apply_lm = lm_weight > 0
        if self.apply_lm:
            self.lm_w = lm_weight
            self.lm_path = lm_path
            lm_config = yaml.load(open(lm_config, 'r'), Loader=yaml.FullLoader)
            self.lm = RNNLM(self.asr.vocab_size, **lm_config['model'])
            self.lm.load_state_dict(torch.load(
                self.lm_path, map_location='cpu')['model'])
            self.lm.eval()

        self.apply_emb = emb_decoder is not None
        if self.apply_emb:
            self.emb_decoder = emb_decoder 
Example #21
Source File: solver.py    From End-to-end-ASR-Pytorch with MIT License 5 votes vote down vote up
def load_data(self):
        '''
        Called by main to load all data
        After this call, data related attributes should be setup (e.g. self.tr_set, self.dev_set)
        No return value
        '''
        raise NotImplementedError 
Example #22
Source File: tools.py    From transferlearning with MIT License 5 votes vote down vote up
def print_model_parm_nums(model):
    #model = ResNet.DANNet(num_classes=31)
    #model = torch.load('./models/alex/model_20.pkl')
    total = sum([param.nelement() for param in model.parameters()])
    print('  + Number of params: %.2fM' % (total / 1e6)) 
Example #23
Source File: utils.py    From Pytorch-Networks with MIT License 5 votes vote down vote up
def load_checkpoints(model, opt, save_path, logger, lrs=None):
    try:
        states = torch.load(save_path.EXPS+save_path.NAME+save_path.MODEL)
        model.load_state_dict(states['model_state'])
        opt.load_state_dict(states['opt_state'])
        current_epoch = states['epoch']
        if lrs is not None:
            lrs.load_state_dict(states['lrs'])
        logger.info('loading checkpoints success')
    except:
        current_epoch = 0
        logger.info("no checkpoints")
    return current_epoch 
Example #24
Source File: network_factory.py    From Pytorch-Networks with MIT License 5 votes vote down vote up
def load_regnet_weight(model,pretrain_path,sub_name):
    from collections import OrderedDict
    import torch
    checkpoints = torch.load(pretrain_path+WEIGHT_LUT[sub_name])
    states_no_module = OrderedDict()
    for k, v in checkpoints['model_state'].items():
        if k != 'head.fc.weight' and k!= 'head.fc.bias':
            name_no_module = k
            states_no_module[name_no_module] = v
    model.load_state_dict(states_no_module,strict=False) 
Example #25
Source File: utils.py    From pruning_yolov3 with GNU General Public License v3.0 5 votes vote down vote up
def create_backbone(f='weights/last.pt'):  # from utils.utils import *; create_backbone()
    # create a backbone from a *.pt file
    x = torch.load(f)
    x['optimizer'] = None
    x['training_results'] = None
    x['epoch'] = -1
    for p in x['model'].values():
        try:
            p.requires_grad = True
        except:
            pass
    torch.save(x, 'weights/backbone.pt') 
Example #26
Source File: utils.py    From pruning_yolov3 with GNU General Public License v3.0 5 votes vote down vote up
def strip_optimizer(f='weights/last.pt'):  # from utils.utils import *; strip_optimizer()
    # Strip optimizer from *.pt files for lighter files (reduced by 2/3 size)
    x = torch.load(f)
    x['optimizer'] = None
    torch.save(x, f) 
Example #27
Source File: model.py    From graph-neural-networks with GNU General Public License v3.0 5 votes vote down vote up
def load(self, label = '', **kwargs):
        if 'loadFiles' in kwargs.keys():
            (architLoadFile, optimLoadFile) = kwargs['loadFiles']
        else:
            saveModelDir = os.path.join(self.saveDir,'savedModels')
            architLoadFile = os.path.join(saveModelDir,
                                          self.name + 'Archit' + label +'.ckpt')
            optimLoadFile = os.path.join(saveModelDir,
                                         self.name + 'Optim' + label + '.ckpt')
        self.archit.load_state_dict(torch.load(architLoadFile))
        self.optim.load_state_dict(torch.load(optimLoadFile)) 
Example #28
Source File: pretrain.py    From OpenNRE with MIT License 5 votes vote down vote up
def get_model(model_name, root_path=default_root_path):
    check_root()
    ckpt = os.path.join(root_path, 'pretrain/nre/' + model_name + '.pth.tar')
    if model_name == 'wiki80_cnn_softmax':
        download_pretrain(model_name, root_path=root_path)
        download('glove', root_path=root_path)
        download('wiki80', root_path=root_path)
        wordi2d = json.load(open(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_word2id.json')))
        word2vec = np.load(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_mat.npy'))
        rel2id = json.load(open(os.path.join(root_path, 'benchmark/wiki80/wiki80_rel2id.json')))
        sentence_encoder = encoder.CNNEncoder(token2id=wordi2d,
                                                     max_length=40,
                                                     word_size=50,
                                                     position_size=5,
                                                     hidden_size=230,
                                                     blank_padding=True,
                                                     kernel_size=3,
                                                     padding_size=1,
                                                     word2vec=word2vec,
                                                     dropout=0.5)
        m = model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id)
        m.load_state_dict(torch.load(ckpt, map_location='cpu')['state_dict'])
        return m
    elif model_name == 'wiki80_bert_softmax':
        download_pretrain(model_name, root_path=root_path)
        download('bert_base_uncased', root_path=root_path)
        download('wiki80', root_path=root_path)
        rel2id = json.load(open(os.path.join(root_path, 'benchmark/wiki80/wiki80_rel2id.json')))
        sentence_encoder = encoder.BERTEncoder(
            max_length=80, pretrain_path=os.path.join(root_path, 'pretrain/bert-base-uncased'))
        m = model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id)
        m.load_state_dict(torch.load(ckpt, map_location='cpu')['state_dict'])
        return m
    else:
        raise NotImplementedError 
Example #29
Source File: modeling.py    From cmrc2019 with Creative Commons Attribution Share Alike 4.0 International 5 votes vote down vote up
def __init__(self, config):
        super(BertEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob) 
Example #30
Source File: dog_multi_image_dataset.py    From dogTorch with MIT License 5 votes vote down vote up
def _read_labels(json_file, imus, sequence_length):
    """Returns a list of all frames, and a list of where each data point (whose
    length is sequence_length) in the list of frames."""
    with open(json_file, 'r') as fp:
        dataset_meta = json.load(fp)
    frames = []
    idx_to_fid = []
    centroids = {
        'absolute_centroids': torch.Tensor(dataset_meta['absolute_centroids']),
        'difference_centroids':
        torch.Tensor(dataset_meta['difference_centroids']),
    }
    for clip in dataset_meta['clips']:
        frame_clips = [{
            'cur_frame':
            frame_meta['filename'],
            'prev_frame':
            frame_meta['prev-frame'],
            'labels':
            torch.LongTensor(
                [frame_meta['imu-diff-clusters'][imu] for imu in imus]),
            'diffs':
            torch.FloatTensor(
                [frame_meta['imu-diff-values'][imu] for imu in imus]),
            'absolute_cur_imus':
            torch.FloatTensor(
                [frame_meta['absolute_cur_imus'][imu] for imu in imus]),
            'absolute_prev_imus':
            torch.FloatTensor(
                [frame_meta['absolute_prev_imus'][imu] for imu in imus]),
        } for frame_meta in clip['frames']]
        for i in range(len(frame_clips) - sequence_length + 1):
            idx_to_fid.append(i + len(frames))
        frames += frame_clips
    return frames, idx_to_fid, centroids