""" Module for testing kernel module. """ __author__ = 'wittawat' import autograd import autograd.numpy as np import matplotlib.pyplot as plt import kgof.data as data import kgof.density as density import kgof.util as util import kgof.kernel as kernel import kgof.goftest as gof import kgof.glo as glo import scipy.stats as stats import numpy.testing as testing import unittest class TestKGauss(unittest.TestCase): def setUp(self): pass def test_basic(self): """ Nothing special. Just test basic things. """ # sample n = 10 d = 3 with util.NumpySeedContext(seed=29): X = np.random.randn(n, d)*3 k = kernel.KGauss(sigma2=1) K = k.eval(X, X) self.assertEqual(K.shape, (n, n)) self.assertTrue(np.all(K >= 0-1e-6)) self.assertTrue(np.all(K <= 1+1e-6), 'K not bounded by 1') def test_pair_gradX_Y(self): # sample n = 11 d = 3 with util.NumpySeedContext(seed=20): X = np.random.randn(n, d)*4 Y = np.random.randn(n, d)*2 k = kernel.KGauss(sigma2=2.1) # n x d pair_grad = k.pair_gradX_Y(X, Y) loop_grad = np.zeros((n, d)) for i in range(n): for j in range(d): loop_grad[i, j] = k.gradX_Y(X[[i], :], Y[[i], :], j) testing.assert_almost_equal(pair_grad, loop_grad) def test_gradX_y(self): n = 10 with util.NumpySeedContext(seed=10): for d in [1, 3]: y = np.random.randn(d)*2 X = np.random.rand(n, d)*3 sigma2 = 1.3 k = kernel.KGauss(sigma2=sigma2) # n x d G = k.gradX_y(X, y) # check correctness K = k.eval(X, y[np.newaxis, :]) myG = -K/sigma2*(X-y) self.assertEqual(G.shape, myG.shape) testing.assert_almost_equal(G, myG) def test_gradXY_sum(self): n = 11 with util.NumpySeedContext(seed=12): for d in [3, 1]: X = np.random.randn(n, d) sigma2 = 1.4 k = kernel.KGauss(sigma2=sigma2) # n x n myG = np.zeros((n, n)) K = k.eval(X, X) for i in range(n): for j in range(n): diffi2 = np.sum( (X[i, :] - X[j, :])**2 ) #myG[i, j] = -diffi2*K[i, j]/(sigma2**2)+ d*K[i, j]/sigma2 myG[i, j] = K[i, j]/sigma2*(d - diffi2/sigma2) # check correctness G = k.gradXY_sum(X, X) self.assertEqual(G.shape, myG.shape) testing.assert_almost_equal(G, myG) def tearDown(self): pass if __name__ == '__main__': unittest.main()