Repository: spark
Updated Branches:
  refs/heads/master 7cf9fab33 -> 66a3a5a2d


[SPARK-23099][SS] Migrate foreach sink to DataSourceV2

## What changes were proposed in this pull request?

Migrate foreach sink to DataSourceV2.

Since the previous attempt at this PR #20552, we've changed and strictly 
defined the lifecycle of writer components. This means we no longer need the 
complicated lifecycle shim from that PR; it just naturally works.

## How was this patch tested?

existing tests

Author: Jose Torres <torres.joseph.f+git...@gmail.com>

Closes #20951 from jose-torres/foreach.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/66a3a5a2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/66a3a5a2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/66a3a5a2

Branch: refs/heads/master
Commit: 66a3a5a2dc83e03dedcee9839415c1ddc1fb8125
Parents: 7cf9fab
Author: Jose Torres <torres.joseph.f+git...@gmail.com>
Authored: Tue Apr 3 11:05:29 2018 -0700
Committer: Tathagata Das <tathagata.das1...@gmail.com>
Committed: Tue Apr 3 11:05:29 2018 -0700

----------------------------------------------------------------------
 .../sql/execution/streaming/ForeachSink.scala   |  68 -----
 .../sources/ForeachWriterProvider.scala         | 111 +++++++
 .../spark/sql/streaming/DataStreamWriter.scala  |   4 +-
 .../execution/streaming/ForeachSinkSuite.scala  | 305 ------------------
 .../streaming/sources/ForeachWriterSuite.scala  | 306 +++++++++++++++++++
 .../sql/streaming/StreamingQuerySuite.scala     |   1 +
 6 files changed, 420 insertions(+), 375 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/66a3a5a2/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
deleted file mode 100644
index 2cc5410..0000000
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
+++ /dev/null
@@ -1,68 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming
-
-import org.apache.spark.TaskContext
-import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter}
-import org.apache.spark.sql.catalyst.encoders.encoderFor
-
-/**
- * A [[Sink]] that forwards all data into [[ForeachWriter]] according to the 
contract defined by
- * [[ForeachWriter]].
- *
- * @param writer The [[ForeachWriter]] to process all data.
- * @tparam T The expected type of the sink.
- */
-class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with 
Serializable {
-
-  override def addBatch(batchId: Long, data: DataFrame): Unit = {
-    // This logic should've been as simple as:
-    // ```
-    //   data.as[T].foreachPartition { iter => ... }
-    // ```
-    //
-    // Unfortunately, doing that would just break the incremental planing. The 
reason is,
-    // `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but 
`Dataset.rdd()` will
-    // create a new plan. Because StreamExecution uses the existing plan to 
collect metrics and
-    // update watermark, we should never create a new plan. Otherwise, metrics 
and watermark are
-    // updated in the new plan, and StreamExecution cannot retrieval them.
-    //
-    // Hence, we need to manually convert internal rows to objects using 
encoder.
-    val encoder = encoderFor[T].resolveAndBind(
-      data.logicalPlan.output,
-      data.sparkSession.sessionState.analyzer)
-    data.queryExecution.toRdd.foreachPartition { iter =>
-      if (writer.open(TaskContext.getPartitionId(), batchId)) {
-        try {
-          while (iter.hasNext) {
-            writer.process(encoder.fromRow(iter.next()))
-          }
-        } catch {
-          case e: Throwable =>
-            writer.close(e)
-            throw e
-        }
-        writer.close(null)
-      } else {
-        writer.close(null)
-      }
-    }
-  }
-
-  override def toString(): String = "ForeachSink"
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/66a3a5a2/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala
new file mode 100644
index 0000000..df5d69d
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport}
+import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, 
SupportsWriteInternalRow, WriterCommitMessage}
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.StructType
+
+/**
+ * A [[org.apache.spark.sql.sources.v2.DataSourceV2]] for forwarding data into 
the specified
+ * [[ForeachWriter]].
+ *
+ * @param writer The [[ForeachWriter]] to process all data.
+ * @tparam T The expected type of the sink.
+ */
+case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends 
StreamWriteSupport {
+  override def createStreamWriter(
+      queryId: String,
+      schema: StructType,
+      mode: OutputMode,
+      options: DataSourceOptions): StreamWriter = {
+    new StreamWriter with SupportsWriteInternalRow {
+      override def commit(epochId: Long, messages: 
Array[WriterCommitMessage]): Unit = {}
+      override def abort(epochId: Long, messages: Array[WriterCommitMessage]): 
Unit = {}
+
+      override def createInternalRowWriterFactory(): 
DataWriterFactory[InternalRow] = {
+        val encoder = encoderFor[T].resolveAndBind(
+          schema.toAttributes,
+          SparkSession.getActiveSession.get.sessionState.analyzer)
+        ForeachWriterFactory(writer, encoder)
+      }
+
+      override def toString: String = "ForeachSink"
+    }
+  }
+}
+
+case class ForeachWriterFactory[T: Encoder](
+    writer: ForeachWriter[T],
+    encoder: ExpressionEncoder[T])
+  extends DataWriterFactory[InternalRow] {
+  override def createDataWriter(
+      partitionId: Int,
+      attemptNumber: Int,
+      epochId: Long): ForeachDataWriter[T] = {
+    new ForeachDataWriter(writer, encoder, partitionId, epochId)
+  }
+}
+
+/**
+ * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]].
+ * @param writer The [[ForeachWriter]] to process all data.
+ * @param encoder An encoder which can convert [[InternalRow]] to the required 
type [[T]]
+ * @param partitionId
+ * @param epochId
+ * @tparam T The type expected by the writer.
+ */
+class ForeachDataWriter[T : Encoder](
+    writer: ForeachWriter[T],
+    encoder: ExpressionEncoder[T],
+    partitionId: Int,
+    epochId: Long)
+  extends DataWriter[InternalRow] {
+
+  // If open returns false, we should skip writing rows.
+  private val opened = writer.open(partitionId, epochId)
+
+  override def write(record: InternalRow): Unit = {
+    if (!opened) return
+
+    try {
+      writer.process(encoder.fromRow(record))
+    } catch {
+      case t: Throwable =>
+        writer.close(t)
+        throw t
+    }
+  }
+
+  override def commit(): WriterCommitMessage = {
+    writer.close(null)
+    ForeachWriterCommitMessage
+  }
+
+  override def abort(): Unit = {}
+}
+
+/**
+ * An empty [[WriterCommitMessage]]. [[ForeachWriter]] implementations have no 
global coordination.
+ */
+case object ForeachWriterCommitMessage extends WriterCommitMessage

http://git-wip-us.apache.org/repos/asf/spark/blob/66a3a5a2/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 2fc9031..effc147 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.command.DDLUtils
 import org.apache.spark.sql.execution.datasources.DataSource
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
-import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, 
MemorySinkV2}
+import 
org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, 
MemoryPlanV2, MemorySinkV2}
 import org.apache.spark.sql.sources.v2.StreamWriteSupport
 
 /**
@@ -269,7 +269,7 @@ final class DataStreamWriter[T] private[sql](ds: 
Dataset[T]) {
       query
     } else if (source == "foreach") {
       assertNotPartitioned("foreach")
-      val sink = new ForeachSink[T](foreachWriter)(ds.exprEnc)
+      val sink = new ForeachWriterProvider[T](foreachWriter)(ds.exprEnc)
       df.sparkSession.sessionState.streamingQueryManager.startQuery(
         extraOptions.get("queryName"),
         extraOptions.get("checkpointLocation"),

http://git-wip-us.apache.org/repos/asf/spark/blob/66a3a5a2/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
deleted file mode 100644
index b249dd4..0000000
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
+++ /dev/null
@@ -1,305 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming
-
-import java.util.concurrent.ConcurrentLinkedQueue
-
-import scala.collection.mutable
-
-import org.scalatest.BeforeAndAfter
-
-import org.apache.spark.SparkException
-import org.apache.spark.sql.ForeachWriter
-import org.apache.spark.sql.functions.{count, window}
-import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, 
StreamTest}
-import org.apache.spark.sql.test.SharedSQLContext
-
-class ForeachSinkSuite extends StreamTest with SharedSQLContext with 
BeforeAndAfter {
-
-  import testImplicits._
-
-  after {
-    sqlContext.streams.active.foreach(_.stop())
-  }
-
-  test("foreach() with `append` output mode") {
-    withTempDir { checkpointDir =>
-      val input = MemoryStream[Int]
-      val query = input.toDS().repartition(2).writeStream
-        .option("checkpointLocation", checkpointDir.getCanonicalPath)
-        .outputMode(OutputMode.Append)
-        .foreach(new TestForeachWriter())
-        .start()
-
-      def verifyOutput(expectedVersion: Int, expectedData: Seq[Int]): Unit = {
-        import ForeachSinkSuite._
-
-        val events = ForeachSinkSuite.allEvents()
-        assert(events.size === 2) // one seq of events for each of the 2 
partitions
-
-        // Verify both seq of events have an Open event as the first event
-        assert(events.map(_.head).toSet === Set(0, 1).map(p => Open(p, 
expectedVersion)))
-
-        // Verify all the Process event correspond to the expected data
-        val allProcessEvents = 
events.flatMap(_.filter(_.isInstanceOf[Process[_]]))
-        assert(allProcessEvents.toSet === expectedData.map { data => 
Process(data) }.toSet)
-
-        // Verify both seq of events have a Close event as the last event
-        assert(events.map(_.last).toSet === Set(Close(None), Close(None)))
-      }
-
-      // -- batch 0 ---------------------------------------
-      ForeachSinkSuite.clear()
-      input.addData(1, 2, 3, 4)
-      query.processAllAvailable()
-      verifyOutput(expectedVersion = 0, expectedData = 1 to 4)
-
-      // -- batch 1 ---------------------------------------
-      ForeachSinkSuite.clear()
-      input.addData(5, 6, 7, 8)
-      query.processAllAvailable()
-      verifyOutput(expectedVersion = 1, expectedData = 5 to 8)
-
-      query.stop()
-    }
-  }
-
-  test("foreach() with `complete` output mode") {
-    withTempDir { checkpointDir =>
-      val input = MemoryStream[Int]
-
-      val query = input.toDS()
-        .groupBy().count().as[Long].map(_.toInt)
-        .writeStream
-        .option("checkpointLocation", checkpointDir.getCanonicalPath)
-        .outputMode(OutputMode.Complete)
-        .foreach(new TestForeachWriter())
-        .start()
-
-      // -- batch 0 ---------------------------------------
-      input.addData(1, 2, 3, 4)
-      query.processAllAvailable()
-
-      var allEvents = ForeachSinkSuite.allEvents()
-      assert(allEvents.size === 1)
-      var expectedEvents = Seq(
-        ForeachSinkSuite.Open(partition = 0, version = 0),
-        ForeachSinkSuite.Process(value = 4),
-        ForeachSinkSuite.Close(None)
-      )
-      assert(allEvents === Seq(expectedEvents))
-
-      ForeachSinkSuite.clear()
-
-      // -- batch 1 ---------------------------------------
-      input.addData(5, 6, 7, 8)
-      query.processAllAvailable()
-
-      allEvents = ForeachSinkSuite.allEvents()
-      assert(allEvents.size === 1)
-      expectedEvents = Seq(
-        ForeachSinkSuite.Open(partition = 0, version = 1),
-        ForeachSinkSuite.Process(value = 8),
-        ForeachSinkSuite.Close(None)
-      )
-      assert(allEvents === Seq(expectedEvents))
-
-      query.stop()
-    }
-  }
-
-  testQuietly("foreach with error") {
-    withTempDir { checkpointDir =>
-      val input = MemoryStream[Int]
-      val query = input.toDS().repartition(1).writeStream
-        .option("checkpointLocation", checkpointDir.getCanonicalPath)
-        .foreach(new TestForeachWriter() {
-          override def process(value: Int): Unit = {
-            super.process(value)
-            throw new RuntimeException("error")
-          }
-        }).start()
-      input.addData(1, 2, 3, 4)
-
-      // Error in `process` should fail the Spark job
-      val e = intercept[StreamingQueryException] {
-        query.processAllAvailable()
-      }
-      assert(e.getCause.isInstanceOf[SparkException])
-      assert(e.getCause.getCause.getMessage === "error")
-      assert(query.isActive === false)
-
-      val allEvents = ForeachSinkSuite.allEvents()
-      assert(allEvents.size === 1)
-      assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version 
= 0))
-      assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1))
-
-      // `close` should be called with the error
-      val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close]
-      assert(errorEvent.error.get.isInstanceOf[RuntimeException])
-      assert(errorEvent.error.get.getMessage === "error")
-    }
-  }
-
-  test("foreach with watermark: complete") {
-    val inputData = MemoryStream[Int]
-
-    val windowedAggregation = inputData.toDF()
-      .withColumn("eventTime", $"value".cast("timestamp"))
-      .withWatermark("eventTime", "10 seconds")
-      .groupBy(window($"eventTime", "5 seconds") as 'window)
-      .agg(count("*") as 'count)
-      .select($"count".as[Long])
-      .map(_.toInt)
-      .repartition(1)
-
-    val query = windowedAggregation
-      .writeStream
-      .outputMode(OutputMode.Complete)
-      .foreach(new TestForeachWriter())
-      .start()
-    try {
-      inputData.addData(10, 11, 12)
-      query.processAllAvailable()
-
-      val allEvents = ForeachSinkSuite.allEvents()
-      assert(allEvents.size === 1)
-      val expectedEvents = Seq(
-        ForeachSinkSuite.Open(partition = 0, version = 0),
-        ForeachSinkSuite.Process(value = 3),
-        ForeachSinkSuite.Close(None)
-      )
-      assert(allEvents === Seq(expectedEvents))
-    } finally {
-      query.stop()
-    }
-  }
-
-  test("foreach with watermark: append") {
-    val inputData = MemoryStream[Int]
-
-    val windowedAggregation = inputData.toDF()
-      .withColumn("eventTime", $"value".cast("timestamp"))
-      .withWatermark("eventTime", "10 seconds")
-      .groupBy(window($"eventTime", "5 seconds") as 'window)
-      .agg(count("*") as 'count)
-      .select($"count".as[Long])
-      .map(_.toInt)
-      .repartition(1)
-
-    val query = windowedAggregation
-      .writeStream
-      .outputMode(OutputMode.Append)
-      .foreach(new TestForeachWriter())
-      .start()
-    try {
-      inputData.addData(10, 11, 12)
-      query.processAllAvailable()
-      inputData.addData(25) // Advance watermark to 15 seconds
-      query.processAllAvailable()
-      inputData.addData(25) // Evict items less than previous watermark
-      query.processAllAvailable()
-
-      // There should be 3 batches and only does the last batch contain a 
value.
-      val allEvents = ForeachSinkSuite.allEvents()
-      assert(allEvents.size === 3)
-      val expectedEvents = Seq(
-        Seq(
-          ForeachSinkSuite.Open(partition = 0, version = 0),
-          ForeachSinkSuite.Close(None)
-        ),
-        Seq(
-          ForeachSinkSuite.Open(partition = 0, version = 1),
-          ForeachSinkSuite.Close(None)
-        ),
-        Seq(
-          ForeachSinkSuite.Open(partition = 0, version = 2),
-          ForeachSinkSuite.Process(value = 3),
-          ForeachSinkSuite.Close(None)
-        )
-      )
-      assert(allEvents === expectedEvents)
-    } finally {
-      query.stop()
-    }
-  }
-
-  test("foreach sink should support metrics") {
-    val inputData = MemoryStream[Int]
-    val query = inputData.toDS()
-      .writeStream
-      .foreach(new TestForeachWriter())
-      .start()
-    try {
-      inputData.addData(10, 11, 12)
-      query.processAllAvailable()
-      val recentProgress = query.recentProgress.filter(_.numInputRows != 
0).headOption
-      assert(recentProgress.isDefined && recentProgress.get.numInputRows === 3,
-        s"recentProgress[${query.recentProgress.toList}] doesn't contain 
correct metrics")
-    } finally {
-      query.stop()
-    }
-  }
-}
-
-/** A global object to collect events in the executor */
-object ForeachSinkSuite {
-
-  trait Event
-
-  case class Open(partition: Long, version: Long) extends Event
-
-  case class Process[T](value: T) extends Event
-
-  case class Close(error: Option[Throwable]) extends Event
-
-  private val _allEvents = new ConcurrentLinkedQueue[Seq[Event]]()
-
-  def addEvents(events: Seq[Event]): Unit = {
-    _allEvents.add(events)
-  }
-
-  def allEvents(): Seq[Seq[Event]] = {
-    _allEvents.toArray(new Array[Seq[Event]](_allEvents.size()))
-  }
-
-  def clear(): Unit = {
-    _allEvents.clear()
-  }
-}
-
-/** A [[ForeachWriter]] that writes collected events to ForeachSinkSuite */
-class TestForeachWriter extends ForeachWriter[Int] {
-  ForeachSinkSuite.clear()
-
-  private val events = mutable.ArrayBuffer[ForeachSinkSuite.Event]()
-
-  override def open(partitionId: Long, version: Long): Boolean = {
-    events += ForeachSinkSuite.Open(partition = partitionId, version = version)
-    true
-  }
-
-  override def process(value: Int): Unit = {
-    events += ForeachSinkSuite.Process(value)
-  }
-
-  override def close(errorOrNull: Throwable): Unit = {
-    events += ForeachSinkSuite.Close(error = Option(errorOrNull))
-    ForeachSinkSuite.addEvents(events)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/66a3a5a2/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala
new file mode 100644
index 0000000..03bf71b
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala
@@ -0,0 +1,306 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import java.util.concurrent.ConcurrentLinkedQueue
+
+import scala.collection.mutable
+
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.ForeachWriter
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.functions.{count, window}
+import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, 
StreamTest}
+import org.apache.spark.sql.test.SharedSQLContext
+
+class ForeachWriterSuite extends StreamTest with SharedSQLContext with 
BeforeAndAfter {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  test("foreach() with `append` output mode") {
+    withTempDir { checkpointDir =>
+      val input = MemoryStream[Int]
+      val query = input.toDS().repartition(2).writeStream
+        .option("checkpointLocation", checkpointDir.getCanonicalPath)
+        .outputMode(OutputMode.Append)
+        .foreach(new TestForeachWriter())
+        .start()
+
+      def verifyOutput(expectedVersion: Int, expectedData: Seq[Int]): Unit = {
+        import ForeachWriterSuite._
+
+        val events = ForeachWriterSuite.allEvents()
+        assert(events.size === 2) // one seq of events for each of the 2 
partitions
+
+        // Verify both seq of events have an Open event as the first event
+        assert(events.map(_.head).toSet === Set(0, 1).map(p => Open(p, 
expectedVersion)))
+
+        // Verify all the Process event correspond to the expected data
+        val allProcessEvents = 
events.flatMap(_.filter(_.isInstanceOf[Process[_]]))
+        assert(allProcessEvents.toSet === expectedData.map { data => 
Process(data) }.toSet)
+
+        // Verify both seq of events have a Close event as the last event
+        assert(events.map(_.last).toSet === Set(Close(None), Close(None)))
+      }
+
+      // -- batch 0 ---------------------------------------
+      ForeachWriterSuite.clear()
+      input.addData(1, 2, 3, 4)
+      query.processAllAvailable()
+      verifyOutput(expectedVersion = 0, expectedData = 1 to 4)
+
+      // -- batch 1 ---------------------------------------
+      ForeachWriterSuite.clear()
+      input.addData(5, 6, 7, 8)
+      query.processAllAvailable()
+      verifyOutput(expectedVersion = 1, expectedData = 5 to 8)
+
+      query.stop()
+    }
+  }
+
+  test("foreach() with `complete` output mode") {
+    withTempDir { checkpointDir =>
+      val input = MemoryStream[Int]
+
+      val query = input.toDS()
+        .groupBy().count().as[Long].map(_.toInt)
+        .writeStream
+        .option("checkpointLocation", checkpointDir.getCanonicalPath)
+        .outputMode(OutputMode.Complete)
+        .foreach(new TestForeachWriter())
+        .start()
+
+      // -- batch 0 ---------------------------------------
+      input.addData(1, 2, 3, 4)
+      query.processAllAvailable()
+
+      var allEvents = ForeachWriterSuite.allEvents()
+      assert(allEvents.size === 1)
+      var expectedEvents = Seq(
+        ForeachWriterSuite.Open(partition = 0, version = 0),
+        ForeachWriterSuite.Process(value = 4),
+        ForeachWriterSuite.Close(None)
+      )
+      assert(allEvents === Seq(expectedEvents))
+
+      ForeachWriterSuite.clear()
+
+      // -- batch 1 ---------------------------------------
+      input.addData(5, 6, 7, 8)
+      query.processAllAvailable()
+
+      allEvents = ForeachWriterSuite.allEvents()
+      assert(allEvents.size === 1)
+      expectedEvents = Seq(
+        ForeachWriterSuite.Open(partition = 0, version = 1),
+        ForeachWriterSuite.Process(value = 8),
+        ForeachWriterSuite.Close(None)
+      )
+      assert(allEvents === Seq(expectedEvents))
+
+      query.stop()
+    }
+  }
+
+  testQuietly("foreach with error") {
+    withTempDir { checkpointDir =>
+      val input = MemoryStream[Int]
+      val query = input.toDS().repartition(1).writeStream
+        .option("checkpointLocation", checkpointDir.getCanonicalPath)
+        .foreach(new TestForeachWriter() {
+          override def process(value: Int): Unit = {
+            super.process(value)
+            throw new RuntimeException("ForeachSinkSuite error")
+          }
+        }).start()
+      input.addData(1, 2, 3, 4)
+
+      // Error in `process` should fail the Spark job
+      val e = intercept[StreamingQueryException] {
+        query.processAllAvailable()
+      }
+      assert(e.getCause.isInstanceOf[SparkException])
+      assert(e.getCause.getCause.getCause.getMessage === "ForeachSinkSuite 
error")
+      assert(query.isActive === false)
+
+      val allEvents = ForeachWriterSuite.allEvents()
+      assert(allEvents.size === 1)
+      assert(allEvents(0)(0) === ForeachWriterSuite.Open(partition = 0, 
version = 0))
+      assert(allEvents(0)(1) === ForeachWriterSuite.Process(value = 1))
+
+      // `close` should be called with the error
+      val errorEvent = allEvents(0)(2).asInstanceOf[ForeachWriterSuite.Close]
+      assert(errorEvent.error.get.isInstanceOf[RuntimeException])
+      assert(errorEvent.error.get.getMessage === "ForeachSinkSuite error")
+    }
+  }
+
+  test("foreach with watermark: complete") {
+    val inputData = MemoryStream[Int]
+
+    val windowedAggregation = inputData.toDF()
+      .withColumn("eventTime", $"value".cast("timestamp"))
+      .withWatermark("eventTime", "10 seconds")
+      .groupBy(window($"eventTime", "5 seconds") as 'window)
+      .agg(count("*") as 'count)
+      .select($"count".as[Long])
+      .map(_.toInt)
+      .repartition(1)
+
+    val query = windowedAggregation
+      .writeStream
+      .outputMode(OutputMode.Complete)
+      .foreach(new TestForeachWriter())
+      .start()
+    try {
+      inputData.addData(10, 11, 12)
+      query.processAllAvailable()
+
+      val allEvents = ForeachWriterSuite.allEvents()
+      assert(allEvents.size === 1)
+      val expectedEvents = Seq(
+        ForeachWriterSuite.Open(partition = 0, version = 0),
+        ForeachWriterSuite.Process(value = 3),
+        ForeachWriterSuite.Close(None)
+      )
+      assert(allEvents === Seq(expectedEvents))
+    } finally {
+      query.stop()
+    }
+  }
+
+  test("foreach with watermark: append") {
+    val inputData = MemoryStream[Int]
+
+    val windowedAggregation = inputData.toDF()
+      .withColumn("eventTime", $"value".cast("timestamp"))
+      .withWatermark("eventTime", "10 seconds")
+      .groupBy(window($"eventTime", "5 seconds") as 'window)
+      .agg(count("*") as 'count)
+      .select($"count".as[Long])
+      .map(_.toInt)
+      .repartition(1)
+
+    val query = windowedAggregation
+      .writeStream
+      .outputMode(OutputMode.Append)
+      .foreach(new TestForeachWriter())
+      .start()
+    try {
+      inputData.addData(10, 11, 12)
+      query.processAllAvailable()
+      inputData.addData(25) // Advance watermark to 15 seconds
+      query.processAllAvailable()
+      inputData.addData(25) // Evict items less than previous watermark
+      query.processAllAvailable()
+
+      // There should be 3 batches and only does the last batch contain a 
value.
+      val allEvents = ForeachWriterSuite.allEvents()
+      assert(allEvents.size === 3)
+      val expectedEvents = Seq(
+        Seq(
+          ForeachWriterSuite.Open(partition = 0, version = 0),
+          ForeachWriterSuite.Close(None)
+        ),
+        Seq(
+          ForeachWriterSuite.Open(partition = 0, version = 1),
+          ForeachWriterSuite.Close(None)
+        ),
+        Seq(
+          ForeachWriterSuite.Open(partition = 0, version = 2),
+          ForeachWriterSuite.Process(value = 3),
+          ForeachWriterSuite.Close(None)
+        )
+      )
+      assert(allEvents === expectedEvents)
+    } finally {
+      query.stop()
+    }
+  }
+
+  test("foreach sink should support metrics") {
+    val inputData = MemoryStream[Int]
+    val query = inputData.toDS()
+      .writeStream
+      .foreach(new TestForeachWriter())
+      .start()
+    try {
+      inputData.addData(10, 11, 12)
+      query.processAllAvailable()
+      val recentProgress = query.recentProgress.filter(_.numInputRows != 
0).headOption
+      assert(recentProgress.isDefined && recentProgress.get.numInputRows === 3,
+        s"recentProgress[${query.recentProgress.toList}] doesn't contain 
correct metrics")
+    } finally {
+      query.stop()
+    }
+  }
+}
+
+/** A global object to collect events in the executor */
+object ForeachWriterSuite {
+
+  trait Event
+
+  case class Open(partition: Long, version: Long) extends Event
+
+  case class Process[T](value: T) extends Event
+
+  case class Close(error: Option[Throwable]) extends Event
+
+  private val _allEvents = new ConcurrentLinkedQueue[Seq[Event]]()
+
+  def addEvents(events: Seq[Event]): Unit = {
+    _allEvents.add(events)
+  }
+
+  def allEvents(): Seq[Seq[Event]] = {
+    _allEvents.toArray(new Array[Seq[Event]](_allEvents.size()))
+  }
+
+  def clear(): Unit = {
+    _allEvents.clear()
+  }
+}
+
+/** A [[ForeachWriter]] that writes collected events to ForeachSinkSuite */
+class TestForeachWriter extends ForeachWriter[Int] {
+  ForeachWriterSuite.clear()
+
+  private val events = mutable.ArrayBuffer[ForeachWriterSuite.Event]()
+
+  override def open(partitionId: Long, version: Long): Boolean = {
+    events += ForeachWriterSuite.Open(partition = partitionId, version = 
version)
+    true
+  }
+
+  override def process(value: Int): Unit = {
+    events += ForeachWriterSuite.Process(value)
+  }
+
+  override def close(errorOrNull: Throwable): Unit = {
+    events += ForeachWriterSuite.Close(error = Option(errorOrNull))
+    ForeachWriterSuite.addEvents(events)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/66a3a5a2/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index 08749b4..20942ed 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -32,6 +32,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.sources.v2.reader.DataReaderFactory


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to