This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 406d0e243cf [SPARK-40925][SQL][SS] Fix stateful operator late record filtering 406d0e243cf is described below commit 406d0e243cfec9b29df946e1a0e20ed5fe25e152 Author: Alex Balikov <91913242+alex-bali...@users.noreply.github.com> AuthorDate: Mon Oct 31 10:13:50 2022 +0900 [SPARK-40925][SQL][SS] Fix stateful operator late record filtering ### What changes were proposed in this pull request? This PR fixes the input late record filtering done by stateful operators to allow for chaining of stateful operators. Currently stateful operators are initialized with the current microbatch watermark and perform both input late record filtering and state eviction (e.g. producing aggregations) using the same watermark value. The state evicted (or aggregates produced) due to watermark advancing is behind the watermark and thus effectively late - if a following stateful operator consume [...] This PR provides two watermark values to the stateful operators - one from the previous microbatch to be used for late record filtering and the one from the current microbatch (as in the existing code) to be used for state eviction. This solves the above problem of the broken late record filtering. Note that this PR still does not solve the issue of time-interval stream join producing records delayed against the watermark. Therefore time-interval streaming join followed by stateful operators is still not supported. That will be fixed in a follow up PR (and a SPIP) effectively replacing the single global watermark with conceptually watermarks per operator. Also, the stateful operator chains unblocked by this PR (e.g. a chain of window aggregations) are still blocked by the unsupported operations checker. The new test for these scenarios - MultiStatefulOperatorsSuite has to explicitly disable the unsupported ops check. This again will be fixed in a follow-up PR. ### Why are the changes needed? The PR allows Spark Structured Streaming to support chaining of stateful operators e.g. chaining of time window aggregations which is a meaningful streaming scenario. ### Does this PR introduce _any_ user-facing change? With this PR, chains of stateful operators will be supported in Spark Structured Streaming. ### How was this patch tested? Added a new test suite - MultiStatefulOperatorsSuite Closes #38405 from alex-balikov/multiple_stateful-ops-base. Lead-authored-by: Alex Balikov <91913242+alex-bali...@users.noreply.github.com> Co-authored-by: Alex Balikov <alex.bali...@databricks.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../sql/catalyst/expressions/SessionWindow.scala | 24 +- .../sql/catalyst/expressions/TimeWindow.scala | 20 +- .../org/apache/spark/sql/internal/SQLConf.scala | 16 + .../spark/sql/execution/QueryExecution.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 6 +- .../spark/sql/execution/aggregate/AggUtils.scala | 9 +- .../FlatMapGroupsInPandasWithStateExec.scala | 13 +- .../streaming/FlatMapGroupsWithStateExec.scala | 27 +- .../execution/streaming/IncrementalExecution.scala | 44 ++- .../execution/streaming/MicroBatchExecution.scala | 1 + .../sql/execution/streaming/OffsetSeqLog.scala | 4 + .../streaming/StreamingSymmetricHashJoinExec.scala | 15 +- .../StreamingSymmetricHashJoinHelper.scala | 6 +- .../streaming/continuous/ContinuousExecution.scala | 1 + .../streaming/sources/MicroBatchWrite.scala | 2 +- .../execution/streaming/statefulOperators.scala | 108 +++-- .../streaming/FlatMapGroupsWithStateSuite.scala | 2 +- .../streaming/MultiStatefulOperatorsSuite.scala | 440 +++++++++++++++++++++ .../spark/sql/streaming/StreamingJoinSuite.scala | 2 +- 19 files changed, 661 insertions(+), 81 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala index 77e8dfde87b..02273b0c461 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala @@ -68,11 +68,29 @@ case class SessionWindow(timeColumn: Expression, gapDuration: Expression) extend with Unevaluable with NonSQLExpression { + private def inputTypeOnTimeColumn: AbstractDataType = { + TypeCollection( + AnyTimestampType, + // Below two types cover both time window & session window, since they produce the same type + // of output as window column. + new StructType() + .add(StructField("start", TimestampType)) + .add(StructField("end", TimestampType)), + new StructType() + .add(StructField("start", TimestampNTZType)) + .add(StructField("end", TimestampNTZType)) + ) + } + + // NOTE: if the window column is given as a time column, we resolve it to the point of time, + // which resolves to either TimestampType or TimestampNTZType. That means, timeColumn may not + // be "resolved", so it is safe to not rely on the data type of timeColumn directly. + override def children: Seq[Expression] = Seq(timeColumn, gapDuration) - override def inputTypes: Seq[AbstractDataType] = Seq(AnyTimestampType, AnyDataType) + override def inputTypes: Seq[AbstractDataType] = Seq(inputTypeOnTimeColumn, AnyDataType) override def dataType: DataType = new StructType() - .add(StructField("start", timeColumn.dataType)) - .add(StructField("end", timeColumn.dataType)) + .add(StructField("start", children.head.dataType)) + .add(StructField("end", children.head.dataType)) // This expression is replaced in the analyzer. override lazy val resolved = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 93c1074dfbe..bc9b7de7464 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -96,8 +96,26 @@ case class TimeWindow( this(timeColumn, windowDuration, windowDuration) } + private def inputTypeOnTimeColumn: AbstractDataType = { + TypeCollection( + AnyTimestampType, + // Below two types cover both time window & session window, since they produce the same type + // of output as window column. + new StructType() + .add(StructField("start", TimestampType)) + .add(StructField("end", TimestampType)), + new StructType() + .add(StructField("start", TimestampNTZType)) + .add(StructField("end", TimestampNTZType)) + ) + } + + // NOTE: if the window column is given as a time column, we resolve it to the point of time, + // which resolves to either TimestampType or TimestampNTZType. That means, timeColumn may not + // be "resolved", so it is safe to not rely on the data type of timeColumn directly. + override def child: Expression = timeColumn - override def inputTypes: Seq[AbstractDataType] = Seq(AnyTimestampType) + override def inputTypes: Seq[AbstractDataType] = Seq(inputTypeOnTimeColumn) override def dataType: DataType = new StructType() .add(StructField("start", child.dataType)) .add(StructField("end", child.dataType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0a60c6b0265..3854f3190a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1941,6 +1941,22 @@ object SQLConf { .booleanConf .createWithDefault(true) + val STATEFUL_OPERATOR_ALLOW_MULTIPLE = + buildConf("spark.sql.streaming.statefulOperator.allowMultiple") + .internal() + .doc("When true, multiple stateful operators are allowed to be present in a streaming " + + "pipeline. The support for multiple stateful operators introduces a minor (semantically " + + "correct) change in respect to late record filtering - late records are detected and " + + "filtered in respect to the watermark from the previous microbatch instead of the " + + "current one. This is a behavior change for Spark streaming pipelines and we allow " + + "users to revert to the previous behavior of late record filtering (late records are " + + "detected and filtered by comparing with the current microbatch watermark) by setting " + + "the flag value to false. In this mode, only a single stateful operator will be allowed " + + "in a streaming pipeline.") + .version("3.4.0") + .booleanConf + .createWithDefault(true) + val STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION = buildConf("spark.sql.streaming.statefulOperator.useStrictDistribution") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 10b763b1b51..8bf5d3d317b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -229,7 +229,7 @@ class QueryExecution( // output mode does not matter since there is no `Sink`. new IncrementalExecution( sparkSession, logical, OutputMode.Append(), "<unknown>", - UUID.randomUUID, UUID.randomUUID, 0, OffsetSeqMetadata(0, 0)) + UUID.randomUUID, UUID.randomUUID, 0, None, OffsetSeqMetadata(0, 0)) } else { this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 03e722a86fb..b96e47846fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -678,7 +678,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val execPlan = FlatMapGroupsWithStateExec( func, keyDeser, valueDeser, sDeser, groupAttr, stateGroupAttr, dataAttr, sda, outputAttr, None, stateEnc, stateVersion, outputMode, timeout, batchTimestampMs = None, - eventTimeWatermark = None, planLater(initialState), hasInitialState, planLater(child) + eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None, + planLater(initialState), hasInitialState, planLater(child) ) execPlan :: Nil case _ => @@ -697,7 +698,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) val execPlan = python.FlatMapGroupsInPandasWithStateExec( func, groupAttr, outputAttr, stateType, None, stateVersion, outputMode, timeout, - batchTimestampMs = None, eventTimeWatermark = None, planLater(child) + batchTimestampMs = None, eventTimeWatermarkForLateEvents = None, + eventTimeWatermarkForEviction = None, planLater(child) ) execPlan :: Nil case _ => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 579a00c7996..557f0e897ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -369,7 +369,8 @@ object AggUtils { groupingAttributes, stateInfo = None, outputMode = None, - eventTimeWatermark = None, + eventTimeWatermarkForLateEvents = None, + eventTimeWatermarkForEviction = None, stateFormatVersion = stateFormatVersion, partialMerged2) @@ -472,7 +473,8 @@ object AggUtils { // shuffle & sort happens here: most of details are also handled in this physical plan val restored = SessionWindowStateStoreRestoreExec(groupingWithoutSessionAttributes, - sessionExpression.toAttribute, stateInfo = None, eventTimeWatermark = None, + sessionExpression.toAttribute, stateInfo = None, + eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None, stateFormatVersion, partialMerged1) val mergedSessions = { @@ -501,7 +503,8 @@ object AggUtils { sessionExpression.toAttribute, stateInfo = None, outputMode = None, - eventTimeWatermark = None, + eventTimeWatermarkForLateEvents = None, + eventTimeWatermarkForEviction = None, stateFormatVersion, mergedSessions) val finalAndCompleteAggregate: SparkPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index 3b096f07241..bc1a5ae17e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -48,7 +48,8 @@ import org.apache.spark.util.CompletionIterator * @param outputMode the output mode of `functionExpr` * @param timeoutConf used to timeout groups that have not received data in a while * @param batchTimestampMs processing timestamp of the current batch. - * @param eventTimeWatermark event time watermark for the current batch + * @param eventTimeWatermarkForLateEvents event time watermark for filtering late events + * @param eventTimeWatermarkForEviction event time watermark for state eviction * @param child logical plan of the underlying data */ case class FlatMapGroupsInPandasWithStateExec( @@ -61,9 +62,9 @@ case class FlatMapGroupsInPandasWithStateExec( outputMode: OutputMode, timeoutConf: GroupStateTimeout, batchTimestampMs: Option[Long], - eventTimeWatermark: Option[Long], - child: SparkPlan) - extends UnaryExecNode with PythonSQLMetrics with FlatMapGroupsWithStateExecBase { + eventTimeWatermarkForLateEvents: Option[Long], + eventTimeWatermarkForEviction: Option[Long], + child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase { // TODO(SPARK-40444): Add the support of initial state. override protected val initialStateDeserializer: Expression = null @@ -132,7 +133,7 @@ case class FlatMapGroupsInPandasWithStateExec( if (isTimeoutEnabled) { val timeoutThreshold = timeoutConf match { case ProcessingTimeTimeout => batchTimestampMs.get - case EventTimeTimeout => eventTimeWatermark.get + case EventTimeTimeout => eventTimeWatermarkForEviction.get case _ => throw new IllegalStateException( s"Cannot filter timed out keys for $timeoutConf") @@ -176,7 +177,7 @@ case class FlatMapGroupsInPandasWithStateExec( val groupedState = GroupStateImpl.createForStreaming( Option(stateData.stateObj).map { r => assert(r.isInstanceOf[Row]); r }, batchTimestampMs.getOrElse(NO_TIMESTAMP), - eventTimeWatermark.getOrElse(NO_TIMESTAMP), + eventTimeWatermarkForEviction.getOrElse(NO_TIMESTAMP), timeoutConf, hasTimedOut = hasTimedOut, watermarkPresent).asInstanceOf[GroupStateImpl[Row]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 790a652f211..138029e76c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -54,8 +54,8 @@ trait FlatMapGroupsWithStateExecBase protected val outputMode: OutputMode protected val timeoutConf: GroupStateTimeout protected val batchTimestampMs: Option[Long] - val eventTimeWatermark: Option[Long] - + val eventTimeWatermarkForLateEvents: Option[Long] + val eventTimeWatermarkForEviction: Option[Long] protected val isTimeoutEnabled: Boolean = timeoutConf != NoTimeout protected val watermarkPresent: Boolean = child.output.exists { case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true @@ -96,7 +96,8 @@ trait FlatMapGroupsWithStateExecBase true // Always run batches to process timeouts case EventTimeTimeout => // Process another non-data batch only if the watermark has changed in this executed plan - eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + eventTimeWatermarkForEviction.isDefined && + newMetadata.batchWatermarkMs > eventTimeWatermarkForEviction.get case _ => false } @@ -125,7 +126,7 @@ trait FlatMapGroupsWithStateExecBase var timeoutProcessingStartTimeNs = currentTimeNs // If timeout is based on event time, then filter late data based on watermark - val filteredIter = watermarkPredicateForData match { + val filteredIter = watermarkPredicateForDataForLateEvents match { case Some(predicate) if timeoutConf == EventTimeTimeout => applyRemovingRowsOlderThanWatermark(iter, predicate) case _ => @@ -189,8 +190,12 @@ trait FlatMapGroupsWithStateExecBase case ProcessingTimeTimeout => require(batchTimestampMs.nonEmpty) case EventTimeTimeout => - require(eventTimeWatermark.nonEmpty) // watermark value has been populated - require(watermarkExpression.nonEmpty) // input schema has watermark attribute + // watermark value has been populated + require(eventTimeWatermarkForLateEvents.nonEmpty) + require(eventTimeWatermarkForEviction.nonEmpty) + // input schema has watermark attribute + require(watermarkExpressionForLateEvents.nonEmpty) + require(watermarkExpressionForEviction.nonEmpty) case _ => } @@ -310,7 +315,7 @@ trait FlatMapGroupsWithStateExecBase if (isTimeoutEnabled) { val timeoutThreshold = timeoutConf match { case ProcessingTimeTimeout => batchTimestampMs.get - case EventTimeTimeout => eventTimeWatermark.get + case EventTimeTimeout => eventTimeWatermarkForEviction.get case _ => throw new IllegalStateException( s"Cannot filter timed out keys for $timeoutConf") @@ -354,7 +359,8 @@ trait FlatMapGroupsWithStateExecBase * @param outputMode the output mode of `func` * @param timeoutConf used to timeout groups that have not received data in a while * @param batchTimestampMs processing timestamp of the current batch. - * @param eventTimeWatermark event time watermark for the current batch + * @param eventTimeWatermarkForLateEvents event time watermark for filtering late events + * @param eventTimeWatermarkForEviction event time watermark for state eviction * @param initialState the user specified initial state * @param hasInitialState indicates whether the initial state is provided or not * @param child the physical plan for the underlying data @@ -375,7 +381,8 @@ case class FlatMapGroupsWithStateExec( outputMode: OutputMode, timeoutConf: GroupStateTimeout, batchTimestampMs: Option[Long], - eventTimeWatermark: Option[Long], + eventTimeWatermarkForLateEvents: Option[Long], + eventTimeWatermarkForEviction: Option[Long], initialState: SparkPlan, hasInitialState: Boolean, child: SparkPlan) @@ -410,7 +417,7 @@ case class FlatMapGroupsWithStateExec( val groupState = GroupStateImpl.createForStreaming( Option(stateData.stateObj), batchTimestampMs.getOrElse(NO_TIMESTAMP), - eventTimeWatermark.getOrElse(NO_TIMESTAMP), + eventTimeWatermarkForEviction.getOrElse(NO_TIMESTAMP), timeoutConf, hasTimedOut, watermarkPresent) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index f386282a0b3..574709d05b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -48,6 +48,7 @@ class IncrementalExecution( val queryId: UUID, val runId: UUID, val currentBatchId: Long, + val prevOffsetSeqMetadata: Option[OffsetSeqMetadata], val offsetSeqMetadata: OffsetSeqMetadata) extends QueryExecution(sparkSession, logicalPlan) with Logging { @@ -112,6 +113,17 @@ class IncrementalExecution( numStateStores) } + // Watermarks to use for late record filtering and state eviction in stateful operators. + // Using the previous watermark for late record filtering is a Spark behavior change so we allow + // this to be disabled. + val eventTimeWatermarkForEviction = offsetSeqMetadata.batchWatermarkMs + val eventTimeWatermarkForLateEvents = + if (sparkSession.conf.get(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE)) { + prevOffsetSeqMetadata.getOrElse(offsetSeqMetadata).batchWatermarkMs + } else { + eventTimeWatermarkForEviction + } + /** Locates save/restore pairs surrounding aggregation. */ val state = new Rule[SparkPlan] { @@ -158,7 +170,7 @@ class IncrementalExecution( case a: UpdatingSessionsExec if a.isStreaming => a.copy(numShufflePartitions = Some(numStateStores)) - case StateStoreSaveExec(keys, None, None, None, stateFormatVersion, + case StateStoreSaveExec(keys, None, None, None, None, stateFormatVersion, UnaryExecNode(agg, StateStoreRestoreExec(_, None, _, child))) => val aggStateInfo = nextStatefulOperationStateInfo @@ -166,7 +178,8 @@ class IncrementalExecution( keys, Some(aggStateInfo), Some(outputMode), - Some(offsetSeqMetadata.batchWatermarkMs), + eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents), + eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction), stateFormatVersion, agg.withNewChildren( StateStoreRestoreExec( @@ -175,32 +188,36 @@ class IncrementalExecution( stateFormatVersion, child) :: Nil)) - case SessionWindowStateStoreSaveExec(keys, session, None, None, None, stateFormatVersion, + case SessionWindowStateStoreSaveExec(keys, session, None, None, None, None, + stateFormatVersion, UnaryExecNode(agg, - SessionWindowStateStoreRestoreExec(_, _, None, None, _, child))) => + SessionWindowStateStoreRestoreExec(_, _, None, None, None, _, child))) => val aggStateInfo = nextStatefulOperationStateInfo SessionWindowStateStoreSaveExec( keys, session, Some(aggStateInfo), Some(outputMode), - Some(offsetSeqMetadata.batchWatermarkMs), + eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents), + eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction), stateFormatVersion, agg.withNewChildren( SessionWindowStateStoreRestoreExec( keys, session, Some(aggStateInfo), - Some(offsetSeqMetadata.batchWatermarkMs), + eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents), + eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction), stateFormatVersion, child) :: Nil)) - case StreamingDeduplicateExec(keys, child, None, None) => + case StreamingDeduplicateExec(keys, child, None, None, None) => StreamingDeduplicateExec( keys, child, Some(nextStatefulOperationStateInfo), - Some(offsetSeqMetadata.batchWatermarkMs)) + eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents), + eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction)) case m: FlatMapGroupsWithStateExec => // We set this to true only for the first batch of the streaming query. @@ -208,7 +225,8 @@ class IncrementalExecution( m.copy( stateInfo = Some(nextStatefulOperationStateInfo), batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), - eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs), + eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents), + eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction), hasInitialState = hasInitialState ) @@ -216,17 +234,19 @@ class IncrementalExecution( m.copy( stateInfo = Some(nextStatefulOperationStateInfo), batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), - eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs) + eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents), + eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction) ) case j: StreamingSymmetricHashJoinExec => j.copy( stateInfo = Some(nextStatefulOperationStateInfo), - eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs), + eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents), + eventTimeWatermarkForEviction = Some(eventTimeWatermarkForLateEvents), stateWatermarkPredicates = StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates( j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full, - Some(offsetSeqMetadata.batchWatermarkMs))) + Some(eventTimeWatermarkForEviction))) case l: StreamingGlobalLimitExec => l.copy( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 5f8fb93827b..7ed19b35114 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -695,6 +695,7 @@ class MicroBatchExecution( id, runId, currentBatchId, + offsetLog.offsetSeqMetadataForBatchId(currentBatchId - 1), offsetSeqMetadata) lastExecution.executedPlan // Force the lazy generation of execution plan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala index 82e50263893..7f00717ea4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala @@ -102,6 +102,10 @@ class OffsetSeqLog(sparkSession: SparkSession, path: String) } } } + + def offsetSeqMetadataForBatchId(batchId: Long): Option[OffsetSeqMetadata] = { + if (batchId < 0) None else get(batchId).flatMap(_.metadata) + } } object OffsetSeqLog { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 4a8f3b18c09..dfde4156812 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -119,7 +119,8 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration} * @param condition Conditions to filter rows, split by left, right, and joined. See * [[JoinConditionSplitPredicates]] * @param stateInfo Version information required to read join state (buffered rows) - * @param eventTimeWatermark Watermark of input event, same for both sides + * @param eventTimeWatermarkForLateEvents Watermark for filtering late events, same for both sides + * @param eventTimeWatermarkForEviction Watermark for state eviction * @param stateWatermarkPredicates Predicates for removal of state, see * [[JoinStateWatermarkPredicates]] * @param left Left child plan @@ -131,7 +132,8 @@ case class StreamingSymmetricHashJoinExec( joinType: JoinType, condition: JoinConditionSplitPredicates, stateInfo: Option[StatefulOperatorStateInfo], - eventTimeWatermark: Option[Long], + eventTimeWatermarkForLateEvents: Option[Long], + eventTimeWatermarkForEviction: Option[Long], stateWatermarkPredicates: JoinStateWatermarkPredicates, stateFormatVersion: Int, left: SparkPlan, @@ -148,7 +150,8 @@ case class StreamingSymmetricHashJoinExec( this( leftKeys, rightKeys, joinType, JoinConditionSplitPredicates(condition, left, right), - stateInfo = None, eventTimeWatermark = None, + stateInfo = None, + eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None, stateWatermarkPredicates = JoinStateWatermarkPredicates(), stateFormatVersion, left, right) } @@ -222,7 +225,8 @@ case class StreamingSymmetricHashJoinExec( // Latest watermark value is more than that used in this previous executed plan val watermarkHasChanged = - eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + eventTimeWatermarkForEviction.isDefined && + newMetadata.batchWatermarkMs > eventTimeWatermarkForEviction.get watermarkUsedForStateCleanup && watermarkHasChanged } @@ -555,7 +559,8 @@ case class StreamingSymmetricHashJoinExec( val watermarkAttribute = inputAttributes.find(_.metadata.contains(delayKey)) val nonLateRows = - WatermarkSupport.watermarkExpression(watermarkAttribute, eventTimeWatermark) match { + WatermarkSupport.watermarkExpression( + watermarkAttribute, eventTimeWatermarkForLateEvents) match { case Some(watermarkExpr) => val predicate = Predicate.create(watermarkExpr, inputAttributes) applyRemovingRowsOlderThanWatermark(inputIter, predicate) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala index 2f62dbd7ec5..7bf6381e08f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala @@ -137,7 +137,7 @@ object StreamingSymmetricHashJoinHelper extends Logging { leftKeys: Seq[Expression], rightKeys: Seq[Expression], condition: Option[Expression], - eventTimeWatermark: Option[Long]): JoinStateWatermarkPredicates = { + eventTimeWatermarkForEviction: Option[Long]): JoinStateWatermarkPredicates = { // Join keys of both sides generate rows of the same fields, that is, same sequence of data @@ -172,7 +172,7 @@ object StreamingSymmetricHashJoinHelper extends Logging { joinKeyOrdinalForWatermark.get, oneSideJoinKeys(joinKeyOrdinalForWatermark.get).dataType, oneSideJoinKeys(joinKeyOrdinalForWatermark.get).nullable) - val expr = watermarkExpression(Some(keyExprWithWatermark), eventTimeWatermark) + val expr = watermarkExpression(Some(keyExprWithWatermark), eventTimeWatermarkForEviction) expr.map(JoinStateKeyWatermarkPredicate.apply _) } else if (isWatermarkDefinedOnInput) { // case 2 in the StreamingSymmetricHashJoinExec docs @@ -180,7 +180,7 @@ object StreamingSymmetricHashJoinHelper extends Logging { attributesToFindStateWatermarkFor = AttributeSet(oneSideInputAttributes), attributesWithEventWatermark = AttributeSet(otherSideInputAttributes), condition, - eventTimeWatermark) + eventTimeWatermarkForEviction) val inputAttributeWithWatermark = oneSideInputAttributes.find(_.metadata.contains(delayKey)) val expr = watermarkExpression(inputAttributeWithWatermark, stateValueWatermark) expr.map(JoinStateValueWatermarkPredicate.apply _) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 5b620eec25f..e8092e072bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -218,6 +218,7 @@ class ContinuousExecution( id, runId, currentBatchId, + None, offsetSeqMetadata) lastExecution.executedPlan // Force the lazy generation of execution plan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala index 0a603a3b141..3f474ea533c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactor */ class MicroBatchWrite(epochId: Long, val writeSupport: StreamingWrite) extends BatchWrite { override def toString: String = { - s"MicroBathWrite[epoch: $epochId, writer: $writeSupport]" + s"MicroBatchWrite[epoch: $epochId, writer: $writeSupport]" } override def commit(messages: Array[WriterCommitMessage]): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index b540f9f0093..457e5f80ae6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -212,34 +212,73 @@ trait WatermarkSupport extends SparkPlan { /** The keys that may have a watermark attribute. */ def keyExpressions: Seq[Attribute] - /** The watermark value. */ - def eventTimeWatermark: Option[Long] + /** + * The watermark value for filtering late events/records. This should be the previous + * batch state eviction watermark. + */ + def eventTimeWatermarkForLateEvents: Option[Long] + /** + * The watermark value for closing aggregates and evicting state. + * It is different from the late events filtering watermark (consider chained aggregators + * agg1 -> agg2: agg1 evicts state which will be effectively late against the eviction watermark + * but should not be late for agg2 input late record filtering watermark. Thus agg1 and agg2 use + * the current batch watermark for state eviction but the previous batch watermark for late + * record filtering. + */ + def eventTimeWatermarkForEviction: Option[Long] + + /** Generate an expression that matches data older than late event filtering watermark */ + lazy val watermarkExpressionForLateEvents: Option[Expression] = + watermarkExpression(eventTimeWatermarkForLateEvents) + /** Generate an expression that matches data older than the state eviction watermark */ + lazy val watermarkExpressionForEviction: Option[Expression] = + watermarkExpression(eventTimeWatermarkForEviction) /** Generate an expression that matches data older than the watermark */ - lazy val watermarkExpression: Option[Expression] = { + private def watermarkExpression(watermark: Option[Long]): Option[Expression] = { WatermarkSupport.watermarkExpression( - child.output.find(_.metadata.contains(EventTimeWatermark.delayKey)), - eventTimeWatermark) + child.output.find(_.metadata.contains(EventTimeWatermark.delayKey)), watermark) } - /** Predicate based on keys that matches data older than the watermark */ - lazy val watermarkPredicateForKeys: Option[BasePredicate] = watermarkExpression.flatMap { e => - if (keyExpressions.exists(_.metadata.contains(EventTimeWatermark.delayKey))) { - Some(Predicate.create(e, keyExpressions)) - } else { - None + /** Predicate based on keys that matches data older than the late event filtering watermark */ + lazy val watermarkPredicateForKeysForLateEvents: Option[BasePredicate] = + watermarkPredicateForKeys(watermarkExpressionForLateEvents) + + /** Generate an expression that matches data older than the state eviction watermark */ + lazy val watermarkPredicateForKeysForEviction: Option[BasePredicate] = + watermarkPredicateForKeys(watermarkExpressionForEviction) + + private def watermarkPredicateForKeys( + watermarkExpression: Option[Expression]): Option[BasePredicate] = { + watermarkExpression.flatMap { e => + if (keyExpressions.exists(_.metadata.contains(EventTimeWatermark.delayKey))) { + Some(Predicate.create(e, keyExpressions)) + } else { + None + } } } - /** Predicate based on the child output that matches data older than the watermark. */ - lazy val watermarkPredicateForData: Option[BasePredicate] = + /** + * Predicate based on the child output that matches data older than the watermark for late events + * filtering. + */ + lazy val watermarkPredicateForDataForLateEvents: Option[BasePredicate] = + watermarkPredicateForData(watermarkExpressionForLateEvents) + + lazy val watermarkPredicateForDataForEviction: Option[BasePredicate] = + watermarkPredicateForData(watermarkExpressionForEviction) + + private def watermarkPredicateForData( + watermarkExpression: Option[Expression]): Option[BasePredicate] = { watermarkExpression.map(Predicate.create(_, child.output)) + } protected def removeKeysOlderThanWatermark(store: StateStore): Unit = { - if (watermarkPredicateForKeys.nonEmpty) { + if (watermarkPredicateForKeysForEviction.nonEmpty) { val numRemovedStateRows = longMetric("numRemovedStateRows") store.iterator().foreach { rowPair => - if (watermarkPredicateForKeys.get.eval(rowPair.key)) { + if (watermarkPredicateForKeysForEviction.get.eval(rowPair.key)) { store.remove(rowPair.key) numRemovedStateRows += 1 } @@ -250,10 +289,10 @@ trait WatermarkSupport extends SparkPlan { protected def removeKeysOlderThanWatermark( storeManager: StreamingAggregationStateManager, store: StateStore): Unit = { - if (watermarkPredicateForKeys.nonEmpty) { + if (watermarkPredicateForKeysForEviction.nonEmpty) { val numRemovedStateRows = longMetric("numRemovedStateRows") storeManager.keys(store).foreach { keyRow => - if (watermarkPredicateForKeys.get.eval(keyRow)) { + if (watermarkPredicateForKeysForEviction.get.eval(keyRow)) { storeManager.remove(store, keyRow) numRemovedStateRows += 1 } @@ -354,7 +393,8 @@ case class StateStoreSaveExec( keyExpressions: Seq[Attribute], stateInfo: Option[StatefulOperatorStateInfo] = None, outputMode: Option[OutputMode] = None, - eventTimeWatermark: Option[Long] = None, + eventTimeWatermarkForLateEvents: Option[Long] = None, + eventTimeWatermarkForEviction: Option[Long] = None, stateFormatVersion: Int, child: SparkPlan) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { @@ -407,7 +447,7 @@ case class StateStoreSaveExec( case Some(Append) => allUpdatesTimeMs += timeTakenMs { val filteredIter = applyRemovingRowsOlderThanWatermark(iter, - watermarkPredicateForData.get) + watermarkPredicateForDataForLateEvents.get) while (filteredIter.hasNext) { val row = filteredIter.next().asInstanceOf[UnsafeRow] stateManager.put(store, row) @@ -423,7 +463,7 @@ case class StateStoreSaveExec( var removedValueRow: InternalRow = null while(rangeIter.hasNext && removedValueRow == null) { val rowPair = rangeIter.next() - if (watermarkPredicateForKeys.get.eval(rowPair.key)) { + if (watermarkPredicateForKeysForEviction.get.eval(rowPair.key)) { stateManager.remove(store, rowPair.key) numRemovedStateRows += 1 removedValueRow = rowPair.value @@ -453,7 +493,7 @@ case class StateStoreSaveExec( new NextIterator[InternalRow] { // Filter late date using watermark if specified - private[this] val baseIterator = watermarkPredicateForData match { + private[this] val baseIterator = watermarkPredicateForDataForLateEvents match { case Some(predicate) => applyRemovingRowsOlderThanWatermark(iter, predicate) case None => iter } @@ -507,8 +547,8 @@ case class StateStoreSaveExec( override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { (outputMode.contains(Append) || outputMode.contains(Update)) && - eventTimeWatermark.isDefined && - newMetadata.batchWatermarkMs > eventTimeWatermark.get + eventTimeWatermarkForEviction.isDefined && + newMetadata.batchWatermarkMs > eventTimeWatermarkForEviction.get } override protected def withNewChildInternal(newChild: SparkPlan): StateStoreSaveExec = @@ -525,7 +565,8 @@ case class SessionWindowStateStoreRestoreExec( keyWithoutSessionExpressions: Seq[Attribute], sessionExpression: Attribute, stateInfo: Option[StatefulOperatorStateInfo], - eventTimeWatermark: Option[Long], + eventTimeWatermarkForLateEvents: Option[Long], + eventTimeWatermarkForEviction: Option[Long], stateFormatVersion: Int, child: SparkPlan) extends UnaryExecNode with StateStoreReader with WatermarkSupport { @@ -555,7 +596,7 @@ case class SessionWindowStateStoreRestoreExec( Some(session.streams.stateStoreCoordinator)) { case (store, iter) => // We need to filter out outdated inputs - val filteredIterator = watermarkPredicateForData match { + val filteredIterator = watermarkPredicateForDataForLateEvents match { case Some(predicate) => iter.filter((row: InternalRow) => { val shouldKeep = !predicate.eval(row) if (!shouldKeep) longMetric("numRowsDroppedByWatermark") += 1 @@ -611,7 +652,8 @@ case class SessionWindowStateStoreSaveExec( sessionExpression: Attribute, stateInfo: Option[StatefulOperatorStateInfo] = None, outputMode: Option[OutputMode] = None, - eventTimeWatermark: Option[Long] = None, + eventTimeWatermarkForLateEvents: Option[Long] = None, + eventTimeWatermarkForEviction: Option[Long] = None, stateFormatVersion: Int, child: SparkPlan) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { @@ -667,7 +709,7 @@ case class SessionWindowStateStoreSaveExec( val removalStartTimeNs = System.nanoTime new NextIterator[InternalRow] { private val removedIter = stateManager.removeByValueCondition( - store, watermarkPredicateForData.get.eval) + store, watermarkPredicateForDataForEviction.get.eval) override protected def getNext(): InternalRow = { if (!removedIter.hasNext) { @@ -704,8 +746,8 @@ case class SessionWindowStateStoreSaveExec( override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { (outputMode.contains(Append) || outputMode.contains(Update)) && - eventTimeWatermark.isDefined && - newMetadata.batchWatermarkMs > eventTimeWatermark.get + eventTimeWatermarkForEviction.isDefined && + newMetadata.batchWatermarkMs > eventTimeWatermarkForEviction.get } private def putToStore(iter: Iterator[InternalRow], store: StateStore): Unit = { @@ -775,7 +817,8 @@ case class StreamingDeduplicateExec( keyExpressions: Seq[Attribute], child: SparkPlan, stateInfo: Option[StatefulOperatorStateInfo] = None, - eventTimeWatermark: Option[Long] = None) + eventTimeWatermarkForLateEvents: Option[Long] = None, + eventTimeWatermarkForEviction: Option[Long] = None) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { /** Distribute by grouping attributes */ @@ -807,7 +850,7 @@ case class StreamingDeduplicateExec( val commitTimeMs = longMetric("commitTimeMs") val numDroppedDuplicateRows = longMetric("numDroppedDuplicateRows") - val baseIterator = watermarkPredicateForData match { + val baseIterator = watermarkPredicateForDataForLateEvents match { case Some(predicate) => applyRemovingRowsOlderThanWatermark(iter, predicate) case None => iter } @@ -851,7 +894,8 @@ case class StreamingDeduplicateExec( override def shortName: String = "dedupe" override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { - eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + eventTimeWatermarkForEviction.isDefined && + newMetadata.batchWatermarkMs > eventTimeWatermarkForEviction.get } override protected def withNewChildInternal(newChild: SparkPlan): StreamingDeduplicateExec = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 14f083bbd30..49f4214ac1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -1048,7 +1048,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { hasInitialState, sga, sda, se, i, c) => FlatMapGroupsWithStateExec( f, k, v, se, g, sga, d, sda, o, None, s, stateFormatVersion, m, t, - Some(currentBatchTimestamp), Some(currentBatchWatermark), + Some(currentBatchTimestamp), Some(0), Some(currentBatchWatermark), RDDScanExec(g, emptyRdd, "rdd"), hasInitialState, RDDScanExec(g, emptyRdd, "rdd")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala new file mode 100644 index 00000000000..0a3ea40a677 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala @@ -0,0 +1,440 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.functions._ + +// Tests for the multiple stateful operators support. +class MultiStatefulOperatorsSuite + extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { + + import testImplicits._ + + before { + SparkSession.setActiveSession(spark) // set this before force initializing 'joinExec' + spark.streams.stateStoreCoordinator // initialize the lazy coordinator + } + + after { + StateStore.stop() + } + + test("window agg -> window agg, append mode") { + // TODO: SPARK-40940 - Fix the unsupported ops checker to allow chaining of stateful ops. + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val inputData = MemoryStream[Int] + + val stream = inputData.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "0 seconds") + .groupBy(window($"eventTime", "5 seconds").as("window")) + .agg(count("*").as("count")) + .groupBy(window($"window", "10 seconds")) + .agg(count("*").as("count"), sum("count").as("sum")) + .select($"window".getField("start").cast("long").as[Long], + $"count".as[Long], $"sum".as[Long]) + + testStream(stream)( + AddData(inputData, 10 to 21: _*), + // op1 W (0, 0) + // agg: [10, 15) 5, [15, 20) 5, [20, 25) 2 + // output: None + // state: [10, 15) 5, [15, 20) 5, [20, 25) 2 + // op2 W (0, 0) + // agg: None + // output: None + // state: None + + // no-data batch triggered + + // op1 W (0, 21) + // agg: None + // output: [10, 15) 5, [15, 20) 5 + // state: [20, 25) 2 + // op2 W (0, 21) + // agg: [10, 20) (2, 10) + // output: [10, 20) (2, 10) + // state: None + CheckNewAnswer((10, 2, 10)), + assertNumStateRows(Seq(0, 1)), + assertNumRowsDroppedByWatermark(Seq(0, 0)), + + AddData(inputData, 10 to 29: _*), + // op1 W (21, 21) + // agg: [10, 15) 5 - late, [15, 20) 5 - late, [20, 25) 5, [25, 30) 5 + // output: None + // state: [20, 25) 7, [25, 30) 5 + // op2 W (21, 21) + // agg: None + // output: None + // state: None + + // no-data batch triggered + + // op1 W (21, 29) + // agg: None + // output: [20, 25) 7 + // state: [25, 30) 5 + // op2 W (21, 29) + // agg: [20, 30) (1, 7) + // output: None + // state: [20, 30) (1, 7) + CheckNewAnswer(), + assertNumStateRows(Seq(1, 1)), + assertNumRowsDroppedByWatermark(Seq(0, 2)), + + // Move the watermark. + AddData(inputData, 30, 31), + // op1 W (29, 29) + // agg: [30, 35) 2 + // output: None + // state: [25, 30) 5 [30, 35) 2 + // op2 W (29, 29) + // agg: None + // output: None + // state: [20, 30) (1, 7) + + // no-data batch triggered + + // op1 W (29, 31) + // agg: None + // output: [25, 30) 5 + // state: [30, 35) 2 + // op2 W (29, 31) + // agg: [20, 30) (2, 12) + // output: [20, 30) (2, 12) + // state: None + CheckNewAnswer((20, 2, 12)), + assertNumStateRows(Seq(0, 1)), + assertNumRowsDroppedByWatermark(Seq(0, 0)) + ) + } + } + + test("agg -> agg -> agg, append mode") { + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val inputData = MemoryStream[Int] + + val stream = inputData.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "0 seconds") + .groupBy(window($"eventTime", "5 seconds").as("window")) + .agg(count("*").as("count")) + .groupBy(window(window_time($"window"), "10 seconds")) + .agg(count("*").as("count"), sum("count").as("sum")) + .groupBy(window(window_time($"window"), "20 seconds")) + .agg(count("*").as("count"), sum("sum").as("sum")) + .select( + $"window".getField("start").cast("long").as[Long], + $"window".getField("end").cast("long").as[Long], + $"count".as[Long], $"sum".as[Long]) + + testStream(stream)( + AddData(inputData, 0 to 37: _*), + // op1 W (0, 0) + // agg: [0, 5) 5, [5, 10) 5, [10, 15) 5, [15, 20) 5, [20, 25) 5, [25, 30) 5, [30, 35) 5, + // [35, 40) 3 + // output: None + // state: [0, 5) 5, [5, 10) 5, [10, 15) 5, [15, 20) 5, [20, 25) 5, [25, 30) 5, [30, 35) 5, + // [35, 40) 3 + // op2 W (0, 0) + // agg: None + // output: None + // state: None + // op3 W (0, 0) + // agg: None + // output: None + // state: None + + // no-data batch triggered + + // op1 W (0, 37) + // agg: None + // output: [0, 5) 5, [5, 10) 5, [10, 15) 5, [15, 20) 5, [20, 25) 5, [25, 30) 5, [30, 35) 5 + // state: [35, 40) 3 + // op2 W (0, 37) + // agg: [0, 10) (2, 10), [10, 20) (2, 10), [20, 30) (2, 10), [30, 40) (1, 5) + // output: [0, 10) (2, 10), [10, 20) (2, 10), [20, 30) (2, 10) + // state: [30, 40) (1, 5) + // op3 W (0, 37) + // agg: [0, 20) (2, 20), [20, 40) (1, 10) + // output: [0, 20) (2, 20) + // state: [20, 40) (1, 10) + CheckNewAnswer((0, 20, 2, 20)), + assertNumStateRows(Seq(1, 1, 1)), + assertNumRowsDroppedByWatermark(Seq(0, 0, 0)), + + AddData(inputData, 30 to 60: _*), + // op1 W (37, 37) + // dropped rows: [30, 35), 1 row <= note that 35, 36, 37 are still in effect + // agg: [35, 40) 8, [40, 45) 5, [45, 50) 5, [50, 55) 5, [55, 60) 5, [60, 65) 1 + // output: None + // state: [35, 40) 8, [40, 45) 5, [45, 50) 5, [50, 55) 5, [55, 60) 5, [60, 65) 1 + // op2 W (37, 37) + // output: None + // state: [30, 40) (1, 5) + // op3 W (37, 37) + // output: None + // state: [20, 40) (1, 10) + + // no-data batch + // op1 W (37, 60) + // output: [35, 40) 8, [40, 45) 5, [45, 50) 5, [50, 55) 5, [55, 60) 5 + // state: [60, 65) 1 + // op2 W (37, 60) + // agg: [30, 40) (2, 13), [40, 50) (2, 10), [50, 60), (2, 10) + // output: [30, 40) (2, 13), [40, 50) (2, 10), [50, 60), (2, 10) + // state: None + // op3 W (37, 60) + // agg: [20, 40) (2, 23), [40, 60) (2, 20) + // output: [20, 40) (2, 23), [40, 60) (2, 20) + // state: None + + CheckNewAnswer((20, 40, 2, 23), (40, 60, 2, 20)), + assertNumStateRows(Seq(0, 0, 1)), + assertNumRowsDroppedByWatermark(Seq(0, 0, 1)) + ) + } + } + + test("stream deduplication -> aggregation, append mode") { + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val inputData = MemoryStream[Int] + + val deduplication = inputData.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "10 seconds") + .dropDuplicates("value", "eventTime") + + val windowedAggregation = deduplication + .groupBy(window($"eventTime", "5 seconds").as("window")) + .agg(count("*").as("count"), sum("value").as("sum")) + .select($"window".getField("start").cast("long").as[Long], + $"count".as[Long]) + + testStream(windowedAggregation)( + AddData(inputData, 1 to 15: _*), + // op1 W (0, 0) + // input: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + // deduplicated: None + // output: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + // state: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + // op2 W (0, 0) + // agg: [0, 5) 4, [5, 10) 5 [10, 15) 5, [15, 20) 1 + // output: None + // state: [0, 5) 4, [5, 10) 5 [10, 15) 5, [15, 20) 1 + + // no-data batch triggered + + // op1 W (0, 5) + // agg: None + // output: None + // state: 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + // op2 W (0, 5) + // agg: None + // output: [0, 5) 4 + // state: [5, 10) 5 [10, 15) 5, [15, 20) 1 + CheckNewAnswer((0, 4)), + assertNumStateRows(Seq(3, 10)), + assertNumRowsDroppedByWatermark(Seq(0, 0)) + ) + } + } + + test("join -> window agg, append mode") { + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val input1 = MemoryStream[Int] + val inputDF1 = input1.toDF + .withColumnRenamed("value", "value1") + .withColumn("eventTime1", timestamp_seconds($"value1")) + .withWatermark("eventTime1", "0 seconds") + + val input2 = MemoryStream[Int] + val inputDF2 = input2.toDF + .withColumnRenamed("value", "value2") + .withColumn("eventTime2", timestamp_seconds($"value2")) + .withWatermark("eventTime2", "0 seconds") + + val stream = inputDF1.join(inputDF2, expr("eventTime1 = eventTime2"), "inner") + .groupBy(window($"eventTime1", "5 seconds").as("window")) + .agg(count("*").as("count")) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(stream)( + MultiAddData(input1, 1 to 4: _*)(input2, 1 to 4: _*), + + // op1 W (0, 0) + // join output: (1, 1), (2, 2), (3, 3), (4, 4) + // state: (1, 1), (2, 2), (3, 3), (4, 4) + // op2 W (0, 0) + // agg: [0, 5) 4 + // output: None + // state: [0, 5) 4 + + // no-data batch triggered + + // op1 W (0, 4) + // join output: None + // state: None + // op2 W (0, 4) + // agg: None + // output: None + // state: [0, 5) 4 + CheckNewAnswer(), + assertNumStateRows(Seq(1, 0)), + assertNumRowsDroppedByWatermark(Seq(0, 0)), + + // Move the watermark + MultiAddData(input1, 5)(input2, 5), + + // op1 W (4, 4) + // join output: (5, 5) + // state: (5, 5) + // op2 W (4, 4) + // agg: [5, 10) 1 + // output: None + // state: [0, 5) 4, [5, 10) 1 + + // no-data batch triggered + + // op1 W (4, 5) + // join output: None + // state: None + // op2 W (4, 5) + // agg: None + // output: [0, 5) 4 + // state: [5, 10) 1 + CheckNewAnswer((0, 4)), + assertNumStateRows(Seq(1, 0)), + assertNumRowsDroppedByWatermark(Seq(0, 0)) + ) + } + } + + test("aggregation -> stream deduplication, append mode") { + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val inputData = MemoryStream[Int] + + val aggStream = inputData.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "0 seconds") + .groupBy(window($"eventTime", "5 seconds").as("window")) + .agg(count("*").as("count")) + .withColumn("windowEnd", expr("window.end")) + + // dropDuplicates from aggStream without event time column for dropDuplicates - the + // state does not get trimmed due to watermark advancement. + val dedupNoEventTime = aggStream + .dropDuplicates("count", "windowEnd") + .select( + $"windowEnd".cast("long").as[Long], + $"count".as[Long]) + + testStream(dedupNoEventTime)( + AddData(inputData, 1, 5, 10, 15), + + // op1 W (0, 0) + // agg: [0, 5) 1, [5, 10) 1, [10, 15) 1, [15, 20) 1 + // output: None + // state: [0, 5) 1, [5, 10) 1, [10, 15) 1, [15, 20) 1 + // op2 W (0, 0) + // output: None + // state: None + + // no-data batch triggered + + // op1 W (0, 15) + // agg: None + // output: [0, 5) 1, [5, 10) 1, [10, 15) 1 + // state: [15, 20) 1 + // op2 W (0, 15) + // output: (5, 1), (10, 1), (15, 1) + // state: (5, 1), (10, 1), (15, 1) + + CheckNewAnswer((5, 1), (10, 1), (15, 1)), + assertNumStateRows(Seq(3, 1)), + assertNumRowsDroppedByWatermark(Seq(0, 0)) + ) + + // Similar to the above but add event time. The dedup state will get trimmed. + val dedupWithEventTime = aggStream + .withColumn("windowTime", expr("window_time(window)")) + .withColumn("windowTimeMicros", expr("unix_micros(windowTime)")) + .dropDuplicates("count", "windowEnd", "windowTime") + .select( + $"windowEnd".cast("long").as[Long], + $"windowTimeMicros".cast("long").as[Long], + $"count".as[Long]) + + testStream(dedupWithEventTime)( + AddData(inputData, 1, 5, 10, 15), + + // op1 W (0, 0) + // agg: [0, 5) 1, [5, 10) 1, [10, 15) 1, [15, 20) 1 + // output: None + // state: [0, 5) 1, [5, 10) 1, [10, 15) 1, [15, 20) 1 + // op2 W (0, 0) + // output: None + // state: None + + // no-data batch triggered + + // op1 W (0, 15) + // agg: None + // output: [0, 5) 1, [5, 10) 1, [10, 15) 1 + // state: [15, 20) 1 + // op2 W (0, 15) + // output: (5, 4999999, 1), (10, 9999999, 1), (15, 14999999, 1) + // state: None - trimmed by watermark + + CheckNewAnswer((5, 4999999, 1), (10, 9999999, 1), (15, 14999999, 1)), + assertNumStateRows(Seq(0, 1)), + assertNumRowsDroppedByWatermark(Seq(0, 0)) + ) + } + } + + private def assertNumStateRows(numTotalRows: Seq[Long]): AssertOnQuery = AssertOnQuery { q => + q.processAllAvailable() + val progressWithData = q.recentProgress.lastOption.get + val stateOperators = progressWithData.stateOperators + assert(stateOperators.size === numTotalRows.size) + assert(stateOperators.map(_.numRowsTotal).toSeq === numTotalRows) + true + } + + private def assertNumRowsDroppedByWatermark( + numRowsDroppedByWatermark: Seq[Long]): AssertOnQuery = AssertOnQuery { q => + q.processAllAvailable() + val progressWithData = q.recentProgress.filterNot { p => + // filter out batches which are falling into one of types: + // 1) doesn't execute the batch run + // 2) empty input batch + p.numInputRows == 0 + }.lastOption.get + val stateOperators = progressWithData.stateOperators + assert(stateOperators.size === numRowsDroppedByWatermark.size) + assert(stateOperators.map(_.numRowsDroppedByWatermark).toSeq === numRowsDroppedByWatermark) + true + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 8ef8c21e13a..40868f896f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -618,7 +618,7 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { val numPartitions = spark.sqlContext.conf.getConf(SQLConf.SHUFFLE_PARTITIONS) assert(query.lastExecution.executedPlan.collect { - case j @ StreamingSymmetricHashJoinExec(_, _, _, _, _, _, _, _, + case j @ StreamingSymmetricHashJoinExec(_, _, _, _, _, _, _, _, _, ShuffleExchangeExec(opA: HashPartitioning, _, _), ShuffleExchangeExec(opB: HashPartitioning, _, _)) if partitionExpressionsColumns(opA.expressions) === Seq("a", "b") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org