Repository: spark Updated Branches: refs/heads/master 7cf9fab33 -> 66a3a5a2d
[SPARK-23099][SS] Migrate foreach sink to DataSourceV2 ## What changes were proposed in this pull request? Migrate foreach sink to DataSourceV2. Since the previous attempt at this PR #20552, we've changed and strictly defined the lifecycle of writer components. This means we no longer need the complicated lifecycle shim from that PR; it just naturally works. ## How was this patch tested? existing tests Author: Jose Torres <torres.joseph.f+git...@gmail.com> Closes #20951 from jose-torres/foreach. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/66a3a5a2 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/66a3a5a2 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/66a3a5a2 Branch: refs/heads/master Commit: 66a3a5a2dc83e03dedcee9839415c1ddc1fb8125 Parents: 7cf9fab Author: Jose Torres <torres.joseph.f+git...@gmail.com> Authored: Tue Apr 3 11:05:29 2018 -0700 Committer: Tathagata Das <tathagata.das1...@gmail.com> Committed: Tue Apr 3 11:05:29 2018 -0700 ---------------------------------------------------------------------- .../sql/execution/streaming/ForeachSink.scala | 68 ----- .../sources/ForeachWriterProvider.scala | 111 +++++++ .../spark/sql/streaming/DataStreamWriter.scala | 4 +- .../execution/streaming/ForeachSinkSuite.scala | 305 ------------------ .../streaming/sources/ForeachWriterSuite.scala | 306 +++++++++++++++++++ .../sql/streaming/StreamingQuerySuite.scala | 1 + 6 files changed, 420 insertions(+), 375 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/66a3a5a2/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala deleted file mode 100644 index 2cc5410..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import org.apache.spark.TaskContext -import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter} -import org.apache.spark.sql.catalyst.encoders.encoderFor - -/** - * A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by - * [[ForeachWriter]]. - * - * @param writer The [[ForeachWriter]] to process all data. - * @tparam T The expected type of the sink. - */ -class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable { - - override def addBatch(batchId: Long, data: DataFrame): Unit = { - // This logic should've been as simple as: - // ``` - // data.as[T].foreachPartition { iter => ... } - // ``` - // - // Unfortunately, doing that would just break the incremental planing. The reason is, - // `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` will - // create a new plan. Because StreamExecution uses the existing plan to collect metrics and - // update watermark, we should never create a new plan. Otherwise, metrics and watermark are - // updated in the new plan, and StreamExecution cannot retrieval them. - // - // Hence, we need to manually convert internal rows to objects using encoder. - val encoder = encoderFor[T].resolveAndBind( - data.logicalPlan.output, - data.sparkSession.sessionState.analyzer) - data.queryExecution.toRdd.foreachPartition { iter => - if (writer.open(TaskContext.getPartitionId(), batchId)) { - try { - while (iter.hasNext) { - writer.process(encoder.fromRow(iter.next())) - } - } catch { - case e: Throwable => - writer.close(e) - throw e - } - writer.close(null) - } else { - writer.close(null) - } - } - } - - override def toString(): String = "ForeachSink" -} http://git-wip-us.apache.org/repos/asf/spark/blob/66a3a5a2/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala new file mode 100644 index 0000000..df5d69d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType + +/** + * A [[org.apache.spark.sql.sources.v2.DataSourceV2]] for forwarding data into the specified + * [[ForeachWriter]]. + * + * @param writer The [[ForeachWriter]] to process all data. + * @tparam T The expected type of the sink. + */ +case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) extends StreamWriteSupport { + override def createStreamWriter( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceOptions): StreamWriter = { + new StreamWriter with SupportsWriteInternalRow { + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + + override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { + val encoder = encoderFor[T].resolveAndBind( + schema.toAttributes, + SparkSession.getActiveSession.get.sessionState.analyzer) + ForeachWriterFactory(writer, encoder) + } + + override def toString: String = "ForeachSink" + } + } +} + +case class ForeachWriterFactory[T: Encoder]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T]) + extends DataWriterFactory[InternalRow] { + override def createDataWriter( + partitionId: Int, + attemptNumber: Int, + epochId: Long): ForeachDataWriter[T] = { + new ForeachDataWriter(writer, encoder, partitionId, epochId) + } +} + +/** + * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]]. + * @param writer The [[ForeachWriter]] to process all data. + * @param encoder An encoder which can convert [[InternalRow]] to the required type [[T]] + * @param partitionId + * @param epochId + * @tparam T The type expected by the writer. + */ +class ForeachDataWriter[T : Encoder]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T], + partitionId: Int, + epochId: Long) + extends DataWriter[InternalRow] { + + // If open returns false, we should skip writing rows. + private val opened = writer.open(partitionId, epochId) + + override def write(record: InternalRow): Unit = { + if (!opened) return + + try { + writer.process(encoder.fromRow(record)) + } catch { + case t: Throwable => + writer.close(t) + throw t + } + } + + override def commit(): WriterCommitMessage = { + writer.close(null) + ForeachWriterCommitMessage + } + + override def abort(): Unit = {} +} + +/** + * An empty [[WriterCommitMessage]]. [[ForeachWriter]] implementations have no global coordination. + */ +case object ForeachWriterCommitMessage extends WriterCommitMessage http://git-wip-us.apache.org/repos/asf/spark/blob/66a3a5a2/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 2fc9031..effc147 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger -import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2} +import org.apache.spark.sql.execution.streaming.sources.{ForeachWriterProvider, MemoryPlanV2, MemorySinkV2} import org.apache.spark.sql.sources.v2.StreamWriteSupport /** @@ -269,7 +269,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = new ForeachSink[T](foreachWriter)(ds.exprEnc) + val sink = new ForeachWriterProvider[T](foreachWriter)(ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), http://git-wip-us.apache.org/repos/asf/spark/blob/66a3a5a2/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala deleted file mode 100644 index b249dd4..0000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ /dev/null @@ -1,305 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.util.concurrent.ConcurrentLinkedQueue - -import scala.collection.mutable - -import org.scalatest.BeforeAndAfter - -import org.apache.spark.SparkException -import org.apache.spark.sql.ForeachWriter -import org.apache.spark.sql.functions.{count, window} -import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest} -import org.apache.spark.sql.test.SharedSQLContext - -class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { - - import testImplicits._ - - after { - sqlContext.streams.active.foreach(_.stop()) - } - - test("foreach() with `append` output mode") { - withTempDir { checkpointDir => - val input = MemoryStream[Int] - val query = input.toDS().repartition(2).writeStream - .option("checkpointLocation", checkpointDir.getCanonicalPath) - .outputMode(OutputMode.Append) - .foreach(new TestForeachWriter()) - .start() - - def verifyOutput(expectedVersion: Int, expectedData: Seq[Int]): Unit = { - import ForeachSinkSuite._ - - val events = ForeachSinkSuite.allEvents() - assert(events.size === 2) // one seq of events for each of the 2 partitions - - // Verify both seq of events have an Open event as the first event - assert(events.map(_.head).toSet === Set(0, 1).map(p => Open(p, expectedVersion))) - - // Verify all the Process event correspond to the expected data - val allProcessEvents = events.flatMap(_.filter(_.isInstanceOf[Process[_]])) - assert(allProcessEvents.toSet === expectedData.map { data => Process(data) }.toSet) - - // Verify both seq of events have a Close event as the last event - assert(events.map(_.last).toSet === Set(Close(None), Close(None))) - } - - // -- batch 0 --------------------------------------- - ForeachSinkSuite.clear() - input.addData(1, 2, 3, 4) - query.processAllAvailable() - verifyOutput(expectedVersion = 0, expectedData = 1 to 4) - - // -- batch 1 --------------------------------------- - ForeachSinkSuite.clear() - input.addData(5, 6, 7, 8) - query.processAllAvailable() - verifyOutput(expectedVersion = 1, expectedData = 5 to 8) - - query.stop() - } - } - - test("foreach() with `complete` output mode") { - withTempDir { checkpointDir => - val input = MemoryStream[Int] - - val query = input.toDS() - .groupBy().count().as[Long].map(_.toInt) - .writeStream - .option("checkpointLocation", checkpointDir.getCanonicalPath) - .outputMode(OutputMode.Complete) - .foreach(new TestForeachWriter()) - .start() - - // -- batch 0 --------------------------------------- - input.addData(1, 2, 3, 4) - query.processAllAvailable() - - var allEvents = ForeachSinkSuite.allEvents() - assert(allEvents.size === 1) - var expectedEvents = Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Process(value = 4), - ForeachSinkSuite.Close(None) - ) - assert(allEvents === Seq(expectedEvents)) - - ForeachSinkSuite.clear() - - // -- batch 1 --------------------------------------- - input.addData(5, 6, 7, 8) - query.processAllAvailable() - - allEvents = ForeachSinkSuite.allEvents() - assert(allEvents.size === 1) - expectedEvents = Seq( - ForeachSinkSuite.Open(partition = 0, version = 1), - ForeachSinkSuite.Process(value = 8), - ForeachSinkSuite.Close(None) - ) - assert(allEvents === Seq(expectedEvents)) - - query.stop() - } - } - - testQuietly("foreach with error") { - withTempDir { checkpointDir => - val input = MemoryStream[Int] - val query = input.toDS().repartition(1).writeStream - .option("checkpointLocation", checkpointDir.getCanonicalPath) - .foreach(new TestForeachWriter() { - override def process(value: Int): Unit = { - super.process(value) - throw new RuntimeException("error") - } - }).start() - input.addData(1, 2, 3, 4) - - // Error in `process` should fail the Spark job - val e = intercept[StreamingQueryException] { - query.processAllAvailable() - } - assert(e.getCause.isInstanceOf[SparkException]) - assert(e.getCause.getCause.getMessage === "error") - assert(query.isActive === false) - - val allEvents = ForeachSinkSuite.allEvents() - assert(allEvents.size === 1) - assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0)) - assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1)) - - // `close` should be called with the error - val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close] - assert(errorEvent.error.get.isInstanceOf[RuntimeException]) - assert(errorEvent.error.get.getMessage === "error") - } - } - - test("foreach with watermark: complete") { - val inputData = MemoryStream[Int] - - val windowedAggregation = inputData.toDF() - .withColumn("eventTime", $"value".cast("timestamp")) - .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"count".as[Long]) - .map(_.toInt) - .repartition(1) - - val query = windowedAggregation - .writeStream - .outputMode(OutputMode.Complete) - .foreach(new TestForeachWriter()) - .start() - try { - inputData.addData(10, 11, 12) - query.processAllAvailable() - - val allEvents = ForeachSinkSuite.allEvents() - assert(allEvents.size === 1) - val expectedEvents = Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Process(value = 3), - ForeachSinkSuite.Close(None) - ) - assert(allEvents === Seq(expectedEvents)) - } finally { - query.stop() - } - } - - test("foreach with watermark: append") { - val inputData = MemoryStream[Int] - - val windowedAggregation = inputData.toDF() - .withColumn("eventTime", $"value".cast("timestamp")) - .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"count".as[Long]) - .map(_.toInt) - .repartition(1) - - val query = windowedAggregation - .writeStream - .outputMode(OutputMode.Append) - .foreach(new TestForeachWriter()) - .start() - try { - inputData.addData(10, 11, 12) - query.processAllAvailable() - inputData.addData(25) // Advance watermark to 15 seconds - query.processAllAvailable() - inputData.addData(25) // Evict items less than previous watermark - query.processAllAvailable() - - // There should be 3 batches and only does the last batch contain a value. - val allEvents = ForeachSinkSuite.allEvents() - assert(allEvents.size === 3) - val expectedEvents = Seq( - Seq( - ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Close(None) - ), - Seq( - ForeachSinkSuite.Open(partition = 0, version = 1), - ForeachSinkSuite.Close(None) - ), - Seq( - ForeachSinkSuite.Open(partition = 0, version = 2), - ForeachSinkSuite.Process(value = 3), - ForeachSinkSuite.Close(None) - ) - ) - assert(allEvents === expectedEvents) - } finally { - query.stop() - } - } - - test("foreach sink should support metrics") { - val inputData = MemoryStream[Int] - val query = inputData.toDS() - .writeStream - .foreach(new TestForeachWriter()) - .start() - try { - inputData.addData(10, 11, 12) - query.processAllAvailable() - val recentProgress = query.recentProgress.filter(_.numInputRows != 0).headOption - assert(recentProgress.isDefined && recentProgress.get.numInputRows === 3, - s"recentProgress[${query.recentProgress.toList}] doesn't contain correct metrics") - } finally { - query.stop() - } - } -} - -/** A global object to collect events in the executor */ -object ForeachSinkSuite { - - trait Event - - case class Open(partition: Long, version: Long) extends Event - - case class Process[T](value: T) extends Event - - case class Close(error: Option[Throwable]) extends Event - - private val _allEvents = new ConcurrentLinkedQueue[Seq[Event]]() - - def addEvents(events: Seq[Event]): Unit = { - _allEvents.add(events) - } - - def allEvents(): Seq[Seq[Event]] = { - _allEvents.toArray(new Array[Seq[Event]](_allEvents.size())) - } - - def clear(): Unit = { - _allEvents.clear() - } -} - -/** A [[ForeachWriter]] that writes collected events to ForeachSinkSuite */ -class TestForeachWriter extends ForeachWriter[Int] { - ForeachSinkSuite.clear() - - private val events = mutable.ArrayBuffer[ForeachSinkSuite.Event]() - - override def open(partitionId: Long, version: Long): Boolean = { - events += ForeachSinkSuite.Open(partition = partitionId, version = version) - true - } - - override def process(value: Int): Unit = { - events += ForeachSinkSuite.Process(value) - } - - override def close(errorOrNull: Throwable): Unit = { - events += ForeachSinkSuite.Close(error = Option(errorOrNull)) - ForeachSinkSuite.addEvents(events) - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/66a3a5a2/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala new file mode 100644 index 0000000..03bf71b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.collection.mutable + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkException +import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions.{count, window} +import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest} +import org.apache.spark.sql.test.SharedSQLContext + +class ForeachWriterSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { + + import testImplicits._ + + after { + sqlContext.streams.active.foreach(_.stop()) + } + + test("foreach() with `append` output mode") { + withTempDir { checkpointDir => + val input = MemoryStream[Int] + val query = input.toDS().repartition(2).writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .outputMode(OutputMode.Append) + .foreach(new TestForeachWriter()) + .start() + + def verifyOutput(expectedVersion: Int, expectedData: Seq[Int]): Unit = { + import ForeachWriterSuite._ + + val events = ForeachWriterSuite.allEvents() + assert(events.size === 2) // one seq of events for each of the 2 partitions + + // Verify both seq of events have an Open event as the first event + assert(events.map(_.head).toSet === Set(0, 1).map(p => Open(p, expectedVersion))) + + // Verify all the Process event correspond to the expected data + val allProcessEvents = events.flatMap(_.filter(_.isInstanceOf[Process[_]])) + assert(allProcessEvents.toSet === expectedData.map { data => Process(data) }.toSet) + + // Verify both seq of events have a Close event as the last event + assert(events.map(_.last).toSet === Set(Close(None), Close(None))) + } + + // -- batch 0 --------------------------------------- + ForeachWriterSuite.clear() + input.addData(1, 2, 3, 4) + query.processAllAvailable() + verifyOutput(expectedVersion = 0, expectedData = 1 to 4) + + // -- batch 1 --------------------------------------- + ForeachWriterSuite.clear() + input.addData(5, 6, 7, 8) + query.processAllAvailable() + verifyOutput(expectedVersion = 1, expectedData = 5 to 8) + + query.stop() + } + } + + test("foreach() with `complete` output mode") { + withTempDir { checkpointDir => + val input = MemoryStream[Int] + + val query = input.toDS() + .groupBy().count().as[Long].map(_.toInt) + .writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .outputMode(OutputMode.Complete) + .foreach(new TestForeachWriter()) + .start() + + // -- batch 0 --------------------------------------- + input.addData(1, 2, 3, 4) + query.processAllAvailable() + + var allEvents = ForeachWriterSuite.allEvents() + assert(allEvents.size === 1) + var expectedEvents = Seq( + ForeachWriterSuite.Open(partition = 0, version = 0), + ForeachWriterSuite.Process(value = 4), + ForeachWriterSuite.Close(None) + ) + assert(allEvents === Seq(expectedEvents)) + + ForeachWriterSuite.clear() + + // -- batch 1 --------------------------------------- + input.addData(5, 6, 7, 8) + query.processAllAvailable() + + allEvents = ForeachWriterSuite.allEvents() + assert(allEvents.size === 1) + expectedEvents = Seq( + ForeachWriterSuite.Open(partition = 0, version = 1), + ForeachWriterSuite.Process(value = 8), + ForeachWriterSuite.Close(None) + ) + assert(allEvents === Seq(expectedEvents)) + + query.stop() + } + } + + testQuietly("foreach with error") { + withTempDir { checkpointDir => + val input = MemoryStream[Int] + val query = input.toDS().repartition(1).writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .foreach(new TestForeachWriter() { + override def process(value: Int): Unit = { + super.process(value) + throw new RuntimeException("ForeachSinkSuite error") + } + }).start() + input.addData(1, 2, 3, 4) + + // Error in `process` should fail the Spark job + val e = intercept[StreamingQueryException] { + query.processAllAvailable() + } + assert(e.getCause.isInstanceOf[SparkException]) + assert(e.getCause.getCause.getCause.getMessage === "ForeachSinkSuite error") + assert(query.isActive === false) + + val allEvents = ForeachWriterSuite.allEvents() + assert(allEvents.size === 1) + assert(allEvents(0)(0) === ForeachWriterSuite.Open(partition = 0, version = 0)) + assert(allEvents(0)(1) === ForeachWriterSuite.Process(value = 1)) + + // `close` should be called with the error + val errorEvent = allEvents(0)(2).asInstanceOf[ForeachWriterSuite.Close] + assert(errorEvent.error.get.isInstanceOf[RuntimeException]) + assert(errorEvent.error.get.getMessage === "ForeachSinkSuite error") + } + } + + test("foreach with watermark: complete") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"count".as[Long]) + .map(_.toInt) + .repartition(1) + + val query = windowedAggregation + .writeStream + .outputMode(OutputMode.Complete) + .foreach(new TestForeachWriter()) + .start() + try { + inputData.addData(10, 11, 12) + query.processAllAvailable() + + val allEvents = ForeachWriterSuite.allEvents() + assert(allEvents.size === 1) + val expectedEvents = Seq( + ForeachWriterSuite.Open(partition = 0, version = 0), + ForeachWriterSuite.Process(value = 3), + ForeachWriterSuite.Close(None) + ) + assert(allEvents === Seq(expectedEvents)) + } finally { + query.stop() + } + } + + test("foreach with watermark: append") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"count".as[Long]) + .map(_.toInt) + .repartition(1) + + val query = windowedAggregation + .writeStream + .outputMode(OutputMode.Append) + .foreach(new TestForeachWriter()) + .start() + try { + inputData.addData(10, 11, 12) + query.processAllAvailable() + inputData.addData(25) // Advance watermark to 15 seconds + query.processAllAvailable() + inputData.addData(25) // Evict items less than previous watermark + query.processAllAvailable() + + // There should be 3 batches and only does the last batch contain a value. + val allEvents = ForeachWriterSuite.allEvents() + assert(allEvents.size === 3) + val expectedEvents = Seq( + Seq( + ForeachWriterSuite.Open(partition = 0, version = 0), + ForeachWriterSuite.Close(None) + ), + Seq( + ForeachWriterSuite.Open(partition = 0, version = 1), + ForeachWriterSuite.Close(None) + ), + Seq( + ForeachWriterSuite.Open(partition = 0, version = 2), + ForeachWriterSuite.Process(value = 3), + ForeachWriterSuite.Close(None) + ) + ) + assert(allEvents === expectedEvents) + } finally { + query.stop() + } + } + + test("foreach sink should support metrics") { + val inputData = MemoryStream[Int] + val query = inputData.toDS() + .writeStream + .foreach(new TestForeachWriter()) + .start() + try { + inputData.addData(10, 11, 12) + query.processAllAvailable() + val recentProgress = query.recentProgress.filter(_.numInputRows != 0).headOption + assert(recentProgress.isDefined && recentProgress.get.numInputRows === 3, + s"recentProgress[${query.recentProgress.toList}] doesn't contain correct metrics") + } finally { + query.stop() + } + } +} + +/** A global object to collect events in the executor */ +object ForeachWriterSuite { + + trait Event + + case class Open(partition: Long, version: Long) extends Event + + case class Process[T](value: T) extends Event + + case class Close(error: Option[Throwable]) extends Event + + private val _allEvents = new ConcurrentLinkedQueue[Seq[Event]]() + + def addEvents(events: Seq[Event]): Unit = { + _allEvents.add(events) + } + + def allEvents(): Seq[Seq[Event]] = { + _allEvents.toArray(new Array[Seq[Event]](_allEvents.size())) + } + + def clear(): Unit = { + _allEvents.clear() + } +} + +/** A [[ForeachWriter]] that writes collected events to ForeachSinkSuite */ +class TestForeachWriter extends ForeachWriter[Int] { + ForeachWriterSuite.clear() + + private val events = mutable.ArrayBuffer[ForeachWriterSuite.Event]() + + override def open(partitionId: Long, version: Long): Boolean = { + events += ForeachWriterSuite.Open(partition = partitionId, version = version) + true + } + + override def process(value: Int): Unit = { + events += ForeachWriterSuite.Process(value) + } + + override def close(errorOrNull: Throwable): Unit = { + events += ForeachWriterSuite.Close(error = Option(errorOrNull)) + ForeachWriterSuite.addEvents(events) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/66a3a5a2/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 08749b4..20942ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.reader.DataReaderFactory --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org