from __future__ import absolute_import
from __future__ import print_function

import autograd.numpy as np
from autograd import grad
from autograd.extend import notrace_primitive

@notrace_primitive
def resampling(w, rs):
    """
    Stratified resampling with "nograd_primitive" to ensure autograd 
    takes no derivatives through it.
    """
    N = w.shape[0]
    bins = np.cumsum(w)
    ind = np.arange(N)
    u = (ind  + rs.rand(N))/N
    
    return np.digitize(u, bins)

def vsmc_lower_bound(prop_params, model_params, y, smc_obj, rs, verbose=False, adapt_resamp=False):
    """
    Estimate the VSMC lower bound. Amenable to (biased) reparameterization 
    gradients.

    .. math::
        ELBO(\theta,\lambda) =
        \mathbb{E}_{\phi}\left[\nabla_\lambda \log \hat p(y_{1:T}) \right]

    Requires an SMC object with 2 member functions:
    -- sim_prop(t, x_{t-1}, y, prop_params, model_params, rs)
    -- log_weights(t, x_t, x_{t-1}, y, prop_params, model_params)
    """
    # Extract constants
    T = y.shape[0]
    Dx = smc_obj.Dx
    N = smc_obj.N

    # Initialize SMC
    X = np.zeros((N,Dx))
    Xp = np.zeros((N,Dx))
    logW = np.zeros(N)
    W = np.exp(logW)
    W /= np.sum(W)
    logZ = 0.
    ESS = 1./np.sum(W**2)/N

    for t in range(T):
        # Resampling
        if adapt_resamp:
            if ESS < 0.5:
                ancestors = resampling(W, rs)
                Xp = X[ancestors]
                logZ = logZ + max_logW + np.log(np.sum(W)) - np.log(N)
                logW = np.zeros(N)
            else:
                Xp = X
        else:
            if t > 0:
                ancestors = resampling(W, rs)
                Xp = X[ancestors]
            else:
                Xp = X

        # Propagation
        X = smc_obj.sim_prop(t, Xp, y, prop_params, model_params, rs)

        # Weighting
        if adapt_resamp:
            logW = logW + smc_obj.log_weights(t, X, Xp, y, prop_params, model_params)
        else:
            logW = smc_obj.log_weights(t, X, Xp, y, prop_params, model_params)
        max_logW = np.max(logW)
        W = np.exp(logW-max_logW)
        if adapt_resamp:
            if t == T-1:
                logZ = logZ + max_logW + np.log(np.sum(W)) - np.log(N)
        else:
            logZ = logZ + max_logW + np.log(np.sum(W)) - np.log(N)
        W /= np.sum(W)
        ESS = 1./np.sum(W**2)/N
    if verbose:
        print('ESS: '+str(ESS))
    return logZ

def sim_q(prop_params, model_params, y, smc_obj, rs, verbose=False):
    """
    Simulates a single sample from the VSMC approximation.

    Requires an SMC object with 2 member functions:
    -- sim_prop(t, x_{t-1}, y, prop_params, model_params, rs)
    -- log_weights(t, x_t, x_{t-1}, y, prop_params, model_params)
    """
    # Extract constants
    T = y.shape[0]
    Dx = smc_obj.Dx
    N = smc_obj.N

    # Initialize SMC
    X = np.zeros((N,T,Dx))
    logW = np.zeros(N)
    W = np.zeros((N,T))
    ESS = np.zeros(T)

    for t in range(T):
        # Resampling
        if t > 0:
            ancestors = resampling(W[:,t-1], rs)
            X[:,:t,:] = X[ancestors,:t,:]

        # Propagation
        X[:,t,:] = smc_obj.sim_prop(t, X[:,t-1,:], y, prop_params, model_params, rs)

        # Weighting
        logW = smc_obj.log_weights(t, X[:,t,:], X[:,t-1,:], y, prop_params, model_params)
        max_logW = np.max(logW)
        W[:,t] = np.exp(logW-max_logW)
        W[:,t] /= np.sum(W[:,t])
        ESS[t] = 1./np.sum(W[:,t]**2)

    # Sample from the empirical approximation
    bins = np.cumsum(W[:,-1])
    u = rs.rand()
    B = np.digitize(u,bins)

    if verbose:
        print('Mean ESS', np.mean(ESS)/N)
        print('Min ESS', np.min(ESS))
        
    return X[B,:,:]