Python torch.enable_grad() Examples

The following are 30 code examples of torch.enable_grad(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: trainers.py    From homura with Apache License 2.0 7 votes vote down vote up
def train(self,
              data_loader: Iterable or DataLoader,
              mode: str = TRAIN):
        """ Training the model for an epoch.

        :param data_loader:
        :param mode: Name of this loop. Default is `train`. Passed to callbacks.
        """

        self._is_train = True
        self._epoch += 1
        self.model.train()
        if hasattr(self.loss_f, "train"):
            self.loss_f.train()
        with torch.enable_grad():
            self._loop(data_loader, mode=mode)

        if self.scheduler is not None and self.update_scheduler_by_epoch:
            self.scheduler.step()

        if isinstance(data_loader, DataLoader) and isinstance(data_loader.sampler, DistributedSampler):
            data_loader.sampler.set_epoch(self.epoch) 
Example #2
Source File: pix2pix.py    From jdit with Apache License 2.0 6 votes vote down vote up
def compute_valid(self):
        """Rewrite this method to compute valid_epoch values.

        You can return a ``dict`` of values that you want to visualize.

        .. note::

            This method is under ``torch.no_grad():``. So, it will never compute grad.
            If you want to compute grad, please use ``torch.enable_grad():`` to wrap your operations.

        Example::

            d_fake = self.netD(self.fake.detach())
            d_real = self.netD(self.ground_truth)
            var_dic = {}
            var_dic["WD"] = w_distance = (d_real.mean() - d_fake.mean()).detach()
            return var_dic

        """
        _, d_var_dic = self.compute_g_loss()
        _, g_var_dic = self.compute_d_loss()
        var_dic = dict(d_var_dic, **g_var_dic)
        return var_dic 
Example #3
Source File: attacks.py    From ss-ood with MIT License 6 votes vote down vote up
def forward(self, model, bx, by, by_prime, curr_batch_size):
        """
        :param model: the classifier's forward method
        :param bx: batch of images
        :param by: true labels
        :return: perturbed batch of images
        """
        adv_bx = bx.detach()
        adv_bx += torch.zeros_like(adv_bx).uniform_(-self.epsilon, self.epsilon)

        for i in range(self.num_steps):
            adv_bx.requires_grad_()
            with torch.enable_grad():
                logits, pen = model(adv_bx * 2 - 1)
                loss = F.cross_entropy(logits[:curr_batch_size], by, reduction='sum')
                if self.attack_rotations:
                    loss += F.cross_entropy(model.module.rot_pred(pen[curr_batch_size:]), by_prime, reduction='sum') / 8.
            grad = torch.autograd.grad(loss, adv_bx, only_inputs=True)[0]

            adv_bx = adv_bx.detach() + self.step_size * torch.sign(grad.detach())

            adv_bx = torch.min(torch.max(adv_bx, bx - self.epsilon), bx + self.epsilon).clamp(0, 1)

        return adv_bx 
Example #4
Source File: train_trades_mnist_binary.py    From TRADES with MIT License 6 votes vote down vote up
def perturb_hinge(net, x_nat):
    # Perturb function based on (E[\phi(f(x)f(x'))])
    # init with random noise
    net.eval()
    x = x_nat.detach() + 0.001 * torch.randn(x_nat.shape).cuda().detach()
    for _ in range(args.num_steps):
        x.requires_grad_()
        with torch.enable_grad():
            # perturb based on hinge loss
            loss = torch.mean(torch.clamp(1 - net(x).squeeze(1) * (net(x_nat).squeeze(1) / args.beta), min=0))
        grad = torch.autograd.grad(loss, [x])[0]
        x = x.detach() + args.step_size * torch.sign(grad.detach())
        x = torch.min(torch.max(x, x_nat - args.epsilon), x_nat + args.epsilon)
        x = torch.clamp(x, 0.0, 1.0)
    net.train()
    return x 
Example #5
Source File: iresblock.py    From residual-flows with MIT License 6 votes vote down vote up
def forward(ctx, estimator_fn, gnet, x, n_power_series, vareps, coeff_fn, training, *g_params):
        ctx.training = training
        with torch.enable_grad():
            x = x.detach().requires_grad_(True)
            g = gnet(x)
            ctx.g = g
            ctx.x = x
            logdetgrad = estimator_fn(g, x, n_power_series, vareps, coeff_fn, training)

            if training:
                grad_x, *grad_params = torch.autograd.grad(
                    logdetgrad.sum(), (x,) + g_params, retain_graph=True, allow_unused=True
                )
                if grad_x is None:
                    grad_x = torch.zeros_like(x)
                ctx.save_for_backward(grad_x, *g_params, *grad_params)

        return safe_detach(g), safe_detach(logdetgrad) 
Example #6
Source File: pl.py    From neuralsort with MIT License 6 votes vote down vote up
def rsample(self, sample_shape, log_score=True):
        """
        sample_shape: number of samples from the PL distribution. Scalar.
        """
        with torch.enable_grad():  # torch.distributions turns off autograd
            n_samples = sample_shape[0]

            def sample_gumbel(samples_shape, eps=1e-20):
                U = torch.zeros(samples_shape, device='cuda').uniform_()
                return -torch.log(-torch.log(U + eps) + eps)
            if not log_score:
                log_s_perturb = torch.log(self.scores.unsqueeze(
                    0)) + sample_gumbel([n_samples, 1, self.n, 1])
            else:
                log_s_perturb = self.scores.unsqueeze(
                    0) + sample_gumbel([n_samples, 1, self.n, 1])
            log_s_perturb = log_s_perturb.view(-1, self.n, 1)
            P_hat = self.relaxed_sort(log_s_perturb)
            P_hat = P_hat.view(n_samples, -1, self.n, self.n)

            return P_hat.squeeze() 
Example #7
Source File: InnerCosFunction.py    From Shift-Net_pytorch with MIT License 6 votes vote down vote up
def backward(ctx, grad_output):

        with torch.enable_grad():
            input, target, mask = ctx.saved_tensors
            former = input.narrow(1, 0, ctx.c//2)
            former_in_mask = torch.mul(former, mask)
            if former_in_mask.size() != target.size():  # For the last iteration of one epoch
                target = target.narrow(0, 0, 1).expand_as(former_in_mask).type_as(former_in_mask)
            
            former_in_mask_clone = former_in_mask.clone().detach().requires_grad_(True)
            ctx.loss = ctx.criterion(former_in_mask_clone, target) * ctx.strength
            ctx.loss.backward()

        grad_output[:,0:ctx.c//2, :,:] += former_in_mask_clone.grad  
        
        return grad_output, None, None, None, None 
Example #8
Source File: rcgan_trainer.py    From rGAN with MIT License 6 votes vote down vote up
def cond_gradient_penalty(self, x_real, x_fake, y, netD, index=0):
        device = x_real.device
        with torch.enable_grad():
            alpha = torch.rand(x_real.size(0), 1, 1, 1).to(device)
            alpha = alpha.expand_as(x_real)
            x_hat = alpha * x_real.detach() + (1 - alpha) * x_fake.detach()
            x_hat.requires_grad = True
            output = netD(x_hat, y)[index]
            grad_output = torch.ones(output.size()).to(device)
            grad = torch.autograd.grad(outputs=output,
                                       inputs=x_hat,
                                       grad_outputs=grad_output,
                                       retain_graph=True,
                                       create_graph=True,
                                       only_inputs=True)[0]
            grad = grad.contiguous().view(grad.size(0), -1)
            loss_gp = ((grad.norm(p=2, dim=1) - 1)**2).mean()
        return loss_gp 
Example #9
Source File: racgan_trainer.py    From rGAN with MIT License 6 votes vote down vote up
def gradient_penalty(self, x_real, x_fake, netD, index=0):
        device = x_real.device
        with torch.enable_grad():
            alpha = torch.rand(x_real.size(0), 1, 1, 1).to(device)
            alpha = alpha.expand_as(x_real)
            x_hat = alpha * x_real.detach() + (1 - alpha) * x_fake.detach()
            x_hat.requires_grad = True
            output = netD(x_hat)[index]
            grad_output = torch.ones(output.size()).to(device)
            grad = torch.autograd.grad(outputs=output,
                                       inputs=x_hat,
                                       grad_outputs=grad_output,
                                       retain_graph=True,
                                       create_graph=True,
                                       only_inputs=True)[0]
            grad = grad.view(grad.size(0), -1)
            loss_gp = ((grad.norm(p=2, dim=1) - 1)**2).mean()
        return loss_gp 
Example #10
Source File: imitator_training.py    From ReAgent with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def train(self, training_batch, train=True):
        learning_input = training_batch.training_input

        with torch.enable_grad():
            action_preds = self.imitator(learning_input.state.float_features)
            # Classification label is index of action with value 1
            pred_action_idxs = torch.max(action_preds, dim=1)[1]
            actual_action_idxs = torch.max(learning_input.action, dim=1)[1]

            if train:
                imitator_loss = torch.nn.CrossEntropyLoss()
                bcq_loss = imitator_loss(action_preds, actual_action_idxs)
                bcq_loss.backward()
                self._maybe_run_optimizer(
                    self.imitator_optimizer, self.minibatches_per_step
                )

        return self._imitator_accuracy(pred_action_idxs, actual_action_idxs) 
Example #11
Source File: abstract.py    From rising with MIT License 6 votes vote down vote up
def __call__(self, *args, **kwargs) -> Any:
        """
        Call super class with correct torch context

        Args:
            *args: forwarded positional arguments
            **kwargs: forwarded keyword arguments

        Returns:
            Any: transformed data

        """
        if self.grad:
            context = torch.enable_grad()
        else:
            context = torch.no_grad()

        with context:
            return super().__call__(*args, **kwargs) 
Example #12
Source File: checkpoint.py    From pytorch-memonger with MIT License 6 votes vote down vote up
def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
        inputs = ctx.saved_tensors
        # Stash the surrounding rng state, and mimic the state that was
        # present at this time during forward.  Restore the surrouding state
        # when we're done.
        rng_devices = [torch.cuda.current_device()] if ctx.had_cuda_in_fwd else []
        with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state):
            if preserve_rng_state:
                torch.set_rng_state(ctx.fwd_cpu_rng_state)
                if ctx.had_cuda_in_fwd:
                    torch.cuda.set_rng_state(ctx.fwd_cuda_rng_state)
            detached_inputs = detach_variable(inputs)
            with torch.enable_grad():
                outputs = ctx.run_function(*detached_inputs)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)
        torch.autograd.backward(outputs, args)
        return (None,) + tuple(inp.grad for inp in detached_inputs) 
Example #13
Source File: attack_whitebox.py    From ME-Net with MIT License 6 votes vote down vote up
def forward(self, inputs, targets):
        if not args.attack:
            return self.model(inputs), inputs

        x = inputs.detach()
        if self.rand:
            x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
        for i in range(self.num_steps):
            x.requires_grad_()
            with torch.enable_grad():
                logits = self.model(x)
                loss = F.cross_entropy(logits, targets, size_average=False)
            grad = torch.autograd.grad(loss, [x])[0]
            # print(grad)
            x = x.detach() + self.step_size * torch.sign(grad.detach())
            x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon)
            x = torch.clamp(x, 0, 1)

        return self.model(x), x 
Example #14
Source File: test_dependency.py    From torchgpipe with Apache License 2.0 6 votes vote down vote up
def test_fork_join_enable_grad():
    x = torch.rand(1, requires_grad=True)

    with torch.enable_grad():
        x2, p = fork(x)

    assert p.requires_grad
    assert x2 is not x
    x = x2

    assert x.requires_grad
    assert p.requires_grad
    assert x.grad_fn.__class__ is Fork._backward_cls
    assert p.grad_fn.__class__ is Fork._backward_cls

    with torch.enable_grad():
        x2 = join(x, p)

    assert x2 is not x
    x = x2

    assert x.requires_grad
    assert x.grad_fn.__class__ is Join._backward_cls 
Example #15
Source File: train_adv.py    From ME-Net with MIT License 6 votes vote down vote up
def forward(self, inputs, targets):
        if not args.attack:
            return self.model(inputs), inputs

        x = inputs.detach()
        if self.rand:
            x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
        for i in range(self.num_steps):
            x.requires_grad_()
            with torch.enable_grad():
                logits = self.model(x)
                loss = F.cross_entropy(logits, targets, size_average=False)
            grad = torch.autograd.grad(loss, [x])[0]
            # print(grad)
            x = x.detach() + self.step_size * torch.sign(grad.detach())
            x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon)
            x = torch.clamp(x, 0, 1)

        return self.model(x), x 
Example #16
Source File: deq.py    From deq with MIT License 5 votes vote down vote up
def backward(ctx, grad):
            torch.cuda.empty_cache()

            # grad should have dimension (bsz x d_model x seq_len)
            bsz, d_model, seq_len = grad.size()
            grad = grad.clone()
            z1ss, uss, z0 = ctx.saved_tensors
            args = ctx.args
            threshold, train_step = args[-2:]

            func = ctx.func
            z1ss = z1ss.clone().detach().requires_grad_()
            uss = uss.clone().detach()
            z0 = z0.clone().detach()

            with torch.enable_grad():
                y = RootFind.g(func, z1ss, uss, z0, *args)

            def g(x):
                y.backward(x, retain_graph=True)   # Retain for future calls to g
                JTx = z1ss.grad.clone().detach()
                z1ss.grad.zero_()
                return JTx + grad

            eps = 2e-10 * np.sqrt(bsz * seq_len * d_model)
            dl_df_est = torch.zeros_like(grad)

            result_info = broyden(g, dl_df_est, threshold=threshold, eps=eps, name="backward")
            dl_df_est = result_info['result']
            
            y.backward(torch.zeros_like(dl_df_est), retain_graph=False)

            grad_args = [None for _ in range(len(args))]
            return (None, dl_df_est, None, None, *grad_args) 
Example #17
Source File: norm_flows.py    From ffjord with MIT License 5 votes vote down vote up
def _detgrad(self, z):
        """Computes |det df/dz|"""
        with torch.enable_grad():
            z = z.requires_grad_(True)
            h = self.activation(torch.mm(z, self.w.view(self.nd, 1)) + self.b)
            psi = grad(h, z, grad_outputs=torch.ones_like(h), create_graph=True, only_inputs=True)[0]
        u_dot_psi = torch.mm(psi, self.u.view(self.nd, 1))
        detgrad = 1 + u_dot_psi
        return detgrad 
Example #18
Source File: iresblock.py    From residual-flows with MIT License 5 votes vote down vote up
def backward(ctx, grad_g, grad_logdetgrad):
        training = ctx.training
        if not training:
            raise ValueError('Provide training=True if using backward.')

        with torch.enable_grad():
            grad_x, *params_and_grad = ctx.saved_tensors
            g, x = ctx.g, ctx.x

            # Precomputed gradients.
            g_params = params_and_grad[:len(params_and_grad) // 2]
            grad_params = params_and_grad[len(params_and_grad) // 2:]

            dg_x, *dg_params = torch.autograd.grad(g, [x] + g_params, grad_g, allow_unused=True)

        # Update based on gradient from logdetgrad.
        dL = grad_logdetgrad[0].detach()
        with torch.no_grad():
            grad_x.mul_(dL)
            grad_params = tuple([g.mul_(dL) if g is not None else None for g in grad_params])

        # Update based on gradient from g.
        with torch.no_grad():
            grad_x.add_(dg_x)
            grad_params = tuple([dg.add_(djac) if djac is not None else dg for dg, djac in zip(dg_params, grad_params)])

        return (None, None, grad_x, None, None, None, None) + grad_params 
Example #19
Source File: pgd_attack_mnist.py    From TRADES with MIT License 5 votes vote down vote up
def _pgd_whitebox(model,
                  X,
                  y,
                  epsilon=args.epsilon,
                  num_steps=args.num_steps,
                  step_size=args.step_size):
    out = model(X)
    err = (out.data.max(1)[1] != y.data).float().sum()
    X_pgd = Variable(X.data, requires_grad=True)
    if args.random:
        random_noise = torch.FloatTensor(*X_pgd.shape).uniform_(-epsilon, epsilon).to(device)
        X_pgd = Variable(X_pgd.data + random_noise, requires_grad=True)

    for _ in range(num_steps):
        opt = optim.SGD([X_pgd], lr=1e-3)
        opt.zero_grad()

        with torch.enable_grad():
            loss = nn.CrossEntropyLoss()(model(X_pgd), y)
        loss.backward()
        eta = step_size * X_pgd.grad.data.sign()
        X_pgd = Variable(X_pgd.data + eta, requires_grad=True)
        eta = torch.clamp(X_pgd.data - X.data, -epsilon, epsilon)
        X_pgd = Variable(X.data + eta, requires_grad=True)
        X_pgd = Variable(torch.clamp(X_pgd, 0, 1.0), requires_grad=True)
    err_pgd = (model(X_pgd).data.max(1)[1] != y.data).float().sum()
    print('err pgd (white-box): ', err_pgd)
    return err, err_pgd 
Example #20
Source File: mutator.py    From nni with MIT License 5 votes vote down vote up
def forward(ctx, x, binary_gates, run_func, backward_func):
        ctx.run_func = run_func
        ctx.backward_func = backward_func

        detached_x = detach_variable(x)
        with torch.enable_grad():
            output = run_func(detached_x)
        ctx.save_for_backward(detached_x, output)
        return output.data 
Example #21
Source File: power.py    From tensorgrad with Apache License 2.0 5 votes vote down vote up
def backward(ctx, grad):
        A, x = detach_variable(ctx.saved_tensors)
        dA = grad
        while True:
            with torch.enable_grad():
                grad = torch.autograd.grad(step(A, x), x, grad_outputs=grad)[0]
            if (torch.norm(grad) > ctx.tol):
                dA = dA + grad
            else:
                break
        with torch.enable_grad():
            dA = torch.autograd.grad(step(A, x), A, grad_outputs=dA)[0]
        return dA, None, None 
Example #22
Source File: lazy_evaluated_kernel_tensor.py    From gpytorch with MIT License 5 votes vote down vote up
def _quad_form_derivative(self, left_vecs, right_vecs):
        # This _quad_form_derivative computes the kernel in chunks
        # It is only used when we are using kernel checkpointing
        # It won't be called if checkpointing is off
        split_size = beta_features.checkpoint_kernel.value()
        if not split_size:
            raise RuntimeError(
                "Should not have ended up in LazyEvaluatedKernelTensor._quad_form_derivative without kernel "
                "checkpointing. This is probably a bug in GPyTorch."
            )

        x1 = self.x1.detach().requires_grad_(True)
        x2 = self.x2.detach().requires_grad_(True)

        # Break objects into chunks
        sub_x1s = torch.split(x1, split_size, dim=-2)
        sub_left_vecss = torch.split(left_vecs, split_size, dim=-2)
        # Compute the gradient in chunks
        for sub_x1, sub_left_vecs in zip(sub_x1s, sub_left_vecss):
            sub_x1.detach_().requires_grad_(True)
            with torch.enable_grad(), settings.lazily_evaluate_kernels(False):
                sub_kernel_matrix = lazify(
                    self.kernel(sub_x1, x2, diag=False, last_dim_is_batch=self.last_dim_is_batch, **self.params)
                )
            sub_grad_outputs = tuple(sub_kernel_matrix._quad_form_derivative(sub_left_vecs, right_vecs))
            sub_kernel_outputs = tuple(sub_kernel_matrix.representation())
            torch.autograd.backward(sub_kernel_outputs, sub_grad_outputs)

        x1.grad = torch.cat([sub_x1.grad.data for sub_x1 in sub_x1s], dim=-2)
        return x1.grad, x2.grad 
Example #23
Source File: torch_agent.py    From KBRD with MIT License 5 votes vote down vote up
def batch_act(self, observations):
        """
        Process a batch of observations (batchsize list of message dicts).

        These observations have been preprocessed by the observe method.

        Subclasses can override this for special functionality, but if the
        default behaviors are fine then just override the ``train_step`` and
        ``eval_step`` methods instead. The former is called when labels are
        present in the observations batch; otherwise, the latter is called.
        """
        batch_size = len(observations)
        # initialize a list of replies with this agent's id
        batch_reply = [{'id': self.getID()} for _ in range(batch_size)]

        # check if there are any labels available, if so we will train on them
        self.is_training = any('labels' in obs for obs in observations)

        # create a batch from the vectors
        batch = self.batchify(observations)

        if self.is_training:
            output = self.train_step(batch)
        else:
            with torch.no_grad():
                # save memory and compute by disabling autograd.
                # use `with torch.enable_grad()` to gain back graidients.
                output = self.eval_step(batch)

        if output is None:
            self.replies['batch_reply'] = None
            return batch_reply

        self.match_batch(batch_reply, batch.valid_indices, output)
        self.replies['batch_reply'] = batch_reply
        self._save_history(observations, batch_reply)  # save model predictions

        return batch_reply 
Example #24
Source File: main.py    From NJUNMT-pytorch with MIT License 5 votes vote down vote up
def compute_forward(model,
                    critic,
                    seqs_x,
                    seqs_y,
                    eval=False,
                    normalization=1.0,
                    norm_by_words=False
                    ):
    """
    :type model: nn.Module

    :type critic: NMTCriterion
    """
    y_inp = seqs_y[:, :-1].contiguous()
    y_label = seqs_y[:, 1:].contiguous()

    words_norm = y_label.ne(PAD).float().sum(1)

    if not eval:
        model.train()
        critic.train()
        # For training
        with torch.enable_grad():
            log_probs = model(seqs_x, y_inp)
            loss = critic(inputs=log_probs, labels=y_label, reduce=False, normalization=normalization)

            if norm_by_words:
                loss = loss.div(words_norm).sum()
            else:
                loss = loss.sum()
        torch.autograd.backward(loss)
        return loss.item()
    else:
        model.eval()
        critic.eval()
        # For compute loss
        with torch.no_grad():
            log_probs = model(seqs_x, y_inp)
            loss = critic(inputs=log_probs, labels=y_label, normalization=normalization, reduce=True)
        return loss.item() 
Example #25
Source File: attack_pgd.py    From semisup-adv with MIT License 5 votes vote down vote up
def pgd(model,
        X,
        y,
        epsilon=8 / 255,
        num_steps=20,
        step_size=0.01,
        random_start=True):
    out = model(X)
    is_correct_natural = (out.max(1)[1] == y).float().cpu().numpy()
    perturbation = torch.zeros_like(X, requires_grad=True)

    if random_start:
        perturbation = torch.rand_like(X, requires_grad=True)
        perturbation.data = perturbation.data * 2 * epsilon - epsilon

    is_correct_adv = []
    opt = optim.SGD([perturbation], lr=1e-3)  # This is just to clear the grad

    for _ in range(num_steps):
        opt.zero_grad()

        with torch.enable_grad():
            loss = nn.CrossEntropyLoss()(model(X + perturbation), y)

        loss.backward()

        perturbation.data = (
            perturbation + step_size * perturbation.grad.detach().sign()).clamp(
            -epsilon, epsilon)
        perturbation.data = torch.min(torch.max(perturbation.detach(), -X),
                                      1 - X)  # clip X+delta to [0,1]
        X_pgd = Variable(torch.clamp(X.data + perturbation.data, 0, 1.0),
                         requires_grad=False)
        is_correct_adv.append(np.reshape(
            (model(X_pgd).max(1)[1] == y).float().cpu().numpy(),
            [-1, 1]))

    is_correct_adv = np.concatenate(is_correct_adv, axis=1)
    return is_correct_natural, is_correct_adv 
Example #26
Source File: torch_agent.py    From neural_chat with MIT License 5 votes vote down vote up
def batch_act(self, observations):
        """
        Process a batch of observations (batchsize list of message dicts).

        These observations have been preprocessed by the observe method.

        Subclasses can override this for special functionality, but if the
        default behaviors are fine then just override the ``train_step`` and
        ``eval_step`` methods instead. The former is called when labels are
        present in the observations batch; otherwise, the latter is called.
        """
        batch_size = len(observations)
        # initialize a list of replies with this agent's id
        batch_reply = [Message({'id': self.getID()}) for _ in range(batch_size)]

        # check if there are any labels available, if so we will train on them
        self.is_training = any('labels' in obs for obs in observations)

        # create a batch from the vectors
        batch = self.batchify(observations)

        if self.is_training:
            output = self.train_step(batch)
        else:
            with torch.no_grad():
                # save memory and compute by disabling autograd.
                # use `with torch.enable_grad()` to gain back graidients.
                output = self.eval_step(batch)

        if output is None:
            self.replies['batch_reply'] = None
            return batch_reply

        self.match_batch(batch_reply, batch.valid_indices, output)
        self.replies['batch_reply'] = batch_reply
        #print('hello', self.observation)
        return batch_reply 
Example #27
Source File: utils.py    From elastic with Apache License 2.0 5 votes vote down vote up
def backward(ctx, *output_grads):
        for i in range(len(ctx.input_tensors)):
            temp = ctx.input_tensors[i]
            ctx.input_tensors[i] = temp.detach()
            ctx.input_tensors[i].requires_grad = temp.requires_grad
        with torch.enable_grad():
            output_tensors = ctx.run_function(*ctx.input_tensors)
        input_grads = torch.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True)
        return (None, None) + input_grads 
Example #28
Source File: ProxProp.py    From proxprop with MIT License 5 votes vote down vote up
def backward(ctx, grad_z):
        input = ctx.saved_variables[0]
        params = list(ctx.saved_variables[1:-1])
        output = ctx.saved_variables[-1]
        grad_input = None
        grad_params = [None] * len(params)
        layer = ctx.optimization_layer

        # explicit gradient step on z
        z_updated = output - grad_z

        # prox step or gradient step on the network parameters
        if layer.optimization_mode == 'prox_exact':
            A, Y, Z = layer.to_exact_solve_shape(input, z_updated, *params)
        else:
            A = input.detach()
            Y = z_updated.detach()
            Z = layer.to_cg_shape(params).detach()

        X_tensor = optimization_step(A, Y, Z, 1./layer.tau_prox, layer.apply_cg, mode=layer.optimization_mode)

        if 'prox_exact' == layer.optimization_mode:
            params_udpated = list(layer.from_exact_solve_shape(X_tensor).values())
        else:
            params_udpated = list(layer.from_cg_shape(X_tensor).values())

        # write difference in grad fields
        grad_params = [x[0] - x[1] for x in zip(params,params_udpated)]

        # explicit gradient step on a
        input.requires_grad_()
        with torch.enable_grad():
            out_temp = ctx.optimization_layer.apply_forward(input)
            grad_temp = torch.autograd.grad(out_temp, input, grad_z)
        grad_input = grad_temp[0]
        return tuple([grad_input] + grad_params + [None]) 
Example #29
Source File: metric_optimization_torch.py    From pysaliency with MIT License 5 votes vote down vote up
def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad
                learning_rate = group['lr']

                #
                constraint_grad = torch.ones_like(d_p)
                constraint_grad_norm = torch.sum(torch.pow(constraint_grad, 2))
                normed_constraint_grad = constraint_grad / constraint_grad_norm

                # first step: make sure we are not running into negative values
                max_allowed_grad = p / learning_rate
                projected_grad1 = torch.min(d_p, max_allowed_grad)

                # second step: Make sure that the gradient does not walk
                # out of the constraint
                projected_grad2 = projected_grad1 - torch.sum(projected_grad1 * constraint_grad) * normed_constraint_grad

                p.add_(projected_grad2, alpha=-group['lr'])

        return loss 
Example #30
Source File: NeuralIntegral.py    From UMNN with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def computeIntegrand(x, h, integrand, x_tot):
    with torch.enable_grad():
        f = integrand.forward(x, h)
        g_param = _flatten(torch.autograd.grad(f, integrand.parameters(), x_tot, create_graph=True, retain_graph=True))
        g_h = _flatten(torch.autograd.grad(f, h, x_tot))

    return g_param, g_h