'''
MMD functions for theano variables.
For each variant, returns (mmd2, objective_to_maximize).

If you just want to call them, do e.g.:

    Xth, Yth = T.matrices('X', 'Y')
    sigmath = T.scalar('sigma')
    fn = theano.function([Xth, Yth, sigmath],
                         rbf_mmd2_and_ratio(Xth, Yth, sigma=sigmath))

    mmd2, ratio = fn(X, Y, 1)
'''
from __future__ import division
import numpy as np
import theano.tensor as T
from theano.tensor import slinalg

_eps = 1e-8

################################################################################
### Quadratic-time MMD with Gaussian RBF kernel

def rbf_mmd2(X, Y, sigma=0, biased=True):
    gamma = 1 / (2 * sigma**2)

    XX = T.dot(X, X.T)
    XY = T.dot(X, Y.T)
    YY = T.dot(Y, Y.T)

    X_sqnorms = T.diagonal(XX)
    Y_sqnorms = T.diagonal(YY)

    K_XY = T.exp(-gamma * (
            -2 * XY + X_sqnorms[:, np.newaxis] + Y_sqnorms[np.newaxis, :]))
    K_XX = T.exp(-gamma * (
            -2 * XX + X_sqnorms[:, np.newaxis] + X_sqnorms[np.newaxis, :]))
    K_YY = T.exp(-gamma * (
            -2 * YY + Y_sqnorms[:, np.newaxis] + Y_sqnorms[np.newaxis, :]))

    if biased:
        mmd2 = K_XX.mean() + K_YY.mean() - 2 * K_XY.mean()
    else:
        m = K_XX.shape[0]
        n = K_YY.shape[0]

        mmd2 = ((K_XX.sum() - m) / (m * (m - 1))
              + (K_YY.sum() - n) / (n * (n - 1))
              - 2 * K_XY.mean())
    return mmd2, mmd2


def rbf_mmd2_and_ratio(X, Y, sigma=0, biased=True):
    gamma = 1 / (2 * sigma**2)

    XX = T.dot(X, X.T)
    XY = T.dot(X, Y.T)
    YY = T.dot(Y, Y.T)

    X_sqnorms = T.diagonal(XX)
    Y_sqnorms = T.diagonal(YY)

    K_XY = T.exp(-gamma * (
            -2 * XY + X_sqnorms[:, np.newaxis] + Y_sqnorms[np.newaxis, :]))
    K_XX = T.exp(-gamma * (
            -2 * XX + X_sqnorms[:, np.newaxis] + X_sqnorms[np.newaxis, :]))
    K_YY = T.exp(-gamma * (
            -2 * YY + Y_sqnorms[:, np.newaxis] + Y_sqnorms[np.newaxis, :]))

    return _mmd2_and_ratio(K_XX, K_XY, K_YY, unit_diagonal=True, biased=biased)


################################################################################
### Linear-time MMD with Gaussian RBF kernel

# Estimator and the idea of optimizing the ratio from:
#    Gretton, Sriperumbudur, Sejdinovic, Strathmann, and Pontil.
#    Optimal kernel choice for large-scale two-sample tests. NIPS 2012.

def rbf_mmd2_streaming(X, Y, sigma=0):
    # n = (T.smallest(X.shape[0], Y.shape[0]) // 2) * 2
    n = (X.shape[0] // 2) * 2
    gamma = 1 / (2 * sigma**2)
    rbf = lambda A, B: T.exp(-gamma * ((A - B) ** 2).sum(axis=1))
    mmd2 = (rbf(X[:n:2], X[1:n:2]) + rbf(Y[:n:2], Y[1:n:2])
          - rbf(X[:n:2], Y[1:n:2]) - rbf(X[1:n:2], Y[:n:2])).mean()
    return mmd2, mmd2


def rbf_mmd2_streaming_and_ratio(X, Y, sigma=0):
    # n = (T.smallest(X.shape[0], Y.shape[0]) // 2) * 2
    n = (X.shape[0] // 2) * 2
    gamma = 1 / (2 * sigma**2)
    rbf = lambda A, B: T.exp(-gamma * ((A - B) ** 2).sum(axis=1))
    h_bits = (rbf(X[:n:2], X[1:n:2]) + rbf(Y[:n:2], Y[1:n:2])
            - rbf(X[:n:2], Y[1:n:2]) - rbf(X[1:n:2], Y[:n:2]))

    mmd2 = h_bits.mean()

    # variance is 1/2 E_{v, v'} (h(v) - h(v'))^2
    # estimate with even, odd diffs
    m = (n // 2) * 2
    approx_var = 1/2 * ((h_bits[:m:2] - h_bits[1:m:2]) ** 2).mean()
    ratio = mmd2 / T.sqrt(T.largest(approx_var, _eps))
    return mmd2, ratio


################################################################################
### MMD with linear kernel

# Hotelling test statistic is from:
#    Jitkrittum, Szabo, Chwialkowski, and Gretton.
#    Interpretable Distribution Features with Maximum Testing Power.
#    NIPS 2015.

def linear_mmd2(X, Y, biased=True):
    if not biased:
        raise ValueError("Haven't implemented unbiased linear_mmd2 yet")
    X_bar = X.mean(axis=0)
    Y_bar = Y.mean(axis=0)
    Z_bar = X_bar - Y_bar
    mmd2 = Z_bar.dot(Z_bar)
    return mmd2, mmd2


def linear_mmd2_and_hotelling(X, Y, biased=True, reg=0):
    if not biased:
        raise ValueError("linear_mmd2_and_hotelling only works for biased est")

    n = X.shape[0]
    p = X.shape[1]
    Z = X - Y
    Z_bar = Z.mean(axis=0)
    mmd2 = Z_bar.dot(Z_bar)

    Z_cent = Z - Z_bar
    S = Z_cent.T.dot(Z_cent) / (n - 1)
    # z' inv(S) z = z' inv(L L') z = z' inv(L)' inv(L) z = ||inv(L) z||^2
    L = slinalg.cholesky(S + reg * T.eye(p))
    Linv_Z_bar = slinalg.solve_lower_triangular(L, Z_bar)
    lambda_ = n * Linv_Z_bar.dot(Linv_Z_bar)
    # happens on the CPU!
    return mmd2, lambda_


def linear_mmd2_and_ratio(X, Y, biased=True):
    # TODO: can definitely do this faster for a linear kernel...
    K_XX = T.dot(X, X.T)
    K_XY = T.dot(X, Y.T)
    K_YY = T.dot(Y, Y.T)
    return _mmd2_and_ratio(K_XX, K_XY, K_YY, unit_diagonal=False, biased=biased)


################################################################################
### Helper functions to compute variances based on kernel matrices

def _mmd2_and_ratio(K_XX, K_XY, K_YY, unit_diagonal=False, biased=False,
                    min_var_est=_eps):
    mmd2, var_est = _mmd2_and_variance(
        K_XX, K_XY, K_YY, unit_diagonal=unit_diagonal, biased=biased)
    ratio = mmd2 / T.sqrt(T.largest(var_est, min_var_est))
    return mmd2, ratio


def _mmd2_and_variance(K_XX, K_XY, K_YY, unit_diagonal=False, biased=False):
    m = K_XX.shape[0]  # Assumes X, Y are same shape

    ### Get the various sums of kernels that we'll use
    # Kts drop the diagonal, but we don't need to compute them explicitly
    if unit_diagonal:
        diag_X = diag_Y = 1
        sum_diag_X = sum_diag_Y = m
        sum_diag2_X = sum_diag2_Y = m
    else:
        diag_X = T.diagonal(K_XX)
        diag_Y = T.diagonal(K_YY)

        sum_diag_X = diag_X.sum()
        sum_diag_Y = diag_Y.sum()

        sum_diag2_X = diag_X.dot(diag_X)
        sum_diag2_Y = diag_Y.dot(diag_Y)

    Kt_XX_sums = K_XX.sum(axis=1) - diag_X
    Kt_YY_sums = K_YY.sum(axis=1) - diag_Y
    K_XY_sums_0 = K_XY.sum(axis=0)
    K_XY_sums_1 = K_XY.sum(axis=1)

    Kt_XX_sum = Kt_XX_sums.sum()
    Kt_YY_sum = Kt_YY_sums.sum()
    K_XY_sum = K_XY_sums_0.sum()

    # TODO: turn these into dot products?
    # should figure out if that's faster or not on GPU / with theano...
    Kt_XX_2_sum = (K_XX ** 2).sum() - sum_diag2_X
    Kt_YY_2_sum = (K_YY ** 2).sum() - sum_diag2_Y
    K_XY_2_sum  = (K_XY ** 2).sum()

    if biased:
        mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m)
              + (Kt_YY_sum + sum_diag_Y) / (m * m)
              - 2 * K_XY_sum / (m * m))
    else:
        mmd2 = (Kt_XX_sum / (m * (m-1))
              + Kt_YY_sum / (m * (m-1))
              - 2 * K_XY_sum / (m * m))

    var_est = (
          2 / (m**2 * (m-1)**2) * (
              2 * Kt_XX_sums.dot(Kt_XX_sums) - Kt_XX_2_sum
            + 2 * Kt_YY_sums.dot(Kt_YY_sums) - Kt_YY_2_sum)
        - (4*m-6) / (m**3 * (m-1)**3) * (Kt_XX_sum**2 + Kt_YY_sum**2)
        + 4*(m-2) / (m**3 * (m-1)**2) * (
              K_XY_sums_1.dot(K_XY_sums_1)
            + K_XY_sums_0.dot(K_XY_sums_0))
        - 4 * (m-3) / (m**3 * (m-1)**2) * K_XY_2_sum
        - (8*m - 12) / (m**5 * (m-1)) * K_XY_sum**2
        + 8 / (m**3 * (m-1)) * (
              1/m * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
            - Kt_XX_sums.dot(K_XY_sums_1)
            - Kt_YY_sums.dot(K_XY_sums_0))
    )

    return mmd2, var_est