import numpy as np
import os
import matplotlib.pyplot as plt


def plot_row_colors(C, fig_size=6, title=None):
    """
    Plot rows of C as colors (RGB)

    :param C: An array N x 3 where the rows are considered as RGB colors.
    :return:
    """
    assert isinstance(C, np.ndarray), "C must be a numpy array."
    assert C.ndim == 2, "C must be 2D."
    assert C.shape[1] == 3, "C must have 3 columns."

    N = C.shape[0]
    range255 = C.max() > 1.0  # quick check to see if we have an image in range [0,1] or [0,255].
    plt.rcParams['figure.figsize'] = (fig_size, fig_size)
    for i in range(N):
        if range255:
            plt.plot([0, 1], [N - 1 - i, N - 1 - i], c=C[i] / 255, linewidth=20)
        else:
            plt.plot([0, 1], [N - 1 - i, N - 1 - i], c=C[i], linewidth=20)
    if title is not None:
        plt.title(title)
    plt.axis("off")
    plt.axis([0, 1, -0.5, N-0.5])


def plot_image(image, show=True, fig_size=10, title=None):
    """
    Plot an image (np.array).
    Caution: Rescales image to be in range [0,1].

    :param image: RGB uint8
    :param show: plt.show() now?
    :param fig_size: Size of largest dimension
    :param title: Image title
    :return:
    """
    image = image.astype(np.float32)
    m, M = image.min(), image.max()
    if fig_size is not None:
        plt.rcParams['figure.figsize'] = (fig_size, fig_size)
    else:
        plt.imshow((image - m) / (M - m))
    if title is not None:
        plt.title(title)
    plt.axis("off")
    if show:
        plt.show()


def plot_image_list(images, width=5, sub_sample=False, rand=False, save_name=None, title_list=None, show=True):
    """
    Display a grid of images.

    :param images: List of RGB uint8
    :param width: Number of images per row.
    :param sub_sample: Number of images to subsample or false.
    :param rand: Should the subsample be randomized?
    :param save_name: File name to save to.
    :param title_list: A list of titles. Should only be used when sub_sample is false.
    :param show: plt.show() now?
    :return:
    """
    if sub_sample and rand:
        indicies = list(np.random.choice(range(len(images)), sub_sample, replace=False))
    elif sub_sample and not rand:
        indicies = range(sub_sample)
    else:
        indicies = range(len(images))

    height = np.ceil(float(len(indicies)) / width).astype(int)
    plt.rcParams['figure.figsize'] = (18, (18 / width) * height)
    plt.figure()

    for i in range(len(indicies)):
        plt.subplot(height, width, i + 1)
        if title_list is not None:
            plt.title(title_list[i])
        plot_image(images[i], show=False, fig_size=None)

    if save_name is not None:
        os.makedirs(os.path.dirname(save_name), exist_ok=True)
        plt.savefig(save_name)

    if show:
        plt.show()