This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 6c67c61bfd21 [SPARK-47840][SS] Disable foldable propagation across 
Streaming Aggregate/Join nodes
6c67c61bfd21 is described below

commit 6c67c61bfd21ebe68837f889502368ab9d99ebc5
Author: Bhuwan Sahni <bhuwan.sa...@databricks.com>
AuthorDate: Tue Apr 16 12:36:08 2024 +0900

    [SPARK-47840][SS] Disable foldable propagation across Streaming 
Aggregate/Join nodes
    
    ### What changes were proposed in this pull request?
    
    Streaming queries with Union of 2 data streams followed by an Aggregate 
(groupBy) can produce incorrect results if the grouping key is a constant 
literal for micro-batch duration.
    
    The query produces incorrect results because the query optimizer recognizes 
the literal value in the grouping key as foldable and replaces the grouping key 
expression with the actual literal value. This optimization is correct for 
batch queries. However Streaming queries also read information from StateStore, 
and the output contains both the results from StateStore (computed in previous 
microbatches) and data from input sources (computed in this microbatch). The 
HashAggregate node aft [...]
    
    See an example logical and physical plan below for a query performing a 
union on 2 data streams, followed by a groupBy. Note that the name#4 expression 
has been optimized to ds1. The Streaming query Aggregate adds StateStoreSave 
node as child of HashAggregate, however any grouping key read from StateStore 
will still be read as ds1 due to the optimization.
    
    ### Optimized Logical Plan
    
    ```
    === Applying Rule 
org.apache.spark.sql.catalyst.optimizer.FoldablePropagation ===
    
    === Old Plan ===
    
    WriteToMicroBatchDataSource MemorySink, 
eb67645e-30fc-41a8-8006-35bb7649c202, Complete, 0
    +- Aggregate [name#4], [name#4, count(1) AS count#31L]
       +- Project [ds1 AS name#4]
            +- StreamingDataSourceV2ScanRelation[value#1] MemoryStreamDataSource
    
    === New Plan ===
    
    WriteToMicroBatchDataSource MemorySink, 
eb67645e-30fc-41a8-8006-35bb7649c202, Complete, 0
    +- Aggregate [ds1], [ds1 AS name#4, count(1) AS count#31L]
       +- Project [ds1 AS name#4]
            +- StreamingDataSourceV2ScanRelation[value#1] MemoryStreamDataSource
    
    ====
    ```
    
    ### Corresponding Physical Plan
    
    ```
    WriteToDataSourceV2 MicroBatchWrite[epoch: 0, writer: 
org.apache.spark.sql.execution.streaming.sources.MemoryStreamingWrite2b4c6242], 
org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy$$Lambda$3143/185907563435709d26
    +- HashAggregate(keys=[ds1#39], functions=[finalmerge_count(merge 
count#38L) AS count(1)#30L], output=[name#4, count#31L])
       +- StateStoreSave [ds1#39], state info [ checkpoint = 
file:/tmp/streaming.metadata-e470782a-18a3-463c-9e61-3a10d0bdf180/state, runId 
= 4dedecca-910c-4518-855e-456702617414, opId = 0, ver = 0, numPartitions = 5], 
Complete, 0, 0, 2
            +- HashAggregate(keys=[ds1#39], functions=[merge_count(merge 
count#38L) AS count#38L], output=[ds1#39, count#38L])
            +- StateStoreRestore [ds1#39], state info [ checkpoint = 
file:/tmp/streaming.metadata-e470782a-18a3-463c-9e61-3a10d0bdf180/state, runId 
= 4dedecca-910c-4518-855e-456702617414, opId = 0, ver = 0, numPartitions = 5], 2
                    +- HashAggregate(keys=[ds1#39], 
functions=[merge_count(merge count#38L) AS count#38L], output=[ds1#39, 
count#38L])
                    +- HashAggregate(keys=[ds1 AS ds1#39], 
functions=[partial_count(1) AS count#38L], output=[ds1#39, count#38L])
                    +- Project
                            +- MicroBatchScan[value#1] MemoryStreamDataSource
    
    ```
    
    This PR disables foldable propagation across Streaming Aggregate/Join nodes 
in the logical plan.
    
    ### Why are the changes needed?
    
    Changes are needed to ensure that Streaming queries with literal value for 
grouping key/join key produce correct results.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added 
`sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala`
 testcase.
    
    ```
    
    [info] Run completed in 54 seconds, 150 milliseconds.
    [info] Total number of tests run: 9
    [info] Suites: completed 1, aborted 0
    [info] Tests: succeeded 9, failed 0, canceled 0, ignored 0, pending 0
    [info] All tests passed.
    
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #46035 from sahnib/SPARK-47840.
    
    Authored-by: Bhuwan Sahni <bhuwan.sa...@databricks.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
    (cherry picked from commit f21719346eb0492cf9de47495853a4efad37dbab)
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../spark/sql/catalyst/optimizer/expressions.scala |  18 +-
 ...treamingQueryOptimizationCorrectnessSuite.scala | 419 +++++++++++++++++++++
 2 files changed, 435 insertions(+), 2 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index d4f0f72c9352..456d0da54de5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -934,7 +934,14 @@ object FoldablePropagation extends Rule[LogicalPlan] {
         val newFoldableMap = collectFoldables(newProject.projectList)
         (newProject, newFoldableMap)
 
-      case a: Aggregate =>
+      // FoldablePropagation rule can produce incorrect optimized plan for 
streaming queries.
+      // This is because the optimizer can replace the grouping expressions, 
or join column
+      // with a literal value if the grouping key is constant for the 
micro-batch. However,
+      // as Streaming queries also read from the StateStore, this optimization 
also
+      // overwrites any keys read from State Store. We need to disable this 
optimization
+      // until we can make optimizer aware of Streaming state store. The State 
Store nodes
+      // are currently added in the Physical plan.
+      case a: Aggregate if !a.isStreaming =>
         val (newChild, foldableMap) = propagateFoldables(a.child)
         val newAggregate =
           
replaceFoldable(a.withNewChildren(Seq(newChild)).asInstanceOf[Aggregate], 
foldableMap)
@@ -971,7 +978,14 @@ object FoldablePropagation extends Rule[LogicalPlan] {
       // propagating the foldable expressions.
       // TODO(cloud-fan): It seems more reasonable to use new attributes as 
the output attributes
       // of outer join.
-      case j: Join =>
+      // FoldablePropagation rule can produce incorrect optimized plan for 
streaming queries.
+      // This is because the optimizer can replace the grouping expressions, 
or join column
+      // with a literal value if the grouping key is constant for the 
micro-batch. However,
+      // as Streaming queries also read from the StateStore, this optimization 
also
+      // overwrites any keys read from State Store. We need to disable this 
optimization
+      // until we can make optimizer aware of Streaming state store. The State 
Store nodes
+      // are currently added in the Physical plan.
+      case j: Join if !j.left.isStreaming || !j.right.isStreaming =>
         val (newChildren, foldableMaps) = 
j.children.map(propagateFoldables).unzip
         val foldableMap = AttributeMap(
           foldableMaps.foldLeft(Iterable.empty[(Attribute, Alias)])(_ ++ 
_.baseMap.values))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala
new file mode 100644
index 000000000000..efc84c8e4c7c
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala
@@ -0,0 +1,419 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.sql.Timestamp
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.functions.{lit, window}
+
+/**
+ * This test ensures that any optimizations done by Spark SQL optimizer are
+ * correct for Streaming queries.
+ */
+class StreamingQueryOptimizationCorrectnessSuite extends StreamTest {
+  import testImplicits._
+
+  test("streaming Union with literal produces correct results") {
+    val inputStream1 = MemoryStream[Int]
+    val ds1 = inputStream1
+      .toDS()
+      .withColumn("name", lit("ds1"))
+      .withColumn("count", $"value")
+      .select("name", "count")
+
+    val inputStream2 = MemoryStream[Int]
+    val ds2 = inputStream2
+      .toDS()
+      .withColumn("name", lit("ds2"))
+      .withColumn("count", $"value")
+      .select("name", "count")
+
+    val result =
+      ds1.union(ds2)
+        .groupBy("name")
+        .count()
+
+    testStream(result, OutputMode.Complete())(
+      AddData(inputStream1, 1),
+      ProcessAllAvailable(),
+      AddData(inputStream2, 1),
+      ProcessAllAvailable(),
+      CheckNewAnswer(Row("ds1", 1), Row("ds2", 1))
+    )
+  }
+
+  test("streaming aggregate with literal and watermark after literal column" +
+    " produces correct results on query change") {
+    withTempDir { dir =>
+      val inputStream1 = MemoryStream[Timestamp]
+      val ds1 = inputStream1
+        .toDS()
+        .withColumn("name", lit("ds1"))
+        .withColumn("ts", $"value")
+        .withWatermark("ts", "1 minutes")
+        .select("name", "ts")
+
+      val result =
+        ds1.groupBy("name").count()
+
+      testStream(result, OutputMode.Complete())(
+        StartStream(checkpointLocation = dir.getAbsolutePath),
+        AddData(inputStream1, Timestamp.valueOf("2023-01-02 00:00:00")),
+        ProcessAllAvailable()
+      )
+
+      val ds2 = inputStream1
+        .toDS()
+        .withColumn("name", lit("ds2"))
+        .withColumn("ts", $"value")
+        .withWatermark("ts", "1 minutes")
+        .select("name", "ts")
+
+      val result2 =
+        ds2.groupBy("name").count()
+
+      testStream(result2, OutputMode.Complete())(
+        StartStream(checkpointLocation = dir.getAbsolutePath),
+        AddData(inputStream1, Timestamp.valueOf("2023-01-03 00:00:00")),
+        ProcessAllAvailable(),
+        CheckNewAnswer(Row("ds1", 1), Row("ds2", 1)),
+        AddData(inputStream1, Timestamp.valueOf("2023-01-04 00:00:00")),
+        ProcessAllAvailable(),
+        CheckNewAnswer(Row("ds1", 1), Row("ds2", 2))
+      )
+    }
+  }
+
+  test("streaming aggregate with literal and watermark before literal column" +
+    " produces correct results on query change") {
+    withTempDir { dir =>
+      val inputStream1 = MemoryStream[Timestamp]
+      val ds1 = inputStream1
+        .toDS()
+        .withColumn("ts", $"value")
+        .withWatermark("ts", "1 minutes")
+        .withColumn("name", lit("ds1"))
+        .select("name", "ts")
+
+      val result =
+        ds1.groupBy("name").count()
+
+      testStream(result, OutputMode.Complete())(
+        StartStream(checkpointLocation = dir.getAbsolutePath),
+        AddData(inputStream1, Timestamp.valueOf("2023-01-02 00:00:00")),
+        ProcessAllAvailable()
+      )
+
+      val ds2 = inputStream1
+        .toDS()
+        .withColumn("ts", $"value")
+        .withWatermark("ts", "1 minutes")
+        .withColumn("name", lit("ds2"))
+        .select("name", "ts")
+
+      val result2 =
+        ds2.groupBy("name").count()
+
+      testStream(result2, OutputMode.Complete())(
+        StartStream(checkpointLocation = dir.getAbsolutePath),
+        AddData(inputStream1, Timestamp.valueOf("2023-01-03 00:00:00")),
+        ProcessAllAvailable(),
+        CheckNewAnswer(Row("ds1", 1), Row("ds2", 1)),
+        AddData(inputStream1, Timestamp.valueOf("2023-01-04 00:00:00")),
+        ProcessAllAvailable(),
+        CheckNewAnswer(Row("ds1", 1), Row("ds2", 2))
+      )
+    }
+  }
+
+  test("streaming aggregate with literal" +
+    " produces correct results on query change") {
+    withTempDir { dir =>
+      val inputStream1 = MemoryStream[Int]
+      val ds1 = inputStream1
+        .toDS()
+        .withColumn("name", lit("ds1"))
+        .withColumn("count", $"value")
+        .select("name", "count")
+
+      val result =
+        ds1.groupBy("name").count()
+
+      testStream(result, OutputMode.Complete())(
+        StartStream(checkpointLocation = dir.getAbsolutePath),
+        AddData(inputStream1, 1),
+        ProcessAllAvailable()
+      )
+
+      val ds2 = inputStream1
+        .toDS()
+        .withColumn("name", lit("ds2"))
+        .withColumn("count", $"value")
+        .select("name", "count")
+
+      val result2 =
+        ds2.groupBy("name").count()
+
+      testStream(result2, OutputMode.Complete())(
+        StartStream(checkpointLocation = dir.getAbsolutePath),
+        AddData(inputStream1, 1),
+        ProcessAllAvailable(),
+        CheckNewAnswer(Row("ds1", 1), Row("ds2", 1))
+      )
+    }
+  }
+
+  test("stream stream join with literal" +
+    " produces correct results") {
+    withTempDir { dir =>
+      import java.sql.Timestamp
+      val inputStream1 = MemoryStream[Int]
+      val inputStream2 = MemoryStream[Int]
+
+      val ds1 = inputStream1
+        .toDS()
+        .withColumn("name", lit(Timestamp.valueOf("2023-01-01 00:00:00")))
+        .withWatermark("name", "1 minutes")
+        .withColumn("count1", lit(1))
+
+      val ds2 = inputStream2
+        .toDS()
+        .withColumn("name", lit(Timestamp.valueOf("2023-01-02 00:00:00")))
+        .withWatermark("name", "1 minutes")
+        .withColumn("count2", lit(2))
+
+
+      val result =
+        ds1.join(ds2, "name", "full")
+          .select("name", "count1", "count2")
+
+      testStream(result, OutputMode.Append())(
+        StartStream(checkpointLocation = dir.getAbsolutePath),
+        AddData(inputStream1, 1),
+        ProcessAllAvailable(),
+        AddData(inputStream2, 1),
+        ProcessAllAvailable(),
+        AddData(inputStream1, 2),
+        ProcessAllAvailable(),
+        AddData(inputStream2, 2),
+        ProcessAllAvailable(),
+        CheckNewAnswer()
+      )
+
+      // modify the query and update literal values for name
+      val ds3 = inputStream1
+        .toDS()
+        .withColumn("name", lit(Timestamp.valueOf("2023-02-01 00:00:00")))
+        .withWatermark("name", "1 minutes")
+        .withColumn("count1", lit(3))
+
+      val ds4 = inputStream2
+        .toDS()
+        .withColumn("name", lit(Timestamp.valueOf("2023-02-02 00:00:00")))
+        .withWatermark("name", "1 minutes")
+        .withColumn("count2", lit(4))
+
+      val result2 =
+        ds3.join(ds4, "name", "full")
+          .select("name", "count1", "count2")
+
+      testStream(result2, OutputMode.Append())(
+        StartStream(checkpointLocation = dir.getAbsolutePath),
+        AddData(inputStream1, 1),
+        ProcessAllAvailable(),
+        AddData(inputStream2, 1),
+        ProcessAllAvailable(),
+        AddData(inputStream1, 2),
+        ProcessAllAvailable(),
+        AddData(inputStream2, 2),
+        ProcessAllAvailable(),
+        CheckNewAnswer(
+          Row(Timestamp.valueOf("2023-01-01 00:00:00"),
+            1, null.asInstanceOf[java.lang.Integer]),
+          Row(Timestamp.valueOf("2023-01-01 00:00:00"),
+            1, null.asInstanceOf[java.lang.Integer]),
+          Row(Timestamp.valueOf("2023-01-02 00:00:00"),
+            null.asInstanceOf[java.lang.Integer], 2),
+          Row(Timestamp.valueOf("2023-01-02 00:00:00"),
+            null.asInstanceOf[java.lang.Integer], 2)
+        )
+      )
+    }
+  }
+
+  test("streaming SQL distinct usage with literal grouping" +
+    " key produces correct results") {
+    val inputStream1 = MemoryStream[Int]
+    val ds1 = inputStream1
+      .toDS()
+      .withColumn("name", lit("ds1"))
+      .withColumn("count", $"value")
+      .select("name", "count")
+
+    val inputStream2 = MemoryStream[Int]
+    val ds2 = inputStream2
+      .toDS()
+      .withColumn("name", lit("ds2"))
+      .withColumn("count", $"value")
+      .select("name", "count")
+
+    val result =
+      ds1.union(ds2)
+        .groupBy("name")
+        .as[String, (String, Int, Int)]
+        .keys
+
+    testStream(result, OutputMode.Complete())(
+      AddData(inputStream1, 1),
+      ProcessAllAvailable(),
+      AddData(inputStream2, 1),
+      ProcessAllAvailable(),
+      CheckNewAnswer(Row("ds1"), Row("ds2"))
+    )
+  }
+
+  test("streaming window aggregation with literal time column" +
+    " key produces correct results") {
+    val inputStream1 = MemoryStream[Int]
+    val ds1 = inputStream1
+      .toDS()
+      .withColumn("name", lit(Timestamp.valueOf("2023-01-01 00:00:00")))
+      .withColumn("count", $"value")
+      .select("name", "count")
+
+    val inputStream2 = MemoryStream[Int]
+    val ds2 = inputStream2
+      .toDS()
+      .withColumn("name", lit(Timestamp.valueOf("2023-01-02 00:00:00")))
+      .withColumn("count", $"value")
+      .select("name", "count")
+
+    val result =
+      ds1.union(ds2)
+        .groupBy(
+          window($"name", "1 second", "1 second")
+        )
+        .count()
+
+    testStream(result, OutputMode.Complete())(
+      AddData(inputStream1, 1),
+      ProcessAllAvailable(),
+      AddData(inputStream2, 1),
+      ProcessAllAvailable(),
+      CheckNewAnswer(
+        Row(
+          Row(Timestamp.valueOf("2023-01-01 00:00:00"), 
Timestamp.valueOf("2023-01-01 00:00:01")),
+          1),
+        Row(
+          Row(Timestamp.valueOf("2023-01-02 00:00:00"), 
Timestamp.valueOf("2023-01-02 00:00:01")),
+          1))
+    )
+  }
+
+  test("stream stream join with literals produces correct value") {
+    withTempDir { dir =>
+      val input1 = MemoryStream[Int]
+      val input2 = MemoryStream[Int]
+
+      val df1 = input1
+        .toDF()
+        .withColumn("key", $"value")
+        .withColumn("leftValue", lit(1))
+        .select("key", "leftValue")
+
+      val df2 = input2
+        .toDF()
+        .withColumn("key", $"value")
+        .withColumn("rightValue", lit(2))
+        .select("key", "rightValue")
+
+      val result = df1
+        .join(df2, "key")
+        .select("key", "leftValue", "rightValue")
+
+      testStream(result, OutputMode.Append())(
+        StartStream(checkpointLocation = dir.getAbsolutePath),
+        AddData(input1, 1),
+        ProcessAllAvailable(),
+        AddData(input2, 1),
+        ProcessAllAvailable(),
+        CheckAnswer(Row(1, 1, 2))
+      )
+    }
+  }
+
+  test("stream stream join with literals produces correct value on query 
change") {
+    withTempDir { dir =>
+      val input1 = MemoryStream[Int]
+      val input2 = MemoryStream[Int]
+
+      val df1 = input1
+        .toDF()
+        .withColumn("key", lit("key1"))
+        .withColumn("leftValue", lit(1))
+        .select("key", "leftValue")
+
+      val df2 = input2
+        .toDF()
+        .withColumn("key", lit("key2"))
+        .withColumn("rightValue", lit(2))
+        .select("key", "rightValue")
+
+      val result = df1
+        .join(df2, "key")
+        .select("key", "leftValue", "rightValue")
+
+      testStream(result, OutputMode.Append())(
+        StartStream(checkpointLocation = dir.getAbsolutePath),
+        AddData(input1, 1),
+        ProcessAllAvailable(),
+        AddData(input2, 1),
+        ProcessAllAvailable()
+      )
+
+      val df3 = input1
+        .toDF()
+        .withColumn("key", lit("key2"))
+        .withColumn("leftValue", lit(3))
+        .select("key", "leftValue")
+
+      val df4 = input2
+        .toDF()
+        .withColumn("key", lit("key1"))
+        .withColumn("rightValue", lit(4))
+        .select("key", "rightValue")
+
+      val result2 = df3
+        .join(df4, "key")
+        .select("key", "leftValue", "rightValue")
+
+      testStream(result2, OutputMode.Append())(
+        StartStream(checkpointLocation = dir.getAbsolutePath),
+        AddData(input1, 1),
+        ProcessAllAvailable(),
+        AddData(input2, 1),
+        ProcessAllAvailable(),
+        CheckAnswer(
+          Row("key1", 1, 4),
+          Row("key2", 3, 2))
+      )
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to