Repository: spark
Updated Branches:
  refs/heads/master a04cab8f1 -> 0f7175def


[SPARK-16350][SQL] Fix support for incremental planning in wirteStream.foreach()

## What changes were proposed in this pull request?

There are cases where `complete` output mode does not output updated aggregated 
value; for details please refer to 
[SPARK-16350](https://issues.apache.org/jira/browse/SPARK-16350).

The cause is that, as we do `data.as[T].foreachPartition { iter => ... }` in 
`ForeachSink.addBatch()`, `foreachPartition()` does not support incremental 
planning for now.

This patches makes `foreachPartition()` support incremental planning in 
`ForeachSink`, by making a special version of `Dataset` with its `rdd()` method 
supporting incremental planning.

## How was this patch tested?

Added a unit test which failed before the change

Author: Liwei Lin <lwl...@gmail.com>

Closes #14030 from lw-lin/fix-foreach-complete.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0f7175de
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0f7175de
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0f7175de

Branch: refs/heads/master
Commit: 0f7175def985a7f1e37198680f893e749612ab76
Parents: a04cab8
Author: Liwei Lin <lwl...@gmail.com>
Authored: Thu Jul 7 10:40:42 2016 -0700
Committer: Shixiong Zhu <shixi...@databricks.com>
Committed: Thu Jul 7 10:40:42 2016 -0700

----------------------------------------------------------------------
 .../sql/execution/streaming/ForeachSink.scala   | 40 ++++++++-
 .../streaming/IncrementalExecution.scala        |  4 +-
 .../execution/streaming/ForeachSinkSuite.scala  | 86 ++++++++++++++++++--
 3 files changed, 117 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0f7175de/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
index 14b9b1c..082664a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
@@ -18,7 +18,9 @@
 package org.apache.spark.sql.execution.streaming
 
 import org.apache.spark.TaskContext
-import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Dataset, Encoder, ForeachWriter}
+import org.apache.spark.sql.catalyst.plans.logical.CatalystSerde
 
 /**
  * A [[Sink]] that forwards all data into [[ForeachWriter]] according to the 
contract defined by
@@ -30,7 +32,41 @@ import org.apache.spark.sql.{DataFrame, Encoder, 
ForeachWriter}
 class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with 
Serializable {
 
   override def addBatch(batchId: Long, data: DataFrame): Unit = {
-    data.as[T].foreachPartition { iter =>
+    // TODO: Refine this method when SPARK-16264 is resolved; see comments 
below.
+
+    // This logic should've been as simple as:
+    // ```
+    //   data.as[T].foreachPartition { iter => ... }
+    // ```
+    //
+    // Unfortunately, doing that would just break the incremental planing. The 
reason is,
+    // `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but 
`Dataset.rdd()` just
+    // does not support `IncrementalExecution`.
+    //
+    // So as a provisional fix, below we've made a special version of 
`Dataset` with its `rdd()`
+    // method supporting incremental planning. But in the long run, we should 
generally make newly
+    // created Datasets use `IncrementalExecution` where necessary (which is 
SPARK-16264 tries to
+    // resolve).
+
+    val datasetWithIncrementalExecution =
+      new Dataset(data.sparkSession, data.logicalPlan, implicitly[Encoder[T]]) 
{
+        override lazy val rdd: RDD[T] = {
+          val objectType = exprEnc.deserializer.dataType
+          val deserialized = CatalystSerde.deserialize[T](logicalPlan)
+
+          // was originally: 
sparkSession.sessionState.executePlan(deserialized) ...
+          val incrementalExecution = new IncrementalExecution(
+            this.sparkSession,
+            deserialized,
+            data.queryExecution.asInstanceOf[IncrementalExecution].outputMode,
+            
data.queryExecution.asInstanceOf[IncrementalExecution].checkpointLocation,
+            
data.queryExecution.asInstanceOf[IncrementalExecution].currentBatchId)
+          incrementalExecution.toRdd.mapPartitions { rows =>
+            rows.map(_.get(0, objectType))
+          }.asInstanceOf[RDD[T]]
+        }
+      }
+    datasetWithIncrementalExecution.foreachPartition { iter =>
       if (writer.open(TaskContext.getPartitionId(), batchId)) {
         var isFailed = false
         try {

http://git-wip-us.apache.org/repos/asf/spark/blob/0f7175de/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 0ce0055..7367c68 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -30,8 +30,8 @@ import org.apache.spark.sql.streaming.OutputMode
 class IncrementalExecution private[sql](
     sparkSession: SparkSession,
     logicalPlan: LogicalPlan,
-    outputMode: OutputMode,
-    checkpointLocation: String,
+    val outputMode: OutputMode,
+    val checkpointLocation: String,
     val currentBatchId: Long)
   extends QueryExecution(sparkSession, logicalPlan) {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0f7175de/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
index 6ff597c..7928b8e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
@@ -24,7 +24,7 @@ import scala.collection.mutable
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.sql.ForeachWriter
-import org.apache.spark.sql.streaming.StreamTest
+import org.apache.spark.sql.streaming.{OutputMode, StreamTest}
 import org.apache.spark.sql.test.SharedSQLContext
 
 class ForeachSinkSuite extends StreamTest with SharedSQLContext with 
BeforeAndAfter {
@@ -35,35 +35,103 @@ class ForeachSinkSuite extends StreamTest with 
SharedSQLContext with BeforeAndAf
     sqlContext.streams.active.foreach(_.stop())
   }
 
-  test("foreach") {
+  test("foreach() with `append` output mode") {
     withTempDir { checkpointDir =>
       val input = MemoryStream[Int]
       val query = input.toDS().repartition(2).writeStream
         .option("checkpointLocation", checkpointDir.getCanonicalPath)
+        .outputMode(OutputMode.Append)
         .foreach(new TestForeachWriter())
         .start()
+
+      // -- batch 0 ---------------------------------------
       input.addData(1, 2, 3, 4)
       query.processAllAvailable()
 
-      val expectedEventsForPartition0 = Seq(
+      var expectedEventsForPartition0 = Seq(
         ForeachSinkSuite.Open(partition = 0, version = 0),
         ForeachSinkSuite.Process(value = 1),
         ForeachSinkSuite.Process(value = 3),
         ForeachSinkSuite.Close(None)
       )
-      val expectedEventsForPartition1 = Seq(
+      var expectedEventsForPartition1 = Seq(
         ForeachSinkSuite.Open(partition = 1, version = 0),
         ForeachSinkSuite.Process(value = 2),
         ForeachSinkSuite.Process(value = 4),
         ForeachSinkSuite.Close(None)
       )
 
-      val allEvents = ForeachSinkSuite.allEvents()
+      var allEvents = ForeachSinkSuite.allEvents()
+      assert(allEvents.size === 2)
+      assert(allEvents.toSet === Set(expectedEventsForPartition0, 
expectedEventsForPartition1))
+
+      ForeachSinkSuite.clear()
+
+      // -- batch 1 ---------------------------------------
+      input.addData(5, 6, 7, 8)
+      query.processAllAvailable()
+
+      expectedEventsForPartition0 = Seq(
+        ForeachSinkSuite.Open(partition = 0, version = 1),
+        ForeachSinkSuite.Process(value = 5),
+        ForeachSinkSuite.Process(value = 7),
+        ForeachSinkSuite.Close(None)
+      )
+      expectedEventsForPartition1 = Seq(
+        ForeachSinkSuite.Open(partition = 1, version = 1),
+        ForeachSinkSuite.Process(value = 6),
+        ForeachSinkSuite.Process(value = 8),
+        ForeachSinkSuite.Close(None)
+      )
+
+      allEvents = ForeachSinkSuite.allEvents()
       assert(allEvents.size === 2)
-      assert {
-        allEvents === Seq(expectedEventsForPartition0, 
expectedEventsForPartition1) ||
-          allEvents === Seq(expectedEventsForPartition1, 
expectedEventsForPartition0)
-      }
+      assert(allEvents.toSet === Set(expectedEventsForPartition0, 
expectedEventsForPartition1))
+
+      query.stop()
+    }
+  }
+
+  test("foreach() with `complete` output mode") {
+    withTempDir { checkpointDir =>
+      val input = MemoryStream[Int]
+
+      val query = input.toDS()
+        .groupBy().count().as[Long].map(_.toInt)
+        .writeStream
+        .option("checkpointLocation", checkpointDir.getCanonicalPath)
+        .outputMode(OutputMode.Complete)
+        .foreach(new TestForeachWriter())
+        .start()
+
+      // -- batch 0 ---------------------------------------
+      input.addData(1, 2, 3, 4)
+      query.processAllAvailable()
+
+      var allEvents = ForeachSinkSuite.allEvents()
+      assert(allEvents.size === 1)
+      var expectedEvents = Seq(
+        ForeachSinkSuite.Open(partition = 0, version = 0),
+        ForeachSinkSuite.Process(value = 4),
+        ForeachSinkSuite.Close(None)
+      )
+      assert(allEvents === Seq(expectedEvents))
+
+      ForeachSinkSuite.clear()
+
+      // -- batch 1 ---------------------------------------
+      input.addData(5, 6, 7, 8)
+      query.processAllAvailable()
+
+      allEvents = ForeachSinkSuite.allEvents()
+      assert(allEvents.size === 1)
+      expectedEvents = Seq(
+        ForeachSinkSuite.Open(partition = 0, version = 1),
+        ForeachSinkSuite.Process(value = 8),
+        ForeachSinkSuite.Close(None)
+      )
+      assert(allEvents === Seq(expectedEvents))
+
       query.stop()
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to