Python torch.autograd.Function() Examples

The following are 6 code examples of torch.autograd.Function(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.autograd , or try the search function .
Example #1
Source File: losses.py    From soccerontable with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def forward(self, input, target, mask):
        self.loss = self.criterion(input, target*mask)
        return self.loss


# class DepthTo3D(Function):
#
#     def forward(self, input, pix_inv, R_inv, T):
#         self.save_for_backward(input, pix_inv, R_inv, T)
#         return torch.bmm(R_inv, input.resize(bs, 1, sx * sy).repeat(1, 3, 1) * pix_inv - T_var.repeat(1, 1,sx * sy)).resize(bs, 3, sx, sy)
#
#     def backward(self):
#
#         pix_inv, R_inv, T = self.saved_tensors
#
#         return grad_input, 
Example #2
Source File: test_operators.py    From onnx-fb-universe with MIT License 6 votes vote down vote up
def test_symbolic_mismatch(self):
        class MyFun(Function):
            @staticmethod
            def symbolic(g, x):
                # The inside of this function should never be invoked, because
                # we will fail due to an argument mismatch first.
                assert False

            @staticmethod
            def forward(ctx, x, y):
                return x + y

        x = Variable(torch.randn(2, 2).fill_(1.0))
        y = Variable(torch.randn(2, 2).fill_(1.0))
        # NB: Don't use expect test here, the type error wobbles depending
        # on Python version
        with self.assertRaisesRegex(TypeError, "occurred when translating MyFun"):
            export_to_string(FuncModule(MyFun().apply), (x, y))

    # TODO: Do an nn style test for these 
Example #3
Source File: test_operators.py    From onnx-fb-universe with MIT License 6 votes vote down vote up
def test_at_op(self):
        x = Variable(torch.randn(3, 4))

        class MyFun(Function):

            @staticmethod
            def symbolic(g, x):
                return g.at("add", x, x)

            @staticmethod
            def forward(ctx, x):
                return x + x

        class MyModule(Module):
            def forward(self, x):
                return MyFun.apply(x)

        self.assertONNX(MyModule(), x) 
Example #4
Source File: test_verify.py    From onnx-fb-universe with MIT License 6 votes vote down vote up
def test_result_different(self):
        class BrokenAdd(Function):
            @staticmethod
            def symbolic(g, a, b):
                return g.op("Add", a, b)

            @staticmethod
            def forward(ctx, a, b):
                return a.sub(b) # yahaha! you found me!

        class MyModel(Module):
            def forward(self, x, y):
                return BrokenAdd().apply(x, y)

        x = Variable(torch.Tensor([1,2]))
        y = Variable(torch.Tensor([3,4]))
        self.assertVerifyExpectFail(MyModel(), (x, y), backend) 
Example #5
Source File: miscs.py    From homura with Apache License 2.0 6 votes vote down vote up
def straight_backprop(function):
    """ A function whose `derivative` is as linear

    >>> straight_backprop_relu = straight_backprop(F.relu)
    >>> straight_backprop_relu(tensor)

    :param function: original function
    :return: modified function
    """

    class _StraightBackprop(Function):
        @staticmethod
        def forward(ctx, inputs):
            return function(inputs)

        @staticmethod
        def backward(ctx, grad_outputs):
            return grad_outputs

    return _StraightBackprop.apply 
Example #6
Source File: test_gridpooling.py    From occupancy_networks with MIT License 5 votes vote down vote up
def grid_pooling_auto(pts, feat):
    xv_value, yv_value, zv_value = np.meshgrid(x_grids[:-1], y_grids[:-1], z_grids[:-1], indexing='ij')
    xv_value = xv_value.flatten()
    yv_value = yv_value.flatten()
    zv_value = zv_value.flatten()

    
    feat_cell = Variable(torch.zeros((len(x_grids)-1) * (len(y_grids)-1) * (len(z_grids)-1), C).type(dtype))
    #for k in range(batchsize):
    for i_,(x_,y_,z_) in enumerate(zip(xv_value, yv_value, zv_value)): 
        pts_index = pts_in_cell(pts.unsqueeze(0),[x_,y_,z_,
            x_+len_cell, y_+len_cell, z_+len_cell])
        if len(pts_index)>0:
            pts_index = torch.LongTensor(pts_index).type(dtype_long)
            #pts_feat = feat.index_select(0, pts_index)
            pts_feat = feat[pts_index,:]
            # max pooling
            #pts_feat,_ = torch.max(pts_feat, 0)
            m = nn.MaxPool1d(len(pts_index))
            pts_feat = m(pts_feat.t().unsqueeze(0))
            feat_cell[i_, :] = pts_feat.squeeze()
    return feat_cell

#class GridPooling(Function):
#    def forward(self, points, feat_points):
#        feat_cells = torch.zeros(W*H*D, C).type(dtype)
#        indices = -1 * torch.ones(W*H*D, C).type(dtype_long)
#        shape = torch.LongTensor([W, H, D]).type(dtype_long)
#        forward_utils.grid_pooling_forward(points, feat_points, shape, feat_cells, indices) 
#        self.saved_indices = indices
#        return feat_cells 
#
#    def backward(self, grad_output):
#        grad_points = torch.zeros(N, C).type(torch.FloatTensor)
#        forward_utils.grid_pooling_backward( grad_output, self.saved_indices, grad_points) 
#        return None, grad_points