import pdb
import numbers
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import patches, transforms
from matplotlib.path import Path
from numpy.linalg import norm

from .edgenode import Edge, Node, Pin, _node
from .theme import NODE_THEME_DICT, BLUE
from .utils import rotate
from .setting import node_setting, edge_setting
from .import shapes

class Brush(object):
    '''Base Class of brushes.'''
    pass


class NodeBrush(Brush):
    '''
    a brush class used to draw node.

    Attributes:
        style (str): refer keys for `viznet.theme.NODE_THEME_DICT`.
        ax (:obj:`Axes`): matplotlib Axes instance.
        color (str|None): the color of painted node by this brush, it will overide theme color if is not `None`.
        size ('huge'|'large'|'normal'|'small'|'tiny'|'dot'|tuple|float): size of node.
        roundness (float): the roundness of edges.
        zorder (int): same to matplotlib zorder.
        rotate (float): angle for rotation.
        ls (str): line style.
        props (dict): other arguments passed to handler.
    '''
    setting = node_setting

    size_dict = {
        'huge': 0.9,
        'large': 0.39,
        'normal': 0.3,
        'small': 0.21,
        'tiny': 0.09,
        'dot': 0.05,
    }

    def __init__(self, style, ax=None, color=None, size='normal', roundness=0, zorder=0, rotate=0., ls='-', lw=None, edgecolor=None, props=None):
        self.style = style
        self.size = size
        self.ax = ax
        self.color = color
        self.zorder = zorder
        self.rotate = rotate
        self.ls = ls
        self.lw = lw
        self.edgecolor = edgecolor
        self.node_handler = basicgeometry_handler
        self.props = props if props is not None else {}
        self.roundness = roundness

    @property
    def is_rectangular(self):
        return self._style[1] in ['rectangle']

    @property
    def _size(self):
        if isinstance(self.size, str):
            size = self.size_dict[self.size]
        else:
            size = self.size
        if self.is_rectangular and np.ndim(size)==0:
            size = [size, size]
        return np.asarray(size)

    def __rshift__(self, xy):
        '''
        add a node.

        Args:
            xy (tuple): position.

        Returns:
            :obj:`Node`: node object.
        '''
        # color priority: brush color > theme color
        ax = plt.gca() if self.ax is None else self.ax
        # override color
        if self.color is not None:
            color = self.color
        lw = self.lw
        edgecolor = self.edgecolor
        if lw is None:
            lw = self.setting['lw']
        if edgecolor is None:
            edgecolor = self.setting['edgecolor']

        theme_code = self._style

        # get the size and position
        size = self._size
        if np.ndim(size) == 1 and len(size) == 2:
            xstart, xstop = self._get_range(xy[0])
            ystart, ystop = self._get_range(xy[1])
            size = (size[0] + abs(xstop - xstart)/2., size[1] + abs(ystop - ystart)/2.)
            xy = (xstop + xstart)/2., (ystart + ystop)/2.

        objs = self.node_handler(theme_code, xy, size, self.roundness, facecolor=self.color,
                lw = lw, edgecolor=edgecolor, ls=self.ls, zorder=self.zorder, angle=self.rotate, props=self.props)

        # add patches
        for p in objs:
            ax.add_patch(p)
            #shapes.affine(p, offset=xy, scale=np.atleast_1d(self._size)[0], angle=self.rotate)
        node = Node(objs, xy, self)
        return node

    @property
    def _style(self):
        if isinstance(self.style, str):
            return NODE_THEME_DICT[self.style]
        else:
            return self.style

    def _get_range(self, x):
        if isinstance(x, slice):   # gridwise operation.
            if not self.is_rectangular:
                raise ValueError('Only Rectangular support slice plot!')
            return x.start, x.stop
        else:
            return x, x

class EdgeBrush(Brush):
    '''
    a brush for drawing edges.

    Attributes:
        style (str): the style of edge, must be a combination of ('>'|'<'|'-'|'.').
            * '>', right arrow
            * '<', left arrow,
            * '-', line,
            * '.', dashed line.
        ax (:obj:`Axes`): matplotlib Axes instance.
        lw (float): line width.
        color (str): the color of painted edge by this brush.
    '''
    setting = edge_setting

    def __init__(self, style, ax=None, lw=1, color='k', zorder=0, solid_capstyle='butt'):
        self.lw = lw
        self.color = color
        self.ax = ax
        self.style = style
        self.zorder = zorder
        self.line_handler = basicline_handler
        self.solid_capstyle = solid_capstyle

    def __rshift__(self, startend):
        '''
        connect start node and end node

        Args:
            startend (tuple): start node (position) and end node (position).

        Returns:
            :obj:`Edge`: edge object.
        '''
        ax = plt.gca() if self.ax is None else self.ax
        lw = self.lw
        head_length = self.setting['arrow_head_length'] * lw
        head_width = self.setting['arrow_head_width'] * lw

        # get start position and end position
        start, end = _node(startend[0]), _node(startend[1])
        sxy, exy = np.asarray(start.position), np.asarray(end.position)
        d = exy - sxy
        unit_d = d / norm(d)
        sxy = start.get_connection_point(unit_d)
        exy = end.get_connection_point(-unit_d)

        arrows, lines = self.line_handler(sxy, exy, self.style, head_length)
        objs = _arrows(ax, arrows, head_width=head_width, head_length=head_length, lw=lw, zorder=self.zorder, color=self.color)
        objs += _lines(ax, lines, lw=lw, color=self.color, zorder=self.zorder, use_path=False, solid_capstyle=self.solid_capstyle)
        return Edge(objs, sxy, exy, start, end, brush=self)

class CLinkBrush(EdgeBrush):
    '''
    Brush for C type link.

    Attributes:
        style (str): e.g. '<->', right-side grow with respect to the line direction.
    '''
    def __init__(self, style, ax=None, offsets=(0.2,), roundness=0, lw=1, color='k', zorder=0, solid_capstyle='butt'):
        super(CLinkBrush, self).__init__(style, ax=ax, lw=lw, color=color, zorder=zorder, solid_capstyle=solid_capstyle)
        self.roundness = roundness
        self.offsets = list(offsets)
        self.line_handler = clink_handler

    def __rshift__(self, startend):
        '''
        connect start node and end node

        Args:
            startend (tuple): start node (position) and end node (position).

        Returns:
            :obj:`Edge`: edge object.
        '''
        ax = plt.gca() if self.ax is None else self.ax
        lw = self.lw
        head_length = self.setting['arrow_head_length'] * lw
        head_width = self.setting['arrow_head_width'] * lw

        # get start position and end position
        start, end = _node(startend[0]), _node(startend[1])
        sxy, exy = np.asarray(start.position), np.asarray(end.position)
        
        arrows, lines, (sxy_, exy_) = self.line_handler(sxy, exy, self.style, self.offsets, self.roundness, head_length)
        objs = _arrows(ax, arrows, head_width=head_width, head_length=head_length, lw=lw, zorder=self.zorder, color=self.color)
        objs += _lines(ax, lines, lw=lw, color=self.color, zorder=self.zorder, use_path=True, solid_capstyle=self.solid_capstyle)
        return Edge(objs, sxy_, exy_, start, end, brush=self)

class CurveBrush(Brush):
    '''
    a brush for drawing edges.

    Attributes:
        style (str): the style of edge, same as arrowprops in https://matplotlib.org/api/_as_gen/matplotlib.pyplot.annotate.html.
        ax (:obj:`Axes`): matplotlib Axes instance.
        lw (float): line width.
        color (str): the color of painted edge by this brush.
    '''
    setting = edge_setting

    def __init__(self, style, ax=None, lw=1, color='k', zorder=0, solid_capstyle='butt', ls='-'):
        self.lw = lw
        self.ls = ls
        self.color = color
        self.ax = ax
        self.style = style
        self.zorder = zorder
        self.solid_capstyle = solid_capstyle

    def __rshift__(self, startend):
        '''
        connect start node and end node

        Args:
            startend (tuple): start node (position) and end node (position).

        Returns:
            :obj:`Edge`: edge object.
        '''
        ax = plt.gca() if self.ax is None else self.ax
        lw = self.lw
        head_length = self.setting['arrow_head_length'] * lw
        head_width = self.setting['arrow_head_width'] * lw

        # get start position and end position
        start, end = _node(startend[0]), _node(startend[1])
        rad = startend[2]
        text = "" if len(startend) <= 3 else startend[3]
        sxy, exy = np.asarray(start.position), np.asarray(end.position)

        arrowprops=dict(patchA=getattr(start, "obj", None), patchB = getattr(end, "obj", None),
                                   connectionstyle="arc3,rad=%f"%rad,
                                   arrowstyle=self.style, lw=self.lw, ls=self.ls, color=self.color)
        
        obj = ax.annotate(text, exy, sxy, xycoords='data', textcoords='data', zorder=self.zorder, arrowprops=arrowprops)
        return Edge([obj], sxy, exy, start, end, brush=self)
def clink_handler(sxy, exy, style, offsets, roundness, head_length):
    '''a C style link between two edges.'''
    nturn = len(offsets)
    offsets = np.asarray(offsets)
    unit_t = (exy - sxy)/norm(exy - sxy)
    unit_l = rotate(unit_t, np.pi/2.)
    vl, vr = [sxy], [exy]

    # get arrow locations and directions
    # - first, get head and tail vector.
    arrows = []
    unit_head_vec = (-unit_t if nturn%2==0 else unit_l)*(np.sign(offsets[0]) if len(offsets)!=0 else -1)
    sign_eo = -1 if nturn%2 == 0 else 1
    unit_tail_vec = unit_head_vec * sign_eo
    head_vec = unit_head_vec * head_length
    tail_vec = unit_tail_vec * head_length
    if style[0] in ['<', '>']:
        sign = (1 if style[0]=='>' else -1)
        arrows.append((sxy+head_vec*0.6, unit_head_vec * sign))
        style = style[1:]
        vl[0] = vl[0] + head_vec
    if style[-1] in ['<', '>']:
        sign = (-1 if style[-1]=='>' else 1)
        arrows.append((exy+tail_vec*0.6, unit_tail_vec * sign))
        style = style[:-1]
        vr[0] = vr[0] + tail_vec

    # get path
    ls = style
    if len(ls) !=1:
        raise ValueError('style must contain exactly 1 line style code.')
    for i, dxy in enumerate(offsets):
        sxy = sxy + (-unit_t if i%2 == nturn%2 else unit_l)*dxy
        exy = exy + (unit_t if i%2 == nturn%2 else unit_l)*dxy
        vl.append(sxy)
        vr.append(exy)
    return arrows, [(ls, rounded_path(vl+vr[::-1], roundness))], (vl[-1], vr[-1])

def rounded_path(vertices, roundness):
    '''make rounded path from vertices.'''
    vertices = np.asarray(vertices)
    if roundness == 0:
        return Path(vertices)

    codes = [Path.MOVETO]
    vertices_new = [vertices[0]]
    for pre, cur, nex in zip(vertices[:-2], vertices[1:-1], vertices[2:]):
        codes.extend([Path.LINETO, Path.CURVE3, Path.CURVE3])
        dv_pre = (pre - cur)/norm(cur-pre)*roundness
        dv_nex = (nex - cur)/norm(cur-nex)*roundness
        vertices_new.extend([cur+dv_pre,cur,cur+dv_nex])
    codes.append(Path.LINETO)
    vertices_new.append(vertices[-1])
    return Path(vertices_new, codes)

def basicline_handler(sxy, exy, style, head_length):
    '''draw a line between start and end.'''
    # the distance and unit distance
    d = np.asarray(exy) - sxy
    unit_d = d / norm(d)

    # get arrow locations.
    arrows = []
    segs = []
    for s in style:
        if s in ['>', '<']:
            sign = 1 if s == '>' else -1
            arrows.append([len(segs), sign*unit_d])
        else:
            segs.append(s)
    head_vec = unit_d * head_length
    vec_d = d - head_vec * 1.2
    num_segs = len(segs)
    for al in arrows:
        al[0] = al[0] * vec_d / max(num_segs, 1) + sxy + 0.6 * head_vec

    # get the line locations.
    uni = d / num_segs
    lines = []
    end = start = sxy
    seg_pre = ''
    for seg in segs:
        if seg != seg_pre and seg_pre != '':
            lines.append([seg_pre, [start, end]])
            start = end
        seg_pre = seg
        end = end + uni
    lines.append([seg, [start, end]])

    # fix end of line
    if style[-1] in ['<', '>']:
        lines[-1][1][1] -= head_vec
    if style[0] in ['<', '>']:
        lines[0][1][0] += head_vec
    return arrows, lines

def _arrows(ax, arrows, **kwargs):
    '''show arrows'''
    objs = []
    for mxy, direction in arrows:
        objs.append(_arrow(ax, mxy, direction, **kwargs))
    return objs

def _lines(ax, lines, **kwargs):
    '''show the lines.'''
    objs = []
    for ls, line in lines:
        objs.extend(_line(ax, ls, line, **kwargs))
    return objs

def _arrow(ax, mxy, direction, head_width, head_length, lw, zorder, color):
    '''draw an arrow.'''
    head_vec = direction * head_length
    mxy = mxy - head_vec * 0.6
    dx, dy = direction
    obj = plt.arrow(*mxy, 1e-8 * dx, 1e-8 * dy,
              head_length=head_length, width=0,
              head_width=head_width, fc=color,
              length_includes_head=False, lw=lw, edgecolor=color, zorder=zorder)
    return obj

def _line(ax, ls, path, lw, color, zorder, use_path, solid_capstyle):
    '''draw a line connecting sxy and exy.'''
    objs = []
    def _plot_line(path_):
        if not use_path:
            sxy_, exy_ = path_
            objs.extend(ax.plot([sxy_[0], exy_[0]], [
                               sxy_[1], exy_[1]], lw=lw, color=color,
                               zorder=zorder, ls=ls, solid_capstyle=solid_capstyle))
        else:
            obj = patches.PathPatch(path_, lw=lw, facecolor='none', edgecolor=color,
                               zorder=zorder, ls=ls)
            objs.append(obj)
            ax.add_patch(obj)
    if ls == '=':
        if not use_path:
            sxy, exy = path
            d = np.asarray(exy) - sxy
            unit_d = d / norm(d)
            perp_d = np.array([-unit_d[1], unit_d[0]])
            offset = perp_d * edge_setting['doubleline_space'] * lw
            ls = '-'
            _plot_line((sxy + offset, exy + offset))
            _plot_line((sxy - offset, exy - offset))
        else:
            raise NotImplementedError('Double line for C link not implemented!')
    else:
        if ls == '.': ls = '--'
        _plot_line(path)
    return objs

def _basicgeometry(xy, geo, size, angle, roundness, props, **kwargs):
    '''basic geometric handler.'''
    return eval('shapes.%s'%geo)(xy, size, angle, roundness, props=props, **kwargs)

def basicgeometry_handler(theme_code, xy, size, roundness, facecolor, ls, lw, edgecolor, zorder, angle, props):
    '''basic geometry node handler.'''
    default_color, geo, inner_geo = theme_code
    if facecolor is None:
        facecolor = default_color
    if facecolor is None:  # both color and default color is None
        facecolor = 'none'
        edgecolor = 'none'

    objs = _basicgeometry(xy, geo, size, angle,roundness, props, facecolor=facecolor, edgecolor=edgecolor, ls=ls, zorder=zorder, lw=lw)

    # add a geometric patch at the top of circle.
    if inner_geo != 'none':
        # get the size
        if inner_geo in ['measure', 'cross', 'vbar', 'plus']:
            inner_size = size
        else:
            inner_size = 0.7 * size

        inner_fc = node_setting['inner_facecolor']
        inner_ec = node_setting['inner_edgecolor']
        inner_lw = node_setting['inner_lw']
        objs += _basicgeometry(xy, inner_geo, inner_size, angle, roundness, props, facecolor=inner_fc, edgecolor=inner_ec, lw=inner_lw, ls=ls, zorder=zorder+1)

    # for BLUE nodes, add a self-loop (Stands for Recurrent Unit)
    if facecolor == BLUE and theme_code == 'nn.':
        loop = plt.Circle((xy[0], xy[1] + 1.2 * size), 0.5 * size,
                          edgecolor=edgecolor, facecolor=inner_fc, lw=lw, zorder=-5)
        objs.append(loop)
    return objs

def rotate_translate_path(path, angle, dxy=(0,0)):
    '''rotate path by angle'''
    affine = transforms.Affine2D()
    return path.transformed(affine.rotate(angle)).transformed(affine.translate(*dxy))

pin = NodeBrush('pin')