"""
This module contains a script for generating the product of experts image patch figure
"""
import numpy as np
import matplotlib
matplotlib.use('Agg')  # no displayed figures -- need to call before loading pylab
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns
import itertools
from scipy.sparse import rand

from mjhmc.misc.distributions import ProductOfT
from mjhmc.samplers.markov_jump_hmc import MarkovJumpHMC, ControlHMC
from mjhmc.samplers.hmc_state import HMCState

from nuts import nuts6

plt.ion()


# for deterministic params for poet
np.random.seed(1234)

mjhmc_params = {'epsilon' : 0.127, 'beta' : .01,'num_leapfrog_steps' : 1}
control_params = {'epsilon' : 0.065, 'beta' : 0.01, 'num_leapfrog_steps' : 1}
# mjhmc_params = control_params

def generate_figure_samples(samples_per_frame, n_frames, burnin = int(1e4)):
    """ Generates the figure

    :param samples_per_frame: number of sample steps between each frame
    :param n_frames: number of frames to draw
    :returns: None
    :rtype: None
    """
    n_samples = samples_per_frame * n_frames
    ndims = 36
    nbasis = 72

    rand_val = rand(ndims,nbasis/2,density=0.25)
    W = np.concatenate([rand_val.toarray(), -rand_val.toarray()],axis=1)
    logalpha = np.random.randn(nbasis, 1)
    poe = ProductOfT(nbatch=1, W=W, logalpha=logalpha)

    ## NUTS uses a different number of grad evals for each update step!!
    ## makes it very hard to compare against others w/ same number of update steps
    # # NUTS
    # print "NUTS"
    # nuts_init = poe.Xinit[:, 0]
    # nuts_samples = nuts6(poe.reset(), n_samples, nuts_burnin, nuts_init)[0]
    # nuts_frames = [nuts_samples[f_idx * samples_per_frame, :] for f_idx in xrange(0, n_frames)]
    # x_init = nuts_samples[0, :].reshape(ndims, 1)

    ## burnin
    print "MJHMC burnin"
    x_init = poe.Xinit #[:, [0]]
    mjhmc = MarkovJumpHMC(distribution=poe.reset(), **mjhmc_params)
    mjhmc.state = HMCState(x_init.copy(), mjhmc)
    mjhmc_samples = mjhmc.sample(burnin)
    print mjhmc_samples.shape
    x_init = mjhmc_samples[:, [0]]

    # control HMC
    print "Control"
    hmc = ControlHMC(distribution=poe.reset(), **control_params)
    hmc.state = HMCState(x_init.copy(), hmc)
    hmc_samples = hmc.sample(n_samples)
    hmc_frames = [hmc_samples[:, f_idx * samples_per_frame].copy() for f_idx in xrange(0, n_frames)]

    # MJHMC
    print "MJHMC"
    mjhmc = MarkovJumpHMC(distribution=poe.reset(), resample=False, **mjhmc_params)
    mjhmc.state = HMCState(x_init.copy(), mjhmc)
    mjhmc_samples = mjhmc.sample(n_samples)
    mjhmc_frames = [mjhmc_samples[:, f_idx * samples_per_frame].copy() for f_idx in xrange(0, n_frames)]

    print mjhmc.r_count, hmc.r_count
    print mjhmc.l_count, hmc.l_count
    print mjhmc.f_count, hmc.f_count
    print mjhmc.fl_count, hmc.fl_count


    frames = [mjhmc_frames, hmc_frames]
    names = ['MJHMC', 'ControlHMC']
    frame_grads = [f_idx * samples_per_frame for f_idx in xrange(0, n_frames)]
    return frames, names, frame_grads


def plot_imgs(imgs, samp_names, step_nums, vmin = -2, vmax = 2):
    plt.figure(figsize=(5.5,3.6))

    nsamplers = len(samp_names)
    nsteps = len(step_nums)

    plt.subplot(nsamplers+1, nsteps+1, 1)
    plt.axis('off')
    plt.text(0.9, -0.1, "# grads",
        horizontalalignment='right',
        verticalalignment='bottom')

    for step_i in range(nsteps):
        plt.subplot(nsamplers+1, nsteps+1, 2 + step_i)
        plt.axis('off')
        plt.text(0.5, -0.1, "%d"%step_nums[step_i],
            horizontalalignment='center',
            verticalalignment='bottom')
    for samp_i in range(nsamplers):
        plt.subplot(nsamplers+1, nsteps+1, 1 + (samp_i+1)*(nsteps+1))
        plt.axis('off')
        plt.text(0.9, 0.5, samp_names[samp_i],
            horizontalalignment='right',
            verticalalignment='center')


    for samp_i in range(nsamplers):
        for step_i in range(nsteps):
            plt.subplot(nsamplers+1, nsteps+1, 2 + step_i + (samp_i+1)*(nsteps+1))

            ptch = imgs[samp_i][step_i].copy()
            img_w = np.sqrt(np.prod(ptch.shape))
            ptch = ptch.reshape((img_w, img_w))

            ptch -= vmin
            ptch /= vmax-vmin
            plt.imshow(ptch, interpolation='nearest', cmap=cm.Greys_r )
            plt.axis('off')

    # plt.tight_layout()
    plt.savefig('poe_samples.pdf')
    plt.close()



def plot_concat_imgs(imgs, border_thickness=2, axis=None, normalize=False):
    """ concatenate the imgs together into one big image separated by borders

    :param imgs: list or array of images. total number of images must be a perfect square and
      images must be square
    :param border_thickness: how many pixels of border between
    :param axis: optional matplotlib axis object to plot on
    :returns: array containing all receptive fields
    :rtype: array
    """
    sns.set_style('dark')
    assert isinstance(border_thickness, int)
    assert int(np.sqrt(len(imgs))) == np.sqrt(len(imgs))
    assert imgs[0].shape[0] == imgs[0].shape[1]
    if normalize:
        imgs = np.array(imgs)
        imgs /= np.sum(imgs ** 2, axis=(1,2)).reshape(-1, 1, 1)
    img_length = imgs[0].shape[0]
    layer_length = int(np.sqrt(len(imgs)))
    concat_length = layer_length * img_length + (layer_length - 1) * border_thickness
    border_color = np.nan
    concat_rf = np.ones((concat_length, concat_length)) * border_color
    for x_idx, y_idx in itertools.product(xrange(layer_length),
                                          xrange(layer_length)):
        # this keys into imgs
        flat_idx = x_idx * layer_length + y_idx
        x_offset = border_thickness * x_idx
        y_offset = border_thickness * y_idx
        # not sure how to do a continuation line cleanly here
        concat_rf[x_idx * img_length + x_offset: (x_idx + 1) * img_length + x_offset,
                  y_idx * img_length + y_offset: (y_idx + 1) * img_length + y_offset] = imgs[flat_idx]
    if axis is not None:
        axis.imshow(concat_rf, interpolation='none', aspect='auto')
    else:
        plt.imshow(concat_rf, interpolation='none', aspect='auto')

def ac_plot(n_samples=5000, **kwargs):
    """ Plots the autocorrelation for the best found parameters of the 36
    dimensional product of experts

    :returns: None
    :rtype: None
    """

    from mjhmc.figures.ac_fig import plot_best
    ndims = 36
    nbasis = 36

    np.random.seed(2015)
    poe = ProductOfT(nbatch=25,ndims=ndims,nbasis=nbasis)
    plot_best(poe, num_steps=n_samples, update_params=False, **kwargs)