# -*- coding: utf-8 -*-
"""
Created on Fri Feb 21 14:32:31 2020

@author: Jonathan.Carruthers
"""
import logging

import sys
import numpy as np
import scipy.stats as st
import matplotlib.pyplot
from inspect import signature

from pygom.utilR import dmvnorm, rmvnorm
from pygom.loss.ode_loss import SquareLoss, NormalLoss, PoissonLoss

""" v7: - allowing us to specify a constraint on initial conditions when inferring initial conditions """
""" v8: 15-04-2020
    - added a parameter class that allows use to more easily implement different
      prior distributions. The prior distributions are those defined in pygom.utilR:
          - normal
          - gamma
          - beta
          - uniform
    - improved plotting that makes it easy to plot pointwise predictions for
      selected states and plot posterior histograms/pairs plots for specific
      parameters.
"""

#%%
def _log_limits(par, logscale):
    # used in plot_scatter to convert axis limits from a log-scale if necessary
    if (not logscale) and (par.logscale):
       return (10**par.prior_low, 10**par.prior_high)
    else:
       return par.prior_low, par.prior_high

def get_length(attr):
    if hasattr(attr,"__len__"):
        return len(attr)
    else:
        return 0
        
def _get_target(parameters,target):
    # used to separate target_param and target_state from a single list of parameters  
    target_list = [param.name for param in parameters if param.name in target]
    if len(target_list) == 0:
        return None
    else:
        return target_list
    
def get_function(str): 
    # gets a function from a string for a distribution - is there a better way to write this?
    """
    Parameters
    ----------
    str: name of the chosen prior distribution
    """
    #try:
    #    return getattr(sys.modules["__main__"],str)
    #except AttributeError:
    #    pass
    try:
        return getattr(sys.modules["pygom.utilR"],str)
    except:
        raise AttributeError("The chosen distribution is not available, please choose a different one")

def _get_sigma(i,res,weights,indices):
        diff = res[indices] - res[i]
        return np.einsum('ij,ik,i->jk', diff, diff, weights)

#%%
""" Parameter class"""
class Parameter():
    """
    Parameters
    ----------
    name: string specifying the name of the parameter
    distname: string specifying the prior distribution
    distpars: parameters of the chosen prior distribution
    logscale: bool specifying whether parameter should be sampled on a log scale
    """
    def __init__(self, name, distname, *distpars, logscale):
        self.name = name
        self.prior_distribution = distname
        self.prior_rvs = get_function("r"+distname)
        self.prior_density = get_function("d"+distname)
        self.prior_quantile = get_function("q"+distname)
        self.prior_pars = distpars
        self.logscale = logscale

        if self.prior_distribution == "unif":
            self.prior_low, self.prior_high = self.prior_pars
        elif self.prior_distribution == "beta":
            self.prior_low, self.prior_high = 0, 1
        else:
            # we need to round these values so they're a bit nicer
            self.prior_low = np.floor(self.prior_quantile(0.001,*(self.prior_pars)))
            self.prior_high = np.ceil(self.prior_quantile(0.999,*(self.prior_pars)))
        
        sig = signature(self.prior_rvs)
        assert len([param for param in sig.parameters if param not in ["n","seed"]]) == len(distpars), "Check that the correct " \
                + "number of parameters has been specified for the chosen prior distribution."
                
    def random_sample(self):
        return self.prior_rvs(1, *(self.prior_pars), seed=None)
        
    def density(self,x):
        return self.prior_density(x, *(self.prior_pars), log=False)
    
    def plot_prior(self):
        xvalues = np.linspace(self.prior_low, self.prior_high, 1+int((self.prior_high-self.prior_low)/0.001))
        fig, axarr = matplotlib.pyplot.subplots()
        axarr.plot(xvalues,self.density(xvalues))
        axarr.set(xlim=(self.prior_low,self.prior_high), ylim=(0), xlabel=self.name, ylabel="Density")
        fig.show()  
    
    
""" create a loss object by making use of a list of parameters. 
    - This avoids the need to specify `target_param` and `target_state` since 
      these will be determined from the parameters that are provided. 
    - For ABC, it is not necessary to specify an initial guess for each parameter,
      theta, since the first generation will repeatedly sample from the prior
      distributions instead. This therefor avoids the need to specify theta. """
      
def create_loss(loss_type, parameters, ode, x0, t0, t, y, state_name,
                state_weight=None, sigma=None):
    """
    Parameters (see also class `BaseLoss`)
    ----------
    loss_type: class `BaseLoss`
    parameters: list
        a list of objects of class `Parameter`
    ode: class `DeterministicOde`
        the ode class in this package
    x0: numeric
        initial time
    t0: numeric
        initial value
    t: array like
        time points where observations were made
    y: array like
        observations
    state_name: str or list
        the state(s) which the observations came from
    state_weight: array like
        weight for the observations
    sigma: 
    """
    assert t0 != t[0], "Make sure that the times, t, do not include t0"
    assert all(param.name in (ode.param_list+ode.state_list) for param in parameters), "Parameters have been provided that are not in the model"
    
    target_param = _get_target(parameters, ode.param_list)
    target_state = _get_target(parameters, ode.state_list)
    theta = [param.random_sample() for param in parameters if param.name in target_param]
    
    if loss_type == SquareLoss:
        return SquareLoss(theta, ode, x0, t0, t, y, state_name, state_weight, target_param, target_state)
    
    elif loss_type == NormalLoss:
        return NormalLoss(theta, ode, x0, t0, t, y, state_name, sigma, target_param, target_state)
    
    elif loss_type == PoissonLoss:
        return PoissonLoss(theta, ode, x0, t0, t, y, state_name, target_param, target_state)    

#%% 
""" ABC class and methods for obtaining an approximate posterior sample/plotting the results """
class ABC():
    """
    Parameters
    ----------
    loss_object: class `BaseLoss` e.g. SquareLoss
    parameters: list
        a list of objects of class `Parameter`
    constraint: tuple
            specifies the total population size and which state's initial 
            condition should be changed to conserve the population size.
    """
    def __init__(self, loss_object, parameters, constraint=None):
        self.obj = loss_object
        self.parameters = parameters
        
        if self.obj._targetParam is None:
            # perform a normal inference with all of the parameters and any unknown initial conditions
            self.numParam = self.obj._num_param + get_length(self.obj._targetState)
        else:
            # perform inference on specified parameters only, along with any unknown initial conditions
            self.numParam = len(self.obj._targetParam) + get_length(self.obj._targetState)
    
        self.log = np.array([param.logscale for param in self.parameters])
        self.prior_range = np.array([(param.prior_high-param.prior_low) for param in self.parameters])
            
        if constraint is not None:
            self.pop_size = constraint[0]
            self.con_state = self.obj._ode.get_state_index(constraint[1])[0]
            # indices of the states that are not changed when conserving total population size
            self.con_state_indices = [i for i in range(self.obj._ode.num_state) if i!=self.con_state]
                    
        assert all(isinstance(param,Parameter) for param in parameters), "Use the Parameter class to define all parameters"
        

    def get_posterior_sample_original(self, N, tol, G=1, q=None, M=None, progress=False, rerun=False):
        """
        Parameters
        ----------
        N: integer
            the number of samples in each generation
        tol: float or array like
            the initial tolerance or sequence of decreasing tolerances
        G: integer
            the number of generations used in ABC SMC/ ABC SMC MNN
        q: float (0 < q < 1)
            the quantile used to specify the tolerance for future generations in ABC SMC/ ABC SMC MNN
        M: integer
            the number of nearest neighbours used in ABC SMC MNN (M < N)
        progress: bool
            if True, reports the generation number, acceptance rate and threshold after each generation
        rerun: bool
            if False, this is the first attempt to obtain the posterior sample
        """      
        self.N = N
        self.tol = tol
        self.G = G
        self.q = q
        self.M = M
                       
        if not rerun:
            self.res = np.zeros((self.N,self.numParam))
            self.w = np.ones(self.N)
            self.dist = np.zeros(self.N)
        self.acceptance_rate = np.zeros(self.G)
        self.tolerances = np.zeros(self.G)
        
        # perform some checks
        if self.G == 1:
            assert not hasattr(self.tol, "__len__"), "When performing rejection sampling ABC, only provide a single tolerance"
        elif self.q is None:
            assert hasattr(self.tol, "__len__"), "When performing ABC SMC, a list of tolerances or quantile must be provided"
            assert len(self.tol) == self.G, "The number of tolerances specified must be equal to the number of generations"
        else:
            assert not hasattr(self.tol, "__len__"), "When specifying a quantile, only provide an initial tolerance"
            
        if self.M is not None:
            assert (self.M < self.N), "The number of nearest neighbours must be less than the sample size (M < N). Omitting M is equivalent to M = N."
                
        # setting the appropriate function for updating the parameters/initial conditions
        par_update = self._get_update_function()
                
        for g in range(rerun, self.G+rerun):
            tolerance = self.get_tolerance(g-rerun)
            self.tolerances[g-rerun] = tolerance
            
            i = 0
            total_counter = 0
            
            # making copies of the parameters and weights for referencing
            res_old = self.res.copy()
            w_old = self.w.copy()/sum(self.w)

            # getting the correct covariance matrix  
            if (self.M is not None):
                sigma_list = [self.sigma_nearest_neighbours(res_old,k) for k in range(self.N)]
            else:
                tilde_indices = np.where(self.dist < tolerance)[0] # (this should have length self.q*self.N)
                # using einsum
                w_tilde = w_old[tilde_indices]
                w_tilde_norm = w_tilde/sum(w_tilde)
                sigma_list = [_get_sigma(i,res_old,w_tilde_norm,tilde_indices) for i in range(self.N)]
                
            while i < self.N:
                total_counter += 1
                if g == 0: 
                    trial_params = np.array([param.random_sample() for param in self.parameters])
                else:
                    random_index = np.random.choice(self.N, p=w_old)
                    sigma = sigma_list[random_index]
                    trial_params = np.atleast_1d(rmvnorm(1,mean=res_old[random_index],sigma=sigma))
                w1 = np.prod([self.parameters[i].density(trial_params[i]) for i in range(self.numParam)])
                if w1:
                    # converting from log-scale and ensuring total population size is conserved
                    model_params = self._log_parameters(trial_params.copy())
                    par_update(model_params)
                    if hasattr(self,"con_state"):
                        self.obj._x0[self.con_state] = self.pop_size - self.obj._x0[self.con_state_indices].sum() 
                    
                    cost = self.obj.cost()
                    if cost < tolerance:
                        self.res[i] = trial_params
                        self.dist[i] = cost
                        if g == 0:
                            w2 = 1
                        else:
                            # the following definition of wk is fine if the kernel is symmetric e.g. for a normal pdf we have (x-mu)**2 = (mu-x)**2
                            wk = dmvnorm(res_old, mean=self.res[i], sigma=sigma)  
                            w2 = np.dot(wk, w_old)
                        self.w[i] = w1/w2
                        i += 1
            accept_rate = 100*self.N/total_counter
            self.acceptance_rate[g-rerun] = accept_rate
            if progress: print("Generation %s \n tolerance = %.5f \n acceptance rate = %.2f%%\n" % (g+1-rerun,tolerance,accept_rate))
        self.final_tol = tolerance
        if q is not None:
            self.next_tol = np.quantile(self.dist,self.q)        
        

    def get_posterior_sample(self, N, tol, G=1, q=None, M=None, progress=False, rerun=False):
        """
        Parameters
        ----------
        N: integer
            the number of samples in each generation
        tol: float or array like
            the initial tolerance or sequence of decreasing tolerances
        G: integer
            the number of generations used in ABC SMC/ ABC SMC MNN
        q: float (0 < q < 1)
            the quantile used to specify the tolerance for future generations in ABC SMC/ ABC SMC MNN
        M: integer
            the number of nearest neighbours used in ABC SMC MNN (M < N)
        progress: bool
            if True, reports the generation number, acceptance rate and threshold after each generation
        rerun: bool
            if False, this is the first attempt to obtain the posterior sample
        """      
        self.N = N
        self.tol = tol
        self.G = G
        self.q = q
        self.M = M
                       
        if not rerun:
            self.res = np.zeros((self.N,self.numParam))
            self.w = np.ones(self.N)
            self.dist = np.zeros(self.N)
        self.acceptance_rate = np.zeros(self.G)
        self.tolerances = np.zeros(self.G)
    
        if N < 100:
            logging.warn('N is low this may cause errors')#Todo: why does rmvnorm give LinAlgError with low N? better test and catch needed
        # perform some checks
        if self.G == 1:
            assert not hasattr(self.tol, "__len__"), "When performing rejection sampling ABC, only provide a single tolerance"
        elif self.q is None:
            assert hasattr(self.tol, "__len__"), "When performing ABC SMC, a list of tolerances or quantile must be provided"
            assert len(self.tol) == self.G, "The number of tolerances specified must be equal to the number of generations"
        else:
            assert not hasattr(self.tol, "__len__"), "When specifying a quantile, only provide an initial tolerance"
            
        if self.M is not None:
            assert (self.M < self.N), "The number of nearest neighbours must be less than the sample size (M < N). Omitting M is equivalent to M = N."
                
        # setting the appropriate function for updating the parameters/initial conditions
        par_update = self._get_update_function()
            
        for g in range(rerun,self.G+rerun):
            tolerance = self.get_tolerance(g-rerun)
            self.tolerances[g-rerun] = tolerance
            
            i = 0
            total_counter = 0
            
            # making copies of the parameters and weights for referencing
            res_old = self.res.copy()
            w_old = self.w.copy()/sum(self.w)
            # Todo: place these into the dask cluster

            # getting the correct covariance matrix   
            if (self.M is not None):
                sigma_list = [self.sigma_nearest_neighbours(res_old,k) for k in range(self.N)]
            else:
                tilde_indices = np.where(self.dist < tolerance)[0] # (this should have length self.q*self.N)
                # using einsum
                w_tilde = w_old[tilde_indices]
                w_tilde_norm = w_tilde/sum(w_tilde)
                sigma_list = [_get_sigma(i,res_old,w_tilde_norm,tilde_indices) for i in range(self.N)]                
                
            #total_counter = 0
            for i in range(self.N):
                (self.w[i], 
                 rejections, 
                 self.res[i], 
                 self.dist[i]) = self._perform_generation(generation=g,
                                                  sigma_list=sigma_list,
                                                  tolerance=tolerance,
                                                  par_update=par_update,
                                                  res_old=res_old,
                                                  w_old=w_old)

                total_counter += (rejections + 1)
                
            accept_rate = 100 * self.N / total_counter
            self.acceptance_rate[g-rerun] = accept_rate
            if progress: print("Generation %s \n tolerance = %.5f \n acceptance rate = %.2f%%\n" % (g+1-rerun,tolerance,accept_rate))
        
        self.final_tol = tolerance
        if q is not None:
            self.next_tol = np.quantile(self.dist,self.q)        

    def _perform_generation(self, 
                            generation, 
                            sigma_list,
                            tolerance, 
                            par_update,
                            res_old,
                            w_old):
        '''
        Carry out a single generation
        
        Parameters
        ----------
        generation: The generation number
        mnn_sigma: covariance of the M nearest neighbours 
        tolerance: Calculated tolerence for this generation
        par_update: The update function
        res_old: Previous generation parameters
        w_old: Previous generation weights
        '''
        rejections = 0
        while True: # Todo: should be some timeout on this
            if generation == 0: 
                trial_params = np.array([param.random_sample() for param in self.parameters])
            else:
                random_index = np.random.choice(self.N,p=w_old)
                sigma = sigma_list[random_index]
                trial_params = np.atleast_1d(rmvnorm(1,
                                                     mean=res_old[random_index],
                                                     sigma=sigma))
            w1 = np.prod([self.parameters[i].density(trial_params[i]) for i in range(self.numParam)])
            if w1:
                # converting from log-scale and ensuring total population size is conserved
                model_params = self._log_parameters(trial_params.copy())
                par_update(model_params)
                if hasattr(self,"con_state"): 
                    self.obj._x0[self.con_state] = self.pop_size - self.obj._x0[self.con_state_indices].sum() 
                
                cost = self.obj.cost()
                if cost < tolerance:
                    if generation == 0:
                        w2 = 1
                    else:
                        # the following definition of wk is fine if the kernel is symmetric e.g. for a normal pdf we have (x-mu)**2 = (mu-x)**2
                        wk = dmvnorm(res_old, mean=trial_params, sigma=sigma)  
                        w2 = np.dot(wk, w_old)
                    break # sucess so escape from the while
            rejections += 1
        return (w1/w2, rejections, trial_params, cost)


    def continue_posterior_sample(self, N, tol, G=1, q=None, M=None, progress=False):
        """
        Parameters (same as get_posterior_sample)
        ----------
        N: integer
            the number of samples in each generation
        tol: float or array like
            the initial tolerance or sequence of decreasing tolerances
        G: integer
            the number of generations used in ABC SMC/ ABC SMC MNN
        q: float (0 < q < 1)
            the quantile used to specify the tolerance for future generations in ABC SMC/ ABC SMC MNN
        M: integer
            the number of nearest neighbours used in ABC SMC MNN (M < N)
        progress: bool
            if True, reports the generation number, acceptance rate and threshold after each generation
        """
        # perform checks
        assert N == self.N, "For now, set the sample size to be the same as the previous run"
        assert hasattr(self, "res"), "Use 'get_posterior_sample' before 'continue_posterior_sample'"

        if hasattr(tol, "__len__"):
            assert tol[0] <= self.final_tol, "The initial tolerance is greater than the final tolerance from the previous run"
        else:
            assert tol <= self.final_tol, "The initial tolerance is greater than the final tolerance from the previous run"
            
        self.get_posterior_sample(N, tol, G, q, M, progress, rerun=True)
        
                 
    def plot_posterior_histograms(self,plot_params=None,max_ncol=4):
        """
        Parameters
        ----------
        plot_param: list
            specifies which parameters the posterior distributions are plotted for 
        max_ncol: integer
            specifies the maximum number of columns in the figure, purely to
            avoid squashed plots.
        """        
        fit_params = [p.name for p in self.parameters]
        
        if plot_params is not None:
            assert all(p in fit_params for p in plot_params), "You are trying to plot histograms of parameters not included in the inference"
            param_indices = [fit_params.index(p) for p in plot_params]
        else:
            param_indices = [i for i in range(self.numParam)]
        
        numPlotParam = len(param_indices)
        nrows = 1 + (numPlotParam-1)//max_ncol
        ncols = min(numPlotParam,max_ncol)
        f, axarr = matplotlib.pyplot.subplots(nrows,ncols)

        if numPlotParam == 1:
            # plotting the pdf of the prior distribution
            ind = param_indices[0]
            plot_low= min(self.parameters[ind].prior_low, np.floor(min(self.res.T[ind])))
            plot_high = max(self.parameters[ind].prior_high, np.ceil(max(self.res.T[ind])))
            xvalues = np.linspace(plot_low, plot_high, 1+int((plot_high-plot_low)/0.001))
            axarr.plot(xvalues,self.parameters[ind].density(xvalues),color="r",ls="--",alpha=0.75,lw=1.5)
            
            # using kernel density estimation to plot a smoothed histogram
            kernel = st.gaussian_kde(self.res.T[ind])
            axarr.fill_between(xvalues,kernel(xvalues),facecolor=(0,0,1,0.2),edgecolor=(0,0,1,1),lw=2.0)
            axarr.set(xlim=(plot_low,plot_high), ylim=(0), xlabel=fit_params[ind])
            axarr.set(adjustable='box')
        else:
            pp = 0
            for pp, ax in enumerate(f.axes):
                # plotting the pdf of the prior distribution
                ind = param_indices[pp]
                plot_low= min(self.parameters[ind].prior_low, np.floor(min(self.res.T[ind])))
                plot_high = max(self.parameters[ind].prior_high, np.ceil(max(self.res.T[ind])))
                xvalues = np.linspace(plot_low, plot_high, 1+int((plot_high-plot_low)/0.001))
                ax.plot(xvalues,self.parameters[ind].density(xvalues),color="r",ls="--",alpha=0.75,lw=1.5)

                # using kernel density estimation to plot a smoothed histogram
                kernel = st.gaussian_kde(self.res.T[ind])
                ax.fill_between(xvalues,kernel(xvalues),facecolor=(0,0,1,0.2),edgecolor=(0,0,1,1),lw=2.0)
                ax.set(xlim=(plot_low,plot_high), ylim=(0), xlabel=fit_params[ind])
                ax.set(adjustable='box')
                if pp == numPlotParam-1: break
        f.tight_layout()
        f.show()      
        return f
                
                
    def plot_pointwise_predictions(self,plot_states=None,new_time=None,max_ncol=3):
        """
        Parameters
        ----------
        plot_states: list
            specifies which states the solution is plotted for
        new_time: array like
            an array of new times to plot the model solution for. This temporarily
            overwrites the array of times used for the inference, but this change
            is reverted once the plotting is completed in case the user wants to
            run 'continue_posterior_sample'.
        max_ncol: integer
            specifies the maximum number of columns in the figure, purely to
            avoid squashed plots.
        """
        par_update = self._get_update_function()
        
        # getting the indices of the states we want to plot the solution for
        if plot_states is not None:
            assert all(s in self.obj._ode.state_list for s in plot_states), "It is only possible to plot the solution for states already in the model"
            state_indices = [self.obj._ode.state_list.index(s) for s in plot_states]
        else:
            state_indices = [i for i in range(self.obj._num_state)]       
        
        # setting up an empty array to store the result for each parameter set
        if new_time is None:
            tt = self.obj._observeT
        else:
            tt = new_time       
        
        # formatting the array so we don't get an error when fitting to one variable
        if len(self.obj._stateName) == 1:
            self.obj._y = self.obj._y.reshape((1,len(self.obj._y))).T
                
        numStates = self.obj._num_state
        numPlotStates = len(state_indices)
        
        nrows = 1 + (numPlotStates-1)//max_ncol
        ncols = min(numPlotStates,max_ncol)
        
        # finding the point-wise median solution and 95% credible regions
        solution = np.zeros((self.N,len(tt)*numStates))
        for i in range(self.N):
            params = self.res[i].copy()
            params = self._log_parameters(params)
            par_update(params)
            if hasattr(self,"con_state"):
                self.obj._x0[self.con_state] = self.pop_size - self.obj._x0[self.con_state_indices].sum()
            self.obj._ode.parameters = self.obj._theta
            self.obj._ode.initial_state = self.obj._x0
            solution[i] = self.obj._ode.integrate(tt)[1:,].T.flatten()
        
        median = np.median(solution,axis=0).reshape((numStates,len(tt)))
        credible_95_high = np.quantile(solution,q=0.975,axis=0).reshape((numStates,len(tt)))
        credible_95_low = np.quantile(solution,q=0.025,axis=0).reshape((numStates,len(tt)))
        
        self.median = median
        self.credible_95_high = credible_95_high
        self.credible_95_low = credible_95_low
        
        # plotting the solution
        f, axarr = matplotlib.pyplot.subplots(nrows,ncols,squeeze=0)
        for pp, ax in enumerate(f.axes):
            ind = state_indices[pp]            
            ax.plot(tt,median[ind],color='r')
            ax.fill_between(tt,credible_95_low[ind],credible_95_high[ind],color='gray',alpha=0.5)
            try:
                dd = self.obj._stateName.index(self.obj._ode.state_list[ind])
                ax.scatter(self.obj._observeT,self.obj._y[::,dd], marker='o',facecolor='gray',alpha=0.5,edgecolor=(0,0,0,1))
            except:
                pass
            ax.set(xlabel='Time',title=str(self.obj._ode.state_list[ind]))
            if pp == numPlotStates-1: break
        f.tight_layout()
        f.show()
        return f


    def plot_scatter(self,plot_params=None,logscale=True):
        """
        Parameters
        ----------
        plot_params: list
            specifies which parameters should be included in the pairs plot, the list should
            ideally contain the names of two or more parameters.
        logscale: bool
            indicates whether parameters sampled on a log-scale should remain on a log-scale 
        """
        fit_params = [p.name for p in self.parameters]
        
        if plot_params is not None:
            assert all(p in fit_params for p in plot_params), "You are trying to plot parameters that were not included in the inference"
            param_indices = [fit_params.index(p) for p in plot_params]
        else:
            param_indices = [i for i in range(self.numParam)]
        numPlotParam = len(param_indices)
        
        posterior_sample = self.res.copy()
        if not logscale:
            posterior_sample[:,self.log] = 10**posterior_sample[:,self.log]
                    
        f, axarr = matplotlib.pyplot.subplots(numPlotParam,numPlotParam,squeeze=0)
        for i in range(numPlotParam):
            for j in range(numPlotParam):
                iind, jind = param_indices[i], param_indices[j]
                axarr[i,j].scatter(posterior_sample.T[jind],posterior_sample.T[iind],marker='o', s=5, color='b', alpha=0.2)
                axarr[i,j].set(xlabel=fit_params[jind],ylabel=fit_params[iind])               
                axarr[i,j].set(xlim=_log_limits(self.parameters[jind],logscale))
                axarr[i,j].set(ylim=_log_limits(self.parameters[iind],logscale))
        for ax in axarr.flat:
            ax.label_outer()
        f.tight_layout()
        f.show()
     
        
    def get_tolerance(self,g):
        """
        Parameters
        ----------
        g: integer
           generation number of the ABC-SMC/MNN algorithm
        """
        # choose the tolerance given the generation number and how q and tol are defined
        if g == 0:
            if not hasattr(self.tol, "__len__"):
                return self.tol
            else:
                return self.tol[0]
        else:
            if self.q is not None:
                return np.quantile(self.dist,self.q)
            else:
                return self.tol[g]
        
        
    def sigma_nearest_neighbours(self,xx,index):
        """
        Parameters
        ----------
        xx: array like
            array of parameters
        index: integer
            index of the parameter set the nearest neighbours will be found for        
        """
        # find the covariance matrix of the M nearest particles to a specified particle
        if self.M == self.N-1:
            return np.cov(xx.T)
        else:
            diff = (xx - xx[index])/self.prior_range
            euclidean_norm = np.sum(diff**2,axis=1)
            nn = np.argpartition(euclidean_norm,self.M+1)[:self.M+1]
            return np.cov(xx[nn].T)
                
        
    def _get_update_function(self):
        if self.obj._targetState is None:
            return self.obj._setParam
        else:
            return self.obj._setParamStateInput
        
        
    def _log_parameters(self, params):
        """
        Parameters
        ----------
        params: array like
            array containing values for each of the inferred parameters
        """
        if hasattr(params,"__len__"):
            params[self.log] = 10**params[self.log]
        else:
            if self.log[0]:
                params = 10**params
        return params
    
    
    def _vprod(self,a1,a2):
        diff = (a1-a2).reshape((self.numParam,1))
        return np.dot(diff,diff.T)