Python torch.save() Examples

The following are 30 code examples for showing how to use torch.save(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.

You may check out the related API usage on the sidebar.

You may also want to check out all available functions/classes of the module torch , or try the search function .

Example 1
Project: DDPAE-video-prediction   Author: jthsieh   File: base_model.py    License: MIT License 6 votes vote down vote up
def save(self, ckpt_path, epoch):
    '''
    Save checkpoint.
    '''
    for name, net in self.nets.items():
      if isinstance(net, torch.nn.DataParallel):
        module = net.module
      else:
        module = net

      path = os.path.join(ckpt_path, 'net_{}_{}.pth'.format(name, epoch))
      torch.save(module.state_dict(), path)

    for name, optimizer in self.optimizers.items():
      path = os.path.join(ckpt_path, 'optimizer_{}_{}.pth'.format(name, epoch))
      torch.save(optimizer.state_dict(), path) 
Example 2
Project: mmdetection   Author: open-mmlab   File: regnet2mmdet.py    License: Apache License 2.0 6 votes vote down vote up
def convert(src, dst):
    """Convert keys in pycls pretrained RegNet models to mmdet style."""
    # load caffe model
    regnet_model = torch.load(src)
    blobs = regnet_model['model_state']
    # convert to pytorch style
    state_dict = OrderedDict()
    converted_names = set()
    for key, weight in blobs.items():
        if 'stem' in key:
            convert_stem(key, weight, state_dict, converted_names)
        elif 'head' in key:
            convert_head(key, weight, state_dict, converted_names)
        elif key.startswith('s'):
            convert_reslayer(key, weight, state_dict, converted_names)

    # check if all layers are converted
    for key in blobs:
        if key not in converted_names:
            print(f'not converted: {key}')
    # save checkpoint
    checkpoint = dict()
    checkpoint['state_dict'] = state_dict
    torch.save(checkpoint, dst) 
Example 3
Project: pytorch_NER_BiLSTM_CNN_CRF   Author: bamtercelboo   File: utils.py    License: Apache License 2.0 6 votes vote down vote up
def save_model_all(model, save_dir, model_name, epoch):
    """
    :param model:  nn model
    :param save_dir: save model direction
    :param model_name:  model name
    :param epoch:  epoch
    :return:  None
    """
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    save_prefix = os.path.join(save_dir, model_name)
    save_path = '{}_epoch_{}.pt'.format(save_prefix, epoch)
    print("save all model to {}".format(save_path))
    output = open(save_path, mode="wb")
    torch.save(model.state_dict(), output)
    # torch.save(model.state_dict(), save_path)
    output.close() 
Example 4
Project: pytorch_NER_BiLSTM_CNN_CRF   Author: bamtercelboo   File: utils.py    License: Apache License 2.0 6 votes vote down vote up
def save_best_model(model, save_dir, model_name, best_eval):
    """
    :param model:  nn model
    :param save_dir:  save model direction
    :param model_name:  model name
    :param best_eval:  eval best
    :return:  None
    """
    if best_eval.current_dev_score >= best_eval.best_dev_score:
        if not os.path.isdir(save_dir): os.makedirs(save_dir)
        model_name = "{}.pt".format(model_name)
        save_path = os.path.join(save_dir, model_name)
        print("save best model to {}".format(save_path))
        # if os.path.exists(save_path):  os.remove(save_path)
        output = open(save_path, mode="wb")
        torch.save(model.state_dict(), output)
        # torch.save(model.state_dict(), save_path)
        output.close()
        best_eval.early_current_patience = 0


# adjust lr 
Example 5
Project: DeepLab_v3_plus   Author: songdejia   File: util.py    License: MIT License 6 votes vote down vote up
def save_checkpoint(state, weights_dir = '' ):
    """[summary]
    
    [description]
    
    Arguments:
        state {[type]} -- [description] a dict describe some params
        is_best {bool} -- [description] a bool value
    
    Keyword Arguments:
        filename {str} -- [description] (default: {'checkpoint.pth.tar'})
    """
    if not os.path.exists(weights_dir):
        os.makedirs(weights_dir)
    
    epoch = state['epoch']

    file_path = os.path.join(weights_dir, 'model-{:04d}.pth.tar'.format(int(epoch)))  
    torch.save(state, file_path)
    

#############################################
# loss function
############################################# 
Example 6
Project: cvpr2018-hnd   Author: kibok90   File: models.py    License: MIT License 6 votes vote down vote up
def init_truncated_normal(model, aux_str=''):
    if model is None: return None
    init_path = '{path}/{in_dim:d}_{out_dim:d}{aux_str}.pth' \
                .format(path=path, in_dim=model.in_features, out_dim=model.out_features, aux_str=aux_str)
    if os.path.isfile(init_path):
        model.load_state_dict(torch.load(init_path))
        print('load init weight: {init_path}'.format(init_path=init_path))
    else:
        if isinstance(model, nn.ModuleList):
            [truncated_normal(sub) for sub in model]
        else:
            truncated_normal(model)
        print('generate init weight: {init_path}'.format(init_path=init_path))
        torch.save(model.state_dict(), init_path)
        print('save init weight: {init_path}'.format(init_path=init_path))
    
    return model 
Example 7
Project: pruning_yolov3   Author: zbyuan   File: utils.py    License: GNU General Public License v3.0 6 votes vote down vote up
def print_mutation(hyp, results, bucket=''):
    # Print mutation results to evolve.txt (for use with train.py --evolve)
    a = '%10s' * len(hyp) % tuple(hyp.keys())  # hyperparam keys
    b = '%10.3g' * len(hyp) % tuple(hyp.values())  # hyperparam values
    c = '%10.3g' * len(results) % results  # results (P, R, mAP, F1, test_loss)
    print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))

    if bucket:
        os.system('gsutil cp gs://%s/evolve.txt .' % bucket)  # download evolve.txt

    with open('evolve.txt', 'a') as f:  # append result
        f.write(c + b + '\n')
    x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0)  # load unique rows
    np.savetxt('evolve.txt', x[np.argsort(-fitness(x))], '%10.3g')  # save sort by fitness

    if bucket:
        os.system('gsutil cp evolve.txt gs://%s' % bucket)  # upload evolve.txt 
Example 8
Project: Pytorch-Project-Template   Author: moemen95   File: condensenet.py    License: MIT License 6 votes vote down vote up
def save_checkpoint(self, filename='checkpoint.pth.tar', is_best=0):
        """
        Saving the latest checkpoint of the training
        :param filename: filename which will contain the state
        :param is_best: flag is it is the best model
        :return:
        """
        state = {
            'epoch': self.current_epoch,
            'iteration': self.current_iteration,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
        }
        # Save the state
        torch.save(state, self.config.checkpoint_dir + filename)
        # If it is the best copy it to another file 'model_best.pth.tar'
        if is_best:
            shutil.copyfile(self.config.checkpoint_dir + filename,
                            self.config.checkpoint_dir + 'model_best.pth.tar') 
Example 9
Project: Pytorch-Project-Template   Author: moemen95   File: erfnet.py    License: MIT License 6 votes vote down vote up
def save_checkpoint(self, filename='checkpoint.pth.tar', is_best=0):
        """
        Saving the latest checkpoint of the training
        :param filename: filename which will contain the state
        :param is_best: flag is it is the best model
        :return:
        """
        state = {
            'epoch': self.current_epoch + 1,
            'iteration': self.current_iteration,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
        }
        # Save the state
        torch.save(state, self.config.checkpoint_dir + filename)
        # If it is the best copy it to another file 'model_best.pth.tar'
        if is_best:
            shutil.copyfile(self.config.checkpoint_dir + filename,
                            self.config.checkpoint_dir + 'model_best.pth.tar') 
Example 10
Project: Pytorch-Project-Template   Author: moemen95   File: dcgan.py    License: MIT License 6 votes vote down vote up
def save_checkpoint(self, file_name="checkpoint.pth.tar", is_best = 0):
        state = {
            'epoch': self.current_epoch,
            'iteration': self.current_iteration,
            'G_state_dict': self.netG.state_dict(),
            'G_optimizer': self.optimG.state_dict(),
            'D_state_dict': self.netD.state_dict(),
            'D_optimizer': self.optimD.state_dict(),
            'fixed_noise': self.fixed_noise,
            'manual_seed': self.manual_seed
        }
        # Save the state
        torch.save(state, self.config.checkpoint_dir + file_name)
        # If it is the best copy it to another file 'model_best.pth.tar'
        if is_best:
            shutil.copyfile(self.config.checkpoint_dir + file_name,
                            self.config.checkpoint_dir + 'model_best.pth.tar') 
Example 11
Project: ACAN   Author: miraiaroha   File: trainer.py    License: MIT License 6 votes vote down vote up
def _save_checkpoint(self, epoch, acc):
        """
        Saves a checkpoint of the network and other variables.
        Only save the best and latest epoch.
        """
        net_type = type(self.net).__name__
        if epoch - self.eval_freq != self.best_epoch:
            pre_save = os.path.join(self.logdir, '{}_{:03d}.pkl'.format(net_type, epoch - self.eval_freq))
            if os.path.isfile(pre_save):
                os.remove(pre_save)
        cur_save = os.path.join(self.logdir, '{}_{:03d}.pkl'.format(net_type, epoch))
        state = {
            'epoch': epoch,
            'acc': acc,
            'net_type': net_type,
            'net': self.net.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            #'scheduler': self.scheduler.state_dict(),
            'use_gpu': self.use_gpu,
            'save_time': datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
        }
        torch.save(state, cur_save)
        return True 
Example 12
Project: End-to-end-ASR-Pytorch   Author: Alexander-H-Liu   File: solver.py    License: MIT License 6 votes vote down vote up
def save_checkpoint(self, f_name, metric, score, show_msg=True):
        '''' 
        Ckpt saver
            f_name - <str> the name phnof ckpt file (w/o prefix) to store, overwrite if existed
            score  - <float> The value of metric used to evaluate model
        '''
        ckpt_path = os.path.join(self.ckpdir, f_name)
        full_dict = {
            "model": self.model.state_dict(),
            "optimizer": self.optimizer.get_opt_state_dict(),
            "global_step": self.step,
            metric: score
        }
        # Additional modules to save
        # if self.amp:
        #    full_dict['amp'] = self.amp_lib.state_dict()
        if self.emb_decoder is not None:
            full_dict['emb_decoder'] = self.emb_decoder.state_dict()

        torch.save(full_dict, ckpt_path)
        if show_msg:
            self.verbose("Saved checkpoint (step = {}, {} = {:.2f}) and status @ {}".
                         format(human_format(self.step), metric, score, ckpt_path)) 
Example 13
Project: transferlearning   Author: jindongwang   File: finetune.py    License: MIT License 6 votes vote down vote up
def train(self, optimizer = None, epoches = 10, save_name=None):
        for i in range(epoches):
            print("Epoch: ", i+1)
            self.train_epoch(optimizer, i+1, epoches+1)
            cur_correct = self.test()
            if cur_correct >= self.littlemax_correct:
                self.littlemax_correct = cur_correct
                self.cur_model = self.model
                print("write cur bset model")

            if cur_correct > self.max_correct:
                self.max_correct = cur_correct
                if save_name:
                    torch.save(self.model, str(save_name))
            print('amazon to webcam max correct: {} max accuracy{: .2f}%\n'.format(
                self.max_correct, 100.0 * self.max_correct / self.len_target_dataset))

        print("Finished fine tuning.") 
Example 14
Project: L3C-PyTorch   Author: fab-jul   File: saver.py    License: GNU General Public License v3.0 6 votes vote down vote up
def save(self, modules, global_step, force=False):
        """
        Save iff (force given or global_step % keep_tmp_itr == 0)
        :param modules: dictionary name -> nn.Module
        :param global_step: current step
        :return: bool, Whether previous checkpoints were removed
        """
        if not (force or (global_step % self.keep_tmp_itr == 0)):
            return False
        assert self._out_dir is not None
        current_ckpt_p = self._save(modules, global_step)
        self.ckpts_since_last_permanent += 1
        if self.ckpts_since_last_permanent == self.keep_every:
            self._remove_previous(current_ckpt_p)
            self.ckpts_since_last_permanent = 0
            return True
        return False 
Example 15
def __init__(self, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None, wsddn_premodel=None):
    self.net = network
    self.imdb = imdb
    self.roidb = roidb
    self.valroidb = valroidb
    self.output_dir = output_dir
    self.tbdir = tbdir
    # Simply put '_val' at the end to save the summaries from the validation set
    self.tbvaldir = tbdir + '_val'
    if not os.path.exists(self.tbvaldir):
      os.makedirs(self.tbvaldir)
    self.pretrained_model = pretrained_model
    self.wsddn_premodel = wsddn_premodel 
Example 16
def snapshot(self, iter):
    net = self.net

    if not os.path.exists(self.output_dir):
      os.makedirs(self.output_dir)

    # Store the model snapshot
    filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pth'
    filename = os.path.join(self.output_dir, filename)
    torch.save(self.net.state_dict(), filename)
    print('Wrote snapshot to: {:s}'.format(filename))
    
    
    if iter % 10000 == 0:
        shutil.copyfile(filename, filename + '.{:d}_cache'.format(iter))
    
    # Also store some meta information, random state, etc.
    nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl'
    nfilename = os.path.join(self.output_dir, nfilename)
    # current state of numpy random
    st0 = np.random.get_state()
    # current position in the database
    cur = self.data_layer._cur
    # current shuffled indexes of the database
    perm = self.data_layer._perm
    # current position in the validation database
    cur_val = self.data_layer_val._cur
    # current shuffled indexes of the validation database
    perm_val = self.data_layer_val._perm

    # Dump the meta info
    with open(nfilename, 'wb') as fid:
      pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

    return filename, nfilename 
Example 17
Project: controllable-text-attribute-transfer   Author: Nrgeup   File: main.py    License: Apache License 2.0 5 votes vote down vote up
def preparation():
    # set model save path
    if args.if_load_from_checkpoint:
        timestamp = args.checkpoint_name
    else:
        timestamp = str(int(time.time()))
        print("create new model save path: %s" % timestamp)
    args.current_save_path = 'save/%s/' % timestamp
    args.log_file = args.current_save_path + time.strftime("log_%Y_%m_%d_%H_%M_%S.txt", time.localtime())
    args.output_file = args.current_save_path + time.strftime("output_%Y_%m_%d_%H_%M_%S.txt", time.localtime())
    print("create log file at path: %s" % args.log_file)

    if os.path.exists(args.current_save_path):
        add_log("Load checkpoint model from Path: %s" % args.current_save_path)
    else:
        os.makedirs(args.current_save_path)
        add_log("Path: %s is created" % args.current_save_path)

    # set task type
    if args.task == 'yelp':
        args.data_path = '../../data/yelp/processed_files/'
    elif args.task == 'amazon':
        args.data_path = '../../data/amazon/processed_files/'
    elif args.task == 'imagecaption':
        pass
    else:
        raise TypeError('Wrong task type!')

    # prepare data
    args.id_to_word, args.vocab_size, \
    args.train_file_list, args.train_label_list = prepare_data(
        data_path=args.data_path, max_num=args.word_dict_max_num, task_type=args.task
    )
    return 
Example 18
Project: controllable-text-attribute-transfer   Author: Nrgeup   File: main.py    License: Apache License 2.0 5 votes vote down vote up
def preparation():
    # set model save path
    if args.if_load_from_checkpoint:
        timestamp = args.checkpoint_name
    else:
        timestamp = str(int(time.time()))
        print("create new model save path: %s" % timestamp)
    args.current_save_path = 'save/%s/' % timestamp
    args.log_file = args.current_save_path + time.strftime("log_%Y_%m_%d_%H_%M_%S.txt", time.localtime())
    args.output_file = args.current_save_path + time.strftime("output_%Y_%m_%d_%H_%M_%S.txt", time.localtime())
    print("create log file at path: %s" % args.log_file)

    if os.path.exists(args.current_save_path):
        add_log("Load checkpoint model from Path: %s" % args.current_save_path)
    else:
        os.makedirs(args.current_save_path)
        add_log("Path: %s is created" % args.current_save_path)

    # set task type
    if args.task == 'yelp':
        args.data_path = '../../data/yelp/processed_files/'
    elif args.task == 'amazon':
        args.data_path = '../../data/amazon/processed_files/'
    elif args.task == 'imagecaption':
        args.data_path = '../../data/imagecaption/processed_files/'
    else:
        raise TypeError('Wrong task type!')

    # prepare data
    args.id_to_word, args.vocab_size, \
    args.train_file_list, args.train_label_list = prepare_data(
        data_path=args.data_path, max_num=args.word_dict_max_num, task_type=args.task
    )
    return 
Example 19
Project: controllable-text-attribute-transfer   Author: Nrgeup   File: main.py    License: Apache License 2.0 5 votes vote down vote up
def preparation():
    # set model save path
    if args.if_load_from_checkpoint:
        timestamp = args.checkpoint_name
    else:
        timestamp = str(int(time.time()))
        print("create new model save path: %s" % timestamp)
    args.current_save_path = 'save/%s/' % timestamp
    args.log_file = args.current_save_path + time.strftime("log_%Y_%m_%d_%H_%M_%S.txt", time.localtime())
    args.output_file = args.current_save_path + time.strftime("output_%Y_%m_%d_%H_%M_%S.txt", time.localtime())
    print("create log file at path: %s" % args.log_file)

    if os.path.exists(args.current_save_path):
        add_log("Load checkpoint model from Path: %s" % args.current_save_path)
    else:
        os.makedirs(args.current_save_path)
        add_log("Path: %s is created" % args.current_save_path)

    # set task type
    if args.task == 'yelp':
        args.data_path = '../../data/yelp/processed_files/'
    elif args.task == 'amazon':
        args.data_path = '../../data/amazon/processed_files/'
    elif args.task == 'imagecaption':
        pass
    else:
        raise TypeError('Wrong task type!')

    # prepare data
    args.id_to_word, args.vocab_size, \
    args.train_file_list, args.train_label_list = prepare_data(
        data_path=args.data_path, max_num=args.word_dict_max_num, task_type=args.task
    )
    return 
Example 20
Project: comet-commonsense   Author: atcbosselut   File: data.py    License: Apache License 2.0 5 votes vote down vote up
def save_checkpoint(state, filename):
    print("Saving model to {}".format(filename))
    torch.save(state, filename) 
Example 21
Project: mmdetection   Author: open-mmlab   File: gather_models.py    License: Apache License 2.0 5 votes vote down vote up
def process_checkpoint(in_file, out_file):
    checkpoint = torch.load(in_file, map_location='cpu')
    # remove optimizer for smaller file size
    if 'optimizer' in checkpoint:
        del checkpoint['optimizer']
    # if it is necessary to remove some sensitive data in checkpoint['meta'],
    # add the code here.
    torch.save(checkpoint, out_file)
    sha = subprocess.check_output(['sha256sum', out_file]).decode()
    final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8])
    subprocess.Popen(['mv', out_file, final_file])
    return final_file 
Example 22
Project: mmdetection   Author: open-mmlab   File: regnet2mmdet.py    License: Apache License 2.0 5 votes vote down vote up
def main():
    parser = argparse.ArgumentParser(description='Convert model keys')
    parser.add_argument('src', help='src detectron model path')
    parser.add_argument('dst', help='save path')
    args = parser.parse_args()
    convert(args.src, args.dst) 
Example 23
Project: mmdetection   Author: open-mmlab   File: detectron2pytorch.py    License: Apache License 2.0 5 votes vote down vote up
def convert(src, dst, depth):
    """Convert keys in detectron pretrained ResNet models to pytorch style."""
    # load arch_settings
    if depth not in arch_settings:
        raise ValueError('Only support ResNet-50 and ResNet-101 currently')
    block_nums = arch_settings[depth]
    # load caffe model
    caffe_model = mmcv.load(src, encoding='latin1')
    blobs = caffe_model['blobs'] if 'blobs' in caffe_model else caffe_model
    # convert to pytorch style
    state_dict = OrderedDict()
    converted_names = set()
    convert_conv_fc(blobs, state_dict, 'conv1', 'conv1', converted_names)
    convert_bn(blobs, state_dict, 'res_conv1_bn', 'bn1', converted_names)
    for i in range(1, len(block_nums) + 1):
        for j in range(block_nums[i - 1]):
            if j == 0:
                convert_conv_fc(blobs, state_dict, f'res{i + 1}_{j}_branch1',
                                f'layer{i}.{j}.downsample.0', converted_names)
                convert_bn(blobs, state_dict, f'res{i + 1}_{j}_branch1_bn',
                           f'layer{i}.{j}.downsample.1', converted_names)
            for k, letter in enumerate(['a', 'b', 'c']):
                convert_conv_fc(blobs, state_dict,
                                f'res{i + 1}_{j}_branch2{letter}',
                                f'layer{i}.{j}.conv{k+1}', converted_names)
                convert_bn(blobs, state_dict,
                           f'res{i + 1}_{j}_branch2{letter}_bn',
                           f'layer{i}.{j}.bn{k + 1}', converted_names)
    # check if all layers are converted
    for key in blobs:
        if key not in converted_names:
            print(f'Not Convert: {key}')
    # save checkpoint
    checkpoint = dict()
    checkpoint['state_dict'] = state_dict
    torch.save(checkpoint, dst) 
Example 24
Project: mmdetection   Author: open-mmlab   File: detectron2pytorch.py    License: Apache License 2.0 5 votes vote down vote up
def main():
    parser = argparse.ArgumentParser(description='Convert model keys')
    parser.add_argument('src', help='src detectron model path')
    parser.add_argument('dst', help='save path')
    parser.add_argument('depth', type=int, help='ResNet model depth')
    args = parser.parse_args()
    convert(args.src, args.dst, args.depth) 
Example 25
Project: subword-qac   Author: clovaai   File: utils.py    License: MIT License 5 votes vote down vote up
def model_save(path, model, optimizer=None):
    model_to_save = get_model(model)
    open(os.path.join(path, 'config.json'), 'w').write(str(model_to_save.config))
    torch.save(model_to_save.state_dict(), open(os.path.join(path, 'model.pt'), 'wb'))
    if optimizer:
        torch.save(optimizer.state_dict(), open(os.path.join(path, 'optimizer.pt'), 'wb')) 
Example 26
Project: nmp_qc   Author: priba   File: utils.py    License: MIT License 5 votes vote down vote up
def save_checkpoint(state, is_best, directory):

    if not os.path.isdir(directory):
        os.makedirs(directory)
    checkpoint_file = os.path.join(directory, 'checkpoint.pth')
    best_model_file = os.path.join(directory, 'model_best.pth')
    torch.save(state, checkpoint_file)
    if is_best:
        shutil.copyfile(checkpoint_file, best_model_file) 
Example 27
Project: models   Author: kipoi   File: model_architecture.py    License: MIT License 5 votes vote down vote up
def save_seqpred_model_weights(fname = 'model_files/deepsea_variant_effects.pth'):
    m = get_model()
    torch.save(m.state_dict(), fname) 
Example 28
Project: models   Author: kipoi   File: model_architecture.py    License: MIT License 5 votes vote down vote up
def save_seqpred_model_weights(fname = 'model_files/deepsea_predict.pth'):
    spm = get_seqpred_model()
    torch.save(spm.state_dict(), fname)

# final models to be imported by Kipoi 
Example 29
Project: Random-Erasing   Author: zhunzhong07   File: fashionmnist.py    License: Apache License 2.0 5 votes vote down vote up
def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) 
Example 30
Project: Random-Erasing   Author: zhunzhong07   File: cifar.py    License: Apache License 2.0 5 votes vote down vote up
def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))