from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import numpy as np
from collections import OrderedDict
from six.moves import xrange
import warnings
import logging
import os

known_number_types = (int, float, np.float16, np.float32, np.float64,
                      np.int8, np.int16, np.int32, np.int32, np.int64,
                      np.uint8, np.uint16, np.uint32, np.uint64)


CLEVERHANS_ROOT = os.path.dirname(os.path.dirname(__file__))


class _ArgsWrapper(object):

    """
    Wrapper that allows attribute access to dictionaries
    """

    def __init__(self, args):
        if not isinstance(args, dict):
            args = vars(args)
        self.args = args

    def __getattr__(self, name):
        return self.args.get(name)


class AccuracyReport(object):

    """
    An object summarizing the accuracy results for experiments involving
    training on clean examples or adversarial examples, then evaluating
    on clean or adversarial examples.
    """

    def __init__(self):
        self.clean_train_clean_eval = 0.
        self.clean_train_adv_eval = 0.
        self.adv_train_clean_eval = 0.
        self.adv_train_adv_eval = 0.

        # Training data accuracy results to be used by tutorials
        self.train_clean_train_clean_eval = 0.
        self.train_clean_train_adv_eval = 0.
        self.train_adv_train_clean_eval = 0.
        self.train_adv_train_adv_eval = 0.


def batch_indices(batch_nb, data_length, batch_size):
    """
    This helper function computes a batch start and end index
    :param batch_nb: the batch number
    :param data_length: the total length of the data being parsed by batches
    :param batch_size: the number of inputs in each batch
    :return: pair of (start, end) indices
    """
    # Batch start and end index
    start = int(batch_nb * batch_size)
    end = int((batch_nb + 1) * batch_size)

    # When there are not enough inputs left, we reuse some to complete the
    # batch
    if end > data_length:
        shift = end - data_length
        start -= shift
        end -= shift

    return start, end


def other_classes(nb_classes, class_ind):
    """
    Returns a list of class indices excluding the class indexed by class_ind
    :param nb_classes: number of classes in the task
    :param class_ind: the class index to be omitted
    :return: list of class indices excluding the class indexed by class_ind
    """
    if class_ind < 0 or class_ind >= nb_classes:
        error_str = "class_ind must be within the range (0, nb_classes - 1)"
        raise ValueError(error_str)

    other_classes_list = list(range(nb_classes))
    other_classes_list.remove(class_ind)

    return other_classes_list


def to_categorical(y, num_classes=None):
    """
    Converts a class vector (integers) to binary class matrix.
    This is adapted from the Keras function with the same name.
    :param y: class vector to be converted into a matrix
              (integers from 0 to num_classes).
    :param num_classes: num_classes: total number of classes.
    :return: A binary matrix representation of the input.
    """
    y = np.array(y, dtype='int').ravel()
    if not num_classes:
        num_classes = np.max(y) + 1
        warnings.warn("FutureWarning: the default value of the second"
                      "argument in function \"to_categorical\" is deprecated."
                      "On 2018-9-19, the second argument"
                      "will become mandatory.")
    n = y.shape[0]
    categorical = np.zeros((n, num_classes))
    categorical[np.arange(n), y] = 1
    return categorical


def random_targets(gt, nb_classes):
    """
    Take in an array of correct labels and randomly select a different label
    for each label in the array. This is typically used to randomly select a
    target class in targeted adversarial examples attacks (i.e., when the
    search algorithm takes in both a source class and target class to compute
    the adversarial example).
    :param gt: the ground truth (correct) labels. They can be provided as a
               1D vector or 2D array of one-hot encoded labels.
    :param nb_classes: The number of classes for this task. The random class
                       will be chosen between 0 and nb_classes such that it
                       is different from the correct class.
    :return: A numpy array holding the randomly-selected target classes
             encoded as one-hot labels.
    """
    # If the ground truth labels are encoded as one-hot, convert to labels.
    if len(gt.shape) == 2:
        gt = np.argmax(gt, axis=1)

    # This vector will hold the randomly selected labels.
    result = np.zeros(gt.shape, dtype=np.int32)

    for class_ind in xrange(nb_classes):
        # Compute all indices in that class.
        in_cl = gt == class_ind
        size = np.sum(in_cl)

        # Compute the set of potential targets for this class.
        potential_targets = other_classes(nb_classes, class_ind)

        # Draw with replacement random targets among the potential targets.
        result[in_cl] = np.random.choice(potential_targets, size=size)

    # Encode vector of random labels as one-hot labels.
    result = to_categorical(result, nb_classes)
    result = result.astype(np.int32)

    return result


def pair_visual(original, adversarial, figure=None):
    """
    This function displays two images: the original and the adversarial sample
    :param original: the original input
    :param adversarial: the input after perterbations have been applied
    :param figure: if we've already displayed images, use the same plot
    :return: the matplot figure to reuse for future samples
    """
    import matplotlib.pyplot as plt

    # Squeeze the image to remove single-dimensional entries from array shape
    original = np.squeeze(original)
    adversarial = np.squeeze(adversarial)

    # Ensure our inputs are of proper shape
    assert(len(original.shape) == 2 or len(original.shape) == 3)

    # To avoid creating figures per input sample, reuse the sample plot
    if figure is None:
        plt.ion()
        figure = plt.figure()
        figure.canvas.set_window_title('Cleverhans: Pair Visualization')

    # Add the images to the plot
    perterbations = adversarial - original
    for index, image in enumerate((original, perterbations, adversarial)):
        figure.add_subplot(1, 3, index + 1)
        plt.axis('off')

        # If the image is 2D, then we have 1 color channel
        if len(image.shape) == 2:
            plt.imshow(image, cmap='gray')
        else:
            plt.imshow(image)

        # Give the plot some time to update
        plt.pause(0.01)

    # Draw the plot and return
    plt.show()
    return figure


def grid_visual(data):
    """
    This function displays a grid of images to show full misclassification
    :param data: grid data of the form;
        [nb_classes : nb_classes : img_rows : img_cols : nb_channels]
    :return: if necessary, the matplot figure to reuse
    """
    import matplotlib.pyplot as plt

    # Ensure interactive mode is disabled and initialize our graph
    plt.ioff()
    figure = plt.figure()
    figure.canvas.set_window_title('Cleverhans: Grid Visualization')

    # Add the images to the plot
    num_cols = data.shape[0]
    num_rows = data.shape[1]
    num_channels = data.shape[4]
    current_row = 0
    for y in xrange(num_rows):
        for x in xrange(num_cols):
            figure.add_subplot(num_rows, num_cols, (x + 1) + (y * num_cols))
            plt.axis('off')

            if num_channels == 1:
                plt.imshow(data[x, y, :, :, 0], cmap='gray')
            else:
                plt.imshow(data[x, y, :, :, :])

    # Draw the plot and return
    plt.show()
    return figure


def get_logits_over_interval(sess, model, x_data, fgsm_params,
                             min_epsilon=-10., max_epsilon=10.,
                             num_points=21):
    """Get logits when the input is perturbed in an interval in adv direction.

    Args:
        sess: Tf session
        model: Model for which we wish to get logits.
        x_data: Numpy array corresponding to single data.
                point of shape [height, width, channels].
        fgsm_params: Parameters for generating adversarial examples.
        min_epsilon: Minimum value of epsilon over the interval.
        max_epsilon: Maximum value of epsilon over the interval.
        num_points: Number of points used to interpolate.

    Returns:
        Numpy array containing logits.

    Raises:
        ValueError if min_epsilon is larger than max_epsilon.
    """
    # Get the height, width and number of channels
    height = x_data.shape[0]
    width = x_data.shape[1]
    channels = x_data.shape[2]
    size = height * width * channels

    x_data = np.expand_dims(x_data, axis=0)
    import tensorflow as tf
    from cleverhans.attacks import FastGradientMethod

    # Define the data placeholder
    x = tf.placeholder(dtype=tf.float32,
                       shape=[1, height,
                              width,
                              channels],
                       name='x')
    # Define adv_x
    fgsm = FastGradientMethod(model, sess=sess)
    adv_x = fgsm.generate(x, **fgsm_params)

    if min_epsilon > max_epsilon:
        raise ValueError('Minimum epsilon is less than maximum epsilon')

    eta = tf.nn.l2_normalize(adv_x - x, dim=0)
    epsilon = tf.reshape(tf.lin_space(float(min_epsilon),
                                      float(max_epsilon),
                                      num_points),
                         (num_points, 1, 1, 1))
    lin_batch = x + epsilon * eta
    logits = model.get_logits(lin_batch)
    with sess.as_default():
        log_prob_adv_array = sess.run(logits,
                                      feed_dict={x: x_data})
    return log_prob_adv_array


def linear_extrapolation_plot(log_prob_adv_array, y, file_name,
                              min_epsilon=-10, max_epsilon=10,
                              num_points=21):
    """Generate linear extrapolation plot.

    Args:
        log_prob_adv_array: Numpy array containing log probabilities
        y: Tf placeholder for the labels
        file_name: Plot filename
        min_epsilon: Minimum value of epsilon over the interval
        max_epsilon: Maximum value of epsilon over the interval
        num_points: Number of points used to interpolate
    """
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    figure = plt.figure()
    figure.canvas.set_window_title('Cleverhans: Linear Extrapolation Plot')

    correct_idx = np.argmax(y, axis=0)
    fig = plt.figure()
    plt.xlabel('Epsilon')
    plt.ylabel('Logits')
    x_axis = np.linspace(min_epsilon, max_epsilon, num_points)
    plt.xlim(min_epsilon - 1, max_epsilon + 1)
    for i in xrange(y.shape[0]):
        if i == correct_idx:
            ls = '-'
            linewidth = 5
        else:
            ls = '--'
            linewidth = 2
        plt.plot(
            x_axis,
            log_prob_adv_array[:, i],
            ls=ls,
            linewidth=linewidth,
            label='{}'.format(i))
    plt.legend(loc='best', fontsize=14)
    plt.show()
    fig.savefig(file_name)
    plt.clf()
    return figure


def set_log_level(level, name="cleverhans"):
    """
    Sets the threshold for the cleverhans logger to level
    :param level: the logger threshold. You can find values here:
                  https://docs.python.org/2/library/logging.html#levels
    :param name: the name used for the cleverhans logger
    """
    logging.getLogger(name).setLevel(level)


def get_log_level(name="cleverhans"):
    """
    Gets the current threshold for the cleverhans logger
    :param name: the name used for the cleverhans logger
    """
    return logging.getLogger(name).getEffectiveLevel()


class TemporaryLogLevel(object):
    """
    A ContextManager that changes a log level temporarily.

    Note that the log level will be set back to its original value when
    the context manager exits, even if the log level has been changed
    again in the meantime.
    """

    def __init__(self, level, name):
        self.name = name
        self.level = level

    def __enter__(self):
        self.old_level = get_log_level(self.name)
        set_log_level(self.level, self.name)

    def __exit__(self, type, value, traceback):
        set_log_level(self.old_level, self.name)
        return True


def create_logger(name):
    """
    Create a logger object with the given name.

    If this is the first time that we call this method, then initialize the
    formatter.
    """
    base = logging.getLogger("cleverhans")
    if len(base.handlers) == 0:
        ch = logging.StreamHandler()
        formatter = logging.Formatter('[%(levelname)s %(asctime)s %(name)s] ' +
                                      '%(message)s')
        ch.setFormatter(formatter)
        base.addHandler(ch)

    return base


def deterministic_dict(normal_dict):
    """
    Returns a version of `normal_dict` whose iteration order is always the same
    """
    out = OrderedDict()
    for key in sorted(normal_dict.keys()):
        out[key] = normal_dict[key]
    return out


def ordered_union(l1, l2):
    out = []
    for e in l1 + l2:
        if e not in out:
            out.append(e)
    return out


def safe_zip(*args):
    """zip, with a guarantee that all arguments are the same length.
    (normal zip silently drops entries to make them the same length)
    """
    length = len(args[0])
    assert all(len(arg) == length for arg in args)
    return zip(*args)