"""Utilities to debug Myia. Using pytest: pytest -s -T debug.trace.<name>:<arg1>:<arg2>... ... pytest -s -T 'debug.trace.<name>(<arg1>,<arg2>)' ... For example: -T debug.trace.prof # profile -T debug.trace.explore # look at available events and fields -T debug.trace.log # list events in order -T debug.trace.log:opt # list all opt events -T debug.trace.log:opt:opt # show the opt field for all opt events -T debug.trace.log:+opt:opt # show the opts from step_opt # compare node count before and after a successful opt -T debug.trace.compare:+opt:+optname:+countnodes -T debug.trace.graph # show final graph (requires Buche) The '+' prefix taps pre-made paths or rules: * +xyz, as a path argument, refers to debug.trace._path_xyz * +xyz, as a field, refers to debug.trace._rule_xyz """ import os import time from collections import Counter, defaultdict from decimal import Decimal import breakword from colorama import Fore from myia.utils import ( # noqa: F401 DoTrace, Profiler as prof, TraceExplorer as explore, TraceListener, ) from .inject import bucheg _beginning = time.monotonic() _current = time.monotonic() ############# # Utilities # ############# class Time: def __init__(self, t=None): if t is None: self.t = time.monotonic() - _beginning else: self.t = t def compare(self, other): return Time(other.t - self.t) @classmethod def statistics(cls, tdata): data = [t.t for t in tdata] print(f" Min:", Time(min(data))) print(f" Avg:", Time(sum(data) / len(data))) print(f" Max:", Time(max(data))) def __str__(self): d = Decimal(self.t) unit = "s" units = ["ms", "us", "ns"] for other_unit in units: if d >= 1: break else: d *= 1000 unit = other_unit d = round(d, 3) return f"{d}{unit}" def _color(color, text): """Wrap the text with the given color. If Buche is active, the color is not applied. """ if os.environ.get("BUCHE"): return text else: return f"{color}{text}{Fore.RESET}" def _pgraph(path): """Print a graph using Buche.""" def _p(graph, **_): bucheg(graph) return lambda: DoTrace({path: _p}) class Getters(dict): def __init__(self, fields, kwfields): for field in fields: if field == "help": self[field] = lambda **kwargs: ", ".join(kwargs) elif field.startswith("+"): field = field[1:] self[field] = globals()[f"_rule_{field}"] else: self[field] = self._get_by_name(field) for name, getter in kwfields.items(): self[name] = getter def _get_by_name(self, field): def _get(**kwargs): return kwargs.get(field, f"<{field} NOT FOUND>") return _get def __call__(self, kwargs): results = {name: getter(**kwargs) for name, getter in self.items()} return results def _display(curpath, results, word=None, brk=True): w = word or breakword.word() if len(results) == 0: print(w, curpath) elif len(results) == 1: _, value = list(results.items())[0] print(w, _color(Fore.LIGHTBLACK_EX, curpath), value) else: print(w, _color(Fore.LIGHTBLACK_EX, curpath)) for name, value in results.items(): print(f" {name}: {value}") if brk: _brk(w) def _brk(w): if breakword.after(): print("Breaking on:", w) breakpoint(skip=["debug.*", "myia.utils.trace"]) def _resolve_path(p, variant=""): if not p: rval = "**" elif p.startswith("+"): rval = globals()[f"_path{variant}_{p[1:]}"] else: rval = p if isinstance(rval, str): rval = [rval] return rval ########### # Tracers # ########### # Print the final graph graph = _pgraph("step_validate/enter") # Print the graph after monomorphization graph_mono = _pgraph("step_specialize/exit") # Print the graph after parsing graph_parse = _pgraph("step_parse/exit") def log(path=None, *fields, **kwfields): """Log fields of interest on the given path. The breakword module is used for logging, thus it is possible to set a word upon which to enter a breakpoint (using the BREAKWORD environment variable). * When no path is given, show all events. * The "help" field shows all possible fields. """ getters = Getters(fields, kwfields) def _p(**kwargs): _curpath = kwargs["_curpath"] results = getters(kwargs) _display(_curpath, results) return DoTrace({pth: _p for pth in _resolve_path(path)}) def opts(): """Log the optimizations applied during the opt phase.""" return log("step_opt/**/opt/success", opt=lambda opt, **_: opt.name) def compare(path=None, *fields, **kwfields): store = {} getters = Getters(fields, kwfields) def _compare(old, new): if isinstance(old, dict): return {k: _compare(v, new[k]) for k, v in old.items()} elif isinstance(old, (int, float)): diff = new - old if diff == 0: return old c = Fore.LIGHTGREEN_EX if diff > 0 else Fore.LIGHTRED_EX diff = f"+{diff}" if diff > 0 else str(diff) return f"{old} -> {new} ({_color(c, diff)})" elif hasattr(old, "compare"): return old.compare(new) elif old == new: return old else: return f"{old} -> {new}" def _enter(_curpath, **kwargs): _path = _curpath[:-6] w = breakword.word() store[_path] = (w, getters(kwargs)) _brk(w) def _exit(_curpath, **kwargs): if "success" in kwargs and not kwargs["success"]: return _path = _curpath[:-5] w, old = store[_path] new = getters(kwargs) _display(_path, _compare(old, new), word=w, brk=False) path = _resolve_path(path, variant="cmp") return DoTrace({f"{path}/enter": _enter, f"{path}/exit": _exit}) class StatAccumulator(TraceListener): def __init__(self, path, fields, kwfields): """Initialize a StatAccumulator.""" self.path = _resolve_path(path) self.accum = defaultdict(list) self.getters = Getters(fields, kwfields) def install(self, tracer): """Install the StatAccumulator.""" patt = self.path or "**" tracer.on(patt, self._do) def _do(self, **kwargs): for k, v in self.getters(kwargs).items(): self.accum[(k, type(v))].append(v) def post(self): for (name, typ), data in self.accum.items(): print(f"{name}:") if not data: print(" No data.") if issubclass(typ, (int, float)): print(f" Min:", min(data)) print(f" Avg:", sum(data) / len(data)) print(f" Max:", max(data)) elif hasattr(typ, "statistics"): typ.statistics(data) else: counts = Counter(data) align = max(len(str(obj)) for obj in counts) counts = sorted(counts.items(), key=lambda k: -k[1]) for obj, count in counts: print(f" {str(obj).ljust(align)} -> {count}") def stat(path=None, *fields, **kwfields): """Collect and display statistics about certain fields. * Numeric fields will display min/max/avg * String/other fields will count occurrences, sorted descending """ return StatAccumulator(path, fields, kwfields) ######### # Paths # ######### _path_opt = ["step_opt/**/opt/success", "step_opt2/**/opt/success"] _pathcmp_opt = ["step_opt/**/opt", "step_opt2/**/opt"] ######### # Rules # ######### def _rule_optname(opt=None, **kwargs): if opt is None: return "<NOT FOUND>" return opt.name def _rule_optparam(node=None, **kwargs): if node is None: return "<NOT FOUND>" try: return str(node.inputs[1]) except Exception: return "<???>" def _rule_countnodes(graph=None, manager=None, **kwargs): if manager is None: if graph is None: return "<NOT FOUND>" if graph._manager is None: return "<NO MANAGER>" manager = graph.manager return len(manager.all_nodes) def _rule_countgraphs(graph=None, manager=None, **kwargs): if manager is None: if graph is None: return "<NOT FOUND>" if graph._manager is None: return "<NO MANAGER>" manager = graph.manager return len(manager.graphs) def _rule_time(**kwargs): return Time() def _rule_reltime(**kwargs): global _current old = _current _current = time.monotonic() return Time(_current - old)