Python torch.load() Examples
The following are 30
code examples of torch.load().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.

Example #1
Source File: model_architecture.py From models with MIT License | 8 votes |
def get_model(load_weights = True): deepsea_cpu = nn.Sequential( # Sequential, nn.Conv2d(4,320,(1, 8),(1, 1)), nn.Threshold(0, 1e-06), nn.MaxPool2d((1, 4),(1, 4)), nn.Dropout(0.2), nn.Conv2d(320,480,(1, 8),(1, 1)), nn.Threshold(0, 1e-06), nn.MaxPool2d((1, 4),(1, 4)), nn.Dropout(0.2), nn.Conv2d(480,960,(1, 8),(1, 1)), nn.Threshold(0, 1e-06), nn.Dropout(0.5), Lambda(lambda x: x.view(x.size(0),-1)), # Reshape, nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(50880,925)), # Linear, nn.Threshold(0, 1e-06), nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(925,919)), # Linear, nn.Sigmoid(), ) if load_weights: deepsea_cpu.load_state_dict(torch.load('model_files/deepsea_cpu.pth')) return nn.Sequential(ReCodeAlphabet(), deepsea_cpu)
Example #2
Source File: model_architecture.py From models with MIT License | 6 votes |
def get_seqpred_model(load_weights = True): deepsea_cpu = nn.Sequential( # Sequential, nn.Conv2d(4,320,(1, 8),(1, 1)), nn.Threshold(0, 1e-06), nn.MaxPool2d((1, 4),(1, 4)), nn.Dropout(0.2), nn.Conv2d(320,480,(1, 8),(1, 1)), nn.Threshold(0, 1e-06), nn.MaxPool2d((1, 4),(1, 4)), nn.Dropout(0.2), nn.Conv2d(480,960,(1, 8),(1, 1)), nn.Threshold(0, 1e-06), nn.Dropout(0.5), Lambda(lambda x: x.view(x.size(0),-1)), # Reshape, nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(50880,925)), # Linear, nn.Threshold(0, 1e-06), nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(925,919)), # Linear, nn.Sigmoid(), ) if load_weights: deepsea_cpu.load_state_dict(torch.load('model_files/deepsea_cpu.pth')) return nn.Sequential(ReCodeAlphabet(), ConcatenateRC(), deepsea_cpu, AverageRC())
Example #3
Source File: utils.py From Pytorch-Networks with MIT License | 6 votes |
def load_test_checkpoints(model, save_path, logger, use_best=False): #try: if use_best: print(save_path.EXPS+save_path.NAME+save_path.BESTMODEL) states= torch.load(save_path.EXPS+save_path.NAME+save_path.BESTMODEL) if torch.cuda.is_available() \ else torch.load(save_path.EXPS+save_path.NAME+save_path.BESTMODEL, map_location=torch.device('cpu')) else: states= torch.load(save_path.EXPS+save_path.NAME+save_path.MODEL) if torch.cuda.is_available() \ else torch.load(save_path.EXPS+save_path.NAME+save_path.MODEL, map_location=torch.device('cpu')) #logger.debug("success") #try: model.load_state_dict(states['model_state']) # except: # states_no_module = OrderedDict() # for k, v in states['model_state'].items(): # name_no_module = k[7:] # states_no_module[name_no_module] = v # model.load_state_dict(states_no_module) logger.info('loading checkpoints success') # except: # logger.error("no checkpoints")
Example #4
Source File: regnet2mmdet.py From mmdetection with Apache License 2.0 | 6 votes |
def convert(src, dst): """Convert keys in pycls pretrained RegNet models to mmdet style.""" # load caffe model regnet_model = torch.load(src) blobs = regnet_model['model_state'] # convert to pytorch style state_dict = OrderedDict() converted_names = set() for key, weight in blobs.items(): if 'stem' in key: convert_stem(key, weight, state_dict, converted_names) elif 'head' in key: convert_head(key, weight, state_dict, converted_names) elif key.startswith('s'): convert_reslayer(key, weight, state_dict, converted_names) # check if all layers are converted for key in blobs: if key not in converted_names: print(f'not converted: {key}') # save checkpoint checkpoint = dict() checkpoint['state_dict'] = state_dict torch.save(checkpoint, dst)
Example #5
Source File: main.py From transferlearning with MIT License | 6 votes |
def extract_feature(model, dataloader, save_path, load_from_disk=True, model_path=''): if load_from_disk: model = models.Network(base_net=args.model_name, n_class=args.num_class) model.load_state_dict(torch.load(model_path)) model = model.to(DEVICE) model.eval() correct = 0 fea_all = torch.zeros(1,1+model.base_network.output_num()).to(DEVICE) with torch.no_grad(): for inputs, labels in dataloader: inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) feas = model.get_features(inputs) labels = labels.view(labels.size(0), 1).float() x = torch.cat((feas, labels), dim=1) fea_all = torch.cat((fea_all, x), dim=0) outputs = model(inputs) preds = torch.max(outputs, 1)[1] correct += torch.sum(preds == labels.data.long()) test_acc = correct.double() / len(dataloader.dataset) fea_numpy = fea_all.cpu().numpy() np.savetxt(save_path, fea_numpy[1:], fmt='%.6f', delimiter=',') print('Test acc: %f' % test_acc) # You may want to classify with 1nn after getting features
Example #6
Source File: train_val.py From Collaborative-Learning-for-Weakly-Supervised-Object-Detection with MIT License | 6 votes |
def from_snapshot(self, sfile, nfile): print('Restoring model snapshots from {:s}'.format(sfile)) self.net.load_state_dict(torch.load(str(sfile))) print('Restored.') # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have # tried my best to find the random states so that it can be recovered exactly # However the Tensorflow state is currently not available with open(nfile, 'rb') as fid: st0 = pickle.load(fid) cur = pickle.load(fid) perm = pickle.load(fid) cur_val = pickle.load(fid) perm_val = pickle.load(fid) last_snapshot_iter = pickle.load(fid) np.random.set_state(st0) self.data_layer._cur = cur self.data_layer._perm = perm self.data_layer_val._cur = cur_val self.data_layer_val._perm = perm_val return last_snapshot_iter
Example #7
Source File: train_lm.py From End-to-end-ASR-Pytorch with MIT License | 6 votes |
def set_model(self): ''' Setup ASR model and optimizer ''' # Model self.model = RNNLM(self.vocab_size, ** self.config['model']).to(self.device) self.verbose(self.model.create_msg()) # Losses self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0) # Optimizer self.optimizer = Optimizer( self.model.parameters(), **self.config['hparas']) # Enable AMP if needed self.enable_apex() # load pre-trained model if self.paras.load: self.load_ckpt() ckpt = torch.load(self.paras.load, map_location=self.device) self.model.load_state_dict(ckpt['model']) self.optimizer.load_opt_state_dict(ckpt['optimizer']) self.step = ckpt['global_step'] self.verbose('Load ckpt from {}, restarting at step {}'.format( self.paras.load, self.step))
Example #8
Source File: test.py From pytorch_NER_BiLSTM_CNN_CRF with Apache License 2.0 | 6 votes |
def load_test_model(model, config): """ :param model: initial model :param config: config :return: loaded model """ if config.t_model is None: test_model_dir = config.save_best_model_dir test_model_name = "{}.pt".format(config.model_name) test_model_path = os.path.join(test_model_dir, test_model_name) print("load default model from {}".format(test_model_path)) else: test_model_path = config.t_model print("load user model from {}".format(test_model_path)) model.load_state_dict(torch.load(test_model_path)) return model
Example #9
Source File: saliency_visualization.py From VSE-C with MIT License | 6 votes |
def load_checkpoint(self, checkpoint): checkpoint = torch.load(checkpoint) opt = checkpoint['opt'] opt.use_external_captions = False vocab = Vocab.from_pickle(pjoin(opt.vocab_path, '%s_vocab.pkl' % opt.data_name)) opt.vocab_size = len(vocab) from model import VSE self.model = VSE(opt) self.model.load_state_dict(checkpoint['model']) self.projector = vocab self.model.img_enc.eval() self.model.txt_enc.eval() for p in self.model.img_enc.parameters(): p.requires_grad = False for p in self.model.txt_enc.parameters(): p.requires_grad = False
Example #10
Source File: nyu_walkable_surface_dataset.py From dogTorch with MIT License | 6 votes |
def __getitem__(self, idx): fid = self.data_set_list[idx] if self.read_features: features = [] for i in range(self.sequence_length): feature_path = os.path.join( self.features_dir, self.frames_metadata[fid + i]['cur_frame'] + '.pytar') features.append(torch.load(feature_path)) input = torch.stack(features) else: image = self.load_and_resize( os.path.join(self.root_dir, 'images', fid)) segment = self.load_and_resize_segmentation( os.path.join(self.root_dir, 'walkable', fid)) # The two 0s are just place holders. They can be replaced by any values return (image, segment, 0, 0, ['images' + fid])
Example #11
Source File: dcgan.py From Pytorch-Project-Template with MIT License | 6 votes |
def load_checkpoint(self, file_name): filename = self.config.checkpoint_dir + file_name try: self.logger.info("Loading checkpoint '{}'".format(filename)) checkpoint = torch.load(filename) self.current_epoch = checkpoint['epoch'] self.current_iteration = checkpoint['iteration'] self.netG.load_state_dict(checkpoint['G_state_dict']) self.optimG.load_state_dict(checkpoint['G_optimizer']) self.netD.load_state_dict(checkpoint['D_state_dict']) self.optimD.load_state_dict(checkpoint['D_optimizer']) self.fixed_noise = checkpoint['fixed_noise'] self.manual_seed = checkpoint['manual_seed'] self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n" .format(self.config.checkpoint_dir, checkpoint['epoch'], checkpoint['iteration'])) except OSError as e: self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir)) self.logger.info("**First time to train**")
Example #12
Source File: erfnet.py From Pytorch-Project-Template with MIT License | 6 votes |
def load_checkpoint(self, filename): filename = self.config.checkpoint_dir + filename try: self.logger.info("Loading checkpoint '{}'".format(filename)) checkpoint = torch.load(filename) self.current_epoch = checkpoint['epoch'] self.current_iteration = checkpoint['iteration'] self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n" .format(self.config.checkpoint_dir, checkpoint['epoch'], checkpoint['iteration'])) except OSError as e: self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir)) self.logger.info("**First time to train**")
Example #13
Source File: condensenet.py From Pytorch-Project-Template with MIT License | 6 votes |
def load_checkpoint(self, filename): filename = self.config.checkpoint_dir + filename try: self.logger.info("Loading checkpoint '{}'".format(filename)) checkpoint = torch.load(filename) self.current_epoch = checkpoint['epoch'] self.current_iteration = checkpoint['iteration'] self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n" .format(self.config.checkpoint_dir, checkpoint['epoch'], checkpoint['iteration'])) except OSError as e: self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir)) self.logger.info("**First time to train**")
Example #14
Source File: dqn.py From Pytorch-Project-Template with MIT License | 6 votes |
def load_checkpoint(self, file_name): filename = self.config.checkpoint_dir + file_name try: self.logger.info("Loading checkpoint '{}'".format(filename)) checkpoint = torch.load(filename) self.current_episode = checkpoint['episode'] self.current_iteration = checkpoint['iteration'] self.policy_model.load_state_dict(checkpoint['state_dict']) self.optim.load_state_dict(checkpoint['optimizer']) self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n" .format(self.config.checkpoint_dir, checkpoint['episode'], checkpoint['iteration'])) except OSError as e: self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir)) self.logger.info("**First time to train**")
Example #15
Source File: utils.py From pruning_yolov3 with GNU General Public License v3.0 | 6 votes |
def print_mutation(hyp, results, bucket=''): # Print mutation results to evolve.txt (for use with train.py --evolve) a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values c = '%10.3g' * len(results) % results # results (P, R, mAP, F1, test_loss) print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c)) if bucket: os.system('gsutil cp gs://%s/evolve.txt .' % bucket) # download evolve.txt with open('evolve.txt', 'a') as f: # append result f.write(c + b + '\n') x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows np.savetxt('evolve.txt', x[np.argsort(-fitness(x))], '%10.3g') # save sort by fitness if bucket: os.system('gsutil cp evolve.txt gs://%s' % bucket) # upload evolve.txt
Example #16
Source File: models.py From cvpr2018-hnd with MIT License | 6 votes |
def init_truncated_normal(model, aux_str=''): if model is None: return None init_path = '{path}/{in_dim:d}_{out_dim:d}{aux_str}.pth' \ .format(path=path, in_dim=model.in_features, out_dim=model.out_features, aux_str=aux_str) if os.path.isfile(init_path): model.load_state_dict(torch.load(init_path)) print('load init weight: {init_path}'.format(init_path=init_path)) else: if isinstance(model, nn.ModuleList): [truncated_normal(sub) for sub in model] else: truncated_normal(model) print('generate init weight: {init_path}'.format(init_path=init_path)) torch.save(model.state_dict(), init_path) print('save init weight: {init_path}'.format(init_path=init_path)) return model
Example #17
Source File: utils.py From cvpr2018-hnd with MIT License | 6 votes |
def load_model(model, optimizer, scheduler, path, num_epochs, start_time=time.time()): epoch = num_epochs while epoch > 0 and not os.path.isfile('{path}_model_{epoch:d}.pth'.format(path=path, epoch=epoch)): epoch -= 1 if epoch > 0: model_path = '{path}_model_{epoch:d}.pth'.format(path=path, epoch=epoch) model_state_dict = torch.load('{path}_model_{epoch:d}.pth'.format(path=path, epoch=epoch)) model.load_state_dict(model_state_dict) if optimizer is not None: optimizer_state_dict = torch.load('{path}_optimizer_{epoch:d}.pth'.format(path=path, epoch=epoch)) optimizer.load_state_dict(optimizer_state_dict) if scheduler is not None: scheduler_state_dict = torch.load('{path}_scheduler_{epoch:d}.pth'.format(path=path, epoch=epoch)) scheduler.best = scheduler_state_dict['best'] scheduler.cooldown_counter = scheduler_state_dict['cooldown_counter'] scheduler.num_bad_epochs = scheduler_state_dict['num_bad_epochs'] scheduler.last_epoch = scheduler_state_dict['last_epoch'] print('{epoch:4d}/{num_epochs:4d} e; '.format(epoch=epoch, num_epochs=num_epochs), end='') print('load {path}; '.format(path=model_path), end='') print('{time:8.3f} s'.format(time=time.time()-start_time)) return epoch
Example #18
Source File: network_factory.py From Pytorch-Networks with MIT License | 5 votes |
def get_network(net_name, logger=None, cfg=None): try: net_class = NET_LUT.get(net_name) except: logger.error("network tpye error, {} not exist".format(net_name)) net_instance = net_class(cfg=cfg, logger=logger) if cfg.PRETRAIN is not None: load_func = LOAD_LUT.get(net_name) load_func(net_instance,cfg.PRETRAIN_PATH,cfg.PRETRAIN) logger.info("load {} pretrain weight success".format(net_name)) return net_instance
Example #19
Source File: solver.py From End-to-end-ASR-Pytorch with MIT License | 5 votes |
def load_ckpt(self): ''' Load ckpt if --load option is specified ''' if self.paras.load: # Load weights ckpt = torch.load( self.paras.load, map_location=self.device if self.mode == 'train' else 'cpu') self.model.load_state_dict(ckpt['model']) if self.emb_decoder is not None: self.emb_decoder.load_state_dict(ckpt['emb_decoder']) # if self.amp: # amp.load_state_dict(ckpt['amp']) # Load task-dependent items metric = "None" score = 0.0 for k, v in ckpt.items(): if type(v) is float: metric, score = k, v if self.mode == 'train': self.step = ckpt['global_step'] self.optimizer.load_opt_state_dict(ckpt['optimizer']) self.verbose('Load ckpt from {}, restarting at step {} (recorded {} = {:.2f} %)'.format( self.paras.load, self.step, metric, score)) else: self.model.eval() if self.emb_decoder is not None: self.emb_decoder.eval() self.verbose('Evaluation target = {} (recorded {} = {:.2f} %)'.format(self.paras.load, metric, score))
Example #20
Source File: decode.py From End-to-end-ASR-Pytorch with MIT License | 5 votes |
def __init__(self, asr, emb_decoder, beam_size, min_len_ratio, max_len_ratio, lm_path='', lm_config='', lm_weight=0.0, ctc_weight=0.0): super().__init__() # Setup self.beam_size = beam_size self.min_len_ratio = min_len_ratio self.max_len_ratio = max_len_ratio self.asr = asr # ToDo : implement pure ctc decode assert self.asr.enable_att # Additional decoding modules self.apply_ctc = ctc_weight > 0 if self.apply_ctc: assert self.asr.ctc_weight > 0, 'ASR was not trained with CTC decoder' self.ctc_w = ctc_weight self.ctc_beam_size = int(CTC_BEAM_RATIO * self.beam_size) self.apply_lm = lm_weight > 0 if self.apply_lm: self.lm_w = lm_weight self.lm_path = lm_path lm_config = yaml.load(open(lm_config, 'r'), Loader=yaml.FullLoader) self.lm = RNNLM(self.asr.vocab_size, **lm_config['model']) self.lm.load_state_dict(torch.load( self.lm_path, map_location='cpu')['model']) self.lm.eval() self.apply_emb = emb_decoder is not None if self.apply_emb: self.emb_decoder = emb_decoder
Example #21
Source File: solver.py From End-to-end-ASR-Pytorch with MIT License | 5 votes |
def load_data(self): ''' Called by main to load all data After this call, data related attributes should be setup (e.g. self.tr_set, self.dev_set) No return value ''' raise NotImplementedError
Example #22
Source File: tools.py From transferlearning with MIT License | 5 votes |
def print_model_parm_nums(model): #model = ResNet.DANNet(num_classes=31) #model = torch.load('./models/alex/model_20.pkl') total = sum([param.nelement() for param in model.parameters()]) print(' + Number of params: %.2fM' % (total / 1e6))
Example #23
Source File: utils.py From Pytorch-Networks with MIT License | 5 votes |
def load_checkpoints(model, opt, save_path, logger, lrs=None): try: states = torch.load(save_path.EXPS+save_path.NAME+save_path.MODEL) model.load_state_dict(states['model_state']) opt.load_state_dict(states['opt_state']) current_epoch = states['epoch'] if lrs is not None: lrs.load_state_dict(states['lrs']) logger.info('loading checkpoints success') except: current_epoch = 0 logger.info("no checkpoints") return current_epoch
Example #24
Source File: network_factory.py From Pytorch-Networks with MIT License | 5 votes |
def load_regnet_weight(model,pretrain_path,sub_name): from collections import OrderedDict import torch checkpoints = torch.load(pretrain_path+WEIGHT_LUT[sub_name]) states_no_module = OrderedDict() for k, v in checkpoints['model_state'].items(): if k != 'head.fc.weight' and k!= 'head.fc.bias': name_no_module = k states_no_module[name_no_module] = v model.load_state_dict(states_no_module,strict=False)
Example #25
Source File: utils.py From pruning_yolov3 with GNU General Public License v3.0 | 5 votes |
def create_backbone(f='weights/last.pt'): # from utils.utils import *; create_backbone() # create a backbone from a *.pt file x = torch.load(f) x['optimizer'] = None x['training_results'] = None x['epoch'] = -1 for p in x['model'].values(): try: p.requires_grad = True except: pass torch.save(x, 'weights/backbone.pt')
Example #26
Source File: utils.py From pruning_yolov3 with GNU General Public License v3.0 | 5 votes |
def strip_optimizer(f='weights/last.pt'): # from utils.utils import *; strip_optimizer() # Strip optimizer from *.pt files for lighter files (reduced by 2/3 size) x = torch.load(f) x['optimizer'] = None torch.save(x, f)
Example #27
Source File: model.py From graph-neural-networks with GNU General Public License v3.0 | 5 votes |
def load(self, label = '', **kwargs): if 'loadFiles' in kwargs.keys(): (architLoadFile, optimLoadFile) = kwargs['loadFiles'] else: saveModelDir = os.path.join(self.saveDir,'savedModels') architLoadFile = os.path.join(saveModelDir, self.name + 'Archit' + label +'.ckpt') optimLoadFile = os.path.join(saveModelDir, self.name + 'Optim' + label + '.ckpt') self.archit.load_state_dict(torch.load(architLoadFile)) self.optim.load_state_dict(torch.load(optimLoadFile))
Example #28
Source File: pretrain.py From OpenNRE with MIT License | 5 votes |
def get_model(model_name, root_path=default_root_path): check_root() ckpt = os.path.join(root_path, 'pretrain/nre/' + model_name + '.pth.tar') if model_name == 'wiki80_cnn_softmax': download_pretrain(model_name, root_path=root_path) download('glove', root_path=root_path) download('wiki80', root_path=root_path) wordi2d = json.load(open(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_word2id.json'))) word2vec = np.load(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_mat.npy')) rel2id = json.load(open(os.path.join(root_path, 'benchmark/wiki80/wiki80_rel2id.json'))) sentence_encoder = encoder.CNNEncoder(token2id=wordi2d, max_length=40, word_size=50, position_size=5, hidden_size=230, blank_padding=True, kernel_size=3, padding_size=1, word2vec=word2vec, dropout=0.5) m = model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id) m.load_state_dict(torch.load(ckpt, map_location='cpu')['state_dict']) return m elif model_name == 'wiki80_bert_softmax': download_pretrain(model_name, root_path=root_path) download('bert_base_uncased', root_path=root_path) download('wiki80', root_path=root_path) rel2id = json.load(open(os.path.join(root_path, 'benchmark/wiki80/wiki80_rel2id.json'))) sentence_encoder = encoder.BERTEncoder( max_length=80, pretrain_path=os.path.join(root_path, 'pretrain/bert-base-uncased')) m = model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id) m.load_state_dict(torch.load(ckpt, map_location='cpu')['state_dict']) return m else: raise NotImplementedError
Example #29
Source File: modeling.py From cmrc2019 with Creative Commons Attribution Share Alike 4.0 International | 5 votes |
def __init__(self, config): super(BertEmbeddings, self).__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob)
Example #30
Source File: dog_multi_image_dataset.py From dogTorch with MIT License | 5 votes |
def _read_labels(json_file, imus, sequence_length): """Returns a list of all frames, and a list of where each data point (whose length is sequence_length) in the list of frames.""" with open(json_file, 'r') as fp: dataset_meta = json.load(fp) frames = [] idx_to_fid = [] centroids = { 'absolute_centroids': torch.Tensor(dataset_meta['absolute_centroids']), 'difference_centroids': torch.Tensor(dataset_meta['difference_centroids']), } for clip in dataset_meta['clips']: frame_clips = [{ 'cur_frame': frame_meta['filename'], 'prev_frame': frame_meta['prev-frame'], 'labels': torch.LongTensor( [frame_meta['imu-diff-clusters'][imu] for imu in imus]), 'diffs': torch.FloatTensor( [frame_meta['imu-diff-values'][imu] for imu in imus]), 'absolute_cur_imus': torch.FloatTensor( [frame_meta['absolute_cur_imus'][imu] for imu in imus]), 'absolute_prev_imus': torch.FloatTensor( [frame_meta['absolute_prev_imus'][imu] for imu in imus]), } for frame_meta in clip['frames']] for i in range(len(frame_clips) - sequence_length + 1): idx_to_fid.append(i + len(frames)) frames += frame_clips return frames, idx_to_fid, centroids