import sys

sys.path.append("./..")
import pdb
import os
import scipy as sp
import scipy.stats as st
import scipy.linalg as la
import numpy as np
import pylab as pl
import torch
from torch.autograd import Variable


def _compose(orig, recon):
    _imgo = []
    _imgr = []
    for i in range(orig.shape[0]):
        _imgo.append(orig[i])
    for i in range(orig.shape[0]):
        _imgr.append(recon[i])
    _imgo = sp.concatenate(_imgo, 1)
    _imgr = sp.concatenate(_imgr, 1)
    _rv = sp.concatenate([_imgo, _imgr], 0)
    _rv = sp.clip(_rv, 0, 1)
    return _rv


def _compose_multi(imgs):
    _imgs = []
    for i in range(len(imgs)):
        _imgs.append([])
        for j in range(imgs[i].shape[0]):
            _imgs[i].append(imgs[i][j])
        _imgs[i] = sp.concatenate(_imgs[i], 1)
    _rv = sp.concatenate(_imgs, 0)
    _rv = sp.clip(_rv, 0, 1)
    return _rv


def callback(epoch, val_queue, vae, history, figname, device):

    with torch.no_grad():

        # compute z
        zm = []
        zs = []
        for batch_i, data in enumerate(val_queue):
            y = data[0].to(device)
            _zm, _zs = vae.encode(y)
            zm.append(_zm.data.cpu().numpy())
            zs.append(_zs.data.cpu().numpy())
        zm, zs = sp.concatenate(zm, 0), sp.concatenate(zs, 0)

        # init fig
        pl.figure(1, figsize=(8, 8))

        # plot history
        xs = sp.arange(1, epoch + 2)
        keys = ["loss", "nll", "kld", "mse"]
        plots = [1, 2, 5, 6]
        for ik, key in enumerate(keys):
            pl.subplot(4, 4, plots[ik])
            pl.title(key)
            pl.plot(xs, history[key], "k")
            if key not in ["lr", "vy"]:
                pl.plot(xs, history[key + "_val"], "r")
            if key == "mse":
                pl.ylim(0.0, 0.01)

        # plot hist of zm and zs
        pl.subplot(4, 4, 13)
        pl.title("Zm")
        _y, _x = np.histogram(zm.ravel(), 30)
        _x = 0.5 * (_x[:-1] + _x[1:])
        pl.plot(_x, _y, "k")
        pl.subplot(4, 4, 14)
        pl.title("log$_{10}$Zs")
        _y, _x = np.histogram(sp.log10(zs.ravel()), 30)
        _x = 0.5 * (_x[:-1] + _x[1:])
        pl.plot(_x, _y, "k")

        # val reconstructions
        _zm = Variable(torch.tensor(zm[:24]), requires_grad=False).to(device)
        Rv = vae.decode(_zm[:24]).data.cpu().numpy().transpose((0, 2, 3, 1))
        Yv = val_queue.dataset.Y[:24].numpy().transpose((0, 2, 3, 1))

        # make plot
        pl.subplot(4, 2, 2)
        _img = _compose(Yv[0:6], Rv[0:6])
        pl.imshow(_img)

        pl.subplot(4, 2, 4)
        _img = _compose(Yv[6:12], Rv[6:12])
        pl.imshow(_img)

        pl.subplot(4, 2, 6)
        _img = _compose(Yv[12:18], Rv[12:18])
        pl.imshow(_img)

        pl.subplot(4, 2, 8)
        _img = _compose(Yv[18:24], Rv[18:24])
        pl.imshow(_img)

        pl.savefig(figname)
        pl.close()


def callback_gppvae0(epoch, history, covs, imgs, ffile):

    # init fig
    pl.figure(1, figsize=(8, 8))
    pl.subplot(4, 4, 1)
    pl.title("loss")
    pl.plot(history["loss"])
    pl.subplot(4, 4, 2)
    pl.title("vars")
    pl.plot(sp.array(history["vs"])[:, 0], "r")
    pl.plot(sp.array(history["vs"])[:, 1], "k")
    pl.subplot(4, 4, 5)
    pl.title("mse_out")
    pl.plot(history["mse_out"])

    pl.subplot(4, 4, 9)
    pl.title("XX")
    pl.imshow(covs["XX"], vmin=-0.4, vmax=1)
    pl.colorbar()
    pl.subplot(4, 4, 10)
    pl.title("WW")
    pl.imshow(covs["WW"], vmin=-0.4, vmax=1)
    pl.colorbar()

    Yv, Rv = imgs["Yv"], imgs["Yo"]

    # make plot
    pl.subplot(4, 2, 2)
    _img = _compose(Yv[0:6], Rv[0:6])
    pl.imshow(_img)

    pl.subplot(4, 2, 4)
    _img = _compose(Yv[6:12], Rv[6:12])
    pl.imshow(_img)

    pl.subplot(4, 2, 6)
    _img = _compose(Yv[12:18], Rv[12:18])
    pl.imshow(_img)

    pl.subplot(4, 2, 8)
    _img = _compose(Yv[18:24], Rv[18:24])
    pl.imshow(_img)

    pl.savefig(ffile)
    pl.close()


def callback_gppvae(epoch, history, covs, imgs, ffile):

    # init fig
    pl.figure(1, figsize=(8, 8))
    pl.subplot(4, 4, 1)
    pl.title("loss")
    pl.plot(history["loss"], "k")
    pl.subplot(4, 4, 2)
    pl.title("vars")
    pl.plot(sp.array(history["vs"])[:, 0], "r")
    pl.plot(sp.array(history["vs"])[:, 1], "k")
    pl.ylim(0, 1)
    pl.subplot(4, 4, 5)
    pl.title("recon_term")
    pl.plot(history["recon_term"], "k")
    pl.subplot(4, 4, 6)
    pl.title("gp_nll")
    pl.plot(history["gp_nll"], "k")
    pl.subplot(4, 4, 9)
    pl.title("mse_out")
    pl.plot(history["mse_out"], "k")
    pl.ylim(0, 0.1)
    pl.subplot(4, 4, 10)
    pl.title("mse")
    pl.plot(history["mse"], "k")
    pl.plot(history["mse_val"], "r")
    pl.ylim(0, 0.01)

    pl.subplot(4, 4, 13)
    pl.title("XX")
    pl.imshow(covs["XX"], vmin=-0.4, vmax=1)
    pl.colorbar()
    pl.subplot(4, 4, 14)
    pl.title("WW")
    pl.imshow(covs["WW"], vmin=-0.4, vmax=1)
    pl.colorbar()

    Yv, Yr, Rv = imgs["Yv"], imgs["Yr"], imgs["Yo"]

    # make plot
    pl.subplot(4, 2, 2)
    _img = _compose_multi([Yv[0:6], Yr[0:6], Rv[0:6]])
    pl.imshow(_img)

    pl.subplot(4, 2, 4)
    _img = _compose_multi([Yv[6:12], Yr[6:12], Rv[6:12]])
    pl.imshow(_img)

    pl.subplot(4, 2, 6)
    _img = _compose_multi([Yv[12:18], Yr[12:18], Rv[12:18]])
    pl.imshow(_img)

    pl.subplot(4, 2, 8)
    _img = _compose_multi([Yv[18:24], Yr[18:24], Rv[18:24]])
    pl.imshow(_img)

    pl.savefig(ffile)
    pl.close()