import numpy as np
import tensorflow as tf


def logistic_logpdf(*, x, mean, logscale):
    """
    log density of logistic distribution
    this operates elementwise
    """
    z = (x - mean) * tf.exp(-logscale)
    return z - logscale - 2 * tf.nn.softplus(z)


def logistic_logcdf(*, x, mean, logscale):
    """
    log cdf of logistic distribution
    this operates elementwise
    """
    z = (x - mean) * tf.exp(-logscale)
    return tf.log_sigmoid(z)


def test_logistic():
    import scipy.stats

    # TF graph for logistic pdf computation
    tf.reset_default_graph()
    in_x = tf.placeholder(tf.float64, [None])
    in_means = tf.placeholder(tf.float64, [None])
    in_logscales = tf.placeholder(tf.float64, [None])
    out_logpdf = logistic_logpdf(x=in_x, mean=in_means, logscale=in_logscales)
    out_logcdf = logistic_logcdf(x=in_x, mean=in_means, logscale=in_logscales)

    # Evaluate log pdf at these points
    n = 100
    xs = np.linspace(-5, 5, n)

    with tf.Session() as sess:
        # Test against scipy
        for loc in np.linspace(-1, 2, 5):
            for scale in np.linspace(.01, 3, 5):
                true_logpdfs = scipy.stats.logistic.logpdf(xs, loc, scale)
                true_logcdfs = scipy.stats.logistic.logcdf(xs, loc, scale)
                logpdfs, logcdfs = sess.run([out_logpdf, out_logcdf], {
                    in_x: xs,
                    in_means: [loc] * n,
                    in_logscales: np.log([scale] * n)
                })
                assert np.allclose(logpdfs, true_logpdfs)
                assert np.allclose(logcdfs, true_logcdfs)


def mixlogistic_logpdf(*, x, prior_logits, means, logscales):
    """logpdf of a mixture of logistics"""
    assert len(x.get_shape()) + 1 == len(prior_logits.get_shape()) == len(means.get_shape()) == len(
        logscales.get_shape())
    return tf.reduce_logsumexp(
        tf.nn.log_softmax(prior_logits, axis=-1) + logistic_logpdf(
            x=tf.expand_dims(x, -1), mean=means, logscale=logscales),
        axis=-1
    )


def mixlogistic_logcdf(*, x, prior_logits, means, logscales):
    """log cumulative distribution function of a mixture of logistics"""
    assert (len(x.get_shape()) + 1 == len(prior_logits.get_shape()) ==
            len(means.get_shape()) == len(logscales.get_shape()))
    return tf.reduce_logsumexp(
        tf.nn.log_softmax(prior_logits, axis=-1) + logistic_logcdf(
            x=tf.expand_dims(x, -1), mean=means, logscale=logscales),
        axis=-1
    )


def test_logistic_mixture():
    import scipy.stats

    tf.reset_default_graph()
    in_x = tf.placeholder(tf.float64, [None])
    in_prior_logits = tf.placeholder(tf.float64, [None, None])
    in_means = tf.placeholder(tf.float64, [None, None])
    in_logscales = tf.placeholder(tf.float64, [None, None])
    out_logpdf = mixlogistic_logpdf(x=in_x, prior_logits=in_prior_logits, means=in_means, logscales=in_logscales)
    out_logcdf = mixlogistic_logcdf(x=in_x, prior_logits=in_prior_logits, means=in_means, logscales=in_logscales)

    n = 100
    xs = np.linspace(-5, 5, n)
    prior_logits = [.1, .2, 4]
    means = [-1., 0., 1]
    logscales = [-5., 0., 0.2]

    with tf.Session() as sess:
        logpdfs, logcdfs = sess.run([out_logpdf, out_logcdf], {
            in_x: xs,
            in_prior_logits: [prior_logits] * n,
            in_means: [means] * n,
            in_logscales: [logscales] * n,
        })

    prior_probs = np.exp(prior_logits) / np.exp(prior_logits).sum()
    scipy_probs = 0.
    scipy_cdfs = 0.
    for p, m, ls in zip(prior_probs, means, logscales):
        scipy_probs += p * scipy.stats.logistic.pdf(xs, m, np.exp(ls))
        scipy_cdfs += p * scipy.stats.logistic.cdf(xs, m, np.exp(ls))

    assert scipy_probs.shape == logpdfs.shape
    assert np.allclose(logpdfs, np.log(scipy_probs))
    assert np.allclose(logcdfs, np.log(scipy_cdfs))


def mixlogistic_sample(*, prior_logits, means, logscales):
    # Sample mixture component
    sampled_inds = tf.argmax(
        prior_logits - tf.log(-tf.log(tf.random_uniform(tf.shape(prior_logits), minval=1e-5, maxval=1. - 1e-5))),
        axis=-1
    )
    sampled_onehot = tf.one_hot(sampled_inds, tf.shape(prior_logits)[-1])
    # Pull out the sampled mixture component
    means = tf.reduce_sum(means * sampled_onehot, axis=-1)
    logscales = tf.reduce_sum(logscales * sampled_onehot, axis=-1)
    # Sample from the component
    u = tf.random_uniform(tf.shape(means), minval=1e-5, maxval=1. - 1e-5)
    x = means + tf.exp(logscales) * (tf.log(u) - tf.log(1. - u))
    return x


def assert_in_range(x, *, min, max):
    """Asserts that x is in [min, max] elementwise"""
    return tf.Assert(tf.logical_and(
        tf.greater_equal(tf.reduce_min(x), min),
        tf.less_equal(tf.reduce_max(x), max)
    ), [x])


def mixlogistic_invcdf(*, y, prior_logits, means, logscales, tol=1e-10, max_bisection_iters=500):
    """inverse cumulative distribution function of a mixture of logistics"""
    assert len(y.shape) + 1 == len(prior_logits.shape) == len(means.shape) == len(logscales.shape)
    dtype = y.dtype
    with tf.control_dependencies([assert_in_range(y, min=0., max=1.)]):
        y = tf.identity(y)

    def body(x, lb, ub, _last_diff):
        cur_y = tf.exp(mixlogistic_logcdf(x=x, prior_logits=prior_logits, means=means, logscales=logscales))
        gt = tf.cast(tf.greater(cur_y, y), dtype=dtype)
        lt = 1 - gt
        new_x = gt * (x + lb) / 2. + lt * (x + ub) / 2.
        new_lb = gt * lb + lt * x
        new_ub = gt * x + lt * ub
        diff = tf.reduce_max(tf.abs(new_x - x))
        return new_x, new_lb, new_ub, diff

    init_x = tf.zeros_like(y)
    maxscales = tf.reduce_sum(tf.exp(logscales), axis=-1, keepdims=True)  # sum of scales across mixture components
    init_lb = tf.reduce_min(means - 50 * maxscales, axis=-1)
    init_ub = tf.reduce_max(means + 50 * maxscales, axis=-1)
    init_diff = tf.constant(np.inf, dtype=dtype)

    out_x, _, _, _ = tf.while_loop(
        cond=lambda _x, _lb, _ub, last_diff: last_diff > tol,
        body=body,
        loop_vars=(init_x, init_lb, init_ub, init_diff),
        back_prop=False,
        maximum_iterations=max_bisection_iters
    )
    assert out_x.shape == y.shape
    return out_x


def test_mixlogistic_invcdf():
    tf.reset_default_graph()

    dtype = tf.float64

    n = 100
    d = 3
    in_x = tf.placeholder(dtype, [n])
    in_prior_logits = tf.placeholder(dtype, [n, d])
    in_means = tf.placeholder(dtype, [n, d])
    in_logscales = tf.placeholder(dtype, [n, d])
    logistic_args = dict(prior_logits=in_prior_logits, means=in_means, logscales=in_logscales)
    out_logcdf = mixlogistic_logcdf(x=in_x, **logistic_args)
    out_inv_cdf = mixlogistic_invcdf(y=tf.exp(out_logcdf), **logistic_args)
    assert out_inv_cdf.shape == in_x.shape
    err = tf.reduce_max(tf.abs(out_inv_cdf - in_x))

    range_max = 30
    xs = np.linspace(-range_max, range_max, n)
    prior_logits = [.1, .2, 4]
    means = [-1., 0., 1]
    logscales = [-5., 0., 0.2]

    with tf.Session() as sess:
        e, a, b = sess.run([err, in_x, out_inv_cdf], {
            in_x: xs,
            in_prior_logits: [prior_logits] * n,
            in_means: [means] * n,
            in_logscales: [logscales] * n,
        })
        print(np.c_[a, b, np.abs(a - b)])
        print(e)
        assert e < 1e-5
        print('ok')


if __name__ == '__main__':
    test_mixlogistic_invcdf()