Repository: spark
Updated Branches:
  refs/heads/branch-1.6 663a492f0 -> 2554c35e7


[SPARK-14563][ML] use a random table name instead of __THIS__ in SQLTransformer

## What changes were proposed in this pull request?

Use a random table name instead of `__THIS__` in SQLTransformer, and add a test 
for `transformSchema`. The problems of using `__THIS__` are:

* It doesn't work under HiveContext (in Spark 1.6)
* Race conditions

## How was this patch tested?

* Manual test with HiveContext.
* Added a unit test for `transformSchema` to improve coverage.

cc: yhuai

Author: Xiangrui Meng <m...@databricks.com>

Closes #12330 from mengxr/SPARK-14563.

(cherry picked from commit 1995c2e6482bf4af5a4be087bfc156311c1bec19)
Signed-off-by: Xiangrui Meng <m...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2554c35e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2554c35e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2554c35e

Branch: refs/heads/branch-1.6
Commit: 2554c35e77bd9f58a45f2ded7e2ff291af6ecc78
Parents: 663a492
Author: Xiangrui Meng <m...@databricks.com>
Authored: Tue Apr 12 11:30:09 2016 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Tue Apr 12 11:30:17 2016 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/SQLTransformer.scala      | 10 ++++++----
 .../org/apache/spark/ml/feature/SQLTransformerSuite.scala | 10 ++++++++++
 2 files changed, 16 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2554c35e/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
index c09f4d0..f5509c4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
@@ -68,8 +68,7 @@ class SQLTransformer @Since("1.6.0") (override val uid: 
String) extends Transfor
     val tableName = Identifiable.randomUID(uid)
     dataset.registerTempTable(tableName)
     val realStatement = $(statement).replace(tableIdentifier, tableName)
-    val outputDF = dataset.sqlContext.sql(realStatement)
-    outputDF
+    dataset.sqlContext.sql(realStatement)
   }
 
   @Since("1.6.0")
@@ -78,8 +77,11 @@ class SQLTransformer @Since("1.6.0") (override val uid: 
String) extends Transfor
     val sqlContext = SQLContext.getOrCreate(sc)
     val dummyRDD = sc.parallelize(Seq(Row.empty))
     val dummyDF = sqlContext.createDataFrame(dummyRDD, schema)
-    dummyDF.registerTempTable(tableIdentifier)
-    val outputSchema = sqlContext.sql($(statement)).schema
+    val tableName = Identifiable.randomUID(uid)
+    val realStatement = $(statement).replace(tableIdentifier, tableName)
+    dummyDF.registerTempTable(tableName)
+    val outputSchema = sqlContext.sql(realStatement).schema
+    sqlContext.dropTempTable(tableName)
     outputSchema
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2554c35e/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
index 553e0b8..e213e17 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.types.{LongType, StructField, StructType}
 
 class SQLTransformerSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -49,4 +50,13 @@ class SQLTransformerSuite
       .setStatement("select * from __THIS__")
     testDefaultReadWrite(t)
   }
+
+  test("transformSchema") {
+    val df = sqlContext.range(10)
+    val outputSchema = new SQLTransformer()
+      .setStatement("SELECT id + 1 AS id1 FROM __THIS__")
+      .transformSchema(df.schema)
+    val expected = StructType(Seq(StructField("id1", LongType, nullable = 
false)))
+    assert(outputSchema === expected)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to