from functools import partial, wraps, reduce from copy import copy, deepcopy import numpy as np import numpy.linalg as nla import scipy.linalg as sla import numpy.testing as npt import mici.matrices as matrices AUTOGRAD_AVAILABLE = True try: import autograd.numpy as anp from autograd import grad from autograd.core import primitive, defvjp except ImportError: AUTOGRAD_AVAILABLE = False import warnings warnings.warn( 'Autograd not available. Skipping gradient tests.') SEED = 3046987125 NUM_SCALAR = 4 NUM_VECTOR = 4 SIZES = {1, 2, 5, 10} ATOL = 1e-10 def iterate_over_matrices(test): @wraps(test) def iterated_test(self): for (matrix, np_matrix) in self.matrix_pairs.values(): yield (test, matrix) return iterated_test def iterate_over_matrix_pairs(test): @wraps(test) def iterated_test(self): for matrix_pair in self.matrix_pairs.values(): yield (test, *matrix_pair) return iterated_test def iterate_over_matrix_pairs_vectors(test): @wraps(test) def iterated_test(self): for key, (matrix, np_matrix) in self.matrix_pairs.items(): for vector in self.vectors[np_matrix.shape[0]]: yield test, matrix, np_matrix, vector return iterated_test def iterate_over_matrix_pairs_premultipliers(test): @wraps(test) def iterated_test(self): for key, (matrix, np_matrix) in self.matrix_pairs.items(): for pre in self.premultipliers[np_matrix.shape[0]]: yield test, matrix, np_matrix, pre return iterated_test def iterate_over_matrix_pairs_postmultipliers(test): @wraps(test) def iterated_test(self): for key, (matrix, np_matrix) in self.matrix_pairs.items(): for post in self.postmultipliers[np_matrix.shape[1]]: yield test, matrix, np_matrix, post return iterated_test def iterate_over_matrix_pairs_scalars(test): @wraps(test) def iterated_test(self): for matrix, np_matrix in self.matrix_pairs.values(): for scalar in self.scalars: yield test, matrix, np_matrix, scalar return iterated_test def iterate_over_matrix_pairs_scalars_postmultipliers(test): @wraps(test) def iterated_test(self): for matrix, np_matrix in self.matrix_pairs.values(): for scalar in self.scalars: for post in self.postmultipliers[np_matrix.shape[1]]: yield test, matrix, np_matrix, scalar, post return iterated_test def iterate_over_matrix_pairs_scalars_premultipliers(test): @wraps(test) def iterated_test(self): for matrix, np_matrix in self.matrix_pairs.values(): for scalar in self.scalars: for pre in self.premultipliers[np_matrix.shape[0]]: yield test, matrix, np_matrix, scalar, pre return iterated_test class MatrixTestCase(object): def __init__(self, matrix_pairs, rng=None): self.matrix_pairs = matrix_pairs self.rng = np.random.RandomState(SEED) if rng is None else rng # Ensure a mix of positive and negative scalar multipliers self.scalars = np.abs(self.rng.standard_normal(NUM_SCALAR)) self.scalars[NUM_SCALAR // 2:] = -self.scalars[NUM_SCALAR // 2:] self.premultipliers = { shape_0: self._generate_premultipliers(shape_0) for shape_0 in set(m.shape[0] for _, m in matrix_pairs.values())} self.postmultipliers = { shape_1: self._generate_postmultipliers(shape_1) for shape_1 in set(m.shape[1] for _, m in matrix_pairs.values())} def _generate_premultipliers(self, size): return ( [self.rng.standard_normal((size,))] + [self.rng.standard_normal((s, size)) for s in [1, size, 2 * size]] ) def _generate_postmultipliers(self, size): return ( [self.rng.standard_normal((size,))] + [self.rng.standard_normal((size, s)) for s in [1, size, 2 * size]] ) @iterate_over_matrices def test_self_equality(matrix): assert matrix == matrix @iterate_over_matrices def test_hashable(matrix): assert hash(matrix) == hash(matrix) @iterate_over_matrices def test_copy_equality(matrix): matrix_copy = copy(matrix) assert matrix == matrix_copy assert hash(matrix) == hash(matrix_copy) @iterate_over_matrices def test_deepcopy_equality(matrix): matrix_copy = deepcopy(matrix) assert matrix == matrix_copy assert hash(matrix) == hash(matrix_copy) @iterate_over_matrix_pairs def test_shape(matrix, np_matrix): assert ( matrix.shape == (None, None) or matrix.shape == np_matrix.shape) @iterate_over_matrix_pairs_postmultipliers def test_lmult(matrix, np_matrix, post): npt.assert_allclose(matrix @ post, np_matrix @ post) @iterate_over_matrix_pairs_premultipliers def test_rmult(matrix, np_matrix, pre): npt.assert_allclose(pre @ matrix, pre @ np_matrix) @iterate_over_matrix_pairs_postmultipliers def test_neg_lmult(matrix, np_matrix, post): npt.assert_allclose((-matrix) @ post, -np_matrix @ post) @iterate_over_matrix_pairs_postmultipliers def test_lmult_rmult_trans(matrix, np_matrix, post): npt.assert_allclose(matrix @ post, (post.T @ matrix.T).T) @iterate_over_matrix_pairs_premultipliers def test_rmult_lmult_trans(matrix, np_matrix, pre): npt.assert_allclose(pre @ matrix, (matrix.T @ pre.T).T) @iterate_over_matrix_pairs_scalars_postmultipliers def test_lmult_scalar_lmult(matrix, np_matrix, scalar, post): npt.assert_allclose( (scalar * matrix) @ post, scalar * np_matrix @ post) @iterate_over_matrix_pairs_scalars_postmultipliers def test_rdiv_scalar_lmult(matrix, np_matrix, scalar, post): npt.assert_allclose( (matrix / scalar) @ post, (np_matrix / scalar) @ post) @iterate_over_matrix_pairs_scalars_postmultipliers def test_rmult_scalar_lmult(matrix, np_matrix, scalar, post): npt.assert_allclose( (matrix * scalar) @ post, (np_matrix * scalar) @ post) @iterate_over_matrix_pairs_scalars_premultipliers def test_lmult_scalar_rmult(matrix, np_matrix, scalar, pre): npt.assert_allclose( pre @ (scalar * matrix), pre @ (scalar * np_matrix)) @iterate_over_matrix_pairs_scalars_premultipliers def test_rmult_scalar_rmult(matrix, np_matrix, scalar, pre): npt.assert_allclose( pre @ (matrix * scalar), pre @ (np_matrix * scalar)) class ExplicitShapeMatrixTestCase(MatrixTestCase): def test_matrix_inequality_different_shapes(self): matrices_ = [matrix for matrix, _ in self.matrix_pairs.values()] for matrix_1, matrix_2 in zip(matrices_[:-1], matrices_[1:]): assert matrix_1 != matrix_2 # Technically hashes could collide, but assume sufficiently low # probability for this to happen on small set of comparisons that # equal hashes for different matrices is indicative of issue with # hash implementation assert hash(matrix_1) != hash(matrix_2) @iterate_over_matrix_pairs def test_array(matrix, np_matrix): npt.assert_allclose(matrix.array, np_matrix) @iterate_over_matrix_pairs def test_array_transpose(matrix, np_matrix): npt.assert_allclose(matrix.T.array, np_matrix.T) @iterate_over_matrix_pairs def test_array_transpose_transpose(matrix, np_matrix): npt.assert_allclose(matrix.T.T.array, np_matrix) @iterate_over_matrix_pairs def test_array_numpy(matrix, np_matrix): npt.assert_allclose(matrix, np_matrix) @iterate_over_matrix_pairs def test_diagonal(matrix, np_matrix): npt.assert_allclose(matrix.diagonal, np_matrix.diagonal()) @iterate_over_matrix_pairs_scalars def test_lmult_scalar_array(matrix, np_matrix, scalar): npt.assert_allclose((scalar * matrix).array, scalar * np_matrix) @iterate_over_matrix_pairs_scalars def test_rmult_scalar_array(matrix, np_matrix, scalar): npt.assert_allclose((matrix * scalar).array, np_matrix * scalar) @iterate_over_matrix_pairs_scalars def test_rdiv_scalar_array(matrix, np_matrix, scalar): npt.assert_allclose((matrix / scalar).array, np_matrix / scalar) @iterate_over_matrix_pairs def test_neg_array(matrix, np_matrix): npt.assert_allclose((-matrix).array, -np_matrix) class SquareMatrixTestCase(MatrixTestCase): def __init__(self, matrix_pairs, rng=None): super().__init__(matrix_pairs, rng) self.vectors = { size: self.rng.standard_normal((NUM_VECTOR, size)) for size in set(m.shape[0] for _, m in matrix_pairs.values())} @iterate_over_matrix_pairs_vectors def test_quadratic_form(matrix, np_matrix, vector): npt.assert_allclose( vector @ matrix @ vector, vector @ np_matrix @ vector) class ExplicitShapeSquareMatrixTestCase(SquareMatrixTestCase): @iterate_over_matrix_pairs def test_log_abs_det(matrix, np_matrix): npt.assert_allclose( matrix.log_abs_det, nla.slogdet(np_matrix)[1], atol=ATOL) class SymmetricMatrixTestCase(SquareMatrixTestCase): @iterate_over_matrix_pairs def test_symmetry_identity(matrix, np_matrix): assert matrix is matrix.T @iterate_over_matrix_pairs_postmultipliers def test_symmetry_lmult(matrix, np_matrix, post): npt.assert_allclose(matrix @ post, (post.T @ matrix).T) @iterate_over_matrix_pairs_premultipliers def test_symmetry_rmult(matrix, np_matrix, pre): npt.assert_allclose(pre @ matrix, (matrix @ pre.T).T) class ExplicitShapeSymmetricMatrixTestCase( SymmetricMatrixTestCase, ExplicitShapeSquareMatrixTestCase): @iterate_over_matrix_pairs def test_symmetry_array(matrix, np_matrix): npt.assert_allclose(matrix.array, matrix.T.array) @iterate_over_matrix_pairs def test_eigval(matrix, np_matrix): # Ensure eigenvalues in ascending order npt.assert_allclose( np.sort(matrix.eigval), nla.eigh(np_matrix)[0]) @iterate_over_matrix_pairs def test_eigvec(matrix, np_matrix): # Ensure eigenvectors correspond to ascending eigenvalue ordering eigval_order = np.argsort(matrix.eigval) eigvec = matrix.eigvec.array[:, eigval_order] np_eigvec = nla.eigh(np_matrix)[1] # Account for eigenvector sign ambiguity when checking for equivalence assert np.all( np.isclose(eigvec, np_eigvec) | np.isclose(eigvec, -np_eigvec)) class InvertibleMatrixTestCase(MatrixTestCase): @iterate_over_matrix_pairs_postmultipliers def test_lmult_inv(matrix, np_matrix, post): npt.assert_allclose(matrix.inv @ post, nla.solve(np_matrix, post)) @iterate_over_matrix_pairs_premultipliers def test_rmult_inv(matrix, np_matrix, pre): npt.assert_allclose(pre @ matrix.inv, nla.solve(np_matrix.T, pre.T).T) @iterate_over_matrix_pairs_scalars_postmultipliers def test_lmult_scalar_inv_lmult(matrix, np_matrix, scalar, post): npt.assert_allclose( (scalar * matrix.inv) @ post, nla.solve(np_matrix / scalar, post)) @iterate_over_matrix_pairs_scalars_postmultipliers def test_inv_lmult_scalar_lmult(matrix, np_matrix, scalar, post): npt.assert_allclose( (scalar * matrix).inv @ post, nla.solve(scalar * np_matrix, post)) @iterate_over_matrix_pairs_vectors def test_quadratic_form_inv(matrix, np_matrix, vector): npt.assert_allclose( vector @ matrix.inv @ vector, vector @ nla.solve(np_matrix, vector)) class ExplicitShapeInvertibleMatrixTestCase( ExplicitShapeSquareMatrixTestCase, InvertibleMatrixTestCase): @iterate_over_matrix_pairs def test_array_inv(matrix, np_matrix): npt.assert_allclose(matrix.inv.array, nla.inv(np_matrix), atol=ATOL) @iterate_over_matrix_pairs def test_array_inv_inv(matrix, np_matrix): npt.assert_allclose(matrix.inv.inv.array, np_matrix, atol=ATOL) @iterate_over_matrix_pairs def test_log_abs_det_inv(matrix, np_matrix): npt.assert_allclose( matrix.inv.log_abs_det, -nla.slogdet(np_matrix)[1], atol=ATOL) class PositiveDefiniteMatrixTestCase( SymmetricMatrixTestCase, InvertibleMatrixTestCase): @iterate_over_matrix_pairs_vectors def test_pos_def(matrix, np_matrix, vector): assert vector @ matrix @ vector > 0 @iterate_over_matrix_pairs_postmultipliers def test_lmult_sqrt(matrix, np_matrix, post): npt.assert_allclose( matrix.sqrt @ (matrix.sqrt.T @ post), np_matrix @ post) @iterate_over_matrix_pairs_premultipliers def test_rmult_sqrt(matrix, np_matrix, pre): npt.assert_allclose( (pre @ matrix.sqrt) @ matrix.sqrt.T, pre @ np_matrix) @iterate_over_matrix_pairs def test_inv_is_posdef(matrix, np_matrix): assert isinstance(matrix.inv, matrices.PositiveDefiniteMatrix) @iterate_over_matrix_pairs def test_pos_scalar_multiple_is_posdef(matrix, np_matrix): assert isinstance(matrix * 2, matrices.PositiveDefiniteMatrix) class ExplicitShapePositiveDefiniteMatrixTestCase( PositiveDefiniteMatrixTestCase, ExplicitShapeInvertibleMatrixTestCase, ExplicitShapeSymmetricMatrixTestCase): @iterate_over_matrix_pairs def test_sqrt_array(matrix, np_matrix): npt.assert_allclose((matrix.sqrt @ matrix.sqrt.T).array, np_matrix) class DifferentiableMatrixTestCase(MatrixTestCase): def __init__(self, matrix_pairs, get_param, param_func, rng=None): super().__init__(matrix_pairs, rng) self.get_param = get_param self.param_func = param_func if AUTOGRAD_AVAILABLE: def grad_log_abs_det(self, matrix): param = self.get_param(matrix) return grad( lambda p: anp.linalg.slogdet( self.param_func(p, matrix))[1])(param) def grad_quadratic_form_inv(self, matrix): param = self.get_param(matrix) return lambda v: grad( lambda p: v @ anp.linalg.solve( self.param_func(p, matrix), v))(param) def check_grad_log_abs_det(self, matrix, grad_log_abs_det): # Use non-zero atol to allow for floating point errors in gradients # analytically equal to zero npt.assert_allclose( matrix.grad_log_abs_det, grad_log_abs_det, atol=1e-10) def test_grad_log_abs_det(self): for key, (matrix, np_matrix) in self.matrix_pairs.items(): yield (self.check_grad_log_abs_det, matrix, self.grad_log_abs_det(matrix)) def check_grad_quadratic_form_inv( self, matrix, vector, grad_quadratic_form_inv): # Use non-zero atol to allow for floating point errors in gradients # analytically equal to zero npt.assert_allclose( matrix.grad_quadratic_form_inv(vector), grad_quadratic_form_inv(vector), atol=1e-10) def test_grad_quadratic_form_inv(self): for key, (matrix, np_matrix) in self.matrix_pairs.items(): for vector in self.vectors[np_matrix.shape[0]]: yield (self.check_grad_quadratic_form_inv, matrix, vector, self.grad_quadratic_form_inv(matrix)) class TestImplicitIdentityMatrix( SymmetricMatrixTestCase, InvertibleMatrixTestCase): def __init__(self): super().__init__({sz: ( matrices.IdentityMatrix(None), np.identity(sz)) for sz in SIZES}) class TestIdentityMatrix(ExplicitShapePositiveDefiniteMatrixTestCase): def __init__(self): super().__init__({sz: ( matrices.IdentityMatrix(sz), np.identity(sz)) for sz in SIZES}) class TestImplicitScaledIdentityMatrix( InvertibleMatrixTestCase, SymmetricMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for sz in SIZES: scalar = rng.normal() matrix_pairs[sz] = ( matrices.ScaledIdentityMatrix(scalar, None), scalar * np.identity(sz)) super().__init__(matrix_pairs, rng) class DifferentiableScaledIdentityMatrixTestCase(DifferentiableMatrixTestCase): def __init__(self, generate_scalar, matrix_class): rng = np.random.RandomState(SEED) matrix_pairs = {} for sz in SIZES: scalar = generate_scalar(rng) matrix_pairs[sz] = ( matrix_class(scalar, sz), scalar * np.identity(sz)) if AUTOGRAD_AVAILABLE: def param_func(param, matrix): return param * anp.eye(matrix.shape[0]) def get_param(matrix): return matrix._scalar else: param_func, get_param = None, None super().__init__( matrix_pairs, get_param, param_func, rng) class TestScaledIdentityMatrix( DifferentiableScaledIdentityMatrixTestCase, ExplicitShapeSymmetricMatrixTestCase, ExplicitShapeInvertibleMatrixTestCase): def __init__(self): super().__init__( lambda rng: rng.normal(), matrices.ScaledIdentityMatrix) class TestPositiveScaledIdentityMatrix( DifferentiableScaledIdentityMatrixTestCase, ExplicitShapePositiveDefiniteMatrixTestCase): def __init__(self): super().__init__( lambda rng: abs(rng.normal()), matrices.PositiveScaledIdentityMatrix) class DifferentiableDiagonalMatrixTestCase(DifferentiableMatrixTestCase): def __init__(self, generate_diagonal, matrix_class): matrix_pairs = {} rng = np.random.RandomState(SEED) for sz in SIZES: diagonal = generate_diagonal(sz, rng) matrix_pairs[sz] = (matrix_class(diagonal), np.diag(diagonal)) if AUTOGRAD_AVAILABLE: def param_func(param, matrix): return anp.diag(param) def get_param(matrix): return matrix.diagonal else: param_func, get_param = None, None super().__init__(matrix_pairs, get_param, param_func, rng) class TestDiagonalMatrix( DifferentiableDiagonalMatrixTestCase, ExplicitShapeSymmetricMatrixTestCase, ExplicitShapeInvertibleMatrixTestCase): def __init__(self): super().__init__( lambda sz, rng: rng.standard_normal(sz), matrices.DiagonalMatrix) class TestPositiveDiagonalMatrix( DifferentiableDiagonalMatrixTestCase, ExplicitShapePositiveDefiniteMatrixTestCase): def __init__(self): super().__init__( lambda sz, rng: abs(rng.standard_normal(sz)), matrices.PositiveDiagonalMatrix) class TestTriangularMatrix(ExplicitShapeInvertibleMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for sz in SIZES: for lower in [True, False]: array = rng.standard_normal((sz, sz)) tri_array = np.tril(array) if lower else np.triu(array) matrix_pairs[(sz, lower)] = ( matrices.TriangularMatrix(tri_array, lower), tri_array) super().__init__(matrix_pairs, rng) class TestInverseTriangularMatrix(ExplicitShapeInvertibleMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for sz in SIZES: for lower in [True, False]: array = rng.standard_normal((sz, sz)) inv_tri_array = np.tril(array) if lower else np.triu(array) matrix_pairs[(sz, lower)] = ( matrices.InverseTriangularMatrix(inv_tri_array, lower), nla.inv(inv_tri_array)) super().__init__(matrix_pairs, rng) class DifferentiableTriangularFactoredDefiniteMatrixTestCase( DifferentiableMatrixTestCase): def __init__(self, matrix_class, signs): matrix_pairs = {} rng = np.random.RandomState(SEED) for sz in SIZES: for factor_is_lower in [True, False]: for sign in signs: array = rng.standard_normal((sz, sz)) tri_array = sla.cholesky(array @ array.T, factor_is_lower) matrix_pairs[(sz, factor_is_lower, sign)] = ( matrix_class(tri_array, sign, factor_is_lower), sign * tri_array @ tri_array.T) if AUTOGRAD_AVAILABLE: def param_func(param, matrix): param = ( anp.tril(param) if matrix.factor.lower else anp.triu(param)) return param @ param.T def get_param(matrix): return matrix.factor.array else: param_func, get_param = None, None super().__init__(matrix_pairs, get_param, param_func, rng) class TestTriangularFactoredDefiniteMatrix( DifferentiableTriangularFactoredDefiniteMatrixTestCase, ExplicitShapeSymmetricMatrixTestCase, ExplicitShapeInvertibleMatrixTestCase): def __init__(self): super().__init__(matrices.TriangularFactoredDefiniteMatrix, (+1, -1)) class TestTriangularFactoredPositiveDefiniteMatrix( DifferentiableTriangularFactoredDefiniteMatrixTestCase, ExplicitShapePositiveDefiniteMatrixTestCase): def __init__(self): super().__init__( lambda factor, sign, factor_is_lower: matrices.TriangularFactoredPositiveDefiniteMatrix( factor, factor_is_lower), (+1,)) class DifferentiableDenseDefiniteMatrixTestCase(DifferentiableMatrixTestCase): def __init__(self, matrix_class, signs): matrix_pairs = {} rng = np.random.RandomState(SEED) for sz in SIZES: for sign in signs: sqrt_array = rng.standard_normal((sz, sz)) array = sign * sqrt_array @ sqrt_array.T matrix_pairs[(sz, sign)] = ( matrix_class(array, is_posdef=(sign == 1)), array) if AUTOGRAD_AVAILABLE: def param_func(param, matrix): return param def get_param(matrix): return matrix.array else: param_func, get_param = None, None super().__init__(matrix_pairs, get_param, param_func, rng) class TestDenseDefiniteMatrix( DifferentiableDenseDefiniteMatrixTestCase, ExplicitShapeSymmetricMatrixTestCase, ExplicitShapeInvertibleMatrixTestCase): def __init__(self): super().__init__(matrices.DenseDefiniteMatrix, (+1, -1)) class TestDensePositiveDefiniteMatrix( DifferentiableDenseDefiniteMatrixTestCase, ExplicitShapePositiveDefiniteMatrixTestCase): def __init__(self): super().__init__( lambda array, is_posdef: matrices.DensePositiveDefiniteMatrix(array), (+1,)) class TestDensePositiveDefiniteProductMatrix( DifferentiableMatrixTestCase, ExplicitShapePositiveDefiniteMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for dim_0 in SIZES: for dim_1 in [dim_0 + 1, dim_0 * 2]: rect_matrix = rng.standard_normal((dim_0, dim_1)) pos_def_matrix = rng.standard_normal((dim_1, dim_1)) pos_def_matrix = pos_def_matrix @ pos_def_matrix.T array = rect_matrix @ pos_def_matrix @ rect_matrix.T matrix_pairs[(dim_0, dim_1)] = ( matrices.DensePositiveDefiniteProductMatrix( rect_matrix, pos_def_matrix), array) if AUTOGRAD_AVAILABLE: def param_func(param, matrix): return param @ matrix._pos_def_matrix @ param.T def get_param(matrix): return matrix._rect_matrix.array else: param_func, get_param = None, None super().__init__(matrix_pairs, get_param, param_func, rng) class TestDenseSquareMatrix(ExplicitShapeInvertibleMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for sz in SIZES: array = rng.standard_normal((sz, sz)) matrix_pairs[sz] = ( matrices.DenseSquareMatrix(array), array) super().__init__(matrix_pairs, rng) class TestInverseLUFactoredSquareMatrix(ExplicitShapeInvertibleMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for sz in SIZES: for transposed in [True, False]: inverse_array = rng.standard_normal((sz, sz)) inverse_lu_and_piv = sla.lu_factor( inverse_array.T if transposed else inverse_array) array = nla.inv(inverse_array) matrix_pairs[(sz, transposed)] = ( matrices.InverseLUFactoredSquareMatrix( inverse_array, inverse_lu_and_piv, transposed), array) super().__init__(matrix_pairs, rng) class TestDenseSymmetricMatrix( ExplicitShapeInvertibleMatrixTestCase, ExplicitShapeSymmetricMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for sz in SIZES: array = rng.standard_normal((sz, sz)) array = array + array.T matrix_pairs[sz] = ( matrices.DenseSymmetricMatrix(array), array) super().__init__(matrix_pairs, rng) class TestOrthogonalMatrix(ExplicitShapeInvertibleMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for sz in SIZES: array = nla.qr(rng.standard_normal((sz, sz)))[0] matrix_pairs[sz] = (matrices.OrthogonalMatrix(array), array) super().__init__(matrix_pairs, rng) class TestScaledOrthogonalMatrix(ExplicitShapeInvertibleMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for sz in SIZES: orth_array = nla.qr(rng.standard_normal((sz, sz)))[0] scalar = rng.standard_normal() matrix_pairs[sz] = ( matrices.ScaledOrthogonalMatrix(scalar, orth_array), scalar * orth_array) super().__init__(matrix_pairs, rng) class TestEigendecomposedSymmetricMatrix( ExplicitShapeInvertibleMatrixTestCase, ExplicitShapeSymmetricMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for sz in SIZES: eigvec = nla.qr(rng.standard_normal((sz, sz)))[0] eigval = rng.standard_normal(sz) matrix_pairs[sz] = ( matrices.EigendecomposedSymmetricMatrix(eigvec, eigval), (eigvec * eigval) @ eigvec.T) super().__init__(matrix_pairs, rng) class TestEigendecomposedPositiveDefiniteMatrix( ExplicitShapePositiveDefiniteMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for sz in SIZES: eigvec = nla.qr(rng.standard_normal((sz, sz)))[0] eigval = np.abs(rng.standard_normal(sz)) matrix_pairs[sz] = ( matrices.EigendecomposedPositiveDefiniteMatrix(eigvec, eigval), (eigvec * eigval) @ eigvec.T) super().__init__(matrix_pairs, rng) class TestSoftAbsRegularizedPositiveDefiniteMatrix( DifferentiableMatrixTestCase, ExplicitShapePositiveDefiniteMatrixTestCase): def __init__(self): matrix_pairs, grad_log_abs_dets, grad_quadratic_form_invs = {}, {}, {} rng = np.random.RandomState(SEED) for sz in SIZES: for softabs_coeff in [0.5, 1., 1.5]: sym_array = rng.standard_normal((sz, sz)) sym_array = sym_array + sym_array.T unreg_eigval, eigvec = np.linalg.eigh(sym_array) eigval = unreg_eigval / np.tanh(unreg_eigval * softabs_coeff) matrix_pairs[(sz, softabs_coeff)] = ( matrices.SoftAbsRegularizedPositiveDefiniteMatrix( sym_array, softabs_coeff ), (eigvec * eigval) @ eigvec.T) if AUTOGRAD_AVAILABLE: def get_param(matrix): eigvec = matrix.eigvec.array return (eigvec * matrix.unreg_eigval) @ eigvec.T def param_func(param, matrix): softabs_coeff = matrix._softabs_coeff sym_array = (param + param.T) / 2 unreg_eigval, eigvec = anp.linalg.eigh(sym_array) eigval = unreg_eigval / anp.tanh(unreg_eigval * softabs_coeff) return (eigvec * eigval) @ eigvec.T else: param_func, get_param = None, None super().__init__(matrix_pairs, get_param, param_func, rng) class TestMatrixProduct(ExplicitShapeMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for s in SIZES: for n_terms in [2, 4]: for explicit in [True, False]: arrays = [ rng.standard_normal((s if t % 2 == 0 else 2 * s, 2 * s if t % 2 == 0 else s)) for t in range(n_terms)] matrices_ = [ matrices.DenseRectangularMatrix(a) for a in arrays] if explicit: matrix = matrices.MatrixProduct(matrices_) else: matrix = reduce(lambda a, b: a @ b, matrices_) matrix_pairs[(s, n_terms, explicit)] = ( matrix, nla.multi_dot(arrays)) super().__init__(matrix_pairs, rng) class TestSquareMatrixProduct(ExplicitShapeSquareMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for s in SIZES: for n_terms in [2, 5]: arrays = [ rng.standard_normal((s, s)) for _ in range(n_terms)] matrix = matrices.SquareMatrixProduct([ matrices.DenseSquareMatrix(a) for a in arrays]) matrix_pairs[(s, n_terms)] = (matrix, nla.multi_dot(arrays)) super().__init__(matrix_pairs, rng) class TestInvertibleMatrixProduct(ExplicitShapeInvertibleMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for s in SIZES: for n_terms in [2, 5]: for explicit in [True, False]: arrays = [ rng.standard_normal((s, s)) for _ in range(n_terms)] matrices_ = [ matrices.DenseSquareMatrix(a) for a in arrays] if explicit: matrix = matrices.InvertibleMatrixProduct(matrices_) else: matrix = reduce(lambda a, b: a @ b, matrices_) matrix_pairs[(s, n_terms, explicit)] = ( matrix, nla.multi_dot(arrays)) super().__init__(matrix_pairs, rng) class TestSquareBlockDiagonalMatrix(ExplicitShapeInvertibleMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for s in SIZES: for n_block in [1, 2, 5]: arrays = [rng.standard_normal((s, s)) for _ in range(n_block)] matrix_pairs[(s, n_block)] = ( matrices.SquareBlockDiagonalMatrix( matrices.DenseSquareMatrix(arr) for arr in arrays), sla.block_diag(*arrays)) super().__init__(matrix_pairs, rng) class TestSymmetricBlockDiagonalMatrix( ExplicitShapeInvertibleMatrixTestCase, ExplicitShapeSymmetricMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for s in SIZES: for n_block in [1, 2, 5]: arrays = [rng.standard_normal((s, s)) for _ in range(n_block)] arrays = [arr + arr.T for arr in arrays] matrix_pairs[(s, n_block)] = ( matrices.SymmetricBlockDiagonalMatrix( matrices.DenseSymmetricMatrix(arr) for arr in arrays), sla.block_diag(*arrays)) super().__init__(matrix_pairs, rng) class TestPositiveDefiniteBlockDiagonalMatrix( ExplicitShapePositiveDefiniteMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for s in SIZES: for n_block in [1, 2, 5]: arrays = [rng.standard_normal((s, s)) for _ in range(n_block)] arrays = [arr @ arr.T for arr in arrays] matrix_pairs[(s, n_block)] = ( matrices.PositiveDefiniteBlockDiagonalMatrix( matrices.DensePositiveDefiniteMatrix(arr) for arr in arrays), sla.block_diag(*arrays)) super().__init__(matrix_pairs, rng) class TestPositiveDefiniteBlockDiagonalMatrix( DifferentiableMatrixTestCase, ExplicitShapePositiveDefiniteMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for s in SIZES: for n_block in [1, 2, 5]: arrays = [rng.standard_normal((s, s)) for _ in range(n_block)] arrays = [arr @ arr.T for arr in arrays] matrix_pairs[(s, n_block)] = ( matrices.PositiveDefiniteBlockDiagonalMatrix( matrices.DensePositiveDefiniteMatrix(arr) for arr in arrays), sla.block_diag(*arrays)) if AUTOGRAD_AVAILABLE: @primitive def block_diag(blocks): return sla.block_diag(*blocks) def vjp_block_diag(ans, blocks): blocks = tuple(blocks) def vjp(g): i, j = 0, 0 vjp_blocks = [] for block in blocks: j += block.shape[0] vjp_blocks.append(g[i:j, i:j]) i = j return tuple(vjp_blocks) return vjp defvjp(block_diag, vjp_block_diag) def get_param(matrix): return tuple( block.array for block in matrix._blocks) def param_func(param, matrix): return block_diag(param) else: param_func, get_param = None, None super().__init__(matrix_pairs, get_param, param_func, rng) class TestDenseRectangularMatrix(ExplicitShapeMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for s0 in SIZES: for s1 in SIZES: if s0 != s1: array = rng.standard_normal((s0, s1)) matrix_pairs[(s0, s1)] = ( matrices.DenseRectangularMatrix(array), array) super().__init__(matrix_pairs, rng) class TestBlockRowMatrix(ExplicitShapeMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for s in SIZES: for n_blocks in [2, 5]: blocks = [rng.standard_normal((s, s)) for _ in range(n_blocks)] matrix_pairs[(s, n_blocks)] = ( matrices.BlockRowMatrix( matrices.DenseSquareMatrix(block) for block in blocks), np.hstack(blocks)) super().__init__(matrix_pairs, rng) class TestBlockColumnMatrix(ExplicitShapeMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for s in SIZES: for n_blocks in [2, 5]: blocks = [rng.standard_normal((s, s)) for _ in range(n_blocks)] matrix_pairs[(s, n_blocks)] = ( matrices.BlockColumnMatrix( matrices.DenseSquareMatrix(block) for block in blocks), np.vstack(blocks)) super().__init__(matrix_pairs, rng) class TestSquareLowRankUpdateMatrix(ExplicitShapeInvertibleMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for outer_dim in SIZES: for inner_dim in [max(1, outer_dim // 2), max(1, outer_dim - 1)]: left_factor_matrix = rng.standard_normal( (outer_dim, inner_dim)) right_factor_matrix = rng.standard_normal( (inner_dim, outer_dim)) inner_square_matrix = rng.standard_normal( (inner_dim, inner_dim)) square_matrix = rng.standard_normal((outer_dim, outer_dim)) matrix_pairs[(inner_dim, outer_dim)] = ( matrices.SquareLowRankUpdateMatrix( matrices.DenseRectangularMatrix(left_factor_matrix), matrices.DenseRectangularMatrix(right_factor_matrix), matrices.DenseSquareMatrix(square_matrix), matrices.DenseSquareMatrix(inner_square_matrix)), square_matrix + left_factor_matrix @ ( inner_square_matrix @ right_factor_matrix)) super().__init__(matrix_pairs, rng) class TestNoInnerMatrixSquareLowRankUpdateMatrix( ExplicitShapeInvertibleMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for outer_dim in SIZES: inner_dim = max(1, outer_dim // 2) left_factor_matrix = rng.standard_normal( (outer_dim, inner_dim)) right_factor_matrix = rng.standard_normal( (inner_dim, outer_dim)) square_matrix = rng.standard_normal((outer_dim, outer_dim)) matrix_pairs[(inner_dim, outer_dim)] = ( matrices.SquareLowRankUpdateMatrix( matrices.DenseRectangularMatrix(left_factor_matrix), matrices.DenseRectangularMatrix(right_factor_matrix), matrices.DenseSquareMatrix(square_matrix)), square_matrix + left_factor_matrix @ right_factor_matrix) super().__init__(matrix_pairs, rng) class TestSymmetricLowRankUpdateMatrix( ExplicitShapeSymmetricMatrixTestCase, ExplicitShapeInvertibleMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for outer_dim in SIZES: for inner_dim in [max(1, outer_dim // 2), max(1, outer_dim - 1)]: factor_matrix = rng.standard_normal( (outer_dim, inner_dim)) inner_symmetric_matrix = rng.standard_normal( (inner_dim, inner_dim)) inner_symmetric_matrix = ( inner_symmetric_matrix + inner_symmetric_matrix.T) symmetric_matrix = rng.standard_normal((outer_dim, outer_dim)) symmetric_matrix = symmetric_matrix + symmetric_matrix.T matrix_pairs[(inner_dim, outer_dim)] = ( matrices.SymmetricLowRankUpdateMatrix( matrices.DenseRectangularMatrix(factor_matrix), matrices.DenseSymmetricMatrix(symmetric_matrix), matrices.DenseSymmetricMatrix(inner_symmetric_matrix)), symmetric_matrix + factor_matrix @ ( inner_symmetric_matrix @ factor_matrix.T)) super().__init__(matrix_pairs, rng) class TestPositiveDefiniteLowRankUpdateMatrix( DifferentiableMatrixTestCase, ExplicitShapePositiveDefiniteMatrixTestCase): def __init__(self): matrix_pairs = {} rng = np.random.RandomState(SEED) for outer_dim in SIZES: for inner_dim in [max(1, outer_dim // 2), max(1, outer_dim - 1)]: factor_matrix = rng.standard_normal( (outer_dim, inner_dim)) inner_pos_def_matrix = rng.standard_normal( (inner_dim, inner_dim)) inner_pos_def_matrix = ( inner_pos_def_matrix @ inner_pos_def_matrix.T) pos_def_matrix = rng.standard_normal((outer_dim, outer_dim)) pos_def_matrix = pos_def_matrix @ pos_def_matrix.T matrix_pairs[(inner_dim, outer_dim)] = ( matrices.PositiveDefiniteLowRankUpdateMatrix( matrices.DenseRectangularMatrix(factor_matrix), matrices.DensePositiveDefiniteMatrix(pos_def_matrix), matrices.DensePositiveDefiniteMatrix( inner_pos_def_matrix)), pos_def_matrix + factor_matrix @ ( inner_pos_def_matrix @ factor_matrix.T)) if AUTOGRAD_AVAILABLE: def param_func(param, matrix): return ( matrix.pos_def_matrix.array + param @ matrix.inner_pos_def_matrix @ param.T) def get_param(matrix): return matrix.factor_matrix.array else: param_func, get_param = None, None super().__init__(matrix_pairs, get_param, param_func, rng)