/* * Copyright (c) 2016 Villu Ruusmann * * This file is part of JPMML-SparkML * * JPMML-SparkML is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * JPMML-SparkML is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with JPMML-SparkML. If not, see <http://www.gnu.org/licenses/>. */ package org.jpmml.sparkml; import java.util.List; import java.util.Objects; import org.apache.spark.ml.Model; import org.apache.spark.ml.PredictionModel; import org.apache.spark.ml.classification.ClassificationModel; import org.apache.spark.ml.param.shared.HasFeaturesCol; import org.apache.spark.ml.param.shared.HasLabelCol; import org.apache.spark.ml.param.shared.HasPredictionCol; import org.dmg.pmml.DataField; import org.dmg.pmml.DataType; import org.dmg.pmml.Field; import org.dmg.pmml.MiningFunction; import org.dmg.pmml.Output; import org.dmg.pmml.OutputField; import org.jpmml.converter.BooleanFeature; import org.jpmml.converter.CategoricalFeature; import org.jpmml.converter.CategoricalLabel; import org.jpmml.converter.ContinuousFeature; import org.jpmml.converter.ContinuousLabel; import org.jpmml.converter.Feature; import org.jpmml.converter.IndexFeature; import org.jpmml.converter.Label; import org.jpmml.converter.LabelUtil; import org.jpmml.converter.ModelUtil; import org.jpmml.converter.Schema; import org.jpmml.converter.SchemaUtil; import org.jpmml.converter.mining.MiningModelUtil; abstract public class ModelConverter<T extends Model<T> & HasFeaturesCol & HasPredictionCol> extends TransformerConverter<T> { public ModelConverter(T model){ super(model); } abstract public MiningFunction getMiningFunction(); abstract public org.dmg.pmml.Model encodeModel(Schema schema); public Schema encodeSchema(SparkMLEncoder encoder){ T model = getTransformer(); Label label = null; if(model instanceof HasLabelCol){ HasLabelCol hasLabelCol = (HasLabelCol)model; String labelCol = hasLabelCol.getLabelCol(); Feature feature = encoder.getOnlyFeature(labelCol); MiningFunction miningFunction = getMiningFunction(); switch(miningFunction){ case CLASSIFICATION: { if(feature instanceof BooleanFeature){ BooleanFeature booleanFeature = (BooleanFeature)feature; label = new CategoricalLabel(booleanFeature.getName(), booleanFeature.getDataType(), booleanFeature.getValues()); } else if(feature instanceof CategoricalFeature){ CategoricalFeature categoricalFeature = (CategoricalFeature)feature; DataField dataField = (DataField)categoricalFeature.getField(); label = new CategoricalLabel(dataField); } else if(feature instanceof ContinuousFeature){ ContinuousFeature continuousFeature = (ContinuousFeature)feature; int numClasses = 2; if(model instanceof ClassificationModel){ ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>)model; numClasses = classificationModel.numClasses(); } List<Integer> categories = LabelUtil.createTargetCategories(numClasses); Field<?> field = encoder.toCategorical(continuousFeature.getName(), categories); encoder.putOnlyFeature(labelCol, new IndexFeature(encoder, field, categories)); label = new CategoricalLabel(field.getName(), field.getDataType(), categories); } else { throw new IllegalArgumentException("Expected a categorical or categorical-like continuous feature, got " + feature); } } break; case REGRESSION: { Field<?> field = encoder.toContinuous(feature.getName()); field.setDataType(DataType.DOUBLE); label = new ContinuousLabel(field.getName(), field.getDataType()); } break; default: throw new IllegalArgumentException("Mining function " + miningFunction + " is not supported"); } } if(model instanceof ClassificationModel){ ClassificationModel<?, ?> classificationModel = (ClassificationModel<?, ?>)model; int numClasses = classificationModel.numClasses(); CategoricalLabel categoricalLabel = (CategoricalLabel)label; SchemaUtil.checkSize(numClasses, categoricalLabel); } String featuresCol = model.getFeaturesCol(); List<Feature> features = encoder.getFeatures(featuresCol); if(model instanceof PredictionModel){ PredictionModel<?, ?> predictionModel = (PredictionModel<?, ?>)model; int numFeatures = predictionModel.numFeatures(); if(numFeatures != -1){ SchemaUtil.checkSize(numFeatures, features); } } Schema result = new Schema(label, features); checkSchema(result); return result; } public List<OutputField> registerOutputFields(Label label, org.dmg.pmml.Model model, SparkMLEncoder encoder){ return null; } public org.dmg.pmml.Model registerModel(SparkMLEncoder encoder){ Schema schema = encodeSchema(encoder); Label label = schema.getLabel(); org.dmg.pmml.Model model = encodeModel(schema); List<OutputField> sparkOutputFields = registerOutputFields(label, model, encoder); if(sparkOutputFields != null && sparkOutputFields.size() > 0){ org.dmg.pmml.Model finalModel = MiningModelUtil.getFinalModel(model); Output output = ModelUtil.ensureOutput(finalModel); List<OutputField> outputFields = output.getOutputFields(); outputFields.addAll(sparkOutputFields); } return model; } static private void checkSchema(Schema schema){ Label label = schema.getLabel(); List<? extends Feature> features = schema.getFeatures(); if(label == null){ return; } for(Feature feature : features){ if(Objects.equals(label.getName(), feature.getName())){ throw new IllegalArgumentException("Label column '" + label.getName() + "' is contained in the list of feature columns"); } } } }