import warnings import torch def _flatten(sequence): flat = [p.contiguous().view(-1) for p in sequence] return torch.cat(flat) if len(flat) > 0 else torch.tensor([]) def _flatten_convert_none_to_zeros(sequence, like_sequence): flat = [ p.contiguous().view(-1) if p is not None else torch.zeros_like(q).view(-1) for p, q in zip(sequence, like_sequence) ] return torch.cat(flat) if len(flat) > 0 else torch.tensor([]) def _possibly_nonzero(x): return isinstance(x, torch.Tensor) or x != 0 def _scaled_dot_product(scale, xs, ys): """Calculate a scaled, vector inner product between lists of Tensors.""" # Using _possibly_nonzero lets us avoid wasted computation. return sum([(scale * x) * y for x, y in zip(xs, ys) if _possibly_nonzero(x) or _possibly_nonzero(y)]) def _dot_product(xs, ys): """Calculate the vector inner product between two lists of Tensors.""" return sum([x * y for x, y in zip(xs, ys)]) def _has_converged(y0, y1, rtol, atol): """Checks that each element is within the error tolerance.""" error_tol = tuple(atol + rtol * torch.max(torch.abs(y0_), torch.abs(y1_)) for y0_, y1_ in zip(y0, y1)) error = tuple(torch.abs(y0_ - y1_) for y0_, y1_ in zip(y0, y1)) return all((error_ < error_tol_).all() for error_, error_tol_ in zip(error, error_tol)) def _convert_to_tensor(a, dtype=None, device=None): if not isinstance(a, torch.Tensor): a = torch.tensor(a) if dtype is not None: a = a.type(dtype) if device is not None: a = a.to(device) return a def _is_finite(tensor): _check = (tensor == float('inf')) + (tensor == float('-inf')) + torch.isnan(tensor) return not _check.any() def _decreasing(t): return (t[1:] < t[:-1]).all() def _assert_increasing(t): assert (t[1:] > t[:-1]).all(), 't must be strictly increasing or decreasing' def _is_iterable(inputs): try: iter(inputs) return True except TypeError: return False def _norm(x): """Compute RMS norm.""" if torch.is_tensor(x): return x.norm() / (x.numel()**0.5) else: return torch.sqrt(sum(x_.norm()**2 for x_ in x) / sum(x_.numel() for x_ in x)) def _handle_unused_kwargs(solver, unused_kwargs): if len(unused_kwargs) > 0: warnings.warn('{}: Unexpected arguments {}'.format(solver.__class__.__name__, unused_kwargs)) def _select_initial_step(fun, t0, y0, order, rtol, atol, f0=None): """Empirically select a good initial step. The algorithm is described in [1]_. Parameters ---------- fun : callable Right-hand side of the system. t0 : float Initial value of the independent variable. y0 : ndarray, shape (n,) Initial value of the dependent variable. direction : float Integration direction. order : float Method order. rtol : float Desired relative tolerance. atol : float Desired absolute tolerance. Returns ------- h_abs : float Absolute value of the suggested initial step. References ---------- .. [1] E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential Equations I: Nonstiff Problems", Sec. II.4. """ t0 = t0.to(y0[0]) if f0 is None: f0 = fun(t0, y0) rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0) atol = atol if _is_iterable(atol) else [atol] * len(y0) scale = tuple(atol_ + torch.abs(y0_) * rtol_ for y0_, atol_, rtol_ in zip(y0, atol, rtol)) d0 = tuple(_norm(y0_ / scale_) for y0_, scale_ in zip(y0, scale)) d1 = tuple(_norm(f0_ / scale_) for f0_, scale_ in zip(f0, scale)) if max(d0).item() < 1e-5 or max(d1).item() < 1e-5: h0 = torch.tensor(1e-6).to(t0) else: h0 = 0.01 * max(d0_ / d1_ for d0_, d1_ in zip(d0, d1)) y1 = tuple(y0_ + h0 * f0_ for y0_, f0_ in zip(y0, f0)) f1 = fun(t0 + h0, y1) d2 = tuple(_norm((f1_ - f0_) / scale_) / h0 for f1_, f0_, scale_ in zip(f1, f0, scale)) if max(d1).item() <= 1e-15 and max(d2).item() <= 1e-15: h1 = torch.max(torch.tensor(1e-6).to(h0), h0 * 1e-3) else: h1 = (0.01 / max(d1 + d2))**(1. / float(order + 1)) return torch.min(100 * h0, h1) def _compute_error_ratio(error_estimate, error_tol=None, rtol=None, atol=None, y0=None, y1=None): if error_tol is None: assert rtol is not None and atol is not None and y0 is not None and y1 is not None rtol if _is_iterable(rtol) else [rtol] * len(y0) atol if _is_iterable(atol) else [atol] * len(y0) error_tol = tuple( atol_ + rtol_ * torch.max(torch.abs(y0_), torch.abs(y1_)) for atol_, rtol_, y0_, y1_ in zip(atol, rtol, y0, y1) ) error_ratio = tuple(error_estimate_ / error_tol_ for error_estimate_, error_tol_ in zip(error_estimate, error_tol)) mean_sq_error_ratio = tuple(torch.mean(error_ratio_ * error_ratio_) for error_ratio_ in error_ratio) return mean_sq_error_ratio def _optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0, dfactor=0.2, order=5): """Calculate the optimal size for the next step.""" mean_error_ratio = max(mean_error_ratio) # Compute step size based on highest ratio. if mean_error_ratio == 0: return last_step * ifactor if mean_error_ratio < 1: dfactor = _convert_to_tensor(1, dtype=torch.float64, device=mean_error_ratio.device) error_ratio = torch.sqrt(mean_error_ratio).to(last_step) exponent = torch.tensor(1 / order).to(last_step) factor = torch.max(1 / ifactor, torch.min(error_ratio**exponent / safety, 1 / dfactor)) return last_step / factor def _check_inputs(func, y0, t): tensor_input = False if torch.is_tensor(y0): tensor_input = True y0 = (y0,) _base_nontuple_func_ = func func = lambda t, y: (_base_nontuple_func_(t, y[0]),) assert isinstance(y0, tuple), 'y0 must be either a torch.Tensor or a tuple' for y0_ in y0: assert torch.is_tensor(y0_), 'each element must be a torch.Tensor but received {}'.format(type(y0_)) if _decreasing(t): t = -t _base_reverse_func = func func = lambda t, y: tuple(-f_ for f_ in _base_reverse_func(-t, y)) for y0_ in y0: if not torch.is_floating_point(y0_): raise TypeError('`y0` must be a floating point Tensor but is a {}'.format(y0_.type())) if not torch.is_floating_point(t): raise TypeError('`t` must be a floating point Tensor but is a {}'.format(t.type())) return tensor_input, func, y0, t