#!/usr/bin/env python3 import warnings import torch from .errors import NanError from .warnings import NumericalWarning def psd_safe_cholesky(A, upper=False, out=None, jitter=None, max_tries=3): """Compute the Cholesky decomposition of A. If A is only p.s.d, add a small jitter to the diagonal. Args: :attr:`A` (Tensor): The tensor to compute the Cholesky decomposition of :attr:`upper` (bool, optional): See torch.cholesky :attr:`out` (Tensor, optional): See torch.cholesky :attr:`jitter` (float, optional): The jitter to add to the diagonal of A in case A is only p.s.d. If omitted, chosen as 1e-6 (float) or 1e-8 (double) :attr:`max_tries` (int, optional): Number of attempts (with successively increasing jitter) to make before raising an error. """ try: L = torch.cholesky(A, upper=upper, out=out) return L except RuntimeError as e: isnan = torch.isnan(A) if isnan.any(): raise NanError( f"cholesky_cpu: {isnan.sum().item()} of {A.numel()} elements of the {A.shape} tensor are NaN." ) if jitter is None: jitter = 1e-6 if A.dtype == torch.float32 else 1e-8 Aprime = A.clone() jitter_prev = 0 for i in range(max_tries): jitter_new = jitter * (10 ** i) Aprime.diagonal(dim1=-2, dim2=-1).add_(jitter_new - jitter_prev) jitter_prev = jitter_new try: L = torch.cholesky(Aprime, upper=upper, out=out) warnings.warn(f"A not p.d., added jitter of {jitter_new} to the diagonal", NumericalWarning) return L except RuntimeError: continue raise e