"""
    Utilities library for evaluation.

Code borrowed from tensorlow-gan library.
We do not claim any ownership on this code and you should refer to the LICENCE
of the tensorflow-gan library.

"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import numpy as np
import PIL

import tensorflow as tf
import data_provider
import tensorflow_gan as tfgan  # tf


def get_activations(get_images_fn, num_batches, get_logits=False):
    """Get Inception activations.

    Use TF-GAN utility to avoid holding images or Inception activations in
    memory all at once.

    Args:
      get_images_fn: A function that takes no arguments and returns images.
      num_batches: The number of batches to fetch at a time.
      get_logits: If `True`, return (logits, pools). Otherwise just return pools.

    Returns:
      1 or 2 Tensors of Inception activations.
    """
    inception_img_sz = tfgan.eval.INCEPTION_DEFAULT_IMAGE_SIZE
    outputs = tfgan.eval.sample_and_run_inception(sample_fn=lambda _: tf.compat.v1.image.resize(get_images_fn(),
                                                                                                [inception_img_sz, inception_img_sz],
                                                                                                method=tf.image.ResizeMethod.BILINEAR),
                                                  sample_inputs=[1.0] * num_batches)
    if get_logits:
        return outputs['logits'], outputs['pool_3']
    else:
        return outputs['pool_3']


def get_activations_from_dataset(image_ds, num_batches, get_logits=False):
    """Get Inception activations.

    Args:
      image_ds: tf.Dataset for images.
      num_batches: The number of batches to fetch at a time.
      get_logits: If `True`, return (logits, pools). Otherwise just return pools.

    Returns:
      1 or 2 Tensors of Inception activations.
    """
    iterator = tf.compat.v1.data.make_one_shot_iterator(image_ds)

    get_images_fn = iterator.get_next
    return get_activations(get_images_fn, num_batches, get_logits)


def get_real_activations(batch_size,
                         num_batches,
                         shuffle_buffer_size=100000,
                         split='validation',
                         get_logits=False):
    """Fetches batches inception pools and images.

    NOTE: This function runs inference on an Inception network, so it would be
    more efficient to run this on GPU or TPU than on CPU.

    Args:
      batch_size: The number of elements in a single minibatch.
      num_batches: The number of batches to fetch at a time.
      shuffle_buffer_size: The number of records to load before shuffling. Larger
          means more likely randomization.
      split: Shuffle if 'train', else deterministic.
      get_logits: If `True`, return (logits, pools). Otherwise just return pools.

    Returns:
      A Tensor of `real_pools` or (`real_logits`, `real_pools`) with batch
      dimension (batch_size * num_batches).
    """
    ds = data_provider.provide_dataset(batch_size, shuffle_buffer_size, split)
    ds = ds.map(lambda img, lbl: img)  # Remove labels.
    return get_activations_from_dataset(ds, num_batches, get_logits)


def print_debug_statistics(image, labels, dbg_messge_prefix, on_tpu):
    """Adds a Print directive to an image tensor which prints debug statistics."""
    if on_tpu:
        # Print operations are not supported on TPUs.
        return image, labels
    image_means = tf.reduce_mean(input_tensor=image, axis=0, keepdims=True)
    image_vars = tf.reduce_mean(
        input_tensor=tf.math.squared_difference(image, image_means),
        axis=0,
        keepdims=True)
    image = tf.compat.v1.Print(
        image, [
            tf.reduce_mean(input_tensor=image_means),
            tf.reduce_mean(input_tensor=image_vars)
        ],
        dbg_messge_prefix + ' mean and average var',
        first_n=1)
    labels = tf.compat.v1.Print(
        labels, [labels, labels.shape],
        dbg_messge_prefix + ' sparse labels',
        first_n=2)
    return image, labels


def log_and_summarize_variables(var_list, dbg_messge, on_tpu):
    """Logs given variables, summarizes sigma_ratio_vars."""
    tf.compat.v1.logging.info(dbg_messge + str(var_list))
    sigma_ratio_vars = [var for var in var_list if 'sigma_ratio' in var.name]
    tf.compat.v1.logging.info('sigma_ratio_vars %s %s', dbg_messge,
                              sigma_ratio_vars)
    # Reset the name scope so the summary names are displayed as passed to the
    # summary function.
    if not on_tpu:
        # The TPU estimator doesn't support summaries.
        with tf.compat.v1.name_scope(name=None):
            for var in sigma_ratio_vars:
                tf.compat.v1.summary.scalar(
                    'sigma_ratio_vars/' + var.name, var)


def predict_and_write_images(estimator, input_fn, model_dir, filename_suffix):
    """Generates images and write them to the model dir.

    Args:
      estimator: An object of type tfgan.estimator.GANEstimator or
        tfgan.estimator.TPUGANEstimator for performing the predictions.
      input_fn: An input_fn function to be used by `estimator.predict`.
      model_dir: The model directory (the images will be saved inside an 'images'
        subdirectory).
      filename_suffix: A suffix to append to the image file names.
    """
    # Generate images.
    image_iterator = estimator.predict(input_fn)
    if isinstance(estimator, tfgan.estimator.TPUGANEstimator):
        predictions = np.array(
            [next(image_iterator)['generated_data'] for _ in range(16)])
    else:
        predictions = np.array([next(image_iterator) for _ in range(16)])
    # Write images to disk.
    output_dir = os.path.join(model_dir, 'images')
    if not tf.io.gfile.exists(output_dir):
        tf.io.gfile.makedirs(output_dir)
    # Generate a grid of images and write it to disk.
    image_grid = tfgan.eval.python_image_grid(predictions, grid_shape=(4, 4))
    grid_fname = os.path.join(output_dir, 'grid_%s.png' % filename_suffix)
    _write_image_to_disk(image_grid, grid_fname)


def _write_image_to_disk(image, filename):
    with tf.io.gfile.GFile(filename, 'w') as f:
        # Convert tiled_image from float32 in [-1, 1] to unit8 [0, 255].
        img_np = (255 / 2.0) * (image + 1.0)
        pil_image = PIL.Image.fromarray(img_np.astype(np.uint8))
        pil_image.convert('RGB').save(f, 'PNG')
    tf.compat.v1.logging.info('Wrote output to: %s', filename)