Repository: spark
Updated Branches:
  refs/heads/master 13092d733 -> 2cb976355


[SPARK-24565][SS] Add API for in Structured Streaming for exposing output rows 
of each microbatch as a DataFrame

## What changes were proposed in this pull request?

Currently, the micro-batches in the MicroBatchExecution is not exposed to the 
user through any public API. This was because we did not want to expose the 
micro-batches, so that all the APIs we expose, we can eventually support them 
in the Continuous engine. But now that we have better sense of buiding a 
ContinuousExecution, I am considering adding APIs which will run only the 
MicroBatchExecution. I have quite a few use cases where exposing the microbatch 
output as a dataframe is useful.
- Pass the output rows of each batch to a library that is designed only the 
batch jobs (example, uses many ML libraries need to collect() while learning).
- Reuse batch data sources for output whose streaming version does not exists 
(e.g. redshift data source).
- Writer the output rows to multiple places by writing twice for each batch. 
This is not the most elegant thing to do for multiple-output streaming queries 
but is likely to be better than running two streaming queries processing the 
same data twice.

The proposal is to add a method `foreachBatch(f: Dataset[T] => Unit)` to 
Scala/Java/Python `DataStreamWriter`.

## How was this patch tested?
New unit tests.

Author: Tathagata Das <tathagata.das1...@gmail.com>

Closes #21571 from tdas/foreachBatch.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2cb97635
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2cb97635
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2cb97635

Branch: refs/heads/master
Commit: 2cb976355c615eee4ebd0a86f3911fa9284fccf6
Parents: 13092d7
Author: Tathagata Das <tathagata.das1...@gmail.com>
Authored: Tue Jun 19 13:56:51 2018 -0700
Committer: Shixiong Zhu <zsxw...@gmail.com>
Committed: Tue Jun 19 13:56:51 2018 -0700

----------------------------------------------------------------------
 python/pyspark/java_gateway.py                  |  25 +++-
 python/pyspark/sql/streaming.py                 |  33 ++++-
 python/pyspark/sql/tests.py                     |  36 +++++
 python/pyspark/sql/utils.py                     |  23 +++
 python/pyspark/streaming/context.py             |  18 +--
 .../streaming/sources/ForeachBatchSink.scala    |  58 ++++++++
 .../spark/sql/streaming/DataStreamWriter.scala  |  63 +++++++-
 .../sources/ForeachBatchSinkSuite.scala         | 148 +++++++++++++++++++
 8 files changed, 383 insertions(+), 21 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2cb97635/python/pyspark/java_gateway.py
----------------------------------------------------------------------
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 0afbe9d..fa2d5e8 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -31,7 +31,7 @@ from subprocess import Popen, PIPE
 if sys.version >= '3':
     xrange = range
 
-from py4j.java_gateway import java_import, JavaGateway, GatewayParameters
+from py4j.java_gateway import java_import, JavaGateway, JavaObject, 
GatewayParameters
 from pyspark.find_spark_home import _find_spark_home
 from pyspark.serializers import read_int, write_with_length, UTF8Deserializer
 
@@ -145,3 +145,26 @@ def do_server_auth(conn, auth_secret):
     if reply != "ok":
         conn.close()
         raise Exception("Unexpected reply from iterator server.")
+
+
+def ensure_callback_server_started(gw):
+    """
+    Start callback server if not already started. The callback server is 
needed if the Java
+    driver process needs to callback into the Python driver process to execute 
Python code.
+    """
+
+    # getattr will fallback to JVM, so we cannot test by hasattr()
+    if "_callback_server" not in gw.__dict__ or gw._callback_server is None:
+        gw.callback_server_parameters.eager_load = True
+        gw.callback_server_parameters.daemonize = True
+        gw.callback_server_parameters.daemonize_connections = True
+        gw.callback_server_parameters.port = 0
+        gw.start_callback_server(gw.callback_server_parameters)
+        cbport = gw._callback_server.server_socket.getsockname()[1]
+        gw._callback_server.port = cbport
+        # gateway with real port
+        gw._python_proxy_port = gw._callback_server.port
+        # get the GatewayServer object in JVM by ID
+        jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client)
+        # update the port of CallbackClient with real port
+        jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), 
gw._python_proxy_port)

http://git-wip-us.apache.org/repos/asf/spark/blob/2cb97635/python/pyspark/sql/streaming.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index 4984593..8c1fd4a 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -24,12 +24,14 @@ if sys.version >= '3':
 else:
     intlike = (int, long)
 
+from py4j.java_gateway import java_import
+
 from pyspark import since, keyword_only
 from pyspark.rdd import ignore_unicode_prefix
 from pyspark.sql.column import _to_seq
 from pyspark.sql.readwriter import OptionUtils, to_str
 from pyspark.sql.types import *
-from pyspark.sql.utils import StreamingQueryException
+from pyspark.sql.utils import ForeachBatchFunction, StreamingQueryException
 
 __all__ = ["StreamingQuery", "StreamingQueryManager", "DataStreamReader", 
"DataStreamWriter"]
 
@@ -1016,6 +1018,35 @@ class DataStreamWriter(object):
         self._jwrite.foreach(jForeachWriter)
         return self
 
+    @since(2.4)
+    def foreachBatch(self, func):
+        """
+        Sets the output of the streaming query to be processed using the 
provided
+        function. This is supported only the in the micro-batch execution 
modes (that is, when the
+        trigger is not continuous). In every micro-batch, the provided 
function will be called in
+        every micro-batch with (i) the output rows as a DataFrame and (ii) the 
batch identifier.
+        The batchId can be used deduplicate and transactionally write the 
output
+        (that is, the provided Dataset) to external systems. The output 
DataFrame is guaranteed
+        to exactly same for the same batchId (assuming all operations are 
deterministic in the
+        query).
+
+        .. note:: Evolving.
+
+        >>> def func(batch_df, batch_id):
+        ...     batch_df.collect()
+        ...
+        >>> writer = sdf.writeStream.foreach(func)
+        """
+
+        from pyspark.java_gateway import ensure_callback_server_started
+        gw = self._spark._sc._gateway
+        java_import(gw.jvm, 
"org.apache.spark.sql.execution.streaming.sources.*")
+
+        wrapped_func = ForeachBatchFunction(self._spark, func)
+        gw.jvm.PythonForeachBatchHelper.callForeachBatch(self._jwrite, 
wrapped_func)
+        ensure_callback_server_started(gw)
+        return self
+
     @ignore_unicode_prefix
     @since(2.0)
     def start(self, path=None, format=None, outputMode=None, partitionBy=None, 
queryName=None,

http://git-wip-us.apache.org/repos/asf/spark/blob/2cb97635/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 4e5fafa..94ab867 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -2126,6 +2126,42 @@ class SQLTests(ReusedSQLTestCase):
         tester.assert_invalid_writer(WriterWithNonCallableClose(),
                                      "'close' in provided object is not 
callable")
 
+    def test_streaming_foreachBatch(self):
+        q = None
+        collected = dict()
+
+        def collectBatch(batch_df, batch_id):
+            collected[batch_id] = batch_df.collect()
+
+        try:
+            df = 
self.spark.readStream.format('text').load('python/test_support/sql/streaming')
+            q = df.writeStream.foreachBatch(collectBatch).start()
+            q.processAllAvailable()
+            self.assertTrue(0 in collected)
+            self.assertTrue(len(collected[0]), 2)
+        finally:
+            if q:
+                q.stop()
+
+    def test_streaming_foreachBatch_propagates_python_errors(self):
+        from pyspark.sql.utils import StreamingQueryException
+
+        q = None
+
+        def collectBatch(df, id):
+            raise Exception("this should fail the query")
+
+        try:
+            df = 
self.spark.readStream.format('text').load('python/test_support/sql/streaming')
+            q = df.writeStream.foreachBatch(collectBatch).start()
+            q.processAllAvailable()
+            self.fail("Expected a failure")
+        except StreamingQueryException as e:
+            self.assertTrue("this should fail" in str(e))
+        finally:
+            if q:
+                q.stop()
+
     def test_help_command(self):
         # Regression test for SPARK-5464
         rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])

http://git-wip-us.apache.org/repos/asf/spark/blob/2cb97635/python/pyspark/sql/utils.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index 45363f0..bb9ce02 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -150,3 +150,26 @@ def require_minimum_pyarrow_version():
     if LooseVersion(pyarrow.__version__) < 
LooseVersion(minimum_pyarrow_version):
         raise ImportError("PyArrow >= %s must be installed; however, "
                           "your version was %s." % (minimum_pyarrow_version, 
pyarrow.__version__))
+
+
+class ForeachBatchFunction(object):
+    """
+    This is the Python implementation of Java interface 
'ForeachBatchFunction'. This wraps
+    the user-defined 'foreachBatch' function such that it can be called from 
the JVM when
+    the query is active.
+    """
+
+    def __init__(self, sql_ctx, func):
+        self.sql_ctx = sql_ctx
+        self.func = func
+
+    def call(self, jdf, batch_id):
+        from pyspark.sql.dataframe import DataFrame
+        try:
+            self.func(DataFrame(jdf, self.sql_ctx), batch_id)
+        except Exception as e:
+            self.error = e
+            raise e
+
+    class Java:
+        implements = 
['org.apache.spark.sql.execution.streaming.sources.PythonForeachBatchFunction']

http://git-wip-us.apache.org/repos/asf/spark/blob/2cb97635/python/pyspark/streaming/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/streaming/context.py 
b/python/pyspark/streaming/context.py
index dd924ef..a451582 100644
--- a/python/pyspark/streaming/context.py
+++ b/python/pyspark/streaming/context.py
@@ -79,22 +79,8 @@ class StreamingContext(object):
         java_import(gw.jvm, "org.apache.spark.streaming.api.java.*")
         java_import(gw.jvm, "org.apache.spark.streaming.api.python.*")
 
-        # start callback server
-        # getattr will fallback to JVM, so we cannot test by hasattr()
-        if "_callback_server" not in gw.__dict__ or gw._callback_server is 
None:
-            gw.callback_server_parameters.eager_load = True
-            gw.callback_server_parameters.daemonize = True
-            gw.callback_server_parameters.daemonize_connections = True
-            gw.callback_server_parameters.port = 0
-            gw.start_callback_server(gw.callback_server_parameters)
-            cbport = gw._callback_server.server_socket.getsockname()[1]
-            gw._callback_server.port = cbport
-            # gateway with real port
-            gw._python_proxy_port = gw._callback_server.port
-            # get the GatewayServer object in JVM by ID
-            jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client)
-            # update the port of CallbackClient with real port
-            jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), 
gw._python_proxy_port)
+        from pyspark.java_gateway import ensure_callback_server_started
+        ensure_callback_server_started(gw)
 
         # register serializer for TransformFunction
         # it happens before creating SparkContext when loading from 
checkpointing

http://git-wip-us.apache.org/repos/asf/spark/blob/2cb97635/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala
new file mode 100644
index 0000000..03c567c
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import org.apache.spark.api.python.PythonException
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.execution.streaming.Sink
+import org.apache.spark.sql.streaming.DataStreamWriter
+
+class ForeachBatchSink[T](batchWriter: (Dataset[T], Long) => Unit, encoder: 
ExpressionEncoder[T])
+  extends Sink {
+
+  override def addBatch(batchId: Long, data: DataFrame): Unit = {
+    val resolvedEncoder = encoder.resolveAndBind(
+      data.logicalPlan.output,
+      data.sparkSession.sessionState.analyzer)
+    val rdd = 
data.queryExecution.toRdd.map[T](resolvedEncoder.fromRow)(encoder.clsTag)
+    val ds = data.sparkSession.createDataset(rdd)(encoder)
+    batchWriter(ds, batchId)
+  }
+
+  override def toString(): String = "ForeachBatchSink"
+}
+
+
+/**
+ * Interface that is meant to be extended by Python classes via Py4J.
+ * Py4J allows Python classes to implement Java interfaces so that the JVM can 
call back
+ * Python objects. In this case, this allows the user-defined Python 
`foreachBatch` function
+ * to be called from JVM when the query is active.
+ * */
+trait PythonForeachBatchFunction {
+  /** Call the Python implementation of this function */
+  def call(batchDF: DataFrame, batchId: Long): Unit
+}
+
+object PythonForeachBatchHelper {
+  def callForeachBatch(dsw: DataStreamWriter[Row], pythonFunc: 
PythonForeachBatchFunction): Unit = {
+    dsw.foreachBatch(pythonFunc.call _)
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/2cb97635/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 43e80e4..926c0b6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -21,14 +21,15 @@ import java.util.Locale
 
 import scala.collection.JavaConverters._
 
-import org.apache.spark.annotation.InterfaceStability
-import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter}
+import org.apache.spark.annotation.{InterfaceStability, Since}
+import org.apache.spark.api.java.function.VoidFunction2
+import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
 import org.apache.spark.sql.execution.command.DDLUtils
 import org.apache.spark.sql.execution.datasources.DataSource
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
-import 
org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, 
MemoryPlanV2, MemorySinkV2}
+import org.apache.spark.sql.execution.streaming.sources._
 import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport}
 
 /**
@@ -279,6 +280,21 @@ final class DataStreamWriter[T] private[sql](ds: 
Dataset[T]) {
         outputMode,
         useTempCheckpointLocation = true,
         trigger = trigger)
+    } else if (source == "foreachBatch") {
+      assertNotPartitioned("foreachBatch")
+      if (trigger.isInstanceOf[ContinuousTrigger]) {
+        throw new AnalysisException("'foreachBatch' is not supported with 
continuous trigger")
+      }
+      val sink = new ForeachBatchSink[T](foreachBatchWriter, ds.exprEnc)
+      df.sparkSession.sessionState.streamingQueryManager.startQuery(
+        extraOptions.get("queryName"),
+        extraOptions.get("checkpointLocation"),
+        df,
+        extraOptions.toMap,
+        sink,
+        outputMode,
+        useTempCheckpointLocation = true,
+        trigger = trigger)
     } else {
       val ds = DataSource.lookupDataSource(source, 
df.sparkSession.sessionState.conf)
       val disabledSources = 
df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",")
@@ -322,6 +338,45 @@ final class DataStreamWriter[T] private[sql](ds: 
Dataset[T]) {
     this
   }
 
+  /**
+   * :: Experimental ::
+   *
+   * (Scala-specific) Sets the output of the streaming query to be processed 
using the provided
+   * function. This is supported only the in the micro-batch execution modes 
(that is, when the
+   * trigger is not continuous). In every micro-batch, the provided function 
will be called in
+   * every micro-batch with (i) the output rows as a Dataset and (ii) the 
batch identifier.
+   * The batchId can be used deduplicate and transactionally write the output
+   * (that is, the provided Dataset) to external systems. The output Dataset 
is guaranteed
+   * to exactly same for the same batchId (assuming all operations are 
deterministic in the query).
+   *
+   * @since 2.4.0
+   */
+  @InterfaceStability.Evolving
+  def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] 
= {
+    this.source = "foreachBatch"
+    if (function == null) throw new IllegalArgumentException("foreachBatch 
function cannot be null")
+    this.foreachBatchWriter = function
+    this
+  }
+
+  /**
+   * :: Experimental ::
+   *
+   * (Java-specific) Sets the output of the streaming query to be processed 
using the provided
+   * function. This is supported only the in the micro-batch execution modes 
(that is, when the
+   * trigger is not continuous). In every micro-batch, the provided function 
will be called in
+   * every micro-batch with (i) the output rows as a Dataset and (ii) the 
batch identifier.
+   * The batchId can be used deduplicate and transactionally write the output
+   * (that is, the provided Dataset) to external systems. The output Dataset 
is guaranteed
+   * to exactly same for the same batchId (assuming all operations are 
deterministic in the query).
+   *
+   * @since 2.4.0
+   */
+  @InterfaceStability.Evolving
+  def foreachBatch(function: VoidFunction2[Dataset[T], Long]): 
DataStreamWriter[T] = {
+    foreachBatch((batchDs: Dataset[T], batchId: Long) => 
function.call(batchDs, batchId))
+  }
+
   private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map 
{ cols =>
     cols.map(normalize(_, "Partition"))
   }
@@ -358,5 +413,7 @@ final class DataStreamWriter[T] private[sql](ds: 
Dataset[T]) {
 
   private var foreachWriter: ForeachWriter[T] = null
 
+  private var foreachBatchWriter: (Dataset[T], Long) => Unit = null
+
   private var partitioningColumns: Option[Seq[String]] = None
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2cb97635/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala
new file mode 100644
index 0000000..a4233e1
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import scala.collection.mutable
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.streaming._
+
+case class KV(key: Int, value: Long)
+
+class ForeachBatchSinkSuite extends StreamTest {
+  import testImplicits._
+
+  test("foreachBatch with non-stateful query") {
+    val mem = MemoryStream[Int]
+    val ds = mem.toDS.map(_ + 1)
+
+    val tester = new ForeachBatchTester[Int](mem)
+    val writer = (ds: Dataset[Int], batchId: Long) => tester.record(batchId, 
ds.map(_ + 1))
+
+    import tester._
+    testWriter(ds, writer)(
+      check(in = 1, 2, 3)(out = 3, 4, 5), // out = in + 2 (i.e. 1 in query, 1 
in writer)
+      check(in = 5, 6, 7)(out = 7, 8, 9))
+  }
+
+  test("foreachBatch with stateful query in update mode") {
+    val mem = MemoryStream[Int]
+    val ds = mem.toDF()
+      .select($"value" % 2 as "key")
+      .groupBy("key")
+      .agg(count("*") as "value")
+      .toDF.as[KV]
+
+    val tester = new ForeachBatchTester[KV](mem)
+    val writer = (batchDS: Dataset[KV], batchId: Long) => 
tester.record(batchId, batchDS)
+
+    import tester._
+    testWriter(ds, writer, outputMode = OutputMode.Update)(
+      check(in = 0)(out = (0, 1L)),
+      check(in = 1)(out = (1, 1L)),
+      check(in = 2, 3)(out = (0, 2L), (1, 2L)))
+  }
+
+  test("foreachBatch with stateful query in complete mode") {
+    val mem = MemoryStream[Int]
+    val ds = mem.toDF()
+      .select($"value" % 2 as "key")
+      .groupBy("key")
+      .agg(count("*") as "value")
+      .toDF.as[KV]
+
+    val tester = new ForeachBatchTester[KV](mem)
+    val writer = (batchDS: Dataset[KV], batchId: Long) => 
tester.record(batchId, batchDS)
+
+    import tester._
+    testWriter(ds, writer, outputMode = OutputMode.Complete)(
+      check(in = 0)(out = (0, 1L)),
+      check(in = 1)(out = (0, 1L), (1, 1L)),
+      check(in = 2)(out = (0, 2L), (1, 1L)))
+  }
+
+  test("foreachBatchSink does not affect metric generation") {
+    val mem = MemoryStream[Int]
+    val ds = mem.toDS.map(_ + 1)
+
+    val tester = new ForeachBatchTester[Int](mem)
+    val writer = (ds: Dataset[Int], batchId: Long) => tester.record(batchId, 
ds.map(_ + 1))
+
+    import tester._
+    testWriter(ds, writer)(
+      check(in = 1, 2, 3)(out = 3, 4, 5),
+      checkMetrics)
+  }
+
+  test("throws errors in invalid situations") {
+    val ds = MemoryStream[Int].toDS
+    val ex1 = intercept[IllegalArgumentException] {
+      ds.writeStream.foreachBatch(null.asInstanceOf[(Dataset[Int], Long) => 
Unit]).start()
+    }
+    assert(ex1.getMessage.contains("foreachBatch function cannot be null"))
+    val ex2 = intercept[AnalysisException] {
+      ds.writeStream.foreachBatch((_, _) => {}).trigger(Trigger.Continuous("1 
second")).start()
+    }
+    assert(ex2.getMessage.contains("'foreachBatch' is not supported with 
continuous trigger"))
+    val ex3 = intercept[AnalysisException] {
+      ds.writeStream.foreachBatch((_, _) => {}).partitionBy("value").start()
+    }
+    assert(ex3.getMessage.contains("'foreachBatch' does not support 
partitioning"))
+  }
+
+  // ============== Helper classes and methods =================
+
+  private class ForeachBatchTester[T: Encoder](memoryStream: 
MemoryStream[Int]) {
+    trait Test
+    private case class Check(in: Seq[Int], out: Seq[T]) extends Test
+    private case object CheckMetrics extends Test
+
+    private val recordedOutput = new mutable.HashMap[Long, Seq[T]]
+
+    def testWriter(
+        ds: Dataset[T],
+        outputBatchWriter: (Dataset[T], Long) => Unit,
+        outputMode: OutputMode = OutputMode.Append())(tests: Test*): Unit = {
+      try {
+        var expectedBatchId = -1
+        val query = 
ds.writeStream.outputMode(outputMode).foreachBatch(outputBatchWriter).start()
+
+        tests.foreach {
+          case Check(in, out) =>
+            expectedBatchId += 1
+            memoryStream.addData(in)
+            query.processAllAvailable()
+            assert(recordedOutput.contains(expectedBatchId))
+            val ds: Dataset[T] = 
spark.createDataset[T](recordedOutput(expectedBatchId))
+            checkDataset[T](ds, out: _*)
+          case CheckMetrics =>
+            assert(query.recentProgress.exists(_.numInputRows > 0))
+        }
+      } finally {
+        sqlContext.streams.active.foreach(_.stop())
+      }
+    }
+
+    def check(in: Int*)(out: T*): Test = Check(in, out)
+    def checkMetrics: Test = CheckMetrics
+    def record(batchId: Long, ds: Dataset[T]): Unit = 
recordedOutput.put(batchId, ds.collect())
+    implicit def conv(x: (Int, Long)): KV = KV(x._1, x._2)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to