This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new bd14136b397 [SPARK-45747][SS] Use prefix key information in state metadata to handle reading state for session window aggregation bd14136b397 is described below commit bd14136b39784038c3cef7dc3cafac2b07024a92 Author: Chaoqin Li <chaoqin...@databricks.com> AuthorDate: Thu Nov 16 12:29:10 2023 +0900 [SPARK-45747][SS] Use prefix key information in state metadata to handle reading state for session window aggregation ### What changes were proposed in this pull request? Currently reading state for session window aggregation operator is not supported because the numColPrefixKey is unknown. We can read the operator state metadata introduced in SPARK-45558 to determine the number of prefix columns and load the state of session window correctly. ### Why are the changes needed? To support reading state for session window aggregation. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Add integration test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43788 from chaoqin-li1123/session_window_agg. Authored-by: Chaoqin Li <chaoqin...@databricks.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../v2/state/StatePartitionReader.scala | 30 +++++++++++++++---- .../v2/state/metadata/StateMetadataSource.scala | 2 +- .../v2/state/StateDataSourceReadSuite.scala | 16 ++++++++++ .../v2/state/StateDataSourceTestBase.scala | 34 ++++++++++++++++++++++ 4 files changed, 75 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 824034f42ea..1e5f7216e8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution.datasources.v2.state +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow, UnsafeRow} import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} @@ -47,7 +48,7 @@ class StatePartitionReader( storeConf: StateStoreConf, hadoopConf: SerializableConfiguration, partition: StateStoreInputPartition, - schema: StructType) extends PartitionReader[InternalRow] { + schema: StructType) extends PartitionReader[InternalRow] with Logging { private val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] private val valueSchema = SchemaUtil.getSchemaAsDataType(schema, "value").asInstanceOf[StructType] @@ -57,12 +58,29 @@ class StatePartitionReader( partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName) val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId) - // TODO: This does not handle the case of session window aggregation; we don't have an - // information whether the state store uses prefix scan or not. We will have to add such - // information to determine the right encoder/decoder for the data. + val allStateStoreMetadata = new StateMetadataPartitionReader( + partition.sourceOptions.stateCheckpointLocation.getParent.toString, hadoopConf) + .stateMetadata.toArray + + val stateStoreMetadata = allStateStoreMetadata.filter { entry => + entry.operatorId == partition.sourceOptions.operatorId && + entry.stateStoreName == partition.sourceOptions.storeName + } + val numColsPrefixKey = if (stateStoreMetadata.isEmpty) { + logWarning("Metadata for state store not found, possible cause is this checkpoint " + + "is created by older version of spark. If the query has session window aggregation, " + + "the state can't be read correctly and runtime exception will be thrown. " + + "Run the streaming query in newer spark version to generate state metadata " + + "can fix the issue.") + 0 + } else { + require(stateStoreMetadata.length == 1) + stateStoreMetadata.head.numColsPrefixKey + } + StateStore.getReadOnly(stateStoreProviderId, keySchema, valueSchema, - numColsPrefixKey = 0, version = partition.sourceOptions.batchId + 1, storeConf = storeConf, - hadoopConf = hadoopConf.value) + numColsPrefixKey = numColsPrefixKey, version = partition.sourceOptions.batchId + 1, + storeConf = storeConf, hadoopConf = hadoopConf.value) } private lazy val iter: Iterator[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala index 8a74db8d196..ca123a9e501 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala @@ -195,7 +195,7 @@ class StateMetadataPartitionReader( } } - private lazy val stateMetadata: Iterator[StateMetadataTableEntry] = { + private[state] lazy val stateMetadata: Iterator[StateMetadataTableEntry] = { allOperatorStateMetadata.flatMap { operatorStateMetadata => require(operatorStateMetadata.version == 1) val operatorStateMetadataV1 = operatorStateMetadata.asInstanceOf[OperatorStateMetadataV1] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala index 69573f46e68..7166705d03a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala @@ -494,6 +494,22 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass } } + test("Session window aggregation") { + withTempDir { checkpointDir => + runSessionWindowAggregationQuery(checkpointDir.getAbsolutePath) + + val df = spark.read.format("statestore").load(checkpointDir.toString) + checkAnswer(df.selectExpr("key.sessionId", "CAST(key.sessionStartTime AS LONG)", + "CAST(value.session_window.start AS LONG)", "CAST(value.session_window.end AS LONG)", + "value.sessionId", "value.count"), + Seq(Row("hello", 40, 40, 51, "hello", 2), + Row("spark", 40, 40, 50, "spark", 1), + Row("streaming", 40, 40, 51, "streaming", 2), + Row("world", 40, 40, 51, "world", 2), + Row("structured", 41, 41, 51, "structured", 1))) + } + } + test("flatMapGroupsWithState, state ver 1") { testFlatMapGroupsWithState(1) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala index 1fe93f891f4..890a716bbef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala @@ -392,6 +392,40 @@ trait StateDataSourceTestBase extends StreamTest with StateStoreMetricsTest { .select(col("leftId"), col("leftTime").cast("int"), col("rightId"), col("rightTime").cast("int")) } + + protected def runSessionWindowAggregationQuery(checkpointRoot: String): Unit = { + val input = MemoryStream[(String, Long)] + val sessionWindow = session_window($"eventTime", "10 seconds") + + val events = input.toDF() + .select($"_1".as("value"), $"_2".as("timestamp")) + .withColumn("eventTime", $"timestamp".cast("timestamp")) + .withWatermark("eventTime", "30 seconds") + .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") + + val streamingDf = events + .groupBy(sessionWindow as Symbol("session"), $"sessionId") + .agg(count("*").as("numEvents")) + .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents") + + testStream(streamingDf, OutputMode.Complete())( + StartStream(checkpointLocation = checkpointRoot), + AddData(input, + ("hello world spark streaming", 40L), + ("world hello structured streaming", 41L) + ), + CheckNewAnswer( + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("streaming", 40, 51, 11, 2), + ("spark", 40, 50, 10, 1), + ("structured", 41, 51, 10, 1) + ), + StopStream + ) + } } case class Event(sessionId: String, timestamp: Timestamp) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org