Java Code Examples for org.apache.flink.ml.api.misc.param.Params#contains()

The following examples show how to use org.apache.flink.ml.api.misc.param.Params#contains() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: MySqlDB.java    From Alink with Apache License 2.0 6 votes vote down vote up
@Override
public Table getStreamTable(String tableName, Params params, Long sessionId) throws Exception {
	if (!params.contains(MySqlSourceParams.SCHEMA_STR)) {
		return super.getStreamTable(tableName, params, sessionId);
	} else {
		TableSchema schema = CsvUtil.schemaStr2Schema(params.get(MySqlSourceParams.SCHEMA_STR));

		JDBCInputFormat inputFormat = JDBCInputFormat.buildJDBCInputFormat()
			.setUsername(getUserName())
			.setPassword(getPassword())
			.setDrivername(getDriverName())
			.setDBUrl(getDbUrl())
			.setQuery("select * from " + tableName)
			.setRowTypeInfo(new RowTypeInfo(schema.getFieldTypes(), schema.getFieldNames()))
			.finish();

		return DataStreamConversionUtil.toTable(
			sessionId,
			MLEnvironmentFactory.get(sessionId).getStreamExecutionEnvironment().createInput(inputFormat),
			schema.getFieldNames(), schema.getFieldTypes());
	}
}
 
Example 2
Source File: MySqlDB.java    From Alink with Apache License 2.0 6 votes vote down vote up
@Override
public Table getBatchTable(String tableName, Params params, Long sessionId) throws Exception {
	if (!params.contains(MySqlSourceParams.SCHEMA_STR)) {
		return super.getBatchTable(tableName, params, sessionId);
	} else {
		TableSchema schema = CsvUtil.schemaStr2Schema(params.get(MySqlSourceParams.SCHEMA_STR));

		JDBCInputFormat inputFormat = JDBCInputFormat.buildJDBCInputFormat()
			.setUsername(getUserName())
			.setPassword(getPassword())
			.setDrivername(getDriverName())
			.setDBUrl(getDbUrl())
			.setQuery("select * from " + tableName)
			.setRowTypeInfo(new RowTypeInfo(schema.getFieldTypes(), schema.getFieldNames()))
			.finish();

		return DataSetConversionUtil.toTable(sessionId,
			MLEnvironmentFactory.get(sessionId).getExecutionEnvironment().createInput(inputFormat),
			schema.getFieldNames(), schema.getFieldTypes());
	}
}
 
Example 3
Source File: LinearModelDataConverter.java    From Alink with Apache License 2.0 6 votes vote down vote up
/**
 * Deserialize the model data.
 *
 * @param meta         The model meta data.
 * @param data         The model concrete data.
 * @param distinctLabels All the label values in the data set.
 * @return The deserialized model data.
 */
@Override
public LinearModelData deserializeModel(Params meta, Iterable<String> data, Iterable<Object> distinctLabels) {
    LinearModelData modelData = new LinearModelData();
    if (meta.contains(ModelParamName.LABEL_VALUES)) {
        modelData.labelValues = FeatureLabelUtil.recoverLabelType(meta.get(ModelParamName.LABEL_VALUES),
            this.labelType);
    }
    setMetaInfo(meta, modelData);
    if (distinctLabels != null) {
        List<Object> labelList = new ArrayList<>();
        distinctLabels.forEach(labelList::add);
        modelData.labelValues = labelList.toArray();
    }
    setModelData(JsonConverter.fromJson(data.iterator().next(), ModelData.class), modelData);
    return modelData;
}
 
Example 4
Source File: CorrelationDataConverter.java    From Alink with Apache License 2.0 6 votes vote down vote up
/**
 * Deserialize the model from "Params meta" and "List<String> data".
 */
@Override
public CorrelationResult deserializeModel(Params meta, Iterable<String> data) {


    String[] colNames = null;
    if (meta.contains(CorrelationParams.SELECTED_COLS)) {
        colNames = meta.get(CorrelationParams.SELECTED_COLS);
    }
    DenseMatrix matrix = null;
    int i = 0;
    for (String vecStr : data) {
        DenseVector vec = (DenseVector) VectorUtil.getVector(vecStr);
        if (matrix == null) {
            matrix = new DenseMatrix(vec.size(), vec.size());
        }
        for (int j = 0; j < vec.size(); j++) {
            matrix.set(i, j, vec.get(j));
        }
        i++;
    }

    CorrelationResult modelData = new CorrelationResult(matrix, colNames);
    
    return modelData;
}
 
Example 5
Source File: OneVsRestModelMapper.java    From Alink with Apache License 2.0 6 votes vote down vote up
public OneVsRestModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) {
    super(modelSchema, dataSchema, params);

    String predResultColName = params.get(OneVsRestPredictParams.PREDICTION_COL);
    String[] keepColNames = params.get(OneVsRestPredictParams.RESERVED_COLS);
    this.predDetail = params.contains(OneVsRestPredictParams.PREDICTION_DETAIL_COL);
    int numModelCols = modelSchema.getFieldNames().length;
    TypeInformation labelType = modelSchema.getFieldTypes()[numModelCols - 1];
    if (predDetail) {
        String predDetailColName = params.get(OneVsRestPredictParams.PREDICTION_DETAIL_COL);
        outputColsHelper = new OutputColsHelper(dataSchema, new String[]{predResultColName, predDetailColName},
            new TypeInformation[]{labelType, Types.STRING}, keepColNames);
    } else {
        outputColsHelper = new OutputColsHelper(dataSchema, predResultColName, labelType, keepColNames);
    }

    this.binClsPredParams = params.clone();
    this.binClsPredParams.set(OneVsRestPredictParams.RESERVED_COLS, new String[0]);
    this.binClsPredParams.set(OneVsRestPredictParams.PREDICTION_COL, "pred_result");
    this.binClsPredParams.set(OneVsRestPredictParams.PREDICTION_DETAIL_COL, "pred_detail");
}
 
Example 6
Source File: KMeansModelMapper.java    From Alink with Apache License 2.0 6 votes vote down vote up
public KMeansModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) {
    super(modelSchema, dataSchema, params);
    String[] reservedColNames = this.params.get(KMeansPredictParams.RESERVED_COLS);
    String predResultColName = this.params.get(KMeansPredictParams.PREDICTION_COL);
    isPredDetail = params.contains(KMeansPredictParams.PREDICTION_DETAIL_COL);
    isPredDistance = params.contains(KMeansPredictParams.PREDICTION_DISTANCE_COL);
    List<String> outputCols = new ArrayList<>();
    List<TypeInformation> outputTypes = new ArrayList<>();
    outputCols.add(predResultColName);
    outputTypes.add(Types.LONG);
    if (isPredDetail) {
        outputCols.add(params.get(KMeansPredictParams.PREDICTION_DETAIL_COL));
        outputTypes.add(Types.STRING);
    }
    if (isPredDistance) {
        outputCols.add(params.get(KMeansPredictParams.PREDICTION_DISTANCE_COL));
        outputTypes.add(Types.DOUBLE);
    }
    this.outputColsHelper = new OutputColsHelper(dataSchema, outputCols.toArray(new String[0]),
        outputTypes.toArray(new TypeInformation[0]), reservedColNames);
}
 
Example 7
Source File: Preprocessing.java    From Alink with Apache License 2.0 5 votes vote down vote up
public static BatchOperator<?> castLabel(
	BatchOperator<?> input, Params params, DataSet<Object[]> labels, boolean isRegression) {
	String[] inputColNames = input.getColNames();
	if (!isRegression) {
		final String labelColName = params.get(HasLabelCol.LABEL_COL);
		final TypeInformation<?>[] types = input.getColTypes();
		input = new DataSetWrapperBatchOp(
			findIndexOfLabel(
				input.getDataSet(), labels,
				TableUtil.findColIndex(inputColNames, labelColName)
			),
			input.getColNames(),
			IntStream.range(0, input.getColTypes().length)
				.mapToObj(x -> x == TableUtil.findColIndex(inputColNames, labelColName) ? Types.INT :
					types[x])
				.toArray(TypeInformation[]::new)

		).setMLEnvironmentId(input.getMLEnvironmentId());

	} else {
		if (params.contains(HasLabelCol.LABEL_COL)) {
			input = new NumericalTypeCastBatchOp()
				.setMLEnvironmentId(input.getMLEnvironmentId())
				.setSelectedCols(params.get(HasLabelCol.LABEL_COL))
				.setTargetType("DOUBLE")
				.linkFrom(input);
		}
	}

	return input;
}
 
Example 8
Source File: Preprocessing.java    From Alink with Apache License 2.0 5 votes vote down vote up
public static BatchOperator<?> generateStringIndexerModel(BatchOperator<?> input, Params params) {
	String[] categoricalColNames = null;
	if (params.contains(HasCategoricalCols.CATEGORICAL_COLS)) {
		categoricalColNames = params.get(HasCategoricalCols.CATEGORICAL_COLS);
	}
	BatchOperator<?> stringIndexerModel;
	if (categoricalColNames == null || categoricalColNames.length == 0) {
		MultiStringIndexerModelDataConverter emptyModel = new MultiStringIndexerModelDataConverter();

		stringIndexerModel = new DataSetWrapperBatchOp(
			MLEnvironmentFactory
				.get(input.getMLEnvironmentId())
				.getExecutionEnvironment()
				.fromElements(1)
				.mapPartition(new MapPartitionFunction<Integer, Row>() {
					@Override
					public void mapPartition(Iterable<Integer> values, Collector<Row> out) throws Exception {
						//pass
					}
				}),
			emptyModel.getModelSchema().getFieldNames(),
			emptyModel.getModelSchema().getFieldTypes()
		).setMLEnvironmentId(input.getMLEnvironmentId());
	} else {
		stringIndexerModel = new MultiStringIndexerTrainBatchOp()
			.setMLEnvironmentId(input.getMLEnvironmentId())
			.setSelectedCols(categoricalColNames)
			.linkFrom(input);
	}

	return stringIndexerModel;
}
 
Example 9
Source File: TreeUtil.java    From Alink with Apache License 2.0 5 votes vote down vote up
public static String[] trainColNames(Params params) {
	ArrayList<String> colNames = new ArrayList<>(
		Arrays.asList(params.get(HasFeatureCols.FEATURE_COLS))
	);

	if (params.contains(HasLabelCol.LABEL_COL)) {
		colNames.add(params.get(HasLabelCol.LABEL_COL));
	}

	if (params.get(HasWeightColDefaultAsNull.WEIGHT_COL) != null) {
		colNames.add(params.get(HasWeightColDefaultAsNull.WEIGHT_COL));
	}

	return colNames.toArray(new String[0]);
}
 
Example 10
Source File: BaseSourceBatchOp.java    From Alink with Apache License 2.0 5 votes vote down vote up
public static BaseSourceBatchOp of(Params params) throws Exception {
    if (params.contains(HasIoType.IO_TYPE)
        && params.get(HasIoType.IO_TYPE).equals(IO_TYPE)
        && params.contains(HasIoName.IO_NAME)) {
        if (BaseDB.isDB(params)) {
            return new DBSourceBatchOp(BaseDB.of(params), params);
        } else if (params.contains(HasIoName.IO_NAME)) {
            String name = params.get(HasIoName.IO_NAME);
            return (BaseSourceBatchOp) AnnotationUtils.createOp(name, IO_TYPE, params);
        }
    }
    throw new RuntimeException("Parameter Error.");
}
 
Example 11
Source File: BaseSinkBatchOp.java    From Alink with Apache License 2.0 5 votes vote down vote up
public static BaseSinkBatchOp of(Params params) throws Exception {
	if (params.contains(HasIoType.IO_TYPE)
		&& params.get(HasIoType.IO_TYPE).equals(IO_TYPE)
		&& params.contains(HasIoName.IO_NAME)) {
		if (BaseDB.isDB(params)) {
			return new DBSinkBatchOp(BaseDB.of(params), params);
		} else if (params.contains(HasIoName.IO_NAME)) {
			String name = params.get(HasIoName.IO_NAME);
			return (BaseSinkBatchOp) AnnotationUtils.createOp(name, IO_TYPE, params);
		}
	}
	throw new RuntimeException("Parameter Error.");

}
 
Example 12
Source File: BaseLinearModelTrainBatchOp.java    From Alink with Apache License 2.0 5 votes vote down vote up
/**
 * optimize linear problem
 *
 * @param params     parameters need by optimizer.
 * @param vectorSize vector size.
 * @param trainData  train Data.
 * @param modelType  linear model type.
 * @param session    machine learning environment
 * @return coefficient of linear problem.
 */
public static DataSet<Tuple2<DenseVector, double[]>> optimize(Params params,
                                                              DataSet<Integer> vectorSize,
                                                              DataSet<Tuple3<Double, Double, Vector>> trainData,
                                                              LinearModelType modelType,
                                                              MLEnvironment session) {
    boolean hasInterceptItem = params.get(LinearTrainParams.WITH_INTERCEPT);
    String[] featureColNames = params.get(LinearTrainParams.FEATURE_COLS);
    String vectorColName = params.get(LinearTrainParams.VECTOR_COL);
    if ("".equals(vectorColName)) {
        vectorColName = null;
    }
    if (org.apache.commons.lang3.ArrayUtils.isEmpty(featureColNames)) {
        featureColNames = null;
    }

    DataSet<Integer> coefficientDim;

    if (vectorColName != null && vectorColName.length() != 0) {
        coefficientDim = session.getExecutionEnvironment().fromElements(0)
            .map(new DimTrans(hasInterceptItem, modelType))
            .withBroadcastSet(vectorSize, VECTOR_SIZE);
    } else {
        coefficientDim = session.getExecutionEnvironment().fromElements(featureColNames.length
            + (hasInterceptItem ? 1 : 0) + (modelType.equals(LinearModelType.AFT) ? 1 : 0));
    }

    // Loss object function
    DataSet<OptimObjFunc> objFunc = session.getExecutionEnvironment()
        .fromElements(getObjFunction(modelType, params));

    if (params.contains(LinearTrainParams.OPTIM_METHOD)) {
        LinearTrainParams.OptimMethod method = params.get(LinearTrainParams.OPTIM_METHOD);
        return OptimizerFactory.create(objFunc, trainData, coefficientDim, params, method).optimize();
    } else if (params.get(HasL1.L_1) > 0) {
        return new Owlqn(objFunc, trainData, coefficientDim, params).optimize();
    } else {
        return new Lbfgs(objFunc, trainData, coefficientDim, params).optimize();
    }
}
 
Example 13
Source File: LinearModelData.java    From Alink with Apache License 2.0 5 votes vote down vote up
/**
 * Construct function.
 * @param labelType label Type.
 * @param meta meta information of model.
 * @param featureNames the feature column names.
 * @param coefVector
 */
public LinearModelData(TypeInformation labelType, Params meta, String[] featureNames, DenseVector coefVector) {
	this.labelType = labelType;
	this.coefVector = coefVector;
	this.featureNames = featureNames;
	if (meta.contains(ModelParamName.LABEL_VALUES)) {
		this.labelValues = FeatureLabelUtil.recoverLabelType(meta.get(ModelParamName.LABEL_VALUES), this.labelType);
	}
	setMetaInfo(meta);
}
 
Example 14
Source File: LinearModelData.java    From Alink with Apache License 2.0 5 votes vote down vote up
public void setMetaInfo(Params meta) {
	this.modelName = meta.get(ModelParamName.MODEL_NAME);
	this.linearModelType = meta.contains(ModelParamName.LINEAR_MODEL_TYPE)
		? meta.get(ModelParamName.LINEAR_MODEL_TYPE) : null;
	this.hasInterceptItem = meta.contains(ModelParamName.HAS_INTERCEPT_ITEM) ? meta.get(
		ModelParamName.HAS_INTERCEPT_ITEM) : true;
	this.vectorSize = meta.contains(ModelParamName.VECTOR_SIZE) ? meta.get(ModelParamName.VECTOR_SIZE) : 0;
	this.vectorColName = meta.contains(HasVectorCol.VECTOR_COL) ? meta.get(HasVectorCol.VECTOR_COL) : null;
}
 
Example 15
Source File: LinearModelData.java    From Alink with Apache License 2.0 5 votes vote down vote up
private List <Object> recoverLabelsFromOldFormatModel(Params meta) {

		this.labelType = FlinkTypeConverter.getFlinkType(meta.get(ModelParamName.LABEL_TYPE_NAME));
		List <Object> labels = new ArrayList<>();

		if (meta.contains(ModelParamName.LABEL_VALUES)) {
			Object[] labelValues = FeatureLabelUtil.recoverLabelType(meta.get(ModelParamName.LABEL_VALUES), labelType);
			labels = Arrays.asList(labelValues);
		}
		return labels;
	}
 
Example 16
Source File: BaseSourceStreamOp.java    From Alink with Apache License 2.0 5 votes vote down vote up
public static BaseSourceStreamOp of(Params params) throws Exception {
	if (params.contains(HasIoType.IO_TYPE)
		&& params.get(HasIoType.IO_TYPE).equals(IO_TYPE)
		&& params.contains(HasIoName.IO_NAME)) {
		if (BaseDB.isDB(params)) {
			return new DBSourceStreamOp(BaseDB.of(params), params);
		} else if (params.contains(HasIoName.IO_NAME)) {
			String name = params.get(HasIoName.IO_NAME);
			return (BaseSourceStreamOp) AnnotationUtils.createOp(name, IO_TYPE, params);
		}
	}
	throw new RuntimeException("Parameter Error.");
}
 
Example 17
Source File: LinearModelDataConverter.java    From Alink with Apache License 2.0 5 votes vote down vote up
/**
 * Set the meta information into the linear model data.
 */
private void setMetaInfo(Params meta, LinearModelData data) {
    data.modelName = meta.get(ModelParamName.MODEL_NAME);
    data.linearModelType
        = meta.contains(ModelParamName.LINEAR_MODEL_TYPE) ? meta.get(ModelParamName.LINEAR_MODEL_TYPE) : null;
    data.hasInterceptItem
        = meta.contains(ModelParamName.HAS_INTERCEPT_ITEM) ? meta.get(ModelParamName.HAS_INTERCEPT_ITEM) : true;
    data.vectorSize = meta.contains(ModelParamName.VECTOR_SIZE) ? meta.get(ModelParamName.VECTOR_SIZE) : 0;
    data.vectorColName = meta.contains(HasVectorCol.VECTOR_COL) ? meta.get(HasVectorCol.VECTOR_COL) : null;
    data.labelName = meta.contains(HasLabelCol.LABEL_COL) ? meta.get(HasLabelCol.LABEL_COL) : null;
}
 
Example 18
Source File: BaseSinkStreamOp.java    From Alink with Apache License 2.0 5 votes vote down vote up
public static BaseSinkStreamOp of(Params params) throws Exception {
	if (params.contains(HasIoType.IO_TYPE)
		&& params.get(HasIoType.IO_TYPE).equals(IO_TYPE)
		&& params.contains(HasIoName.IO_NAME)) {
		if (BaseDB.isDB(params)) {
			return new DBSinkStreamOp(BaseDB.of(params), params);
		} else if (params.contains(HasIoName.IO_NAME)) {
			String name = params.get(HasIoName.IO_NAME);
			return (BaseSinkStreamOp) AnnotationUtils.createOp(name, IO_TYPE, params);
		}
	}
	throw new RuntimeException("Parameter Error.");

}
 
Example 19
Source File: FormatTransMapper.java    From Alink with Apache License 2.0 4 votes vote down vote up
public static Tuple2<FormatReader, String[]> initFormatReader(TableSchema dataSchema, Params params) {
    FormatReader formatReader;
    String[] fromColNames;

    FormatType fromFormat = params.get(FormatTransParams.FROM_FORMAT);
    switch (fromFormat) {
        case KV:
            String kvColName = params.get(FromKvParams.KV_COL);
            int kvColIndex = TableUtil.findColIndexWithAssertAndHint(dataSchema.getFieldNames(), kvColName);
            formatReader = new KvReader(
                kvColIndex,
                params.get(FromKvParams.KV_COL_DELIMITER),
                params.get(FromKvParams.KV_VAL_DELIMITER)
            );
            fromColNames = null;
            break;
        case CSV:
            String csvColName = params.get(FromCsvParams.CSV_COL);
            int csvColIndex = TableUtil.findColIndexWithAssertAndHint(dataSchema.getFieldNames(), csvColName);
            TableSchema fromCsvSchema = CsvUtil.schemaStr2Schema(params.get(FromCsvParams.SCHEMA_STR));
            formatReader = new CsvReader(
                csvColIndex,
                fromCsvSchema,
                params.get(FromCsvParams.CSV_FIELD_DELIMITER),
                params.get(FromCsvParams.QUOTE_CHAR)
            );
            fromColNames = fromCsvSchema.getFieldNames();
            break;
        case VECTOR:
            String vectorColName = params.get(FromVectorParams.VECTOR_COL);
            int vectorColIndex = TableUtil.findColIndexWithAssertAndHint(dataSchema.getFieldNames(),
                vectorColName);
            if (params.contains(HasSchemaStr.SCHEMA_STR)) {
                formatReader = new VectorReader(
                    vectorColIndex,
                    CsvUtil.schemaStr2Schema(params.get(HasSchemaStr.SCHEMA_STR))
                );
            } else {
                formatReader = new VectorReader(vectorColIndex, null);
            }
            fromColNames = null;
            break;
        case JSON:
            String jsonColName = params.get(FromJsonParams.JSON_COL);
            int jsonColIndex = TableUtil.findColIndexWithAssertAndHint(dataSchema.getFieldNames(), jsonColName);
            formatReader = new JsonReader(jsonColIndex);
            fromColNames = null;
            break;
        case COLUMNS:
            fromColNames = params.get(FromColumnsParams.SELECTED_COLS);
            if (null == fromColNames) {
                fromColNames = dataSchema.getFieldNames();
            }
            int[] colIndices = TableUtil.findColIndicesWithAssertAndHint(dataSchema.getFieldNames(), fromColNames);
            formatReader = new ColumnsReader(colIndices, fromColNames);
            break;
        default:
            throw new IllegalArgumentException("Can not translate this type : " + fromFormat);
    }

    return new Tuple2<>(formatReader, fromColNames);
}
 
Example 20
Source File: StringIndexerModel.java    From Alink with Apache License 2.0 4 votes vote down vote up
public StringIndexerModel(Params params) {
    super(StringIndexerModelMapper::new, params);
    if (params.contains(StringIndexer.MODEL_NAME)) {
        registerModel(params.get(StringIndexer.MODEL_NAME), this);
    }
}