/* * Copyright (c) 2018 Villu Ruusmann * * This file is part of JPMML-SparkML * * JPMML-SparkML is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * JPMML-SparkML is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with JPMML-SparkML. If not, see <http://www.gnu.org/licenses/>. */ package org.jpmml.sparkml; import java.io.ByteArrayOutputStream; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.ListIterator; import java.util.Map; import java.util.function.Function; import java.util.regex.Pattern; import java.util.stream.Collectors; import javax.xml.bind.JAXBException; import com.google.common.collect.Iterables; import org.apache.spark.ml.PipelineModel; import org.apache.spark.ml.PipelineStage; import org.apache.spark.ml.Transformer; import org.apache.spark.ml.linalg.Vector; import org.apache.spark.ml.param.shared.HasPredictionCol; import org.apache.spark.ml.param.shared.HasProbabilityCol; import org.apache.spark.ml.regression.GeneralizedLinearRegressionModel; import org.apache.spark.ml.tuning.CrossValidatorModel; import org.apache.spark.ml.tuning.TrainValidationSplitModel; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.types.StructType; import org.dmg.pmml.DerivedField; import org.dmg.pmml.FieldName; import org.dmg.pmml.MiningField; import org.dmg.pmml.MiningFunction; import org.dmg.pmml.MiningSchema; import org.dmg.pmml.Output; import org.dmg.pmml.OutputField; import org.dmg.pmml.PMML; import org.dmg.pmml.ResultFeature; import org.dmg.pmml.VerificationField; import org.jpmml.converter.Feature; import org.jpmml.converter.ModelUtil; import org.jpmml.converter.mining.MiningModelUtil; import org.jpmml.model.metro.MetroJAXBUtil; public class PMMLBuilder { private StructType schema = null; private PipelineModel pipelineModel = null; private Map<RegexKey, Map<String, Object>> options = new LinkedHashMap<>(); private Verification verification = null; public PMMLBuilder(StructType schema, PipelineModel pipelineModel){ setSchema(schema); setPipelineModel(pipelineModel); } public PMMLBuilder(StructType schema, PipelineStage pipelineStage){ throw new IllegalArgumentException("Expected a fitted pipeline model (class " + PipelineModel.class.getName() + "), got a pipeline stage (" + (pipelineStage != null ? ("class " + (pipelineStage.getClass()).getName()) : null) + ")"); } public PMML build(){ StructType schema = getSchema(); PipelineModel pipelineModel = getPipelineModel(); Map<RegexKey, ? extends Map<String, ?>> options = getOptions(); Verification verification = getVerification(); ConverterFactory converterFactory = new ConverterFactory(options); SparkMLEncoder encoder = new SparkMLEncoder(schema, converterFactory); Map<FieldName, DerivedField> derivedFields = encoder.getDerivedFields(); List<org.dmg.pmml.Model> models = new ArrayList<>(); List<String> predictionColumns = new ArrayList<>(); List<String> probabilityColumns = new ArrayList<>(); // Transformations preceding the last model List<FieldName> preProcessorNames = Collections.emptyList(); Iterable<Transformer> transformers = getTransformers(pipelineModel); for(Transformer transformer : transformers){ TransformerConverter<?> converter = converterFactory.newConverter(transformer); if(converter instanceof FeatureConverter){ FeatureConverter<?> featureConverter = (FeatureConverter<?>)converter; featureConverter.registerFeatures(encoder); } else if(converter instanceof ModelConverter){ ModelConverter<?> modelConverter = (ModelConverter<?>)converter; org.dmg.pmml.Model model = modelConverter.registerModel(encoder); models.add(model); hasPredictionCol: if(transformer instanceof HasPredictionCol){ HasPredictionCol hasPredictionCol = (HasPredictionCol)transformer; // XXX if((transformer instanceof GeneralizedLinearRegressionModel) && (MiningFunction.CLASSIFICATION).equals(model.getMiningFunction())){ break hasPredictionCol; } predictionColumns.add(hasPredictionCol.getPredictionCol()); } // End if if(transformer instanceof HasProbabilityCol){ HasProbabilityCol hasProbabilityCol = (HasProbabilityCol)transformer; probabilityColumns.add(hasProbabilityCol.getProbabilityCol()); } preProcessorNames = new ArrayList<>(derivedFields.keySet()); } else { throw new IllegalArgumentException("Expected a subclass of " + FeatureConverter.class.getName() + " or " + ModelConverter.class.getName() + ", got " + (converter != null ? ("class " + (converter.getClass()).getName()) : null)); } } // Transformations following the last model List<FieldName> postProcessorNames = new ArrayList<>(derivedFields.keySet()); postProcessorNames.removeAll(preProcessorNames); org.dmg.pmml.Model model; if(models.size() == 1){ model = Iterables.getOnlyElement(models); } else if(models.size() > 1){ model = MiningModelUtil.createModelChain(models); } else { throw new IllegalArgumentException("Expected a pipeline with one or more models, got a pipeline with zero models"); } // End if if(postProcessorNames.size() > 0){ org.dmg.pmml.Model finalModel = MiningModelUtil.getFinalModel(model); Output output = ModelUtil.ensureOutput(finalModel); for(FieldName postProcessorName : postProcessorNames){ DerivedField derivedField = derivedFields.get(postProcessorName); encoder.removeDerivedField(postProcessorName); OutputField outputField = new OutputField(derivedField.getName(), derivedField.getOpType(), derivedField.getDataType()) .setResultFeature(ResultFeature.TRANSFORMED_VALUE) .setExpression(derivedField.getExpression()); output.addOutputFields(outputField); } } PMML pmml = encoder.encodePMML(model); if((predictionColumns.size() > 0 || probabilityColumns.size() > 0) && (verification != null)){ Dataset<Row> dataset = verification.getDataset(); Dataset<Row> transformedDataset = verification.getTransformedDataset(); Double precision = verification.getPrecision(); Double zeroThreshold = verification.getZeroThreshold(); List<String> inputColumns = new ArrayList<>(); MiningSchema miningSchema = model.getMiningSchema(); List<MiningField> miningFields = miningSchema.getMiningFields(); for(MiningField miningField : miningFields){ MiningField.UsageType usageType = miningField.getUsageType(); switch(usageType){ case ACTIVE: FieldName name = miningField.getName(); inputColumns.add(name.getValue()); break; default: break; } } Map<VerificationField, List<?>> data = new LinkedHashMap<>(); for(String inputColumn : inputColumns){ VerificationField verificationField = ModelUtil.createVerificationField(FieldName.create(inputColumn)); data.put(verificationField, getColumn(dataset, inputColumn)); } for(String predictionColumn : predictionColumns){ Feature feature = encoder.getOnlyFeature(predictionColumn); VerificationField verificationField = ModelUtil.createVerificationField(feature.getName()) .setPrecision(precision) .setZeroThreshold(zeroThreshold); data.put(verificationField, getColumn(transformedDataset, predictionColumn)); } for(String probabilityColumn : probabilityColumns){ List<Feature> features = encoder.getFeatures(probabilityColumn); for(int i = 0; i < features.size(); i++){ Feature feature = features.get(i); VerificationField verificationField = ModelUtil.createVerificationField(feature.getName()) .setPrecision(precision) .setZeroThreshold(zeroThreshold); data.put(verificationField, getVectorColumn(transformedDataset, probabilityColumn, i)); } } model.setModelVerification(ModelUtil.createModelVerification(data)); } return pmml; } public byte[] buildByteArray(){ return buildByteArray(1024 * 1024); } private byte[] buildByteArray(int size){ PMML pmml = build(); ByteArrayOutputStream os = new ByteArrayOutputStream(size); try { MetroJAXBUtil.marshalPMML(pmml, os); } catch(JAXBException je){ throw new RuntimeException(je); } return os.toByteArray(); } public File buildFile(File file) throws IOException { PMML pmml = build(); OutputStream os = new FileOutputStream(file); try { MetroJAXBUtil.marshalPMML(pmml, os); } catch(JAXBException je){ throw new RuntimeException(je); } finally { os.close(); } return file; } public PMMLBuilder putOption(String key, Object value){ return putOptions(Collections.singletonMap(key, value)); } public PMMLBuilder putOptions(Map<String, ?> map){ return putOptions(Pattern.compile(".*"), map); } public PMMLBuilder putOption(PipelineStage pipelineStage, String key, Object value){ return putOptions(pipelineStage, Collections.singletonMap(key, value)); } public PMMLBuilder putOptions(PipelineStage pipelineStage, Map<String, ?> map){ return putOptions(Pattern.compile(pipelineStage.uid(), Pattern.LITERAL), map); } public PMMLBuilder putOptions(Pattern pattern, Map<String, ?> map){ Map<RegexKey, Map<String, Object>> options = getOptions(); RegexKey key = new RegexKey(pattern); Map<String, Object> patternOptions = options.get(key); if(patternOptions == null){ patternOptions = new LinkedHashMap<>(); options.put(key, patternOptions); } patternOptions.putAll(map); return this; } public PMMLBuilder verify(Dataset<Row> dataset){ return verify(dataset, 1e-14, 1e-14); } public PMMLBuilder verify(Dataset<Row> dataset, double precision, double zeroThreshold){ PipelineModel pipelineModel = getPipelineModel(); Dataset<Row> transformedDataset = pipelineModel.transform(dataset); Verification verification = new Verification(dataset, transformedDataset) .setPrecision(precision) .setZeroThreshold(zeroThreshold); return setVerification(verification); } public StructType getSchema(){ return this.schema; } public PMMLBuilder setSchema(StructType schema){ if(schema == null){ throw new IllegalArgumentException(); } this.schema = schema; return this; } public PipelineModel getPipelineModel(){ return this.pipelineModel; } public PMMLBuilder setPipelineModel(PipelineModel pipelineModel){ if(pipelineModel == null){ throw new IllegalArgumentException(); } this.pipelineModel = pipelineModel; return this; } public Map<RegexKey, Map<String, Object>> getOptions(){ return this.options; } private PMMLBuilder setOptions(Map<RegexKey, Map<String, Object>> options){ if(options == null){ throw new IllegalArgumentException(); } this.options = options; return this; } public Verification getVerification(){ return this.verification; } private PMMLBuilder setVerification(Verification verification){ this.verification = verification; return this; } static private Iterable<Transformer> getTransformers(PipelineModel pipelineModel){ List<Transformer> result = new ArrayList<>(); result.add(pipelineModel); Function<Transformer, List<Transformer>> function = new Function<Transformer, List<Transformer>>(){ @Override public List<Transformer> apply(Transformer transformer){ if(transformer instanceof PipelineModel){ PipelineModel pipelineModel = (PipelineModel)transformer; return Arrays.asList(pipelineModel.stages()); } else if(transformer instanceof CrossValidatorModel){ CrossValidatorModel crossValidatorModel = (CrossValidatorModel)transformer; return Collections.singletonList(crossValidatorModel.bestModel()); } else if(transformer instanceof TrainValidationSplitModel){ TrainValidationSplitModel trainValidationSplitModel = (TrainValidationSplitModel)transformer; return Collections.singletonList(trainValidationSplitModel.bestModel()); } return null; } }; while(true){ boolean modified = false; ListIterator<Transformer> transformerIt = result.listIterator(); while(transformerIt.hasNext()){ Transformer transformer = transformerIt.next(); List<Transformer> childTransformers = function.apply(transformer); if(childTransformers != null){ transformerIt.remove(); for(Transformer childTransformer : childTransformers){ transformerIt.add(childTransformer); } modified = true; } } if(!modified){ break; } } return result; } static private List<?> getColumn(Dataset<Row> dataset, String name){ List<Row> rows = dataset.select(name) .collectAsList(); return rows.stream() .map(row -> row.apply(0)) .collect(Collectors.toList()); } static private List<?> getVectorColumn(Dataset<Row> dataset, String name, int index){ List<Vector> column = (List<Vector>)getColumn(dataset, name); return column.stream() .map(vector -> vector.apply(index)) .collect(Collectors.toList()); } static private void init(){ ConverterFactory.checkVersion(); ConverterFactory.checkApplicationClasspath(); ConverterFactory.checkNoShading(); } static public class Verification { private Dataset<Row> dataset = null; private Dataset<Row> transformedDataset = null; public Double precision = null; public Double zeroThreshold = null; private Verification(Dataset<Row> dataset, Dataset<Row> transformedDataset){ setDataset(dataset); setTransformedDataset(transformedDataset); } public Dataset<Row> getDataset(){ return this.dataset; } private Verification setDataset(Dataset<Row> dataset){ this.dataset = dataset; return this; } public Dataset<Row> getTransformedDataset(){ return this.transformedDataset; } private Verification setTransformedDataset(Dataset<Row> transformedDataset){ this.transformedDataset = transformedDataset; return this; } public Double getPrecision(){ return this.precision; } public Verification setPrecision(Double precision){ this.precision = precision; return this; } public Double getZeroThreshold(){ return this.zeroThreshold; } public Verification setZeroThreshold(Double zeroThreshold){ this.zeroThreshold = zeroThreshold; return this; } } static { init(); } }