Python torch.nn.DataParallel() Examples

The following are 30 code examples of torch.nn.DataParallel(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.nn , or try the search function .
Example #1
Source File: netbase.py    From person-reid-lib with MIT License 7 votes vote down vote up
def __init__(self, nClass, nCam, model_client, use_flow, task_dir, raw_model_dir, is_image_dataset, recorder):
        self.nClass = nClass
        self.nCam = nCam
        self.recorder = recorder
        self.visual = self.recorder.visual
        self.logger = self.recorder.logger
        self._mode = 'Train'
        self.is_image_dataset = is_image_dataset
        self.task_dir = task_dir

        self.model = model_client(self.nClass, self.nCam, use_flow, self.is_image_dataset, raw_model_dir, self.logger)
        self.model_parallel = DataParallel(self.model).cuda()
        self.model_parallel.feature = DataParallel(self.model.feature).cuda()

        self.net_info = []
        self.const_options()
        self.init_options()
        self.loss_mean = AverageMeter(len(self.line_name))

        self.net_info.extend(self.model.net_info)
        self.optimizer = self.init_optimizer()
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.lr_decay_step, gamma=self.gamma)
        self.idx = 0
        self.best_performance = 0.0 
Example #2
Source File: train.py    From pytorch-multigpu with MIT License 7 votes vote down vote up
def main():
    best_acc = 0

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    print('==> Preparing data..')
    transforms_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    dataset_train = CIFAR10(root='../data', train=True, download=True, 
                            transform=transforms_train)

    train_loader = DataLoader(dataset_train, batch_size=args.batch_size, 
                              shuffle=True, num_workers=args.num_worker)

    # there are 10 classes so the dataset name is cifar-10
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck')

    print('==> Making model..')

    net = pyramidnet()
    net = nn.DataParallel(net)
    net = net.to(device)
    num_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print('The number of parameters of model is', num_params)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=args.lr)
    # optimizer = optim.SGD(net.parameters(), lr=args.lr, 
    #                       momentum=0.9, weight_decay=1e-4)
    
    train(net, criterion, optimizer, train_loader, device) 
Example #3
Source File: train.py    From UDA_pytorch with Apache License 2.0 6 votes vote down vote up
def eval(self, evaluate, model_file, model):
        """ evaluation function """
        if model_file:
            self.model.eval()
            self.load(model_file, None)
            model = self.model.to(self.device)
            if self.cfg.data_parallel:
                model = nn.DataParallel(model)

        results = []
        iter_bar = tqdm(self.sup_iter) if model_file \
            else tqdm(deepcopy(self.eval_iter))
        for batch in iter_bar:
            batch = [t.to(self.device) for t in batch]

            with torch.no_grad():
                accuracy, result = evaluate(model, batch)
            results.append(result)

            iter_bar.set_description('Eval Acc=%5.3f' % accuracy)
        return results 
Example #4
Source File: model.py    From VSE-C with MIT License 6 votes vote down vote up
def get_cnn(self, arch, pretrained):
        """Load a pretrained CNN and parallelize over GPUs
        """
        if pretrained:
            print(("=> using pre-trained model '{}'".format(arch)))
            model = models.__dict__[arch](pretrained=True)
        else:
            print(("=> creating model '{}'".format(arch)))
            model = models.__dict__[arch]()

        if arch.startswith('alexnet') or arch.startswith('vgg'):
            model.features = nn.DataParallel(model.features)
            model.cuda()
        else:
            model = nn.DataParallel(model).cuda()

        return model 
Example #5
Source File: model_saver.py    From ITDD with MIT License 6 votes vote down vote up
def _save(self, step):
        real_model = (self.model.module
                      if isinstance(self.model, nn.DataParallel)
                      else self.model)
        real_generator = (real_model.generator.module
                          if isinstance(real_model.generator, nn.DataParallel)
                          else real_model.generator)

        model_state_dict = real_model.state_dict()
        model_state_dict = {k: v for k, v in model_state_dict.items()
                            if 'generator' not in k}
        generator_state_dict = real_generator.state_dict()
        checkpoint = {
            'model': model_state_dict,
            'generator': generator_state_dict,
            'vocab': onmt.inputters.save_fields_to_vocab(self.fields),
            'opt': self.model_opt,
            'optim': self.optim,
        }

        logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
        checkpoint_path = '%s_step_%d.pt' % (self.base_path, step)
        torch.save(checkpoint, checkpoint_path)
        return checkpoint, checkpoint_path 
Example #6
Source File: trainer.py    From BigGAN-pytorch with Apache License 2.0 6 votes vote down vote up
def build_model(self):
        # code_dim=100, n_class=1000
        self.G = Generator(self.z_dim, self.n_class, chn=self.chn).to(self.device)
        self.D = Discriminator(self.n_class, chn=self.chn).to(self.device)
        if self.parallel:
            print('use parallel...')
            print('gpuids ', self.gpus)
            gpus = [int(i) for i in self.gpus.split(',')]
    
            self.G = nn.DataParallel(self.G, device_ids=gpus)
            self.D = nn.DataParallel(self.D, device_ids=gpus)

        # self.G.apply(weights_init)
        # self.D.apply(weights_init)

        # Loss and optimizer
        # self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.g_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2])

        self.c_loss = torch.nn.CrossEntropyLoss()
        # print networks
        print(self.G)
        print(self.D) 
Example #7
Source File: serve.py    From robosat with MIT License 6 votes vote down vote up
def net_from_chkpt_(self):
        def map_location(storage, _):
            return storage.cuda() if self.cuda else storage.cpu()

        # https://github.com/pytorch/pytorch/issues/7178
        chkpt = torch.load(self.checkpoint, map_location=map_location)

        num_classes = len(self.dataset["common"]["classes"])

        net = UNet(num_classes).to(self.device)
        net = nn.DataParallel(net)

        if self.cuda:
            torch.backends.cudnn.benchmark = True

        net.load_state_dict(chkpt["state_dict"])
        net.eval()

        return net 
Example #8
Source File: run_networks.py    From OpenLongTailRecognition-OLTR with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def init_models(self, optimizer=True):

        networks_defs = self.config['networks']
        self.networks = {}
        self.model_optim_params_list = []

        print("Using", torch.cuda.device_count(), "GPUs.")
        
        for key, val in networks_defs.items():

            # Networks
            def_file = val['def_file']
            model_args = list(val['params'].values())
            model_args.append(self.test_mode)

            self.networks[key] = source_import(def_file).create_model(*model_args)
            self.networks[key] = nn.DataParallel(self.networks[key]).to(self.device)
            
            if 'fix' in val and val['fix']:
                print('Freezing feature weights except for modulated attention weights (if exist).')
                for param_name, param in self.networks[key].named_parameters():
                    # Freeze all parameters except self attention parameters
                    if 'modulatedatt' not in param_name and 'fc' not in param_name:
                        param.requires_grad = False

            # Optimizer list
            optim_params = val['optim_params']
            self.model_optim_params_list.append({'params': self.networks[key].parameters(),
                                                 'lr': optim_params['lr'],
                                                 'momentum': optim_params['momentum'],
                                                 'weight_decay': optim_params['weight_decay']}) 
Example #9
Source File: vq-wav2vec_featurize.py    From fairseq with MIT License 5 votes vote down vote up
def load_model(self):
        cp = torch.load(self.checkpoint, map_location=lambda x, _: x)

        model = Wav2VecModel.build_model(cp["args"], None)

        self.quantize_location = getattr(cp["args"], "vq", "encoder")

        model.load_state_dict(cp["model"])
        model.eval().float()
        model.cuda()

        if self.data_parallel:
            model = nn.DataParallel(model)

        return model 
Example #10
Source File: model.py    From emmental with MIT License 5 votes vote down vote up
def update_task(self, task: EmmentalTask) -> None:
        """Update a existing task in MTL network.

        Args:
          task: A task to update.
        """
        # Update module_pool with task
        for key in task.module_pool.keys():
            # Update the model's module with the task's module
            if Meta.config["model_config"]["dataparallel"]:
                self.module_pool[key] = nn.DataParallel(task.module_pool[key])
            else:
                self.module_pool[key] = task.module_pool[key]
        # Update task flow
        self.task_flows[task.name] = task.task_flow
        # Update loss function
        self.loss_funcs[task.name] = task.loss_func
        # Update output function
        self.output_funcs[task.name] = task.output_func
        # Collect scorer
        self.scorers[task.name] = task.scorer
        # Collect weight
        self.weights[task.name] = task.weight

        # Move model to specified device
        self._move_to_device() 
Example #11
Source File: base_trainer.py    From centerpose with MIT License 5 votes vote down vote up
def set_device(self, gpus, chunk_sizes, device):
    
        if  self.cfg.TRAIN.DISTRIBUTE:
            self.model = self.model.to(device)
            self.model = nn.parallel.DistributedDataParallel(self.model, find_unused_parameters=True,
                                                        device_ids=[self.local_rank, ],
                                                        output_device=self.local_rank)
        else:
            self.model = nn.DataParallel(self.model).to(device)
        self.loss.to(device)
        for state in self.optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device=device, non_blocking=True) 
Example #12
Source File: resnet110_fixup_0_0_1.py    From pipeline with MIT License 5 votes vote down vote up
def __init__(self):
        model = resnet110(use_fixup=True, fixup_coeff=0.01)

        super().__init__(model=DataParallel(model), model_save_path=MODEL_SAVE_PATH,
                         epoch_count=100, batch_size=128) 
Example #13
Source File: base.py    From pipeline with MIT License 5 votes vote down vote up
def __init__(self, num_layers, fixup_coeff=1, normalization_type=BATCH_NORM, batch_size=128):
        if normalization_type == self.BATCH_NORM:
            model = WideResNetBatchNorm(depth=num_layers, num_classes=10)
            norm_type = "batchnorm"
        else:
            model = WideResNetFixup(depth=num_layers, num_classes=10, fixup_coeff=fixup_coeff)
            norm_type = "fixup_coeff_{}".format(fixup_coeff)

        super().__init__(model=DataParallel(model), model_save_path=MODEL_SAVE_PATH.format(norm_type, num_layers),
                         epoch_count=1, batch_size=batch_size) 
Example #14
Source File: multi_gpu.py    From MoePhoto with Apache License 2.0 5 votes vote down vote up
def model_convert(path, scale, gpus=1):
    if gpus > 1:
        loadmultiGPU = True
        gids = [i for i in range(gpus)]
    else:
        loadmultiGPU = False

    if scale == 2:
        from models import Net2x as Net
    if scale == 3:
        from models import Net3x as Net
    elif scale == 4:
        from models import Net4x as Net
    model = Net()

    if loadmultiGPU and torch.cuda.is_available():
        model = nn.DataParallel(model, device_ids=gids).cuda()
    elif torch.cuda.is_available():
        model = model.cuda()
    else:
        model = model.cpu()
    # optionally resume from a checkpoint
    if os.path.isfile(path):
        print("=> loading checkpoint '{}'".format(path))
        weights = torch.load(path)
        # saved_state = weights.state_dict()
        model.load_state_dict(weights)
        # multi gpu loader之前存的模型好像去掉了这部分只有权重
        # if loadmultiGPU:
        #     from collections import OrderedDict
        #     new_state_dict = OrderedDict()
        #     for k, v in saved_state.items():
        #         namekey = 'module.'+k  # add `module.`
        #         new_state_dict[namekey] = v
        #         # load params
        #     model.load_state_dict(new_state_dict)
        # else:
        #     model.load_state_dict(saved_state)
    else:
        print("=> no checkpoint found at '{}'".format(path))
    return model 
Example #15
Source File: train.py    From Single-Human-Parsing-LIP with MIT License 5 votes vote down vote up
def build_network(snapshot, backend):
    epoch = 0
    backend = backend.lower()
    net = models[backend]()
    net = nn.DataParallel(net)
    if snapshot is not None:
        _, epoch = os.path.basename(snapshot).split('_')
        epoch = int(epoch)
        net.load_state_dict(torch.load(snapshot))
        print("Snapshot for epoch {} loaded from {}".format(epoch, snapshot))
    net = net.cuda()
    return net, epoch 
Example #16
Source File: SRGAN_model.py    From BasicSR with Apache License 2.0 5 votes vote down vote up
def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
                                             self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD,
                                                                    DistributedDataParallel):
                net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
                                                 self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info('Network D structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
                                                     self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info('Network F structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                    logger.info(s) 
Example #17
Source File: resnet110_fixup_0_1.py    From pipeline with MIT License 5 votes vote down vote up
def __init__(self):
        model = resnet110(use_fixup=True, fixup_coeff=0.1)

        super().__init__(model=DataParallel(model), model_save_path=MODEL_SAVE_PATH,
                         epoch_count=100, batch_size=128) 
Example #18
Source File: checkpointing.py    From visdial-challenge-starter-pytorch with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _model_state_dict(self):
        """Returns state dict of model, taking care of DataParallel case."""
        if isinstance(self.model, nn.DataParallel):
            return self.model.module.state_dict()
        else:
            return self.model.state_dict() 
Example #19
Source File: test.py    From ScanSSD with MIT License 5 votes vote down vote up
def test_gtdb(args):

    gpu_id = 0
    if args.cuda:
        gpu_id = helpers.get_freer_gpu()
        torch.cuda.set_device(gpu_id)

    # load net
    num_classes = 2 # +1 background

    # initialize SSD
    net = build_ssd(args, 'test', exp_cfg[args.cfg], gpu_id, args.model_type, num_classes)

    logging.debug(net)
    net.to(gpu_id)
    net = nn.DataParallel(net)
    net.load_state_dict(torch.load(args.trained_model, map_location={'cuda:1':'cuda:0'}))
    net.eval()
    logging.debug('Finished loading model!')

    dataset = GTDBDetection(args, args.test_data, split='test',
                            transform=BaseTransform(args.model_type, (246,246,246)),
                            target_transform=GTDBAnnotationTransform())

    if args.cuda:
        net = net.to(gpu_id)
        cudnn.benchmark = True

    # evaluation
    test_net_batch(args, net, gpu_id, dataset,
                   BaseTransform(args.model_type, (246,246,246)),
                   thresh=args.visual_threshold) 
Example #20
Source File: metal_model.py    From metal with Apache License 2.0 5 votes vote down vote up
def forward(self, X, task_names):
        """Returns the outputs of the requested task heads in a dictionary

        The output of each task is the result of passing the input through the
        input_module, middle_module, and head_module for that task, in that order.
        Before calculating any intermediate values, we first check whether a previously
        evaluated task has produced that intermediate result. If so, we use that.

        Args:
            X: a [batch_size, ...] batch from a DataLoader
        Returns:
            output_dict: {task_name (str): output (Tensor)}
        """
        input = move_to_device(X, self.config["device"])
        outputs = {}
        # TODO: Replace this naive caching scheme with a more intelligent and feature-
        # complete approach where arbitrary DAGs of modules are specified and we only
        # cache things that will be reused by another task
        for task_name in task_names:
            # Extra .module call is to get past DataParallel wrapper
            input_module = self.input_modules[task_name].module
            if input_module not in outputs:
                output = input_module(input)
                outputs[input_module] = output

            middle_module = self.middle_modules[task_name].module
            if middle_module not in outputs:
                output = middle_module(outputs[input_module])
                outputs[middle_module] = output

            head_module = self.head_modules[task_name].module
            if head_module not in outputs:
                output = head_module(outputs[middle_module])
                outputs[head_module] = output
        return {t: outputs[self.head_modules[t].module] for t in task_names} 
Example #21
Source File: metal_model.py    From metal with Apache License 2.0 5 votes vote down vote up
def _build(self, tasks):
        """Iterates over tasks, adding their input_modules and head_modules"""
        # TODO: Allow more flexible specification of network structure
        self.input_modules = nn.ModuleDict(
            {task.name: nn.DataParallel(task.input_module) for task in tasks}
        )
        self.middle_modules = nn.ModuleDict(
            {task.name: nn.DataParallel(task.middle_module) for task in tasks}
        )
        self.head_modules = nn.ModuleDict(
            {task.name: nn.DataParallel(task.head_module) for task in tasks}
        )

        self.loss_hat_funcs = {task.name: task.loss_hat_func for task in tasks}
        self.output_hat_funcs = {task.name: task.output_hat_func for task in tasks} 
Example #22
Source File: resnet50_fixup_128.py    From pipeline with MIT License 5 votes vote down vote up
def __init__(self, model_save_path=MODEL_SAVE_PATH):
        super().__init__(model=DataParallel(resnet50()), model_save_path=model_save_path, use_mixup=True, batch_size=128, learning_rate=0.1) 
Example #23
Source File: eval.py    From Single-Human-Parsing-LIP with MIT License 5 votes vote down vote up
def build_network(snapshot, backend):
    epoch = 0
    backend = backend.lower()
    net = models[backend]()
    net = nn.DataParallel(net)
    if snapshot is not None:
        _, epoch = os.path.basename(snapshot).split('_')
        if not epoch == 'last':
            epoch = int(epoch)
        net.load_state_dict(torch.load(snapshot))
        logging.info("Snapshot for epoch {} loaded from {}".format(epoch, snapshot))
    net = net.cuda()
    return net, epoch 
Example #24
Source File: inference.py    From Single-Human-Parsing-LIP with MIT License 5 votes vote down vote up
def build_network(snapshot, backend):
    epoch = 0
    backend = backend.lower()
    net = models[backend]()
    net = nn.DataParallel(net)
    if snapshot is not None:
        _, epoch = os.path.basename(snapshot).split('_')
        if not epoch == 'last':
            epoch = int(epoch)
        net.load_state_dict(torch.load(snapshot))
        logging.info("Snapshot for epoch {} loaded from {}".format(epoch, snapshot))
    net = net.cuda()
    return net, epoch 
Example #25
Source File: __init__.py    From OISR-PyTorch with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def __init__(self, args, ckp):
        super(Model, self).__init__()
        print('Making model...')

        self.scale = args.scale
        self.idx_scale = 0
        self.input_large = (args.model == 'VDSR')
        self.self_ensemble = args.self_ensemble
        self.chop = args.chop
        self.precision = args.precision
        self.cpu = args.cpu
        self.device = torch.device('cpu' if args.cpu else 'cuda')
        self.n_GPUs = args.n_GPUs
        self.save_models = args.save_models

        module = import_module('model.' + args.model.lower())
        self.model = module.make_model(args).to(self.device)
        if args.precision == 'half': self.model.half()

        if not args.cpu and args.n_GPUs > 1:
            self.model = nn.DataParallel(self.model, range(args.n_GPUs))

        self.load(
            ckp.get_path('model'),
            pre_train=args.pre_train,
            resume=args.resume,
            cpu=args.cpu
        )
        print(self.model, file=ckp.log_file) 
Example #26
Source File: __init__.py    From OISR-PyTorch with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def __init__(self, args, ckp):
        super(Model, self).__init__()
        print('Making model...')

        self.scale = args.scale
        self.idx_scale = 0
        self.input_large = (args.model == 'VDSR')
        self.self_ensemble = args.self_ensemble
        self.chop = args.chop
        self.precision = args.precision
        self.cpu = args.cpu
        self.device = torch.device('cpu' if args.cpu else 'cuda')
        self.n_GPUs = args.n_GPUs
        self.save_models = args.save_models

        module = import_module('model.' + args.model.lower())
        self.model = module.make_model(args).to(self.device)
        if args.precision == 'half': self.model.half()

        if not args.cpu and args.n_GPUs > 1:
            self.model = nn.DataParallel(self.model, range(args.n_GPUs))

        self.load(
            ckp.get_path('model'),
            pre_train=args.pre_train,
            resume=args.resume,
            cpu=args.cpu
        )
        print(self.model, file=ckp.log_file) 
Example #27
Source File: __init__.py    From OISR-PyTorch with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def __init__(self, args, ckp):
        super(Model, self).__init__()
        print('Making model...')

        self.scale = args.scale
        self.idx_scale = 0
        self.input_large = (args.model == 'VDSR')
        self.self_ensemble = args.self_ensemble
        self.chop = args.chop
        self.precision = args.precision
        self.cpu = args.cpu
        self.device = torch.device('cpu' if args.cpu else 'cuda')
        self.n_GPUs = args.n_GPUs
        self.save_models = args.save_models

        module = import_module('model.' + args.model.lower())
        self.model = module.make_model(args).to(self.device)
        if args.precision == 'half': self.model.half()

        if not args.cpu and args.n_GPUs > 1:
            self.model = nn.DataParallel(self.model, range(args.n_GPUs))

        self.load(
            ckp.get_path('model'),
            pre_train=args.pre_train,
            resume=args.resume,
            cpu=args.cpu
        )
        print(self.model, file=ckp.log_file) 
Example #28
Source File: __init__.py    From OISR-PyTorch with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def __init__(self, args, ckp):
        super(Model, self).__init__()
        print('Making model...')

        self.scale = args.scale
        self.idx_scale = 0
        self.input_large = (args.model == 'VDSR')
        self.self_ensemble = args.self_ensemble
        self.chop = args.chop
        self.precision = args.precision
        self.cpu = args.cpu
        self.device = torch.device('cpu' if args.cpu else 'cuda')
        self.n_GPUs = args.n_GPUs
        self.save_models = args.save_models

        module = import_module('model.' + args.model.lower())
        self.model = module.make_model(args).to(self.device)
        if args.precision == 'half': self.model.half()

        if not args.cpu and args.n_GPUs > 1:
            self.model = nn.DataParallel(self.model, range(args.n_GPUs))

        self.load(
            ckp.get_path('model'),
            pre_train=args.pre_train,
            resume=args.resume,
            cpu=args.cpu
        )
        print(self.model, file=ckp.log_file) 
Example #29
Source File: models.py    From open-solution-salt-identification with MIT License 5 votes vote down vote up
def load(self, filepath):
        self.model.eval()

        if not isinstance(self.model, nn.DataParallel):
            self.model = nn.DataParallel(self.model)

        if torch.cuda.is_available():
            self.model.cpu()
            self.model.load_state_dict(torch.load(filepath))
            self.model = self.model.cuda()
        else:
            self.model.load_state_dict(torch.load(filepath, map_location=lambda storage, loc: storage))
        return self 
Example #30
Source File: models.py    From open-solution-salt-identification with MIT License 5 votes vote down vote up
def fit(self, datagen, validation_datagen=None, meta_valid=None):
        self._initialize_model_weights()

        if not isinstance(self.model, nn.DataParallel):
            self.model = nn.DataParallel(self.model)

        if torch.cuda.is_available():
            self.model = self.model.cuda()

        self.callbacks.set_params(self, validation_datagen=validation_datagen, meta_valid=meta_valid)
        self.callbacks.on_train_begin()

        batch_gen, steps = datagen
        for epoch_id in range(self.training_config['epochs']):
            self.callbacks.on_epoch_begin()
            for batch_id, data in enumerate(batch_gen):
                self.callbacks.on_batch_begin()
                metrics = self._fit_loop(data)
                self.callbacks.on_batch_end(metrics=metrics)
                if batch_id == steps:
                    break
            self.callbacks.on_epoch_end()
            if self.callbacks.training_break():
                break
        self.callbacks.on_train_end()
        return self