Python torch.set_rng_state() Examples

The following are 30 code examples of torch.set_rng_state(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: checkpoint.py    From torchgpipe with Apache License 2.0 6 votes vote down vote up
def restore_rng_states(device: torch.device,
                       rng_states: Deque[RNGStates],
                       ) -> Generator[None, None, None]:
    """:meth:`Recompute.backward` restores the random number generator states
    captured by :func:`save_rng_states` within its context.

    .. seealso:: :ref:`Referential Transparency`

    """
    cpu_rng_state, gpu_rng_state = rng_states.pop()

    gpu_devices: List[torch.device] = []
    if device.type == 'cuda':
        gpu_devices.append(device)

    with torch.random.fork_rng(gpu_devices):
        torch.set_rng_state(cpu_rng_state)
        if gpu_rng_state is not None:
            torch.cuda.set_rng_state(gpu_rng_state, device)
        yield 
Example #2
Source File: checkpoint.py    From pytorch-memonger with MIT License 6 votes vote down vote up
def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
        inputs = ctx.saved_tensors
        # Stash the surrounding rng state, and mimic the state that was
        # present at this time during forward.  Restore the surrouding state
        # when we're done.
        rng_devices = [torch.cuda.current_device()] if ctx.had_cuda_in_fwd else []
        with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state):
            if preserve_rng_state:
                torch.set_rng_state(ctx.fwd_cpu_rng_state)
                if ctx.had_cuda_in_fwd:
                    torch.cuda.set_rng_state(ctx.fwd_cuda_rng_state)
            detached_inputs = detach_variable(inputs)
            with torch.enable_grad():
                outputs = ctx.run_function(*detached_inputs)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)
        torch.autograd.backward(outputs, args)
        return (None,) + tuple(inp.grad for inp in detached_inputs) 
Example #3
Source File: runner.py    From skeltorch with MIT License 6 votes vote down vote up
def load_states(self, epoch, device):
        """Loads the states from the checkpoint associated with ``epoch``.

        Args:
            epoch (int): ``--epoch`` command argument.
            device (str): ``--device`` command argument.
        """
        checkpoint_data = self.experiment.checkpoint_load(epoch, device)
        if isinstance(self.model, torch.nn.DataParallel):
            self.model.module.load_state_dict(checkpoint_data['model'])
        else:
            self.model.load_state_dict(checkpoint_data['model'])
        self.optimizer.load_state_dict(checkpoint_data['optimizer'])
        random.setstate(checkpoint_data['random_states'][0])
        np.random.set_state(checkpoint_data['random_states'][1])
        torch.set_rng_state(checkpoint_data['random_states'][2].cpu())
        if torch.cuda.is_available() and checkpoint_data['random_states'][3] is not None:
            torch.cuda.set_rng_state(checkpoint_data['random_states'][3].cpu())
        self.counters = checkpoint_data['counters']
        if 'losses' in checkpoint_data:  # Compatibility purposes until next release
            self.losses_epoch = checkpoint_data['losses']
        else:
            self.losses_epoch = checkpoint_data['losses_epoch']
            self.losses_it = checkpoint_data['losses_it']
        self.load_states_others(checkpoint_data) 
Example #4
Source File: test_revop.py    From memcnn with MIT License 6 votes vote down vote up
def test_get_set_device_states(device, enabled):
    shape = (1, 1, 10, 10)
    if not torch.cuda.is_available() and device == 'cuda':
        pytest.skip('This test requires a GPU to be available')
    X = torch.ones(shape, device=device)
    devices, states = get_device_states(X)
    assert len(states) == (1 if device == 'cuda' else 0)
    assert len(devices) == (1 if device == 'cuda' else 0)
    cpu_rng_state = torch.get_rng_state()
    Y = X * torch.rand(shape, device=device)
    with torch.random.fork_rng(devices=devices, enabled=True):
        if enabled:
            if device == 'cpu':
                torch.set_rng_state(cpu_rng_state)
            else:
                set_device_states(devices=devices, states=states)
        Y2 = X * torch.rand(shape, device=device)
    assert torch.equal(Y, Y2) == enabled 
Example #5
Source File: env.py    From detectron2 with Apache License 2.0 6 votes vote down vote up
def seed_all_rng(seed=None):
    """
    Set the random seed for the RNG in torch, numpy and python.

    Args:
        seed (int): if None, will use a strong random seed.
    """
    if seed is None:
        seed = (
            os.getpid()
            + int(datetime.now().strftime("%S%f"))
            + int.from_bytes(os.urandom(2), "big")
        )
        logger = logging.getLogger(__name__)
        logger.info("Using a generated random seed {}".format(seed))
    np.random.seed(seed)
    torch.set_rng_state(torch.manual_seed(seed).get_state())
    random.seed(seed)


# from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path 
Example #6
Source File: checkpoint.py    From torchgpipe with Apache License 2.0 6 votes vote down vote up
def restore_rng_states(device: torch.device,
                       rng_states: Deque[RNGStates],
                       ) -> Generator[None, None, None]:
    """:meth:`Recompute.backward` restores the random number generator states
    captured by :func:`save_rng_states` within its context.

    .. seealso:: :ref:`Referential Transparency`

    """
    cpu_rng_state, gpu_rng_state = rng_states.pop()

    gpu_devices: List[torch.device] = []
    if device.type == 'cuda':
        gpu_devices.append(device)

    with torch.random.fork_rng(gpu_devices):
        torch.set_rng_state(cpu_rng_state)
        if gpu_rng_state is not None:
            torch.cuda.set_rng_state(gpu_rng_state, device)
        yield 
Example #7
Source File: checkpoint.py    From torchgpipe with Apache License 2.0 6 votes vote down vote up
def restore_rng_states(device: torch.device,
                       rng_states: Deque[RNGStates],
                       ) -> Generator[None, None, None]:
    """:meth:`Recompute.backward` restores the random number generator states
    captured by :func:`save_rng_states` within its context.

    .. seealso:: :ref:`Referential Transparency`

    """
    cpu_rng_state, gpu_rng_state = rng_states.pop()

    gpu_devices: List[torch.device] = []
    if device.type == 'cuda':
        gpu_devices.append(device)

    with torch.random.fork_rng(gpu_devices):
        torch.set_rng_state(cpu_rng_state)
        if gpu_rng_state is not None:
            torch.cuda.set_rng_state(gpu_rng_state, device)
        yield 
Example #8
Source File: checkpoint.py    From torchgpipe with Apache License 2.0 6 votes vote down vote up
def restore_rng_states(device: torch.device,
                       rng_states: Deque[RNGStates],
                       ) -> Generator[None, None, None]:
    """:meth:`Recompute.backward` restores the random number generator states
    captured by :func:`save_rng_states` within its context.

    .. seealso:: :ref:`Referential Transparency`

    """
    cpu_rng_state, gpu_rng_state = rng_states.pop()

    gpu_devices: List[torch.device] = []
    if device.type == 'cuda':
        gpu_devices.append(device)

    with torch.random.fork_rng(gpu_devices):
        torch.set_rng_state(cpu_rng_state)
        if gpu_rng_state is not None:
            torch.cuda.set_rng_state(gpu_rng_state, device)
        yield 
Example #9
Source File: checkpoint.py    From torchgpipe with Apache License 2.0 6 votes vote down vote up
def restore_rng_states(device: torch.device,
                       rng_states: Deque[RNGStates],
                       ) -> Generator[None, None, None]:
    """:meth:`Recompute.backward` restores the random number generator states
    captured by :func:`save_rng_states` within its context.

    .. seealso:: :ref:`Referential Transparency`

    """
    cpu_rng_state, gpu_rng_state = rng_states.pop()

    gpu_devices: List[torch.device] = []
    if device.type == 'cuda':
        gpu_devices.append(device)

    with torch.random.fork_rng(gpu_devices):
        torch.set_rng_state(cpu_rng_state)
        if gpu_rng_state is not None:
            torch.cuda.set_rng_state(gpu_rng_state, device)
        yield 
Example #10
Source File: checkpoint.py    From torchgpipe with Apache License 2.0 6 votes vote down vote up
def restore_rng_states(device: torch.device,
                       rng_states: Deque[RNGStates],
                       ) -> Generator[None, None, None]:
    """:meth:`Recompute.backward` restores the random number generator states
    captured by :func:`save_rng_states` within its context.

    .. seealso:: :ref:`Referential Transparency`

    """
    cpu_rng_state, gpu_rng_state = rng_states.pop()

    gpu_devices: List[torch.device] = []
    if device.type == 'cuda':
        gpu_devices.append(device)

    with torch.random.fork_rng(gpu_devices):
        torch.set_rng_state(cpu_rng_state)
        if gpu_rng_state is not None:
            torch.cuda.set_rng_state(gpu_rng_state, device)
        yield 
Example #11
Source File: checkpoint.py    From torchgpipe with Apache License 2.0 6 votes vote down vote up
def restore_rng_states(device: torch.device,
                       rng_states: Deque[RNGStates],
                       ) -> Generator[None, None, None]:
    """:meth:`Recompute.backward` restores the random number generator states
    captured by :func:`save_rng_states` within its context.

    .. seealso:: :ref:`Referential Transparency`

    """
    cpu_rng_state, gpu_rng_state = rng_states.pop()

    gpu_devices: List[torch.device] = []
    if device.type == 'cuda':
        gpu_devices.append(device)

    with torch.random.fork_rng(gpu_devices):
        torch.set_rng_state(cpu_rng_state)
        if gpu_rng_state is not None:
            torch.cuda.set_rng_state(gpu_rng_state, device)
        yield 
Example #12
Source File: env.py    From detectron2 with Apache License 2.0 6 votes vote down vote up
def seed_all_rng(seed=None):
    """
    Set the random seed for the RNG in torch, numpy and python.

    Args:
        seed (int): if None, will use a strong random seed.
    """
    if seed is None:
        seed = (
            os.getpid()
            + int(datetime.now().strftime("%S%f"))
            + int.from_bytes(os.urandom(2), "big")
        )
        logger = logging.getLogger(__name__)
        logger.info("Using a generated random seed {}".format(seed))
    np.random.seed(seed)
    torch.set_rng_state(torch.manual_seed(seed).get_state())
    random.seed(seed)


# from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path 
Example #13
Source File: fakedata.py    From Global-Second-order-Pooling-Convolutional-Networks with MIT License 6 votes vote down vote up
def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is class_index of the target class.
        """
        # create random image that is consistent with the index id
        rng_state = torch.get_rng_state()
        torch.manual_seed(index + self.random_offset)
        img = torch.randn(*self.image_size)
        target = torch.Tensor(1).random_(0, self.num_classes)[0]
        torch.set_rng_state(rng_state)

        # convert to PIL Image
        img = transforms.ToPILImage()(img)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target 
Example #14
Source File: env.py    From SegmenTron with Apache License 2.0 6 votes vote down vote up
def seed_all_rng(seed=None):
    """
    Set the random seed for the RNG in torch, numpy and python.

    Args:
        seed (int): if None, will use a strong random seed.
    """
    if seed is None:
        seed = (
            os.getpid()
            + int(datetime.now().strftime("%S%f"))
            + int.from_bytes(os.urandom(2), "big")
        )
        logger = logging.getLogger(__name__)
        logger.info("Using a generated random seed {}".format(seed))
    np.random.seed(seed)
    torch.set_rng_state(torch.manual_seed(seed).get_state())
    random.seed(seed) 
Example #15
Source File: utils.py    From attn2d with MIT License 5 votes vote down vote up
def with_torch_seed(seed):
    assert isinstance(seed, int)
    rng_state = torch.get_rng_state()
    cuda_rng_state = torch.cuda.get_rng_state()
    set_torch_seed(seed)
    yield
    torch.set_rng_state(rng_state)
    torch.cuda.set_rng_state(cuda_rng_state) 
Example #16
Source File: test_precise_bn.py    From fvcore with Apache License 2.0 5 votes vote down vote up
def setUp(self) -> None:
        torch.set_rng_state(torch.manual_seed(42).get_state()) 
Example #17
Source File: base_model.py    From memory-augmented-self-play with MIT License 5 votes vote down vote up
def _load_metadata(self, checkpoint):
        np.random.set_state(checkpoint[NP_RANDOM_STATE])
        random.setstate(checkpoint[PYTHON_RANDOM_STATE])
        torch.set_rng_state(checkpoint[PYTORCH_RANDOM_STATE]) 
Example #18
Source File: test_kissgp_gp_classification.py    From gpytorch with MIT License 5 votes vote down vote up
def tearDown(self):
        if hasattr(self, "rng_state"):
            torch.set_rng_state(self.rng_state) 
Example #19
Source File: base_policy.py    From memory-augmented-self-play with MIT License 5 votes vote down vote up
def _load_metadata(self, checkpoint):
        np.random.set_state(checkpoint[NP_RANDOM_STATE])
        random.setstate(checkpoint[PYTHON_RANDOM_STATE])
        torch.set_rng_state(checkpoint[PYTORCH_RANDOM_STATE]) 
Example #20
Source File: test_general_multitask_gaussian_likelihood.py    From gpytorch with MIT License 5 votes vote down vote up
def tearDown(self):
        if hasattr(self, "rng_state"):
            torch.set_rng_state(self.rng_state) 
Example #21
Source File: test_pivoted_cholesky.py    From gpytorch with MIT License 5 votes vote down vote up
def tearDown(self):
        if hasattr(self, "rng_state"):
            torch.set_rng_state(self.rng_state) 
Example #22
Source File: test_weight_init.py    From fvcore with Apache License 2.0 5 votes vote down vote up
def setUp(self) -> None:
        torch.set_rng_state(torch.manual_seed(42).get_state()) 
Example #23
Source File: test_pivoted_cholesky.py    From gpytorch with MIT License 5 votes vote down vote up
def tearDown(self):
        if hasattr(self, "rng_state"):
            torch.set_rng_state(self.rng_state) 
Example #24
Source File: test_pivoted_cholesky.py    From gpytorch with MIT License 5 votes vote down vote up
def tearDown(self):
        if hasattr(self, "rng_state"):
            torch.set_rng_state(self.rng_state) 
Example #25
Source File: test_quadrature.py    From gpytorch with MIT License 5 votes vote down vote up
def tearDown(self):
        if hasattr(self, "rng_state"):
            torch.set_rng_state(self.rng_state) 
Example #26
Source File: test_batch_multitask_gp_regression.py    From gpytorch with MIT License 5 votes vote down vote up
def tearDown(self):
        if hasattr(self, "rng_state"):
            torch.set_rng_state(self.rng_state) 
Example #27
Source File: test_simple_gp_classification.py    From gpytorch with MIT License 5 votes vote down vote up
def tearDown(self):
        if hasattr(self, "rng_state"):
            torch.set_rng_state(self.rng_state) 
Example #28
Source File: test_sgpr_regression.py    From gpytorch with MIT License 5 votes vote down vote up
def tearDown(self):
        if hasattr(self, "rng_state"):
            torch.set_rng_state(self.rng_state) 
Example #29
Source File: test_independent_multitask_gp_regression.py    From gpytorch with MIT License 5 votes vote down vote up
def tearDown(self):
        if hasattr(self, "rng_state"):
            torch.set_rng_state(self.rng_state) 
Example #30
Source File: test_kissgp_dkl_regression.py    From gpytorch with MIT License 5 votes vote down vote up
def tearDown(self):
        if hasattr(self, "rng_state"):
            torch.set_rng_state(self.rng_state)