import unittest from functools import partial import torch from torch.autograd import gradcheck, Variable import pyinn as P from pyinn.modules import Conv2dDepthwise import torch.nn.functional as F def ncrelu_ref(input): return torch.cat([F.relu(input), -F.relu(-input)], 1) def cdgmm_ref(A, B): C = Variable(A.data.new(A.size())) A_r = A[..., 0].contiguous().view(-1, A.size(-2)) A_i = A[..., 1].contiguous().view(-1, A.size(-2)) B_r = B[..., 0].contiguous().view(-1).unsqueeze(0).expand_as(A_i) B_i = B[..., 1].contiguous().view(-1).unsqueeze(0).expand_as(A_r) C[..., 0] = A_r * B_r - A_i * B_i C[..., 1] = A_r * B_i + A_i * B_r return C class TestPYINN(unittest.TestCase): def testNCReLU(self): for dtype in [torch.cuda.FloatTensor, torch.cuda.DoubleTensor]: x = Variable(torch.randn(2,5,3,1).type(dtype), requires_grad=True) #go = Variable(torch.randn(2,10,3,1).cuda(), requires_grad=False) go = torch.randn(2,10,3,1).type(dtype) self.assertEqual((ncrelu_ref(x).data - P.ncrelu(x).data).abs().sum(), 0) ncrelu_ref(x).backward(go) gref = x.grad.data.clone() x.grad.data.zero_() P.ncrelu(x).backward(go) g = x.grad.data.clone() self.assertLess((g - gref).abs().sum(), 1e-8) def testDGMM(self): inputs = Variable(torch.randn(16, 8).cuda()) x = Variable(torch.randn(8).cuda()) c_ref = inputs.mm(torch.diag(x)) c_out = P.dgmm(inputs, x) self.assertEqual((c_ref.data - c_out.data).abs().max(), 0, 'DGMM left') # transposed c_ref = torch.diag(x).mm(inputs.t()) c_out = P.dgmm(inputs.t().contiguous(), x) self.assertEqual((c_ref.data - c_out.data).abs().max(), 0, 'DGMM right') # grad wrt inputs inputs.requires_grad, x.requires_grad = True, False P.dgmm(inputs, x).sum().backward() g_out = inputs.grad.data.clone() inputs.grad.data.zero_() inputs.mm(torch.diag(x)).sum().backward() g_ref = inputs.grad.data.clone() self.assertEqual((g_ref - g_out).abs().max(), 0) # grad wrt x inputs.requires_grad, x.requires_grad = False, True P.dgmm(inputs, x).sum().backward() g_out = x.grad.data.clone() x.grad.data.zero_() inputs.mm(torch.diag(x)).sum().backward() g_ref = x.grad.data.clone() self.assertLess((g_ref - g_out).abs().max(), 1e-6) # grad wrt inputs and x inputs.requires_grad, x.requires_grad = True, True x.grad.data.zero_() inputs.grad.data.zero_() P.dgmm(inputs, x).sum().backward() g_x_out = x.grad.data.clone() g_inputs_out = inputs.grad.data.clone() x.grad.data.zero_() inputs.grad.data.zero_() inputs.mm(torch.diag(x)).sum().backward() g_x_ref = x.grad.data.clone() g_x_inputs_out = inputs.grad.data.clone() self.assertLess((g_ref - g_out).abs().max(), 1e-6) self.assertLess((g_x_ref - g_x_out).abs().max(), 1e-6) def testCDGMM(self): inputs = Variable(torch.randn(16, 8, 2).cuda()) x = Variable(torch.randn(8, 2).cuda()) c_ref = cdgmm_ref(inputs, x) c_out = P.cdgmm(inputs, x) self.assertLess((c_ref.data - c_out.data).abs().max(), 1e-6, 'CDGMM left') # grad wrt inputs inputs.requires_grad, x.requires_grad = True, False P.cdgmm(inputs, x).sum().backward() g_out = inputs.grad.data.clone() inputs.grad.data.zero_() cdgmm_ref(inputs, x).sum().backward() g_ref = inputs.grad.data.clone() self.assertLess((g_out - g_ref).abs().max(), 1e-6, 'CDGMM grad wrt A') # grad wrt x # inputs.requires_grad, x.requires_grad = False, True # P.cdgmm(inputs, x).sum().backward() # g_out = x.grad.data.clone() # x.grad.data.zero_() # cdgmm_ref(inputs, x).sum().backward() # g_ref = x.grad.data.clone() # self.assertEqual((g_ref - g_out).abs().max(), 0) def testCDGMMscat(self): shapes = [((1, 3, 40, 40, 2), (40, 40, 2)), ((1, 3, 20, 20, 2), (20, 20, 2))] def cdgmm_ref(A, B): C = Variable(A.data.new(A.size())) A_r = A[..., 0].contiguous().view(-1, A.size(-2)*A.size(-3)) A_i = A[..., 1].contiguous().view(-1, A.size(-2)*A.size(-3)) B_r = B[...,0].contiguous().view(B.size(-2)*B.size(-3)).unsqueeze(0).expand_as(A_i) B_i = B[..., 1].contiguous().view(B.size(-2)*B.size(-3)).unsqueeze(0).expand_as(A_r) C[..., 0] = A_r * B_r - A_i * B_i C[..., 1] = A_r * B_i + A_i * B_r return C def cdgmm_scat(A, B): A_ = A.view(-1, A.size(-2)*A.size(-3), 2) B_ = B.view(-1, 2) return P.cdgmm(A_, B_).view_as(A) for shape in shapes: inputs = Variable(torch.randn(*shape[0]).cuda()) x = Variable(torch.randn(*shape[1]).cuda()) c_ref = cdgmm_ref(inputs, x) c = cdgmm_scat(inputs, x) self.assertLess((c_ref.data - c.data).abs().max(), 1e-6, 'CDGMM left') inputs.requires_grad, x.requires_grad = True, False cdgmm_scat(inputs, x).sum().backward() g_out = inputs.grad.data.clone() inputs.grad.data.zero_() cdgmm_ref(inputs, x).sum().backward() g_ref = inputs.grad.data.clone() self.assertLess((g_out - g_ref).abs().max(), 1e-6, 'CDGMM grad wrt A') def test_im2col(self): src = Variable(torch.randn(8,7,7).cuda()) k = 1 pad = 0 s = (1,1) dst = P.im2col(src, k, s, pad) back = P.col2im(dst, k, s, pad) self.assertEqual((src - back).data.abs().max(), 0) def test_im2col_batch(self): src = Variable(torch.randn(4,8,7,7).cuda()) k = 1 pad = 0 s = (1,1) dst = P.im2col(src, k, s, pad) back = P.col2im(dst, k, s, pad) self.assertEqual((src - back).data.abs().max(), 0) def test_conv2d_depthwise(self): n = 6 x = Variable(torch.randn(1,n,5,5).double().cuda(), requires_grad=True) w = Variable(torch.randn(n,1,3,3).double().cuda(), requires_grad=True) y_fast = P.conv2d_depthwise(x, w, padding=1) y_ref = F.conv2d(x, w, padding=1, groups=n) go = torch.randn(y_fast.size()).double().cuda() self.assertLess((y_fast - y_ref).data.abs().max(), 1e-9) x.requires_grad = True w.requires_grad = True y_fast.backward(go) gx_fast = x.grad.data.clone() gw_fast = w.grad.data.clone() x.grad.data.zero_() w.grad.data.zero_() y_ref.backward(go) gx_ref = x.grad.data.clone() gw_ref = w.grad.data.clone() self.assertTrue(gradcheck(partial(P.conv2d_depthwise, padding=1), (x, w,))) def test_conv2d_depthwise_multigpu(self): n = 6 a0 = Variable(torch.randn(1,n,5,5).cuda(0), requires_grad=True) a1 = Variable(torch.randn(1,n,5,5).cuda(1), requires_grad=True) w0 = Variable(torch.randn(n,1,3,3).double().cuda(0), requires_grad=True) w1 = Variable(torch.randn(n,1,3,3).double().cuda(1), requires_grad=True) y0 = P.conv2d_depthwise(a0, w0, padding=1) go = torch.randn(y0.size()).double().cuda() y0.backward(go) y1 = P.conv2d_depthwise(a1, w1, padding=1) y1.backward(go.cuda(1)) def test_modules(self): module = Conv2dDepthwise(channels=8, kernel_size=3) x = Variable(torch.randn(1,8,5,5)) y = module(x) y_cuda = module.cuda()(x.cuda()) self.assertLess((y - y_cuda.cpu()).data.abs().max(), 1e-6) if __name__ == '__main__': unittest.main()