/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.spark.ml.feature import org.apache.spark.SparkContext import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.param.{ParamMap, Param} import org.apache.spark.ml.Transformer import org.apache.spark.ml.util._ import org.apache.spark.sql.{SQLContext, DataFrame, Row} import org.apache.spark.sql.types.StructType /** * :: Experimental :: * Implements the transformations which are defined by SQL statement. * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__ ...' * where '__THIS__' represents the underlying table of the input dataset. * The select clause specifies the fields, constants, and expressions to display in * the output, it can be any select clause that Spark SQL supports. Users can also * use Spark SQL built-in function and UDFs to operate on these selected columns. * For example, [[SQLTransformer]] supports statements like: * - SELECT a, a + b AS a_b FROM __THIS__ * - SELECT a, SQRT(b) AS b_sqrt FROM __THIS__ where a > 5 * - SELECT a, b, SUM(c) AS c_sum FROM __THIS__ GROUP BY a, b */ @Experimental @Since("1.6.0") class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transformer with DefaultParamsWritable { @Since("1.6.0") def this() = this(Identifiable.randomUID("sql")) /** * SQL statement parameter. The statement is provided in string form. * @group param */ @Since("1.6.0") final val statement: Param[String] = new Param[String](this, "statement", "SQL statement") /** @group setParam */ @Since("1.6.0") def setStatement(value: String): this.type = set(statement, value) /** @group getParam */ @Since("1.6.0") def getStatement: String = $(statement) private val tableIdentifier: String = "__THIS__" @Since("1.6.0") override def transform(dataset: DataFrame): DataFrame = { val tableName = Identifiable.randomUID(uid) dataset.registerTempTable(tableName) val realStatement = $(statement).replace(tableIdentifier, tableName) val outputDF = dataset.sqlContext.sql(realStatement) outputDF } @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { val sc = SparkContext.getOrCreate() val sqlContext = SQLContext.getOrCreate(sc) val dummyRDD = sc.parallelize(Seq(Row.empty)) val dummyDF = sqlContext.createDataFrame(dummyRDD, schema) dummyDF.registerTempTable(tableIdentifier) val outputSchema = sqlContext.sql($(statement)).schema outputSchema } @Since("1.6.0") override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra) } @Since("1.6.0") object SQLTransformer extends DefaultParamsReadable[SQLTransformer] { @Since("1.6.0") override def load(path: String): SQLTransformer = super.load(path) }