""" Command line interface for running YATSM algorithms on individual pixels """ import datetime as dt import logging import re import click import matplotlib as mpl import matplotlib.cm # noqa import matplotlib.pyplot as plt import numpy as np import patsy import yaml from . import options, console from ..algorithms import postprocess from ..config_parser import convert_config, parse_config_file from ..io import read_pixel_timeseries from ..utils import csvfile_to_dataframe, get_image_IDs from ..regression.transforms import harm # noqa avail_plots = ['TS', 'DOY', 'VAL'] SEASONS = { 'winter': ((11, 12, 1, 2, 3), 'b', 0.5), 'spring': ((4, 5), 'c', 0.5), 'summer': ((6, 7, 8), 'g', 1.,), 'fall': ((9, 10), 'y', 0.5) } _DEFAULT_PLOT_CMAP = ('viridis', 'cubehelix', 'jet') PLOT_CMAP = _DEFAULT_PLOT_CMAP[-1] for _cmap in _DEFAULT_PLOT_CMAP: if _cmap in mpl.cm.cmap_d: PLOT_CMAP = _cmap break plot_styles = [] if hasattr(mpl, 'style'): plot_styles = mpl.style.available if hasattr(plt, 'xkcd'): plot_styles.append('xkcd') logger = logging.getLogger('yatsm') @click.command(short_help='Run YATSM algorithm on individual pixels') @options.arg_config_file @click.argument('px', metavar='<px>', nargs=1, type=click.INT) @click.argument('py', metavar='<py>', nargs=1, type=click.INT) @click.option('--band', metavar='<n>', nargs=1, type=click.INT, default=1, show_default=True, help='Band to plot') @click.option('--plot', default=('TS',), multiple=True, show_default=True, type=click.Choice(avail_plots), help='Plot type') @click.option('--ylim', metavar='<min> <max>', nargs=2, type=float, show_default=True, help='Y-axis limits') @click.option('--style', metavar='<style>', default='ggplot', show_default=True, type=click.Choice(plot_styles), help='Plot style') @click.option('--cmap', metavar='<cmap>', default=PLOT_CMAP, show_default=True, help='DOY/VAL plot colormap') @click.option('--embed', is_flag=True, help='Drop to (I)Python interpreter at various points') @click.option('--seed', help='Set NumPy RNG seed value') @click.option('--algo_kw', multiple=True, callback=options.callback_dict, help='Algorithm parameter overrides') @click.option('--result_prefix', type=str, default='', show_default=True, multiple=True, help='Plot coef/rmse from refit that used this prefix') @click.option('--seasons', is_flag=True, help='Plot using seasonal symbology') @click.pass_context def pixel(ctx, config, px, py, band, plot, ylim, style, cmap, embed, seed, algo_kw, result_prefix, seasons): # Set seed np.random.seed(seed) # Convert band to index band -= 1 # Format result prefix if result_prefix: result_prefix = set((_pref if _pref[-1] == '_' else _pref + '_') for _pref in result_prefix) result_prefix.add('') # add in no prefix to show original fit else: result_prefix = ('', ) # Get colormap if cmap not in mpl.cm.cmap_d: raise click.ClickException('Cannot find specified colormap ({}) in ' 'matplotlib'.format(cmap)) # Parse config cfg = parse_config_file(config) # Apply algorithm overrides for kw in algo_kw: value = yaml.load(algo_kw[kw]) cfg = trawl_replace_keys(cfg, kw, value) if algo_kw: # revalidate configuration cfg = convert_config(cfg) # Dataset information df = csvfile_to_dataframe(cfg['dataset']['input_file'], date_format=cfg['dataset']['date_format']) df['image_ID'] = get_image_IDs(df['filename']) df['x'] = df['date'] dates = df['date'].values # Initialize timeseries model model = cfg['YATSM']['algorithm_object'] algo_cfg = cfg[cfg['YATSM']['algorithm']] yatsm = model(estimator=cfg['YATSM']['estimator'], **algo_cfg.get('init', {})) yatsm.px = px yatsm.py = py # Setup algorithm and create design matrix (if needed) X = yatsm.setup(df, **cfg) design_info = getattr(X, 'design_info', None) # Read pixel data Y = read_pixel_timeseries(df['filename'], px, py) if Y.shape[0] != cfg['dataset']['n_bands']: raise click.ClickException( 'Number of bands in image {f} ({nf}) do not match number in ' 'configuration file ({nc})'.format( f=df['filename'][0], nf=Y.shape[0], nc=cfg['dataset']['n_bands'])) # Preprocess pixel data X, Y, dates = yatsm.preprocess(X, Y, dates, **cfg['dataset']) # Convert ordinal to datetime dt_dates = np.array([dt.datetime.fromordinal(d) for d in dates]) # Plot before fitting with plt.xkcd() if style == 'xkcd' else mpl.style.context(style): for _plot in plot: if _plot == 'TS': plot_TS(dt_dates, Y[band, :], seasons) elif _plot == 'DOY': plot_DOY(dt_dates, Y[band, :], cmap) elif _plot == 'VAL': plot_VAL(dt_dates, Y[band, :], cmap) if ylim: plt.ylim(ylim) plt.title('Timeseries: px={px} py={py}'.format(px=px, py=py)) plt.ylabel('Band {b}'.format(b=band + 1)) plt.tight_layout() plt.show() # Fit model yatsm.fit(X, Y, dates, **algo_cfg.get('fit', {})) for prefix, estimator, stay_reg, fitopt in zip( cfg['YATSM']['refit']['prefix'], cfg['YATSM']['refit']['prediction_object'], cfg['YATSM']['refit']['stay_regularized'], cfg['YATSM']['refit']['fit']): yatsm.record = postprocess.refit_record( yatsm, prefix, estimator, fitopt=fitopt, keep_regularized=stay_reg) # Plot after predictions with plt.xkcd() if style == 'xkcd' else mpl.style.context(style): for _plot in plot: if _plot == 'TS': plot_TS(dt_dates, Y[band, :], seasons) elif _plot == 'DOY': plot_DOY(dt_dates, Y[band, :], cmap) elif _plot == 'VAL': plot_VAL(dt_dates, Y[band, :], cmap) if ylim: plt.ylim(ylim) plt.title('Timeseries: px={px} py={py}'.format(px=px, py=py)) plt.ylabel('Band {b}'.format(b=band + 1)) for _prefix in set(result_prefix): plot_results(band, cfg, yatsm, design_info, result_prefix=_prefix, plot_type=_plot) plt.tight_layout() plt.show() if embed: console.open_interpreter( yatsm, message=("Additional functions:\n" "plot_TS, plot_DOY, plot_VAL, plot_results"), variables={ 'config': cfg, }, funcs={ 'plot_TS': plot_TS, 'plot_DOY': plot_DOY, 'plot_VAL': plot_VAL, 'plot_results': plot_results } ) def plot_TS(dates, y, seasons): """ Create a standard timeseries plot Args: dates (iterable): sequence of datetime y (np.ndarray): variable to plot seasons (bool): Plot seasonal symbology """ # Plot data if seasons: months = np.array([d.month for d in dates]) for season_months, color, alpha in SEASONS.values(): season_idx = np.in1d(months, season_months) plt.plot(dates[season_idx], y[season_idx], marker='o', mec=color, mfc=color, alpha=alpha, ls='') else: plt.scatter(dates, y, c='k', marker='o', edgecolors='none', s=35) plt.xlabel('Date') def plot_DOY(dates, y, mpl_cmap): """ Create a DOY plot Args: dates (iterable): sequence of datetime y (np.ndarray): variable to plot mpl_cmap (colormap): matplotlib colormap """ doy = np.array([d.timetuple().tm_yday for d in dates]) year = np.array([d.year for d in dates]) sp = plt.scatter(doy, y, c=year, cmap=mpl_cmap, marker='o', edgecolors='none', s=35) plt.colorbar(sp) months = mpl.dates.MonthLocator() # every month months_fmrt = mpl.dates.DateFormatter('%b') plt.tick_params(axis='x', which='minor', direction='in', pad=-10) plt.axes().xaxis.set_minor_locator(months) plt.axes().xaxis.set_minor_formatter(months_fmrt) plt.xlim(1, 366) plt.xlabel('Day of Year') def plot_VAL(dates, y, mpl_cmap, reps=2): """ Create a "Valerie Pasquarella" plot (repeated DOY plot) Args: dates (iterable): sequence of datetime y (np.ndarray): variable to plot mpl_cmap (colormap): matplotlib colormap reps (int, optional): number of additional repetitions """ doy = np.array([d.timetuple().tm_yday for d in dates]) year = np.array([d.year for d in dates]) # Replicate `reps` times _doy = doy.copy() for r in range(1, reps + 1): _doy = np.concatenate((_doy, doy + r * 366)) _year = np.tile(year, reps + 1) _y = np.tile(y, reps + 1) sp = plt.scatter(_doy, _y, c=_year, cmap=mpl_cmap, marker='o', edgecolors='none', s=35) plt.colorbar(sp) plt.xlabel('Day of Year') def plot_results(band, cfg, model, design_info, result_prefix='', plot_type='TS'): """ Plot model results Args: band (int): plot results for this band cfg (dict): YATSM configuration dictionary model (YATSM model): fitted YATSM timeseries model design_info (patsy.DesignInfo): patsy design information result_prefix (str): Prefix to 'coef' and 'rmse' plot_type (str): type of plot to add results to (TS, DOY, or VAL) """ # Results prefix result_k = model.record.dtype.names coef_k = result_prefix + 'coef' rmse_k = result_prefix + 'rmse' if coef_k not in result_k or rmse_k not in result_k: raise KeyError('Cannot find result prefix "{}" in results' .format(result_prefix)) if result_prefix: click.echo('Using "{}" re-fitted results'.format(result_prefix)) # Handle reverse step = -1 if cfg['YATSM']['reverse'] else 1 # Remove categorical info from predictions design = re.sub(r'[\+\-][\ ]+C\(.*\)', '', cfg['YATSM']['design_matrix']) i_coef = [] for k, v in design_info.column_name_indexes.items(): if not re.match('C\(.*\)', k): i_coef.append(v) i_coef = np.sort(np.asarray(i_coef)) _prefix = result_prefix or cfg['YATSM']['prediction'] for i, r in enumerate(model.record): label = 'Model {i} ({prefix})'.format(i=i, prefix=_prefix) if plot_type == 'TS': # Prediction mx = np.arange(r['start'], r['end'], step) mX = patsy.dmatrix(design, {'x': mx}).T my = np.dot(r[coef_k][i_coef, band], mX) mx_date = np.array([dt.datetime.fromordinal(int(_x)) for _x in mx]) # Break if r['break']: bx = dt.datetime.fromordinal(r['break']) plt.axvline(bx, c='red', lw=2) elif plot_type in ('DOY', 'VAL'): yr_end = dt.datetime.fromordinal(r['end']).year yr_start = dt.datetime.fromordinal(r['start']).year yr_mid = int(yr_end - (yr_end - yr_start) / 2) mx = np.arange(dt.date(yr_mid, 1, 1).toordinal(), dt.date(yr_mid + 1, 1, 1).toordinal(), 1) mX = patsy.dmatrix(design, {'x': mx}).T my = np.dot(r[coef_k][i_coef, band], mX) mx_date = np.array([dt.datetime.fromordinal(d).timetuple().tm_yday for d in mx]) label = 'Model {i} - {yr} ({prefix})'.format(i=i, yr=yr_mid, prefix=_prefix) plt.plot(mx_date, my, lw=3, label=label) leg = plt.legend() leg.draggable(state=True) # UTILITY FUNCTIONS def trawl_replace_keys(d, key, value, s=''): """ Return modified dictionary ``d`` """ md = d.copy() for _key in md: if isinstance(md[_key], dict): # Recursively replace md[_key] = trawl_replace_keys(md[_key], key, value, s='{}[{}]'.format(s, _key)) else: if _key == key: s += '[{}]'.format(_key) click.echo('Replacing d{k}={ov} with {nv}' .format(k=s, ov=md[_key], nv=value)) md[_key] = value return md