from array import array
import collections
import itertools
from typing import Dict, List, Optional, Tuple, Union

import numpy as np

import backtrader as bt

import pandas as pd

from bokeh.models import Span
from bokeh.plotting import figure
from bokeh.models import HoverTool, CrosshairTool
from bokeh.models import LinearAxis, DataRange1d, Renderer
from bokeh.models.formatters import NumeralTickFormatter
from bokeh.models import ColumnDataSource, FuncTickFormatter, DatetimeTickFormatter

from backtrader_plotting.bokeh import label_resolver
from backtrader_plotting.bokeh.label_resolver import plotobj2label
from backtrader_plotting.bokeh.utils import convert_color, sanitize_source_name, get_bar_width, convert_linestyle, get_indicator_data


class HoverContainer(metaclass=bt.MetaParams):
    """Class to store information about hover tooltips. Will be filled while Bokeh glyphs are created. After all figures are complete, hovers will be applied"""

    params = (('hover_tooltip_config', None),
              ('is_multidata', False)
              )

    def __init__(self):
        self._hover_tooltips = []

        self._config = []
        input_config = [] if len(self.p.hover_tooltip_config) == 0 else self.p.hover_tooltip_config.split(',')
        for c in input_config:
            if len(c) != 2:
                raise RuntimeError(f'Invalid hover config entry "{c}"')
            self._config.append((self._get_type(c[0]), self._get_type(c[1])))

    def add_hovertip(self, label: str, tmpl: str, src_obj=None) -> None:
        self._hover_tooltips.append((label, tmpl, src_obj))

    @staticmethod
    def _get_type(t):
        if t == 'd':
            return bt.AbstractDataBase
        elif t == 'i':
            return bt.Indicator
        elif t == 'o':
            return bt.Observer
        else:
            raise RuntimeError(f'Invalid hovertool config type: "{t}')

    def _apply_to_figure(self, fig, hovertool):
        # provide ordering by two groups
        tooltips_top = []
        tooltips_bottom = []
        for label, tmpl, src_obj in self._hover_tooltips:
            apply: bool = src_obj is fig.master  # apply to own
            foreign = False
            if not apply and (isinstance(src_obj, bt.Observer) or isinstance(src_obj, bt.Indicator)) and src_obj.plotinfo.subplot is False:
                # add objects that are on the same figure cause subplot is False (for Indicators and Observers)
                # if plotmaster is set then it will decide where to add, otherwise clock is used
                if src_obj.plotinfo.plotmaster is not None:
                    apply = src_obj.plotinfo.plotmaster is fig.master
                else:
                    apply = src_obj._clock is fig.master
            if not apply:
                for c in self._config:
                    if isinstance(src_obj, c[0]) and isinstance(fig.master, c[1]):
                        apply = True
                        foreign = True
                        break

            if apply:
                prefix = ''
                top = True
                # prefix with data name if we got multiple datas
                if self.p.is_multidata and foreign:
                    if isinstance(src_obj, bt.Indicator):
                        prefix = label_resolver.datatarget2label(src_obj.datas) + " - "
                    elif isinstance(src_obj, bt.AbstractDataBase):
                        prefix = label_resolver.datatarget2label([src_obj]) + " - "
                    top = False

                item = (prefix + label, tmpl)
                if top:
                    tooltips_top.append(item)
                else:
                    tooltips_bottom.append(item)

        # first apply all top hover then all bottoms
        for t in itertools.chain(tooltips_top, tooltips_bottom):
            hovertool.tooltips.append(t)

    def apply_hovertips(self, figures: List['FigureEnvelope']) -> None:
        """Add hovers to to all figures from the figures list"""
        for f in figures:
            for t in f.figure.tools:
                if not isinstance(t, HoverTool):
                    continue

                self._apply_to_figure(f, t)
                break


class FigureEnvelope(object):
    _tools = "pan,wheel_zoom,box_zoom,reset"

    _mrk_fncs = {'^': 'triangle',
                 'v': 'inverted_triangle',
                 'o': 'circle',
                 '<': 'circle_cross',
                 '>': 'circle_x',
                 '1': 'diamond',
                 '2': 'diamond_cross',
                 '3': 'hex',
                 '4': 'square',
                 '8': 'square_cross',
                 's': 'square_x',
                 'p': 'triangle',
                 '*': 'asterisk',
                 'h': 'hex',
                 'H': 'hex',
                 '+': 'asterisk',
                 'x': 'x',
                 'D': 'diamond_cross',
                 'd': 'diamond',
                 }

    def __init__(self, strategy: bt.Strategy, cds: ColumnDataSource, hoverc: HoverContainer, start, end, scheme, master, plotorder, is_multidata):
        self._strategy = strategy
        self._cds: ColumnDataSource = cds
        self._scheme = scheme
        self._start = start
        self._end = end
        self.figure: figure = None
        self._hover_line_set = False
        self._hover: Optional[HoverTool] = None
        self._hoverc = hoverc
        self._coloridx = collections.defaultdict(lambda: -1)
        self.master = master
        self.plottab = None
        self.plotorder = plotorder
        self.datas = []  # list of all datas that have been plotted to this figure
        self._is_multidata = is_multidata
        self._tradingdomain = None
        self._init_figure()

    @staticmethod
    def should_filter_by_tradingdomain(obj, tradingdomain):
        if tradingdomain is None:
            return True

        if isinstance(tradingdomain, str):
            tradingdomain = [tradingdomain]

        obj_lg = FigureEnvelope._resolve_tradingdomain(obj)
        return obj_lg is True or obj_lg in tradingdomain

    @staticmethod
    def _resolve_tradingdomain(obj) -> Union[bool, str]:
        if isinstance(obj, bt.AbstractDataBase):
            # data feeds are end points
            return obj._name
        elif isinstance(obj, bt.IndicatorBase):
            # lets find the data the indicator is based on
            data = get_indicator_data(obj)
            return FigureEnvelope._resolve_tradingdomain(data)
        elif isinstance(obj, bt.ObserverBase):
            if isinstance(obj._clock, bt.AbstractDataBase):
                return FigureEnvelope._resolve_tradingdomain(obj._clock)
            else:
                return True  # for wide observers we return True which means it belongs to all logic groups
        else:
            raise Exception('unsupported')

    def get_tradingdomains(self) -> List[str]:
        tradingdomains = []
        if self._tradingdomain is None:
            tradingdomains.append(self._resolve_tradingdomain(self.master))
        elif isinstance(self._tradingdomain, list):
            tradingdomains += self._tradingdomain
        elif isinstance(self._tradingdomain, str):
            tradingdomains.append(self._tradingdomain)
        else:
            raise Exception(f'Invalid type for tradingdomain: {type(self._tradingdomain)}')

        return tradingdomains

    def _set_single_hover_renderer(self, ren: Renderer):
        """Sets this figure's hover to a single renderer"""
        if self._hover_line_set:
            return

        self._hover.renderers = [ren]
        self._hover_line_set = True

    def _add_hover_renderer(self, ren: Renderer):
        """Adds another hover render target. Only has effect if not single renderer has been set before"""
        if self._hover_line_set:
            return

        if isinstance(self._hover.renderers, list):
            self._hover.renderers.append(ren)
        else:
            self._hover.renderers = [ren]

    def _nextcolor(self, key: object = None) -> int:
        self._coloridx[key] += 1
        return self._coloridx[key]

    def _color(self, key: object = None):
        return convert_color(self._scheme.color(self._coloridx[key]))

    def _init_figure(self):
        # plot height will be set later
        f = figure(tools=FigureEnvelope._tools, x_axis_type='linear', aspect_ratio=self._scheme.plotaspectratio)
        # TODO: backend webgl (output_backend="webgl") removed due to this bug:
        # https://github.com/bokeh/bokeh/issues/7568

        f.y_range.range_padding = self._scheme.y_range_padding

        f.border_fill_color = convert_color(self._scheme.border_fill)

        f.xaxis.axis_line_color = convert_color(self._scheme.axis_line_color)
        f.yaxis.axis_line_color = convert_color(self._scheme.axis_line_color)
        f.xaxis.minor_tick_line_color = convert_color(self._scheme.tick_line_color)
        f.yaxis.minor_tick_line_color = convert_color(self._scheme.tick_line_color)
        f.xaxis.major_tick_line_color = convert_color(self._scheme.tick_line_color)
        f.yaxis.major_tick_line_color = convert_color(self._scheme.tick_line_color)

        f.xaxis.major_label_text_color = convert_color(self._scheme.axis_label_text_color)
        f.yaxis.major_label_text_color = convert_color(self._scheme.axis_label_text_color)

        f.xgrid.grid_line_color = convert_color(self._scheme.grid_line_color)
        f.ygrid.grid_line_color = convert_color(self._scheme.grid_line_color)
        f.title.text_color = convert_color(self._scheme.plot_title_text_color)

        f.left[0].formatter.use_scientific = False
        f.background_fill_color = convert_color(self._scheme.background_fill)

        # mechanism for proper date axis without gaps, thanks!
        # https://groups.google.com/a/continuum.io/forum/#!topic/bokeh/t3HkalO4TGA
        f.xaxis.formatter = FuncTickFormatter(
            args=dict(
                axis=f.xaxis[0],
                formatter=DatetimeTickFormatter(days=[self._scheme.axis_tickformat_days],
                                                hourmin=[self._scheme.axis_tickformat_hourmin],
                                                hours=[self._scheme.axis_tickformat_hours],
                                                minsec=[self._scheme.axis_tickformat_minsec],
                                                minutes=[self._scheme.axis_tickformat_minutes],
                                                months=[self._scheme.axis_tickformat_months],
                                                seconds=[self._scheme.axis_tickformat_seconds],
                                                years=[self._scheme.axis_tickformat_years],
                                                ),
                source=self._cds,
            ),
            code="""
                // We override this axis' formatter's `doFormat` method
                // with one that maps index ticks to dates. Some of those dates
                // are undefined (e.g. those whose ticks fall out of defined data
                // range) and we must filter out and account for those, otherwise
                // the formatter computes invalid visible span and returns some
                // labels as 'ERR'.
                // Note, after this assignment statement, on next plot redrawing,
                // our override `doFormat` will be called directly
                // -- FunctionTickFormatter.doFormat(), i.e. _this_ code, no longer
                // executes.
                axis.formatter.doFormat = function (ticks) {
                    const dates = ticks.map(i => source.data.datetime[source.data.index.indexOf(i)]),
                          valid = t => t !== undefined,
                          labels = formatter.doFormat(dates.filter(valid));
                    let i = 0;
                    return dates.map(t => valid(t) ? labels[i++] : '');
                };
                
                // we do this manually only for the first time we are called
                const labels = axis.formatter.doFormat(ticks);
                return labels[index];
            """
            )

        ch = CrosshairTool(line_color=self._scheme.crosshair_line_color)
        f.tools.append(ch)

        h = HoverTool(tooltips=[('Time', f'@datetime{{{self._scheme.hovertool_timeformat}}}')],
                      mode="vline",
                      formatters={'@datetime': 'datetime'}
                      )
        f.tools.append(h)

        self._hover = h
        self.figure = f

    def plot(self, obj, master=None):
        if isinstance(obj, bt.AbstractDataBase):
            self.plot_data(obj)
        elif isinstance(obj, bt.indicator.Indicator):
            self.plot_indicator(obj, master)
        elif isinstance(obj, bt.observers.Observer):
            self.plot_observer(obj, master)
        else:
            raise Exception(f"Unsupported plot object: {type(obj)}")

        # first object can apply config
        if len(self.datas) == 0:
            aspectr = getattr(obj.plotinfo, 'plotaspectratio', None)
            if aspectr is not None:
                self.figure.aspect_ratio = aspectr

            tab = getattr(obj.plotinfo, 'plottab', None)
            if tab is not None:
                self.plottab = tab

            order = getattr(obj.plotinfo, 'plotorder', None)
            if order is not None:
                self.plotorder = order

            # just store the tradingdomain of the master for later reference
            tradingdomain = getattr(obj.plotinfo, 'tradingdomain', None)
            if tradingdomain is not None:
                self._tradingdomain = tradingdomain

        self.datas.append(obj)

    @staticmethod
    def build_color_lines(df: pd.DataFrame, scheme, col_open: str = 'open', col_close: str = 'close', col_prefix: str='') -> pd.DataFrame:
        # build color strings from scheme
        colorup = convert_color(scheme.barup)
        colordown = convert_color(scheme.bardown)
        colorup_wick = convert_color(scheme.barup_wick)
        colordown_wick = convert_color(scheme.bardown_wick)
        colorup_outline = convert_color(scheme.barup_outline)
        colordown_outline = convert_color(scheme.bardown_outline)
        volup = convert_color(scheme.volup)
        voldown = convert_color(scheme.voldown)

        # build binary series determining if up or down bar
        is_up: pd.DataFrame = df[col_close] >= df[col_open]

        # we use the open-line as a indicator for NaN values
        nan_ref = df[col_open]

        # TODO: we want to have NaN values in the color lines if the corresponding data is also NaN
        # find better way with less isnan usage

        color_df = pd.DataFrame(index=df.index)
        color_df[col_prefix + 'colors_bars'] = [np.nan if np.isnan(n) else colorup if x else colordown for x, n in zip(is_up, nan_ref)]
        color_df[col_prefix + 'colors_wicks'] = [np.nan if np.isnan(n) else colorup_wick if x else colordown_wick for x, n in zip(is_up, nan_ref)]
        color_df[col_prefix + 'colors_outline'] = [np.nan if np.isnan(n) else colorup_outline if x else colordown_outline for x, n in zip(is_up, nan_ref)]
        color_df[col_prefix + 'colors_volume'] = [np.nan if np.isnan(n) else volup if x else voldown for x, n in zip(is_up, nan_ref)]

        # convert to object since we want to hold str and NaN
        for c in color_df.columns:
            color_df[c] = color_df[c].astype(object)

        return color_df

    def _add_column(self, name, dtype):
        self._add_columns([(name, dtype)])

    def _add_columns(self, cols: List[Tuple[str, object]]):
        for name, dtype in cols:
            self._cds.add(np.array([], dtype=dtype), name)

    def plot_data(self, data: bt.AbstractDataBase):
        source_id = FigureEnvelope._source_id(data)
        title = sanitize_source_name(label_resolver.datatarget2label([data]))

        # append to title
        self._figure_append_title(title)

        self._add_columns([(source_id + x, object) for x in ['open', 'high', 'low', 'close']])
        self._add_columns([(source_id + x, str) for x in ['colors_bars', 'colors_wicks', 'colors_outline']])

        if self._scheme.style == 'line':
            if data.plotinfo.plotmaster is None:
                color = convert_color(self._scheme.loc)
            else:
                self._nextcolor(data.plotinfo.plotmaster)
                color = convert_color(self._color(data.plotinfo.plotmaster))

            renderer = self.figure.line('index', source_id + 'close', source=self._cds, line_color=color, legend_label=title)
            self._set_single_hover_renderer(renderer)

            self._hoverc.add_hovertip("Close", f"@{source_id}close", data)
        elif self._scheme.style == 'bar':
            self.figure.segment('index', source_id + 'high', 'index', source_id + 'low', source=self._cds, color=source_id + 'colors_wicks', legend_label=title)
            renderer = self.figure.vbar('index',
                                        get_bar_width(),
                                        source_id + 'open',
                                        source_id + 'close',
                                        source=self._cds,
                                        fill_color=source_id + 'colors_bars',
                                        line_color=source_id + 'colors_outline',
                                        legend_label=title,
                                        )

            self._set_single_hover_renderer(renderer)

            self._hoverc.add_hovertip("Open", f"@{source_id}open{{{self._scheme.number_format}}}", data)
            self._hoverc.add_hovertip("High", f"@{source_id}high{{{self._scheme.number_format}}}", data)
            self._hoverc.add_hovertip("Low", f"@{source_id}low{{{self._scheme.number_format}}}", data)
            self._hoverc.add_hovertip("Close", f"@{source_id}close{{{self._scheme.number_format}}}", data)
        else:
            raise Exception(f"Unsupported style '{self._scheme.style}'")

        # make sure the regular y-axis only scales to the normal data on 1st axis (not to e.g. volume data on 2nd axis)
        self.figure.y_range.renderers.append(renderer)

        if self._scheme.volume and self._scheme.voloverlay:
            self.plot_volume(data, self._scheme.voltrans, True)

    def plot_volume(self, data: bt.AbstractDataBase, alpha=1.0, extra_axis=False):
        """extra_axis displays a second axis (for overlay on data plotting)"""
        source_id = FigureEnvelope._source_id(data)

        self._add_columns([(source_id + 'volume', np.float64), (source_id + 'colors_volume', np.object)])
        kwargs = {'fill_alpha': alpha,
                  'line_alpha': alpha,
                  'name': 'Volume',
                  'legend_label': 'Volume'}

        ax_formatter = NumeralTickFormatter(format=self._scheme.number_format)

        if extra_axis:
            source_data_axis = 'axvol'

            self.figure.extra_y_ranges = {source_data_axis: DataRange1d(
                range_padding=1.0/self._scheme.volscaling,
                start=0,
            )}

            # use colorup
            ax_color = convert_color(self._scheme.volup)

            ax = LinearAxis(y_range_name=source_data_axis, formatter=ax_formatter,
                            axis_label_text_color=ax_color, axis_line_color=ax_color, major_label_text_color=ax_color,
                            major_tick_line_color=ax_color, minor_tick_line_color=ax_color)
            self.figure.add_layout(ax, 'left')
            kwargs['y_range_name'] = source_data_axis
        else:
            self.figure.yaxis.formatter = ax_formatter

        vbars = self.figure.vbar('index', get_bar_width(), f'{source_id}volume', 0, source=self._cds, fill_color=f'{source_id}colors_volume', line_color="black", **kwargs)

        # make sure the new axis only auto-scales to the volume data
        if extra_axis:
            self.figure.extra_y_ranges['axvol'].renderers = [vbars]

        self._hoverc.add_hovertip("Volume", f"@{source_id}volume{{({self._scheme.number_format})}}", data)

    def plot_observer(self, obj, master):
        self._plot_indicator_observer(obj, master)

    def plot_indicator(self, obj: Union[bt.Indicator, bt.Observer], master):
        self._plot_indicator_observer(obj, master)

    def _plot_indicator_observer(self, obj: Union[bt.Indicator, bt.Observer], master):
        pl = plotobj2label(obj)

        self._figure_append_title(pl)
        indlabel = obj.plotlabel()
        plotinfo = obj.plotinfo

        is_multiline = obj.size() > 1
        for lineidx in range(obj.size()):
            line = obj.lines[lineidx]
            source_id = FigureEnvelope._source_id(line)
            linealias = obj.lines._getlinealias(lineidx)

            lineplotinfo = getattr(obj.plotlines, '_%d' % lineidx, None)
            if not lineplotinfo:
                lineplotinfo = getattr(obj.plotlines, linealias, None)

            if not lineplotinfo:
                lineplotinfo = bt.AutoInfoClass()

            if lineplotinfo._get('_plotskip', False):
                continue

            marker = lineplotinfo._get("marker", None)
            method = lineplotinfo._get('_method', "line")

            color = getattr(lineplotinfo, "color", None)
            if color is None:
                if not lineplotinfo._get('_samecolor', False):
                    self._nextcolor()
                color = self._color()
            color = convert_color(color)

            kwglyphs = {'name': linealias}

            self._add_column(source_id, np.float64)

            # either all individual lines of are displayed in the legend or only the ind/obs as a whole
            label = indlabel
            if is_multiline and plotinfo.plotlinelabels:
                label += " " + (lineplotinfo._get("_name", "") or linealias)
            kwglyphs['legend_label'] = label

            if marker is not None:
                kwglyphs['size'] = lineplotinfo.markersize * 1.2
                kwglyphs['color'] = color
                kwglyphs['y'] = source_id

                if marker not in FigureEnvelope._mrk_fncs:
                    raise Exception(f"Sorry, unsupported marker: '{marker}'. Please report to GitHub.")
                glyph_fnc_name = FigureEnvelope._mrk_fncs[marker]
                glyph_fnc = getattr(self.figure, glyph_fnc_name)
            elif method == "bar":
                kwglyphs['bottom'] = 0
                kwglyphs['line_color'] = 'black'
                kwglyphs['fill_color'] = color
                kwglyphs['width'] = get_bar_width()
                kwglyphs['top'] = source_id

                glyph_fnc = self.figure.vbar
            elif method == "line":
                kwglyphs['line_width'] = 1
                kwglyphs['color'] = color
                kwglyphs['y'] = source_id

                linestyle = getattr(lineplotinfo, "ls", None)
                if linestyle is not None:
                    kwglyphs['line_dash'] = convert_linestyle(linestyle)

                glyph_fnc = self.figure.line
            else:
                raise Exception(f"Unknown plotting method '{method}'")

            renderer = glyph_fnc("index", source=self._cds, **kwglyphs)

            # make sure the regular y-axis only scales to the normal data (data + ind/obs) on 1st axis (not to e.g. volume data on 2nd axis)
            self.figure.y_range.renderers.append(renderer)

            # for markers add additional renderer so hover pops up for all of them
            if marker is None:
                self._set_single_hover_renderer(renderer)
            else:
                self._add_hover_renderer(renderer)

            hover_label_suffix = f" - {linealias}" if obj.size() > 1 else ""  # we need no suffix if there is just one line in the indicator anyway
            hover_label = indlabel + hover_label_suffix
            hover_data = f"@{source_id}{{{self._scheme.number_format}}}"
            self._hoverc.add_hovertip(hover_label, hover_data, obj)

        self._set_yticks(obj)
        self._plot_hlines(obj)

    def _set_yticks(self, obj):
        yticks = obj.plotinfo._get('plotyticks', [])
        if not yticks:
            yticks = obj.plotinfo._get('plotyhlines', [])

        if yticks:
            self.figure.yaxis.ticker = yticks

    def _plot_hlines(self, obj):
        hlines = obj.plotinfo._get('plothlines', [])
        if not hlines:
            hlines = obj.plotinfo._get('plotyhlines', [])

        # Horizontal Lines
        hline_color = convert_color(self._scheme.hlinescolor)
        for hline in hlines:
            span = Span(location=hline,
                        dimension='width',
                        line_color=hline_color,
                        line_dash=convert_linestyle(self._scheme.hlinesstyle),
                        line_width=self._scheme.hlineswidth)
            self.figure.renderers.append(span)

    def _figure_append_title(self, title):
        # append to title
        if len(self.figure.title.text) > 0:
            self.figure.title.text += " | "
        self.figure.title.text += title

    def _add_to_cds(self, data, name):
        if name in self._cds.column_names:
            return
        self._cds.add(data, name)

    @staticmethod
    def _source_id(source):
        return str(id(source))