Repository: spark Updated Branches: refs/heads/master f362363d1 -> 35d9c8aa6
[SPARK-14747][SQL] Add assertStreaming/assertNoneStreaming checks in DataFrameWriter ## Problem If an end user happens to write code mixed with continuous-query-oriented methods and non-continuous-query-oriented methods: ```scala ctx.read .format("text") .stream("...") // continuous query .write .text("...") // non-continuous query; should be startStream() here ``` He/she would get this somehow confusing exception: > Exception in thread "main" java.lang.AssertionError: assertion failed: No plan for FileSource[./continuous_query_test_input] at scala.Predef$.assert(Predef.scala:170) at org.apache.spark.sql.catalyst.planning.QueryPlanner.plan(QueryPlanner.scala:59) at org.apache.spark.sql.catalyst.planning.QueryPlanner.planLater(QueryPlanner.scala:54) at ... ## What changes were proposed in this pull request? This PR adds checks for continuous-query-oriented methods and non-continuous-query-oriented methods in `DataFrameWriter`: <table> <tr> <td align="center"></td> <td align="center"><strong>can be called on continuous query?</strong></td> <td align="center"><strong>can be called on non-continuous query?</strong></td> </tr> <tr> <td align="center">mode</td> <td align="center"></td> <td align="center">yes</td> </tr> <tr> <td align="center">trigger</td> <td align="center">yes</td> <td align="center"></td> </tr> <tr> <td align="center">format</td> <td align="center">yes</td> <td align="center">yes</td> </tr> <tr> <td align="center">option/options</td> <td align="center">yes</td> <td align="center">yes</td> </tr> <tr> <td align="center">partitionBy</td> <td align="center">yes</td> <td align="center">yes</td> </tr> <tr> <td align="center">bucketBy</td> <td align="center"></td> <td align="center">yes</td> </tr> <tr> <td align="center">sortBy</td> <td align="center"></td> <td align="center">yes</td> </tr> <tr> <td align="center">save</td> <td align="center"></td> <td align="center">yes</td> </tr> <tr> <td align="center">queryName</td> <td align="center">yes</td> <td align="center"></td> </tr> <tr> <td align="center">startStream</td> <td align="center">yes</td> <td align="center"></td> </tr> <tr> <td align="center">insertInto</td> <td align="center"></td> <td align="center">yes</td> </tr> <tr> <td align="center">saveAsTable</td> <td align="center"></td> <td align="center">yes</td> </tr> <tr> <td align="center">jdbc</td> <td align="center"></td> <td align="center">yes</td> </tr> <tr> <td align="center">json</td> <td align="center"></td> <td align="center">yes</td> </tr> <tr> <td align="center">parquet</td> <td align="center"></td> <td align="center">yes</td> </tr> <tr> <td align="center">orc</td> <td align="center"></td> <td align="center">yes</td> </tr> <tr> <td align="center">text</td> <td align="center"></td> <td align="center">yes</td> </tr> <tr> <td align="center">csv</td> <td align="center"></td> <td align="center">yes</td> </tr> </table> After this PR's change, the friendly exception would be: > Exception in thread "main" org.apache.spark.sql.AnalysisException: text() can only be called on non-continuous queries; at org.apache.spark.sql.DataFrameWriter.assertNotStreaming(DataFrameWriter.scala:678) at org.apache.spark.sql.DataFrameWriter.text(DataFrameWriter.scala:629) at ss.SSDemo$.main(SSDemo.scala:47) ## How was this patch tested? dedicated unit tests were added Author: Liwei Lin <lwl...@gmail.com> Closes #12521 from lw-lin/dataframe-writer-check. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/35d9c8aa Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/35d9c8aa Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/35d9c8aa Branch: refs/heads/master Commit: 35d9c8aa69c650f33037813607dc939922c5fc27 Parents: f362363 Author: Liwei Lin <lwl...@gmail.com> Authored: Mon May 2 16:48:20 2016 -0700 Committer: Michael Armbrust <mich...@databricks.com> Committed: Mon May 2 16:48:20 2016 -0700 ---------------------------------------------------------------------- .../org/apache/spark/sql/DataFrameWriter.scala | 59 ++++++- .../streaming/DataFrameReaderWriterSuite.scala | 156 +++++++++++++++++++ 2 files changed, 210 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/35d9c8aa/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index a57d47d..a8f96a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -53,6 +53,9 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def mode(saveMode: SaveMode): DataFrameWriter = { + // mode() is used for non-continuous queries + // outputMode() is used for continuous queries + assertNotStreaming("mode() can only be called on non-continuous queries") this.mode = saveMode this } @@ -67,6 +70,9 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def mode(saveMode: String): DataFrameWriter = { + // mode() is used for non-continuous queries + // outputMode() is used for continuous queries + assertNotStreaming("mode() can only be called on non-continuous queries") this.mode = saveMode.toLowerCase match { case "overwrite" => SaveMode.Overwrite case "append" => SaveMode.Append @@ -103,6 +109,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { */ @Experimental def trigger(trigger: Trigger): DataFrameWriter = { + assertStreaming("trigger() can only be called on continuous queries") this.trigger = trigger this } @@ -236,6 +243,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { */ def save(): Unit = { assertNotBucketed() + assertNotStreaming("save() can only be called on non-continuous queries") val dataSource = DataSource( df.sparkSession, className = source, @@ -253,6 +261,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 2.0.0 */ def queryName(queryName: String): DataFrameWriter = { + assertStreaming("queryName() can only be called on continuous queries") this.extraOptions += ("queryName" -> queryName) this } @@ -276,6 +285,9 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 2.0.0 */ def startStream(): ContinuousQuery = { + assertNotBucketed + assertStreaming("startStream() can only be called on continuous queries") + if (source == "memory") { val queryName = extraOptions.getOrElse( @@ -348,6 +360,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { private def insertInto(tableIdent: TableIdentifier): Unit = { assertNotBucketed() + assertNotStreaming("insertInto() can only be called on non-continuous queries") val partitions = normalizedParCols.map(_.map(col => col -> (None: Option[String])).toMap) val overwrite = mode == SaveMode.Overwrite @@ -446,6 +459,8 @@ final class DataFrameWriter private[sql](df: DataFrame) { } private def saveAsTable(tableIdent: TableIdentifier): Unit = { + assertNotStreaming("saveAsTable() can only be called on non-continuous queries") + val tableExists = df.sparkSession.sessionState.catalog.tableExists(tableIdent) (tableExists, mode) match { @@ -486,6 +501,8 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { + assertNotStreaming("jdbc() can only be called on non-continuous queries") + val props = new Properties() extraOptions.foreach { case (key, value) => props.put(key, value) @@ -542,7 +559,10 @@ final class DataFrameWriter private[sql](df: DataFrame) { * * @since 1.4.0 */ - def json(path: String): Unit = format("json").save(path) + def json(path: String): Unit = { + assertNotStreaming("json() can only be called on non-continuous queries") + format("json").save(path) + } /** * Saves the content of the [[DataFrame]] in Parquet format at the specified path. @@ -558,7 +578,10 @@ final class DataFrameWriter private[sql](df: DataFrame) { * * @since 1.4.0 */ - def parquet(path: String): Unit = format("parquet").save(path) + def parquet(path: String): Unit = { + assertNotStreaming("parquet() can only be called on non-continuous queries") + format("parquet").save(path) + } /** * Saves the content of the [[DataFrame]] in ORC format at the specified path. @@ -575,7 +598,10 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.5.0 * @note Currently, this method can only be used together with `HiveContext`. */ - def orc(path: String): Unit = format("orc").save(path) + def orc(path: String): Unit = { + assertNotStreaming("orc() can only be called on non-continuous queries") + format("orc").save(path) + } /** * Saves the content of the [[DataFrame]] in a text file at the specified path. @@ -596,7 +622,10 @@ final class DataFrameWriter private[sql](df: DataFrame) { * * @since 1.6.0 */ - def text(path: String): Unit = format("text").save(path) + def text(path: String): Unit = { + assertNotStreaming("text() can only be called on non-continuous queries") + format("text").save(path) + } /** * Saves the content of the [[DataFrame]] in CSV format at the specified path. @@ -620,7 +649,10 @@ final class DataFrameWriter private[sql](df: DataFrame) { * * @since 2.0.0 */ - def csv(path: String): Unit = format("csv").save(path) + def csv(path: String): Unit = { + assertNotStreaming("csv() can only be called on non-continuous queries") + format("csv").save(path) + } /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options @@ -641,4 +673,21 @@ final class DataFrameWriter private[sql](df: DataFrame) { private var numBuckets: Option[Int] = None private var sortColumnNames: Option[Seq[String]] = None + + /////////////////////////////////////////////////////////////////////////////////////// + // Helper functions + /////////////////////////////////////////////////////////////////////////////////////// + + private def assertNotStreaming(errMsg: String): Unit = { + if (df.isStreaming) { + throw new AnalysisException(errMsg) + } + } + + private def assertStreaming(errMsg: String): Unit = { + if (!df.isStreaming) { + throw new AnalysisException(errMsg) + } + } + } http://git-wip-us.apache.org/repos/asf/spark/blob/35d9c8aa/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala index 00efe21..c7b2b99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala @@ -368,4 +368,160 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B "org.apache.spark.sql.streaming.test", Map.empty) } + + private def newTextInput = Utils.createTempDir(namePrefix = "text").getCanonicalPath + + test("check trigger() can only be called on continuous queries") { + val df = sqlContext.read.text(newTextInput) + val w = df.write.option("checkpointLocation", newMetadataDir) + val e = intercept[AnalysisException](w.trigger(ProcessingTime("10 seconds"))) + assert(e.getMessage == "trigger() can only be called on continuous queries;") + } + + test("check queryName() can only be called on continuous queries") { + val df = sqlContext.read.text(newTextInput) + val w = df.write.option("checkpointLocation", newMetadataDir) + val e = intercept[AnalysisException](w.queryName("queryName")) + assert(e.getMessage == "queryName() can only be called on continuous queries;") + } + + test("check startStream() can only be called on continuous queries") { + val df = sqlContext.read.text(newTextInput) + val w = df.write.option("checkpointLocation", newMetadataDir) + val e = intercept[AnalysisException](w.startStream()) + assert(e.getMessage == "startStream() can only be called on continuous queries;") + } + + test("check startStream(path) can only be called on continuous queries") { + val df = sqlContext.read.text(newTextInput) + val w = df.write.option("checkpointLocation", newMetadataDir) + val e = intercept[AnalysisException](w.startStream("non_exist_path")) + assert(e.getMessage == "startStream() can only be called on continuous queries;") + } + + test("check mode(SaveMode) can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.mode(SaveMode.Append)) + assert(e.getMessage == "mode() can only be called on non-continuous queries;") + } + + test("check mode(string) can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.mode("append")) + assert(e.getMessage == "mode() can only be called on non-continuous queries;") + } + + test("check bucketBy() can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[IllegalArgumentException](w.bucketBy(1, "text").startStream()) + assert(e.getMessage == "Currently we don't support writing bucketed data to this data source.") + } + + test("check sortBy() can only be called on non-continuous queries;") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[IllegalArgumentException](w.sortBy("text").startStream()) + assert(e.getMessage == "Currently we don't support writing bucketed data to this data source.") + } + + test("check save(path) can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.save("non_exist_path")) + assert(e.getMessage == "save() can only be called on non-continuous queries;") + } + + test("check save() can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.save()) + assert(e.getMessage == "save() can only be called on non-continuous queries;") + } + + test("check insertInto() can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.insertInto("non_exsit_table")) + assert(e.getMessage == "insertInto() can only be called on non-continuous queries;") + } + + test("check saveAsTable() can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.saveAsTable("non_exsit_table")) + assert(e.getMessage == "saveAsTable() can only be called on non-continuous queries;") + } + + test("check jdbc() can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.jdbc(null, null, null)) + assert(e.getMessage == "jdbc() can only be called on non-continuous queries;") + } + + test("check json() can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.json("non_exist_path")) + assert(e.getMessage == "json() can only be called on non-continuous queries;") + } + + test("check parquet() can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.parquet("non_exist_path")) + assert(e.getMessage == "parquet() can only be called on non-continuous queries;") + } + + test("check orc() can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.orc("non_exist_path")) + assert(e.getMessage == "orc() can only be called on non-continuous queries;") + } + + test("check text() can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.text("non_exist_path")) + assert(e.getMessage == "text() can only be called on non-continuous queries;") + } + + test("check csv() can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.csv("non_exist_path")) + assert(e.getMessage == "csv() can only be called on non-continuous queries;") + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org