# Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. """ Custom Operator for NMF Decomposition ===================================== `NMF <https://scikit-learn.org/stable/modules/generated/ sklearn.decomposition.NMF.html>`_ factorizes an input matrix into two matrices *W, H* of rank *k* so that :math:`WH \\sim M``. :math:`M=(m_{ij})` may be a binary matrix where *i* is a user and *j* a product he bought. The prediction function depends on whether or not the user needs a recommandation for an existing user or a new user. This example addresses the first case. The second case is more complex as it theoretically requires the estimation of a new matrix *W* with a gradient descent. .. contents:: :local: Building a simple model +++++++++++++++++++++++ """ import os import skl2onnx import onnxruntime import sklearn from sklearn.decomposition import NMF import numpy as np import matplotlib.pyplot as plt from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer import onnx from skl2onnx.algebra.onnx_ops import ( OnnxArrayFeatureExtractor, OnnxMul, OnnxReduceSum) from skl2onnx.common.data_types import FloatTensorType from onnxruntime import InferenceSession mat = np.array([[1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0]], dtype=np.float64) mat[:mat.shape[1], :] += np.identity(mat.shape[1]) mod = NMF(n_components=2) W = mod.fit_transform(mat) H = mod.components_ pred = mod.inverse_transform(W) print("original predictions") exp = [] for i in range(mat.shape[0]): for j in range(mat.shape[1]): exp.append((i, j, pred[i, j])) print(exp) ####################### # Let's rewrite the prediction in a way it is closer # to the function we need to convert into ONNX. def predict(W, H, row_index, col_index): return np.dot(W[row_index, :], H[:, col_index]) got = [] for i in range(mat.shape[0]): for j in range(mat.shape[1]): got.append((i, j, predict(W, H, i, j))) print(got) ################################# # Conversion into ONNX # ++++++++++++++++++++ # # There is no implemented converter for # `NMF <https://scikit-learn.org/stable/modules/generated/ # sklearn.decomposition.NMF.html>`_ as the function we plan # to convert is not transformer or a predictor. # The following converter does not need to be registered, # it just creates an ONNX graph equivalent to function # *predict* implemented above. def nmf_to_onnx(W, H, op_version=12): """ The function converts a NMF described by matrices *W*, *H* (*WH* approximate training data *M*). into a function which takes two indices *(i, j)* and returns the predictions for it. It assumes these indices applies on the training data. """ col = OnnxArrayFeatureExtractor(H, 'col') row = OnnxArrayFeatureExtractor(W.T, 'row') dot = OnnxMul(col, row, op_version=op_version) res = OnnxReduceSum(dot, output_names="rec", op_version=op_version) indices_type = np.array([0], dtype=np.int64) onx = res.to_onnx(inputs={'col': indices_type, 'row': indices_type}, outputs=[('rec', FloatTensorType((None, 1)))], target_opset=op_version) return onx model_onnx = nmf_to_onnx(W.astype(np.float32), H.astype(np.float32)) print(model_onnx) ######################################## # Let's compute prediction with it. sess = InferenceSession(model_onnx.SerializeToString()) def predict_onnx(sess, row_indices, col_indices): res = sess.run(None, {'col': col_indices, 'row': row_indices}) return res onnx_preds = [] for i in range(mat.shape[0]): for j in range(mat.shape[1]): row_indices = np.array([i], dtype=np.int64) col_indices = np.array([j], dtype=np.int64) pred = predict_onnx(sess, row_indices, col_indices)[0] onnx_preds.append((i, j, pred[0, 0])) print(onnx_preds) ################################### # The ONNX graph looks like the following. pydot_graph = GetPydotGraph( model_onnx.graph, name=model_onnx.graph.name, rankdir="TB", node_producer=GetOpNodeProducer("docstring")) pydot_graph.write_dot("graph_nmf.dot") os.system('dot -O -Tpng graph_nmf.dot') image = plt.imread("graph_nmf.dot.png") plt.imshow(image) plt.axis('off') ################################# # **Versions used for this example** print("numpy:", np.__version__) print("scikit-learn:", sklearn.__version__) print("onnx: ", onnx.__version__) print("onnxruntime: ", onnxruntime.__version__) print("skl2onnx: ", skl2onnx.__version__)