This is an automated email from the ASF dual-hosted git repository. tdas pushed a commit to branch branch-3.2 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push: new 0d60cb5 [SPARK-36132][SS][SQL] Support initial state for batch mode of flatMapGroupsWithState 0d60cb5 is described below commit 0d60cb51c01c13b0febe2ff7601db7303bfff56d Author: Rahul Mahadev <rahul.maha...@databricks.com> AuthorDate: Wed Jul 21 01:48:58 2021 -0400 [SPARK-36132][SS][SQL] Support initial state for batch mode of flatMapGroupsWithState ### What changes were proposed in this pull request? Adding support for accepting an initial state with flatMapGroupsWithState in batch mode. ### Why are the changes needed? SPARK-35897 added support for accepting an initial state for streaming queries using flatMapGroupsWithState. the code flow is separate for batch and streaming and required a different PR. ### Does this PR introduce _any_ user-facing change? Yes as discussed above flatMapGroupsWithState in batch mode can accept an initialState, previously this would throw an UnsupportedOperationException ### How was this patch tested? Added relevant unit tests in FlatMapGroupsWithStateSuite and modified the tests `JavaDatasetSuite` Closes #33336 from rahulsmahadev/flatMapGroupsWithStateBatch. Authored-by: Rahul Mahadev <rahul.maha...@databricks.com> Signed-off-by: Tathagata Das <tathagata.das1...@gmail.com> (cherry picked from commit efcce23b913ce0de961ac261050e3d6dbf261f6e) Signed-off-by: Tathagata Das <tathagata.das1...@gmail.com> --- .../analysis/UnsupportedOperationChecker.scala | 6 -- .../spark/sql/execution/SparkStrategies.scala | 11 +++- .../streaming/FlatMapGroupsWithStateExec.scala | 71 +++++++++++++++++++++- .../org/apache/spark/sql/JavaDatasetSuite.java | 18 +----- .../streaming/FlatMapGroupsWithStateSuite.scala | 52 ++++++++++++++++ 5 files changed, 130 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 13c7f75..321725d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -37,12 +37,6 @@ object UnsupportedOperationChecker extends Logging { case p if p.isStreaming => throwError("Queries with streaming sources must be executed with writeStream.start()")(p) - case f: FlatMapGroupsWithState => - if (f.hasInitialState) { - throwError("Initial state is not supported in [flatMap|map]GroupsWithState" + - " operation on a batch DataFrame/Dataset")(f) - } - case _ => } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6d10fa8..7624b15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -690,9 +690,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.MapGroups(f, key, value, grouping, data, objAttr, child) => execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil case logical.FlatMapGroupsWithState( - f, key, value, grouping, data, output, _, _, _, timeout, _, _, _, _, _, child) => - execution.MapGroupsExec( - f, key, value, grouping, data, output, timeout, planLater(child)) :: Nil + f, keyDeserializer, valueDeserializer, grouping, data, output, stateEncoder, outputMode, + isFlatMapGroupsWithState, timeout, hasInitialState, initialStateGroupAttrs, + initialStateDataAttrs, initialStateDeserializer, initialState, child) => + FlatMapGroupsWithStateExec.generateSparkPlanForBatchQueries( + f, keyDeserializer, valueDeserializer, initialStateDeserializer, grouping, + initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout, + hasInitialState, planLater(initialState), planLater(child) + ) :: Nil case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 03694d4..a00a622 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -309,9 +309,7 @@ case class FlatMapGroupsWithStateExec( var foundInitialStateForKey = false initialStateRowIter.foreach { initialStateRow => if (foundInitialStateForKey) { - throw new IllegalArgumentException("The initial state provided contained " + - "multiple rows(state) with the same key. Make sure to de-duplicate the " + - "initial state before passing it.") + FlatMapGroupsWithStateExec.foundDuplicateInitialKeyException() } foundInitialStateForKey = true val initStateObj = getStateObj.get(initialStateRow) @@ -403,3 +401,70 @@ case class FlatMapGroupsWithStateExec( copy(child = newLeft, initialState = newRight) } +object FlatMapGroupsWithStateExec { + + def foundDuplicateInitialKeyException(): Exception = { + throw new IllegalArgumentException("The initial state provided contained " + + "multiple rows(state) with the same key. Make sure to de-duplicate the " + + "initial state before passing it.") + } + + /** + * Plan logical flatmapGroupsWIthState for batch queries + * If the initial state is provided, we create an instance of the CoGroupExec, if the initial + * state is not provided we create an instance of the MapGroupsExec + */ + // scalastyle:off argcount + def generateSparkPlanForBatchQueries( + userFunc: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any], + keyDeserializer: Expression, + valueDeserializer: Expression, + initialStateDeserializer: Expression, + groupingAttributes: Seq[Attribute], + initialStateGroupAttrs: Seq[Attribute], + dataAttributes: Seq[Attribute], + initialStateDataAttrs: Seq[Attribute], + outputObjAttr: Attribute, + timeoutConf: GroupStateTimeout, + hasInitialState: Boolean, + initialState: SparkPlan, + child: SparkPlan): SparkPlan = { + if (hasInitialState) { + val watermarkPresent = child.output.exists { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true + case _ => false + } + val func = (keyRow: Any, values: Iterator[Any], states: Iterator[Any]) => { + // Check if there is only one state for every key. + var foundInitialStateForKey = false + val optionalStates = states.map { stateValue => + if (foundInitialStateForKey) { + foundDuplicateInitialKeyException() + } + foundInitialStateForKey = true + stateValue + }.toArray + + // Create group state object + val groupState = GroupStateImpl.createForStreaming( + optionalStates.headOption, + System.currentTimeMillis, + GroupStateImpl.NO_TIMESTAMP, + timeoutConf, + hasTimedOut = false, + watermarkPresent) + + // Call user function with the state and values for this key + userFunc(keyRow, values, groupState) + } + CoGroupExec( + func, keyDeserializer, valueDeserializer, initialStateDeserializer, groupingAttributes, + initialStateGroupAttrs, dataAttributes, initialStateDataAttrs, outputObjAttr, + child, initialState) + } else { + MapGroupsExec( + userFunc, keyDeserializer, valueDeserializer, groupingAttributes, + dataAttributes, outputObjAttr, timeoutConf, child) + } + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 0500c52..28439f2 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -196,14 +196,7 @@ public class JavaDatasetSuite implements Serializable { GroupStateTimeout.NoTimeout(), kvInitStateMappedDS); - Assert.assertThrows( - "Initial state is not supported in [flatMap|map]GroupsWithState " + - "operation on a batch DataFrame/Dataset", - AnalysisException.class, - () -> { - flatMapped2.collectAsList(); - } - ); + Assert.assertEquals(asSet("1a", "2", "3foobar"), toSet(flatMapped2.collectAsList())); Dataset<String> mapped2 = grouped.mapGroupsWithState( (MapGroupsWithStateFunction<Integer, String, Long, String>) (key, values, s) -> { StringBuilder sb = new StringBuilder(key.toString()); @@ -216,14 +209,7 @@ public class JavaDatasetSuite implements Serializable { Encoders.STRING(), GroupStateTimeout.NoTimeout(), kvInitStateMappedDS); - Assert.assertThrows( - "Initial state is not supported in [flatMap|map]GroupsWithState " + - "operation on a batch DataFrame/Dataset", - AnalysisException.class, - () -> { - mapped2.collectAsList(); - } - ); + Assert.assertEquals(asSet("1a", "2", "3foobar"), toSet(mapped2.collectAsList())); } @Test diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 152dd16..d34b2b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -1284,6 +1284,12 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } assertCannotGetWatermark { state.getCurrentWatermarkMs() } assert(!state.hasTimedOut) + if (key.contains("EventTime")) { + state.setTimeoutTimestamp(0, "1 hour") + } + if (key.contains("ProcessingTime")) { + state.setTimeoutDuration("1 hour") + } val count = state.getOption.map(_.count).getOrElse(0L) + valList.size // We need to check if not explicitly calling update will still save the init state or not if (!key.contains("NoUpdate")) { @@ -1413,6 +1419,52 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { ) } + Seq(NoTimeout(), EventTimeTimeout(), ProcessingTimeTimeout()).foreach { timeout => + test(s"flatMapGroupsWithState - initial state - batch mode - timeout ${timeout}") { + // We will test them on different shuffle partition configuration to make sure the + // grouping by key will still work. On higher number of shuffle partitions its possible + // that all keys end up on different partitions. + val initialState = Seq( + (s"keyInStateAndData-1-$timeout", new RunningCount(1)), + ("keyInStateAndData-2", new RunningCount(2)), + ("keyNoUpdate", new RunningCount(2)), // state.update will not be called + ("keyOnlyInState-1", new RunningCount(1)) + ).toDS().groupByKey(x => x._1).mapValues(_._2) + + val inputData = Seq( + ("keyOnlyInData"), ("keyInStateAndData-2") + ) + val result = inputData.toDS().groupByKey(x => x) + .flatMapGroupsWithState( + Update, timeout, initialState)(flatMapGroupsWithStateFunc) + + val expected = Seq( + ("keyOnlyInState-1", Seq[String](), "1"), + ("keyNoUpdate", Seq[String](), "2"), // update will not be called + ("keyInStateAndData-2", Seq[String]("keyInStateAndData-2"), "3"), // inc by 1 + (s"keyInStateAndData-1-$timeout", Seq[String](), "1"), + ("keyOnlyInData", Seq[String]("keyOnlyInData"), "1") // inc by 1 + ).toDF() + checkAnswer(result.toDF(), expected) + } + } + + testQuietly("flatMapGroupsWithState - initial state - batch mode - duplicate state") { + val initialState = Seq( + ("a", new RunningCount(1)), + ("a", new RunningCount(2)) + ).toDS().groupByKey(x => x._1).mapValues(_._2) + + val e = intercept[SparkException] { + Seq("a", "b").toDS().groupByKey(x => x) + .flatMapGroupsWithState(Update, NoTimeout(), initialState)(flatMapGroupsWithStateFunc) + .show() + } + assert(e.getMessage.contains( + "The initial state provided contained multiple rows(state) with the same key." + + " Make sure to de-duplicate the initial state before passing it.")) + } + testQuietly("flatMapGroupsWithState - initial state - streaming initial state") { val initialStateData = MemoryStream[(String, RunningCount)] initialStateData.addData(("a", new RunningCount(1))) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org