# This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. import numpy as np __all__ = ['fit_cubic', 'fit_quartic', 'findroot'] def rms(A): if A.size == 0: return None return np.sqrt(np.sum(A ** 2) / A.size) def pinv(A, log=lambda _: None): U, D, V = np.linalg.svd(A) thre = 1e3 thre_log = 1e8 gaps = D[:-1] / D[1:] try: n = np.flatnonzero(gaps > thre)[0] except IndexError: n = len(gaps) else: gap = gaps[n] if gap < thre_log: log('Pseudoinverse gap of only: {:.1e}'.format(gap)) D[n + 1 :] = 0 D[: n + 1] = 1 / D[: n + 1] return U.dot(np.diag(D)).dot(V) def cross(a, b): return np.array( [ a[1] * b[2] - a[2] * b[1], a[2] * b[0] - a[0] * b[2], a[0] * b[1] - a[1] * b[0], ] ) def fit_cubic(y0, y1, g0, g1): """Fit cubic polynomial to function values and derivatives at x = 0, 1. Returns position and function value of minimum if fit succeeds. Fit does not succeeds if 1. polynomial doesn't have extrema or 2. maximum is from (0,1) or 3. maximum is closer to 0.5 than minimum """ a = 2 * (y0 - y1) + g0 + g1 b = -3 * (y0 - y1) - 2 * g0 - g1 p = np.array([a, b, g0, y0]) r = np.roots(np.polyder(p)) if not np.isreal(r).all(): return None, None r = sorted(x.real for x in r) if p[0] > 0: maxim, minim = r else: minim, maxim = r if 0 < maxim < 1 and abs(minim - 0.5) > abs(maxim - 0.5): return None, None return minim, np.polyval(p, minim) def fit_quartic(y0, y1, g0, g1): """Fit constrained quartic polynomial to function values and erivatives at x = 0,1. Returns position and function value of minimum or None if fit fails or has a maximum. Quartic polynomial is constrained such that it's 2nd derivative is zero at just one point. This ensures that it has just one local extremum. No such or two such quartic polynomials always exist. From the two, the one with lower minimum is chosen. """ def g(y0, y1, g0, g1, c): a = c + 3 * (y0 - y1) + 2 * g0 + g1 b = -2 * c - 4 * (y0 - y1) - 3 * g0 - g1 return np.array([a, b, c, g0, y0]) def quart_min(p): r = np.roots(np.polyder(p)) is_real = np.isreal(r) if is_real.sum() == 1: minim = r[is_real][0].real else: minim = r[(r == max(-abs(r))) | (r == -max(-abs(r)))][0].real return minim, np.polyval(p, minim) # discriminant of d^2y/dx^2=0 D = -((g0 + g1) ** 2) - 2 * g0 * g1 + 6 * (y1 - y0) * (g0 + g1) - 6 * (y1 - y0) ** 2 if D < 1e-11: return None, None else: m = -5 * g0 - g1 - 6 * y0 + 6 * y1 p1 = g(y0, y1, g0, g1, 0.5 * (m + np.sqrt(2 * D))) p2 = g(y0, y1, g0, g1, 0.5 * (m - np.sqrt(2 * D))) if p1[0] < 0 and p2[0] < 0: return None, None [minim1, minval1] = quart_min(p1) [minim2, minval2] = quart_min(p2) if minval1 < minval2: return minim1, minval1 else: return minim2, minval2 class FindrootException(Exception): pass def findroot(f, lim): """Find root of increasing function on (-inf,lim). Assumes f(-inf) < 0, f(lim) > 0. """ d = 1.0 for _ in range(1000): val = f(lim - d) if val > 0: break d = d / 2 # find d so that f(lim-d) > 0 else: raise RuntimeError('Cannot find f(x) > 0') x = lim - d # initial guess dx = 1e-10 # step for numerical derivative fx = f(x) err = abs(fx) for _ in range(1000): fxpdx = f(x + dx) dxf = (fxpdx - fx) / dx x = x - fx / dxf fx = f(x) err_new = abs(fx) if err_new >= err: return x err = err_new else: raise FindrootException()