from __future__ import print_function
import os
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec


def plot_reconstructions(data, recon_mean, loss, loss_type, epoch, args):

    if args.input_type == 'multinomial':
        # data is already between 0 and 1
        num_classes = 256
        # Find largest class logit
        tmp = recon_mean.view(-1, num_classes, *args.input_size).max(dim=1)[1]
        recon_mean = tmp.float() / (num_classes - 1.)
    if epoch == 1:
        if not os.path.exists(args.snap_dir + 'reconstruction/'):
            os.makedirs(args.snap_dir + 'reconstruction/')
        # VISUALIZATION: plot real images
        plot_images(args, data.data.cpu().numpy()[0:9], args.snap_dir + 'reconstruction/', 'real', size_x=3, size_y=3)
    # VISUALIZATION: plot reconstructions
    if loss_type == 'bpd':
        fname = str(epoch) + '_bpd_%5.3f' % loss
    elif loss_type == 'elbo':
        fname = str(epoch) + '_elbo_%6.4f' % loss
    plot_images(args, recon_mean.data.cpu().numpy()[0:9], args.snap_dir + 'reconstruction/', fname, size_x=3, size_y=3)


def plot_images(args, x_sample, dir, file_name, size_x=3, size_y=3):

    fig = plt.figure(figsize=(size_x, size_y))
    # fig = plt.figure(1)
    gs = gridspec.GridSpec(size_x, size_y)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(x_sample):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        sample = sample.reshape((args.input_size[0], args.input_size[1], args.input_size[2]))
        sample = sample.swapaxes(0, 2)
        sample = sample.swapaxes(0, 1)
        if (args.input_type == 'binary') or (args.input_type in ['multinomial'] and args.input_size[0] == 1):
            sample = sample[:, :, 0]
            plt.imshow(sample, cmap='gray', vmin=0, vmax=1)
        else:
            plt.imshow(sample)

    plt.savefig(dir + file_name + '.png', bbox_inches='tight')
    plt.close(fig)