# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- ''' Mapping and utility functions for Name to Spark ML operators ''' from pyspark.ml.feature import Binarizer from pyspark.ml.feature import BucketedRandomProjectionLSHModel from pyspark.ml.feature import Bucketizer from pyspark.ml.feature import ChiSqSelectorModel from pyspark.ml.feature import CountVectorizerModel from pyspark.ml.feature import DCT from pyspark.ml.feature import ElementwiseProduct from pyspark.ml.feature import HashingTF from pyspark.ml.feature import IDFModel from pyspark.ml.feature import ImputerModel from pyspark.ml.feature import IndexToString from pyspark.ml.feature import MaxAbsScalerModel from pyspark.ml.feature import MinHashLSHModel from pyspark.ml.feature import MinMaxScalerModel from pyspark.ml.feature import NGram from pyspark.ml.feature import Normalizer from pyspark.ml.feature import OneHotEncoderModel from pyspark.ml.feature import PCAModel from pyspark.ml.feature import PolynomialExpansion from pyspark.ml.feature import QuantileDiscretizer from pyspark.ml.feature import RegexTokenizer from pyspark.ml.feature import StandardScalerModel from pyspark.ml.feature import StopWordsRemover from pyspark.ml.feature import StringIndexerModel from pyspark.ml.feature import Tokenizer from pyspark.ml.feature import VectorAssembler from pyspark.ml.feature import VectorIndexerModel from pyspark.ml.feature import VectorSlicer from pyspark.ml.feature import Word2VecModel from pyspark.ml.classification import LinearSVCModel, RandomForestClassificationModel, GBTClassificationModel, \ MultilayerPerceptronClassificationModel from pyspark.ml.classification import LogisticRegressionModel from pyspark.ml.classification import DecisionTreeClassificationModel from pyspark.ml.classification import NaiveBayesModel from pyspark.ml.classification import OneVsRestModel from pyspark.ml.regression import AFTSurvivalRegressionModel, DecisionTreeRegressionModel, RandomForestRegressionModel from pyspark.ml.regression import GBTRegressionModel from pyspark.ml.regression import GeneralizedLinearRegressionModel from pyspark.ml.regression import IsotonicRegressionModel from pyspark.ml.regression import LinearRegressionModel from pyspark.ml.clustering import BisectingKMeans from pyspark.ml.clustering import KMeans from pyspark.ml.clustering import GaussianMixture from pyspark.ml.clustering import LDA def build_sparkml_operator_name_map(): res = {k: "pyspark.ml.feature." + k.__name__ for k in [ Binarizer, BucketedRandomProjectionLSHModel, Bucketizer, ChiSqSelectorModel, CountVectorizerModel, DCT, ElementwiseProduct, HashingTF, IDFModel, ImputerModel, IndexToString, MaxAbsScalerModel, MinHashLSHModel, MinMaxScalerModel, NGram, Normalizer, OneHotEncoderModel, PCAModel, PolynomialExpansion, QuantileDiscretizer, RegexTokenizer, StandardScalerModel, StopWordsRemover, StringIndexerModel, Tokenizer, VectorAssembler, VectorIndexerModel, VectorSlicer, Word2VecModel ]} res.update({k: "pyspark.ml.classification." + k.__name__ for k in [ LinearSVCModel, LogisticRegressionModel, DecisionTreeClassificationModel, GBTClassificationModel, RandomForestClassificationModel, NaiveBayesModel, MultilayerPerceptronClassificationModel, OneVsRestModel ]}) res.update({k: "pyspark.ml.regression." + k.__name__ for k in [ AFTSurvivalRegressionModel, DecisionTreeRegressionModel, GBTRegressionModel, GBTRegressionModel, GeneralizedLinearRegressionModel, IsotonicRegressionModel, LinearRegressionModel, RandomForestRegressionModel ]}) return res sparkml_operator_name_map = build_sparkml_operator_name_map() def get_sparkml_operator_name(model_type): ''' Get operator name of the input argument :param model_type: A spark-ml object (LinearRegression, StringIndexer, ...) :return: A string which stands for the type of the input model in our conversion framework ''' if model_type not in sparkml_operator_name_map: raise ValueError("No proper operator name found for '%s'" % model_type) return sparkml_operator_name_map[model_type]