import tensorflow as tf def cholesky_inverse(A): """Compute the inverse of `A` using Choselky decomposition. NOTE: `A` must be symmetric positive definite. This method of inversion is not completely stable since tf.cholesky is not always stable. Might raise `tf.errors.InvalidArgumentError` """ N = tf.shape(A)[0] L = tf.cholesky(A) L_inv = tf.matrix_triangular_solve(L, tf.eye(N)) A_inv = tf.matmul(L_inv, L_inv, transpose_a=True) return A_inv def sherman_morrison_inverse(A_inv, u, v): """Compute the inverse of (A + uv^T) using Sherman-Morrison formula: https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula Args: A_inv: tf.Tensor. The inverse of A or batch. Last two dimensions should have shape [N, N] u: tf.Tensor. (Batch of) column vector(s). Last two dimensions should have shape [N, 1] v: tf.Tensor. (Batch of) column vector(s). Last two dimensions should have shape [N, 1] Returns: (A + uv^T)^{-1} with the same shape as `A_inv` """ assert u.shape.as_list()[-1] == 1 and u.shape.ndims >= 2 assert v.shape.as_list()[-1] == 1 and v.shape.ndims >= 2 A_inv_u = tf.matmul(A_inv, u) num = tf.matmul(A_inv_u, tf.matmul(v, A_inv, transpose_a=True)) denom = tf.matmul(v, A_inv_u, transpose_a=True) denom = 1 + tf.squeeze(denom, axis=[-2, -1]) inverse = A_inv - num / denom return inverse def woodburry_inverse(A_inv, U, V): """Compute the inverse of (A + UV) using Woodburry formula: `(A + UV)^-1 = A^-1 - A^-1 U (I + V A^-1 U)^-1 V A^-1`. For details see: https://en.wikipedia.org/wiki/Woodbury_matrix_identity Args: A_inv: tf.Tensor. The inverse of A, `shape=[N, N]` U: tf.Tensor, `shape=[N, M]` V: tf.Tensor. `shape=[M, N]` Returns: (A + UV)^-1 with the same shape and dtype as `A_inv` """ # NOTE: Must make sure to use double precision. Otherwise results are very inaccurate A_inv_64 = tf.cast(A_inv, tf.float64) if A_inv.dtype.base_dtype == tf.float32 else A_inv U_64 = tf.cast(U, tf.float64) if U.dtype.base_dtype == tf.float32 else U V_64 = tf.cast(V, tf.float64) if V.dtype.base_dtype == tf.float32 else V assert A_inv_64.dtype.base_dtype == tf.float64 assert U_64.dtype.base_dtype == tf.float64 assert V_64.dtype.base_dtype == tf.float64 A_inv_U = tf.matmul(A_inv_64, U_64) V_A_inv = tf.matmul(V_64, A_inv_64) I = tf.eye(tf.shape(V)[0], dtype=tf.float64) inverse = tf.matrix_inverse(I + tf.matmul(V_64, A_inv_U)) inverse = tf.matmul(A_inv_U, inverse) inverse = tf.matmul(inverse, V_A_inv) inverse = A_inv_64 - inverse assert inverse.dtype.base_dtype == tf.float64 inverse = tf.cast(inverse, tf.float32) if A_inv.dtype.base_dtype == tf.float32 else inverse assert inverse.dtype.base_dtype == A_inv.dtype.base_dtype return inverse