""" Command line interface for running YATSM algorithms on individual pixels
"""
import datetime as dt
import logging
import re

import click
import matplotlib as mpl
import matplotlib.cm  # noqa
import matplotlib.pyplot as plt
import numpy as np
import patsy
import yaml

from . import options, console
from ..algorithms import postprocess
from ..config_parser import convert_config, parse_config_file
from ..io import read_pixel_timeseries
from ..utils import csvfile_to_dataframe, get_image_IDs
from ..regression.transforms import harm  # noqa

avail_plots = ['TS', 'DOY', 'VAL']

SEASONS = {
    'winter': ((11, 12, 1, 2, 3), 'b', 0.5),
    'spring': ((4, 5), 'c', 0.5),
    'summer': ((6, 7, 8), 'g', 1.,),
    'fall': ((9, 10), 'y', 0.5)
}

_DEFAULT_PLOT_CMAP = ('viridis', 'cubehelix', 'jet')
PLOT_CMAP = _DEFAULT_PLOT_CMAP[-1]
for _cmap in _DEFAULT_PLOT_CMAP:
    if _cmap in mpl.cm.cmap_d:
        PLOT_CMAP = _cmap
        break

plot_styles = []
if hasattr(mpl, 'style'):
    plot_styles = mpl.style.available
if hasattr(plt, 'xkcd'):
    plot_styles.append('xkcd')

logger = logging.getLogger('yatsm')


@click.command(short_help='Run YATSM algorithm on individual pixels')
@options.arg_config_file
@click.argument('px', metavar='<px>', nargs=1, type=click.INT)
@click.argument('py', metavar='<py>', nargs=1, type=click.INT)
@click.option('--band', metavar='<n>', nargs=1, type=click.INT, default=1,
              show_default=True, help='Band to plot')
@click.option('--plot', default=('TS',), multiple=True, show_default=True,
              type=click.Choice(avail_plots), help='Plot type')
@click.option('--ylim', metavar='<min> <max>', nargs=2, type=float,
              show_default=True, help='Y-axis limits')
@click.option('--style', metavar='<style>', default='ggplot',
              show_default=True, type=click.Choice(plot_styles),
              help='Plot style')
@click.option('--cmap', metavar='<cmap>', default=PLOT_CMAP,
              show_default=True, help='DOY/VAL plot colormap')
@click.option('--embed', is_flag=True,
              help='Drop to (I)Python interpreter at various points')
@click.option('--seed', help='Set NumPy RNG seed value')
@click.option('--algo_kw', multiple=True, callback=options.callback_dict,
              help='Algorithm parameter overrides')
@click.option('--result_prefix', type=str, default='', show_default=True,
              multiple=True,
              help='Plot coef/rmse from refit that used this prefix')
@click.option('--seasons', is_flag=True, help='Plot using seasonal symbology')
@click.pass_context
def pixel(ctx, config, px, py, band, plot, ylim, style, cmap,
          embed, seed, algo_kw, result_prefix, seasons):
    # Set seed
    np.random.seed(seed)
    # Convert band to index
    band -= 1
    # Format result prefix
    if result_prefix:
        result_prefix = set((_pref if _pref[-1] == '_' else _pref + '_')
                            for _pref in result_prefix)
        result_prefix.add('')  # add in no prefix to show original fit
    else:
        result_prefix = ('', )

    # Get colormap
    if cmap not in mpl.cm.cmap_d:
        raise click.ClickException('Cannot find specified colormap ({}) in '
                                   'matplotlib'.format(cmap))

    # Parse config
    cfg = parse_config_file(config)

    # Apply algorithm overrides
    for kw in algo_kw:
        value = yaml.load(algo_kw[kw])
        cfg = trawl_replace_keys(cfg, kw, value)
    if algo_kw:  # revalidate configuration
        cfg = convert_config(cfg)

    # Dataset information
    df = csvfile_to_dataframe(cfg['dataset']['input_file'],
                              date_format=cfg['dataset']['date_format'])
    df['image_ID'] = get_image_IDs(df['filename'])
    df['x'] = df['date']
    dates = df['date'].values

    # Initialize timeseries model
    model = cfg['YATSM']['algorithm_object']
    algo_cfg = cfg[cfg['YATSM']['algorithm']]
    yatsm = model(estimator=cfg['YATSM']['estimator'],
                  **algo_cfg.get('init', {}))
    yatsm.px = px
    yatsm.py = py

    # Setup algorithm and create design matrix (if needed)
    X = yatsm.setup(df, **cfg)
    design_info = getattr(X, 'design_info', None)

    # Read pixel data
    Y = read_pixel_timeseries(df['filename'], px, py)
    if Y.shape[0] != cfg['dataset']['n_bands']:
        raise click.ClickException(
            'Number of bands in image {f} ({nf}) do not match number in '
            'configuration file ({nc})'.format(
                f=df['filename'][0],
                nf=Y.shape[0],
                nc=cfg['dataset']['n_bands']))

    # Preprocess pixel data
    X, Y, dates = yatsm.preprocess(X, Y, dates, **cfg['dataset'])

    # Convert ordinal to datetime
    dt_dates = np.array([dt.datetime.fromordinal(d) for d in dates])

    # Plot before fitting
    with plt.xkcd() if style == 'xkcd' else mpl.style.context(style):
        for _plot in plot:
            if _plot == 'TS':
                plot_TS(dt_dates, Y[band, :], seasons)
            elif _plot == 'DOY':
                plot_DOY(dt_dates, Y[band, :], cmap)
            elif _plot == 'VAL':
                plot_VAL(dt_dates, Y[band, :], cmap)

            if ylim:
                plt.ylim(ylim)
            plt.title('Timeseries: px={px} py={py}'.format(px=px, py=py))
            plt.ylabel('Band {b}'.format(b=band + 1))
            plt.tight_layout()
            plt.show()

    # Fit model
    yatsm.fit(X, Y, dates, **algo_cfg.get('fit', {}))
    for prefix, estimator, stay_reg, fitopt in zip(
            cfg['YATSM']['refit']['prefix'],
            cfg['YATSM']['refit']['prediction_object'],
            cfg['YATSM']['refit']['stay_regularized'],
            cfg['YATSM']['refit']['fit']):
        yatsm.record = postprocess.refit_record(
            yatsm, prefix, estimator,
            fitopt=fitopt, keep_regularized=stay_reg)

    # Plot after predictions
    with plt.xkcd() if style == 'xkcd' else mpl.style.context(style):
            for _plot in plot:
                if _plot == 'TS':
                    plot_TS(dt_dates, Y[band, :], seasons)
                elif _plot == 'DOY':
                    plot_DOY(dt_dates, Y[band, :], cmap)
                elif _plot == 'VAL':
                    plot_VAL(dt_dates, Y[band, :], cmap)

                if ylim:
                    plt.ylim(ylim)
                plt.title('Timeseries: px={px} py={py}'.format(px=px, py=py))
                plt.ylabel('Band {b}'.format(b=band + 1))

                for _prefix in set(result_prefix):
                    plot_results(band, cfg, yatsm, design_info,
                                 result_prefix=_prefix,
                                 plot_type=_plot)

                plt.tight_layout()
                plt.show()

    if embed:
        console.open_interpreter(
            yatsm,
            message=("Additional functions:\n"
                     "plot_TS, plot_DOY, plot_VAL, plot_results"),
            variables={
                'config': cfg,
            },
            funcs={
                'plot_TS': plot_TS, 'plot_DOY': plot_DOY,
                'plot_VAL': plot_VAL, 'plot_results': plot_results
            }
        )


def plot_TS(dates, y, seasons):
    """ Create a standard timeseries plot

    Args:
        dates (iterable): sequence of datetime
        y (np.ndarray): variable to plot
        seasons (bool): Plot seasonal symbology
    """
    # Plot data
    if seasons:
        months = np.array([d.month for d in dates])
        for season_months, color, alpha in SEASONS.values():
            season_idx = np.in1d(months, season_months)
            plt.plot(dates[season_idx], y[season_idx], marker='o',
                     mec=color, mfc=color, alpha=alpha, ls='')
    else:
        plt.scatter(dates, y, c='k', marker='o', edgecolors='none', s=35)
    plt.xlabel('Date')


def plot_DOY(dates, y, mpl_cmap):
    """ Create a DOY plot

    Args:
        dates (iterable): sequence of datetime
        y (np.ndarray): variable to plot
        mpl_cmap (colormap): matplotlib colormap
    """
    doy = np.array([d.timetuple().tm_yday for d in dates])
    year = np.array([d.year for d in dates])

    sp = plt.scatter(doy, y, c=year, cmap=mpl_cmap,
                     marker='o', edgecolors='none', s=35)
    plt.colorbar(sp)

    months = mpl.dates.MonthLocator()  # every month
    months_fmrt = mpl.dates.DateFormatter('%b')

    plt.tick_params(axis='x', which='minor', direction='in', pad=-10)
    plt.axes().xaxis.set_minor_locator(months)
    plt.axes().xaxis.set_minor_formatter(months_fmrt)

    plt.xlim(1, 366)
    plt.xlabel('Day of Year')


def plot_VAL(dates, y, mpl_cmap, reps=2):
    """ Create a "Valerie Pasquarella" plot (repeated DOY plot)

    Args:
        dates (iterable): sequence of datetime
        y (np.ndarray): variable to plot
        mpl_cmap (colormap): matplotlib colormap
        reps (int, optional): number of additional repetitions
    """
    doy = np.array([d.timetuple().tm_yday for d in dates])
    year = np.array([d.year for d in dates])

    # Replicate `reps` times
    _doy = doy.copy()
    for r in range(1, reps + 1):
        _doy = np.concatenate((_doy, doy + r * 366))
    _year = np.tile(year, reps + 1)
    _y = np.tile(y, reps + 1)

    sp = plt.scatter(_doy, _y, c=_year, cmap=mpl_cmap,
                     marker='o', edgecolors='none', s=35)
    plt.colorbar(sp)
    plt.xlabel('Day of Year')


def plot_results(band, cfg, model, design_info,
                 result_prefix='', plot_type='TS'):
    """ Plot model results

    Args:
        band (int): plot results for this band
        cfg (dict): YATSM configuration dictionary
        model (YATSM model): fitted YATSM timeseries model
        design_info (patsy.DesignInfo): patsy design information
        result_prefix (str): Prefix to 'coef' and 'rmse'
        plot_type (str): type of plot to add results to (TS, DOY, or VAL)
    """
    # Results prefix
    result_k = model.record.dtype.names
    coef_k = result_prefix + 'coef'
    rmse_k = result_prefix + 'rmse'
    if coef_k not in result_k or rmse_k not in result_k:
        raise KeyError('Cannot find result prefix "{}" in results'
                       .format(result_prefix))
    if result_prefix:
        click.echo('Using "{}" re-fitted results'.format(result_prefix))

    # Handle reverse
    step = -1 if cfg['YATSM']['reverse'] else 1

    # Remove categorical info from predictions
    design = re.sub(r'[\+\-][\ ]+C\(.*\)', '',
                    cfg['YATSM']['design_matrix'])

    i_coef = []
    for k, v in design_info.column_name_indexes.items():
        if not re.match('C\(.*\)', k):
            i_coef.append(v)
    i_coef = np.sort(np.asarray(i_coef))

    _prefix = result_prefix or cfg['YATSM']['prediction']
    for i, r in enumerate(model.record):
        label = 'Model {i} ({prefix})'.format(i=i, prefix=_prefix)
        if plot_type == 'TS':
            # Prediction
            mx = np.arange(r['start'], r['end'], step)
            mX = patsy.dmatrix(design, {'x': mx}).T

            my = np.dot(r[coef_k][i_coef, band], mX)
            mx_date = np.array([dt.datetime.fromordinal(int(_x)) for _x in mx])
            # Break
            if r['break']:
                bx = dt.datetime.fromordinal(r['break'])
                plt.axvline(bx, c='red', lw=2)

        elif plot_type in ('DOY', 'VAL'):
            yr_end = dt.datetime.fromordinal(r['end']).year
            yr_start = dt.datetime.fromordinal(r['start']).year
            yr_mid = int(yr_end - (yr_end - yr_start) / 2)

            mx = np.arange(dt.date(yr_mid, 1, 1).toordinal(),
                           dt.date(yr_mid + 1, 1, 1).toordinal(), 1)
            mX = patsy.dmatrix(design, {'x': mx}).T

            my = np.dot(r[coef_k][i_coef, band], mX)
            mx_date = np.array([dt.datetime.fromordinal(d).timetuple().tm_yday
                                for d in mx])

            label = 'Model {i} - {yr} ({prefix})'.format(i=i, yr=yr_mid,
                                                         prefix=_prefix)

        plt.plot(mx_date, my, lw=3, label=label)
    leg = plt.legend()
    leg.draggable(state=True)


# UTILITY FUNCTIONS
def trawl_replace_keys(d, key, value, s=''):
    """ Return modified dictionary ``d``
    """
    md = d.copy()
    for _key in md:
        if isinstance(md[_key], dict):
            # Recursively replace
            md[_key] = trawl_replace_keys(md[_key], key, value,
                                          s='{}[{}]'.format(s, _key))
        else:
            if _key == key:
                s += '[{}]'.format(_key)
                click.echo('Replacing d{k}={ov} with {nv}'
                           .format(k=s, ov=md[_key], nv=value))
                md[_key] = value
    return md