Java Code Examples for org.apache.flink.ml.api.misc.param.Params#set()

The following examples show how to use org.apache.flink.ml.api.misc.param.Params#set() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: ParamsTest.java    From flink with Apache License 2.0 6 votes vote down vote up
@Test
public void getOptionalParam() {
	ParamInfo <String> key = ParamInfoFactory
		.createParamInfo("key", String.class)
		.setHasDefaultValue(null)
		.setDescription("")
		.build();

	Params params = new Params();
	Assert.assertNull(params.get(key));

	String val = "3";
	params.set(key, val);
	Assert.assertEquals(params.get(key), val);

	params.set(key, null);
	Assert.assertNull(params.get(key));
}
 
Example 2
Source File: BaseLinearModelTrainBatchOp.java    From Alink with Apache License 2.0 6 votes vote down vote up
@Override
public void mapPartition(Iterable<Object> rows, Collector<Params> metas) throws Exception {
    Object[] labels = null;
    if (!this.isRegProc) {
        labels = orderLabels(rows);
    }

    Params meta = new Params();
    meta.set(ModelParamName.MODEL_NAME, this.modelName);
    meta.set(ModelParamName.LINEAR_MODEL_TYPE, this.modelType);
    meta.set(ModelParamName.LABEL_VALUES, labels);
    meta.set(ModelParamName.HAS_INTERCEPT_ITEM, this.hasInterceptItem);
    meta.set(ModelParamName.VECTOR_COL_NAME, vectorColName);
    meta.set(LinearTrainParams.LABEL_COL, labelName);
    metas.collect(meta);
}
 
Example 3
Source File: ClassificationEvaluationUtilTest.java    From Alink with Apache License 2.0 6 votes vote down vote up
@Test
public void judgeEvaluationTypeTest(){
    Params params = new Params()
        .set(HasPredictionDetailCol.PREDICTION_DETAIL_COL, "detail");

    ClassificationEvaluationUtil.Type type = ClassificationEvaluationUtil.judgeEvaluationType(params);
    Assert.assertEquals(type, ClassificationEvaluationUtil.Type.PRED_DETAIL);

    params.set(HasPredictionCol.PREDICTION_COL, "pred");
    type = ClassificationEvaluationUtil.judgeEvaluationType(params);
    Assert.assertEquals(type, ClassificationEvaluationUtil.Type.PRED_DETAIL);

    params.remove(HasPredictionDetailCol.PREDICTION_DETAIL_COL);
    type = ClassificationEvaluationUtil.judgeEvaluationType(params);
    Assert.assertEquals(type, ClassificationEvaluationUtil.Type.PRED_RESULT);

    params.remove(HasPredictionCol.PREDICTION_COL);
    thrown.expect(RuntimeException.class);
    thrown.expectMessage("Error Input, must give either predictionCol or predictionDetailCol!");
    ClassificationEvaluationUtil.judgeEvaluationType(params);
}
 
Example 4
Source File: ParamsTest.java    From flink with Apache License 2.0 6 votes vote down vote up
@Test
public void getOptionalParam() {
	ParamInfo <String> key = ParamInfoFactory
		.createParamInfo("key", String.class)
		.setHasDefaultValue(null)
		.setDescription("")
		.build();

	Params params = new Params();
	Assert.assertNull(params.get(key));

	String val = "3";
	params.set(key, val);
	Assert.assertEquals(params.get(key), val);

	params.set(key, null);
	Assert.assertNull(params.get(key));
}
 
Example 5
Source File: SelectMapperTest.java    From Alink with Apache License 2.0 6 votes vote down vote up
@Test
public void testValueConstructionFunctions() throws Exception {
    TableSchema dataSchema = TableSchema.builder().fields(
        new String[] {"id", "name"},
        new DataType[] {DataTypes.INT(), DataTypes.STRING()}).build();
    Params params = new Params();
    params.set(HasClause.CLAUSE,
        "ROW(1, 2, 3), ARRAY[1, 2, 3], MAP[1, 2, 3, 4]"
    );
    SelectMapper selectMapper = new SelectMapper(dataSchema, params);
    selectMapper.open();
    Row output = selectMapper.map(Row.of(1, "'abc'"));
    try {
        assertEquals(output.getArity(), 3);
    } finally {
        selectMapper.close();
    }
}
 
Example 6
Source File: ParamsTest.java    From flink with Apache License 2.0 6 votes vote down vote up
@Test
public void getRequiredParam() {
	ParamInfo <String> labelWithRequired = ParamInfoFactory
		.createParamInfo("label", String.class)
		.setDescription("")
		.setRequired()
		.build();
	Params params = new Params();
	try {
		params.get(labelWithRequired);
		Assert.fail("failure");
	} catch (IllegalArgumentException ex) {
		Assert.assertTrue(ex.getMessage().startsWith("Missing non-optional parameter"));
	}

	params.set(labelWithRequired, null);
	Assert.assertNull(params.get(labelWithRequired));

	String val = "3";
	params.set(labelWithRequired, val);
	Assert.assertEquals(params.get(labelWithRequired), val);
}
 
Example 7
Source File: RegressionMetricsSummary.java    From Alink with Apache License 2.0 6 votes vote down vote up
@Override
public RegressionMetrics toMetrics() {
    Params params = new Params();
    params.set(RegressionMetrics.SST, ySum2Local - ySumLocal * ySumLocal / total);
    params.set(RegressionMetrics.SSE, sseLocal);
    params.set(RegressionMetrics.SSR,
        predSum2Local - 2 * ySumLocal * predSumLocal / total + ySumLocal * ySumLocal / total);
    params.set(RegressionMetrics.R2, 1 - params.get(RegressionMetrics.SSE) / params.get(RegressionMetrics.SST));
    params.set(RegressionMetrics.R, Math.sqrt(params.get(RegressionMetrics.R2)));
    params.set(RegressionMetrics.MSE, params.get(RegressionMetrics.SSE) / total);
    params.set(RegressionMetrics.RMSE, Math.sqrt(params.get(RegressionMetrics.MSE)));
    params.set(RegressionMetrics.SAE, maeLocal);
    params.set(RegressionMetrics.MAE, params.get(RegressionMetrics.SAE) / total);
    params.set(RegressionMetrics.COUNT, (double)total);
    params.set(RegressionMetrics.MAPE, mapeLocal * 100 / total);
    params.set(RegressionMetrics.Y_MEAN, ySumLocal / total);
    params.set(RegressionMetrics.PREDICTION_MEAN, predSumLocal / total);
    params.set(RegressionMetrics.EXPLAINED_VARIANCE, params.get(RegressionMetrics.SSR) / total);

    return new RegressionMetrics(params);
}
 
Example 8
Source File: SelectMapperTest.java    From Alink with Apache License 2.0 6 votes vote down vote up
@Test
public void testCollectionFunctions() throws Exception {
    TableSchema dataSchema = TableSchema.builder().fields(
        new String[] {"id", "name"},
        new DataType[] {DataTypes.INT(), DataTypes.STRING()}).build();
    Params params = new Params();
    params.set(HasClause.CLAUSE,
        "CARDINALITY(ARRAY[1,2,3])"
            + ", ARRAY[1,2,3][2]"
            + ", ELEMENT(ARRAY[2])"
            + ", CARDINALITY(MAP[1, 2, 3, 4])"
            + ", MAP[1, 2, 3, 4][3]"
    );
    SelectMapper selectMapper = new SelectMapper(dataSchema, params);
    selectMapper.open();
    Row expected = Row.of(3, 2, 2, 2, 4);
    Row output = selectMapper.map(Row.of(1, "'abc'"));
    try {
        assertEquals(expected, output);
    } finally {
        selectMapper.close();
    }
}
 
Example 9
Source File: MultiMetricsSummary.java    From Alink with Apache License 2.0 5 votes vote down vote up
/**
 * Calculate the detail info based on the confusion matrix.
 */
@Override
public MultiClassMetrics toMetrics() {
    Params params = new Params();
    ConfusionMatrix data = new ConfusionMatrix(matrix);
    params.set(MultiClassMetrics.PREDICT_LABEL_FREQUENCY, data.getPredictLabelFrequency());
    params.set(MultiClassMetrics.PREDICT_LABEL_PROPORTION, data.getPredictLabelProportion());

    for (ClassificationEvaluationUtil.Computations c : ClassificationEvaluationUtil.Computations.values()) {
        params.set(c.arrayParamInfo, ClassificationEvaluationUtil.getAllValues(c.computer, data));
    }
    setClassificationCommonParams(params, data, labels);
    setLoglossParams(params, logLoss, total);
    return new MultiClassMetrics(params);
}
 
Example 10
Source File: ParamsTest.java    From flink with Apache License 2.0 5 votes vote down vote up
@Test
public void testValidator() {
	Params params = new Params();

	ParamInfo<Integer> intParam =
		ParamInfoFactory.createParamInfo("a", Integer.class).setValidator(i -> i > 0).build();
	params.set(intParam, 1);

	thrown.expect(RuntimeException.class);
	thrown.expectMessage("Setting a as a invalid value:0");
	params.set(intParam, 0);
}
 
Example 11
Source File: BaseLinearModelTrainBatchOp.java    From Alink with Apache License 2.0 5 votes vote down vote up
/**
 * Transform train data to Tuple3 format.
 *
 * @param in          train data in row format.
 * @param params      train parameters.
 * @param labelValues label values.
 * @param isRegProc   is regression process or not.
 * @return Tuple3 format train data <weight, label, vector></>.
 */
private DataSet<Tuple3<Double, Double, Vector>> transform(BatchOperator in,
                                                          Params params,
                                                          DataSet<Object> labelValues,
                                                          boolean isRegProc) {
    String[] featureColNames = params.get(LinearTrainParams.FEATURE_COLS);
    String labelName = params.get(LinearTrainParams.LABEL_COL);
    String weightColName = params.get(LinearTrainParams.WEIGHT_COL);
    String vectorColName = params.get(LinearTrainParams.VECTOR_COL);
    TableSchema dataSchema = in.getSchema();
    if (null == featureColNames && null == vectorColName) {
        featureColNames = TableUtil.getNumericCols(dataSchema, new String[] {labelName});
        params.set(LinearTrainParams.FEATURE_COLS, featureColNames);
    }
    int[] featureIndices = null;
    int labelIdx = TableUtil.findColIndexWithAssertAndHint(dataSchema.getFieldNames(), labelName);
    if (featureColNames != null) {
        featureIndices = new int[featureColNames.length];
        for (int i = 0; i < featureColNames.length; ++i) {
            int idx = TableUtil.findColIndexWithAssertAndHint(in.getColNames(), featureColNames[i]);
            featureIndices[i] = idx;
            TypeInformation type = in.getSchema().getFieldTypes()[idx];

            Preconditions.checkState(TableUtil.isNumber(type),
                "linear algorithm only support numerical data type. type is : " + type);
        }
    }
    int weightIdx = weightColName != null ? TableUtil.findColIndexWithAssertAndHint(in.getColNames(), weightColName) : -1;
    int vecIdx = vectorColName != null ? TableUtil.findColIndexWithAssertAndHint(in.getColNames(), vectorColName) : -1;

    return in.getDataSet().map(new Transform(isRegProc, weightIdx, vecIdx, featureIndices, labelIdx))
        .withBroadcastSet(labelValues, LABEL_VALUES);
}
 
Example 12
Source File: FeedForwardTrainer.java    From Alink with Apache License 2.0 5 votes vote down vote up
/**
 * Train the network.
 *
 * @param data               Training data, a dataset of tuples of (label, features).
 * @param optimizationParams Parameters for optimizations.
 * @return The model weights.
 */
public DataSet<DenseVector> train(DataSet<Tuple2<Double, DenseVector>> data, Params optimizationParams) {
    final Topology topology = this.topology;
    final int inputSize = this.inputSize;
    final int outputSize = this.outputSize;
    final boolean onehotLabel = this.onehotLabel;

    ParamInfo<Integer> NUM_SEARCH_STEP = ParamInfoFactory
        .createParamInfo("numSearchStep", Integer.class)
        .setDescription("num search step")
        .setRequired()
        .build();

    DataSet<DenseVector> initCoef = initModel(data, this.topology);
    DataSet<Tuple3<Double, Double, Vector>> trainData = stack(data, blockSize, inputSize, outputSize,
        onehotLabel);
    optimizationParams.set(NUM_SEARCH_STEP, 3);
    final AnnObjFunc annObjFunc = new AnnObjFunc(topology, inputSize, outputSize, onehotLabel, optimizationParams);

    // We always use L-BFGS to train the network.
    Optimizer optimizer = new Lbfgs(
        data.getExecutionEnvironment().fromElements(annObjFunc),
        trainData,
        BatchOperator
            .getExecutionEnvironmentFromDataSets(data)
            .fromElements(inputSize),
        optimizationParams
    );
    optimizer.initCoefWith(initCoef);
    return optimizer.optimize().map(new MapFunction<Tuple2<DenseVector, double[]>, DenseVector>() {
        @Override
        public DenseVector map(Tuple2<DenseVector, double[]> value) throws Exception {
            return value.f0;
        }
    });
}
 
Example 13
Source File: SelectMapperTest.java    From Alink with Apache License 2.0 5 votes vote down vote up
@Test
public void testStringFunctions() throws Exception {
    TableSchema dataSchema = TableSchema.builder().fields(
        new String[] {"id", "name"},
        new DataType[] {DataTypes.INT(), DataTypes.STRING()}).build();
    Params params = new Params();
    params.set(HasClause.CLAUSE,
        "name || name, CHAR_LENGTH(name), CHARACTER_LENGTH(name), UPPER(name), LOWER(name), POSITION(name IN name),"
            + "TRIM('a' FROM name), REPEAT(name, 3)"
            + ", OVERLAY('This is an old string' PLACING ' new' FROM 10 FOR 5)"
            + ", SUBSTRING(name FROM 2)"
            + ", REPLACE('hello world', 'world', 'flink')"
            + ", INITCAP(name)"
            + ", FROM_BASE64('aGVsbG8gd29ybGQ=')"
            + ", TO_BASE64('hello world')"
            + ", LPAD('hi',4,'??')"
            + ", RPAD('hi',4,'??')"
            + ", REGEXP_REPLACE('foobar', 'oo|ar', '')"
            + ", REGEXP_EXTRACT('foothebar', 'foo(.*?)(bar)', 2)"
            + ", LTRIM(' This is a test String.')"
            + ", RTRIM('This is a test String. ')"
    );
    SelectMapper selectMapper = new SelectMapper(dataSchema, params);
    selectMapper.open();
    Row expected = Row.of("'abc''abc'", 5, 5, "'ABC'", "'abc'", 1, "'abc'", "'abc''abc''abc'",
        "This is a new string", "abc'", "hello flink", "'Abc'", "hello world", "aGVsbG8gd29ybGQ=", "??hi", "hi??",
        "fb", "bar", "This is a test String.", "This is a test String.");
    Row output = selectMapper.map(Row.of(1, "'abc'"));
    assertEquals(expected.getArity(), output.getArity());
    try {
        assertEquals(expected, output);
    } finally {
        selectMapper.close();
    }
}
 
Example 14
Source File: ClusterEvaluationUtil.java    From Alink with Apache License 2.0 5 votes vote down vote up
@Override
public Params map(BaseMetricsSummary t) throws Exception {
    Params params = t.toMetrics().getParams();
    List<Tuple1<Double>> silhouetteCoefficient = getRuntimeContext().getBroadcastVariable(
        EvalClusterBatchOp.SILHOUETTE_COEFFICIENT);
    params.set(ClusterMetrics.SILHOUETTE_COEFFICIENT,
        silhouetteCoefficient.get(0).f0 / params.get(ClusterMetrics.COUNT));
    return params;
}
 
Example 15
Source File: LinearModelDataConverter.java    From Alink with Apache License 2.0 5 votes vote down vote up
private Params getMetaInfo(LinearModelData data) {
    Params meta = new Params();
    meta.set(ModelParamName.MODEL_NAME, data.modelName);
    meta.set(ModelParamName.HAS_INTERCEPT_ITEM, data.hasInterceptItem);
    meta.set(ModelParamName.LINEAR_MODEL_TYPE, data.linearModelType);
    if (data.vectorColName != null) {
        meta.set(HasVectorCol.VECTOR_COL, data.vectorColName);
        meta.set(ModelParamName.VECTOR_SIZE, data.vectorSize);
    }
    meta.set(HasLabelCol.LABEL_COL, data.labelName);
    return meta;
}
 
Example 16
Source File: ImputerModelDataConverter.java    From Alink with Apache License 2.0 5 votes vote down vote up
/**
 * Serialize the model to "Tuple3<Params, List<String>, List<Row>>"
 *
 * @param modelData The model data to serialize.
 * @return The serialization result.
 */
@Override
public Tuple3<Params, Iterable<String>, Iterable<Row>> serializeModel(Tuple3<Strategy, TableSummary, String> modelData) {
    Strategy strategy = modelData.f0;
    TableSummary summary = modelData.f1;
    String fillValue = modelData.f2;

    double[] values = null;
    Params meta = new Params()
            .set(STRATEGY, strategy)
            .set(SELECTED_COLS, selectedColNames);
    switch (strategy) {
        case MIN:
            values = new double[selectedColNames.length];
            for (int i = 0; i < selectedColNames.length; i++) {
                values[i] = summary.min(selectedColNames[i]);
            }
            break;
        case MAX:
            values = new double[selectedColNames.length];
            for (int i = 0; i < selectedColNames.length; i++) {
                values[i] = summary.max(selectedColNames[i]);
            }
            break;
        case MEAN:
            values = new double[selectedColNames.length];
            for (int i = 0; i < selectedColNames.length; i++) {
                values[i] = summary.mean(selectedColNames[i]);
            }
            break;
        default:
            meta.set(FILL_VALUE, fillValue);
    }

    List<String> data = new ArrayList<>();
    data.add(JsonConverter.toJson(values));

    return Tuple3.of(meta, data, new ArrayList<>());
}
 
Example 17
Source File: BaseLinearModelTrainBatchOp.java    From Alink with Apache License 2.0 5 votes vote down vote up
/**
 * Build model data.
 *
 * @param meta            meta info.
 * @param featureNames    feature column names.
 * @param labelType       label type.
 * @param meanVar         mean and variance of vector.
 * @param hasIntercept    has interception or not.
 * @param standardization do standardization or not.
 * @param coefVector      coefficient vector.
 * @return linear mode data.
 */
public static LinearModelData buildLinearModelData(Params meta,
                                                   String[] featureNames,
                                                   TypeInformation labelType,
                                                   DenseVector[] meanVar,
                                                   boolean hasIntercept,
                                                   boolean standardization,
                                                   Tuple2<DenseVector, double[]> coefVector) {
    if (!(LinearModelType.AFT.equals(meta.get(ModelParamName.LINEAR_MODEL_TYPE)))) {
        modifyMeanVar(standardization, meanVar);
    }

    meta.set(ModelParamName.VECTOR_SIZE, coefVector.f0.size()
        - (meta.get(ModelParamName.HAS_INTERCEPT_ITEM) ? 1 : 0)
        - (LinearModelType.AFT.equals(meta.get(ModelParamName.LINEAR_MODEL_TYPE).toString()) ? 1 : 0));
    if (!(LinearModelType.AFT.equals(meta.get(ModelParamName.LINEAR_MODEL_TYPE)))) {
        if (standardization) {
            int n = meanVar[0].size();
            if (hasIntercept) {
                double sum = 0.0;
                for (int i = 0; i < n; ++i) {
                    sum += coefVector.f0.get(i + 1) * meanVar[0].get(i) / meanVar[1].get(i);
                    coefVector.f0.set(i + 1, coefVector.f0.get(i + 1) / meanVar[1].get(i));
                }
                coefVector.f0.set(0, coefVector.f0.get(0) - sum);
            } else {
                for (int i = 0; i < n; ++i) {
                    coefVector.f0.set(i, coefVector.f0.get(i) / meanVar[1].get(i));
                }
            }
        }
    }

    LinearModelData modelData = new LinearModelData(labelType, meta, featureNames, coefVector.f0);
    modelData.lossCurve = coefVector.f1;

    return modelData;
}
 
Example 18
Source File: OutputModel.java    From Alink with Apache License 2.0 5 votes vote down vote up
@Override
public List <Row> calc(ComContext context) {
	if (context.getTaskId() != 0) {
		return null;
	}

	// get the coefficient of min loss.
	Tuple2 <DenseVector, double[]> minCoef = context.getObj(OptimVariable.minCoef);
	double[] lossCurve = context.getObj(OptimVariable.lossCurve);

	int effectiveSize = lossCurve.length;
	for (int i = 0; i < lossCurve.length; ++i) {
		if (Double.isInfinite(lossCurve[i])) {
			effectiveSize = i;
			break;
		}
	}

	double[] effectiveCurve = new double[effectiveSize];
	System.arraycopy(lossCurve, 0, effectiveCurve, 0, effectiveSize);


	Params params = new Params();
	for (int i = 0; i < minCoef.f0.size(); ++i) {
		if (Double.isNaN(minCoef.f0.get(i)) || Double.isInfinite(minCoef.f0.get(i))) {
			throw new RuntimeException("Optimization result has NAN or infinite value, coefficient is invalid");
		}
	}
	params.set(ModelParamName.COEF, minCoef.f0);
	params.set(ModelParamName.LOSS_CURVE, effectiveCurve);
	List <Row> model = new ArrayList <>(1);
	model.add(Row.of(params.toJson()));
	return model;
}
 
Example 19
Source File: TreeInitObj.java    From Alink with Apache License 2.0 4 votes vote down vote up
@Override
public void calc(ComContext context) {
	if (context.getStepNo() != 1) {
		return;
	}

	List <Row> dataRows = context.getObj("treeInput");
	List <Row> quantileModel = context.getObj("quantileModel");
	List <Row> stringIndexerModel = context.getObj("stringIndexerModel");
	List<Object[]> labels = context.getObj("labels");

	int nLocalRow = dataRows == null ? 0 : dataRows.size();

	Params localParams = params.clone();
	localParams.set(TASK_ID, context.getTaskId());
	localParams.set(NUM_OF_SUBTASKS, context.getNumTask());
	localParams.set(N_LOCAL_ROW, nLocalRow);

	QuantileDiscretizerModelDataConverter quantileDiscretizerModel = initialMapping(quantileModel);

	List<String> lookUpColNames = new ArrayList<>();

	if (params.get(HasCategoricalCols.CATEGORICAL_COLS) != null) {
		lookUpColNames.addAll(Arrays.asList(params.get(HasCategoricalCols.CATEGORICAL_COLS)));
	}

	Map<String, Integer> categoricalColsSize = TreeUtil.extractCategoricalColsSize(
		stringIndexerModel, lookUpColNames.toArray(new String[0]));

	if (!Criteria.isRegression(params.get(TreeUtil.TREE_TYPE))) {
		categoricalColsSize.put(params.get(HasLabelCol.LABEL_COL), labels.get(0).length);
	}

	FeatureMeta[] featureMetas = TreeUtil.getFeatureMeta(
		params.get(HasFeatureCols.FEATURE_COLS),
		categoricalColsSize
	);

	FeatureMeta labelMeta = TreeUtil.getLabelMeta(
		params.get(HasLabelCol.LABEL_COL),
		params.get(HasFeatureCols.FEATURE_COLS).length,
		categoricalColsSize);

	TreeObj treeObj;

	if (Criteria.isRegression(params.get(TreeUtil.TREE_TYPE))) {
		treeObj = new RegObj(localParams, quantileDiscretizerModel, featureMetas, labelMeta);
	} else {
		treeObj = new ClassifierObj(localParams, quantileDiscretizerModel, featureMetas, labelMeta);
	}

	int nFeatureCol = localParams.get(RandomForestTrainParams.FEATURE_COLS).length;

	int[] data = new int[nFeatureCol * nLocalRow];

	double[] regLabels = null;
	int[] classifyLabels = null;

	if (Criteria.isRegression(params.get(TreeUtil.TREE_TYPE))) {
		regLabels = new double[nLocalRow];
	} else {
		classifyLabels = new int[nLocalRow];
	}

	int agg = 0;
	for (int iter = 0; iter < nLocalRow; ++iter) {

		for (int i = 0; i < nFeatureCol; ++i) {
			data[i * nLocalRow + agg] = (int) dataRows.get(iter).getField(i);
		}

		if (Criteria.isRegression(params.get(TreeUtil.TREE_TYPE))) {
			regLabels[agg] = (double) dataRows.get(iter).getField(nFeatureCol);
		} else {
			classifyLabels[agg] = (int) dataRows.get(iter).getField(nFeatureCol);
		}

		agg++;
	}

	treeObj.setFeatures(data);

	if (Criteria.isRegression(params.get(TreeUtil.TREE_TYPE))) {
		treeObj.setLabels(regLabels);
	} else {
		treeObj.setLabels(classifyLabels);
	}

	double[] histBuffer = new double[treeObj.getMaxHistBufferSize()];
	context.putObj("allReduce", histBuffer);
	treeObj.setHist(histBuffer);

	treeObj.initialRoot();

	context.putObj("treeObj", treeObj);
}
 
Example 20
Source File: ClassificationEvaluationUtil.java    From Alink with Apache License 2.0 4 votes vote down vote up
static void setLoglossParams(Params params, double logLoss, long total) {
    if (logLoss >= 0) {
        params.set(BaseSimpleClassifierMetrics.LOG_LOSS, logLoss / total);
    }
}