from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from six.moves import xrange
import numpy as np

from util import log

from model import Model
from input_ops import create_input_ops, check_data_id

import tensorflow as tf
import time
import imageio
import scipy.misc as sm


class EvalManager(object):

    def __init__(self):
        # collection of batches (not flattened)
        self._ids = []
        self._predictions = []
        self._groundtruths = []

    def add_batch(self, id, prediction, groundtruth):

        # for now, store them all (as a list of minibatch chunks)
        self._ids.append(id)
        self._predictions.append(prediction)
        self._groundtruths.append(groundtruth)

    def compute_loss(self, pred, gt):
        return np.sum(np.abs(pred - gt))/np.prod(pred.shape)

    def report(self):
        log.info("Computing scores...")
        total_loss = []

        for id, pred, gt in zip(self._ids, self._predictions, self._groundtruths):
            total_loss.append(self.compute_loss(pred, gt))
        avg_loss = np.average(total_loss)
        log.infov("Average loss : %.4f", avg_loss)


class Evaler(object):
    def __init__(self,
                 config,
                 dataset,
                 dataset_train):
        self.config = config
        self.train_dir = config.train_dir
        log.info("self.train_dir = %s", self.train_dir)

        # --- input ops ---
        self.batch_size = config.batch_size

        self.dataset = dataset
        self.dataset_train = dataset_train

        check_data_id(dataset, config.data_id)
        _, self.batch = create_input_ops(dataset, self.batch_size,
                                         data_id=config.data_id,
                                         is_training=False,
                                         shuffle=False)

        # --- create model ---
        self.model = Model(config)

        self.global_step = tf.contrib.framework.get_or_create_global_step(graph=None)
        self.step_op = tf.no_op(name='step_no_op')

        tf.set_random_seed(123)

        session_config = tf.ConfigProto(
            allow_soft_placement=True,
            gpu_options=tf.GPUOptions(allow_growth=True),
            device_count={'GPU': 1},
        )
        self.session = tf.Session(config=session_config)

        # --- checkpoint and monitoring ---
        self.saver = tf.train.Saver(max_to_keep=100)

        self.checkpoint_path = config.checkpoint_path
        if self.checkpoint_path is None and self.train_dir:
            self.checkpoint_path = tf.train.latest_checkpoint(self.train_dir)
        if self.checkpoint_path is None:
            log.warn("No checkpoint is given. Just random initialization :-)")
            self.session.run(tf.global_variables_initializer())
        else:
            log.info("Checkpoint path : %s", self.checkpoint_path)

    def eval_run(self):
        # load checkpoint
        if self.checkpoint_path:
            self.saver.restore(self.session, self.checkpoint_path)
            log.info("Loaded from checkpoint!")

        log.infov("Start Inference and Evaluation")

        log.info("# of testing examples = %d", len(self.dataset))
        length_dataset = len(self.dataset)

        max_steps = int(length_dataset / self.batch_size) + 1

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(self.session,
                                               coord=coord, start=True)

        evaler = EvalManager()

        if not (self.config.interpolate or self.config.generate or self.config.reconstruct):
            raise ValueError('Please specify at least one task by indicating' +
                             '--reconstruct, --generate, or --interpolate.')
            return

        if self.config.reconstruct:
            try:
                for s in xrange(max_steps):
                    step, loss, step_time, batch_chunk, prediction_pred, prediction_gt = \
                        self.run_single_step(self.batch)
                    self.log_step_message(s, loss, step_time)
                    evaler.add_batch(batch_chunk['id'], prediction_pred, prediction_gt)

            except Exception as e:
                coord.request_stop(e)

            evaler.report()
            log.warning('Completed reconstruction.')

        if self.config.generate:
            x = self.generator(self.batch_size)
            img = self.image_grid(x)
            imageio.imwrite('generate_{}.png'.format(self.config.prefix), img)
            log.warning('Completed generation. Generated samples are save' +
                        'as generate_{}.png'.format(self.config.prefix))

        if self.config.interpolate:
            x = self.interpolator(self.dataset_train, self.batch_size)
            img = self.image_grid(x)
            imageio.imwrite('interpolate_{}.png'.format(self.config.prefix), img)
            log.warning('Completed interpolation. Interpolated samples are save' +
                        'as interpolate_{}.png'.format(self.config.prefix))

        coord.request_stop()
        try:
            coord.join(threads, stop_grace_period_secs=3)
        except RuntimeError as e:
            log.warn(str(e))

        log.infov("Completed evaluation.")

    def generator(self, num):
        z = np.random.randn(num, self.config.data_info[3])
        row_sums = np.sqrt(np.sum(z ** 2, axis=0))
        z = z / row_sums[np.newaxis, :]
        x_hat = self.session.run(self.model.x_recon, feed_dict={self.model.z: z})
        return x_hat

    def interpolator(self, dataset, bs, num=15):
        transit_num = num - 2
        img = []
        for i in range(num):
            idx = np.random.randint(len(dataset.ids)-1)
            img1, z1 = dataset.get_data(dataset.ids[idx])
            img2, z2 = dataset.get_data(dataset.ids[idx+1])
            z = []
            for j in range(transit_num):
                z_int = (z2 - z1) * (j+1) / (transit_num+1) + z1
                z.append(z_int / np.linalg.norm(z_int))
            z = np.stack(z, axis=0)
            z_aug = np.concatenate((z, np.zeros((bs-transit_num, z.shape[1]))), axis=0)
            x_hat = self.session.run(self.model.x_recon, feed_dict={self.model.z: z_aug})
            img.append(np.concatenate((np.expand_dims(img1, 0),
                                       x_hat[:transit_num], np.expand_dims(img2, 0))))
        return np.reshape(np.stack(img, axis=0), (num*(transit_num+2),
                                                  img1.shape[0], img1.shape[1], img1.shape[2]))

    def image_grid(self, x, shape=(2048, 2048)):
        n = int(np.sqrt(x.shape[0]))
        h, w, c = self.config.data_info[0], self.config.data_info[1], self.config.data_info[2]
        I = np.zeros((n*h, n*w, c))
        for i in range(n):
            for j in range(n):
                I[h * i:h * (i+1), w * j:w * (j+1), :] = x[i * n + j]
        if c == 1:
            I = I[:, :, 0]
        return sm.imresize(I, shape)

    def run_single_step(self, batch, step=None, is_train=True):
        _start_time = time.time()

        batch_chunk = self.session.run(batch)

        [step, loss, all_targets, all_preds, _] = self.session.run(
            [self.global_step, self.model.loss, self.model.x, self.model.x_recon, self.step_op],
            feed_dict=self.model.get_feed_dict(batch_chunk)
        )

        _end_time = time.time()

        return step, loss, (_end_time - _start_time), batch_chunk, all_preds, all_targets

    def log_step_message(self, step, loss, step_time, is_train=False):
        if step_time == 0: step_time = 0.001
        log_fn = (is_train and log.info or log.infov)
        log_fn((" [{split_mode:5s} step {step:4d}] " +
                "Loss (test): {loss:.5f} " +
                "({sec_per_batch:.3f} sec/batch, {instance_per_sec:.3f} instances/sec) "
                ).format(split_mode=(is_train and 'train' or 'val'),
                         step=step,
                         loss=loss,
                         sec_per_batch=step_time,
                         instance_per_sec=self.batch_size / step_time,
                         )
               )


def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--prefix', type=str, default='default')
    parser.add_argument('--checkpoint_path', type=str, default=None)
    parser.add_argument('--train_dir', type=str)
    parser.add_argument('--dataset', type=str, default='MNIST', choices=['MNIST', 'SVHN', 'CIFAR10'])
    parser.add_argument('--reconstruct', action='store_true', default=False)
    parser.add_argument('--generate', action='store_true', default=False)
    parser.add_argument('--interpolate', action='store_true', default=False)
    parser.add_argument('--data_id', nargs='*', default=None)
    config = parser.parse_args()

    if config.dataset == 'MNIST':
        import datasets.mnist as dataset
    elif config.dataset == 'SVHN':
        import datasets.svhn as dataset
    elif config.dataset == 'CIFAR10':
        import datasets.cifar10 as dataset
    else:
        raise ValueError(config.dataset)

    config.conv_info = dataset.get_conv_info()
    config.deconv_info = dataset.get_deconv_info()
    dataset_train, dataset_test = dataset.create_default_splits()

    m, l = dataset_train.get_data(dataset_train.ids[0])
    config.data_info = np.concatenate([np.asarray(m.shape), np.asarray(l.shape)])

    evaler = Evaler(config, dataset_test, dataset_train)

    log.warning("dataset: %s", config.dataset)
    evaler.eval_run()

if __name__ == '__main__':
    main()