org.dmg.pmml.regression.RegressionModel Java Examples

The following examples show how to use org.dmg.pmml.regression.RegressionModel. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: MultinomialLogisticRegression.java    From jpmml-lightgbm with GNU Affero General Public License v3.0 6 votes vote down vote up
@Override
public MiningModel encodeMiningModel(List<Tree> trees, Integer numIteration, Schema schema){
	Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE);

	List<MiningModel> miningModels = new ArrayList<>();

	CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();

	for(int i = 0, rows = categoricalLabel.size(), columns = (trees.size() / rows); i < rows; i++){
		MiningModel miningModel = createMiningModel(FortranMatrixUtil.getRow(trees, rows, columns, i), numIteration, segmentSchema)
			.setOutput(ModelUtil.createPredictedOutput(FieldName.create("lgbmValue(" + categoricalLabel.getValue(i) + ")"), OpType.CONTINUOUS, DataType.DOUBLE));

		miningModels.add(miningModel);
	}

	return MiningModelUtil.createClassification(miningModels, RegressionModel.NormalizationMethod.SOFTMAX, true, schema);
}
 
Example #2
Source File: RegressionModelUtil.java    From jpmml-evaluator with GNU Affero General Public License v3.0 6 votes vote down vote up
static
public <V extends Number> Value<V> normalizeBinaryLogisticClassificationResult(RegressionModel.NormalizationMethod normalizationMethod, Value<V> value){

	switch(normalizationMethod){
		case NONE:
			return value.restrict(Numbers.DOUBLE_ZERO, Numbers.DOUBLE_ONE);
		case LOGIT:
			return value.inverseLogit();
		case PROBIT:
			return value.inverseProbit();
		case CLOGLOG:
			return value.inverseCloglog();
		case LOGLOG:
			return value.inverseLoglog();
		case CAUCHIT:
			return value.inverseCauchit();
		default:
			throw new IllegalArgumentException();
	}
}
 
Example #3
Source File: RegressionModelUtil.java    From jpmml-evaluator with GNU Affero General Public License v3.0 6 votes vote down vote up
static
public <V extends Number> Value<V> normalizeRegressionResult(RegressionModel.NormalizationMethod normalizationMethod, Value<V> value){

	switch(normalizationMethod){
		case NONE:
			return value;
		case SOFTMAX:
		case LOGIT:
			return value.inverseLogit();
		case EXP:
			return value.exp();
		case PROBIT:
			return value.inverseProbit();
		case CLOGLOG:
			return value.inverseCloglog();
		case LOGLOG:
			return value.inverseLoglog();
		case CAUCHIT:
			return value.inverseCauchit();
		default:
			throw new IllegalArgumentException();
	}
}
 
Example #4
Source File: PMMLConverter.java    From pyramid with Apache License 2.0 6 votes vote down vote up
public static MiningModel encodeMiningModel(List<List<RegressionTree>> regTrees, float base_score, Schema schema){
    Schema segmentSchema = new Schema(new ContinuousLabel(null, DataType.FLOAT), schema.getFeatures());

    List<MiningModel> miningModels = new ArrayList<>();

    CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();


    int numClasses = regTrees.size();
    for (int l=0;l<numClasses;l++){
        MiningModel miningModel = createMiningModel(regTrees.get(l), base_score, segmentSchema)
                .setOutput(ModelUtil.createPredictedOutput(FieldName.create("class_(" + categoricalLabel.getValue(l) + ")"), OpType.CONTINUOUS, DataType.FLOAT));
        miningModels.add(miningModel);
    }

    return MiningModelUtil.createClassification(miningModels, RegressionModel.NormalizationMethod.SOFTMAX, true, schema);
}
 
Example #5
Source File: AdaConverter.java    From jpmml-r with GNU Affero General Public License v3.0 6 votes vote down vote up
@Override
public Model encodeModel(Schema schema){
	RGenericVector ada = getObject();

	RGenericVector model = ada.getGenericElement("model");

	RGenericVector trees = model.getGenericElement("trees");
	RDoubleVector alpha = model.getDoubleElement("alpha");

	List<TreeModel> treeModels = encodeTreeModels(trees);

	MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(null))
		.setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.WEIGHTED_SUM, treeModels, alpha.getValues()))
		.setOutput(ModelUtil.createPredictedOutput(FieldName.create("adaValue"), OpType.CONTINUOUS, DataType.DOUBLE));

	return MiningModelUtil.createBinaryLogisticClassification(miningModel, 2d, 0d, RegressionModel.NormalizationMethod.LOGIT, true, schema);
}
 
Example #6
Source File: GBDTLRClassifier.java    From jpmml-sklearn with GNU Affero General Public License v3.0 6 votes vote down vote up
@Override
public Model encodeModel(Schema schema){
	Classifier gbdt = getGBDT();
	MultiOneHotEncoder ohe = getOHE();
	LinearClassifier lr = getLR();

	CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();

	SchemaUtil.checkSize(2, categoricalLabel);

	List<? extends Number> coef = lr.getCoef();
	List<? extends Number> intercept = lr.getIntercept();

	Schema segmentSchema = schema.toAnonymousSchema();

	MiningModel miningModel = GBDTUtil.encodeModel(gbdt, ohe, coef, Iterables.getOnlyElement(intercept), segmentSchema)
		.setOutput(ModelUtil.createPredictedOutput(FieldName.create("decisionFunction"), OpType.CONTINUOUS, DataType.DOUBLE));

	return MiningModelUtil.createBinaryLogisticClassification(miningModel, 1d, 0d, RegressionModel.NormalizationMethod.LOGIT, lr.hasProbabilityDistribution(), schema);
}
 
Example #7
Source File: LogRegFromSparkThroughPMMLExample.java    From ignite with Apache License 2.0 6 votes vote down vote up
/**
 * @param path Path.
 */
public static LogisticRegressionModel load(String path) {
    try (InputStream is = new FileInputStream(new File(path))) {
        PMML pmml = PMMLUtil.unmarshal(is);

        RegressionModel logRegMdl = (RegressionModel)pmml.getModels().get(0);

        RegressionTable regTbl = logRegMdl.getRegressionTables().get(0);

        Vector coefficients = new DenseVector(regTbl.getNumericPredictors().size());

        for (int i = 0; i < regTbl.getNumericPredictors().size(); i++)
            coefficients.set(i, regTbl.getNumericPredictors().get(i).getCoefficient());

        double interceptor = regTbl.getIntercept();

        return new LogisticRegressionModel(coefficients, interceptor);
    }
    catch (IOException | JAXBException | SAXException e) {
        e.printStackTrace();
    }

    return null;
}
 
Example #8
Source File: HingeClassification.java    From jpmml-xgboost with GNU Affero General Public License v3.0 6 votes vote down vote up
@Override
public MiningModel encodeMiningModel(List<RegTree> trees, List<Float> weights, float base_score, Integer ntreeLimit, Schema schema){
	Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.FLOAT);

	Transformation transformation = new FunctionTransformation(PMMLFunctions.THRESHOLD){

		@Override
		public FieldName getName(FieldName name){
			return FieldName.create("hinge(" + name + ")");
		}

		@Override
		public Expression createExpression(FieldRef fieldRef){
			Apply apply = (Apply)super.createExpression(fieldRef);

			apply.addExpressions(PMMLUtil.createConstant(0f));

			return apply;
		}
	};

	MiningModel miningModel = createMiningModel(trees, weights, base_score, ntreeLimit, segmentSchema)
		.setOutput(ModelUtil.createPredictedOutput(FieldName.create("xgbValue"), OpType.CONTINUOUS, DataType.FLOAT, transformation));

	return MiningModelUtil.createBinaryLogisticClassification(miningModel, 1d, 0d, RegressionModel.NormalizationMethod.NONE, true, schema);
}
 
Example #9
Source File: MultinomialLogisticRegression.java    From jpmml-xgboost with GNU Affero General Public License v3.0 6 votes vote down vote up
@Override
public MiningModel encodeMiningModel(List<RegTree> trees, List<Float> weights, float base_score, Integer ntreeLimit, Schema schema){
	Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.FLOAT);

	List<MiningModel> miningModels = new ArrayList<>();

	CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();

	for(int i = 0, columns = categoricalLabel.size(), rows = (trees.size() / columns); i < columns; i++){
		MiningModel miningModel = createMiningModel(CMatrixUtil.getColumn(trees, rows, columns, i), (weights != null) ? CMatrixUtil.getColumn(weights, rows, columns, i) : null, base_score, ntreeLimit, segmentSchema)
			.setOutput(ModelUtil.createPredictedOutput(FieldName.create("xgbValue(" + categoricalLabel.getValue(i) + ")"), OpType.CONTINUOUS, DataType.FLOAT));

		miningModels.add(miningModel);
	}

	return MiningModelUtil.createClassification(miningModels, RegressionModel.NormalizationMethod.SOFTMAX, true, schema);
}
 
Example #10
Source File: LinearSVCModelConverter.java    From jpmml-sparkml with GNU Affero General Public License v3.0 6 votes vote down vote up
@Override
public MiningModel encodeModel(Schema schema){
	LinearSVCModel model = getTransformer();

	Transformation transformation = new AbstractTransformation(){

		@Override
		public Expression createExpression(FieldRef fieldRef){
			return PMMLUtil.createApply(PMMLFunctions.THRESHOLD)
				.addExpressions(fieldRef, PMMLUtil.createConstant(model.getThreshold()));
		}
	};

	Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE);

	Model linearModel = LinearModelUtil.createRegression(this, model.coefficients(), model.intercept(), segmentSchema)
		.setOutput(ModelUtil.createPredictedOutput(FieldName.create("margin"), OpType.CONTINUOUS, DataType.DOUBLE, transformation));

	return MiningModelUtil.createBinaryLogisticClassification(linearModel, 1d, 0d, RegressionModel.NormalizationMethod.NONE, false, schema);
}
 
Example #11
Source File: GBTClassificationModelConverter.java    From jpmml-sparkml with GNU Affero General Public License v3.0 6 votes vote down vote up
@Override
public MiningModel encodeModel(Schema schema){
	GBTClassificationModel model = getTransformer();

	String lossType = model.getLossType();
	switch(lossType){
		case "logistic":
			break;
		default:
			throw new IllegalArgumentException("Loss function " + lossType + " is not supported");
	}

	Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE);

	List<TreeModel> treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(this, segmentSchema);

	MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(segmentSchema.getLabel()))
		.setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.WEIGHTED_SUM, treeModels, Doubles.asList(model.treeWeights())))
		.setOutput(ModelUtil.createPredictedOutput(FieldName.create("gbtValue"), OpType.CONTINUOUS, DataType.DOUBLE));

	return MiningModelUtil.createBinaryLogisticClassification(miningModel, 2d, 0d, RegressionModel.NormalizationMethod.LOGIT, false, schema);
}
 
Example #12
Source File: LogNetConverter.java    From jpmml-r with GNU Affero General Public License v3.0 5 votes vote down vote up
@Override
public Model encodeModel(RDoubleVector a0, RExp beta, int column, Schema schema){
	Double intercept = a0.getValue(column);
	List<Double> coefficients = getCoefficients((S4Object)beta, column);

	return RegressionModelUtil.createBinaryLogisticClassification(schema.getFeatures(), coefficients, intercept, RegressionModel.NormalizationMethod.LOGIT, true, schema);
}
 
Example #13
Source File: RegressionModelUtilTest.java    From jpmml-evaluator with GNU Affero General Public License v3.0 5 votes vote down vote up
@Test
public void computeOrdinalProbabilities(){
	ValueMap<String, Float> values = new ValueMap<>();
	values.put("loud", new FloatValue(0.2f));
	values.put("louder", new FloatValue(0.7f));
	values.put("insane", new FloatValue(1f));

	RegressionModelUtil.computeOrdinalProbabilities(RegressionModel.NormalizationMethod.NONE, values);

	assertEquals(new FloatValue(0.2f - 0f), values.get("loud"));
	assertEquals(new FloatValue(0.7f - 0.2f), values.get("louder"));
	assertEquals(new FloatValue(1f - 0.7f), values.get("insane"));
}
 
Example #14
Source File: RegressionModelUtilTest.java    From jpmml-evaluator with GNU Affero General Public License v3.0 5 votes vote down vote up
@Test
public void computeMultinomialProbabilities(){
	ValueMap<String, Float> values = new ValueMap<>();
	values.put("red", new FloatValue(0.3f));
	values.put("yellow", new FloatValue(0.5f));
	values.put("green", new FloatValue(Float.MAX_VALUE));

	RegressionModelUtil.computeMultinomialProbabilities(RegressionModel.NormalizationMethod.NONE, values);

	assertEquals(new FloatValue(0.3f), values.get("red"));
	assertEquals(new FloatValue(0.5f), values.get("yellow"));
	assertEquals(new FloatValue(1f - (0.3f + 0.5f)), values.get("green"));
}
 
Example #15
Source File: RegressionModelUtilTest.java    From jpmml-evaluator with GNU Affero General Public License v3.0 5 votes vote down vote up
@Test
public void computeBinomialProbabilities(){
	ValueMap<String, Float> values = new ValueMap<>();
	values.put("yes", new FloatValue(0.3f));
	values.put("no", new FloatValue(Float.MAX_VALUE));

	RegressionModelUtil.computeBinomialProbabilities(RegressionModel.NormalizationMethod.NONE, values);

	assertEquals(new FloatValue(0.3f), values.get("yes"));
	assertEquals(new FloatValue(1f - 0.3f), values.get("no"));
}
 
Example #16
Source File: BinomialLogisticRegression.java    From jpmml-lightgbm with GNU Affero General Public License v3.0 5 votes vote down vote up
@Override
public MiningModel encodeMiningModel(List<Tree> trees, Integer numIteration, Schema schema){
	Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE);

	MiningModel miningModel = createMiningModel(trees, numIteration, segmentSchema)
		.setOutput(ModelUtil.createPredictedOutput(FieldName.create("lgbmValue"), OpType.CONTINUOUS, DataType.DOUBLE));

	return MiningModelUtil.createBinaryLogisticClassification(miningModel, BinomialLogisticRegression.this.sigmoid_, 0d, RegressionModel.NormalizationMethod.LOGIT, true, schema);
}
 
Example #17
Source File: RegressionModelEvaluator.java    From jpmml-evaluator with GNU Affero General Public License v3.0 5 votes vote down vote up
public RegressionModelEvaluator(PMML pmml, RegressionModel regressionModel){
	super(pmml, regressionModel);

	if(!regressionModel.hasRegressionTables()){
		throw new MissingElementException(regressionModel, PMMLElements.REGRESSIONMODEL_REGRESSIONTABLES);
	}
}
 
Example #18
Source File: MarshallerTest.java    From jpmml-model with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
@Test
public void marshal() throws Exception {
	PMML pmml = new PMML(Version.PMML_4_4.getVersion(), new Header(), new DataDictionary());

	RegressionModel regressionModel = new RegressionModel()
		.addRegressionTables(new RegressionTable());

	pmml.addModels(regressionModel);

	JAXBContext context = JAXBContextFactory.createContext(new Class[]{org.dmg.pmml.ObjectFactory.class, org.dmg.pmml.regression.ObjectFactory.class}, null);

	Marshaller marshaller = context.createMarshaller();

	String string;

	try(ByteArrayOutputStream os = new ByteArrayOutputStream()){
		marshaller.marshal(pmml, os);

		string = os.toString("UTF-8");
	}

	assertTrue(string.contains("<PMML xmlns=\"http://www.dmg.org/PMML-4_4\""));
	assertTrue(string.contains(" version=\"4.4\">"));
	assertTrue(string.contains("<RegressionModel>"));
	assertTrue(string.contains("</RegressionModel>"));
	assertTrue(string.contains("</PMML>"));
}
 
Example #19
Source File: VersionInspectorTest.java    From jpmml-model with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
@Test
public void inspectTypeAnnotations(){
	PMML pmml = createPMML();

	assertVersionRange(pmml, Version.PMML_3_0, Version.PMML_4_4);

	pmml.addModels(new AssociationModel(),
		//new ClusteringModel(),
		//new GeneralRegressionModel(),
		//new MiningModel(),
		new NaiveBayesModel(),
		new NeuralNetwork(),
		new RegressionModel(),
		new RuleSetModel(),
		new SequenceModel(),
		//new SupportVectorMachineModel(),
		new TextModel(),
		new TreeModel());

	assertVersionRange(pmml, Version.PMML_3_0, Version.PMML_4_4);

	pmml.addModels(new TimeSeriesModel());

	assertVersionRange(pmml, Version.PMML_4_0, Version.PMML_4_4);

	pmml.addModels(new BaselineModel(),
		new Scorecard(),
		new NearestNeighborModel());

	assertVersionRange(pmml, Version.PMML_4_1, Version.PMML_4_4);

	pmml.addModels(new BayesianNetworkModel(),
		new GaussianProcessModel());

	assertVersionRange(pmml, Version.PMML_4_3, Version.PMML_4_4);
}
 
Example #20
Source File: ReflectionUtilTest.java    From jpmml-model with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
@Test
public void copyState(){
	PMML pmml = new PMML(Version.PMML_4_4.getVersion(), new Header(), new DataDictionary());

	// Initialize a live list instance
	pmml.getModels();

	CustomPMML customPmml = new CustomPMML();

	ReflectionUtil.copyState(pmml, customPmml);

	assertSame(pmml.getVersion(), customPmml.getVersion());
	assertSame(pmml.getHeader(), customPmml.getHeader());
	assertSame(pmml.getDataDictionary(), customPmml.getDataDictionary());

	assertFalse(pmml.hasModels());
	assertFalse(customPmml.hasModels());

	pmml.addModels(new RegressionModel());

	assertTrue(pmml.hasModels());
	assertTrue(customPmml.hasModels());

	assertSame(pmml.getModels(), customPmml.getModels());

	try {
		ReflectionUtil.copyState(customPmml, pmml);

		fail();
	} catch(IllegalArgumentException iae){
		// Ignored
	}
}
 
Example #21
Source File: PoissonRegression.java    From jpmml-lightgbm with GNU Affero General Public License v3.0 5 votes vote down vote up
@Override
public MiningModel encodeMiningModel(List<Tree> trees, Integer numIteration, Schema schema){
	Schema segmentSchema = schema.toAnonymousSchema();

	MiningModel miningModel = super.encodeMiningModel(trees, numIteration, segmentSchema)
		.setOutput(ModelUtil.createPredictedOutput(FieldName.create("lgbmValue"), OpType.CONTINUOUS, DataType.DOUBLE));

	return MiningModelUtil.createRegression(miningModel, RegressionModel.NormalizationMethod.EXP, schema);
}
 
Example #22
Source File: GBMConverter.java    From jpmml-r with GNU Affero General Public License v3.0 5 votes vote down vote up
private MiningModel encodeBinaryClassification(List<TreeModel> treeModels, Double initF, double coefficient, Schema schema){
	Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE);

	MiningModel miningModel = createMiningModel(treeModels, initF, segmentSchema)
		.setOutput(ModelUtil.createPredictedOutput(FieldName.create("gbmValue"), OpType.CONTINUOUS, DataType.DOUBLE));

	return MiningModelUtil.createBinaryLogisticClassification(miningModel, -coefficient, 0d, RegressionModel.NormalizationMethod.LOGIT, true, schema);
}
 
Example #23
Source File: BinomialLogisticRegression.java    From jpmml-xgboost with GNU Affero General Public License v3.0 5 votes vote down vote up
@Override
public MiningModel encodeMiningModel(List<RegTree> trees, List<Float> weights, float base_score, Integer ntreeLimit, Schema schema){
	Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.FLOAT);

	MiningModel miningModel = createMiningModel(trees, weights, base_score, ntreeLimit, segmentSchema)
		.setOutput(ModelUtil.createPredictedOutput(FieldName.create("xgbValue"), OpType.CONTINUOUS, DataType.FLOAT));

	return MiningModelUtil.createBinaryLogisticClassification(miningModel, 1d, 0d, RegressionModel.NormalizationMethod.LOGIT, true, schema);
}
 
Example #24
Source File: LogisticRegression.java    From jpmml-xgboost with GNU Affero General Public License v3.0 5 votes vote down vote up
@Override
public MiningModel encodeMiningModel(List<RegTree> trees, List<Float> weights, float base_score, Integer ntreeLimit, Schema schema){
	Schema segmentSchema = schema.toAnonymousSchema();

	MiningModel miningModel = createMiningModel(trees, weights, base_score, ntreeLimit, segmentSchema)
		.setOutput(ModelUtil.createPredictedOutput(FieldName.create("xgbValue"), OpType.CONTINUOUS, DataType.FLOAT));

	return MiningModelUtil.createRegression(miningModel, RegressionModel.NormalizationMethod.LOGIT, schema);
}
 
Example #25
Source File: DummyRegressor.java    From jpmml-sklearn with GNU Affero General Public License v3.0 5 votes vote down vote up
@Override
public RegressionModel encodeModel(Schema schema){
	List<? extends Number> constant = getConstant();

	Number intercept = Iterables.getOnlyElement(constant);

	return RegressionModelUtil.createRegression(Collections.emptyList(), Collections.emptyList(), intercept.doubleValue(), null, schema);
}
 
Example #26
Source File: LinearRegressor.java    From jpmml-sklearn with GNU Affero General Public License v3.0 5 votes vote down vote up
@Override
public RegressionModel encodeModel(Schema schema){
	List<? extends Number> coef = getCoef();
	List<? extends Number> intercept = getIntercept();

	return RegressionModelUtil.createRegression(schema.getFeatures(), coef, Iterables.getOnlyElement(intercept), null, schema);
}
 
Example #27
Source File: GeneralizedLinearRegressor.java    From jpmml-sklearn with GNU Affero General Public License v3.0 5 votes vote down vote up
@Override
public RegressionModel encodeModel(Schema schema){
	RegressionModel regressionModel = super.encodeModel(schema)
		.setNormalizationMethod(RegressionModel.NormalizationMethod.EXP);

	return regressionModel;
}
 
Example #28
Source File: GeneralizedLinearRegression.java    From jpmml-xgboost with GNU Affero General Public License v3.0 5 votes vote down vote up
@Override
public MiningModel encodeMiningModel(List<RegTree> trees, List<Float> weights, float base_score, Integer ntreeLimit, Schema schema){
	Schema segmentSchema = schema.toAnonymousSchema();

	MiningModel miningModel = createMiningModel(trees, weights, base_score, ntreeLimit, segmentSchema)
		.setOutput(ModelUtil.createPredictedOutput(FieldName.create("xgbValue"), OpType.CONTINUOUS, DataType.FLOAT));

	return MiningModelUtil.createRegression(miningModel, RegressionModel.NormalizationMethod.EXP, schema);
}
 
Example #29
Source File: LinearRegressor.java    From jpmml-tensorflow with GNU Affero General Public License v3.0 5 votes vote down vote up
@Override
public RegressionModel encodeModel(TensorFlowEncoder encoder){
	DataField dataField = encoder.createDataField(FieldName.create("_target"), OpType.CONTINUOUS, DataType.FLOAT);

	Label label = new ContinuousLabel(dataField);

	RegressionModel regressionModel = encodeRegressionModel(encoder)
		.setMiningFunction(MiningFunction.REGRESSION)
		.setMiningSchema(ModelUtil.createMiningSchema(label));

	return regressionModel;
}
 
Example #30
Source File: LinearDiscriminantAnalysis.java    From jpmml-sklearn with GNU Affero General Public License v3.0 4 votes vote down vote up
private Model encodeMultinomialModel(Schema schema){
	String sklearnVersion = getSkLearnVersion();
	int[] shape = getCoefShape();

	int numberOfClasses = shape[0];
	int numberOfFeatures = shape[1];

	List<? extends Number> coef = getCoef();
	List<? extends Number> intercept = getIntercept();

	CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();

	List<? extends Feature> features = schema.getFeatures();

	// See https://github.com/scikit-learn/scikit-learn/issues/6848
	boolean corrected = (sklearnVersion != null && SkLearnUtil.compareVersion(sklearnVersion, "0.21") >= 0);

	if(!corrected){
		return super.encodeModel(schema);
	} // End if

	if(numberOfClasses >= 3){
		SchemaUtil.checkSize(numberOfClasses, categoricalLabel);

		Schema segmentSchema = (schema.toAnonymousRegressorSchema(DataType.DOUBLE)).toEmptySchema();

		List<RegressionModel> regressionModels = new ArrayList<>();

		for(int i = 0, rows = categoricalLabel.size(); i < rows; i++){
			RegressionModel regressionModel = RegressionModelUtil.createRegression(features, CMatrixUtil.getRow(coef, numberOfClasses, numberOfFeatures, i), intercept.get(i), RegressionModel.NormalizationMethod.NONE, segmentSchema)
				.setOutput(ModelUtil.createPredictedOutput(FieldName.create("decisionFunction(" + categoricalLabel.getValue(i) + ")"), OpType.CONTINUOUS, DataType.DOUBLE));

			regressionModels.add(regressionModel);
		}

		return MiningModelUtil.createClassification(regressionModels, RegressionModel.NormalizationMethod.SOFTMAX, true, schema);
	} else

	{
		throw new IllegalArgumentException();
	}
}