    Functions that is used to determine the composition of the
    defined ode

import re
from functools import reduce

import sympy
from sympy.matrices import MatrixBase
import numpy as np

from .base_ode_model import BaseOdeModel
from .transition import TransitionType

greekLetter = ('alpha', 'beta', 'gamma', 'delta', 'epsilon', 'zeta', 'eta', 'theta',
               'iota', 'kappa', 'lambda', 'mu', 'nu', 'xi', 'omicron', 'pi', 'rho',
               'sigma', 'tau', 'upsilon', 'phi', 'chi', 'psi', 'omega')

def generateTransitionGraph(ode_model, file_name=None):
    Generates the transition graph in graphviz given an ode model with transitions

    ode_model: OperateOdeModel
        an ode model object
    file_name: str
        location of the file, if none entered, then the default directory is used

    dot: graphviz object
    assert isinstance(ode_model, BaseOdeModel), "An ode model object required"

    from graphviz import Digraph

    if file_name is None:
        dot = Digraph(comment='ode model')
        dot = Digraph(comment='ode model', filename=file_name)


    param = [str(p) for p in ode_model.param_list]
    states = [str(s) for s in ode_model.state_list]

    for s in states:

    transition = ode_model.transition_list
    bd_list = ode_model.birth_death_list

    for transition in (transition + bd_list):
        s1 = transition.origin
        eq = _makeEquationPretty(transition.equation, param)

        if transition.transition_type is TransitionType.T:
            s2 = transition.destination
            dot.edge(s1, s2, label=eq)
        elif transition.transition_type is TransitionType.B:
            # when we have a birth or death process, do not make the box
            dot.node(eq, shape="plaintext", width="0", height="0", margin="0")
            dot.edge(eq, s1)
        elif transition.transition_type is TransitionType.D:
            dot.node(eq, shape="plaintext", width="0", height="0", margin="0")
            dot.edge(s1, eq)

    return dot

def _makeEquationPretty(eq, param):
    Make the equation suitable for graphviz format by converting
    beta to &beta;  and remove all the multiplication sign

    We do not process ** and convert it to a superscript because
    it is only possible with svg (which is a real pain to convert
    back to png) and only available from graphviz versions after
    14 Oct 2011
    for p in param:
        if p.lower() in greekLetter:
            eq = re.sub('(\\W?)(' + p + ')(\\W?)', '\\1&' + p + ';\\3', eq)
    # eq = re.sub('\*{1}[^\*]', '', eq)
    # eq = re.sub('([^\*]?)\*([^\*]?)', '\\1 \\2', eq)
    # eq += " blah<SUP>Yo</SUP> + ha<SUB>Boo</SUB>"
    return eq

def generateDirectedDependencyGraph(ode_matrix, transition=None):
    Returns a binary matrix that contains the direction of the transition in
    a state

    ode_matrix: :class:`sympy.matrcies.MatrixBase`
        A matrix of size [number of states x 1].  Obtained by
        invoking :meth:`DeterministicOde.get_ode_eqn`
    transition: list, optional
        list of transitions.  Can be generated by

    G: :class:`numpy.ndarray`
        Two dimensional array of size [number of state x number of transitions]
        where each column has two entry,
        -1 and 1 to indicate the direction of the transition and the state.
        All column sum to one, i.e. transition must have a source and target.
    assert isinstance(ode_matrix, MatrixBase), \
        "Expecting a vector of expressions"

    if transition is None:
        transition = getMatchingExpressionVector(ode_matrix, True)
        assert isinstance(transition, list), "Require a list of transitions"

    B = np.zeros((len(ode_matrix), len(transition)))
    for i, a in enumerate(ode_matrix):
        for j, transitionTuple in enumerate(transition):
            t1, t2 = transitionTuple
            if _hasExpression(a, t1):
                B[i, j] += -1  # going out
            if _hasExpression(a, t2):
                B[i, j] += 1   # coming in
    return B

def getUnmatchedExpressionVector(expr_vec, full_output=False):
    Return the unmatched expressions from a vector of equations

    expr_vec: :class:`sympy.matrices.MatrixBase`
        A matrix of size [number of states x 1].
    full_output: bool, optional
        Defaults to False, if True, also output the list of matched expressions

        of unmatched expressions, i.e. birth or death processes
    assert isinstance(expr_vec, MatrixBase), \
        "Expecting a vector of expressions"

    transition = reduce(lambda x, y: x + y, map(getExpressions, expr_vec))
    matched_transition_list = _findMatchingExpression(transition)
    out = list(set(transition) - set(matched_transition_list))

    if full_output:
        return out, _transitionListToMatchedTuple(matched_transition_list)
        return out

def getMatchingExpressionVector(expr_vec, outTuple=False):
    Return the matched expressions from a vector of equations

    expr_vec: :class:`sympy.matrices.MatrixBase`
        A matrix of size [number of states x 1].
    outTuple: bool, optional
        Defaults to False, if True, the output is a tuple of length two
        which has the matching elements.  The first element is always
        positive and the second negative

        of matched expressions, i.e. transitions
    assert isinstance(expr_vec, MatrixBase), \
        "Expecting a vector of expressions"

    transition = list()
    for expr in expr_vec:
        transition += getExpressions(expr)

    transition = list(set(_findMatchingExpression(transition)))

    if outTuple:
        return _transitionListToMatchedTuple(transition)
        return transition

def _findMatchingExpression(expressions, full_output=False):
    Reduce a list of expressions to a list of transitions.  A transition
    is found when two expressions are identical with a change of sign.

    expressions: list
        the list of expressions
    full_output: bool, optional
        If True, output the unmatched expressions as well. Defaults to False.

        of expressions that was matched
    t_list = list()
    for i in range(len(expressions) - 1):
        for j in range(i + 1, len(expressions)):
            b = expressions[i] + expressions[j]
            if b == 0:

    if full_output:
        unmatched = set(expressions) - set(t_list)
        return t_list, list(unmatched)
        return t_list

def _transitionListToMatchedTuple(transition):
    Convert a list of transitions to a list of tuple, where each tuple
    is of length 2 and contains the matched transitions. First element
    of the tuple is the positive term
    t_tuple_list = list()
    for i in range(len(transition) - 1):
        for j in range(i + 1, len(transition)):
            b = transition[i] + transition[j]
            # the two terms cancel out
            if b == 0:
                if sympy.Integer(-1) in getLeafs(transition[i]):
                    t_tuple_list.append((transition[j], transition[i]))
                    t_tuple_list.append((transition[i], transition[j]))
    return t_tuple_list

def getExpressions(expr):
    input_dict = dict()
    _getExpression(expr.expand(), input_dict)
    return list(input_dict.keys())

def getLeafs(expr):
    input_dict = dict()
    _getLeaf(expr.expand(), input_dict)
    return list(input_dict.keys())

def _getLeaf(expr, input_dict):
    Get the leafs of an expression, can probably just do
    the same with expr.atoms() with most expression but we
    do not break down power terms i.e. x**2 will be broken
    down to (x,2) in expr.atoms() but this function will
    retain (x**2)
    t = expr.args
    t_lengths = np.array(list(map(_expressionLength, t)))

    for i, ti in enumerate(t):
        if t_lengths[i] == 0:
            input_dict.setdefault(ti, 0)
            input_dict[ti] += 1
            _getLeaf(ti, input_dict)

def _getExpression(expr, input_dict):
    all the operations is dependent on the conditions 
    whether all the elements are leafs or only some of them.
    Only return expressions and not the individual elements
    t = expr.args if len(expr.atoms()) > 1 else [expr]
    # print t

    # find out the length of the components within this node
    t_lengths = np.array(list(map(_expressionLength, t)))
    # print(tLengths)
    if np.all(t_lengths == 0):
        # if all components are leafs, then the node is an expression
        input_dict.setdefault(expr, 0)
        input_dict[expr] += 1
        for i, ti in enumerate(t):
            # if the leaf is a singleton, then it is an expression
            # else, go further along the tree
            if t_lengths[i] == 0:
                input_dict.setdefault(ti, 0)
                input_dict[ti] += 1
                if isinstance(ti, sympy.Mul):
                    _getExpression(ti, input_dict)
                elif isinstance(ti, sympy.Pow):
                    input_dict.setdefault(ti, 0)
                    input_dict[ti] += 1

def _expressionLength(expr):
    Returns the length of the expression i.e. number of terms.
    If the expression is a power term, i.e. x^2 then we assume
    that it is one term and return 0.
    # print type(expr)
    if isinstance(expr, sympy.Mul):
        return len(expr.args)
    elif isinstance(expr, sympy.Pow):
        return 0
        return 0

def _findIndex(eq_vec, expr):
    Given a vector of expressions, find where you will locate the
    input term.

    eq_vec: :class:`sympy.Matrix`
        vector of sympy equation
    expr: sympy type
        An expression that we would like to find

        of index that contains the expression.  Can be an empty list
        or with multiple integer
    out = list()
    for i, a in enumerate(eq_vec):
        j = _hasExpression(a, expr)
        if j is True:
    return out

def _hasExpression(eq, expr):
    Test whether the equation eq has the expression expr
    out = False
    aExpand = eq.expand()
    if expr == aExpand:
        out = True
    if expr in aExpand.args:
        out = True
    return out

def pureTransitionToOde(A):
    Get the ode from a pure transition matrix

    A: `sympy.Matrix`
        a transition matrix of size [n \times n]

    b: `sympy.Matrix`
        a matrix of size [n \times 1] which is the ode
    nrow, ncol = A.shape
    assert nrow == ncol, "Need a square matrix"
    B = [sum(A[:, i]) - sum(A[i, :]) for i in range(nrow)]
    return sympy.simplify(sympy.Matrix(B))

def stripBDFromOde(fx, bd_list=None):
    if bd_list is None:
        bd_list = getUnmatchedExpressionVector(fx, False)

    fx_copy = fx.copy()
    for i, fxi in enumerate(fx):
        term_in_expr = list(map(lambda x: x in fxi.expand().args, bd_list))
        for j, term in enumerate(bd_list):
            fx_copy[i] -= term if term_in_expr[j] else 0

    # simplify converts it to an ImmutableMatrix, so we make it into
    # a mutable object again because we want the expanded form
    return sympy.Matrix(sympy.simplify(fx_copy)).expand()

def odeToPureTransition(fx, states, output_remain=False):
    bd_list, term_list = getUnmatchedExpressionVector(fx, full_output=True)
    fx = stripBDFromOde(fx, bd_list)
    # we now have fx with pure transitions
    A, remain_terms = _singleOriginTransition(fx, term_list, states)
    A, remain_terms = _odeToPureTransition(fx, remain_terms, A)
    # checking if our decomposition is correct
    fx1 = pureTransitionToOde(A)
    diff_ode = sympy.simplify(fx - fx1)
    if np.all(np.array(map(lambda x: x == 0, diff_ode)) == True):
        if output_remain:
            return A, remain_terms
            return A
        diff_term = sympy.Matrix(list(filter(lambda x: x != 0, diff_ode)))
        diff_term_list = getMatchingExpressionVector(diff_term, True)
        # If there is some single origin transition not being matched up
        # it is most likely because the transition originates from a
        # combination like (1-x) which got split into two parts - the
        # "1" and the "x" part.  So we try to reverse the sign to see
        # if it helps.
        # TODO: increase robustness so if it does not help, then we
        # either bail out or revert to the normal version
        diff_term_list = map(lambda x_y: (x_y[1], x_y[0]), diff_term_list)
        A, remain_terms = _singleOriginTransition(diff_ode, diff_term_list,
                                                  states, A)
        AA, remain_terms = _odeToPureTransition(diff_ode, remain_terms, A)

        if output_remain:
            return AA, remain_terms
            return AA

def _odeToPureTransition(fx, terms=None, A=None):
    Get the pure transition matrix between states

    fx: :class:`sympy.matrices.MatrixBase`
       input ode in symbolic form, :math:`f(x)`
        list of two element tuples which contains the
        matching terms
    A:  `sympy.matricies.MatrixBase`, optional
        the matrix to be filled.  Defaults to None, which
        will lead to the creation of a [len(fx), len(fx)] matrix
        with all zero elements
    A: :class:`sympy.matricies.MatrixBase`
        resulting transition matrix
    remain: list
        list of  which contains the unmatched
    if terms is None:
        terms = getMatchingExpressionVector(fx, True)

    if A is None:
        A = sympy.zeros(len(fx), len(fx))

    remain_transition = list()
    for t1, t2 in terms:
        remain = True
        for i, aFrom in enumerate(fx):
            if _hasExpression(aFrom, t2):
                # arriving at
                for j, aTo in enumerate(fx):
                    if _hasExpression(aTo, t1):
                        A[i, j] += t1  # from i to j
                        remain = False
        if remain:
            remain_transition.append((t1, t2))

    return A, remain_transition

def _singleOriginTransition(fx, term_list, states, A=None):
    if A is None:
        A = sympy.zeros(len(fx), len(fx))

    remain_term_list = list()
    for k, transition_tuple in enumerate(term_list):
        t1, t2 = transition_tuple
        possible_origin = list()
        remain = True
        for i, s in enumerate(states):
            if s in t1.atoms():
        if len(possible_origin) == 1:
            for j, fxj in enumerate(fx):
                # print(t1, fxj, possibleOrigin[0] != j, _hasExpression(fxj, t2))
                if possible_origin[0] != j and _hasExpression(fxj, t1):
                    A[possible_origin[0], j] += t1
                    remain = False
                    # print(t1, possibleOrigin, j, fxj, "\n")
        if remain:

    return A, remain_term_list