from __future__ import division import pytest from numpy.testing import assert_allclose import torch from torch.autograd import gradcheck, Variable from .fused import fused_prox_jv_slow, fused_prox_jv_fast from .fused import FusedProxFunction def _fused_prox_jacobian(y_hat, dout=None): """reference naive implementation: construct the jacobian""" dim = y_hat.shape[0] groups = torch.zeros(dim) J = torch.zeros(dim, dim) current_group = 0 for i in range(1, dim): if y_hat[i] == y_hat[i - 1]: groups[i] = groups[i - 1] else: current_group += 1 groups[i] = current_group for i in range(dim): for j in range(dim): if groups[i] == groups[j]: n_fused = (groups == groups[i]).sum() J[i, j] = 1 / n_fused.to(y_hat.dtype) if dout is not None: return torch.mv(J, dout) else: return J @pytest.mark.parametrize('alpha', [0.001, 0.01, 0.1, 1]) def test_jv(alpha): torch.manual_seed(1) torch.set_default_tensor_type('torch.DoubleTensor') for _ in range(30): x = Variable(torch.randn(15)) dout = torch.randn(15) y_hat = FusedProxFunction(alpha=alpha)(x).data ref = _fused_prox_jacobian(y_hat, dout) din_slow = fused_prox_jv_slow(y_hat, dout) din_fast = fused_prox_jv_fast(y_hat, dout) assert_allclose(ref.numpy(), din_slow.numpy(), atol=1e-5) assert_allclose(ref.numpy(), din_fast.numpy(), atol=1e-5) @pytest.mark.parametrize('alpha', [0.001, 0.01, 0.1, 1]) def test_finite_diff(alpha): torch.manual_seed(1) torch.set_default_tensor_type('torch.DoubleTensor') for _ in range(30): x = Variable(torch.randn(20), requires_grad=True) func = FusedProxFunction(alpha=alpha) assert gradcheck(func, (x,), eps=1e-4, atol=1e-3)