Repository: spark
Updated Branches:
  refs/heads/master bce72797f -> 8cad854ef


[SPARK-8345] [ML] Add an SQL node as a feature transformer

Implements the transforms which are defined by SQL statement.
Currently we only support SQL syntax like 'SELECT ... FROM __THIS__'
where '__THIS__' represents the underlying table of the input dataset.

Author: Yanbo Liang <yblia...@gmail.com>

Closes #7465 from yanboliang/spark-8345 and squashes the following commits:

b403fcb [Yanbo Liang] address comments
0d4bb15 [Yanbo Liang] a better transformSchema() implementation
51eb9e7 [Yanbo Liang] Add an SQL node as a feature transformer


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8cad854e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8cad854e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8cad854e

Branch: refs/heads/master
Commit: 8cad854ef6a2066de5adffcca6b79a205ccfd5f3
Parents: bce7279
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Tue Aug 11 11:01:59 2015 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Tue Aug 11 11:01:59 2015 -0700

----------------------------------------------------------------------
 .../spark/ml/feature/SQLTransformer.scala       | 72 ++++++++++++++++++++
 .../spark/ml/feature/SQLTransformerSuite.scala  | 44 ++++++++++++
 2 files changed, 116 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8cad854e/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
new file mode 100644
index 0000000..95e4305
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkContext
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.param.{ParamMap, Param}
+import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.sql.{SQLContext, DataFrame, Row}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * :: Experimental ::
+ * Implements the transforms which are defined by SQL statement.
+ * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__'
+ * where '__THIS__' represents the underlying table of the input dataset.
+ */
+@Experimental
+class SQLTransformer (override val uid: String) extends Transformer {
+
+  def this() = this(Identifiable.randomUID("sql"))
+
+  /**
+   * SQL statement parameter. The statement is provided in string form.
+   * @group param
+   */
+  final val statement: Param[String] = new Param[String](this, "statement", 
"SQL statement")
+
+  /** @group setParam */
+  def setStatement(value: String): this.type = set(statement, value)
+
+  /** @group getParam */
+  def getStatement: String = $(statement)
+
+  private val tableIdentifier: String = "__THIS__"
+
+  override def transform(dataset: DataFrame): DataFrame = {
+    val tableName = Identifiable.randomUID(uid)
+    dataset.registerTempTable(tableName)
+    val realStatement = $(statement).replace(tableIdentifier, tableName)
+    val outputDF = dataset.sqlContext.sql(realStatement)
+    outputDF
+  }
+
+  override def transformSchema(schema: StructType): StructType = {
+    val sc = SparkContext.getOrCreate()
+    val sqlContext = SQLContext.getOrCreate(sc)
+    val dummyRDD = sc.parallelize(Seq(Row.empty))
+    val dummyDF = sqlContext.createDataFrame(dummyRDD, schema)
+    dummyDF.registerTempTable(tableIdentifier)
+    val outputSchema = sqlContext.sql($(statement)).schema
+    outputSchema
+  }
+
+  override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8cad854e/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
new file mode 100644
index 0000000..d190528
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+  test("params") {
+    ParamsSuite.checkParams(new SQLTransformer())
+  }
+
+  test("transform numeric data") {
+    val original = sqlContext.createDataFrame(
+      Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2")
+    val sqlTrans = new SQLTransformer().setStatement(
+      "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__")
+    val result = sqlTrans.transform(original)
+    val resultSchema = sqlTrans.transformSchema(original.schema)
+    val expected = sqlContext.createDataFrame(
+      Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0)))
+      .toDF("id", "v1", "v2", "v3", "v4")
+    assert(result.schema.toString == resultSchema.toString)
+    assert(resultSchema == expected.schema)
+    assert(result.collect().toSeq == expected.collect().toSeq)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to