from cycler import cycler
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np
from PIL import Image
import pymc3 as pm
import theano
import theano.tensor as tt
from tqdm import tqdm
import scipy


def get_rainbow():
    """Creates a rainbow color cycle"""
    return cycler('color', [
        '#FF0000',
        '#FF7F00',
        '#FFFF00',
        '#00FF00',
        '#0000FF',
        '#4B0082',
        '#9400D3',
    ])


def load_image(image_file, mode=None):
    """Load filename into a numpy array, filling in transparency with 0's.

    Parameters
    ----------
    image_file : str
        File to load. Usually works with .jpg and .png.

    Returns
    -------
    numpy.ndarray of resulting image. Has shape (w, h), (w, h, 3), or (w, h, 4)
        if black and white, color, or color with alpha channel, respectively.
    """
    image = Image.open(image_file)
    if mode is None:
        mode = image.mode
    alpha = image.convert('RGBA').split()[-1]
    background = Image.new("RGBA", image.size, (255, 255, 255, 255,))
    background.paste(image, mask=alpha)
    img = np.flipud(np.asarray(background.convert(mode)))
    img = img / 255
    if mode == 'L':  # I don't know how images work, but .png's are inverted
        img = 1 - img
    return img


class ImageLikelihood(theano.Op):
    """
    Custom theano op for turning a 2d intensity matrix into a density
    distribution.
    """

    itypes = [tt.dvector]
    otypes = [tt.dvector]

    def __init__(self, img):
        self.width, self.height = img.shape
        self.density = scipy.interpolate.RectBivariateSpline(
            x=np.arange(self.width),
            y=np.arange(self.height),
            z=img)

    def perform(self, node, inputs, output_storage):
        """Evaluates the density of the image at the given point."""
        x, y = inputs[0]
        if x < 0 or x > self.width or y < 0 or y > self.height:
            output_storage[0][0] = np.array([np.log(0)])
        else:
            output_storage[0][0] = np.log(self.density(x, y))[0]


def sample_grayscale(image, samples=5000, tune=100, nchains=4, threshold=0.2):
    """Run MCMC on a 1 color image. Works best on logos or text.

    Parameters
    ----------
    image : numpy.ndarray
        Image array from `load_image`.  Should have `image.ndims == 2`.

    samples : int
        Number of samples to draw from the image

    tune : int
        Number of tuning steps to take. Note that this adjusts the step size:
            if you want smaller steps, make tune closer to 0.

    nchains : int
        Number of chains to sample with. This will later turn into the number
        of colors in your plot. Note that you get `samples * nchains` of total
        points in your final scatter.

    threshold : float
        Float between 0 and 1. It looks nicer when an image is binarized, and
        this will do that. Use `None` to not binarize. In theory you should get
        fewer samples from lighter areas, but your mileage may vary.

    Returns
    -------
    pymc3.MultiTrace of samples from the image. Each sample is an (x, y) float
        of indices that were sampled, with the variable name 'image'.
    """
    # preprocess
    image_copy = image.copy()
    if threshold != -1:
        image_copy[image < threshold] = 0
        image_copy[image >= threshold] = 1

    # need an active pixel to start on
    active_pixels = np.array(list(zip(*np.where(image_copy == image_copy.max()))))
    idx = np.random.randint(0, len(active_pixels), nchains)
    start = active_pixels[idx]

    with pm.Model():
        pm.DensityDist('image', ImageLikelihood(image_copy), shape=2)
        trace = pm.sample(samples,
                          tune=tune,
                          chains=nchains, step=pm.Metropolis(),
                          start=[{'image': x} for x in start],
                         )
    return trace


def sample_color(image, samples=5000, tune=1000, nchains=4):
    """Run MCMC on a color image. EXPERIMENTAL!

    Parameters
    ----------
    image : numpy.ndarray
        Image array from `load_image`.  Should have `image.ndims == 2`.

    samples : int
        Number of samples to draw from the image

    tune : int
        All chains start at the same spot, so it is good to let them wander
        apart a bit before beginning

    Returns
    -------
    pymc3.MultiTrace of samples from the image. Each sample is an (x, y) float
        of indices that were sampled, with three variables named 'red',
        'green', 'blue'.
    """

    with pm.Model():
        pm.DensityDist('red', ImageLikelihood(image[:, :, 0]), shape=2)
        pm.DensityDist('green', ImageLikelihood(image[:, :, 1]), shape=2)
        pm.DensityDist('blue', ImageLikelihood(image[:, :, 2]), shape=2)

        trace = pm.sample(samples, chains=nchains, tune=tune, step=pm.Metropolis())
    return trace


def plot_multitrace(trace, image, max_size=10, colors=None, **plot_kwargs):
    """Plot an image of the grayscale trace.

    Parameters
    ----------
    trace : pymc3.MultiTrace
        Get this from sample_grayscale

    image : numpy.ndarray
        Image array from `load_image`, used to produce the trace.

    max_size : float
        Used to set the figsize for the image, maintaining the aspect ratio.
        In inches!

    colors : iterable
        You can set custom colors to cycle through! Default is the rainbow.

    plot_kwargs :
        Other keyword arguments passed to the trace plotting. Some useful
        examples are marker='.' in case you sampled lots of points, alpha=0.3
        to add transparency to the points, or linestyle='-', so you can see the
        actual path the chains took.

    Returns
    -------
    (figure, axis)
        The matplotlib figure and axis with the plot
    """
    default_kwargs = {'marker': 'o', 'linestyle': '', 'alpha': 0.4}
    default_kwargs.update(plot_kwargs)
    if colors is None:
        colors = get_rainbow()
    else:
        colors = cycler('color', colors)

    vals = [trace.get_values('image', chains=chain) for chain in trace.chains]

    fig, ax = plt.subplots(figsize=get_figsize(image, max_size))
    ax.set_prop_cycle(colors)
    ax.set_xlim((0, image.shape[1]))
    ax.set_ylim((0, image.shape[0]))
    ax.axis('off')

    for val in vals:
        ax.plot(val[:, 1], val[:, 0], **default_kwargs)
    return fig, ax


def make_gif(trace, image, steps=200, leading_point=True,
             filename='output.gif', max_size=10, interval=30, dpi=20,
             colors=None, **plot_kwargs):
    """Make a gif of the grayscale trace.

    Parameters
    ----------
    trace : pymc3.MultiTrace
        Get this from sample_grayscale

    image : numpy.ndarray
        Image array from `load_image`, used to produce the trace.

    steps : int
        Number of frames in the resulting .gif

    leading_point : bool
        If true, adds a large point at the head of each chain, so you can
        follow the path easier.

    filename : str
        Place to save the resulting .gif to

    max_size : float
        Used to set the figsize for the image, maintaining the aspect ratio.
        In inches!

    interval : int
        How long each frame lasts. Pretty sure this is hundredths of seconds

    dpi : int
        Quality of the resulting .gif Seems like larger values make the gif
        bigger too.

    colors : iterable
        You can set custom colors to cycle through! Default is the rainbow.

    plot_kwargs :
        Other keyword arguments passed to the trace plotting. Some useful
        examples are marker='.' in case you sampled lots of points, alpha=0.3
        to add transparency to the points, or linestyle='-', so you can see the
        actual path the chains took.

    Returns
    -------
    str
        filename where the gif was saved
    """
    default_kwargs = {'marker': 'o', 'linestyle': '', 'alpha': 0.4}
    default_kwargs.update(plot_kwargs)
    if colors is None:
        colors = get_rainbow()
    else:
        colors = cycler('color', colors)

    vals = [trace.get_values('image', chains=chain) for chain in trace.chains]
    intervals = np.linspace(0, vals[0].shape[0] - 1, num=steps + 1, dtype=int)[1:]  # noqa

    fig, ax = plt.subplots(figsize=get_figsize(image, max_size))
    ax.set_prop_cycle(colors)
    ax.set_xlim((0, image.shape[1]))
    ax.set_ylim((0, image.shape[0]))
    ax.axis('off')

    lines, points = [], []
    for _ in vals:
        lines.append(ax.plot([], [], **default_kwargs)[0])
        if leading_point:
            points.append(ax.plot([], [], 'o', c=lines[-1].get_color(), markersize=20)[0])   # noqa
        else:
            points.append(None)

    def update(idx):
        if idx < len(intervals):
            for pts, lns, val in zip(points, lines, vals):
                lns.set_data(val[:intervals[idx], 1], val[:intervals[idx], 0])
                if leading_point:
                    pts.set_data(val[intervals[idx], 1], val[intervals[idx], 0])
        elif idx == len(intervals) and leading_point:
            for pts in points:
                pts.set_data([], [])

        return ax

    anim = FuncAnimation(fig, update, frames=np.arange(steps + 20), interval=interval)   # noqa
    anim.save(filename, dpi=dpi, writer='imagemagick')
    return filename


def get_figsize(image, max_size=10):
    """Helper to scale figures"""
    scale = max_size / max(image.shape)
    return (scale * image.shape[1], scale * image.shape[0])


def _process_image_trace(trace, image, blur):
    """Adds Gaussian blur"""
    w, h = image.shape[:2]
    colors = ('red', 'green', 'blue')
    channels = [np.zeros((w, h)) for color in colors]
    for color, channel in zip(colors, channels):
        for idx in np.array(np.round(trace[color]), dtype=int):
            x, y = idx
            channel[min(x, w - 1), min(y, h - 1)] += 1
    return [scipy.ndimage.filters.gaussian_filter(channel, blur) for channel in channels]   # noqa


def plot_multitrace_color(trace, image, blur=8, channel_max=None):
    """Plot the trace from a color image

    Does additive blending of the three channels using Pillow. Higher `blur`
    make the colors look right, but the image look blurrier.

    Parameters
    ----------
    trace : pymc3.MultiTrace
        Get this from sample_color

    image : numpy.ndarray
        Image array from `load_image`, used to produce the trace.

    blur : float
        Each point only colors in a single pixel, but a gaussian blur makes the
        samples blend well. This typically must be tuned by eye.

    channel_max : list or None
        This is used internally to normalize channels for making a gif

    Returns
    -------
    PIL.Image
        RGB image of the samples
    """
    smoothed = _process_image_trace(trace, image, blur)
    if channel_max is None:
        channel_max = [channel.max() for channel in smoothed]

    pils = []
    for channel, c_max in zip(smoothed, channel_max):
        pils.append(Image.fromarray(np.uint8(255 * np.flipud(channel / c_max))))
    return Image.merge('RGB', pils)


def make_color_gif(trace, image, blur=8, steps=200, max_size=10, filename='output.gif',
                   interval=30, dpi=20):
    """Make a gif of the color trace. SUPER EXPERIMENTAL!

    Tries to grab portions of the trace from


    Parameters
    ----------
    trace : pymc3.MultiTrace
        Get this from sample_grayscale

    image : numpy.ndarray
        Image array from `load_image`, used to produce the trace.

    blur : float
        Each point only colors in a single pixel, but a gaussian blur makes the
        samples blend well.  This typically must be tuned by eye.
    steps : int
        Number of frames in the resulting .gif

    max_size : float
        Used to set the figsize for the image, maintaining the aspect ratio. In
        inches!

    leading_point : bool
        If true, adds a large point at the head of each chain, so you can
        follow the path easier.

    filename : str
        Place to save the resulting .gif to

    interval : int
        How long each frame lasts. Pretty sure this is hundredths of seconds

    dpi : int
        Quality of the resulting .gif Seems like larger values make the gif
        bigger too.

    Returns
    -------
    str
        filename where the gif was saved
    """
    figsize = get_figsize(image, max_size=max_size)

    intervals = np.linspace(0, len(trace) - 1, num=steps + 1, dtype=int)[1:]

    fig, ax = plt.subplots(figsize=figsize)
    ax.imshow(np.zeros_like(image))
    ax.axis('off')
    channel_max = [channel.max() for channel in _process_image_trace(trace, image, blur)]  # noqa

    with tqdm(total=steps) as pbar:
        def update(idx):
            color_image = plot_multitrace_color(trace[:intervals[idx]], image, blur=blur,
                                                channel_max=channel_max)
            ax.imshow(color_image)
            pbar.update(1)
            return ax

        anim = FuncAnimation(fig, update, frames=np.arange(steps), interval=interval)  # noqa
        anim.save(filename, dpi=dpi, writer='imagemagick')
    return filename