""" .. moduleauthor:: Edwin Tye <Edwin.Tye@phe.gov.uk> Functions that is used to determine the composition of the defined ode """ import re from functools import reduce import sympy from sympy.matrices import MatrixBase import numpy as np from .base_ode_model import BaseOdeModel from .transition import TransitionType greekLetter = ('alpha', 'beta', 'gamma', 'delta', 'epsilon', 'zeta', 'eta', 'theta', 'iota', 'kappa', 'lambda', 'mu', 'nu', 'xi', 'omicron', 'pi', 'rho', 'sigma', 'tau', 'upsilon', 'phi', 'chi', 'psi', 'omega') def generateTransitionGraph(ode_model, file_name=None): """ Generates the transition graph in graphviz given an ode model with transitions Parameters ---------- ode_model: OperateOdeModel an ode model object file_name: str location of the file, if none entered, then the default directory is used Returns ------- dot: graphviz object """ assert isinstance(ode_model, BaseOdeModel), "An ode model object required" from graphviz import Digraph if file_name is None: dot = Digraph(comment='ode model') else: dot = Digraph(comment='ode model', filename=file_name) dot.body.extend(['rankdir=LR']) param = [str(p) for p in ode_model.param_list] states = [str(s) for s in ode_model.state_list] for s in states: dot.node(s) transition = ode_model.transition_list bd_list = ode_model.birth_death_list for transition in (transition + bd_list): s1 = transition.origin eq = _makeEquationPretty(transition.equation, param) if transition.transition_type is TransitionType.T: s2 = transition.destination dot.edge(s1, s2, label=eq) elif transition.transition_type is TransitionType.B: # when we have a birth or death process, do not make the box dot.node(eq, shape="plaintext", width="0", height="0", margin="0") dot.edge(eq, s1) elif transition.transition_type is TransitionType.D: dot.node(eq, shape="plaintext", width="0", height="0", margin="0") dot.edge(s1, eq) else: pass return dot def _makeEquationPretty(eq, param): """ Make the equation suitable for graphviz format by converting beta to β and remove all the multiplication sign We do not process ** and convert it to a superscript because it is only possible with svg (which is a real pain to convert back to png) and only available from graphviz versions after 14 Oct 2011 """ for p in param: if p.lower() in greekLetter: eq = re.sub('(\\W?)(' + p + ')(\\W?)', '\\1&' + p + ';\\3', eq) # eq = re.sub('\*{1}[^\*]', '', eq) # eq = re.sub('([^\*]?)\*([^\*]?)', '\\1 \\2', eq) # eq += " blah<SUP>Yo</SUP> + ha<SUB>Boo</SUB>" return eq def generateDirectedDependencyGraph(ode_matrix, transition=None): """ Returns a binary matrix that contains the direction of the transition in a state Parameters ---------- ode_matrix: :class:`sympy.matrcies.MatrixBase` A matrix of size [number of states x 1]. Obtained by invoking :meth:`DeterministicOde.get_ode_eqn` transition: list, optional list of transitions. Can be generated by :func:`getMatchingExpressionVector` Returns ------- G: :class:`numpy.ndarray` Two dimensional array of size [number of state x number of transitions] where each column has two entry, -1 and 1 to indicate the direction of the transition and the state. All column sum to one, i.e. transition must have a source and target. """ assert isinstance(ode_matrix, MatrixBase), \ "Expecting a vector of expressions" if transition is None: transition = getMatchingExpressionVector(ode_matrix, True) else: assert isinstance(transition, list), "Require a list of transitions" B = np.zeros((len(ode_matrix), len(transition))) for i, a in enumerate(ode_matrix): for j, transitionTuple in enumerate(transition): t1, t2 = transitionTuple if _hasExpression(a, t1): B[i, j] += -1 # going out if _hasExpression(a, t2): B[i, j] += 1 # coming in return B def getUnmatchedExpressionVector(expr_vec, full_output=False): """ Return the unmatched expressions from a vector of equations Parameters ---------- expr_vec: :class:`sympy.matrices.MatrixBase` A matrix of size [number of states x 1]. full_output: bool, optional Defaults to False, if True, also output the list of matched expressions Returns ------- list: of unmatched expressions, i.e. birth or death processes """ assert isinstance(expr_vec, MatrixBase), \ "Expecting a vector of expressions" transition = reduce(lambda x, y: x + y, map(getExpressions, expr_vec)) matched_transition_list = _findMatchingExpression(transition) out = list(set(transition) - set(matched_transition_list)) if full_output: return out, _transitionListToMatchedTuple(matched_transition_list) else: return out def getMatchingExpressionVector(expr_vec, outTuple=False): """ Return the matched expressions from a vector of equations Parameters ---------- expr_vec: :class:`sympy.matrices.MatrixBase` A matrix of size [number of states x 1]. outTuple: bool, optional Defaults to False, if True, the output is a tuple of length two which has the matching elements. The first element is always positive and the second negative Returns ------- list: of matched expressions, i.e. transitions """ assert isinstance(expr_vec, MatrixBase), \ "Expecting a vector of expressions" transition = list() for expr in expr_vec: transition += getExpressions(expr) transition = list(set(_findMatchingExpression(transition))) if outTuple: return _transitionListToMatchedTuple(transition) else: return transition def _findMatchingExpression(expressions, full_output=False): """ Reduce a list of expressions to a list of transitions. A transition is found when two expressions are identical with a change of sign. Parameters ---------- expressions: list the list of expressions full_output: bool, optional If True, output the unmatched expressions as well. Defaults to False. Returns ------- list: of expressions that was matched """ t_list = list() for i in range(len(expressions) - 1): for j in range(i + 1, len(expressions)): b = expressions[i] + expressions[j] if b == 0: t_list.append(expressions[i]) t_list.append(expressions[j]) if full_output: unmatched = set(expressions) - set(t_list) return t_list, list(unmatched) else: return t_list def _transitionListToMatchedTuple(transition): """ Convert a list of transitions to a list of tuple, where each tuple is of length 2 and contains the matched transitions. First element of the tuple is the positive term """ t_tuple_list = list() for i in range(len(transition) - 1): for j in range(i + 1, len(transition)): b = transition[i] + transition[j] # the two terms cancel out if b == 0: if sympy.Integer(-1) in getLeafs(transition[i]): t_tuple_list.append((transition[j], transition[i])) else: t_tuple_list.append((transition[i], transition[j])) return t_tuple_list def getExpressions(expr): input_dict = dict() _getExpression(expr.expand(), input_dict) return list(input_dict.keys()) def getLeafs(expr): input_dict = dict() _getLeaf(expr.expand(), input_dict) return list(input_dict.keys()) def _getLeaf(expr, input_dict): """ Get the leafs of an expression, can probably just do the same with expr.atoms() with most expression but we do not break down power terms i.e. x**2 will be broken down to (x,2) in expr.atoms() but this function will retain (x**2) """ t = expr.args t_lengths = np.array(list(map(_expressionLength, t))) for i, ti in enumerate(t): if t_lengths[i] == 0: input_dict.setdefault(ti, 0) input_dict[ti] += 1 else: _getLeaf(ti, input_dict) def _getExpression(expr, input_dict): """ all the operations is dependent on the conditions whether all the elements are leafs or only some of them. Only return expressions and not the individual elements """ t = expr.args if len(expr.atoms()) > 1 else [expr] # print t # find out the length of the components within this node t_lengths = np.array(list(map(_expressionLength, t))) # print(tLengths) if np.all(t_lengths == 0): # if all components are leafs, then the node is an expression input_dict.setdefault(expr, 0) input_dict[expr] += 1 else: for i, ti in enumerate(t): # if the leaf is a singleton, then it is an expression # else, go further along the tree if t_lengths[i] == 0: input_dict.setdefault(ti, 0) input_dict[ti] += 1 else: if isinstance(ti, sympy.Mul): _getExpression(ti, input_dict) elif isinstance(ti, sympy.Pow): input_dict.setdefault(ti, 0) input_dict[ti] += 1 def _expressionLength(expr): """ Returns the length of the expression i.e. number of terms. If the expression is a power term, i.e. x^2 then we assume that it is one term and return 0. """ # print type(expr) if isinstance(expr, sympy.Mul): return len(expr.args) elif isinstance(expr, sympy.Pow): return 0 else: return 0 def _findIndex(eq_vec, expr): """ Given a vector of expressions, find where you will locate the input term. Parameters ---------- eq_vec: :class:`sympy.Matrix` vector of sympy equation expr: sympy type An expression that we would like to find Returns ------- list: of index that contains the expression. Can be an empty list or with multiple integer """ out = list() for i, a in enumerate(eq_vec): j = _hasExpression(a, expr) if j is True: out.append(i) return out def _hasExpression(eq, expr): """ Test whether the equation eq has the expression expr """ out = False aExpand = eq.expand() if expr == aExpand: out = True if expr in aExpand.args: out = True return out def pureTransitionToOde(A): """ Get the ode from a pure transition matrix Parameters ---------- A: `sympy.Matrix` a transition matrix of size [n \times n] Returns ------- b: `sympy.Matrix` a matrix of size [n \times 1] which is the ode """ nrow, ncol = A.shape assert nrow == ncol, "Need a square matrix" B = [sum(A[:, i]) - sum(A[i, :]) for i in range(nrow)] return sympy.simplify(sympy.Matrix(B)) def stripBDFromOde(fx, bd_list=None): if bd_list is None: bd_list = getUnmatchedExpressionVector(fx, False) fx_copy = fx.copy() for i, fxi in enumerate(fx): term_in_expr = list(map(lambda x: x in fxi.expand().args, bd_list)) for j, term in enumerate(bd_list): fx_copy[i] -= term if term_in_expr[j] else 0 # simplify converts it to an ImmutableMatrix, so we make it into # a mutable object again because we want the expanded form return sympy.Matrix(sympy.simplify(fx_copy)).expand() def odeToPureTransition(fx, states, output_remain=False): bd_list, term_list = getUnmatchedExpressionVector(fx, full_output=True) fx = stripBDFromOde(fx, bd_list) # we now have fx with pure transitions A, remain_terms = _singleOriginTransition(fx, term_list, states) A, remain_terms = _odeToPureTransition(fx, remain_terms, A) # checking if our decomposition is correct fx1 = pureTransitionToOde(A) diff_ode = sympy.simplify(fx - fx1) if np.all(np.array(map(lambda x: x == 0, diff_ode)) == True): if output_remain: return A, remain_terms else: return A else: diff_term = sympy.Matrix(list(filter(lambda x: x != 0, diff_ode))) diff_term_list = getMatchingExpressionVector(diff_term, True) # If there is some single origin transition not being matched up # it is most likely because the transition originates from a # combination like (1-x) which got split into two parts - the # "1" and the "x" part. So we try to reverse the sign to see # if it helps. # TODO: increase robustness so if it does not help, then we # either bail out or revert to the normal version diff_term_list = map(lambda x_y: (x_y[1], x_y[0]), diff_term_list) A, remain_terms = _singleOriginTransition(diff_ode, diff_term_list, states, A) AA, remain_terms = _odeToPureTransition(diff_ode, remain_terms, A) if output_remain: return AA, remain_terms else: return AA def _odeToPureTransition(fx, terms=None, A=None): """ Get the pure transition matrix between states Parameters ---------- fx: :class:`sympy.matrices.MatrixBase` input ode in symbolic form, :math:`f(x)` terms: list of two element tuples which contains the matching terms A: `sympy.matricies.MatrixBase`, optional the matrix to be filled. Defaults to None, which will lead to the creation of a [len(fx), len(fx)] matrix with all zero elements Returns ------- A: :class:`sympy.matricies.MatrixBase` resulting transition matrix remain: list list of which contains the unmatched transitions """ if terms is None: terms = getMatchingExpressionVector(fx, True) if A is None: A = sympy.zeros(len(fx), len(fx)) remain_transition = list() for t1, t2 in terms: remain = True for i, aFrom in enumerate(fx): if _hasExpression(aFrom, t2): # arriving at for j, aTo in enumerate(fx): if _hasExpression(aTo, t1): A[i, j] += t1 # from i to j remain = False if remain: remain_transition.append((t1, t2)) return A, remain_transition def _singleOriginTransition(fx, term_list, states, A=None): if A is None: A = sympy.zeros(len(fx), len(fx)) remain_term_list = list() for k, transition_tuple in enumerate(term_list): t1, t2 = transition_tuple possible_origin = list() remain = True for i, s in enumerate(states): if s in t1.atoms(): possible_origin.append(i) if len(possible_origin) == 1: for j, fxj in enumerate(fx): # print(t1, fxj, possibleOrigin[0] != j, _hasExpression(fxj, t2)) if possible_origin[0] != j and _hasExpression(fxj, t1): A[possible_origin[0], j] += t1 remain = False # print(t1, possibleOrigin, j, fxj, "\n") if remain: remain_term_list.append(transition_tuple) return A, remain_term_list