from .base import SequentialParticleAlgorithm from .kernels import ParticleMetropolisHastings, SymmetricMH from ..utils import get_ess from ..filters import ParticleFilter from ..module import TensorContainer from torch import isfinite class SMC2(SequentialParticleAlgorithm): def __init__(self, filter_, particles, threshold=0.2, kernel: ParticleMetropolisHastings = None, max_increases=5): """ Implements the SMC2 algorithm by Chopin et al. :param threshold: The threshold at which to perform MCMC rejuvenation :param kernel: The kernel to use when updating the parameters """ super().__init__(filter_, particles) # ===== When and how to update ===== # self._threshold = threshold * particles self._kernel = kernel or SymmetricMH() if not isinstance(self._kernel, ParticleMetropolisHastings): raise ValueError(f'The kernel must be of instance {ParticleMetropolisHastings.__class__.__name__}!') # ===== Some helpers to figure out whether to raise ===== # self._max_increases = max_increases self._increases = 0 # ===== Save data ===== # self._y = TensorContainer() def _update(self, y): # ===== Save data ===== # self._y.append(y) # ===== Perform a filtering move ===== # _, ll = self.filter.filter(y) self._w_rec += ll # ===== Calculate efficient number of samples ===== # ess = get_ess(self._w_rec) self._logged_ess.append(ess) # ===== Rejuvenate if there are too few samples ===== # if ess < self._threshold or (~isfinite(self._w_rec)).any(): self.rejuvenate() self._w_rec[:] = 0. return self def rejuvenate(self): """ Rejuvenates the particles using a PMCMC move. :return: Self """ # ===== Update the description ===== # self._kernel.set_data(self._y.tensors) self._kernel.update(self.filter.ssm.theta_dists, self.filter, self._w_rec) # ===== Increase states if less than 20% are accepted ===== # if self._kernel.accepted < 0.2 and isinstance(self.filter, ParticleFilter): self._increase_states() return self def _increase_states(self): """ Increases the number of states. :return: Self """ if self._increases >= self._max_increases: raise Exception(f'Configuration only allows {self._max_increases}!') # ===== Create new filter with double the state particles ===== # oldlogl = self.filter.result.loglikelihood self.filter.reset() self.filter.particles = 2 * self.filter.particles[1] self.filter.set_nparallel(self._w_rec.shape[0]).initialize().longfilter(self._y.tensors, bar=False) # ===== Calculate new weights and replace filter ===== # self._w_rec = self.filter.result.loglikelihood - oldlogl self._increases += 1 return self