Python torch.save() Examples
The following are 30
code examples of torch.save().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.

Example #1
Source File: base_model.py From DDPAE-video-prediction with MIT License | 7 votes |
def save(self, ckpt_path, epoch): ''' Save checkpoint. ''' for name, net in self.nets.items(): if isinstance(net, torch.nn.DataParallel): module = net.module else: module = net path = os.path.join(ckpt_path, 'net_{}_{}.pth'.format(name, epoch)) torch.save(module.state_dict(), path) for name, optimizer in self.optimizers.items(): path = os.path.join(ckpt_path, 'optimizer_{}_{}.pth'.format(name, epoch)) torch.save(optimizer.state_dict(), path)
Example #2
Source File: utils.py From pytorch_NER_BiLSTM_CNN_CRF with Apache License 2.0 | 7 votes |
def save_model_all(model, save_dir, model_name, epoch): """ :param model: nn model :param save_dir: save model direction :param model_name: model name :param epoch: epoch :return: None """ if not os.path.isdir(save_dir): os.makedirs(save_dir) save_prefix = os.path.join(save_dir, model_name) save_path = '{}_epoch_{}.pt'.format(save_prefix, epoch) print("save all model to {}".format(save_path)) output = open(save_path, mode="wb") torch.save(model.state_dict(), output) # torch.save(model.state_dict(), save_path) output.close()
Example #3
Source File: regnet2mmdet.py From mmdetection with Apache License 2.0 | 6 votes |
def convert(src, dst): """Convert keys in pycls pretrained RegNet models to mmdet style.""" # load caffe model regnet_model = torch.load(src) blobs = regnet_model['model_state'] # convert to pytorch style state_dict = OrderedDict() converted_names = set() for key, weight in blobs.items(): if 'stem' in key: convert_stem(key, weight, state_dict, converted_names) elif 'head' in key: convert_head(key, weight, state_dict, converted_names) elif key.startswith('s'): convert_reslayer(key, weight, state_dict, converted_names) # check if all layers are converted for key in blobs: if key not in converted_names: print(f'not converted: {key}') # save checkpoint checkpoint = dict() checkpoint['state_dict'] = state_dict torch.save(checkpoint, dst)
Example #4
Source File: utils.py From pytorch_NER_BiLSTM_CNN_CRF with Apache License 2.0 | 6 votes |
def save_best_model(model, save_dir, model_name, best_eval): """ :param model: nn model :param save_dir: save model direction :param model_name: model name :param best_eval: eval best :return: None """ if best_eval.current_dev_score >= best_eval.best_dev_score: if not os.path.isdir(save_dir): os.makedirs(save_dir) model_name = "{}.pt".format(model_name) save_path = os.path.join(save_dir, model_name) print("save best model to {}".format(save_path)) # if os.path.exists(save_path): os.remove(save_path) output = open(save_path, mode="wb") torch.save(model.state_dict(), output) # torch.save(model.state_dict(), save_path) output.close() best_eval.early_current_patience = 0 # adjust lr
Example #5
Source File: util.py From DeepLab_v3_plus with MIT License | 6 votes |
def save_checkpoint(state, weights_dir = '' ): """[summary] [description] Arguments: state {[type]} -- [description] a dict describe some params is_best {bool} -- [description] a bool value Keyword Arguments: filename {str} -- [description] (default: {'checkpoint.pth.tar'}) """ if not os.path.exists(weights_dir): os.makedirs(weights_dir) epoch = state['epoch'] file_path = os.path.join(weights_dir, 'model-{:04d}.pth.tar'.format(int(epoch))) torch.save(state, file_path) ############################################# # loss function #############################################
Example #6
Source File: models.py From cvpr2018-hnd with MIT License | 6 votes |
def init_truncated_normal(model, aux_str=''): if model is None: return None init_path = '{path}/{in_dim:d}_{out_dim:d}{aux_str}.pth' \ .format(path=path, in_dim=model.in_features, out_dim=model.out_features, aux_str=aux_str) if os.path.isfile(init_path): model.load_state_dict(torch.load(init_path)) print('load init weight: {init_path}'.format(init_path=init_path)) else: if isinstance(model, nn.ModuleList): [truncated_normal(sub) for sub in model] else: truncated_normal(model) print('generate init weight: {init_path}'.format(init_path=init_path)) torch.save(model.state_dict(), init_path) print('save init weight: {init_path}'.format(init_path=init_path)) return model
Example #7
Source File: utils.py From pruning_yolov3 with GNU General Public License v3.0 | 6 votes |
def print_mutation(hyp, results, bucket=''): # Print mutation results to evolve.txt (for use with train.py --evolve) a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values c = '%10.3g' * len(results) % results # results (P, R, mAP, F1, test_loss) print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c)) if bucket: os.system('gsutil cp gs://%s/evolve.txt .' % bucket) # download evolve.txt with open('evolve.txt', 'a') as f: # append result f.write(c + b + '\n') x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows np.savetxt('evolve.txt', x[np.argsort(-fitness(x))], '%10.3g') # save sort by fitness if bucket: os.system('gsutil cp evolve.txt gs://%s' % bucket) # upload evolve.txt
Example #8
Source File: condensenet.py From Pytorch-Project-Template with MIT License | 6 votes |
def save_checkpoint(self, filename='checkpoint.pth.tar', is_best=0): """ Saving the latest checkpoint of the training :param filename: filename which will contain the state :param is_best: flag is it is the best model :return: """ state = { 'epoch': self.current_epoch, 'iteration': self.current_iteration, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), } # Save the state torch.save(state, self.config.checkpoint_dir + filename) # If it is the best copy it to another file 'model_best.pth.tar' if is_best: shutil.copyfile(self.config.checkpoint_dir + filename, self.config.checkpoint_dir + 'model_best.pth.tar')
Example #9
Source File: erfnet.py From Pytorch-Project-Template with MIT License | 6 votes |
def save_checkpoint(self, filename='checkpoint.pth.tar', is_best=0): """ Saving the latest checkpoint of the training :param filename: filename which will contain the state :param is_best: flag is it is the best model :return: """ state = { 'epoch': self.current_epoch + 1, 'iteration': self.current_iteration, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), } # Save the state torch.save(state, self.config.checkpoint_dir + filename) # If it is the best copy it to another file 'model_best.pth.tar' if is_best: shutil.copyfile(self.config.checkpoint_dir + filename, self.config.checkpoint_dir + 'model_best.pth.tar')
Example #10
Source File: dcgan.py From Pytorch-Project-Template with MIT License | 6 votes |
def save_checkpoint(self, file_name="checkpoint.pth.tar", is_best = 0): state = { 'epoch': self.current_epoch, 'iteration': self.current_iteration, 'G_state_dict': self.netG.state_dict(), 'G_optimizer': self.optimG.state_dict(), 'D_state_dict': self.netD.state_dict(), 'D_optimizer': self.optimD.state_dict(), 'fixed_noise': self.fixed_noise, 'manual_seed': self.manual_seed } # Save the state torch.save(state, self.config.checkpoint_dir + file_name) # If it is the best copy it to another file 'model_best.pth.tar' if is_best: shutil.copyfile(self.config.checkpoint_dir + file_name, self.config.checkpoint_dir + 'model_best.pth.tar')
Example #11
Source File: trainer.py From ACAN with MIT License | 6 votes |
def _save_checkpoint(self, epoch, acc): """ Saves a checkpoint of the network and other variables. Only save the best and latest epoch. """ net_type = type(self.net).__name__ if epoch - self.eval_freq != self.best_epoch: pre_save = os.path.join(self.logdir, '{}_{:03d}.pkl'.format(net_type, epoch - self.eval_freq)) if os.path.isfile(pre_save): os.remove(pre_save) cur_save = os.path.join(self.logdir, '{}_{:03d}.pkl'.format(net_type, epoch)) state = { 'epoch': epoch, 'acc': acc, 'net_type': net_type, 'net': self.net.state_dict(), 'optimizer': self.optimizer.state_dict(), #'scheduler': self.scheduler.state_dict(), 'use_gpu': self.use_gpu, 'save_time': datetime.datetime.now().strftime('%Y%m%d_%H%M%S') } torch.save(state, cur_save) return True
Example #12
Source File: solver.py From End-to-end-ASR-Pytorch with MIT License | 6 votes |
def save_checkpoint(self, f_name, metric, score, show_msg=True): '''' Ckpt saver f_name - <str> the name phnof ckpt file (w/o prefix) to store, overwrite if existed score - <float> The value of metric used to evaluate model ''' ckpt_path = os.path.join(self.ckpdir, f_name) full_dict = { "model": self.model.state_dict(), "optimizer": self.optimizer.get_opt_state_dict(), "global_step": self.step, metric: score } # Additional modules to save # if self.amp: # full_dict['amp'] = self.amp_lib.state_dict() if self.emb_decoder is not None: full_dict['emb_decoder'] = self.emb_decoder.state_dict() torch.save(full_dict, ckpt_path) if show_msg: self.verbose("Saved checkpoint (step = {}, {} = {:.2f}) and status @ {}". format(human_format(self.step), metric, score, ckpt_path))
Example #13
Source File: finetune.py From transferlearning with MIT License | 6 votes |
def train(self, optimizer = None, epoches = 10, save_name=None): for i in range(epoches): print("Epoch: ", i+1) self.train_epoch(optimizer, i+1, epoches+1) cur_correct = self.test() if cur_correct >= self.littlemax_correct: self.littlemax_correct = cur_correct self.cur_model = self.model print("write cur bset model") if cur_correct > self.max_correct: self.max_correct = cur_correct if save_name: torch.save(self.model, str(save_name)) print('amazon to webcam max correct: {} max accuracy{: .2f}%\n'.format( self.max_correct, 100.0 * self.max_correct / self.len_target_dataset)) print("Finished fine tuning.")
Example #14
Source File: saver.py From L3C-PyTorch with GNU General Public License v3.0 | 6 votes |
def save(self, modules, global_step, force=False): """ Save iff (force given or global_step % keep_tmp_itr == 0) :param modules: dictionary name -> nn.Module :param global_step: current step :return: bool, Whether previous checkpoints were removed """ if not (force or (global_step % self.keep_tmp_itr == 0)): return False assert self._out_dir is not None current_ckpt_p = self._save(modules, global_step) self.ckpts_since_last_permanent += 1 if self.ckpts_since_last_permanent == self.keep_every: self._remove_previous(current_ckpt_p) self.ckpts_since_last_permanent = 0 return True return False
Example #15
Source File: train_val.py From Collaborative-Learning-for-Weakly-Supervised-Object-Detection with MIT License | 5 votes |
def __init__(self, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None, wsddn_premodel=None): self.net = network self.imdb = imdb self.roidb = roidb self.valroidb = valroidb self.output_dir = output_dir self.tbdir = tbdir # Simply put '_val' at the end to save the summaries from the validation set self.tbvaldir = tbdir + '_val' if not os.path.exists(self.tbvaldir): os.makedirs(self.tbvaldir) self.pretrained_model = pretrained_model self.wsddn_premodel = wsddn_premodel
Example #16
Source File: train_val.py From Collaborative-Learning-for-Weakly-Supervised-Object-Detection with MIT License | 5 votes |
def snapshot(self, iter): net = self.net if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) # Store the model snapshot filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pth' filename = os.path.join(self.output_dir, filename) torch.save(self.net.state_dict(), filename) print('Wrote snapshot to: {:s}'.format(filename)) if iter % 10000 == 0: shutil.copyfile(filename, filename + '.{:d}_cache'.format(iter)) # Also store some meta information, random state, etc. nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl' nfilename = os.path.join(self.output_dir, nfilename) # current state of numpy random st0 = np.random.get_state() # current position in the database cur = self.data_layer._cur # current shuffled indexes of the database perm = self.data_layer._perm # current position in the validation database cur_val = self.data_layer_val._cur # current shuffled indexes of the validation database perm_val = self.data_layer_val._perm # Dump the meta info with open(nfilename, 'wb') as fid: pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL) pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL) return filename, nfilename
Example #17
Source File: main.py From controllable-text-attribute-transfer with Apache License 2.0 | 5 votes |
def preparation(): # set model save path if args.if_load_from_checkpoint: timestamp = args.checkpoint_name else: timestamp = str(int(time.time())) print("create new model save path: %s" % timestamp) args.current_save_path = 'save/%s/' % timestamp args.log_file = args.current_save_path + time.strftime("log_%Y_%m_%d_%H_%M_%S.txt", time.localtime()) args.output_file = args.current_save_path + time.strftime("output_%Y_%m_%d_%H_%M_%S.txt", time.localtime()) print("create log file at path: %s" % args.log_file) if os.path.exists(args.current_save_path): add_log("Load checkpoint model from Path: %s" % args.current_save_path) else: os.makedirs(args.current_save_path) add_log("Path: %s is created" % args.current_save_path) # set task type if args.task == 'yelp': args.data_path = '../../data/yelp/processed_files/' elif args.task == 'amazon': args.data_path = '../../data/amazon/processed_files/' elif args.task == 'imagecaption': pass else: raise TypeError('Wrong task type!') # prepare data args.id_to_word, args.vocab_size, \ args.train_file_list, args.train_label_list = prepare_data( data_path=args.data_path, max_num=args.word_dict_max_num, task_type=args.task ) return
Example #18
Source File: main.py From controllable-text-attribute-transfer with Apache License 2.0 | 5 votes |
def preparation(): # set model save path if args.if_load_from_checkpoint: timestamp = args.checkpoint_name else: timestamp = str(int(time.time())) print("create new model save path: %s" % timestamp) args.current_save_path = 'save/%s/' % timestamp args.log_file = args.current_save_path + time.strftime("log_%Y_%m_%d_%H_%M_%S.txt", time.localtime()) args.output_file = args.current_save_path + time.strftime("output_%Y_%m_%d_%H_%M_%S.txt", time.localtime()) print("create log file at path: %s" % args.log_file) if os.path.exists(args.current_save_path): add_log("Load checkpoint model from Path: %s" % args.current_save_path) else: os.makedirs(args.current_save_path) add_log("Path: %s is created" % args.current_save_path) # set task type if args.task == 'yelp': args.data_path = '../../data/yelp/processed_files/' elif args.task == 'amazon': args.data_path = '../../data/amazon/processed_files/' elif args.task == 'imagecaption': args.data_path = '../../data/imagecaption/processed_files/' else: raise TypeError('Wrong task type!') # prepare data args.id_to_word, args.vocab_size, \ args.train_file_list, args.train_label_list = prepare_data( data_path=args.data_path, max_num=args.word_dict_max_num, task_type=args.task ) return
Example #19
Source File: main.py From controllable-text-attribute-transfer with Apache License 2.0 | 5 votes |
def preparation(): # set model save path if args.if_load_from_checkpoint: timestamp = args.checkpoint_name else: timestamp = str(int(time.time())) print("create new model save path: %s" % timestamp) args.current_save_path = 'save/%s/' % timestamp args.log_file = args.current_save_path + time.strftime("log_%Y_%m_%d_%H_%M_%S.txt", time.localtime()) args.output_file = args.current_save_path + time.strftime("output_%Y_%m_%d_%H_%M_%S.txt", time.localtime()) print("create log file at path: %s" % args.log_file) if os.path.exists(args.current_save_path): add_log("Load checkpoint model from Path: %s" % args.current_save_path) else: os.makedirs(args.current_save_path) add_log("Path: %s is created" % args.current_save_path) # set task type if args.task == 'yelp': args.data_path = '../../data/yelp/processed_files/' elif args.task == 'amazon': args.data_path = '../../data/amazon/processed_files/' elif args.task == 'imagecaption': pass else: raise TypeError('Wrong task type!') # prepare data args.id_to_word, args.vocab_size, \ args.train_file_list, args.train_label_list = prepare_data( data_path=args.data_path, max_num=args.word_dict_max_num, task_type=args.task ) return
Example #20
Source File: data.py From comet-commonsense with Apache License 2.0 | 5 votes |
def save_checkpoint(state, filename): print("Saving model to {}".format(filename)) torch.save(state, filename)
Example #21
Source File: gather_models.py From mmdetection with Apache License 2.0 | 5 votes |
def process_checkpoint(in_file, out_file): checkpoint = torch.load(in_file, map_location='cpu') # remove optimizer for smaller file size if 'optimizer' in checkpoint: del checkpoint['optimizer'] # if it is necessary to remove some sensitive data in checkpoint['meta'], # add the code here. torch.save(checkpoint, out_file) sha = subprocess.check_output(['sha256sum', out_file]).decode() final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8]) subprocess.Popen(['mv', out_file, final_file]) return final_file
Example #22
Source File: regnet2mmdet.py From mmdetection with Apache License 2.0 | 5 votes |
def main(): parser = argparse.ArgumentParser(description='Convert model keys') parser.add_argument('src', help='src detectron model path') parser.add_argument('dst', help='save path') args = parser.parse_args() convert(args.src, args.dst)
Example #23
Source File: detectron2pytorch.py From mmdetection with Apache License 2.0 | 5 votes |
def convert(src, dst, depth): """Convert keys in detectron pretrained ResNet models to pytorch style.""" # load arch_settings if depth not in arch_settings: raise ValueError('Only support ResNet-50 and ResNet-101 currently') block_nums = arch_settings[depth] # load caffe model caffe_model = mmcv.load(src, encoding='latin1') blobs = caffe_model['blobs'] if 'blobs' in caffe_model else caffe_model # convert to pytorch style state_dict = OrderedDict() converted_names = set() convert_conv_fc(blobs, state_dict, 'conv1', 'conv1', converted_names) convert_bn(blobs, state_dict, 'res_conv1_bn', 'bn1', converted_names) for i in range(1, len(block_nums) + 1): for j in range(block_nums[i - 1]): if j == 0: convert_conv_fc(blobs, state_dict, f'res{i + 1}_{j}_branch1', f'layer{i}.{j}.downsample.0', converted_names) convert_bn(blobs, state_dict, f'res{i + 1}_{j}_branch1_bn', f'layer{i}.{j}.downsample.1', converted_names) for k, letter in enumerate(['a', 'b', 'c']): convert_conv_fc(blobs, state_dict, f'res{i + 1}_{j}_branch2{letter}', f'layer{i}.{j}.conv{k+1}', converted_names) convert_bn(blobs, state_dict, f'res{i + 1}_{j}_branch2{letter}_bn', f'layer{i}.{j}.bn{k + 1}', converted_names) # check if all layers are converted for key in blobs: if key not in converted_names: print(f'Not Convert: {key}') # save checkpoint checkpoint = dict() checkpoint['state_dict'] = state_dict torch.save(checkpoint, dst)
Example #24
Source File: detectron2pytorch.py From mmdetection with Apache License 2.0 | 5 votes |
def main(): parser = argparse.ArgumentParser(description='Convert model keys') parser.add_argument('src', help='src detectron model path') parser.add_argument('dst', help='save path') parser.add_argument('depth', type=int, help='ResNet model depth') args = parser.parse_args() convert(args.src, args.dst, args.depth)
Example #25
Source File: utils.py From subword-qac with MIT License | 5 votes |
def model_save(path, model, optimizer=None): model_to_save = get_model(model) open(os.path.join(path, 'config.json'), 'w').write(str(model_to_save.config)) torch.save(model_to_save.state_dict(), open(os.path.join(path, 'model.pt'), 'wb')) if optimizer: torch.save(optimizer.state_dict(), open(os.path.join(path, 'optimizer.pt'), 'wb'))
Example #26
Source File: utils.py From nmp_qc with MIT License | 5 votes |
def save_checkpoint(state, is_best, directory): if not os.path.isdir(directory): os.makedirs(directory) checkpoint_file = os.path.join(directory, 'checkpoint.pth') best_model_file = os.path.join(directory, 'model_best.pth') torch.save(state, checkpoint_file) if is_best: shutil.copyfile(checkpoint_file, best_model_file)
Example #27
Source File: model_architecture.py From models with MIT License | 5 votes |
def save_seqpred_model_weights(fname = 'model_files/deepsea_variant_effects.pth'): m = get_model() torch.save(m.state_dict(), fname)
Example #28
Source File: model_architecture.py From models with MIT License | 5 votes |
def save_seqpred_model_weights(fname = 'model_files/deepsea_predict.pth'): spm = get_seqpred_model() torch.save(spm.state_dict(), fname) # final models to be imported by Kipoi
Example #29
Source File: fashionmnist.py From Random-Erasing with Apache License 2.0 | 5 votes |
def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'): filepath = os.path.join(checkpoint, filename) torch.save(state, filepath) if is_best: shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))
Example #30
Source File: cifar.py From Random-Erasing with Apache License 2.0 | 5 votes |
def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'): filepath = os.path.join(checkpoint, filename) torch.save(state, filepath) if is_best: shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))