Python torch.is_grad_enabled() Examples

The following are 30 code examples of torch.is_grad_enabled(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: test_async.py    From mmdetection with Apache License 2.0 6 votes vote down vote up
def test_simple_inference(self):
            if not torch.cuda.is_available():
                import pytest

                pytest.skip('test requires GPU and torch+cuda')

            ori_grad_enabled = torch.is_grad_enabled()
            root_dir = os.path.dirname(os.path.dirname(__name__))
            model_config = os.path.join(
                root_dir, 'configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py')
            detector = MaskRCNNDetector(model_config)
            await detector.init()
            img_path = os.path.join(root_dir, 'demo/demo.jpg')
            bboxes, _ = await detector.apredict(img_path)
            self.assertTrue(bboxes)
            # asy inference detector will hack grad_enabled,
            # so restore here to avoid it to influence other tests
            torch.set_grad_enabled(ori_grad_enabled) 
Example #2
Source File: kfac.py    From dal with MIT License 6 votes vote down vote up
def _save_input(self, module, input):
        if torch.is_grad_enabled() and self.steps % self.Ts == 0:
            classname = module.__class__.__name__
            layer_info = None
            if classname == 'Conv2d':
                layer_info = (module.kernel_size, module.stride,
                              module.padding)

            aa = compute_cov_a(input[0].data, classname, layer_info,
                               self.fast_cnn)

            # Initialize buffers
            if self.steps == 0:
                self.m_aa[module] = aa.clone()

            update_running_stat(aa, self.m_aa[module], self.stat_decay) 
Example #3
Source File: kfac.py    From gym-miniworld with Apache License 2.0 6 votes vote down vote up
def _save_input(self, module, input):
        if torch.is_grad_enabled() and self.steps % self.Ts == 0:
            classname = module.__class__.__name__
            layer_info = None
            if classname == 'Conv2d':
                layer_info = (module.kernel_size, module.stride,
                              module.padding)

            aa = compute_cov_a(input[0].data, classname, layer_info,
                               self.fast_cnn)

            # Initialize buffers
            if self.steps == 0:
                self.m_aa[module] = aa.clone()

            update_running_stat(aa, self.m_aa[module], self.stat_decay) 
Example #4
Source File: kfac.py    From pytorch-pommerman-rl with MIT License 6 votes vote down vote up
def _save_input(self, module, input):
        if torch.is_grad_enabled() and self.steps % self.Ts == 0:
            classname = module.__class__.__name__
            layer_info = None
            if classname == 'Conv2d':
                layer_info = (module.kernel_size, module.stride,
                              module.padding)

            aa = compute_cov_a(input[0].data, classname, layer_info,
                               self.fast_cnn)

            # Initialize buffers
            if self.steps == 0:
                self.m_aa[module] = aa.clone()

            update_running_stat(aa, self.m_aa[module], self.stat_decay) 
Example #5
Source File: sequence_labeling.py    From nested-ner-tacl2020-transformers with GNU General Public License v3.0 6 votes vote down vote up
def _get_rnn_output(self, input_ids: Tensor, input_mask: Tensor,
                        first_subtokens: List[List[int]], last_subtokens: List[List[int]], mask: Tensor = None) \
            -> Tensor:
        # [batch, length, word_dim]
        with torch.set_grad_enabled(self.fine_tune and torch.is_grad_enabled()):
            sequence_output = self.bert(input_ids, attention_mask=input_mask)
            if self.fine_tune:
                sequence_output = sequence_output[0]
            else:
                sequence_output = torch.cat(tuple(sequence_output[2][-self.bert_layers:]), 2).detach()
            batch, _, word_dim = sequence_output.size()
            input = sequence_output.new_zeros((batch, max([len(fst) for fst in first_subtokens]), word_dim))
            for i, subtokens_list_tuple in enumerate(zip(first_subtokens, last_subtokens)):
                for j, subtokens_tuple in enumerate(zip(subtokens_list_tuple[0], subtokens_list_tuple[1])):
                    input[i, j, :] = torch.mean(sequence_output[i, subtokens_tuple[0]:subtokens_tuple[1], :], dim=0)
        # output from rnn [batch, length, hidden_size]
        output, hn = self.rnn(input, mask)

        # apply dropout for the output of rnn
        # [batch, length, hidden_size] --> [batch, hidden_size, length] --> [batch, length, hidden_size]
        output = self.dropout_out(output.transpose(1, 2)).transpose(1, 2)

        return output 
Example #6
Source File: kfac.py    From pytorch-a2c-ppo-acktr-gail with MIT License 6 votes vote down vote up
def _save_input(self, module, input):
        if torch.is_grad_enabled() and self.steps % self.Ts == 0:
            classname = module.__class__.__name__
            layer_info = None
            if classname == 'Conv2d':
                layer_info = (module.kernel_size, module.stride,
                              module.padding)

            aa = compute_cov_a(input[0].data, classname, layer_info,
                               self.fast_cnn)

            # Initialize buffers
            if self.steps == 0:
                self.m_aa[module] = aa.clone()

            update_running_stat(aa, self.m_aa[module], self.stat_decay) 
Example #7
Source File: kfac.py    From carla-rl with MIT License 6 votes vote down vote up
def _save_input(self, module, input):
        if torch.is_grad_enabled() and self.steps % self.Ts == 0:
            classname = module.__class__.__name__
            layer_info = None
            if classname == 'Conv2d':
                layer_info = (module.kernel_size, module.stride,
                              module.padding)

            aa = compute_cov_a(input[0].data, classname, layer_info,
                               self.fast_cnn)

            # Initialize buffers
            if self.steps == 0:
                self.m_aa[module] = aa.clone()

            update_running_stat(aa, self.m_aa[module], self.stat_decay) 
Example #8
Source File: kfac.py    From midlevel-reps with MIT License 6 votes vote down vote up
def _save_input(self, module, input):
        if torch.is_grad_enabled() and self.steps % self.Ts == 0:
            classname = module.__class__.__name__
            layer_info = None
            if classname == 'Conv2d':
                layer_info = (module.kernel_size, module.stride,
                              module.padding)

            aa = compute_cov_a(input[0].data, classname, layer_info,
                               self.fast_cnn)

            # Initialize buffers
            if self.steps == 0:
                self.m_aa[module] = aa.clone()

            update_running_stat(aa, self.m_aa[module], self.stat_decay) 
Example #9
Source File: distributed.py    From mmcv with Apache License 2.0 5 votes vote down vote up
def val_step(self, *inputs, **kwargs):
        """val_step() API for module wrapped by DistributedDataParallel.

        This method is basically the same as
        ``DistributedDataParallel.forward()``, while replacing
        ``self.module.forward()`` with ``self.module.val_step()``.
        It is compatible with PyTorch 1.1 - 1.5.
        """
        if getattr(self, 'require_forward_param_sync', True):
            self._sync_params()
        if self.device_ids:
            inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
            if len(self.device_ids) == 1:
                output = self.module.val_step(*inputs[0], **kwargs[0])
            else:
                outputs = self.parallel_apply(
                    self._module_copies[:len(inputs)], inputs, kwargs)
                output = self.gather(outputs, self.output_device)
        else:
            output = self.module.val_step(*inputs, **kwargs)

        if torch.is_grad_enabled() and getattr(
                self, 'require_backward_grad_sync', True):
            if self.find_unused_parameters:
                self.reducer.prepare_for_backward(list(_find_tensors(output)))
            else:
                self.reducer.prepare_for_backward([])
        else:
            if TORCH_VERSION > '1.2':
                self.require_forward_param_sync = False
        return output 
Example #10
Source File: dependency.py    From torchgpipe with Apache License 2.0 5 votes vote down vote up
def join(input: Tensor, phony: Tensor) -> Tensor:
    """Merges two autograd lanes."""
    if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad):
        input = Join.apply(input, phony)

    return input 
Example #11
Source File: dependency.py    From torchgpipe with Apache License 2.0 5 votes vote down vote up
def fork(input: Tensor) -> Tuple[Tensor, Tensor]:
    """Branches out from an autograd lane of the given tensor."""
    if torch.is_grad_enabled() and input.requires_grad:
        input, phony = Fork.apply(input)
    else:
        phony = get_phony(input.device, requires_grad=False)

    return input, phony 
Example #12
Source File: dependency.py    From torchgpipe with Apache License 2.0 5 votes vote down vote up
def fork(input: Tensor) -> Tuple[Tensor, Tensor]:
    """Branches out from an autograd lane of the given tensor."""
    if torch.is_grad_enabled() and input.requires_grad:
        input, phony = Fork.apply(input)
    else:
        phony = get_phony(input.device, requires_grad=False)

    return input, phony 
Example #13
Source File: dependency.py    From torchgpipe with Apache License 2.0 5 votes vote down vote up
def join(input: Tensor, phony: Tensor) -> Tensor:
    """Merges two autograd lanes."""
    if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad):
        input = Join.apply(input, phony)

    return input 
Example #14
Source File: dependency.py    From torchgpipe with Apache License 2.0 5 votes vote down vote up
def fork(input: Tensor) -> Tuple[Tensor, Tensor]:
    """Branches out from an autograd lane of the given tensor."""
    if torch.is_grad_enabled() and input.requires_grad:
        input, phony = Fork.apply(input)
    else:
        phony = get_phony(input.device, requires_grad=False)

    return input, phony 
Example #15
Source File: test_worker.py    From torchgpipe with Apache License 2.0 5 votes vote down vote up
def test_grad_mode(grad_mode):
    def detect_grad_enabled():
        x = torch.rand(1, requires_grad=torch.is_grad_enabled())
        return Batch(x)

    with torch.set_grad_enabled(grad_mode):
        with spawn_workers([torch.device('cpu')]) as (in_queues, out_queues):
            task = Task(CPUStream, compute=detect_grad_enabled, finalize=None)
            in_queues[0].put(task)

            ok, (_, batch) = out_queues[0].get()

            assert ok
            assert batch[0].requires_grad == grad_mode 
Example #16
Source File: inference.py    From mmfashion with Apache License 2.0 5 votes vote down vote up
def async_inference_detector(model, img):
    """Async inference image(s) with the detector.

    Args:
        model (nn.Module): The loaded detector.
        imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
            images.

    Returns:
        Awaitable detection results.
    """
    cfg = model.cfg
    device = next(model.parameters()).device  # model device
    # build the data pipeline
    test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
    test_pipeline = Compose(test_pipeline)
    # prepare data
    data = dict(img=img)
    data = test_pipeline(data)
    data = scatter(collate([data], samples_per_gpu=1), [device])[0]

    # We don't restore `torch.is_grad_enabled()` value during concurrent
    # inference since execution can overlap
    torch.set_grad_enabled(False)
    result = await model.aforward_test(rescale=True, **data)
    return result


# TODO: merge this method with the one in BaseDetector 
Example #17
Source File: distributed.py    From mmcv with Apache License 2.0 5 votes vote down vote up
def train_step(self, *inputs, **kwargs):
        """train_step() API for module wrapped by DistributedDataParallel.

        This method is basically the same as
        ``DistributedDataParallel.forward()``, while replacing
        ``self.module.forward()`` with ``self.module.train_step()``.
        It is compatible with PyTorch 1.1 - 1.5.
        """
        if getattr(self, 'require_forward_param_sync', True):
            self._sync_params()
        if self.device_ids:
            inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
            if len(self.device_ids) == 1:
                output = self.module.train_step(*inputs[0], **kwargs[0])
            else:
                outputs = self.parallel_apply(
                    self._module_copies[:len(inputs)], inputs, kwargs)
                output = self.gather(outputs, self.output_device)
        else:
            output = self.module.train_step(*inputs, **kwargs)

        if torch.is_grad_enabled() and getattr(
                self, 'require_backward_grad_sync', True):
            if self.find_unused_parameters:
                self.reducer.prepare_for_backward(list(_find_tensors(output)))
            else:
                self.reducer.prepare_for_backward([])
        else:
            if TORCH_VERSION > '1.2':
                self.require_forward_param_sync = False
        return output 
Example #18
Source File: ops.py    From GCA-Matting with MIT License 5 votes vote down vote up
def forward(self, *args):
        # if torch.is_grad_enabled() and self.module.training:
        if self.module.training:
            self._update_u_v()
        else:
            self._noupdate_u_v()
        return self.module.forward(*args) 
Example #19
Source File: util_torch.py    From netharn with Apache License 2.0 5 votes vote down vote up
def __init__(self, flag):
        import warnings
        warnings.warn('Deprecated use torch.set_grad_enabled instead',
                      DeprecationWarning)
        if tuple(map(int, torch.__version__.split('.')[0:2])) < (0, 4):
            self.prev = None
            self.flag = flag
        else:
            self.prev = torch.is_grad_enabled()
            self.flag = flag 
Example #20
Source File: inference.py    From mmdetection with Apache License 2.0 5 votes vote down vote up
def async_inference_detector(model, img):
    """Async inference image(s) with the detector.

    Args:
        model (nn.Module): The loaded detector.
        imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
            images.

    Returns:
        Awaitable detection results.
    """
    cfg = model.cfg
    device = next(model.parameters()).device  # model device
    # build the data pipeline
    test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
    test_pipeline = Compose(test_pipeline)
    # prepare data
    data = dict(img=img)
    data = test_pipeline(data)
    data = scatter(collate([data], samples_per_gpu=1), [device])[0]

    # We don't restore `torch.is_grad_enabled()` value during concurrent
    # inference since execution can overlap
    torch.set_grad_enabled(False)
    result = await model.aforward_test(rescale=True, **data)
    return result 
Example #21
Source File: utils.py    From apex with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def cached_cast(cast_fn, x, cache):
    if is_nested(x):
        return type(x)([cached_cast(y) for y in x])
    if x in cache:
        cached_x = cache[x]
        if x.requires_grad and cached_x.requires_grad:
            # Make sure x is actually cached_x's autograd parent.
            if cached_x.grad_fn.next_functions[1][0].variable is not x:
                raise RuntimeError("x and cache[x] both require grad, but x is not "
                                   "cache[x]'s parent.  This is likely an error.")
        # During eval, it's possible to end up caching casted weights with
        # requires_grad=False.  On the next training iter, if cached_x is found
        # and reused from the cache, it will not actually have x as its parent.
        # Therefore, we choose to invalidate the cache (and force refreshing the cast)
        # if x.requires_grad and cached_x.requires_grad do not match.
        #
        # During eval (i.e. running under with torch.no_grad()) the invalidation
        # check would cause the cached value to be dropped every time, because
        # cached_x would always be created with requires_grad=False, while x would
        # still have requires_grad=True.  This would render the cache effectively
        # useless during eval.  Therefore, if we are running under the no_grad()
        # context manager (torch.is_grad_enabled=False) we elide the invalidation
        # check, and use the cached value even though its requires_grad flag doesn't
        # match.  During eval, we don't care that there's no autograd-graph
        # connection between x and cached_x.
        if torch.is_grad_enabled() and x.requires_grad != cached_x.requires_grad:
            del cache[x]
        else:
            return cached_x

    casted_x = cast_fn(x)
    cache[x] = casted_x
    return casted_x 
Example #22
Source File: test_vgsl.py    From kraken with Apache License 2.0 5 votes vote down vote up
def test_helper_train(self):
        """
        Tests train/eval mode helper methods
        """
        rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]')
        rnn.train()
        self.assertTrue(torch.is_grad_enabled())
        self.assertTrue(rnn.nn.training)
        rnn.eval()
        self.assertFalse(torch.is_grad_enabled())
        self.assertFalse(rnn.nn.training) 
Example #23
Source File: parallel.py    From pytorch-meta with MIT License 5 votes vote down vote up
def scatter(self, inputs, kwargs, device_ids):
        try:
            params = kwargs.pop('params')
        except KeyError:
            return super(DataParallel, self).scatter(inputs, kwargs, device_ids)

        inputs_, kwargs_ = scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
        # Add params argument unchanged back in kwargs
        replicas = self._replicate_params(params, inputs_, device_ids,
                                          detach=not torch.is_grad_enabled())
        kwargs_ = tuple(dict(params=replica, **kwarg)
                        for (kwarg, replica) in zip(kwargs_, replicas))
        return inputs_, kwargs_ 
Example #24
Source File: EncodingDataParallel.py    From torch-toolbox with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def replicate(self, module, device_ids):
        return replicate(module, device_ids, not torch.is_grad_enabled()) 
Example #25
Source File: dependency.py    From torchgpipe with Apache License 2.0 5 votes vote down vote up
def fork(input: Tensor) -> Tuple[Tensor, Tensor]:
    """Branches out from an autograd lane of the given tensor."""
    if torch.is_grad_enabled() and input.requires_grad:
        input, phony = Fork.apply(input)
    else:
        phony = get_phony(input.device, requires_grad=False)

    return input, phony 
Example #26
Source File: dependency.py    From torchgpipe with Apache License 2.0 5 votes vote down vote up
def fork(input: Tensor) -> Tuple[Tensor, Tensor]:
    """Branches out from an autograd lane of the given tensor."""
    if torch.is_grad_enabled() and input.requires_grad:
        input, phony = Fork.apply(input)
    else:
        phony = get_phony(input.device, requires_grad=False)

    return input, phony 
Example #27
Source File: dependency.py    From torchgpipe with Apache License 2.0 5 votes vote down vote up
def join(input: Tensor, phony: Tensor) -> Tensor:
    """Merges two autograd lanes."""
    if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad):
        input = Join.apply(input, phony)

    return input 
Example #28
Source File: dependency.py    From torchgpipe with Apache License 2.0 5 votes vote down vote up
def fork(input: Tensor) -> Tuple[Tensor, Tensor]:
    """Branches out from an autograd lane of the given tensor."""
    if torch.is_grad_enabled() and input.requires_grad:
        input, phony = Fork.apply(input)
    else:
        phony = get_phony(input.device, requires_grad=False)

    return input, phony 
Example #29
Source File: dependency.py    From torchgpipe with Apache License 2.0 5 votes vote down vote up
def fork(input: Tensor) -> Tuple[Tensor, Tensor]:
    """Branches out from an autograd lane of the given tensor."""
    if torch.is_grad_enabled() and input.requires_grad:
        input, phony = Fork.apply(input)
    else:
        phony = get_phony(input.device, requires_grad=False)

    return input, phony 
Example #30
Source File: dependency.py    From torchgpipe with Apache License 2.0 5 votes vote down vote up
def join(input: Tensor, phony: Tensor) -> Tensor:
    """Merges two autograd lanes."""
    if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad):
        input = Join.apply(input, phony)

    return input