Java Code Examples for org.apache.flink.ml.api.misc.param.Params#get()

The following examples show how to use org.apache.flink.ml.api.misc.param.Params#get() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: GlmModelDataConverter.java    From Alink with Apache License 2.0 6 votes vote down vote up
/**
 *
 * @param meta The model meta data.
 * @param data The model concrete data.
 * @return GlmModelData
 */
@Override
public GlmModelData deserializeModel(Params meta, Iterable<String> data) {
    GlmModelData modelData = new GlmModelData();
    modelData.featureColNames = meta.get(GlmTrainParams.FEATURE_COLS);
    modelData.offsetColName = meta.get(GlmTrainParams.OFFSET_COL);
    modelData.weightColName = meta.get(GlmTrainParams.WEIGHT_COL);
    modelData.labelColName = meta.get(GlmTrainParams.LABEL_COL);

    modelData.familyName = meta.get(GlmTrainParams.FAMILY);
    modelData.variancePower = meta.get(GlmTrainParams.VARIANCE_POWER);
    modelData.linkName = meta.get(GlmTrainParams.LINK);
    modelData.linkPower = meta.get(GlmTrainParams.LINK_POWER);
    modelData.fitIntercept = meta.get(GlmTrainParams.FIT_INTERCEPT);
    modelData.regParam = meta.get(GlmTrainParams.REG_PARAM);
    modelData.numIter = meta.get(GlmTrainParams.MAX_ITER);
    modelData.epsilon = meta.get(GlmTrainParams.EPSILON);

    Iterator<String> dataIterator = data.iterator();

    modelData.coefficients = JsonConverter.fromJson(dataIterator.next(), double[].class);
    modelData.intercept = JsonConverter.fromJson(dataIterator.next(), double.class);
    modelData.diagInvAtWA = JsonConverter.fromJson(dataIterator.next(), double[].class);
    return modelData;
}
 
Example 2
Source File: DirectReader.java    From Alink with Apache License 2.0 6 votes vote down vote up
/**
 * Create data bridge from batch operator.
 * The type of result DataBridge is the one with matching policy in global configuration.
 *
 *
 * @param model the operator to collect data.
 * @return the created DataBridge.
 */
public static DataBridge collect(BatchOperator<?> model) {
	final Params globalParams = DirectReader.readProperties();
	final String policy = globalParams.get(POLICY_KEY);

	for (DataBridgeGenerator generator : ServiceLoader.load(DataBridgeGenerator.class, DirectReader.class.getClassLoader())) {
		if (policy.equals(generator
			.getClass()
			.getAnnotation(DataBridgeGeneratorPolicy.class)
			.policy()
		)) {
			return generator.generate(model, globalParams);
		}
	}

	throw new IllegalArgumentException("Can not find the policy: " + policy);
}
 
Example 3
Source File: ParamsTest.java    From flink with Apache License 2.0 6 votes vote down vote up
@Test
public void testGetAliasParam() {
	ParamInfo <String> predResultColName = ParamInfoFactory
		.createParamInfo("predResultColName", String.class)
		.setDescription("Column name of predicted result.")
		.setRequired()
		.setAlias(new String[] {"predColName", "outputColName"})
		.build();

	Params params = Params.fromJson("{\"predResultColName\":\"\\\"f0\\\"\"}");

	Assert.assertEquals("f0", params.get(predResultColName));

	params = Params.fromJson("{\"predResultColName\":\"\\\"f0\\\"\", \"predColName\":\"\\\"f0\\\"\"}");

	try {
		params.get(predResultColName);
		Assert.fail("failure");
	} catch (IllegalArgumentException ex) {
		Assert.assertTrue(ex.getMessage().startsWith("Duplicate parameters of predResultColName and predColName"));
	}
}
 
Example 4
Source File: StandardScalerModelMapper.java    From Alink with Apache License 2.0 6 votes vote down vote up
public StandardScalerModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) {
    super(modelSchema, dataSchema, params);
    this.selectedColNames = ImputerModelDataConverter.extractSelectedColNames(modelSchema);
    this.selectedColTypes = ImputerModelDataConverter.extractSelectedColTypes(modelSchema);
    this.selectedColIndices = TableUtil.findColIndicesWithAssert(dataSchema, selectedColNames);

    String[] outputColNames = params.get(SrtPredictMapperParams.OUTPUT_COLS);
    if (outputColNames == null) {
        outputColNames = selectedColNames;
    }

    this.predResultColsHelper = new OutputColsHelper(dataSchema,
        outputColNames,
        this.selectedColTypes,
        null);
}
 
Example 5
Source File: ImputerModelMapper.java    From Alink with Apache License 2.0 6 votes vote down vote up
/**
 * Constructor.
 * @param modelSchema the model schema.
 * @param dataSchema  the data schema.
 * @param params      the params.
 */
public ImputerModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) {
    super(modelSchema, dataSchema, params);
    String[] selectedColNames = ImputerModelDataConverter.extractSelectedColNames(modelSchema);
    TypeInformation[] selectedColTypes = ImputerModelDataConverter.extractSelectedColTypes(modelSchema);
    this.selectedColIndices = TableUtil.findColIndicesWithAssert(dataSchema, selectedColNames);

    String[] outputColNames = params.get(SrtPredictMapperParams.OUTPUT_COLS);
    if (outputColNames == null) {
        outputColNames = selectedColNames;
    }

    this.predictResultColsHelper = new OutputColsHelper(dataSchema, outputColNames, selectedColTypes, null);
    int length = selectedColTypes.length;
    this.type = new Type[length];
    for (int i = 0; i < length; i++) {
        this.type[i] = Type.valueOf(selectedColTypes[i].getTypeClass().getSimpleName().toUpperCase());
    }
}
 
Example 6
Source File: AFTModelMapper.java    From Alink with Apache License 2.0 5 votes vote down vote up
/**
 * Constructor.
 *
 * @param modelSchema the model schema.
 * @param dataSchema  the data schema.
 * @param params      the params.
 */
public AFTModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) {
    super(modelSchema, dataSchema, params);
    this.quantileProbabilities = params.get(AftRegPredictParams.QUANTILE_PROBABILITIES);
    if (null != params) {
        String vectorColName = params.get(LinearModelMapperParams.VECTOR_COL);
        if (null != vectorColName && vectorColName.length() != 0) {
            this.vectorColIndex = TableUtil.findColIndexWithAssert(dataSchema.getFieldNames(), vectorColName);
        }
    }
}
 
Example 7
Source File: SoftmaxModelMapper.java    From Alink with Apache License 2.0 5 votes vote down vote up
public SoftmaxModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) {
	super(modelSchema, dataSchema, params);
	if (null != params) {
		String vectorColName = params.get(SoftmaxPredictParams.VECTOR_COL);
		if (null != vectorColName && vectorColName.length() != 0) {
			this.vectorColIndex = TableUtil.findColIndexWithAssert(dataSchema.getFieldNames(), vectorColName);
		}
	}
}
 
Example 8
Source File: ParamsTest.java    From Alink with Apache License 2.0 5 votes vote down vote up
@Test
public void testContain4() {
    Params params = new Params()
        .set(HasEnumType.ENUM_TYPE, CalcType.aAA)
        .set(HasAppendType.APPEND_TYPE, "Dense");

    CalcType type = params.get(HasEnumType.ENUM_TYPE);
    System.out.println(type);
}
 
Example 9
Source File: VectorStandardScalerModelDataConverter.java    From Alink with Apache License 2.0 5 votes vote down vote up
/**
 * Deserialize the model data.
 *
 * @param meta         The model meta data.
 * @param data         The model concrete data.
 * @param additionData The additional data.
 * @return The model data used by mapper.
 */
@Override
public Tuple4<Boolean, Boolean, double[], double[]> deserializeModel(Params meta, Iterable<String> data, Iterable<Row> additionData) {
    double[] means = JsonConverter.fromJson(data.iterator().next(), double[].class);
    double[] stdDevs = JsonConverter.fromJson(data.iterator().next(), double[].class);

    Boolean withMean = meta.get(VectorStandardTrainParams.WITH_MEAN);
    Boolean withStd = meta.get(VectorStandardTrainParams.WITH_STD);

    return Tuple4.of(withMean, withStd, means, stdDevs);
}
 
Example 10
Source File: MISOMapper.java    From Alink with Apache License 2.0 5 votes vote down vote up
/**
 * Constructor.
 *
 * @param dataSchema input table schema.
 * @param params     input parameters.
 */
public MISOMapper(TableSchema dataSchema, Params params) {
	super(dataSchema, params);
	String[] inputColNames = this.params.get(MISOMapperParams.SELECTED_COLS);
	this.colIndices = TableUtil.findColIndicesWithAssertAndHint(dataSchema.getFieldNames(), inputColNames);
	String outputColName = params.get(MISOMapperParams.OUTPUT_COL);
	String[] keepColNames = null;
	if (this.params.contains(MISOMapperParams.RESERVED_COLS)) {
		keepColNames = this.params.get(MISOMapperParams.RESERVED_COLS);
	}
	this.outputColsHelper = new OutputColsHelper(dataSchema, outputColName, initOutputColType(), keepColNames);
}
 
Example 11
Source File: Preprocessing.java    From Alink with Apache License 2.0 5 votes vote down vote up
public static DataSet<Object[]> generateLabels(
	BatchOperator<?> input,
	Params params,
	boolean isRegression) {
	DataSet<Object[]> labels;
	if (!isRegression) {
		final String labelColName = params.get(HasLabelCol.LABEL_COL);
		DataSet<Row> labelDataSet = select(input, labelColName).getDataSet();

		labels = distinctLabels(labelDataSet
			.map(new MapFunction<Row, Object>() {
				@Override
				public Object map(Row value) throws Exception {
					return value.getField(0);
				}
			})
		);

	} else {
		labels = MLEnvironmentFactory.get(input.getMLEnvironmentId()).getExecutionEnvironment().fromElements(1)
			.mapPartition(new MapPartitionFunction<Integer, Object[]>() {
				@Override
				public void mapPartition(Iterable<Integer> values, Collector<Object[]> out) throws Exception {
					//pass
				}
			});
	}

	return labels;
}
 
Example 12
Source File: TableBucketingSink.java    From Alink with Apache License 2.0 5 votes vote down vote up
public TableBucketingSink(String tableName, Params params, TableSchema schema, BaseDB db) {
	this.tableNamePrefix = tableName;
	this.types = schema.getFieldTypes();
	this.colNames = schema.getFieldNames();
	this.db = db;

	this.batchRolloverInterval = params.get(TableBucketingSinkParams.BATCH_ROLLOVER_INTERVAL);
	this.batchSize = params.get(TableBucketingSinkParams.BATCH_SIZE);
	if (batchSize > 0 && batchRolloverInterval < 0L) {
		batchRolloverInterval = Long.MAX_VALUE;
	}
	if (batchSize < 0 && batchRolloverInterval > 0L) {
		batchSize = Integer.MAX_VALUE;
	}
}
 
Example 13
Source File: Preprocessing.java    From Alink with Apache License 2.0 5 votes vote down vote up
public static BatchOperator<?> castWeightCol(
	BatchOperator<?> input,
	Params params) {
	String weightCol = params.get(HasWeightColDefaultAsNull.WEIGHT_COL);
	if (weightCol == null) {
		return input;
	}

	return new NumericalTypeCastBatchOp()
		.setMLEnvironmentId(input.getMLEnvironmentId())
		.setSelectedCols(weightCol)
		.setTargetType("DOUBLE")
		.linkFrom(input);
}
 
Example 14
Source File: DocCountVectorizerTrainBatchOp.java    From Alink with Apache License 2.0 4 votes vote down vote up
public BuildDocCountModel(Params params) {
    this.featureType = params.get(DocHashCountVectorizerTrainParams.FEATURE_TYPE).name();
    this.minTF = params.get(DocHashCountVectorizerTrainParams.MIN_TF);
}
 
Example 15
Source File: DocHashCountVectorizerTrainBatchOp.java    From Alink with Apache License 2.0 4 votes vote down vote up
public BuildModel(Params params) {
    this.minDocFrequency = params.get(DocHashCountVectorizerTrainParams.MIN_DF);
    this.numFeatures = params.get(DocHashCountVectorizerTrainParams.NUM_FEATURES);
    this.featureType = params.get(DocHashCountVectorizerTrainParams.FEATURE_TYPE).name();
    this.minTF = params.get(DocHashCountVectorizerTrainParams.MIN_TF);
}
 
Example 16
Source File: FormatTransMapper.java    From Alink with Apache License 2.0 4 votes vote down vote up
public static Tuple2<FormatReader, String[]> initFormatReader(TableSchema dataSchema, Params params) {
    FormatReader formatReader;
    String[] fromColNames;

    FormatType fromFormat = params.get(FormatTransParams.FROM_FORMAT);
    switch (fromFormat) {
        case KV:
            String kvColName = params.get(FromKvParams.KV_COL);
            int kvColIndex = TableUtil.findColIndexWithAssertAndHint(dataSchema.getFieldNames(), kvColName);
            formatReader = new KvReader(
                kvColIndex,
                params.get(FromKvParams.KV_COL_DELIMITER),
                params.get(FromKvParams.KV_VAL_DELIMITER)
            );
            fromColNames = null;
            break;
        case CSV:
            String csvColName = params.get(FromCsvParams.CSV_COL);
            int csvColIndex = TableUtil.findColIndexWithAssertAndHint(dataSchema.getFieldNames(), csvColName);
            TableSchema fromCsvSchema = CsvUtil.schemaStr2Schema(params.get(FromCsvParams.SCHEMA_STR));
            formatReader = new CsvReader(
                csvColIndex,
                fromCsvSchema,
                params.get(FromCsvParams.CSV_FIELD_DELIMITER),
                params.get(FromCsvParams.QUOTE_CHAR)
            );
            fromColNames = fromCsvSchema.getFieldNames();
            break;
        case VECTOR:
            String vectorColName = params.get(FromVectorParams.VECTOR_COL);
            int vectorColIndex = TableUtil.findColIndexWithAssertAndHint(dataSchema.getFieldNames(),
                vectorColName);
            if (params.contains(HasSchemaStr.SCHEMA_STR)) {
                formatReader = new VectorReader(
                    vectorColIndex,
                    CsvUtil.schemaStr2Schema(params.get(HasSchemaStr.SCHEMA_STR))
                );
            } else {
                formatReader = new VectorReader(vectorColIndex, null);
            }
            fromColNames = null;
            break;
        case JSON:
            String jsonColName = params.get(FromJsonParams.JSON_COL);
            int jsonColIndex = TableUtil.findColIndexWithAssertAndHint(dataSchema.getFieldNames(), jsonColName);
            formatReader = new JsonReader(jsonColIndex);
            fromColNames = null;
            break;
        case COLUMNS:
            fromColNames = params.get(FromColumnsParams.SELECTED_COLS);
            if (null == fromColNames) {
                fromColNames = dataSchema.getFieldNames();
            }
            int[] colIndices = TableUtil.findColIndicesWithAssertAndHint(dataSchema.getFieldNames(), fromColNames);
            formatReader = new ColumnsReader(colIndices, fromColNames);
            break;
        default:
            throw new IllegalArgumentException("Can not translate this type : " + fromFormat);
    }

    return new Tuple2<>(formatReader, fromColNames);
}
 
Example 17
Source File: StringToColumnsMappers.java    From Alink with Apache License 2.0 4 votes vote down vote up
@Override
protected StringParsers.StringParser getParser(String[] fieldNames, TypeInformation[] fieldTypes, Params params) {
    String colDelim = params.get(KvToColumnsParams.COL_DELIMITER);
    String valDelim = params.get(KvToColumnsParams.VAL_DELIMITER);
    return new StringParsers.KvParser(fieldNames, fieldTypes, colDelim, valDelim);
}
 
Example 18
Source File: SOSImpl.java    From Alink with Apache License 2.0 4 votes vote down vote up
public SOSImpl(Params params) {
	perplexity = params.get(SosParams.PERPLEXITY);
}
 
Example 19
Source File: TreeModelDataConverter.java    From Alink with Apache License 2.0 4 votes vote down vote up
@Override
protected TreeModelDataConverter deserializeModel(Params meta, Iterable<String> iterable, Iterable<Object> distinctLabels) {
	// parseDense partition of categorical
	Partition stringIndexerModelPartition = meta.get(
		STRING_INDEXER_MODEL_PARTITION
	);

	List<String> data = new ArrayList<>();
	iterable.forEach(data::add);
	if (stringIndexerModelPartition.getF1() != stringIndexerModelPartition.getF0()) {
		stringIndexerModelSerialized = new ArrayList<>();

		for (int i = stringIndexerModelPartition.getF0(); i < stringIndexerModelPartition.getF1(); ++i) {
			Object[] deserialized = JsonConverter.fromJson(data.get(i), Object[].class);
			stringIndexerModelSerialized.add(
				Row.of(
					((Integer)deserialized[0]).longValue(),
					deserialized[1],
					deserialized[2]
				)
			);
		}
	} else {
		stringIndexerModelSerialized = null;
	}

	// toString partition of trees
	Partitions treesPartition = meta.get(
		TREE_PARTITIONS
	);

	roots = treesPartition.getPartitions().stream()
		.map(x -> deserializeTree(data.subList(x.getF0(), x.getF1())))
		.toArray(Node[]::new);

	this.meta = meta;
	List<Object> labelList = new ArrayList<>();
	distinctLabels.forEach(labelList::add);
	this.labels = labelList.toArray();

	return this;
}
 
Example 20
Source File: DBSinkStreamOp.java    From Alink with Apache License 2.0 4 votes vote down vote up
public DBSinkStreamOp(BaseDB db, Params parameter) {
	super(AnnotationUtils.annotatedName(db.getClass()), db.getParams().clone().merge(parameter));

	this.db = db;
	this.tableName = parameter.get(AnnotationUtils.tableAliasParamKey(db.getClass()));
}