# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- import numpy as np from ._apply_operation import apply_cast, apply_reshape from ..proto import onnx_proto def get_label_classes(scope, op): """ Extracts the model classes, handles option ``nocl``. """ options = scope.get_options(op, dict(nocl=False)) if options['nocl']: if len(op.classes_.shape) > 1 and op.classes_.shape[1] > 1: raise RuntimeError( "Options 'nocl=True' is not implemented for multi-label " "classification (class: {}).".format(op.__class__.__name__)) classes = np.arange(0, len(op.classes_)) else: classes = op.classes_ return classes def _finalize_converter_classes(scope, argmax_output_name, output_full_name, container, classes): """ See :func:`convert_voting_classifier`. """ if np.issubdtype(classes.dtype, np.floating): class_type = onnx_proto.TensorProto.INT32 classes = np.array(list(map(lambda x: int(x), classes))) elif np.issubdtype(classes.dtype, np.signedinteger): class_type = onnx_proto.TensorProto.INT32 else: classes = np.array([s.encode('utf-8') for s in classes]) class_type = onnx_proto.TensorProto.STRING classes_name = scope.get_unique_variable_name('classes') container.add_initializer(classes_name, class_type, classes.shape, classes) array_feature_extractor_result_name = scope.get_unique_variable_name( 'array_feature_extractor_result') container.add_node( 'ArrayFeatureExtractor', [classes_name, argmax_output_name], array_feature_extractor_result_name, op_domain='ai.onnx.ml', name=scope.get_unique_operator_name('ArrayFeatureExtractor')) output_shape = (-1,) if class_type == onnx_proto.TensorProto.INT32: cast2_result_name = scope.get_unique_variable_name('cast2_result') reshaped_result_name = scope.get_unique_variable_name( 'reshaped_result') apply_cast(scope, array_feature_extractor_result_name, cast2_result_name, container, to=onnx_proto.TensorProto.FLOAT) apply_reshape(scope, cast2_result_name, reshaped_result_name, container, desired_shape=output_shape) apply_cast(scope, reshaped_result_name, output_full_name, container, to=onnx_proto.TensorProto.INT64) else: # string labels apply_reshape(scope, array_feature_extractor_result_name, output_full_name, container, desired_shape=output_shape)