This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new bd14136b397 [SPARK-45747][SS] Use prefix key information in state 
metadata to handle reading state for session window aggregation
bd14136b397 is described below

commit bd14136b39784038c3cef7dc3cafac2b07024a92
Author: Chaoqin Li <chaoqin...@databricks.com>
AuthorDate: Thu Nov 16 12:29:10 2023 +0900

    [SPARK-45747][SS] Use prefix key information in state metadata to handle 
reading state for session window aggregation
    
    ### What changes were proposed in this pull request?
    Currently reading state for session window aggregation operator is not 
supported because the numColPrefixKey is unknown. We can read the operator 
state metadata introduced in SPARK-45558 to determine the number of prefix 
columns and load the state of session window correctly.
    
    ### Why are the changes needed?
    To support reading state for session window aggregation.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Add integration test.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #43788 from chaoqin-li1123/session_window_agg.
    
    Authored-by: Chaoqin Li <chaoqin...@databricks.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../v2/state/StatePartitionReader.scala            | 30 +++++++++++++++----
 .../v2/state/metadata/StateMetadataSource.scala    |  2 +-
 .../v2/state/StateDataSourceReadSuite.scala        | 16 ++++++++++
 .../v2/state/StateDataSourceTestBase.scala         | 34 ++++++++++++++++++++++
 4 files changed, 75 insertions(+), 7 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
index 824034f42ea..1e5f7216e8b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
@@ -16,6 +16,7 @@
  */
 package org.apache.spark.sql.execution.datasources.v2.state
 
+import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
JoinedRow, UnsafeRow}
 import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, 
PartitionReaderFactory}
@@ -47,7 +48,7 @@ class StatePartitionReader(
     storeConf: StateStoreConf,
     hadoopConf: SerializableConfiguration,
     partition: StateStoreInputPartition,
-    schema: StructType) extends PartitionReader[InternalRow] {
+    schema: StructType) extends PartitionReader[InternalRow] with Logging {
 
   private val keySchema = SchemaUtil.getSchemaAsDataType(schema, 
"key").asInstanceOf[StructType]
   private val valueSchema = SchemaUtil.getSchemaAsDataType(schema, 
"value").asInstanceOf[StructType]
@@ -57,12 +58,29 @@ class StatePartitionReader(
       partition.sourceOptions.operatorId, partition.partition, 
partition.sourceOptions.storeName)
     val stateStoreProviderId = StateStoreProviderId(stateStoreId, 
partition.queryId)
 
-    // TODO: This does not handle the case of session window aggregation; we 
don't have an
-    //  information whether the state store uses prefix scan or not. We will 
have to add such
-    //  information to determine the right encoder/decoder for the data.
+    val allStateStoreMetadata = new StateMetadataPartitionReader(
+      partition.sourceOptions.stateCheckpointLocation.getParent.toString, 
hadoopConf)
+      .stateMetadata.toArray
+
+    val stateStoreMetadata = allStateStoreMetadata.filter { entry =>
+      entry.operatorId == partition.sourceOptions.operatorId &&
+        entry.stateStoreName == partition.sourceOptions.storeName
+    }
+    val numColsPrefixKey = if (stateStoreMetadata.isEmpty) {
+      logWarning("Metadata for state store not found, possible cause is this 
checkpoint " +
+        "is created by older version of spark. If the query has session window 
aggregation, " +
+        "the state can't be read correctly and runtime exception will be 
thrown. " +
+        "Run the streaming query in newer spark version to generate state 
metadata " +
+        "can fix the issue.")
+      0
+    } else {
+      require(stateStoreMetadata.length == 1)
+      stateStoreMetadata.head.numColsPrefixKey
+    }
+
     StateStore.getReadOnly(stateStoreProviderId, keySchema, valueSchema,
-      numColsPrefixKey = 0, version = partition.sourceOptions.batchId + 1, 
storeConf = storeConf,
-      hadoopConf = hadoopConf.value)
+      numColsPrefixKey = numColsPrefixKey, version = 
partition.sourceOptions.batchId + 1,
+      storeConf = storeConf, hadoopConf = hadoopConf.value)
   }
 
   private lazy val iter: Iterator[InternalRow] = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala
index 8a74db8d196..ca123a9e501 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala
@@ -195,7 +195,7 @@ class StateMetadataPartitionReader(
     }
   }
 
-  private lazy val stateMetadata: Iterator[StateMetadataTableEntry] = {
+  private[state] lazy val stateMetadata: Iterator[StateMetadataTableEntry] = {
     allOperatorStateMetadata.flatMap { operatorStateMetadata =>
       require(operatorStateMetadata.version == 1)
       val operatorStateMetadataV1 = 
operatorStateMetadata.asInstanceOf[OperatorStateMetadataV1]
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
index 69573f46e68..7166705d03a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
@@ -494,6 +494,22 @@ abstract class StateDataSourceReadSuite extends 
StateDataSourceTestBase with Ass
     }
   }
 
+  test("Session window aggregation") {
+    withTempDir { checkpointDir =>
+      runSessionWindowAggregationQuery(checkpointDir.getAbsolutePath)
+
+      val df = spark.read.format("statestore").load(checkpointDir.toString)
+      checkAnswer(df.selectExpr("key.sessionId", "CAST(key.sessionStartTime AS 
LONG)",
+        "CAST(value.session_window.start AS LONG)", 
"CAST(value.session_window.end AS LONG)",
+        "value.sessionId", "value.count"),
+        Seq(Row("hello", 40, 40, 51, "hello", 2),
+          Row("spark", 40, 40, 50, "spark", 1),
+          Row("streaming", 40, 40, 51, "streaming", 2),
+          Row("world", 40, 40, 51, "world", 2),
+          Row("structured", 41, 41, 51, "structured", 1)))
+    }
+  }
+
   test("flatMapGroupsWithState, state ver 1") {
     testFlatMapGroupsWithState(1)
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala
index 1fe93f891f4..890a716bbef 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTestBase.scala
@@ -392,6 +392,40 @@ trait StateDataSourceTestBase extends StreamTest with 
StateStoreMetricsTest {
       .select(col("leftId"), col("leftTime").cast("int"),
         col("rightId"), col("rightTime").cast("int"))
   }
+
+  protected def runSessionWindowAggregationQuery(checkpointRoot: String): Unit 
= {
+    val input = MemoryStream[(String, Long)]
+    val sessionWindow = session_window($"eventTime", "10 seconds")
+
+    val events = input.toDF()
+      .select($"_1".as("value"), $"_2".as("timestamp"))
+      .withColumn("eventTime", $"timestamp".cast("timestamp"))
+      .withWatermark("eventTime", "30 seconds")
+      .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime")
+
+    val streamingDf = events
+      .groupBy(sessionWindow as Symbol("session"), $"sessionId")
+      .agg(count("*").as("numEvents"))
+      .selectExpr("sessionId", "CAST(session.start AS LONG)", 
"CAST(session.end AS LONG)",
+        "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS 
durationMs",
+        "numEvents")
+
+    testStream(streamingDf, OutputMode.Complete())(
+      StartStream(checkpointLocation = checkpointRoot),
+      AddData(input,
+        ("hello world spark streaming", 40L),
+        ("world hello structured streaming", 41L)
+      ),
+      CheckNewAnswer(
+        ("hello", 40, 51, 11, 2),
+        ("world", 40, 51, 11, 2),
+        ("streaming", 40, 51, 11, 2),
+        ("spark", 40, 50, 10, 1),
+        ("structured", 41, 51, 10, 1)
+      ),
+      StopStream
+    )
+  }
 }
 
 case class Event(sessionId: String, timestamp: Timestamp)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to