"""
Common matplotlib utilities
"""
import uuid
import os

from matplotlib.ticker import FuncFormatter
from matplotlib.dates import DateFormatter
import matplotlib.pyplot as plt
import numpy as np
import pandas

import tia.util.fmt as fmt
from tia.util.decorator import DeferredExecutionMixin


class _CustomDateFormatter(DateFormatter):
    """Extend so I can use with pandas Period objects """

    def __call__(self, x, pos=0):
        if not hasattr(x, 'strftime'):
            x = pandas.to_datetime(x)
        x = x.strftime(self.fmt)
        return x


class _AxisFormat(DeferredExecutionMixin):
    def __init__(self, parent):
        super(_AxisFormat, self).__init__()
        self.parent = parent

    @property
    def X(self):
        """Provide ability for user to switch from X to Y and vice versa"""
        return self.parent.X

    @property
    def Y(self):
        """Provide ability for user to switch from X to Y and vice versa"""
        return self.parent.Y

    @property
    def axes(self):
        return self.parent.axes

    def percent(self, precision=2):
        fct = fmt.new_percent_formatter(precision=precision)
        wrapper = lambda x, pos: fct(x)
        self.axis.set_major_formatter(FuncFormatter(wrapper))
        return self

    def thousands(self, precision=1):
        fct = fmt.new_thousands_formatter(precision=precision)
        wrapper = lambda x, pos: fct(x)
        self.axis.set_major_formatter(FuncFormatter(wrapper))
        return self

    def millions(self, precision=1):
        fct = fmt.new_millions_formatter(precision=precision)
        wrapper = lambda x, pos: fct(x)
        self.axis.set_major_formatter(FuncFormatter(wrapper))
        return self

    def date(self, fmt='%Y-%m-%d'):
        fmtfct = DateFormatter(fmt)
        self.axis.set_major_formatter(fmtfct)
        return self

    def apply_format(self, fmtfct=lambda x: x):
        wrapper = lambda x, pos: fmtfct(x)
        self.axis.set_major_formatter(FuncFormatter(wrapper))
        return self

    def apply(self, axes=None):
        self.parent.apply(axes=axes)


class _YAxisFormat(_AxisFormat):
    @property
    def axis(self):
        return self.axes.yaxis

    def rotate(self, rot=40, ha='right'):
        rotate_labels(self.axes, which='y', rot=rot, ha=ha)
        return self

    def label(self, txt, **kwargs):
        self.axes.set_ylabel(txt, **kwargs)
        return self


class _XAxisFormat(_AxisFormat):
    @property
    def axis(self):
        return self.axes.xaxis

    def rotate(self, rot=40, ha='right'):
        rotate_labels(self.axes, which='x', rot=rot, ha=ha)
        return self

    def label(self, txt, **kwargs):
        self.axes.set_xlabel(txt, **kwargs)
        return self


class AxesFormat(DeferredExecutionMixin):
    def __init__(self):
        super(AxesFormat, self).__init__()
        self.X = _XAxisFormat(self)
        self.Y = _YAxisFormat(self)
        self.axes = None

    def apply(self, axes=None):
        self.axes = axes or plt.gca()
        self.X()
        self.Y()
        self()

    def tight_layout(self, pad=1.08, h_pad=None, w_pad=None, rect=None):
        plt.tight_layout(pad, h_pad, w_pad, rect)
        return self


class FigureHelper(object):
    def __init__(self, basedir=None, ext='.pdf', dpi=None):
        if not basedir:
            import tempfile

            basedir = tempfile.gettempdir()
        self.basedir = basedir
        self.last = None
        self.ext = ext
        self.fnmap = {}

        self.ax = None
        self.axiter = None
        self.figure = None
        self.dpi = dpi or 100

    def keys(self):
        return self.fnmap.keys()

    def next_ax(self):
        self.ax = self.axiter.next()
        return self.ax

    def __getitem__(self, item):
        return self.fnmap[item]

    def savefig(self, fn=None, dpi=None, clear=1, ext=None, key=None):
        ext = ext or self.ext
        ext = ext.startswith('.') and ext or '.' + ext
        fn = fn or uuid.uuid1()
        key = key or ''
        fn = '%s%s%s' % (key, fn, ext)
        fn = os.path.join(self.basedir, fn)

        figure = self.figure
        use_plt = 0
        if figure is None:
            figure = plt.gcf()
            use_plt = 1

        figure.savefig(fn, dpi=dpi or self.dpi)
        if clear:
            use_plt and plt.close() or figure.clf()
        if key:
            self.fnmap[key] = fn
        self.last = fn
        return fn

    def subplots(self, *params, **kwargs):
        f, ax = plt.subplots(*params, **kwargs)

        def axes_iter(axes):
            if not hasattr(axes, '__iter__'):
                return iter(list([axes]))
            else:
                if not hasattr(axes[0], '__iter__'):
                    return iter(axes)
                else:
                    # array of arrays
                    return iter([y for x in axes for y in x])

        self.axiter = axes_iter(ax)
        self.figure = f
        return self.next_ax()


def rotate_labels(ax, which='x', rot=40, ha='right'):
    which = which.upper()

    def _apply(lbls):
        for lbl in lbls:
            lbl.set_ha(ha)
            lbl.set_rotation(rot)

    'X' in which and _apply(ax.get_xticklabels())
    'Y' in which and _apply(ax.get_yticklabels())


class GridHelper(object):
    @staticmethod
    def build(numobjs, ncols, **subplot_kwargs):
        nrows = int(np.ceil(float(numobjs) / float(ncols)))
        fig, axes = plt.subplots(nrows=nrows, ncols=ncols, **subplot_kwargs)

        if nrows == 1:
            axes = [axes]
        if ncols == 1:
            axes = [[ax] for ax in axes]
        return GridHelper(axes, nrows, ncols, fig=fig)

    def __init__(self, axarr, nrows, ncols, fig=None):
        self.axarr = axarr
        self.nrows = nrows
        self.ncols = ncols
        self.fig = fig

    def __iter__(self):
        import itertools

        flat = list(itertools.chain.from_iterable(self.axarr))
        return iter(flat)

    def get_axes(self, idx):
        """ Allow for simple indexing """
        cidx = 0
        if idx > 0:
            cidx = idx % self.ncols
        ridx = idx / self.ncols
        return self.axarr[ridx][cidx]

    def get_last_row(self):
        return self.axarr[self.nrows - 1]

    def get_first_col(self):
        """ Return the array of Axes objects for the first column """
        return [ax[0] for ax in self.axarr]