Python torch.serialization() Examples

The following are 3 code examples of torch.serialization(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: serial.py    From CrypTen with MIT License 5 votes vote down vote up
def restricted_loads(s):
    result = RestrictedUnpickler(io.BytesIO(s)).load()
    if torch.is_tensor(result) or isinstance(result, torch.nn.Module):
        _check_hooks_are_valid(result, "_backward_hooks")
    return result


# Adapt torch.load to use RestrictedUnpickler - patched for torch.storage._load_from_bytes
# (Adapted from https://github.com/pytorch/pytorch/blob/master/torch/serialization.py#L602-L773) 
Example #2
Source File: pytorch_bind.py    From trains with Apache License 2.0 5 votes vote down vote up
def _patch_model_io():
        if PatchPyTorchModelIO.__patched:
            return

        if 'torch' not in sys.modules:
            return

        PatchPyTorchModelIO.__patched = True

        # noinspection PyBroadException
        try:
            import torch
            torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save)
            torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load)

            # no need to worry about recursive calls, _patched_call takes care of that
            if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_save'):
                torch.serialization._save = _patched_call(
                    torch.serialization._save, PatchPyTorchModelIO._save)
            if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_load'):
                torch.serialization._load = _patched_call(
                    torch.serialization._load, PatchPyTorchModelIO._load)
            if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_legacy_save'):
                torch.serialization._legacy_save = _patched_call(
                    torch.serialization._legacy_save, PatchPyTorchModelIO._save)
            if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_legacy_load'):
                torch.serialization._legacy_load = _patched_call(
                    torch.serialization._legacy_load, PatchPyTorchModelIO._load)
        except ImportError:
            pass
        except Exception:
            pass  # print('Failed patching pytorch') 
Example #3
Source File: dynamic_simultaneous_translation.py    From attn2d with MIT License 5 votes vote down vote up
def build_model(self, args):
        model = super().build_model(args)
        if args.pretrained is not None: # load pretrained model:
            if not os.path.exists(args.pretrained):
                raise ValueError('Could not load pretrained weights \
                                 - from {}'.format(args.pretrained))
            from torch.serialization import default_restore_location
            saved_state = torch.load(
                args.pretrained, 
                map_location=lambda s, l: default_restore_location(s, 'cpu')
            )
            self.adapt_state(saved_state['model'], model)

        return model