import torch from types import GeneratorType from .timeseries.parameter import Parameter from torch.distributions import Distribution, TransformedDistribution from copy import deepcopy from .utils import flatten from typing import Union, Tuple, Type, Dict, Callable _OBJTYPENAME = 'objtype' class TensorContainerBase(object): @property def tensors(self) -> Tuple[torch.Tensor, ...]: raise NotImplementedError() # SEE: https://stackoverflow.com/questions/1500718/how-to-override-the-copy-deepcopy-operations-for-a-python-object def __deepcopy__(self, memo): cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): setattr(result, k, deepcopy(v, memo=memo)) return result class TensorContainer(TensorContainerBase): def __init__(self, *args: torch.Tensor): if len(args) == 0: self._cont = tuple() else: self._cont = tuple(args) if not isinstance(args[0], GeneratorType) else tuple(*args) @property def tensors(self): return flatten(self._cont) def append(self, x: Union[torch.Tensor, TensorContainerBase, None]): if x is None: return if not isinstance(x, (torch.Tensor, TensorContainerBase)): raise NotImplementedError() self._cont += (x,) return self def extend(self, *x: torch.Tensor): if not (isinstance(x, tuple) and all(isinstance(t, torch.Tensor) for t in x)): raise NotImplementedError() self._cont += x return self def __getitem__(self, item): return self._cont[item] def __iter__(self): return (t for t in self._cont) def __bool__(self): return not not self._cont def __len__(self): return len(self._cont) class TensorContainerDict(TensorContainerBase): def __init__(self, **kwargs): self._dict = dict(**kwargs) def __setitem__(self, key, value): self._dict[key] = value def __getitem__(self, item): return self._dict[item] def __bool__(self): return not not self._dict @property def tensors(self): return flatten(self.values()) def items(self): return self._dict.items() def values(self): if all(isinstance(v, TensorContainerBase) for v in self._dict.values()): return tuple(d.values() for d in self._dict.values()) return tuple(self._dict.values()) def _find_types(x, type_: Type) -> Dict[str, object]: """ Helper method for finding all type_ in x. :param x: The object :return: Dictionary """ return {k: v for k, v in vars(x).items() if isinstance(v, type_)} # TODO: Wait for pytorch to implement moving entire distributions def _iterate_distribution(d: Distribution) -> Tuple[Distribution, ...]: """ Helper method for iterating over distributions. :param d: The distribution """ res = tuple() if not isinstance(d, TransformedDistribution): res += tuple(_find_types(d, torch.Tensor).values()) for sd in _find_types(d, Distribution).values(): res += _iterate_distribution(sd) else: res += _iterate_distribution(d.base_dist) for t in d.transforms: res += tuple(_find_types(t, torch.Tensor).values()) return res class Module(object): def _find_obj_helper(self, type_: Type): """ Helper object for finding a specific type of objects in self. :param type_: The type to filter on """ return _find_types(self, type_) def modules(self): """ Finds and returns all instances of type module. """ return self._find_obj_helper(Module) def tensors(self) -> Tuple[torch.Tensor, ...]: """ Finds and returns all instances of type module. """ res = tuple() # ===== Find all tensor types ====== # res += tuple(self._find_obj_helper(torch.Tensor).values()) # ===== Tensor containers ===== # for tc in self._find_obj_helper(TensorContainerBase).values(): res += tc.tensors for t in (t_ for t_ in tc.tensors if isinstance(t_, Parameter) and t_.trainable): res += _iterate_distribution(t.distr) # ===== Pytorch distributions ===== # for d in self._find_obj_helper(Distribution).values(): res += _iterate_distribution(d) # ===== Modules ===== # for mod in self.modules().values(): res += mod.tensors() return res def apply(self, f: Callable[[torch.Tensor], torch.Tensor]): """ Applies function f to all tensors. :param f: The callable :return: Self """ for t in (t_ for t_ in self.tensors() if t_._base is None): t.data = f(t.data) if t._grad is not None: t._grad.data = f(t._grad.data) for t in (t_ for t_ in self.tensors() if t_._base is not None): # TODO: Not too sure about this one, happens for some distributions if t._base.dim() > 0: t.data = t._base.data.view(t.data.shape) else: t.data = f(t.data) return self def to_(self, device: str): """ Move to device. :param device: The device to move to :return: Self """ return self.apply(lambda u: u.to(device)) def state_dict(self) -> Dict[str, object]: """ Returns the state dictionary. """ res = dict() res[_OBJTYPENAME] = self.__class__.__name__ # ===== Tensors ===== # tens = self._find_obj_helper(torch.Tensor) res.update(tens) # ===== Tensor containers ===== # conts = self._find_obj_helper(TensorContainerBase) res.update(conts) # ===== Modules ===== # modules = self.modules() for k, m in modules.items(): res[k] = m.state_dict() return res def load_state_dict(self, state: Dict[str, object]): """ Loads the state dictionary. :param state: The state dictionary :return: Self """ from .timeseries.base import StochasticProcess if state[_OBJTYPENAME] != self.__class__.__name__: raise ValueError(f'Cannot cast {state[_OBJTYPENAME]} as {self.__class__.__name__}!') for k, v in ((k_, v_) for k_, v_ in state.items() if k_ != _OBJTYPENAME): attr = getattr(self, k) if isinstance(attr, Module): attr.load_state_dict(v) if isinstance(attr, StochasticProcess): attr.viewify_params(torch.Size([])) elif isinstance(v, TensorContainerBase) and all(isinstance(i, Parameter) for i in v.tensors): for new, old in zip(v.tensors, getattr(self, k).tensors): new._prior = old._prior setattr(self, k, v) else: setattr(self, k, v) return self