# Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 import operator from numbers import Number import numpy as np from multipledispatch import Dispatcher _builtin_abs = abs _builtin_all = all _builtin_any = any _builtin_max = max _builtin_min = min _builtin_pow = pow _builtin_sum = sum class Op(Dispatcher): def __init__(self, fn): super(Op, self).__init__(fn.__name__) # register as default operation for nargs in (1, 2): default_signature = (object,) * nargs self.add(default_signature, fn) def __repr__(self): return "ops." + self.__name__ def __str__(self): return self.__name__ class TransformOp(Op): def set_inv(self, fn): """ :param callable fn: A function that inputs an arg ``y`` and outputs a value ``x`` such that ``y=self(x)``. """ assert callable(fn) self.inv = fn return fn def set_log_abs_det_jacobian(self, fn): """ :param callable fn: A function that inputs two args ``x, y``, where ``y=self(x)``, and returns ``log(abs(det(dy/dx)))``. """ assert callable(fn) self.log_abs_det_jacobian = fn return fn @staticmethod def inv(x): raise NotImplementedError @staticmethod def log_abs_det_jacobian(x, y): raise NotImplementedError # FIXME Most code assumes this is an AssociativeCommutativeOp. class AssociativeOp(Op): pass class AddOp(AssociativeOp): pass class MulOp(AssociativeOp): pass class MatmulOp(Op): # Associtive but not commutative. pass class LogAddExpOp(AssociativeOp): pass class SampleOp(LogAddExpOp): pass class SubOp(Op): pass class NegOp(Op): pass class DivOp(Op): pass class NullOp(AssociativeOp): """Placeholder associative op that unifies with any other op""" pass @NullOp def nullop(x, y): raise ValueError("should never actually evaluate this!") class ReshapeMeta(type): _cache = {} def __call__(cls, shape): shape = tuple(shape) try: return ReshapeMeta._cache[shape] except KeyError: instance = super().__call__(shape) ReshapeMeta._cache[shape] = instance return instance class ReshapeOp(Op, metaclass=ReshapeMeta): def __init__(self, shape): self.shape = shape super().__init__(self._default) def _default(self, x): return x.reshape(self.shape) class GetitemMeta(type): _cache = {} def __call__(cls, offset): try: return GetitemMeta._cache[offset] except KeyError: instance = super(GetitemMeta, cls).__call__(offset) GetitemMeta._cache[offset] = instance return instance class GetitemOp(Op, metaclass=GetitemMeta): """ Op encoding an index into one dimension, e.g. ``x[:,:,y]`` for offset of 2. """ def __init__(self, offset): assert isinstance(offset, int) assert offset >= 0 self.offset = offset self._prefix = (slice(None),) * offset super(GetitemOp, self).__init__(self._default) self.__name__ = 'GetitemOp({})'.format(offset) def _default(self, x, y): return x[self._prefix + (y,)] if self.offset else x[y] getitem = GetitemOp(0) abs = Op(_builtin_abs) eq = Op(operator.eq) ge = Op(operator.ge) gt = Op(operator.gt) invert = Op(operator.invert) le = Op(operator.le) lt = Op(operator.lt) ne = Op(operator.ne) neg = NegOp(operator.neg) sub = SubOp(operator.sub) truediv = DivOp(operator.truediv) add = AddOp(operator.add) and_ = AssociativeOp(operator.and_) mul = MulOp(operator.mul) matmul = MatmulOp(operator.matmul) or_ = AssociativeOp(operator.or_) xor = AssociativeOp(operator.xor) @add.register(object) def _unary_add(x): return x.sum() @Op def sqrt(x): return np.sqrt(x) class ExpOp(TransformOp): pass @ExpOp def exp(x): return np.exp(x) @exp.set_log_abs_det_jacobian def log_abs_det_jacobian(x, y): return add(x) class LogOp(TransformOp): pass @LogOp def log(x): if isinstance(x, bool) or (isinstance(x, np.ndarray) and x.dtype == 'bool'): return np.where(x, 0., float('-inf')) with np.errstate(divide='ignore'): # skip the warning of log(0.) return np.log(x) @log.set_log_abs_det_jacobian def log_abs_det_jacobian(x, y): return -add(y) exp.set_inv(log) log.set_inv(exp) @Op def log1p(x): return np.log1p(x) @Op def sigmoid(x): return 1 / (1 + np.exp(-x)) @Op def pow(x, y): return x ** y @AssociativeOp def min(x, y): if hasattr(x, '__min__'): return x.__min__(y) if hasattr(y, '__min__'): return y.__min__(x) return _builtin_min(x, y) @AssociativeOp def max(x, y): if hasattr(x, '__max__'): return x.__max__(y) if hasattr(y, '__max__'): return y.__max__(x) return _builtin_max(x, y) def _logaddexp(x, y): if hasattr(x, "__logaddexp__"): return x.__logaddexp__(y) if hasattr(y, "__rlogaddexp__"): return y.__logaddexp__(x) shift = max(x, y) return log(exp(x - shift) + exp(y - shift)) + shift logaddexp = LogAddExpOp(_logaddexp) sample = SampleOp(_logaddexp) @SubOp def safesub(x, y): if isinstance(y, Number): return sub(x, y) @DivOp def safediv(x, y): if isinstance(y, Number): return truediv(x, y) class ReciprocalOp(Op): pass @ReciprocalOp def reciprocal(x): if isinstance(x, Number): return 1. / x raise ValueError("No reciprocal for type {}".format(type(x))) DISTRIBUTIVE_OPS = frozenset([ (logaddexp, add), (add, mul), (max, mul), (min, mul), (max, add), (min, add), (sample, add), ]) UNITS = { mul: 1., add: 0., } PRODUCT_INVERSES = { mul: safediv, add: safesub, } ###################### # Numeric Array Ops ###################### all = Op(np.all) amax = Op(np.amax) amin = Op(np.amin) any = Op(np.any) astype = Dispatcher("ops.astype") cat = Dispatcher("ops.cat") clamp = Dispatcher("ops.clamp") diagonal = Dispatcher("ops.diagonal") einsum = Dispatcher("ops.einsum") full_like = Op(np.full_like) prod = Op(np.prod) stack = Dispatcher("ops.stack") sum = Op(np.sum) transpose = Dispatcher("ops.transpose") array = (np.ndarray, np.generic) @astype.register(array, str) def _astype(x, dtype): return x.astype(dtype) @cat.register(int, [array]) def _cat(dim, *x): return np.concatenate(x, axis=dim) @clamp.register(array, object, object) def _clamp(x, min, max): return np.clip(x, a_min=min, a_max=max) @Op def cholesky(x): """ Like :func:`numpy.linalg.cholesky` but uses sqrt for scalar matrices. """ if x.shape[-1] == 1: return np.sqrt(x) return np.linalg.cholesky(x) @Op def cholesky_inverse(x): """ Like :func:`torch.cholesky_inverse` but supports batching and gradients. """ return cholesky_solve(new_eye(x, x.shape[:-1]), x) @Op def cholesky_solve(x, y): y_inv = np.linalg.inv(y) A = np.swapaxes(y_inv, -2, -1) @ y_inv return A @ x @Op def detach(x): return x @diagonal.register(array, int, int) def _diagonal(x, dim1, dim2): return np.diagonal(x, axis1=dim1, axis2=dim2) @einsum.register(str, [array]) def _einsum(x, *operand): return np.einsum(x, *operand) @Op def expand(x, shape): prepend_dim = len(shape) - np.ndim(x) assert prepend_dim >= 0 shape = shape[:prepend_dim] + tuple(dx if size == -1 else size for dx, size in zip(np.shape(x), shape[prepend_dim:])) return np.broadcast_to(x, shape) return np.broadcast_to(x, shape) @Op def finfo(x): return np.finfo(x.dtype) @Op def is_numeric_array(x): return True if isinstance(x, array) else False @Op def logsumexp(x, dim): amax = np.amax(x, axis=dim, keepdims=True) # treat the case x = -inf amax = np.where(np.isfinite(amax), amax, 0.) return log(np.sum(np.exp(x - amax), axis=dim)) + amax.squeeze(axis=dim) @max.register(array, array) def _max(x, y): return np.maximum(x, y) @max.register((int, float), array) def _max(x, y): return np.clip(y, a_min=x, a_max=None) @max.register(array, (int, float)) def _max(x, y): return np.clip(x, a_min=y, a_max=None) @min.register(array, array) def _min(x, y): return np.minimum(x, y) @min.register((int, float), array) def _min(x, y): return np.clip(y, a_min=None, a_max=x) @min.register(array, (int, float)) def _min(x, y): return np.clip(x, a_min=None, a_max=y) @Op def new_arange(x, stop): return np.arange(stop) @new_arange.register(array, int, int, int) def _new_arange(x, start, stop, step): return np.arange(start, stop, step) @Op def new_zeros(x, shape): return np.zeros(shape, dtype=x.dtype) @Op def new_eye(x, shape): n = shape[-1] return np.broadcast_to(np.eye(n), shape + (n,)) @Op def permute(x, dims): return np.transpose(x, axes=dims) @reciprocal.register(array) def _reciprocal(x): result = np.clip(np.reciprocal(x), a_max=np.finfo(x.dtype).max) return result @safediv.register(object, array) def _safediv(x, y): try: finfo = np.finfo(y.dtype) except ValueError: finfo = np.iinfo(y.dtype) return x * np.clip(np.reciprocal(y), a_min=None, a_max=finfo.max) @safesub.register(object, array) def _safesub(x, y): try: finfo = np.finfo(y.dtype) except ValueError: finfo = np.iinfo(y.dtype) return x + np.clip(-y, a_min=None, a_max=finfo.max) @stack.register(int, [array]) def _stack(dim, *x): return np.stack(x, axis=dim) @transpose.register(array, int, int) def _transpose(x, dim1, dim2): return np.swapaxes(x, dim1, dim2) @Op def triangular_solve(x, y, upper=False, transpose=False): if transpose: y = np.swapaxes(y, -2, -1) return np.linalg.inv(y) @ x @Op def unsqueeze(x, dim): return np.expand_dims(x, axis=dim) __all__ = [ 'AddOp', 'AssociativeOp', 'DISTRIBUTIVE_OPS', 'ExpOp', 'GetitemOp', 'LogAddExpOp', 'LogOp', 'NegOp', 'Op', 'PRODUCT_INVERSES', 'ReciprocalOp', 'SampleOp', 'SubOp', 'ReshapeOp', 'UNITS', 'abs', 'add', 'all', 'amax', 'amin', 'and_', 'any', 'astype', 'cat', 'cholesky', 'cholesky_inverse', 'cholesky_solve', 'clamp', 'detach', 'diagonal', 'einsum', 'eq', 'exp', 'expand', 'finfo', 'full_like', 'ge', 'getitem', 'gt', 'invert', 'is_numeric_array', 'le', 'log', 'log1p', 'logaddexp', 'logsumexp', 'lt', 'matmul', 'max', 'min', 'mul', 'ne', 'neg', 'new_arange', 'new_eye', 'new_zeros', 'or_', 'pow', 'prod', 'reciprocal', 'safediv', 'safesub', 'sample', 'sigmoid', 'sqrt', 'stack', 'sub', 'sum', 'transpose', 'triangular_solve', 'truediv', 'unsqueeze', 'xor', ] __doc__ = """ Built-in operations ------------------- {} Operation classes ----------------- """.format("\n".join(f".. autodata:: {_name}\n" for _name in __all__ if isinstance(globals()[_name], Op)))