#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import sys

# Hack so you don't have to put the library containing this script in the PYTHONPATH.
sys.path = [os.path.abspath(os.path.join(__file__, '..', '..'))] + sys.path

from os.path import join as pjoin
import argparse

import pickle
import numpy as np

import smartlearner.utils as smartutils
from convnade.utils import Timer


def buildArgsParser():
    DESCRIPTION = "Generate samples from a Conv Deep NADE model."
    p = argparse.ArgumentParser(description=DESCRIPTION, formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    p.add_argument('experiment', type=str, help='folder where to find a trained ConvDeepNADE model')
    p.add_argument('count', type=int, help='number of samples to generate.')
    p.add_argument('--out', type=str, help='name of the samples file')

    # General parameters (optional)
    p.add_argument('--seed', type=int, help='seed used to generate random numbers. Default: 1234', default=1234)
    p.add_argument('--view', action='store_true', help="show samples.")

    p.add_argument('-v', '--verbose', action='store_true', help='produce verbose output')
    p.add_argument('-f', '--force',  action='store_true', help='permit overwriting')

    return p


def load_model(experiment_path):

    with Timer("Loading model"):
        from convnade import DeepConvNadeUsingLasagne, DeepConvNadeWithResidualUsingLasagne
        from convnade import DeepConvNADE, DeepConvNADEWithResidual

        for model_class in [DeepConvNadeUsingLasagne, DeepConvNadeWithResidualUsingLasagne, DeepConvNADE, DeepConvNADEWithResidual]:
            try:
                model = model_class.create(experiment_path)
                return model
            except Exception as e:
                print (e)
                pass

    raise NameError("No model found!")
    return None


def main():
    parser = buildArgsParser()
    args = parser.parse_args()

    # Load experiments hyperparameters
    try:
        hyperparams = smartutils.load_dict_from_json_file(pjoin(args.experiment, "hyperparams.json"))
    except:
        hyperparams = smartutils.load_dict_from_json_file(pjoin(args.experiment, '..', "hyperparams.json"))

    model = load_model(args.experiment)
    print(str(model))

    with Timer("Generating {} samples from Conv Deep NADE".format(args.count)):
        sample = model.build_sampling_function(seed=args.seed)
        samples, probs = sample(args.count, return_probs=True, ordering_seed=args.seed)

    if args.out is not None:
        outfile = pjoin(args.experiment, args.out)
        with Timer("Saving {0} samples to '{1}'".format(args.count, outfile)):
            np.save(outfile, samples)

    if args.view:
        import pylab as plt
        from convnade import vizu
        if hyperparams["dataset"] == "binarized_mnist":
            image_shape = (28, 28)
        else:
            raise ValueError("Unknown dataset: {0}".format(hyperparams["dataset"]))

        plt.figure()
        data = vizu.concatenate_images(samples, shape=image_shape, border_size=1, clim=(0, 1))
        plt.imshow(data, cmap=plt.cm.gray, interpolation='nearest')
        plt.title("Samples")

        plt.figure()
        data = vizu.concatenate_images(probs, shape=image_shape, border_size=1, clim=(0, 1))
        plt.imshow(data, cmap=plt.cm.gray, interpolation='nearest')
        plt.title("Probs")

        plt.show()

if __name__ == '__main__':
    main()