from __future__ import absolute_import

import autograd.numpy as np
import scipy.stats
from autograd.extend import primitive, defvjp
from autograd.numpy.numpy_vjps import unbroadcast_f

cdf = primitive(scipy.stats.poisson.cdf)
logpmf = primitive(scipy.stats.poisson.logpmf)
pmf = primitive(scipy.stats.poisson.pmf)

def grad_poisson_logpmf(k, mu):
    return np.where(k % 1 == 0, k / mu - 1, 0)

defvjp(cdf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * -pmf(np.floor(k), mu)), argnums=[1])
defvjp(logpmf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * grad_poisson_logpmf(k, mu)), argnums=[1])
defvjp(pmf, lambda ans, k, mu: unbroadcast_f(mu, lambda g: g * ans * grad_poisson_logpmf(k, mu)), argnums=[1])