Repository: spark Updated Branches: refs/heads/branch-1.6 663a492f0 -> 2554c35e7
[SPARK-14563][ML] use a random table name instead of __THIS__ in SQLTransformer ## What changes were proposed in this pull request? Use a random table name instead of `__THIS__` in SQLTransformer, and add a test for `transformSchema`. The problems of using `__THIS__` are: * It doesn't work under HiveContext (in Spark 1.6) * Race conditions ## How was this patch tested? * Manual test with HiveContext. * Added a unit test for `transformSchema` to improve coverage. cc: yhuai Author: Xiangrui Meng <m...@databricks.com> Closes #12330 from mengxr/SPARK-14563. (cherry picked from commit 1995c2e6482bf4af5a4be087bfc156311c1bec19) Signed-off-by: Xiangrui Meng <m...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2554c35e Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2554c35e Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2554c35e Branch: refs/heads/branch-1.6 Commit: 2554c35e77bd9f58a45f2ded7e2ff291af6ecc78 Parents: 663a492 Author: Xiangrui Meng <m...@databricks.com> Authored: Tue Apr 12 11:30:09 2016 -0700 Committer: Xiangrui Meng <m...@databricks.com> Committed: Tue Apr 12 11:30:17 2016 -0700 ---------------------------------------------------------------------- .../org/apache/spark/ml/feature/SQLTransformer.scala | 10 ++++++---- .../org/apache/spark/ml/feature/SQLTransformerSuite.scala | 10 ++++++++++ 2 files changed, 16 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/2554c35e/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index c09f4d0..f5509c4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -68,8 +68,7 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor val tableName = Identifiable.randomUID(uid) dataset.registerTempTable(tableName) val realStatement = $(statement).replace(tableIdentifier, tableName) - val outputDF = dataset.sqlContext.sql(realStatement) - outputDF + dataset.sqlContext.sql(realStatement) } @Since("1.6.0") @@ -78,8 +77,11 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor val sqlContext = SQLContext.getOrCreate(sc) val dummyRDD = sc.parallelize(Seq(Row.empty)) val dummyDF = sqlContext.createDataFrame(dummyRDD, schema) - dummyDF.registerTempTable(tableIdentifier) - val outputSchema = sqlContext.sql($(statement)).schema + val tableName = Identifiable.randomUID(uid) + val realStatement = $(statement).replace(tableIdentifier, tableName) + dummyDF.registerTempTable(tableName) + val outputSchema = sqlContext.sql(realStatement).schema + sqlContext.dropTempTable(tableName) outputSchema } http://git-wip-us.apache.org/repos/asf/spark/blob/2554c35e/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala index 553e0b8..e213e17 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.types.{LongType, StructField, StructType} class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -49,4 +50,13 @@ class SQLTransformerSuite .setStatement("select * from __THIS__") testDefaultReadWrite(t) } + + test("transformSchema") { + val df = sqlContext.range(10) + val outputSchema = new SQLTransformer() + .setStatement("SELECT id + 1 AS id1 FROM __THIS__") + .transformSchema(df.schema) + val expected = StructType(Seq(StructField("id1", LongType, nullable = false))) + assert(outputSchema === expected) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org