from pyrocko import cake_plot as cp
from pyrocko import orthodrome as otd

from pymc3 import plots as pmp

import math
import os
import logging
import copy

from beat import utility
from beat.models import Stage, load_stage
from beat.sampler.metropolis import get_trace_stats
from beat.heart import init_seismic_targets, init_geodetic_targets, \
                       physical_bounds
from beat.config import ffi_mode_str, geometry_mode_str, dist_vars

from matplotlib import pyplot as plt
from matplotlib.patches import Rectangle, FancyArrow
from matplotlib.collections import PatchCollection
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.ticker as tick

from scipy.stats import kde
import numpy as num
from pyrocko.guts import (Object, String, Dict, List,
                          Bool, Int, load, StringChoice)
from pyrocko import util, trace
from pyrocko.cake_plot import str_to_mpl_color as scolor
from pyrocko.cake_plot import light
from pyrocko.plot import beachball, nice_value, AutoScaler

import pyrocko.moment_tensor as mt
from pyrocko.plot import mpl_papersize, mpl_init, mpl_graph_color, mpl_margins

logger = logging.getLogger('plotting')

km = 1000.


__all__ = [
    'PlotOptions', 'correlation_plot', 'correlation_plot_hist',
    'get_result_point', 'seismic_fits', 'geodetic_fits', 'traceplot',
    'select_transform', 'histplot_op']

u_nm = '$[Nm]$'
u_km = '$[km]$'
u_km_s = '$[km/s]$'
u_deg = '$[^{\circ}]$'
u_m = '$[m]$'
u_v = '$[m^3]$'
u_s = '$[s]$'
u_rad = '$[rad]$'
u_hyp = ''

plot_units = {
    'east_shift': u_km,
    'north_shift': u_km,
    'depth': u_km,
    'width': u_km,
    'length': u_km,

    'dip': u_deg,
    'dip1': u_deg,
    'dip2': u_deg,
    'strike': u_deg,
    'strike1': u_deg,
    'strike2': u_deg,
    'rake': u_deg,
    'rake1': u_deg,
    'rake2': u_deg,
    'mix': u_hyp,

    'volume_change': u_v,
    'diameter': u_km,
    'slip': u_m,
    'azimuth': u_deg,
    'bl_azimuth': u_deg,
    'amplitude': u_nm,
    'bl_amplitude': u_m,
    'locking_depth': u_km,

    'nucleation_dip': u_km,
    'nucleation_strike': u_km,
    'nucleation_x': u_hyp,
    'nucleation_y': u_hyp,
    'time_shift': u_s,
    'uperp': u_m,
    'uparr': u_m,
    'durations': u_s,
    'velocities': u_km_s,

    'mnn': u_nm,
    'mee': u_nm,
    'mdd': u_nm,
    'mne': u_nm,
    'mnd': u_nm,
    'med': u_nm,
    'magnitude': u_hyp,

    'u': u_rad,
    'v': u_rad,
    'kappa': u_rad,
    'sigma': u_rad,
    'h': u_hyp,

    'distance': u_km,
    'delta_depth': u_km,
    'delta_time': u_s,
    'time': u_s,
    'duration': u_s,
    'peak_ratio': u_hyp,
    'h_': u_hyp,
    'like': u_hyp}


plot_projections = ['latlon', 'local']


def hypername(varname):
    if varname[0:2] == 'h_':
        return 'h_'
    return varname


class PlotOptions(Object):
    post_llk = String.T(
        default='max',
        help='Which model to plot on the specified plot; Default: "max";'
             ' Options: "max", "min", "mean", "all"')
    plot_projection = StringChoice.T(
        default='local',
        choices=plot_projections,
        help='Projection to use for plotting geodetic data; options: "latlon"')
    utm_zone = Int.T(
        default=36,
        optional=True,
        help='Only relevant if plot_projection is "utm"')
    load_stage = Int.T(
        default=-1,
        help='Which stage to select for plotting')
    figure_dir = String.T(
        default='figures',
        help='Name of the output directory of plots')
    reference = Dict.T(
        default={},
        help='Reference point for example from a synthetic test.',
        optional=True)
    outformat = String.T(default='pdf')
    dpi = Int.T(default=300)
    force = Bool.T(default=False)
    varnames = List.T(
        default=[], optional=True, help='Names of variables to plot')
    source_idxs = List.T(
        default=None,
        optional=True,
        help='Indexes to patches of slip distribution to draw marginals for')
    nensemble = Int.T(
        default=1,
        help='Number of draws from the PPD to display fuzzy results.')


def str_dist(dist):
    """
    Return string representation of distance.
    """
    if dist < 10.0:
        return '%g m' % dist
    elif 10. <= dist < 1. * km:
        return '%.0f m' % dist
    elif 1. * km <= dist < 10. * km:
        return '%.1f km' % (dist / km)
    else:
        return '%.0f km' % (dist / km)


def str_duration(t):
    """
    Convert time to str representation.
    """
    s = ''
    if t < 0.:
        s = '-'

    t = abs(t)

    if t < 10.0:
        return s + '%.2g s' % t
    elif 10.0 <= t < 3600.:
        return s + util.time_to_str(t, format='%M:%S min')
    elif 3600. <= t < 24 * 3600.:
        return s + util.time_to_str(t, format='%H:%M h')
    else:
        return s + '%.1f d' % (t / (24. * 3600.))


def kde2plot_op(ax, x, y, grid=200, **kwargs):
    xmin = x.min()
    xmax = x.max()
    ymin = y.min()
    ymax = y.max()
    extent = kwargs.pop('extent', [])
    if len(extent) != 4:
        extent = [xmin, xmax, ymin, ymax]

    grid = grid * 1j
    X, Y = num.mgrid[xmin:xmax:grid, ymin:ymax:grid]
    positions = num.vstack([X.ravel(), Y.ravel()])
    values = num.vstack([x.ravel(), y.ravel()])
    kernel = kde.gaussian_kde(values)
    Z = num.reshape(kernel(positions).T, X.shape)

    ax.imshow(num.rot90(Z), extent=extent, **kwargs)


def kde2plot(x, y, grid=200, ax=None, **kwargs):
    if ax is None:
        _, ax = plt.subplots(1, 1, squeeze=True)
    kde2plot_op(ax, x, y, grid, **kwargs)
    return ax


def correlation_plot(
        mtrace, varnames=None,
        transform=lambda x: x, figsize=None, cmap=None, grid=200, point=None,
        point_style='.', point_color='white', point_size='8'):
    """
    Plot 2d marginals (with kernel density estimation) showing the correlations
    of the model parameters.

    Parameters
    ----------
    mtrace : :class:`pymc3.base.MutliTrace`
        Mutlitrace instance containing the sampling results
    varnames : list of variable names
        Variables to be plotted, if None all variable are plotted
    transform : callable
        Function to transform data (defaults to identity)
    figsize : figure size tuple
        If None, size is (12, num of variables * 2) inch
    cmap : matplotlib colormap
    grid : resolution of kernel density estimation
    point : dict
        Dictionary of variable name / value  to be overplotted as marker
        to the posteriors e.g. mean of posteriors, true values of a simulation
    point_style : str
        style of marker according to matplotlib conventions
    point_color : str or tuple of 3
        color according to matplotlib convention
    point_size : str
        marker size according to matplotlib conventions

    Returns
    -------
    fig : figure object
    axs : subplot axis handles
    """

    if varnames is None:
        varnames = mtrace.varnames

    nvar = len(varnames)

    if figsize is None:
        figsize = mpl_papersize('a4', 'landscape')

    fig, axs = plt.subplots(
        sharey='row', sharex='col',
        nrows=nvar - 1, ncols=nvar - 1, figsize=figsize)

    d = dict()
    for var in varnames:
        d[var] = transform(
            mtrace.get_values(
                var, combine=True, squeeze=True))

    for k in range(nvar - 1):
        a = d[varnames[k]]
        for l in range(k + 1, nvar):
            logger.debug('%s, %s' % (varnames[k], varnames[l]))
            b = d[varnames[l]]

            kde2plot(
                a, b, grid=grid, ax=axs[l - 1, k], cmap=cmap, aspect='auto')

            if point is not None:
                axs[l - 1, k].plot(
                    point[varnames[k]], point[varnames[l]],
                    color=point_color, marker=point_style,
                    markersize=point_size)

            axs[l - 1, k].tick_params(direction='in')

            if k == 0:
                axs[l - 1, k].set_ylabel(varnames[l])

        axs[l - 1, k].set_xlabel(varnames[k])

    for k in range(nvar - 1):
        for l in range(k):
            fig.delaxes(axs[l, k])

    fig.tight_layout()
    fig.subplots_adjust(wspace=0.05, hspace=0.05)
    return fig, axs


def correlation_plot_hist(
        mtrace, varnames=None,
        transform=lambda x: x, figsize=None, hist_color='orange', cmap=None,
        grid=50, chains=None, ntickmarks=2, point=None,
        point_style='.', point_color='red', point_size=4, alpha=0.35,
        unify=True):
    """
    Plot 2d marginals (with kernel density estimation) showing the correlations
    of the model parameters. In the main diagonal is shown the parameter
    histograms.

    Parameters
    ----------
    mtrace : :class:`pymc3.base.MutliTrace`
        Mutlitrace instance containing the sampling results
    varnames : list of variable names
        Variables to be plotted, if None all variable are plotted
    transform : callable
        Function to transform data (defaults to identity)
    figsize : figure size tuple
        If None, size is (12, num of variables * 2) inch
    cmap : matplotlib colormap
    hist_color : str or tuple of 3
        color according to matplotlib convention
    grid : resolution of kernel density estimation
    chains : int or list of ints
        chain indexes to select from the trace
    ntickmarks : int
        number of ticks at the axis labels
    point : dict
        Dictionary of variable name / value  to be overplotted as marker
        to the posteriors e.g. mean of posteriors, true values of a simulation
    point_style : str
        style of marker according to matplotlib conventions
    point_color : str or tuple of 3
        color according to matplotlib convention
    point_size : str
        marker size according to matplotlib conventions
    unify: bool
        If true axis units that belong to one group e.g. [km] will
        have common axis increments

    Returns
    -------
    fig : figure object
    axs : subplot axis handles
    """
    fontsize = 9
    ntickmarks_max = 2
    label_pad = 25
    logger.info('Drawing correlation figure ...')

    if varnames is None:
        varnames = mtrace.varnames

    nvar = len(varnames)

    if figsize is None:
        if nvar < 5:
            figsize = mpl_papersize('a5', 'landscape')
        else:
            figsize = mpl_papersize('a4', 'landscape')

    fig, axs = plt.subplots(nrows=nvar, ncols=nvar, figsize=figsize)

    d = dict()

    for var in varnames:
        d[var] = transform(
            mtrace.get_values(
                var, chains=chains, combine=True, squeeze=True))

    hist_ylims = []
    for k in range(nvar):
        v_namea = varnames[k]
        a = d[v_namea]

        for l in range(k, nvar):
            v_nameb = varnames[l]
            logger.debug('%s, %s' % (v_namea, v_nameb))
            if l == k:
                if point is not None:
                    if v_namea in point.keys():
                        reference = point[v_namea]
                        axs[l, k].axvline(
                            x=reference, color=point_color,
                            lw=point_size / 4.)
                    else:
                        reference = None
                else:
                    reference = None

                histplot_op(
                    axs[l, k], pmp.utils.make_2d(a), alpha=alpha,
                    color='orange', tstd=0., reference=reference)

                axs[l, k].get_yaxis().set_visible(False)
                format_axes(axs[l, k])
                xticks = axs[l, k].get_xticks()
                xlim = axs[l, k].get_xlim()
                hist_ylims.append(axs[l, k].get_ylim())
            else:
                b = d[v_nameb]

                kde2plot(
                    a, b, grid=grid, ax=axs[l, k], cmap=cmap, aspect='auto')

                bmin = b.min()
                bmax = b.max()

                if point is not None:
                    if v_namea and v_nameb in point.keys():
                        axs[l, k].plot(
                            point[v_namea], point[v_nameb],
                            color=point_color, marker=point_style,
                            markersize=point_size)

                        bmin = num.minimum(bmin, point[v_nameb])
                        bmax = num.maximum(bmax, point[v_nameb])

                yticker = tick.MaxNLocator(nbins=ntickmarks)
                axs[l, k].set_xticks(xticks)
                axs[l, k].set_xlim(xlim)
                yax = axs[l, k].get_yaxis()
                yax.set_major_locator(yticker)

            if l != nvar - 1:
                axs[l, k].get_xaxis().set_ticklabels([])

            if k == 0:
                axs[l, k].set_ylabel(
                    v_nameb + '\n ' + plot_units[hypername(v_nameb)],
                    fontsize=fontsize)
                if utility.is_odd(l):
                    axs[l, k].tick_params(axis='y', pad=label_pad)
            else:
                axs[l, k].get_yaxis().set_ticklabels([])

            axs[l, k].tick_params(
                axis='both', direction='in', labelsize=fontsize)
            axs[l, k].tick_params(
                axis='both', labelrotation=50.)
            if utility.is_odd(k):
                axs[l, k].tick_params(axis='x', pad=label_pad)

        axs[l, k].set_xlabel(
            v_namea + '\n ' + plot_units[hypername(v_namea)], fontsize=fontsize)

    if unify:
        varnames_repeat_x = [
            var_reap for varname in varnames for var_reap in (varname,) * nvar]
        varnames_repeat_y = varnames * nvar
        unitiesx = unify_tick_intervals(
            axs, varnames_repeat_x, ntickmarks_max=ntickmarks_max, axis='x')
        apply_unified_axis(
            axs, varnames_repeat_x, unitiesx, axis='x', scale_factor=1.,
            ntickmarks_max=ntickmarks_max)
        apply_unified_axis(
            axs, varnames_repeat_y, unitiesx, axis='y', scale_factor=1.,
            ntickmarks_max=ntickmarks_max)

    for k in range(nvar):
        if unify:
            # reset histogram ylims after unify
            axs[k, k].set_ylim(hist_ylims[k])

        for l in range(k):
            fig.delaxes(axs[l, k])

    fig.tight_layout()
    fig.subplots_adjust(wspace=0.05, hspace=0.05)
    return fig, axs


def plot(uwifg, point_size=20):
    """
    Very simple scatter plot of given IFG for fast inspections.

    Parameters
    ----------
    point_size : int
        determines the size of the scatter plot points
    """

    ax = plt.axes()
    im = ax.scatter(
        uwifg.lons, uwifg.lats, point_size, uwifg.displacement,
        edgecolors='none')
    plt.colorbar(im)
    plt.title('Displacements [m] %s' % uwifg.name)
    plt.show()


def plot_cov(target, point_size=20):

    ax = plt.axes()
    im = ax.scatter(
        target.lons, target.lats, point_size,
        num.array(target.covariance.pred_v.sum(axis=0)).flatten(),
        edgecolors='none')
    plt.colorbar(im)
    plt.title('Prediction Covariance [m2] %s' % target.name)
    plt.show()


def plot_matrix(A):
    """
    Very simple plot of a matrix for fast inspections.
    """
    ax = plt.axes()
    im = ax.matshow(A)
    plt.colorbar(im)
    plt.show()


def plot_log_cov(cov_mat):
    ax = plt.axes()
    mask = num.ones_like(cov_mat)
    mask[cov_mat < 0] = -1.
    im = ax.imshow(num.multiply(num.log(num.abs(cov_mat)), mask))
    plt.colorbar(im)
    plt.show()


def get_result_point(stage, config, point_llk='max'):
    """
    Return point of a given stage result.

    Parameters
    ----------
    stage : :class:`models.Stage`
    config : :class:`config.BEATConfig`
    point_llk : str
        with specified llk(max, mean, min).

    Returns
    -------
    dict
    """
    if point_llk != 'None':
        sampler_name = config.sampler_config.name
        if sampler_name == 'Metropolis':
            if stage.step is None:
                raise AttributeError(
                    'Loading Metropolis results requires'
                    ' sampler parameters to be loaded!')

            sc = config.sampler_config.parameters
            pdict, _ = get_trace_stats(
                stage.mtrace, stage.step, sc.burn, sc.thin)
            point = pdict[point_llk]
        elif sampler_name == 'SMC' or sampler_name == 'PT':
            llk = stage.mtrace.get_values(
                varname='like',
                combine=True)

            posterior_idxs = utility.get_fit_indexes(llk)

            point = stage.mtrace.point(idx=posterior_idxs[point_llk])
        else:
            raise NotImplementedError(
                'Sampler "%s" is not supported!' % config.sampler_config.name)
    else:
        point = None

    return point


def plot_quadtree(ax, data, target, cmap, colim, alpha=0.8):
    """
    Plot UnwrappedIFG displacements on the respective quadtree rectangle.
    """
    rectangles = []
    for E, N, sE, sN in target.quadtree.iter_leaves():
        rectangles.append(
            Rectangle(
                (E / km, N / km),
                width=sE / km,
                height=sN / km,
                edgecolor='black'))

    patch_col = PatchCollection(
        rectangles, match_original=True, alpha=alpha, linewidth=0.5)
    patch_col.set(array=data, cmap=cmap)
    patch_col.set_clim((-colim, colim))

    E = target.quadtree.east_shifts
    N = target.quadtree.north_shifts
    xmin = E.min() / km
    xmax = (E + target.quadtree.sizeE).max() / km
    ymin = N.min() / km
    ymax = (N + target.quadtree.sizeN).max() / km

    ax.add_collection(patch_col)
    ax.set_xlim((xmin, xmax))
    ax.set_ylim((ymin, ymax))
    return patch_col


def plot_scene(ax, target, data, scattersize, colim,
               outmode='latlon', **kwargs):
    if outmode == 'latlon':
        x = target.lons
        y = target.lats
    elif outmode == 'local':
        if target.quadtree is not None:
            cmap = kwargs.pop('cmap', plt.cm.jet)
            return plot_quadtree(ax, data, target, cmap, colim)
        else:
            x = target.east_shifts / km
            y = target.north_shifts / km

    return ax.scatter(
        x, y, scattersize, data,
        edgecolors='none', vmin=-colim, vmax=colim, **kwargs)


def format_axes(ax, remove=['right', 'top', 'left']):
    """
    Removes box top, left and right.
    """
    for rm in remove:
        ax.spines[rm].set_visible(False)


def scale_axes(axis, scale, offset=0.):
    from matplotlib.ticker import ScalarFormatter

    class FormatScaled(ScalarFormatter):

        @staticmethod
        def __call__(value, pos):
            return '{:,.1f}'.format(offset + value * scale).replace(',', ' ')

    axis.set_major_formatter(FormatScaled())


def set_anchor(sources, anchor):
    for source in sources:
        source.anchor = anchor


def geodetic_fits(problem, stage, plot_options):
    """
    Plot geodetic data, synthetics and residuals.
    """
    from pyrocko.dataset import gshhg
    from kite.scene import Scene, UserIOWarning
    import gc

    datatype = 'geodetic'
    mode = problem.config.problem_config.mode
    problem.init_hierarchicals()

    fontsize = 10
    fontsize_title = 12
    ndmax = 3
    nxmax = 3
    cmap = plt.cm.jet

    po = plot_options

    composite = problem.composites[datatype]

    try:
        sources = composite.sources
        ref_sources = None
    except AttributeError:
        logger.info('FFI scene fit, using reference source ...')
        ref_sources = composite.config.gf_config.reference_sources
        set_anchor(ref_sources, anchor='top')
        fault = composite.load_fault_geometry()
        sources = fault.get_all_subfaults(
            datatype=datatype, component=composite.slip_varnames[0])
        set_anchor(sources, anchor='top')

    if po.reference:
        if mode != ffi_mode_str:
            composite.point2sources(po.reference)
            ref_sources = copy.deepcopy(composite.sources)
        point = po.reference
    else:
        point = get_result_point(stage, problem.config, po.post_llk)

    dataset_index = dict(
        (data, i) for (i, data) in enumerate(composite.datasets))

    results = composite.assemble_results(point)
    nrmax = len(results)

    dataset_to_result = {}
    for dataset, result in zip(composite.datasets, results):
        dataset_to_result[dataset] = result

    fullfig, restfig = utility.mod_i(nrmax, ndmax)
    factors = num.ones(fullfig).tolist()
    if restfig:
        factors.append(float(restfig) / ndmax)

    figures = []
    axes = []
    for f in factors:
        figsize = list(mpl_papersize('a4', 'portrait'))
        figsize[1] *= f

        fig, ax = plt.subplots(
            nrows=int(round(ndmax * f)), ncols=nxmax, figsize=figsize)
        fig.tight_layout()
        fig.subplots_adjust(
            left=0.08,
            right=1.0 - 0.03,
            bottom=0.06,
            top=1.0 - 0.06,
            wspace=0.,
            hspace=0.1)
        figures.append(fig)
        axes.append(ax)

    nfigs = len(figures)

    def axis_config(axes, source, scene, po):

        for ax in axes:
            if po.plot_projection == 'latlon':
                ystr = 'Latitude [deg]'
                xstr = 'Longitude [deg]'
                if scene.frame.isDegree():
                    scale_x = {'scale': 1.}
                    scale_y = {'scale': 1.}
                else:
                    scale_x = {'scale': otd.m2d}
                    scale_y = {'scale': otd.m2d}

                scale_x['offset'] = source.lon
                scale_y['offset'] = source.lat

            elif po.plot_projection == 'local':
                ystr = 'Distance [km]'
                xstr = 'Distance [km]'
                if scene.frame.isDegree():
                    scale_x = {'scale': otd.d2m / km}
                    scale_y = {'scale': otd.d2m / km}
                else:
                    scale_x = {'scale': 1. / km}
                    scale_y = {'scale': 1. / km}
            else:
                raise TypeError(
                    'Plot projection %s not available' % po.plot_projection)

            scale_axes(ax.get_xaxis(), **scale_x)
            scale_axes(ax.get_yaxis(), **scale_y)
            ax.set_aspect('equal')

        axes[1].get_yaxis().set_ticklabels([])
        axes[2].get_yaxis().set_ticklabels([])
        axes[1].get_xaxis().set_ticklabels([])
        axes[2].get_xaxis().set_ticklabels([])
        axes[0].set_ylabel(ystr, fontsize=fontsize)
        axes[0].set_xlabel(xstr, fontsize=fontsize)
        ticker = tick.MaxNLocator(nbins=3)

        axes[0].get_xaxis().set_major_locator(ticker)
        axes[0].get_yaxis().set_major_locator(ticker)
        axes[0].tick_params(
            axis='y', labelrotation=90.)

    def draw_coastlines(ax, xlim, ylim, event, scene, po):
        """
        xlim and ylim in Lon/Lat[deg]
        """

        logger.debug('Drawing coastlines ...')
        coasts = gshhg.GSHHG.full()

        if po.plot_projection == 'latlon':
            west, east = xlim
            south, north = ylim

        elif po.plot_projection == 'local':
            lats, lons = otd.ne_to_latlon(
                event.lat, event.lon,
                north_m=num.array(ylim) * km, east_m=num.array(xlim) * km)
            south, north = lats
            west, east = lons

        polygons = coasts.get_polygons_within(
            west=west, east=east, south=south, north=north)

        for p in polygons:
            if (p.is_land() or p.is_antarctic_grounding_line() or
               p.is_island_in_lake()):

                if scene.frame.isMeter():
                    ys, xs = otd.latlon_to_ne_numpy(
                        event.lat, event.lon, p.lats, p.lons)

                elif scene.frame.isDegree():

                    xs = p.lons - event.lon
                    ys = p.lats - event.lat

                ax.plot(xs, ys, '-k', linewidth=0.5)

    def addArrow(ax, scene):
        phi = num.nanmean(scene.phi)
        los_dx = num.cos(phi + num.pi) * .0625
        los_dy = num.sin(phi + num.pi) * .0625

        az_dx = num.cos(phi - num.pi / 2) * .125
        az_dy = num.sin(phi - num.pi / 2) * .125

        anchor_x = .9 if los_dx < 0 else .1
        anchor_y = .85 if los_dx < 0 else .975

        az_arrow = FancyArrow(
            x=anchor_x - az_dx, y=anchor_y - az_dy,
            dx=az_dx, dy=az_dy,
            head_width=.025,
            alpha=.5, fc='k',
            head_starts_at_zero=False,
            length_includes_head=True,
            transform=ax.transAxes)

        los_arrow = FancyArrow(
            x=anchor_x - az_dx / 2, y=anchor_y - az_dy / 2,
            dx=los_dx, dy=los_dy,
            head_width=.02,
            alpha=.5, fc='k',
            head_starts_at_zero=False,
            length_includes_head=True,
            transform=ax.transAxes)

        ax.add_artist(az_arrow)
        ax.add_artist(los_arrow)

    def draw_leaves(ax, scene, offset_e=0, offset_n=0):
        rects = scene.quadtree.getMPLRectangles()
        for r in rects:
            r.set_edgecolor((.4, .4, .4))
            r.set_linewidth(.5)
            r.set_facecolor('none')
            r.set_x(r.get_x() - offset_e)
            r.set_y(r.get_y() - offset_n)
        map(ax.add_artist, rects)

        ax.scatter(scene.quadtree.leaf_coordinates[:, 0] - offset_e,
                   scene.quadtree.leaf_coordinates[:, 1] - offset_n,
                   s=.25, c='black', alpha=.1)

    def draw_sources(ax, sources, scene, po, **kwargs):
        bgcolor = kwargs.pop('color', None)

        for i, source in enumerate(sources):

            if scene.frame.isMeter():
                fn, fe = source.outline(cs='xy').T
            elif scene.frame.isDegree():
                fn, fe = source.outline(cs='latlon').T
                fn -= source.lat
                fe -= source.lon

            if not bgcolor:
                color = mpl_graph_color(i)
            else:
                color = bgcolor

            if fn.size > 1:
                alpha = 0.4
                ax.plot(
                    fe, fn, '-',
                    linewidth=0.5, color=color, alpha=alpha, **kwargs)
                ax.fill(
                    fe, fn,
                    edgecolor=color,
                    facecolor=light(color, .5), alpha=alpha)
                ax.plot(
                    fe[0:2], fn[0:2], '-k', alpha=0.7,
                    linewidth=1.0)
            else:
                ax.plot(
                    fe[:, 0], fn[:, 1], marker='*',
                    markersize=10, color=color, **kwargs)

    def mapDisplacementGrid(displacements, scene):
        arr = num.full_like(scene.displacement, fill_value=num.nan)
        qt = scene.quadtree

        for syn_v, l in zip(displacements, qt.leaves):
            arr[l._slice_rows, l._slice_cols] = syn_v

        arr[scene.displacement_mask] = num.nan
        return arr

    def cbtick(x):
        rx = math.floor(x * 1000.) / 1000.
        return [-rx, rx]

    colims = [num.max([
        num.max(num.abs(r.processed_obs)),
        num.max(num.abs(r.processed_syn))]) for r in results]
    dcolims = [num.max(num.abs(r.processed_res)) for r in results]

    import string
    for idata, dataset in enumerate(composite.datasets):
        subplot_letter = string.ascii_lowercase[idata]
        try:
            homepath = problem.config.geodetic_config.datadir
            scene_path = os.path.join(homepath, dataset.name)
            logger.info(
                'Loading full resolution kite scene: %s' % scene_path)
            scene = Scene.load(scene_path)
        except UserIOWarning:
            logger.warn(
                'Full resolution data could not be loaded! Skipping ...')
            continue

        if scene.frame.isMeter():
            offset_n, offset_e = map(float, otd.latlon_to_ne_numpy(
                scene.frame.llLat, scene.frame.llLon,
                sources[0].lat, sources[0].lon))

        elif scene.frame.isDegree():
            offset_n = sources[0].lat - scene.frame.llLat
            offset_e = sources[0].lon - scene.frame.llLon

        im_extent = (scene.frame.E.min() - offset_e,
                     scene.frame.E.max() - offset_e,
                     scene.frame.N.min() - offset_n,
                     scene.frame.N.max() - offset_n)

        urE, urN, llE, llN = (0., 0., 0., 0.)

        turE, turN, tllE, tllN = zip(
            *[(l.gridE.max() - offset_e,
               l.gridN.max() - offset_n,
               l.gridE.min() - offset_e,
               l.gridN.min() - offset_n)
              for l in scene.quadtree.leaves])

        turE, turN = map(max, (turE, turN))
        tllE, tllN = map(min, (tllE, tllN))
        urE, urN = map(max, ((turE, urE), (urN, turN)))
        llE, llN = map(min, ((tllE, llE), (llN, tllN)))

        lat, lon = otd.ne_to_latlon(
            sources[0].lat, sources[0].lon,
            num.array([llN, urN]), num.array([llE, urE]))

        result = dataset_to_result[dataset]
        tidx = dataset_index[dataset]

        figidx, rowidx = utility.mod_i(tidx, ndmax)
        axs = axes[figidx][rowidx, :]

        imgs = []
        for ax, data_str in zip(axs, ['obs', 'syn', 'res']):
            logger.info('Plotting %s' % data_str)
            datavec = getattr(result, 'processed_%s' % data_str)

            if data_str == 'res' and po.plot_projection == 'local':
                vmin = -dcolims[tidx]
                vmax = dcolims[tidx]
            else:
                vmin = -colims[tidx]
                vmax = colims[tidx]

            imgs.append(ax.imshow(
                mapDisplacementGrid(datavec, scene),
                extent=im_extent, cmap=cmap,
                vmin=vmin, vmax=vmax,
                origin='lower'))

            ax.set_xlim(llE, urE)
            ax.set_ylim(llN, urN)

            draw_leaves(ax, scene, offset_e, offset_n)
            draw_coastlines(
                ax, lon, lat, sources[0], scene, po)

        fontdict = {
            'fontsize': fontsize,
            'fontweight': 'bold',
            'verticalalignment': 'top'}

        transform = axes[figidx][rowidx, 0].transAxes

        if dataset.name[-5::] == 'dscxn':
            title = 'descending'
        elif dataset.name[-5::] == 'ascxn':
            title = 'ascending'
        else:
            title = dataset.name

        axes[figidx][rowidx, 0].text(
            .025, 1.025, '({}) {}'.format(subplot_letter, title),
            fontsize=fontsize_title, alpha=1.,
            va='bottom', transform=transform)
        for i, quantity in enumerate(['data', 'model', 'residual']):
            transform = axes[figidx][rowidx, i].transAxes
            axes[figidx][rowidx, i].text(
                0.5, 0.95, quantity, fontdict, transform=transform,
                horizontalalignment='center')

        draw_sources(
            axes[figidx][rowidx, 1], sources, scene, po)

        if ref_sources:
            ref_color = scolor('aluminium4')
            logger.info('Plotting reference sources')
            draw_sources(
                axes[figidx][rowidx, 1],
                ref_sources, scene, po, color=ref_color)

        f = factors[figidx]
        if f > 2. / 3:
            cbb = (0.68 - (0.3075 * rowidx))
        elif f > 1. / 2:
            cbb = (0.53 - (0.47 * rowidx))
        elif f > 1. / 4:
            cbb = (0.06)

        cbl = 0.46
        cbw = 0.15
        cbh = 0.01

        cbaxes = figures[figidx].add_axes([cbl, cbb, cbw, cbh])

        cblabel = 'LOS displacement [m]'
        cbs = plt.colorbar(
            imgs[1],
            ax=axes[figidx][rowidx, 0],
            ticks=cbtick(colims[tidx]),
            cax=cbaxes,
            orientation='horizontal',
            cmap=cmap)
        cbs.set_label(cblabel, fontsize=fontsize)

        if po.plot_projection == 'local':
            dcbaxes = figures[figidx].add_axes([cbl + 0.3, cbb, cbw, cbh])
            cbr = plt.colorbar(
                imgs[2],
                ax=axes[figidx][rowidx, 2],
                ticks=cbtick(dcolims[tidx]),
                cax=dcbaxes,
                orientation='horizontal',
                cmap=cmap)
            cbr.set_label(cblabel, fontsize=fontsize)

        axis_config(axes[figidx][rowidx, :], sources[0], scene, po)
        addArrow(axes[figidx][rowidx, 0], scene)

        del scene
        gc.collect()

    return figures


def draw_geodetic_fits(problem, plot_options):

    if 'geodetic' not in list(problem.composites.keys()):
        raise TypeError('No geodetic composite defined in the problem!')

    po = plot_options

    stage = Stage(homepath=problem.outfolder,
                  backend=problem.config.sampler_config.backend)

    if not po.reference:
        stage.load_results(
            varnames=problem.varnames,
            model=problem.model, stage_number=po.load_stage,
            load='trace', chains=[-1])
        llk_str = po.post_llk
    else:
        llk_str = 'ref'

    mode = problem.config.problem_config.mode

    outpath = os.path.join(
        problem.config.project_dir,
        mode, po.figure_dir, 'scenes_%s_%s_%s' % (
            stage.number, llk_str, po.plot_projection))

    if not os.path.exists(outpath) or po.force:
        figs = geodetic_fits(problem, stage, po)
    else:
        logger.info('scene plots exist. Use force=True for replotting!')
        return

    if po.outformat == 'display':
        plt.show()
    else:
        logger.info('saving figures to %s' % outpath)
        if po.outformat == 'pdf':
            with PdfPages(outpath + '.pdf') as opdf:
                for fig in figs:
                    opdf.savefig(fig)
        else:
            for i, fig in enumerate(figs):
                fig.savefig(outpath + '_%i.%s' % (i, po.outformat), dpi=po.dpi)


def plot_trace(axes, tr, **kwargs):
    return axes.plot(tr.get_xdata(), tr.get_ydata(), **kwargs)


def plot_taper(axes, t, taper, mode='geometry', **kwargs):
    y = num.ones(t.size) * 0.9
    if mode == 'geometry':
        taper(y, t[0], t[1] - t[0])
    y2 = num.concatenate((y, -y[::-1]))
    t2 = num.concatenate((t, t[::-1]))
    axes.fill(t2, y2, **kwargs)


def plot_dtrace(axes, tr, space, mi, ma, **kwargs):
    t = tr.get_xdata()
    y = tr.get_ydata()
    y2 = (num.concatenate((y, num.zeros(y.size))) - mi) / \
        (ma - mi) * space - (1.0 + space)
    t2 = num.concatenate((t, t[::-1]))
    axes.fill(
        t2, y2,
        clip_on=False,
        **kwargs)


def seismic_fits(problem, stage, plot_options):
    """
    Modified from grond. Plot synthetic and data waveforms and the misfit for
    the selected posterior model.
    """
    composite = problem.composites['seismic']

    fontsize = 8
    fontsize_title = 10

    target_index = dict(
        (target, i) for (i, target) in enumerate(composite.targets))

    po = plot_options

    if not po.reference:
        best_point = get_result_point(stage, problem.config, po.post_llk)
    else:
        best_point = po.reference

    if plot_options.nensemble > 1:
        from tqdm import tqdm
        logger.info(
            'Collecting ensemble of %i synthetic waveforms ...' % po.nensemble)
        nchains = len(stage.mtrace)
        csteps = float(nchains) / po.nensemble
        idxs = num.floor(num.arange(0, nchains, csteps)).astype('int32')
        ens_results = []
        points = []
        for idx in tqdm(idxs):
            point = stage.mtrace.point(idx=idx)
            points.append(point)
            results = composite.assemble_results(point)
            ens_results.append(results)

    if best_point:
        bresults = composite.assemble_results(best_point)
    else:
        # get dummy results for data
        bresults = composite.assemble_results(point)
        best_point = point

    try:
        composite.point2sources(best_point, input_depth='center')
        source = composite.sources[0]
    except AttributeError:
        logger.info('FFI waveform fit, using reference source ...')
        source = composite.config.gf_config.reference_sources[0]
        source.time = composite.event.time

    logger.info('Plotting waveforms ...')
    target_to_results = {}
    all_syn_trs_target = {}
    dtraces = []
    for target in composite.targets:
        target_results = []
        target_synths = []
        i = target_index[target]
        target_results.append(bresults[i])
        target_synths.append(bresults[i].processed_syn)

        dtraces.append(bresults[i].processed_res)
        if plot_options.nensemble > 1:
            for results in ens_results:
                # put all results per target here not only single 
                target_results.append(results[i])
                target_synths.append(results[i].processed_syn)

        target_to_results[target] = target_results
        all_syn_trs_target[target] = target_synths

    skey = lambda tr: tr.channel

#    trace_minmaxs = trace.minmax(all_syn_trs, skey)
    dminmaxs = trace.minmax(dtraces, skey)

    for tr in dtraces:
        if tr:
            dmin, dmax = dminmaxs[skey(tr)]
            tr.ydata /= max(abs(dmin), abs(dmax))

    cg_to_targets = utility.gather(
        composite.targets,
        lambda t: t.codes[3],
        filter=lambda t: t in target_to_results)

    cgs = cg_to_targets.keys()

    figs = []

    for cg in cgs:
        targets = cg_to_targets[cg]

        # can keep from here ... until
        nframes = len(targets)

        nx = int(math.ceil(math.sqrt(nframes)))
        ny = (nframes - 1) // nx + 1

        nxmax = 4
        nymax = 4

        nxx = (nx - 1) // nxmax + 1
        nyy = (ny - 1) // nymax + 1

        xs = num.arange(nx) // ((max(2, nx) - 1.0) / 2.)
        ys = num.arange(ny) // ((max(2, ny) - 1.0) / 2.)

        xs -= num.mean(xs)
        ys -= num.mean(ys)

        fxs = num.tile(xs, ny)
        fys = num.repeat(ys, nx)

        data = []

        for target in targets:
            azi = source.azibazi_to(target)[0]
            dist = source.distance_to(target)
            x = dist * num.sin(num.deg2rad(azi))
            y = dist * num.cos(num.deg2rad(azi))
            data.append((x, y, dist))

        gxs, gys, dists = num.array(data, dtype=num.float).T

        iorder = num.argsort(dists)

        gxs = gxs[iorder]
        gys = gys[iorder]
        targets_sorted = [targets[ii] for ii in iorder]

        gxs -= num.mean(gxs)
        gys -= num.mean(gys)

        gmax = max(num.max(num.abs(gys)), num.max(num.abs(gxs)))
        if gmax == 0.:
            gmax = 1.

        gxs /= gmax
        gys /= gmax

        dists = num.sqrt(
            (fxs[num.newaxis, :] - gxs[:, num.newaxis]) ** 2 +
            (fys[num.newaxis, :] - gys[:, num.newaxis]) ** 2)

        distmax = num.max(dists)

        availmask = num.ones(dists.shape[1], dtype=num.bool)
        frame_to_target = {}
        for itarget, target in enumerate(targets_sorted):
            iframe = num.argmin(
                num.where(availmask, dists[itarget], distmax + 1.))
            availmask[iframe] = False
            iy, ix = num.unravel_index(iframe, (ny, nx))
            frame_to_target[iy, ix] = target

        figures = {}
        for iy in range(ny):
            for ix in range(nx):
                if (iy, ix) not in frame_to_target:
                    continue

                ixx = ix // nxmax
                iyy = iy // nymax
                if (iyy, ixx) not in figures:
                    figures[iyy, ixx] = plt.figure(
                        figsize=mpl_papersize('a4', 'landscape'))

                    figures[iyy, ixx].subplots_adjust(
                        left=0.03,
                        right=1.0 - 0.03,
                        bottom=0.03,
                        top=1.0 - 0.06,
                        wspace=0.2,
                        hspace=0.2)

                    figs.append(figures[iyy, ixx])

                fig = figures[iyy, ixx]

                target = frame_to_target[iy, ix]

                # get min max of all traces
                key = target.codes[3]
                amin, amax = trace.minmax(
                    all_syn_trs_target[target],
                    key=skey)[key]
                # need target specific minmax
                absmax = max(abs(amin), abs(amax))

                ny_this = nymax  # min(ny, nymax)
                nx_this = nxmax  # min(nx, nxmax)
                i_this = (iy % ny_this) * nx_this + (ix % nx_this) + 1

                axes2 = fig.add_subplot(ny_this, nx_this, i_this)

                space = 0.5
                space_factor = 1.0 + space
                axes2.set_axis_off()
                axes2.set_ylim(-1.05 * space_factor, 1.05)

                axes = axes2.twinx()
                axes.set_axis_off()

                ymin, ymax = - absmax * 1.33 * space_factor, absmax * 1.33
                axes.set_ylim(ymin, ymax)

                itarget = target_index[target]
                result = bresults[itarget]

                traces = all_syn_trs_target[target]

                dtrace = dtraces[itarget]

                if po.nensemble > 1:
                    xmin, xmax = trace.minmaxtime(traces, key=skey)[key]
                    extent = [xmin, xmax, ymin, ymax]
                    fuzzy_waveforms(
                        axes, traces, linewidth=7, zorder=0,
                        grid_size=(500, 500), alpha=1.0)

                tap_color_annot = (0.35, 0.35, 0.25)
                tap_color_edge = (0.85, 0.85, 0.80)
                #tap_color_fill = (0.95, 0.95, 0.90)

                plot_taper(
                    axes2, result.processed_obs.get_xdata(), result.taper,
                    mode=composite._mode, fc='None', ec=tap_color_edge,
                    zorder=4, alpha=0.6)

                obs_color = scolor('aluminium5')
                syn_color = scolor('scarletred2')
                misfit_color = scolor('scarletred2')

                if best_point:
                    # only draw if highlighted point exists
                    plot_dtrace(
                        axes2, dtrace, space, 0., 1.,
                        fc=light(misfit_color, 0.3),
                        ec=misfit_color, zorder=4)

                    plot_trace(
                        axes, result.processed_syn,
                        color=syn_color, lw=0.5, zorder=5)

                plot_trace(
                    axes, result.processed_obs,
                    color=obs_color, lw=0.5, zorder=5)

                xdata = result.processed_obs.get_xdata()
                axes.set_xlim(xdata[0], xdata[-1])

                tmarks = [
                    result.processed_obs.tmin,
                    result.processed_obs.tmax]

                for tmark in tmarks:
                    axes2.plot(
                        [tmark, tmark], [-0.9, 0.1], color=tap_color_annot)

                for tmark, text, ha, va in [
                        (tmarks[0],
                         '$\,$ ' + str_duration(tmarks[0] - source.time),
                         'left',
                         'bottom'),
                        (tmarks[1],
                         '$\Delta$ ' + str_duration(tmarks[1] - tmarks[0]),
                         'right',
                         'bottom')]:

                    axes2.annotate(
                        text,
                        xy=(tmark, -0.9),
                        xycoords='data',
                        xytext=(
                            fontsize * 0.4 * [-1, 1][ha == 'left'],
                            fontsize * 0.2),
                        textcoords='offset points',
                        ha=ha,
                        va=va,
                        color=tap_color_annot,
                        fontsize=fontsize, zorder=10)

                scale_string = None

                infos = []
                if scale_string:
                    infos.append(scale_string)

                infos.append('.'.join(x for x in target.codes if x))
                dist = source.distance_to(target)
                azi = source.azibazi_to(target)[0]
                infos.append(str_dist(dist))
                infos.append('%.0f\u00B0' % azi)
                # infos.append('%.3f' % gcms[itarget])
                axes2.annotate(
                    '\n'.join(infos),
                    xy=(0., 1.),
                    xycoords='axes fraction',
                    xytext=(2., 2.),
                    textcoords='offset points',
                    ha='left',
                    va='top',
                    fontsize=fontsize,
                    fontstyle='normal', zorder=10)

                axes2.set_zorder(10)

        for (iyy, ixx), fig in figures.items():
            title = '.'.join(x for x in cg if x)
            if len(figures) > 1:
                title += ' (%i/%i, %i/%i)' % (iyy + 1, nyy, ixx + 1, nxx)

            fig.suptitle(title, fontsize=fontsize_title)

    return figs


def draw_seismic_fits(problem, po):

    if 'seismic' not in list(problem.composites.keys()):
        raise TypeError('No seismic composite defined for this problem!')

    logger.info('Drawing Waveform fits ...')

    stage = Stage(homepath=problem.outfolder,
                  backend=problem.config.sampler_config.backend)

    mode = problem.config.problem_config.mode

    if not po.reference:
        llk_str = po.post_llk
        stage.load_results(
            varnames=problem.varnames,
            model=problem.model, stage_number=po.load_stage,
            load='trace', chains=[-1])
    else:
        llk_str = 'ref'

    outpath = os.path.join(
        problem.config.project_dir,
        mode, po.figure_dir, 'waveforms_%s_%s_%i' % (
            stage.number, llk_str, po.nensemble))

    if not os.path.exists(outpath) or po.force:
        figs = seismic_fits(problem, stage, po)
    else:
        logger.info('waveform plots exist. Use force=True for replotting!')
        return

    if po.outformat == 'display':
        plt.show()
    else:
        logger.info('saving figures to %s' % outpath)
        if po.outformat == 'pdf':
            with PdfPages(outpath + '.pdf') as opdf:
                for fig in figs:
                    opdf.savefig(fig)
        else:
            for i, fig in enumerate(figs):
                fig.savefig(outpath + '_%i.%s' % (i, po.outformat), dpi=po.dpi)


def point2array(point, varnames):
    """
    Concatenate values of point according to order of given varnames.
    """
    if point != None:
        array = num.empty((len(varnames)), dtype='float64')
        for i, varname in enumerate(varnames):
            array[i] = point[varname].ravel()

        return array
    else:
        return None


def extract_mt_components(problem, po, include_magnitude=False):
    """
    Extract Moment Tensor components from problem results for plotting.
    """
    source_type = problem.config.problem_config.source_type
    if source_type == 'MTSource':
        varnames = ['mnn', 'mee', 'mdd', 'mne', 'mnd', 'med']
    elif source_type == 'DCSource':
        varnames = ['strike', 'dip', 'rake']
    else:
        raise ValueError(
            'Plot is only supported for point "MTSource" and "DCSource"')

    if include_magnitude:
        varnames += ['magnitude']

    if not po.reference:
        llk_str = po.post_llk
        stage = load_stage(
            problem, stage_number=po.load_stage, load='trace', chains=[-1])

        n_mts = len(stage.mtrace)
        m6s = num.empty((n_mts, len(varnames)), dtype='float64')
        for i, varname in enumerate(varnames):
            m6s[:, i] = stage.mtrace.get_values(
                varname, combine=True, squeeze=True).ravel()

        csteps = float(n_mts) / po.nensemble
        idxs = num.floor(
            num.arange(0, n_mts, csteps)).astype('int32')
        m6s = m6s[idxs, :]

        point = get_result_point(stage, problem.config, po.post_llk)
        best_mt = point2array(point, varnames=varnames)
    else:
        llk_str = 'ref'
        m6s = [point2array(point=po.reference, varnames=varnames)]
        best_mt = None

    return m6s, best_mt, llk_str


def draw_fuzzy_beachball(problem, po):

    if problem.config.problem_config.n_sources > 1:
        raise NotImplementedError(
            'Fuzzy beachball is not yet implemented for more than one source!')

    if po.load_stage is None:
        po.load_stage = -1

    m6s, best_mt, llk_str = extract_mt_components(problem, po)

    logger.info('Drawing Fuzzy Beachball ...')

    kwargs = {
        'beachball_type': 'full',
        'size': 8,
        'size_units': 'data',
        'position': (5, 5),
        'color_t': 'black',
        'edgecolor': 'black',
        'grid_resolution': 400}

    fig = plt.figure(figsize=(4., 4.))
    fig.subplots_adjust(left=0., right=1., bottom=0., top=1.)
    axes = fig.add_subplot(1, 1, 1)

    outpath = os.path.join(
        problem.outfolder,
        po.figure_dir,
        'fuzzy_beachball_%i_%s_%i.%s' % (
            po.load_stage, llk_str, po.nensemble, po.outformat))

    if not os.path.exists(outpath) or po.force or po.outformat == 'display':

        beachball.plot_fuzzy_beachball_mpl_pixmap(
            m6s, axes, best_mt=best_mt, best_color='red', **kwargs)

        axes.set_xlim(0., 10.)
        axes.set_ylim(0., 10.)
        axes.set_axis_off()

        if not po.outformat == 'display':
            logger.info('saving figure to %s' % outpath)
            fig.savefig(outpath, dpi=po.dpi)
        else:
            plt.show()

    else:
        logger.info('Plot already exists! Please use --force to overwrite!')


def fuzzy_mt_decomposition(
        axes, list_m6s,
        labels=None, colors=None, fontsize=12):
    """
    Plot fuzzy moment tensor decompositions for list of mt ensembles.
    """
    from pymc3 import quantiles

    logger.info('Drawing Fuzzy MT Decomposition ...')

    # beachball kwargs
    kwargs = {
        'beachball_type': 'full',
        'size': 1.,
        'size_units': 'data',
        'edgecolor': 'black',
        'linewidth': 1,
        'grid_resolution': 200}

    def get_decomps(source_vals):

        isos = []
        dcs = []
        clvds = []
        devs = []
        tots = []
        for val in source_vals:
            m = mt.MomentTensor.from_values(val)
            iso, dc, clvd, dev, tot = m.standard_decomposition()
            isos.append(iso)
            dcs.append(dc)
            clvds.append(clvd)
            devs.append(dev)
            tots.append(tot)
        return isos, dcs, clvds, devs, tots

    yscale = 1.3
    nlines = len(list_m6s)
    nlines_max = nlines * yscale

    if colors is None:
        colors = nlines * [None]

    if labels is None:
        labels = ['Ensemble'] + ([None] * (nlines - 1))

    lines = []
    for i, (label, m6s, color) in enumerate(zip(labels, list_m6s, colors)):
        if color is None:
            color = mpl_graph_color(i)

        lines.append(
            (label, m6s, color))

    moments_full_max = mt.magnitude_to_moment(
        max( m6s.mean(axis=0)[-1] for (_, m6s, _) in lines))

    for xpos, label in [
        (0., 'Full'),
        (2., 'Isotropic'),
        (4., 'Deviatoric'),
        (6., 'CLVD'),
        (8., 'DC')]:
        axes.annotate(
            label,
            xy=(1 + xpos, nlines_max),
            xycoords='data',
            xytext=(0., 0.),
            textcoords='offset points',
            ha='center',
            va='center',
            color='black',
            fontsize=fontsize)

    for i, (label, m6s, color_t) in enumerate(lines):
        ypos = nlines_max - (i * yscale) - 1.0

        isos, dcs, clvds, devs, tots = get_decomps(m6s)
        axes.annotate(
            label,
            xy=(-2., ypos),
            xycoords='data',
            xytext=(0., 0.),
            textcoords='offset points',
            ha='left',
            va='center',
            color='black',
            fontsize=fontsize)

        for xpos, decomp, ops in [
            (0., tots, '-'),
            (2., isos, '='),
            (4., devs, '='),
            (6., clvds, '+'),
            (8., dcs, None)]:

            ratios = num.array([comp[1] for comp in decomp])
            ratio = ratios.mean()
            ratios_diff = ratios.max() - ratios.min()

            ratios_qu = quantiles(ratios * 100.)
            mt_parts = [comp[2] for comp in decomp]
            moments_full = num.array([tot[0] for tot in tots])
            size0 = moments_full.mean() / moments_full_max

            if ratio > 1e-4:
                try:
                    kwargs['position'] = (1. + xpos, ypos)
                    kwargs['size'] = math.sqrt(ratio) * 0.95 * size0
                    kwargs['color_t'] = color_t
                    beachball.plot_fuzzy_beachball_mpl_pixmap(
                        mt_parts, axes, best_mt=None, **kwargs)

                    if ratios_diff > 0.:
                        label = '{:03.1f}-{:03.1f}%'.format(ratios_qu[2.5], ratios_qu[97.5])
                    else:
                        label = '{:03.1f}%'.format(ratios_qu[2.5])

                    axes.annotate(
                        label,
                        xy=(1. + xpos, ypos - 0.65),
                        xycoords='data',
                        xytext=(0., 0.),
                        textcoords='offset points',
                        ha='center',
                        va='center',
                        color='black',
                        fontsize=fontsize - 2)

                except beachball.BeachballError as e:
                    logger.warn(str(e))

                    axes.annotate(
                        'ERROR',
                        xy=(1. + xpos, ypos),
                        ha='center',
                        va='center',
                        color='red',
                        fontsize=fontsize)

            else:
                axes.annotate(
                    'N/A',
                    xy=(1. + xpos, ypos),
                    ha='center',
                    va='center',
                    color='black',
                    fontsize=fontsize)

                label = '{:03.1f}%'.format(0.)
                axes.annotate(
                    label,
                    xy=(1. + xpos, ypos - 0.65),
                    xycoords='data',
                    xytext=(0., 0.),
                    textcoords='offset points',
                    ha='center',
                    va='center',
                    color='black',
                    fontsize=fontsize - 2)

            if ops is not None:
                axes.annotate(
                    ops,
                    xy=(2. + xpos, ypos),
                    ha='center',
                    va='center',
                    color='black',
                    fontsize=fontsize)

    axes.axison = False
    axes.set_xlim(-2.25, 9.75)
    axes.set_ylim(-0.7, nlines_max + 0.5)
    axes.set_axis_off()


def draw_fuzzy_mt_decomposition(problem, po):

    fontsize = 10

    if problem.config.problem_config.n_sources > 1:
        raise NotImplementedError(
            'Fuzzy MT decomposition is not yet'
            'implemented for more than one source!')

    if po.load_stage is None:
        po.load_stage = -1

    m6s, _, llk_str = extract_mt_components(problem, po, include_magnitude=True)

    outpath = os.path.join(
        problem.outfolder,
        po.figure_dir,
        'fuzzy_mt_decomposition_%i_%s_%i.%s' % (
            po.load_stage, llk_str, po.nensemble, po.outformat))

    if not os.path.exists(outpath) or po.force or po.outformat == 'display':

        fig = plt.figure(figsize=(6., 2.))
        fig.subplots_adjust(left=0., right=1., bottom=0., top=1.)
        axes = fig.add_subplot(1, 1, 1)

        fuzzy_mt_decomposition(axes, list_m6s=[m6s], fontsize=fontsize)

        if not po.outformat == 'display':
            logger.info('saving figure to %s' % outpath)
            fig.savefig(outpath, dpi=po.dpi)
        else:
            plt.show()

    else:
        logger.info('Plot already exists! Please use --force to overwrite!')


def draw_hudson(problem, po):
    """
    Modified from grond. Plot the hudson graph for the reference event(grey)
    and the best solution (red beachball).
    Also a random number of models from the
    selected stage are plotted as smaller beachballs on the hudson graph.
    """

    from pyrocko.plot import beachball, hudson
    from pyrocko import moment_tensor as mtm
    from numpy import random
    if problem.config.problem_config.n_sources > 1:
        raise NotImplementedError(
            'Hudson plot is not yet implemented for more than one source!')

    if po.load_stage is None:
        po.load_stage = -1

    m6s, best_mt, llk_str = extract_mt_components(problem, po)

    logger.info('Drawing Hudson plot ...')

    fontsize = 12
    beachball_type = 'full'
    color = 'red'
    markersize = fontsize * 1.5
    markersize_small = markersize * 0.2
    beachballsize = markersize
    beachballsize_small = beachballsize * 0.5

    fig = plt.figure(figsize=(4., 4.))
    fig.subplots_adjust(left=0., right=1., bottom=0., top=1.)
    axes = fig.add_subplot(1, 1, 1)
    hudson.draw_axes(axes)

    data = []
    for m6 in m6s:
        mt = mtm.as_mt(m6)
        u, v = hudson.project(mt)

        if random.random() < 0.05:
            try:
                beachball.plot_beachball_mpl(
                    mt, axes,
                    beachball_type=beachball_type,
                    position=(u, v),
                    size=beachballsize_small,
                    color_t='black',
                    alpha=0.5,
                    zorder=1,
                    linewidth=0.25)
            except beachball.BeachballError as e:
                logger.warn(str(e))

        else:
            data.append((u, v))

    if data:
        u, v = num.array(data).T
        axes.plot(
            u, v, 'o',
            color=color,
            ms=markersize_small,
            mec='none',
            mew=0,
            alpha=0.25,
            zorder=0)

    if best_mt is not None:
        mt = mtm.as_mt(best_mt)
        u, v = hudson.project(mt)

        try:
            beachball.plot_beachball_mpl(
                mt, axes,
                beachball_type=beachball_type,
                position=(u, v),
                size=beachballsize,
                color_t=color,
                alpha=0.5,
                zorder=2,
                linewidth=0.25)
        except beachball.BeachballError as e:
            logger.warn(str(e))

    mt = problem.event.moment_tensor
    u, v = hudson.project(mt)

    if not po.reference:
        try:
            beachball.plot_beachball_mpl(
                mt, axes,
                beachball_type=beachball_type,
                position=(u, v),
                size=beachballsize,
                color_t='grey',
                alpha=0.5,
                zorder=2,
                linewidth=0.25)
            logger.info('drawing reference event in grey ...')
        except beachball.BeachballError as e:
            logger.warn(str(e))

    outpath = os.path.join(
        problem.outfolder,
        po.figure_dir,
        'hudson_%i_%s_%i.%s' % (
            po.load_stage, llk_str, po.nensemble, po.outformat))

    if not os.path.exists(outpath) or po.force or po.outformat == 'display':

        if not po.outformat == 'display':
            logger.info('saving figure to %s' % outpath)
            fig.savefig(outpath, dpi=po.dpi)
        else:
            plt.show()

    else:
        logger.info('Plot already exists! Please use --force to overwrite!')


def histplot_op(
        ax, data, reference=None, alpha=.35, color=None, bins=None,
        tstd=None, kwargs={}):
    """
    Modified from pymc3. Additional color argument.
    """
    for i in range(data.shape[1]):
        d = data[:, i]
        mind = d.min()
        maxd = d.max()
        # bins, mind, maxd = pmp.artists.fast_kde(data[:,i])

        if reference is not None:
            mind = num.minimum(mind, reference)
            maxd = num.maximum(maxd, reference)

        if tstd is None:
            tstd = num.std(d)

        step = (maxd - mind) / 40.

        if bins is None:
            bins = int(num.ceil((maxd - mind) / step))

        major, minor = get_matplotlib_version()
        if major < 3:
            kwargs['normed'] = True
        else:
            kwargs['density'] = True

        ax.hist(
            d, bins=bins, stacked=True, alpha=alpha,
            align='left', histtype='stepfilled', color=color, edgecolor=color,
            **kwargs)

        left, right = ax.get_xlim()
        leftb = mind - tstd
        rightb = maxd + tstd

        if left != 0.0 or right != 1.0:
            leftb = num.minimum(leftb, left)
            rightb = num.maximum(rightb, right)

        ax.set_xlim(leftb, rightb)
        xax = ax.get_xaxis()


def unify_tick_intervals(axs, varnames, ntickmarks_max=5, axis='x'):
    """
    Take figure axes objects and determine unit ranges between common
    unit classes (see utility.grouped_vars). Assures that the number of
    increments is not larger than ntickmarks_max. Will thus overwrite

    Returns
    -------
    dict : with types_sets keys and (min_range, max_range) as values
    """
    unities = {}
    for setname in utility.unit_sets.keys():
        unities[setname] = [num.inf, -num.inf]

    def extract_type_range(ax, varname, unities):
        for setname, ranges in unities.items():
            if axis == 'x':
                varrange = num.diff(ax.get_xlim())
            elif axis == 'y':
                varrange = num.diff(ax.get_ylim())
            else:
                raise ValueError('Only "x" or "y" allowed!')

            tset = utility.unit_sets[setname]
            min_range, max_range = ranges
            if varname in tset:
                new_ranges = copy.deepcopy(ranges)
                if varrange < min_range:
                    new_ranges[0] = varrange
                if varrange > max_range:
                    new_ranges[1] = varrange

                unities[setname] = new_ranges

    for ax, varname in zip(axs.ravel('F'), varnames):
        extract_type_range(ax, varname, unities)

    for setname, ranges in unities.items():
        min_range, max_range = ranges
        max_range_frac = max_range / ntickmarks_max
        if max_range_frac > min_range:
            logger.debug(
                'Range difference between min and max for %s is large!'
                ' Extending min_range to %f' % (
                setname, max_range_frac))
            unities[setname] = [max_range_frac, max_range]

    return unities


def apply_unified_axis(axs, varnames, unities, axis='x', ntickmarks_max=3,
                       scale_factor=2 / 3):
    for ax, v in zip(axs.ravel('F'), varnames):
        if v in utility.grouped_vars:
            for setname, varrange in unities.items():
                if v in utility.unit_sets[setname]:
                    inc = nice_value(varrange[0] * scale_factor)
                    autos = AutoScaler(
                        inc=inc, snap='on', approx_ticks=ntickmarks_max)
                    if axis == 'x':
                        min, max = ax.get_xlim()
                    elif axis == 'y':
                        min, max = ax.get_ylim()

                    min, max, sinc = autos.make_scale(
                        (min, max), override_mode='min-max')

                    # check physical bounds if passed truncate
                    phys_min, phys_max = physical_bounds[v]
                    if min < phys_min:
                        min = phys_min
                    if max > phys_max:
                        max = phys_max

                    if axis == 'x':
                        ax.set_xlim((min, max))
                    elif axis == 'y':
                        ax.set_ylim((min, max))

                    ticks = num.arange(min, max + inc, inc).tolist()
                    if axis == 'x':
                        ax.xaxis.set_ticks(ticks)
                    elif axis == 'y':
                        ax.yaxis.set_ticks(ticks)
        else:
            ticker = tick.MaxNLocator(nbins=3)
            if axis == 'x':
                ax.get_xaxis().set_major_locator(ticker)
            elif axis == 'y':
                ax.get_yaxis().set_major_locator(ticker)


def traceplot(trace, varnames=None, transform=lambda x: x, figsize=None,
              lines={}, chains=None, combined=False, grid=False,
              varbins=None, nbins=40, color=None, source_idxs=None,
              alpha=0.35, priors=None, prior_alpha=1, prior_style='--',
              axs=None, posterior=None, fig=None, plot_style='kde',
              prior_bounds={}, unify=True, kwargs={}):
    """
    Plots posterior pdfs as histograms from multiple mtrace objects.

    Modified from pymc3.

    Parameters
    ----------

    trace : result of MCMC run
    varnames : list of variable names
        Variables to be plotted, if None all variable are plotted
    transform : callable
        Function to transform data (defaults to identity)
    posterior : str
        To mark posterior value in distribution 'max', 'min', 'mean', 'all'
    figsize : figure size tuple
        If None, size is (12, num of variables * 2) inch
    lines : dict
        Dictionary of variable name / value  to be overplotted as vertical
        lines to the posteriors and horizontal lines on sample values
        e.g. mean of posteriors, true values of a simulation
    chains : int or list of ints
        chain indexes to select from the trace
    combined : bool
        Flag for combining multiple chains into a single chain. If False
        (default), chains will be plotted separately.
    source_idxs : list
        array like, indexes to sources to plot marginals
    grid : bool
        Flag for adding gridlines to histogram. Defaults to True.
    varbins : list of arrays
        List containing the binning arrays for the variables, if None they will
        be created.
    nbins : int
        Number of bins for each histogram
    color : tuple
        mpl color tuple
    alpha : float
        Alpha value for plot line. Defaults to 0.35.
    axs : axes
        Matplotlib axes. Defaults to None.
    fig : figure
        Matplotlib figure. Defaults to None.
    unify : bool
        If true axis units that belong to one group e.g. [km] will
        have common axis increments
    kwargs : dict
        for histplot op

    Returns
    -------

    ax : matplotlib axes
    """
    ntickmarks = 2
    fontsize = 10
    ntickmarks_max = kwargs.pop('ntickmarks_max', 3)
    scale_factor = kwargs.pop('scale_factor', 2 / 3)

    num.set_printoptions(precision=3)

    def make_bins(data, nbins=40):
        d = data.flatten()
        mind = d.min()
        maxd = d.max()
        return num.linspace(mind, maxd, nbins)

    def remove_var(varnames, varname):
        idx = varnames.index(varname)
        varnames.pop(idx)

    if varnames is None:
        varnames = [name for name in trace.varnames if not name.endswith('_')]

    if 'geo_like' in varnames:
        remove_var(varnames, varname='geo_like')

    if 'seis_like' in varnames:
        remove_var(varnames, varname='seis_like')

    if posterior != 'None':
        llk = trace.get_values(
            'like', combine=combined, chains=chains, squeeze=False)
        llk = num.squeeze(transform(llk[0]))
        llk = pmp.utils.make_2d(llk)

        posterior_idxs = utility.get_fit_indexes(llk)

        colors = {
            'mean': scolor('orange1'),
            'min': scolor('butter1'),
            'max': scolor('scarletred2')}

    n = len(varnames)
    nrow = int(num.ceil(n / 2.))
    ncol = 2

    n_fig = nrow * ncol
    if figsize is None:
        if n < 5:
            figsize = mpl_papersize('a6', 'landscape')
        elif n < 7:
            figsize = mpl_papersize('a5', 'portrait')
        else:
            figsize = mpl_papersize('a4', 'portrait')

    if axs is None:
        fig, axs = plt.subplots(nrow, ncol, figsize=figsize)
        axs = num.atleast_2d(axs)
    elif axs.shape != (nrow, ncol):
        raise TypeError('traceplot requires n*2 subplots %i, %i' % (
                        nrow, ncol))

    if varbins is None:
        make_bins_flag = True
        varbins = []
    else:
        make_bins_flag = False

    input_color = copy.deepcopy(color)
    for i in range(n_fig):
        coli, rowi = utility.mod_i(i, nrow)

        if i > len(varnames) - 1:
            try:
                fig.delaxes(axs[rowi, coli])
            except KeyError:
                pass
        else:
            v = varnames[i]
            color = copy.deepcopy(input_color)

            for d in trace.get_values(
                    v, combine=combined, chains=chains, squeeze=False):
                d = transform(d)
                # iterate over columns in case varsize > 1

                if v in dist_vars:
                    if source_idxs is None:
                        logger.info('No patches defined using 1 every 10!')
                        source_idxs = num.arange(0, d.shape[1], 10).tolist()

                    logger.info(
                        'Plotting patches: %s' % utility.list2string(
                            source_idxs))

                    try:
                        selected = d.T[source_idxs]
                    except IndexError:
                        raise IndexError(
                            'One or several patches do not exist! '
                            'Patch idxs: %s' % utility.list2string(
                                source_idxs))
                else:
                    selected = d.T

                for isource, e in enumerate(selected):
                    e = pmp.utils.make_2d(e)
                    if make_bins_flag:
                        varbin = make_bins(e, nbins=nbins)
                        varbins.append(varbin)
                    else:
                        varbin = varbins[i]

                    if lines:
                        if v in lines:
                            reference = lines[v]
                        else:
                            reference = None
                    else:
                        reference = None

                    if color is None:
                        pcolor = mpl_graph_color(isource)
                    else:
                        pcolor = color

                    if plot_style == 'kde':
                        pmp.kdeplot(
                            e, shade=alpha, ax=axs[rowi, coli],
                            color=color, linewidth=1.,
                            kwargs_shade={'color': pcolor})
                        axs[rowi, coli].relim()
                        axs[rowi, coli].autoscale(tight=False)
                        axs[rowi, coli].set_ylim(0)
                        xax = axs[rowi, coli].get_xaxis()
                        # axs[rowi, coli].set_ylim([0, e.max()])
                        xticker = tick.MaxNLocator(nbins=5)
                        xax.set_major_locator(xticker)
                    elif plot_style == 'hist':
                        histplot_op(
                            axs[rowi, coli], e, reference=reference,
                            bins=varbin, alpha=alpha, color=pcolor,
                            kwargs=kwargs)
                    else:
                        raise NotImplementedError(
                            'Plot style "%s" not implemented' % plot_style)

                    try:
                        param = prior_bounds[v]

                        if v in dist_vars:
                            try:  # variable bounds
                                lower = param.lower[source_idxs]
                                upper = param.upper[source_idxs]
                            except IndexError:
                                lower, upper = param.lower, param.upper

                            title = '{} {}'.format(v, plot_units[hypername(v)])
                        else:
                            lower = num.array2string(
                                param.lower, separator=',')[1:-1]
                            upper = num.array2string(
                                param.upper, separator=',')[1:-1]

                            title = '{} {} priors: ({}; {})'.format(
                                v, plot_units[hypername(v)], lower, upper)
                    except KeyError:
                        try:
                            title = '{} {}'.format(v, float(lines[v]))
                        except KeyError:
                            title = '{} {}'.format(v, plot_units[hypername(v)])

                    axs[rowi, coli].set_xlabel(title, fontsize=fontsize)
                    axs[rowi, coli].grid(grid)
                    axs[rowi, coli].set_yticks([])
                    axs[rowi, coli].set_yticklabels([])
                    format_axes(axs[rowi, coli])
                    axs[rowi, coli].tick_params(axis='x', labelsize=fontsize)
    #                axs[rowi, coli].set_ylabel("Frequency")

                    if lines:
                        try:
                            axs[rowi, coli].axvline(
                                x=lines[v], color="k", lw=1.)
                        except KeyError:
                            pass

                    if posterior != 'None':
                        if posterior == 'all':
                            for k, idx in posterior_idxs.items():
                                axs[rowi, coli].axvline(
                                    x=e[idx], color=colors[k], lw=1.)
                        else:
                            idx = posterior_idxs[posterior]
                            axs[rowi, coli].axvline(
                                x=e[idx], color=pcolor, lw=1.)

    if unify:
        unities = unify_tick_intervals(
            axs, varnames, ntickmarks_max=ntickmarks_max, axis='x')
        apply_unified_axis(axs, varnames, unities, axis='x',
                           scale_factor=scale_factor)

    if source_idxs:
        axs[0, 0].legend(source_idxs)

    fig.tight_layout()
    return fig, axs, varbins


def get_matplotlib_version():
    from matplotlib import __version__ as mplversion
    return float(mplversion[0]), float(mplversion[2:])


def select_transform(sc, n_steps=None):
    """
    Select transform function to be applied after loading the sampling results.

    Parameters
    ----------
    sc : :class:`config.SamplerConfig`
        Name of the sampler that has been used in sampling the posterior pdf
    n_steps : int
        Number of chains to select last samples of each trace.

    Returns
    -------
    func : instance
    """

    pa = sc.parameters

    def last_sample(x):
        return x[(n_steps - 1)::n_steps]

    def burn_sample(x):
        if n_steps == 1:
            return x
        else:
            nchains = x.shape[0] // n_steps
            xout = []
            for i in range(nchains):
                nstart = int((n_steps * i) + (n_steps * pa.burn))
                nend = int(n_steps * (i + 1) - 1)
                xout.append(x[nstart:nend:pa.thin])

            return num.vstack(xout)

    def standard(x):
        return x

    if n_steps is None:
        return standard

    if sc.name == 'SMC':
        return last_sample
    elif sc.name == 'Metropolis' or sc.name == 'PT':
        return burn_sample


def select_metropolis_chains(problem, mtrace, post_llk):
    """
    Select chains from Multitrace
    """
    draws = len(mtrace)

    llks = num.array([mtrace.point(
        draws - 1, chain)[
            problem._like_name] for chain in mtrace.chains])

    chain_idxs = utility.get_fit_indexes(llks)
    return chain_idxs[post_llk]


def draw_posteriors(problem, plot_options):
    """
    Identify which stage is the last complete stage and plot posteriors.
    """

    hypers = utility.check_hyper_flag(problem)
    po = plot_options

    stage = Stage(homepath=problem.outfolder,
                  backend=problem.config.sampler_config.backend)

    pc = problem.config.problem_config

    list_indexes = stage.handler.get_stage_indexes(po.load_stage)

    if hypers:
        sc = problem.config.hyper_sampler_config
        varnames = problem.hypernames + ['like']
    else:
        sc = problem.config.sampler_config
        varnames = problem.varnames + problem.hypernames + ['like']

    if len(po.varnames) > 0:
        varnames = po.varnames

    logger.info('Plotting variables: %s' % (', '.join((v for v in varnames))))
    figs = []

    for s in list_indexes:
        if s == 0:
            draws = 1
        elif s == -1 and not hypers and sc.name == 'Metropolis':
            draws = sc.parameters.n_steps * (sc.parameters.n_stages - 1) + 1
        else:
            draws = None

        transform = select_transform(sc=sc, n_steps=draws)

        if po.source_idxs:
            sidxs = utility.list2string(po.source_idxs, fill='_')
        else:
            sidxs = ''

        outpath = os.path.join(
            problem.outfolder,
            po.figure_dir,
            'stage_%i_%s_%s.%s' % (s, sidxs, po.post_llk, po.outformat))

        if not os.path.exists(outpath) or po.force:
            logger.info('plotting stage: %s' % stage.handler.stage_path(s))
            stage.load_results(
                varnames=problem.varnames,
                model=problem.model, stage_number=s,
                load='trace', chains=[-1])

            if sc.name == 'Metropolis' and po.post_llk != 'all':
                chains = select_metropolis_chains(
                    problem, stage.mtrace, po.post_llk)
                logger.info('plotting result: %s of Metropolis chain %i' % (
                    po.post_llk, chains))
            else:
                chains = None

            prior_bounds = {}
            prior_bounds.update(**pc.hyperparameters)
            prior_bounds.update(**pc.priors)

            fig, _, _ = traceplot(
                stage.mtrace,
                varnames=varnames,
                transform=transform,
                chains=chains,
                combined=True,
                source_idxs=po.source_idxs,
                plot_style='hist',
                lines=po.reference,
                posterior=po.post_llk,
                prior_bounds=prior_bounds)

            if not po.outformat == 'display':
                logger.info('saving figure to %s' % outpath)
                fig.savefig(outpath, format=po.outformat, dpi=po.dpi)
            else:
                figs.append(fig)

        else:
            logger.info(
                'plot for stage %s exists. Use force=True for'
                ' replotting!' % s)

    if format == 'display':
        plt.show()


def draw_correlation_hist(problem, plot_options):
    """
    Draw parameter correlation plot and histograms from the final atmip stage.
    Only feasible for 'geometry' problem.
    """

    if problem.config.problem_config.n_sources > 1:
        raise NotImplementedError(
            'correlation_hist plot not working (yet) for n_sources > 1')

    po = plot_options
    mode = problem.config.problem_config.mode

    assert mode == geometry_mode_str
    assert po.load_stage != 0

    hypers = utility.check_hyper_flag(problem)

    if hypers:
        sc = problem.config.hyper_sampler_config
        varnames = problem.hypernames
    else:
        sc = problem.config.sampler_config
        varnames = list(problem.varnames) + problem.hypernames + ['like']

    if len(po.varnames) > 0:
        varnames = po.varnames

    logger.info('Plotting variables: %s' % (', '.join((v for v in varnames))))

    if len(varnames) < 2:
        raise TypeError('Need at least two parameters to compare!'
                        'Found only %i variables! ' % len(varnames))

    if po.load_stage is None and not hypers and sc.name == 'Metropolis':
        draws = sc.parameters.n_steps * (sc.parameters.n_stages - 1) + 1
    else:
        draws = None

    transform = select_transform(sc=sc, n_steps=draws)

    stage = load_stage(problem, stage_number=po.load_stage, load='trace', chains=[-1])

    if sc.name == 'Metropolis' and po.post_llk != 'all':
        chains = select_metropolis_chains(problem, stage.mtrace, po.post_llk)
        logger.info('plotting result: %s of Metropolis chain %i' % (
            po.post_llk, chains))
    else:
        chains = None

    if not po.reference:
        reference = get_result_point(stage, problem.config, po.post_llk)
        llk_str = po.post_llk
    else:
        reference = po.reference
        llk_str = 'ref'

    outpath = os.path.join(
        problem.outfolder, po.figure_dir, 'corr_hist_%s_%s.%s' % (
            stage.number, llk_str, po.outformat))

    if not os.path.exists(outpath) or po.force:
        fig, axs = correlation_plot_hist(
            mtrace=stage.mtrace,
            varnames=varnames,
            transform=transform,
            cmap=plt.cm.gist_earth_r,
            chains=chains,
            point=reference,
            point_size=6,
            point_color='red')
    else:
        logger.info('correlation plot exists. Use force=True for replotting!')
        return

    if po.outformat == 'display':
        plt.show()
    else:
        logger.info('saving figure to %s' % outpath)
        fig.savefig(outpath, format=po.outformat, dpi=po.dpi)


def n_model_plot(models, axes=None, draw_bg=True, highlightidx=[]):
    """
    Plot cake layered earth models.
    """
    fontsize = 10
    if axes is None:
        mpl_init(fontsize=fontsize)
        fig, axes = plt.subplots(
            nrows=1, ncols=1, figsize=mpl_papersize('a6', 'portrait'))
        labelpos = mpl_margins(
            fig, left=6, bottom=4, top=1.5, right=0.5, units=fontsize)
        labelpos(axes, 2., 1.5)

    def plot_profile(mod, axes, vp_c, vs_c, lw=0.5):
        z = mod.profile('z')
        vp = mod.profile('vp')
        vs = mod.profile('vs')
        axes.plot(vp, z, color=vp_c, lw=lw)
        axes.plot(vs, z, color=vs_c, lw=lw)

    cp.labelspace(axes)
    cp.labels_model(axes=axes)
    if draw_bg:
        cp.sketch_model(models[0], axes=axes)
    else:
        axes.spines['right'].set_visible(False)
        axes.spines['top'].set_visible(False)

    ref_vp_c = scolor('aluminium5')
    ref_vs_c = scolor('aluminium5')
    vp_c = scolor('scarletred2')
    vs_c = scolor('skyblue2')

    for i, mod in enumerate(models):
        plot_profile(
            mod, axes, vp_c=light(vp_c, 0.3), vs_c=light(vs_c, 0.3), lw=1.)

    for count, i in enumerate(sorted(highlightidx)):
        if count == 0:
            vpcolor = ref_vp_c
            vscolor = ref_vs_c
        else:
            vpcolor = vp_c
            vscolor = vs_c

        plot_profile(
            models[i], axes, vp_c=vpcolor, vs_c=vscolor, lw=2.)

    ymin, ymax = axes.get_ylim()
    xmin, xmax = axes.get_xlim()
    xmin = 0.
    my = (ymax - ymin) * 0.05
    mx = (xmax - xmin) * 0.2
    axes.set_ylim(ymax, ymin - my)
    axes.set_xlim(xmin, xmax + mx)
    return fig, axes


def load_earthmodels(store_superdir, targets, depth_max='cmb'):

    ems = []
    emr = []
    for t in targets:
        path = os.path.join(store_superdir, t.store_id, 'config')
        config = load(filename=path)
        em = config.earthmodel_1d.extract(depth_max=depth_max)
        ems.append(em)

        if config.earthmodel_receiver_1d is not None:
            emr.append(config.earthmodel_receiver_1d)

    return [ems, emr]


def draw_earthmodels(problem, plot_options):

    po = plot_options

    for datatype, composite in problem.composites.items():

        if datatype == 'seismic':
            models_dict = {}
            sc = problem.config.seismic_config

            if sc.gf_config.reference_location is None:
                plot_stations = composite.datahandler.stations
            else:
                plot_stations = [composite.datahandler.stations[0]]
                plot_stations[0].station = \
                    sc.gf_config.reference_location.station

            for station in plot_stations:
                outbasepath = os.path.join(
                    problem.outfolder, po.figure_dir,
                    '%s_%s_velocity_model' % (
                        datatype, station.station))

                if not os.path.exists(outbasepath) or po.force:
                    targets = init_seismic_targets(
                        [station],
                        earth_model_name=sc.gf_config.earth_model_name,
                        channels=sc.get_unique_channels()[0],
                        sample_rate=sc.gf_config.sample_rate,
                        crust_inds=list(range(*sc.gf_config.n_variations)),
                        interpolation='multilinear')

                    models = load_earthmodels(
                        composite.engine.store_superdirs[0], targets,
                        depth_max=sc.gf_config.depth_limit_variation * km)

                    for i, mods in enumerate(models):
                        if i == 0:
                            site = 'source'
                        elif i == 1:
                            site = 'receiver'

                        outpath = outbasepath + \
                            '_%s.%s' % (site, po.outformat)

                        models_dict[outpath] = mods

                else:
                    logger.info(
                        '%s earthmodel plot for station %s exists. Use '
                        'force=True for replotting!' % (
                            datatype, station.station))

        elif datatype == 'geodetic':
            gc = problem.config.geodetic_config

            models_dict = {}
            outpath = os.path.join(
                problem.outfolder, po.figure_dir,
                '%s_%s_velocity_model.%s' % (
                    datatype, 'psgrn', po.outformat))

            if not os.path.exists(outpath) or po.force:
                targets = init_geodetic_targets(
                    datasets=composite.datasets,
                    earth_model_name=gc.gf_config.earth_model_name,
                    interpolation='multilinear',
                    crust_inds=list(range(*gc.gf_config.n_variations)),
                    sample_rate=gc.gf_config.sample_rate)

                models = load_earthmodels(
                    store_superdir=composite.engine.store_superdirs[0],
                    targets=targets,
                    depth_max=gc.gf_config.source_depth_max * km)
                models_dict[outpath] = models[0]  # select only source site

            else:
                logger.info(
                    '%s earthmodel plot exists. Use force=True for'
                    ' replotting!' % datatype)
                return

        else:
            raise TypeError(
                'Plot for datatype %s not (yet) supported' % datatype)

        figs = []
        axes = []
        tobepopped = []
        for path, models in models_dict.items():
            if len(models) > 0:
                fig, axs = n_model_plot(
                    models, axes=None,
                    draw_bg=po.reference, highlightidx=[0])
                figs.append(fig)
                axes.append(axs)
            else:
                tobepopped.append(path)

        for entry in tobepopped:
            models_dict.pop(entry)

        if po.outformat == 'display':
            plt.show()
        else:
            for fig, outpath in zip(figs, models_dict.keys()):
                logger.info('saving figure to %s' % outpath)
                fig.savefig(outpath, format=po.outformat, dpi=po.dpi)


def fuzzy_waveforms(
        ax, traces, linewidth, zorder=0, extent=None, 
        grid_size=(500, 500), cmap=None, alpha=0.6):
    """
    Fuzzy waveforms

    traces : list
        of class:`pyrocko.trace.Trace`, the times of the traces should not
        vary too much
    zorder : int
        the higher number is drawn above the lower number
    extent : list
        of [xmin, xmax, ymin, ymax] (tmin, tmax, min/max of amplitudes)
        if None, the default is to determine it from traces list
    """

    if cmap is None:

        from matplotlib.colors import LinearSegmentedColormap

        ncolors = 256
        cmap = LinearSegmentedColormap.from_list(
            'dummy', ['white', scolor('chocolate2'), scolor('scarletred2')], N=ncolors)
        #cmap = plt.cm.gist_earth_r

    if extent is None:
        key = traces[0].channel
        skey = lambda tr: tr.channel

        ymin, ymax = trace.minmax(traces, key=skey)[key]
        xmin, xmax = trace.minmaxtime(traces, key=skey)[key]

        ymax = max(abs(ymin), abs(ymax))
        ymin = -ymax

        extent = [xmin, xmax, ymin, ymax]

    grid = num.zeros(grid_size, dtype='float64')

    for tr in traces:

        draw_line_on_array(
            tr.get_xdata(), tr.ydata,
            grid=grid,
            extent=extent,
            grid_resolution=grid.shape,
            linewidth=linewidth)

    # increase contrast reduce high intense values
    #truncate = len(traces) / 2
    #grid[grid > truncate] = truncate
    ax.imshow(
        grid, extent=extent, origin='lower', cmap=cmap, aspect='auto',
        alpha=alpha, zorder=zorder)


def fuzzy_rupture_fronts(
        ax, rupture_fronts, xgrid, ygrid, alpha=0.6, linewidth=7, zorder=0):
    """
    Fuzzy rupture fronts

    rupture_fronts : list
        of output of cs = pyplot.contour; cs.allsegs
    xgrid : array_like
        of center coordinates of the sub-patches of the fault in
        strike-direction in [km]
    ygrid : array_like
        of center coordinates of the sub-patches of the fault in
        dip-direction in [km]
    """

    from matplotlib.colors import LinearSegmentedColormap

    ncolors = 256
    cmap = LinearSegmentedColormap.from_list(
        'dummy', ['white', 'black'], N=ncolors)

    res_km = 25   # pixel per km

    xmin = xgrid.min()
    xmax = xgrid.max()
    ymin = ygrid.min()
    ymax = ygrid.max()
    extent = (xmin, xmax, ymin, ymax)

    grid = num.zeros(
        (int((num.abs(ymax) - num.abs(ymin)) * res_km),
         int((num.abs(xmax) - num.abs(xmin)) * res_km)),
        dtype='float64')

    for rupture_front in rupture_fronts:
        for level in rupture_front:
            for line in level:
                draw_line_on_array(
                    line[:, 0], line[:, 1],
                    grid=grid,
                    extent=extent,
                    grid_resolution=grid.shape,
                    linewidth=linewidth)

    # increase contrast reduce high intense values
    truncate = len(rupture_fronts) / 2
    grid[grid > truncate] = truncate
    ax.imshow(
        grid, extent=extent, origin='lower', cmap=cmap, aspect='auto',
        alpha=alpha, zorder=zorder)


def fault_slip_distribution(
        fault, mtrace=None, transform=lambda x: x, alpha=0.9, ntickmarks=5,
        reference=None, nensemble=1):
    """
    Draw discretized fault geometry rotated to the 2-d view of the foot-wall
    of the fault.

    Parameters
    ----------
    fault : :class:`ffi.fault.FaultGeometry`

    TODO: 0,0 is now ll of fault at depth, need to turn around axis that
        origin is top-left
    """

    def draw_quivers(
            ax, uperp, uparr, xgr, ygr, rake, color='black',
            draw_legend=False, normalisation=None, zorder=0):

        angles = num.arctan2(-uperp, uparr) * \
            (180. / num.pi) + rake

        slips = num.sqrt((uperp ** 2 + uparr ** 2)).ravel()

        if normalisation is None:
            normalisation = slips.max() / num.abs(
                ygr[1, 0] - ygr[0, 0]) * (3. / 2.)

        slips /= normalisation

        slipsx = num.cos(angles * num.pi / 180.) * slips
        slipsy = num.sin(angles * num.pi / 180.) * slips

        # slip arrows of slip on patches
        quivers = ax.quiver(
            xgr.ravel(), ygr.ravel(), slipsx, slipsy,
            units='dots', angles='xy', scale_units='xy', scale=1,
            width=1., color=color, zorder=zorder)

        if draw_legend:
            quiver_legend_length = num.ceil(
                num.max(slips * normalisation) * 10.) / 10.

            #ax.quiverkey(
            #    quivers, 0.9, 0.8, quiver_legend_length,
            #    '{} [m]'.format(quiver_legend_length), labelpos='E',
            #    coordinates='figure')

        return quivers, normalisation

    def draw_patches(ax, fault, subfault_idx, patch_values, cmap, alpha):
        i = subfault_idx
        height = fault.ordering.patch_sizes_dip[i]
        width = fault.ordering.patch_sizes_strike[i]

        d_patches = []
        lls = []
        for patch_dip_ll in range(np_h, 0, -1):
            for patch_strike_ll in range(np_w):
                ll = [patch_strike_ll * width, patch_dip_ll * height - height]
                d_patches.append(
                    Rectangle(
                        ll, width=width, height=height, edgecolor='black'))
                lls.append(ll)

        llsa = num.vstack(lls)
        lower = llsa.min(axis=0)
        upper = llsa.max(axis=0)
        xlim = [lower[0], upper[0] + width]
        ylim = [lower[1], upper[1] + height]

        ax.set_xlim(*xlim)
        ax.set_ylim(*ylim)

        scale_y = {'scale': 1, 'offset': -ylim[1]}
        scale_axes(ax.yaxis, **scale_y)

        ax.set_xlabel('strike-direction [km]', fontsize=fontsize)
        ax.set_ylabel('dip-direction [km]', fontsize=fontsize)

        xticker = tick.MaxNLocator(nbins=ntickmarks)
        yticker = tick.MaxNLocator(nbins=ntickmarks)

        ax.get_xaxis().set_major_locator(xticker)
        ax.get_yaxis().set_major_locator(yticker)

        pa_col = PatchCollection(
            d_patches, alpha=alpha, match_original=True, zorder=0)
        pa_col.set(array=patch_values, cmap=cmap)

        ax.add_collection(pa_col)
        return pa_col

    def draw_colorbar(fig, ax, cb_related, labeltext):
        cbaxes = fig.add_axes([0.88, 0.4, 0.03, 0.3])
        cb = fig.colorbar(cb_related, ax=axs, cax=cbaxes)
        cb.set_label(labeltext, fontsize=fontsize)
        ax.set_aspect('equal', adjustable='box')

    def get_values_from_trace(mtrace, varname, reference):
        try:
            u = transform(
                mtrace.get_values(
                    varname, combine=True, squeeze=True))
        except(ValueError, KeyError):
            u = num.atleast_2d(reference[varname])
        return u

    from beat.colormap import slip_colormap
    fontsize = 12

    reference_slip = num.sqrt(
        reference['uperp'] ** 2 + reference['uparr'] ** 2)

    figs = []
    axs = []
    for ns in range(fault.nsubfaults):
        fig, ax = plt.subplots(
            nrows=1, ncols=1, figsize=mpl_papersize('a5', 'landscape'))

        np_h, np_w = fault.get_subfault_discretization(ns)

        # alphas = alpha * num.ones(np_h * np_w, dtype='int8')
        pa_col = draw_patches(
            ax, fault, subfault_idx=ns, patch_values=reference_slip,
            cmap=slip_colormap(100), alpha=0.65)

        ext_source = fault.get_subfault(ns)
        patch_idxs = fault.get_patch_indexes(ns)

        # patch central locations
        hpd = fault.ordering.patch_sizes_dip[ns] / 2.
        hps = fault.ordering.patch_sizes_strike[ns] / 2.

        xvec = num.linspace(hps, ext_source.length / km - hps, np_w)
        yvec = num.linspace(ext_source.width / km - hpd, hpd, np_h)

        xgr, ygr = num.meshgrid(xvec, yvec)

        if 'seismic' in fault.datatypes:
            if mtrace is not None:
                from tqdm import tqdm
                nuc_dip = transform(mtrace.get_values(
                    'nucleation_dip', combine=True, squeeze=True))
                nuc_strike = transform(mtrace.get_values(
                    'nucleation_strike', combine=True, squeeze=True))
                velocities = transform(mtrace.get_values(
                    'velocities', combine=True, squeeze=True))

                nchains = nuc_dip.size
                csteps = 6
                rupture_fronts = []
                dummy_fig, dummy_ax = plt.subplots(
                    nrows=1, ncols=1, figsize=mpl_papersize('a5', 'landscape'))
                csteps = float(nchains) / nensemble
                idxs = num.floor(
                    num.arange(0, nchains, csteps)).astype('int32')
                logger.info('Rendering rupture fronts ...')
                for i in tqdm(idxs):
                    nuc_dip_idx, nuc_strike_idx = fault.fault_locations2idxs(
                        0, nuc_dip[i], nuc_strike[i], backend='numpy')
                    sts = fault.get_subfault_starttimes(
                        0, velocities[i, :], nuc_dip_idx, nuc_strike_idx)

                    contours = dummy_ax.contour(xgr, ygr, sts)
                    rupture_fronts.append(contours.allsegs)

                fuzzy_rupture_fronts(
                    ax, rupture_fronts, xgr, ygr,
                    alpha=1., linewidth=7, zorder=-1)

                durations = transform(mtrace.get_values(
                    'durations', combine=True, squeeze=True))
                std_durations = durations.std(axis=0)
                # alphas = std_durations.min() / std_durations

            # rupture durations
            fig2, ax2 = plt.subplots(
                nrows=1, ncols=1, figsize=mpl_papersize('a5', 'landscape'))

            reference_durations = reference['durations']

            pa_col2 = draw_patches(
                ax2, fault, subfault_idx=ns, patch_values=reference_durations,
                cmap=plt.cm.seismic, alpha=alpha)

            draw_colorbar(fig2, ax2, pa_col2, labeltext='durations [s]')
            figs.append(fig2)
            axs.append(ax2)

            ref_starttimes = fault.point2starttimes(reference)
            contours = ax.contour(
                xgr, ygr, ref_starttimes,
                colors='black', linewidths=0.5, alpha=0.9)
            ax.plot(
                reference['nucleation_strike'],
                reference['nucleation_dip'],
                marker='*', color='k', markersize=12)
            plt.clabel(contours, inline=True, fontsize=10)

        if mtrace is not None:
            logger.info('Drawing quantiles ...')

            uparr = get_values_from_trace(
                mtrace, 'uparr', reference)[:, patch_idxs]
            uperp = get_values_from_trace(
                mtrace, 'uperp', reference)[:, patch_idxs]

            uparrmean = uparr.mean(axis=0)
            uperpmean = uperp.mean(axis=0)

            quivers, normalisation = draw_quivers(
                ax, uperpmean, uparrmean, xgr, ygr,
                ext_source.rake, color='grey',
                draw_legend=False)

            uparrstd = uparr.std(axis=0) / normalisation
            uperpstd = uperp.std(axis=0) / normalisation

            slipvecrotmat = mt.euler_to_matrix(
                0.0, 0.0, ext_source.rake * mt.d2r)

            circle = num.linspace(0, 2 * num.pi, 100)
            # 2sigma error ellipses
            for i, (upe, upa) in enumerate(zip(uperpstd, uparrstd)):
                ellipse_x = 2 * upa * num.cos(circle)
                ellipse_y = 2 * upe * num.sin(circle)
                ellipse = num.vstack(
                    [ellipse_x, ellipse_y, num.zeros_like(ellipse_x)]).T
                rot_ellipse = ellipse.dot(slipvecrotmat)

                xcoords = xgr.ravel()[i] + rot_ellipse[:, 0] + quivers.U[i]
                ycoords = ygr.ravel()[i] + rot_ellipse[:, 1] + quivers.V[i]
                ax.plot(xcoords, ycoords, '-k', linewidth=0.5, zorder=2)
        else:
            normalisation = None

        logger.info('Drawing slip vectors ...')
        draw_quivers(
            ax, reference['uperp'][patch_idxs], reference['uparr'][patch_idxs],
            xgr, ygr, ext_source.rake, color='black', draw_legend=True,
            normalisation=normalisation, zorder=3)

        draw_colorbar(fig, ax, pa_col, labeltext='slip [m]')

        fig.tight_layout()
        figs.append(fig)
        axs.append(ax)

    return figs, axs


class ModeError(Exception):
    pass


def draw_slip_dist(problem, po):

    mode = problem.config.problem_config.mode

    if mode != ffi_mode_str:
        raise ModeError(
            'Wrong optimization mode: %s! This plot '
            'variant is only valid for "%s" mode' % (mode, ffi_mode_str))

    datatype, gc = list(problem.composites.items())[0]

    fault = gc.load_fault_geometry()

    sc = problem.config.sampler_config
    if po.load_stage is None and sc.name == 'Metropolis':
        draws = sc.parameters.n_steps * (sc.parameters.n_stages - 1) + 1
    else:
        draws = None

    transform = select_transform(sc=sc, n_steps=draws)

    stage = load_stage(problem, stage_number=po.load_stage, load='trace', chains=[-1])

    if not po.reference:
        reference = problem.config.problem_config.get_test_point()
        res_point = get_result_point(stage, problem.config, po.post_llk)
        reference.update(res_point)
        llk_str = po.post_llk
        mtrace = stage.mtrace
    else:
        reference = po.reference
        llk_str = 'ref'
        mtrace = None

    figs, axs = fault_slip_distribution(
        fault, mtrace, transform=transform,
        reference=reference, nensemble=po.nensemble)

    if po.outformat == 'display':
        plt.show()
    else:
        outpath = os.path.join(
            problem.outfolder, po.figure_dir,
            'slip_dist_%i_%s_%i' % (stage.number, llk_str, po.nensemble))

        logger.info('Storing slip-distribution to: %s' % outpath)
        if po.outformat == 'pdf':
            with PdfPages(outpath + '.pdf') as opdf:
                for fig in figs:
                    opdf.savefig(fig, dpi=po.dpi)
        else:
            for i, fig in enumerate(figs):
                fig.savefig(outpath + '_%i.%s' % (i, po.outformat), dpi=po.dpi)


def _weighted_line(r0, c0, r1, c1, w, rmin=0, rmax=num.inf):
    """
    Draw weighted lines into array
    Modiefied from:
    https://stackoverflow.com/questions/31638651/how-can-i-draw-lines-into-numpy-arrays

    Parameters
    ----------
    r0 : int
        row index for line end point 0
    c0 : int
        col index for line end point 0
    r1 : int
        row index for line end point 1
    c1 : int
        col index for line end point 1
    w : int
        width in pixels for line
    rmin : int
        min row index for grid to draw on
    rmax : int
        max row index for grid to draw on

    Returns
    -------
    rr : array of row indexes of line
    cc : array of col indexes of line
    w : array of line weights
    """
    def trapez(y, y0, w):
        return num.clip(num.minimum(
            y + 1 + w / 2 - y0,
            - y + 1 + w / 2 + y0), 0, 1)
    # The algorithm below works fine if c1 >= c0 and c1-c0 >= abs(r1-r0).
    # If either of these cases are violated, do some switches.
    if abs(c1 - c0) < abs(r1 - r0):
        # Switch x and y, and switch again when returning.
        xx, yy, val = _weighted_line(c0, r0, c1, r1, w=w, rmin=rmin, rmax=rmax)
        return (yy, xx, val)

    # At this point we know that the distance in columns (x) is greater
    # than that in rows (y). Possibly one more switch if c0 > c1.
    if c0 > c1:
        return _weighted_line(r1, c1, r0, c0, w=w, rmin=rmin, rmax=rmax)

    # The following is now always < 1 in abs
    slope = (r1 - r0) / (c1 - c0)

    # Adjust weight by the slope
    w *= num.sqrt(1 + num.abs(slope)) / 2

    # We write y as a function of x, because the slope is always <= 1
    # (in absolute value)
    x = num.arange(c0, c1 + 1, dtype=float)
    y = (x * slope) + ((c1 * r0) - (c0 * r1)) / (c1 - c0)

    # Now instead of 2 values for y, we have 2*np.ceil(w/2).
    # All values are 1 except the upmost and bottommost.
    thickness = num.ceil(w / 2)

    yy = (num.floor(y).reshape(-1, 1) +
          num.arange(-thickness - 1, thickness + 2).reshape(1, -1))
    xx = num.repeat(x, yy.shape[1])

    vals = trapez(yy, y.reshape(-1, 1), w).flatten()

    yy = yy.flatten()

    # Exclude useless parts and those outside of the interval
    # to avoid parts outside of the picture
    mask = num.logical_and.reduce((yy >= rmin, yy < rmax, vals > 0))

    return (yy[mask].astype(int), xx[mask].astype(int), vals[mask])


def draw_line_on_array(
        X, Y, grid=None, extent=[], grid_resolution=(400, 400), linewidth=1):
    """
    Draw line on given array by adding 1 to its fields.

    Parameters
    ----------
    X : array_like
        timeseries on xcoordinate (columns of array)
    Y : array_like
        timeseries on ycoordinate (rows of array)
    grid : array_like 2d
        input array that is used for drawing
    extent : array extent
        [xmin, xmax, ymin, ymax] (cols, rows)
    grid_resolution : tuple
        shape of given grid or grid that is being used for allocation
    linewidth : int
        weight (width) of line drawn on grid

    Returns
    -------
    grid, extent
    """

    def check_grid_shape(ngr, naim, axis):
        if ngr != naim:
            raise TypeError(
                'Gridsize of given grid is inconistent for axis %i!'
                ' Expected %i got %i' % (axis, naim, ngr))

    def check_line_in_grid(idxs, axis, nmax, extent):
        imax = idxs.max()
        if imax > nmax:
            raise TypeError(
                'Line endpoint outside of given grid Axis "%s"! %i > %i'
                ' Extent [%s]' % (
                    axis, imax, nmax, utility.list2string(extent)))

    nxs = len(X)
    nys = len(Y)
    if nxs != nys:
        raise TypeError(
            'Length of X and Y have to be identical! %i != %i' % (nxs, nys))

    if len(extent) == 0:
        xmin = X.min()
        xmax = X.max()
        ymin = Y.min()
        ymax = Y.max()
        extent = [xmin, xmax, ymin, ymax]
    elif len(extent) == 4:
        xmin, xmax, ymin, ymax = extent
    else:
        raise TypeError(
            'extent has to be of length 4! [xmin, xmax, ymin, ymax]')

    if len(grid_resolution) != 2:
        raise TypeError(
            'grid_resolution has to be of length 2! [xstep, ystep]!')

    ynstep, xnstep = grid_resolution

    xvec, xstep = num.linspace(xmin, xmax, xnstep, endpoint=True, retstep=True)
    yvec, ystep = num.linspace(ymin, ymax, ynstep, endpoint=True, retstep=True)

    if grid is not None:
        if grid.ndim != 2:
            raise TypeError('Given grid has to be of dimension 2!')

        for axis, (ngr, naim) in enumerate(
                zip(grid.shape, grid_resolution)):
            check_grid_shape(ngr, naim, axis)
    else:
        grid = num.zeros((ynstep, xnstep), dtype='float64')

    xidxs = utility.positions2idxs(
        X, min_pos=xmin, cell_size=xstep, dtype='int32')
    yidxs = utility.positions2idxs(
        Y, min_pos=ymin, cell_size=ystep, dtype='int32')

    check_line_in_grid(xidxs, 'x', nmax=xnstep - 1, extent=extent)
    check_line_in_grid(yidxs, 'y', nmax=ynstep - 1, extent=extent)

    new_grid = num.zeros_like(grid)
    for i in range(1, nxs):
        c0 = xidxs[i - 1]
        r0 = yidxs[i - 1]
        c1 = xidxs[i]
        r1 = yidxs[i]
        try:
            rr, cc, w = _weighted_line(
                r0=r0, c0=c0, r1=r1, c1=c1, w=linewidth, rmax=ynstep - 1)
            new_grid[rr, cc] = w.astype(grid.dtype)
        except ValueError:
            # line start and end fall in the same grid point cant be drawn
            pass

    grid += new_grid
    return grid, extent


def fuzzy_moment_rate(
        ax, moment_rates, times, cmap=None, grid_size=(500, 500)):
    """
    Plot fuzzy moment rate function into axes.
    """

    if cmap is None:
        # from matplotlib.colors import LinearSegmentedColormap
        # ncolors = 256
        # cmap = LinearSegmentedColormap.from_list(
        #    'dummy', [background_color, rates_color], N=ncolors)
        cmap = plt.cm.hot_r

    nrates = len(moment_rates)
    ntimes = len(times)

    if nrates != ntimes:
        raise TypeError(
            'Number of rates and times have to be identical!'
            ' %i != %i' % (nrates, ntimes))

    max_rates = max(map(num.max, moment_rates))
    max_times = max(map(num.max, times))

    extent = (0., max_times, 0., max_rates)
    grid = num.zeros(grid_size, dtype='float64')

    for mr, time in zip(moment_rates, times):
        draw_line_on_array(
            time, mr,
            grid=grid,
            extent=extent,
            grid_resolution=grid.shape,
            linewidth=7)

    # increase contrast reduce high intense values
    truncate = nrates / 2
    grid[grid > truncate] = truncate

    ax.imshow(grid, extent=extent, origin='lower', cmap=cmap, aspect='auto')
    ax.set_xlabel('Time [s]')
    ax.set_ylabel('Moment rate [$Nm / s$]')


def draw_moment_rate(problem, po):
    """
    Draw moment rate function for the results of a seismic/joint finite fault
    optimization.
    """
    fontsize = 12
    mode = problem.config.problem_config.mode

    if mode != ffi_mode_str:
        raise ModeError(
            'Wrong optimization mode: %s! This plot '
            'variant is only valid for "%s" mode' % (mode, ffi_mode_str))

    if 'seismic' not in problem.config.problem_config.datatypes:
        raise TypeError(
            'Moment rate function only available for optimization results that'
            ' include seismic data.')

    sc = problem.composites['seismic']
    fault = sc.load_fault_geometry()

    stage = load_stage(problem, stage_number=po.load_stage, load='trace', chains=[-1])

    if not po.reference:
        reference = get_result_point(stage, problem.config, po.post_llk)
        llk_str = po.post_llk
        mtrace = stage.mtrace
    else:
        reference = po.reference
        llk_str = 'ref'
        mtrace = None

    logger.info(
        'Drawing ensemble of %i moment rate functions ...' % po.nensemble)
    target = sc.wavemaps[0].targets[0]
    ref_mrf_rates, ref_mrf_times = fault.get_subfault_moment_rate_function(
        index=0, point=reference, target=target,
        store=sc.engine.get_store(target.store_id))

    mpl_init(fontsize=fontsize)
    for ns in range(fault.nsubfaults):
        outpath = os.path.join(
            problem.outfolder, po.figure_dir,
            'moment_rate_%i_%i_%s_%i.%s' % (
                stage.number, ns, llk_str, po.nensemble, po.outformat))

        if not os.path.exists(outpath) or po.force:
            fig, ax = plt.subplots(
                nrows=1, ncols=1, figsize=mpl_papersize('a7', 'landscape'))
            labelpos = mpl_margins(
                fig, left=5, bottom=4, top=1.5, right=0.5, units=fontsize)
            labelpos(ax, 2., 1.5)
            if mtrace is not None:
                nchains = len(mtrace)
                csteps = float(nchains) / po.nensemble
                idxs = num.floor(
                    num.arange(0, nchains, csteps)).astype('int32')
                mrfs_rate = []
                mrfs_time = []
                for idx in idxs:
                    point = mtrace.point(idx=idx)
                    mrf_rate, mrf_time = \
                        fault.get_subfault_moment_rate_function(
                            index=ns, point=point, target=target,
                            store=sc.engine.get_store(target.store_id))
                    mrfs_rate.append(mrf_rate)
                    mrfs_time.append(mrf_time)

                fuzzy_moment_rate(ax, mrfs_rate, mrfs_time)

            ax.plot(
                ref_mrf_times, ref_mrf_rates,
                '-k', alpha=0.8, linewidth=1.)
            format_axes(ax, remove=['top', 'right'])

            if po.outformat == 'display':
                plt.show()
            else:
                logger.info('saving figure to %s' % outpath)
                fig.savefig(outpath, format=po.outformat, dpi=po.dpi)

        else:
            logger.info('Plot exists! Use --force to overwrite!')


def source_geometry(fault, ref_sources):
    """
    Plot source geometry in 3d rotatable view

    Parameters
    ----------
    fault: :class:`beat.ffi.fault.FaultGeometry`
    ref_sources: list
        of :class:'beat.sources.RectangularSource'
    """

    from mpl_toolkits.mplot3d import Axes3D
    from beat.utility import RS_center
    alpha = 0.7

    def plot_subfault(ax, source, color):
        source.anchor = 'top'
        coords = source.outline()
        ax.plot(
            coords[:, 1], coords[:, 0], coords[:, 2] * -1.,
            color=color, linewidth=2, alpha=alpha)
        ax.plot(
            coords[0:2, 1], coords[0:2, 0], coords[0:2, 2] * -1.,
            '-k', linewidth=2, alpha=alpha)
        center = RS_center(source)
        ax.scatter(
            center[0], center[1], center[2] * -1,
            marker='o', s=20, color=color, alpha=alpha)

    fig = plt.figure(figsize=mpl_papersize('a4', 'landscape'))
    ax = fig.add_subplot(111, projection='3d')
    extfs = fault.get_all_subfaults()
    for idx, (refs, exts) in enumerate(zip(ref_sources, extfs)):

        plot_subfault(ax, exts, color=mpl_graph_color(idx))
        plot_subfault(ax, refs, color=scolor('aluminium4'))

        for i, patch in enumerate(fault.get_subfault_patches(idx)):
            coords = patch.outline()
            ax.plot(
                coords[:, 1], coords[:, 0], coords[:, 2] * -1.,
                color=mpl_graph_color(idx), linewidth=0.5, alpha=alpha)
            ax.text(
                patch.east_shift, patch.north_shift, patch.depth * -1., str(i),
                fontsize=10)

    scale = {'scale': 1. / km}
    scale_axes(ax.xaxis, **scale)
    scale_axes(ax.yaxis, **scale)
    scale_axes(ax.zaxis, **scale)
    ax.set_zlabel('Depth [km]')
    ax.set_ylabel('North_shift [km]')
    ax.set_xlabel('East_shift [km]')
    plt.show()


def draw_station_map(problem, po):
    import matplotlib.ticker as mticker

    logger.info('Drawing Station Map ...')
    try:
        import cartopy as ctp
    except ImportError:
        logger.error(
            'Cartopy is not installed.'
            'For a station map cartopy needs to be installed!')
        return

    def draw_gridlines(ax):
        gl = ax.gridlines(crs=grid_proj, color='black', linewidth=0.5)
        gl.n_steps = 300
        gl.xlines = False
        gl.ylocator = mticker.FixedLocator([30, 60, 90])

    fontsize = 12

    if 'seismic' not in problem.config.problem_config.datatypes:
        raise TypeError(
            'Station map is available only for seismic stations!'
            ' However, the datatypes do not include "seismic" data')

    event = problem.config.event

    sc = problem.composites['seismic']

    mpl_init(fontsize=fontsize)
    stations_proj = ctp.crs.PlateCarree()
    for wmap in sc.wavemaps:
        outpath = os.path.join(
            problem.outfolder, po.figure_dir, 'station_map_%s_%i.%s' % (
                wmap.name, wmap.mapnumber, po.outformat))

        if not os.path.exists(outpath) or po.force:
            if max(wmap.config.distances) > 30:
                map_proj = ctp.crs.Orthographic(
                    central_longitude=event.lon, central_latitude=event.lat)
                extent = None
            else:
                max_dist = math.ceil(wmap.config.distances[1])
                map_proj = ctp.crs.PlateCarree()
                extent = [
                    event.lon - max_dist, event.lon + max_dist,
                    event.lat - max_dist, event.lat + max_dist]

            grid_proj = ctp.crs.RotatedPole(
                pole_longitude=event.lon, pole_latitude=event.lat)
            fig, ax = plt.subplots(
                nrows=1, ncols=1, figsize=mpl_papersize('a6', 'landscape'),
                subplot_kw={'projection': map_proj})

            stations_meta = [
                (station.lat, station.lon, station.station)
                for station in wmap.stations]

            if extent:
                # regional map
                labelpos = mpl_margins(
                    fig, left=2, bottom=2, top=2, right=2, units=fontsize)

                import cartopy.feature as cfeature
                from cartopy.mpl.gridliner import \
                    LONGITUDE_FORMATTER, LATITUDE_FORMATTER

                ax.set_extent(extent, crs=map_proj)
                ax.add_feature(cfeature.NaturalEarthFeature(
                    category='physical', name='land',
                    scale='50m', **cfeature.LAND.kwargs))
                ax.add_feature(cfeature.NaturalEarthFeature(
                    category='physical', name='ocean',
                    scale='50m', **cfeature.OCEAN.kwargs))

                gl = ax.gridlines(
                    color='black', linewidth=0.5, draw_labels=True)
                gl.ylocator = tick.MaxNLocator(nbins=5)
                gl.xlocator = tick.MaxNLocator(nbins=5)
                gl.xlabels_top = False
                gl.ylabels_right = False
                gl.xformatter = LONGITUDE_FORMATTER
                gl.yformatter = LATITUDE_FORMATTER

            else:
                # global teleseismic map
                labelpos = mpl_margins(
                    fig, left=1, bottom=1, top=1, right=1, units=fontsize)

                ax.coastlines(linewidth=0.2)
                draw_gridlines(ax)
                ax.stock_img()

            for (lat, lon, name) in stations_meta:
                ax.plot(
                    lon, lat, 'r^', transform=stations_proj,
                    markeredgecolor='black', markeredgewidth=0.3)
                ax.text(
                    lon, lat, name, fontsize=10, transform=stations_proj,
                    horizontalalignment='center', verticalalignment='top')

            ax.plot(
                event.lon, event.lat, '*', transform=stations_proj,
                markeredgecolor='black', markeredgewidth=0.3, markersize=12,
                markerfacecolor=scolor('butter1'))
            if po.outformat == 'display':
                plt.show()
            else:
                logger.info('saving figure to %s' % outpath)
                fig.savefig(outpath, format=po.outformat, dpi=po.dpi)

        else:
            logger.info('Plot exists! Use --force to overwrite!')


plots_catalog = {
    'correlation_hist': draw_correlation_hist,
    'stage_posteriors': draw_posteriors,
    'waveform_fits': draw_seismic_fits,
    'scene_fits': draw_geodetic_fits,
    'velocity_models': draw_earthmodels,
    'slip_distribution': draw_slip_dist,
    'hudson': draw_hudson,
    'fuzzy_beachball': draw_fuzzy_beachball,
    'fuzzy_mt_decomp': draw_fuzzy_mt_decomposition,
    'moment_rate': draw_moment_rate,
    'station_map': draw_station_map}


common_plots = [
    'stage_posteriors',
    'velocity_models']


seismic_plots = [
    'station_map',
    'waveform_fits']


geodetic_plots = [
    'scene_fits']


geometry_plots = [
    'correlation_hist',
    'hudson',
    'fuzzy_beachball']


ffi_plots = [
    'moment_rate',
    'slip_distribution']


plots_mode_catalog = {
    'geometry': common_plots + geometry_plots,
    'ffi': common_plots + ffi_plots,
}

plots_datatype_catalog = {
    'seismic': seismic_plots,
    'geodetic': geodetic_plots,
}


def available_plots(mode=None, datatypes=['geodetic', 'seismic']):
    if mode is None:
        return list(plots_catalog.keys())
    else:
        plots = plots_mode_catalog[mode]
        for datatype in datatypes:
            plots.extend(plots_datatype_catalog[datatype])

        return plots