"""
Support for cftime axis in matplotlib.

"""

from __future__ import (absolute_import, division, print_function)
from six.moves import (filter, input, map, range, zip)  # noqa

from collections import namedtuple

import matplotlib.dates as mdates
import matplotlib.ticker as mticker
import matplotlib.transforms as mtransforms
import matplotlib.units as munits
import cftime
import numpy as np

# Define __version__ based on versioneer's interpretation.
from ._version import get_versions
__version__ = get_versions()['version']
del get_versions


# Lower and upper are in number of days.
FormatOption = namedtuple('FormatOption', ['lower', 'upper', 'format_string'])


class CalendarDateTime(object):
    """
    Container for :class:`cftime.datetime` object and calendar.

    """
    def __init__(self, datetime, calendar):
        self.datetime = datetime
        self.calendar = calendar

    def __eq__(self, other):
        return (isinstance(other, self.__class__) and
                self.datetime == other.datetime and
                self.calendar == other.calendar)

    def __ne__(self, other):
        return not self.__eq__(other)

    def __repr__(self):
        msg = '<{}: datetime={}, calendar={}>'
        return msg.format(type(self).__name__, self.datetime, self.calendar)


class NetCDFTimeDateFormatter(mticker.Formatter):
    """
    Formatter for cftime.datetime data.

    """
    # Some magic numbers. These seem to work pretty well.
    format_options = [FormatOption(0.0, 0.2, '%H:%M:%S'),
                      FormatOption(0.2, 0.8, '%H:%M'),
                      FormatOption(0.8, 15, '%Y-%m-%d %H:%M'),
                      FormatOption(15, 90, '%Y-%m-%d'),
                      FormatOption(90, 900, '%Y-%m'),
                      FormatOption(900, 6000000, '%Y')]

    def __init__(self, locator, calendar, time_units):
        #: The locator associated with this formatter. This is used to get hold
        #: of the scaling information.
        self.locator = locator
        self.calendar = calendar
        self.time_units = time_units

    def pick_format(self, ndays):
        """
        Returns a format string for an interval of the given number of days.

        """
        for option in self.format_options:
            if option.lower < ndays <= option.upper:
                return option.format_string
        else:
            msg = 'No formatter found for an interval of {} days.'
            raise ValueError(msg.format(ndays))

    def __call__(self, x, pos=0):
        format_string = self.pick_format(ndays=self.locator.ndays)
        dt = cftime.utime(self.time_units, self.calendar).num2date(x)
        return dt.strftime(format_string)


class NetCDFTimeDateLocator(mticker.Locator):
    """
    Determines tick locations when plotting cftime.datetime data.

    """
    def __init__(self, max_n_ticks, calendar, date_unit, min_n_ticks=3):
        # The date unit must be in the form of days since ...

        self.max_n_ticks = max_n_ticks
        self.min_n_ticks = min_n_ticks
        self._max_n_locator = mticker.MaxNLocator(max_n_ticks, integer=True)
        self._max_n_locator_days = mticker.MaxNLocator(
            max_n_ticks, integer=True, steps=[1, 2, 4, 7, 10])
        self.calendar = calendar
        self.date_unit = date_unit
        if not self.date_unit.lower().startswith('days since'):
            msg = 'The date unit must be days since for a NetCDF time locator.'
            raise ValueError(msg)

        self._cached_resolution = {}

    def compute_resolution(self, num1, num2, date1, date2):
        """
        Returns the resolution of the dates (hourly, minutely, yearly), and
        an **approximate** number of those units.

        """
        num_days = float(np.abs(num1 - num2))
        resolution = 'SECONDLY'
        n = mdates.SEC_PER_DAY
        if num_days * mdates.MINUTES_PER_DAY > self.max_n_ticks:
            resolution = 'MINUTELY'
            n = int(num_days / mdates.MINUTES_PER_DAY)
        if num_days * mdates.HOURS_PER_DAY > self.max_n_ticks:
            resolution = 'HOURLY'
            n = int(num_days / mdates.HOURS_PER_DAY)
        if num_days > self.max_n_ticks:
            resolution = 'DAILY'
            n = int(num_days)
        if num_days > 30 * self.max_n_ticks:
            resolution = 'MONTHLY'
            n = num_days // 30
        if num_days > 365 * self.max_n_ticks:
            resolution = 'YEARLY'
            n = abs(date1.year - date2.year)

        return resolution, n

    def __call__(self):
        vmin, vmax = self.axis.get_view_interval()
        return self.tick_values(vmin, vmax)

    def tick_values(self, vmin, vmax):
        vmin, vmax = mtransforms.nonsingular(vmin, vmax, expander=1e-7,
                                             tiny=1e-13)

        self.ndays = float(abs(vmax - vmin))

        utime = cftime.utime(self.date_unit, self.calendar)
        lower = utime.num2date(vmin)
        upper = utime.num2date(vmax)

        resolution, n = self.compute_resolution(vmin, vmax, lower, upper)

        if resolution == 'YEARLY':
            # TODO START AT THE BEGINNING OF A DECADE/CENTURY/MILLENIUM as
            # appropriate.
            years = self._max_n_locator.tick_values(lower.year, upper.year)
            ticks = [cftime.datetime(int(year), 1, 1) for year in years]
        elif resolution == 'MONTHLY':
            # TODO START AT THE BEGINNING OF A DECADE/CENTURY/MILLENIUM as
            # appropriate.
            months_offset = self._max_n_locator.tick_values(0, n)
            ticks = []
            for offset in months_offset:
                year = lower.year + np.floor((lower.month + offset) / 12)
                month = ((lower.month + offset) % 12) + 1
                ticks.append(cftime.datetime(int(year), int(month), 1))
        elif resolution == 'DAILY':
            # TODO: It would be great if this favoured multiples of 7.
            days = self._max_n_locator_days.tick_values(vmin, vmax)
            ticks = [utime.num2date(dt) for dt in days]
        elif resolution == 'HOURLY':
            hour_unit = 'hours since 2000-01-01'
            hour_utime = cftime.utime(hour_unit, self.calendar)
            in_hours = hour_utime.date2num([lower, upper])
            hours = self._max_n_locator.tick_values(in_hours[0], in_hours[1])
            ticks = [hour_utime.num2date(dt) for dt in hours]
        elif resolution == 'MINUTELY':
            minute_unit = 'minutes since 2000-01-01'
            minute_utime = cftime.utime(minute_unit, self.calendar)
            in_minutes = minute_utime.date2num([lower, upper])
            minutes = self._max_n_locator.tick_values(in_minutes[0],
                                                      in_minutes[1])
            ticks = [minute_utime.num2date(dt) for dt in minutes]
        elif resolution == 'SECONDLY':
            second_unit = 'seconds since 2000-01-01'
            second_utime = cftime.utime(second_unit, self.calendar)
            in_seconds = second_utime.date2num([lower, upper])
            seconds = self._max_n_locator.tick_values(in_seconds[0],
                                                      in_seconds[1])
            ticks = [second_utime.num2date(dt) for dt in seconds]
        else:
            msg = 'Resolution {} not implemented yet.'.format(resolution)
            raise ValueError(msg)

        return utime.date2num(ticks)


class NetCDFTimeConverter(mdates.DateConverter):
    """
    Converter for cftime.datetime data.

    """
    standard_unit = 'days since 2000-01-01'

    @staticmethod
    def axisinfo(unit, axis):
        """
        Returns the :class:`~matplotlib.units.AxisInfo` for *unit*.

        *unit* is a tzinfo instance or None.
        The *axis* argument is required but not used.
        """
        calendar, date_unit, date_type = unit

        majloc = NetCDFTimeDateLocator(4, calendar=calendar,
                                       date_unit=date_unit)
        majfmt = NetCDFTimeDateFormatter(majloc, calendar=calendar,
                                         time_units=date_unit)
        if date_type is CalendarDateTime:
            datemin = CalendarDateTime(cftime.datetime(2000, 1, 1),
                                       calendar=calendar)
            datemax = CalendarDateTime(cftime.datetime(2010, 1, 1),
                                       calendar=calendar)
        else:
            datemin = date_type(2000, 1, 1)
            datemax = date_type(2010, 1, 1)
        return munits.AxisInfo(majloc=majloc, majfmt=majfmt, label='',
                               default_limits=(datemin, datemax))

    @classmethod
    def default_units(cls, sample_point, axis):
        """
        Computes some units for the given data point.

        """
        if hasattr(sample_point, '__iter__'):
            # Deal with nD `sample_point` arrays.
            if isinstance(sample_point, np.ndarray):
                sample_point = sample_point.reshape(-1)
            calendars = np.array([point.calendar for point in sample_point])
            if np.all(calendars == calendars[0]):
                calendar = calendars[0]
            else:
                raise ValueError('Calendar units are not all equal.')
            date_type = type(sample_point[0])
        else:
            # Deal with a single `sample_point` value.
            if not hasattr(sample_point, 'calendar'):
                msg = ('Expecting cftimes with an extra '
                       '"calendar" attribute.')
                raise ValueError(msg)
            else:
                calendar = sample_point.calendar
            date_type = type(sample_point)
        return calendar, cls.standard_unit, date_type

    @classmethod
    def convert(cls, value, unit, axis):
        """
        Converts value, if it is not already a number or sequence of numbers,
        with :func:`cftime.utime().date2num`.

        """
        shape = None
        if isinstance(value, np.ndarray):
            # Don't do anything with numeric types.
            if value.dtype != np.object:
                return value
            shape = value.shape
            value = value.reshape(-1)
            first_value = value[0]
        else:
            # Don't do anything with numeric types.
            if munits.ConversionInterface.is_numlike(value):
                return value
            first_value = value

        if not isinstance(first_value, (CalendarDateTime, cftime.datetime)):
            raise ValueError('The values must be numbers or instances of '
                             '"nc_time_axis.CalendarDateTime" or '
                             '"cftime.datetime".')

        if isinstance(first_value, CalendarDateTime):
            if not isinstance(first_value.datetime, cftime.datetime):
                raise ValueError('The datetime attribute of the '
                                 'CalendarDateTime object must be of type '
                                 '`cftime.datetime`.')

        ut = cftime.utime(cls.standard_unit, calendar=first_value.calendar)

        if isinstance(value, (CalendarDateTime, cftime.datetime)):
            value = [value]

        if isinstance(first_value, CalendarDateTime):
            result = ut.date2num([v.datetime for v in value])
        else:
            result = ut.date2num(value)

        if shape is not None:
            result = result.reshape(shape)

        return result


# Automatically register NetCDFTimeConverter with matplotlib.unit's converter
# dictionary.
if CalendarDateTime not in munits.registry:
    munits.registry[CalendarDateTime] = NetCDFTimeConverter()

CFTIME_TYPES = [cftime.DatetimeNoLeap, cftime.DatetimeAllLeap,
                cftime.DatetimeProlepticGregorian, cftime.DatetimeGregorian,
                cftime.Datetime360Day, cftime.DatetimeJulian]
for date_type in CFTIME_TYPES:
    if date_type not in munits.registry:
        munits.registry[date_type] = NetCDFTimeConverter()