/* * Copyright (c) 2016 Villu Ruusmann * * This file is part of JPMML-SparkML * * JPMML-SparkML is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * JPMML-SparkML is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with JPMML-SparkML. If not, see <http://www.gnu.org/licenses/>. */ package org.jpmml.sparkml.feature; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import org.apache.spark.ml.feature.StringIndexerModel; import org.dmg.pmml.Apply; import org.dmg.pmml.DataField; import org.dmg.pmml.DataType; import org.dmg.pmml.DerivedField; import org.dmg.pmml.Field; import org.dmg.pmml.InvalidValueTreatmentMethod; import org.dmg.pmml.OpType; import org.dmg.pmml.PMMLFunctions; import org.dmg.pmml.Value; import org.jpmml.converter.CategoricalFeature; import org.jpmml.converter.Feature; import org.jpmml.converter.FeatureUtil; import org.jpmml.converter.InvalidValueDecorator; import org.jpmml.converter.PMMLUtil; import org.jpmml.sparkml.MultiFeatureConverter; import org.jpmml.sparkml.SparkMLEncoder; public class StringIndexerModelConverter extends MultiFeatureConverter<StringIndexerModel> { public StringIndexerModelConverter(StringIndexerModel transformer){ super(transformer); } @Override public List<Feature> encodeFeatures(SparkMLEncoder encoder){ StringIndexerModel transformer = getTransformer(); String[][] labelsArray = transformer.labelsArray(); InOutMode inputMode = getInputMode(); List<Feature> result = new ArrayList<>(); String[] inputCols = inputMode.getInputCols(transformer); for(int i = 0; i < inputCols.length; i++){ String inputCol = inputCols[i]; String[] labels = labelsArray[i]; Feature feature = encoder.getOnlyFeature(inputCol); List<String> categories = new ArrayList<>(); categories.addAll(Arrays.asList(labels)); String invalidCategory; DataType dataType = feature.getDataType(); switch(dataType){ case INTEGER: case FLOAT: case DOUBLE: invalidCategory = "-999"; break; default: invalidCategory = "__unknown"; break; } String handleInvalid = transformer.getHandleInvalid(); Field<?> field = encoder.toCategorical(feature.getName(), categories); if(field instanceof DataField){ DataField dataField = (DataField)field; InvalidValueDecorator invalidValueDecorator; switch(handleInvalid){ case "keep": { invalidValueDecorator = new InvalidValueDecorator(InvalidValueTreatmentMethod.AS_IS, invalidCategory); PMMLUtil.addValues(dataField, Collections.singletonList(invalidCategory), Value.Property.INVALID); categories.add(invalidCategory); } break; case "error": { invalidValueDecorator = new InvalidValueDecorator(InvalidValueTreatmentMethod.RETURN_INVALID, null); } break; default: throw new IllegalArgumentException("Invalid value handling strategy " + handleInvalid + " is not supported"); } encoder.addDecorator(dataField, invalidValueDecorator); } else if(field instanceof DerivedField){ switch(handleInvalid){ case "keep": { Apply setApply = PMMLUtil.createApply(PMMLFunctions.ISIN, feature.ref()); for(String category : categories){ setApply.addExpressions(PMMLUtil.createConstant(category, dataType)); } categories.add(invalidCategory); Apply apply = PMMLUtil.createApply(PMMLFunctions.IF) .addExpressions(setApply) .addExpressions(feature.ref(), PMMLUtil.createConstant(invalidCategory, dataType)); field = encoder.createDerivedField(FeatureUtil.createName("handleInvalid", feature), OpType.CATEGORICAL, dataType, apply); } break; case "error": { // Ignored: Assume that a DerivedField element can never return an erroneous field value } break; default: throw new IllegalArgumentException(handleInvalid); } } else { throw new IllegalArgumentException(); } result.add(new CategoricalFeature(encoder, field, categories)); } return result; } }