# Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. """ Extension for sphinx. """ from importlib import import_module import sphinx from docutils import nodes from docutils.parsers.rst import Directive from docutils.statemachine import StringList from sphinx.util.nodes import nested_parse_with_titles from tabulate import tabulate import skl2onnx from skl2onnx._supported_operators import build_sklearn_operator_name_map from skl2onnx.algebra.onnx_ops import dynamic_class_creation from skl2onnx.algebra.sklearn_ops import dynamic_class_creation_sklearn import onnxruntime def skl2onnx_version_role(role, rawtext, text, lineno, inliner, options=None, content=None): """ Defines custom role *skl2onnx-version* which returns *skl2onnx* version. """ if options is None: options = {} if content is None: content = [] if text == 'v': version = 'v' + skl2onnx.__version__ elif text == 'rt': version = 'v' + onnxruntime.__version__ else: raise RuntimeError("skl2onnx_version_role cannot interpret content '{0}'.".format(text)) node = nodes.literal(version) return [node], [] class SupportedSkl2OnnxDirective(Directive): """ Automatically displays the list of models *skl2onnx* can currently convert. """ required_arguments = False optional_arguments = 0 final_argument_whitespace = True option_spec = {} has_content = False def run(self): models = skl2onnx.supported_converters(True) bullets = nodes.bullet_list() ns = [bullets] for mod in models: par = nodes.paragraph() par += nodes.Text(mod) bullets += nodes.list_item('', par) return ns class SupportedOnnxOpsDirective(Directive): """ Automatically displays the list of supported ONNX models *skl2onnx* can use to build converters. """ required_arguments = False optional_arguments = 0 final_argument_whitespace = True option_spec = {} has_content = False def run(self): cls = dynamic_class_creation() rows = [] sorted_keys = list(sorted(cls)) main = nodes.container() def make_ref(name): cl = cls[name] return ":ref:`l-onnx-{}`".format(cl.__name__) table = [] cut = len(sorted_keys) // 3 + (1 if len(sorted_keys) % 3 else 0) for i in range(cut): row = [] row.append(make_ref(sorted_keys[i])) if i + cut < len(sorted_keys): row.append(make_ref(sorted_keys[i + cut])) if i + cut * 2 < len(sorted_keys): row.append(make_ref(sorted_keys[i + cut * 2])) else: row.append('') else: row.append('') row.append('') table.append(row) rst = tabulate(table, tablefmt="rst") rows = rst.split("\n") node = nodes.container() st = StringList(rows) nested_parse_with_titles(self.state, st, node) main += node rows.append('') for name in sorted_keys: rows = [] cl = cls[name] rows.append('.. _l-onnx-{}:'.format(cl.__name__)) rows.append('') rows.append(cl.__name__) rows.append('=' * len(cl.__name__)) rows.append('') rows.append(".. autoclass:: skl2onnx.algebra.onnx_ops.{}".format(name)) st = StringList(rows) node = nodes.container() nested_parse_with_titles(self.state, st, node) main += node return [main] class SupportedSklearnOpsDirective(Directive): """ Automatically displays the list of available converters. """ required_arguments = False optional_arguments = 0 final_argument_whitespace = True option_spec = {} has_content = False def run(self): cls = dynamic_class_creation_sklearn() rows = [] sorted_keys = list(sorted(cls)) main = nodes.container() def make_ref(name): cl = cls[name] return ":ref:`l-sklops-{}`".format(cl.__name__) table = [] cut = len(sorted_keys) // 3 + (1 if len(sorted_keys) % 3 else 0) for i in range(cut): row = [] row.append(make_ref(sorted_keys[i])) if i + cut < len(sorted_keys): row.append(make_ref(sorted_keys[i + cut])) if i + cut * 2 < len(sorted_keys): row.append(make_ref(sorted_keys[i + cut * 2])) else: row.append('') else: row.append('') row.append('') table.append(row) rst = tabulate(table, tablefmt="rst") rows = rst.split("\n") node = nodes.container() st = StringList(rows) nested_parse_with_titles(self.state, st, node) main += node rows.append('') for name in sorted_keys: rows = [] cl = cls[name] rows.append('.. _l-sklops-{}:'.format(cl.__name__)) rows.append('') rows.append(cl.__name__) rows.append('=' * len(cl.__name__)) rows.append('') rows.append(".. autoclass:: skl2onnx.algebra.sklearn_ops.{}".format(name)) st = StringList(rows) node = nodes.container() nested_parse_with_titles(self.state, st, node) main += node return [main] def missing_ops(): """ Builds the list of supported and not supported models. """ from sklearn import __all__ from sklearn.base import BaseEstimator found = [] for sub in __all__: try: mod = import_module("{0}.{1}".format("sklearn", sub)) except ImportError: continue cls = getattr(mod, "__all__", None) if cls is None: cls = list(mod.__dict__) cls = [mod.__dict__[cl] for cl in cls] for cl in cls: try: issub = issubclass(cl, BaseEstimator) except TypeError: continue if cl.__name__ in {'Pipeline', 'ColumnTransformer', 'FeatureUnion', 'BaseEstimator'}: continue if (sub in {'calibration', 'dummy', 'manifold'} and 'Calibrated' not in cl.__name__): continue if issub: found.append((cl.__name__, sub, cl)) found.sort() return found class AllSklearnOpsDirective(Directive): """ Displays the list of models implemented in scikit-learn and whether or not there is an associated converter. """ required_arguments = False optional_arguments = 0 final_argument_whitespace = True option_spec = {} has_content = False def run(self): from sklearn import __version__ as skver found = missing_ops() nbconverters = 0 supported = set(build_sklearn_operator_name_map()) rows = [".. list-table::", " :header-rows: 1", " :widths: 10 7 4", "", " * - Name", " - Package", " - Supported"] for name, sub, cl in found: rows.append(" * - " + name) rows.append(" - " + sub) if cl in supported: rows.append(" - Yes") nbconverters += 1 else: rows.append(" -") rows.append("") rows.append("scikit-learn's version is **{0}**.".format(skver)) rows.append("{0}/{1} models are covered.".format(nbconverters, len(found))) node = nodes.container() st = StringList(rows) nested_parse_with_titles(self.state, st, node) main = nodes.container() main += node return [main] def setup(app): # Placeholder to initialize the folder before # generating the documentation. app.add_role('skl2onnxversion', skl2onnx_version_role) app.add_directive('supported-skl2onnx', SupportedSkl2OnnxDirective) app.add_directive('supported-onnx-ops', SupportedOnnxOpsDirective) app.add_directive('supported-sklearn-ops', SupportedSklearnOpsDirective) app.add_directive('covered-sklearn-ops', AllSklearnOpsDirective) return {'version': sphinx.__display_version__, 'parallel_read_safe': True}