This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 7e0c31445c31 [SPARK-48481][SQL][SS] Do not apply OptimizeOneRowPlan 
against streaming Dataset
7e0c31445c31 is described below

commit 7e0c31445c31a76f6e1835f204e8a09eee2b57dc
Author: Jungtaek Lim <kabhwan.opensou...@gmail.com>
AuthorDate: Sat Jun 1 15:15:01 2024 +0900

    [SPARK-48481][SQL][SS] Do not apply OptimizeOneRowPlan against streaming 
Dataset
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to exclude streaming Dataset from the target of 
OptimizeOneRowPlan.
    
    ### Why are the changes needed?
    
    The rule should not be applied to streaming source, since the number of 
rows it sees is just for current microbatch. It does not mean the streaming 
source will ever produce max 1 rows during lifetime of the query.
    
    Suppose the case: the streaming query has a case where batch 0 runs with 
empty data in streaming source A which triggers the rule with Aggregate, and 
batch 1 runs with several data in streaming source A which no longer trigger 
the rule.
    
    In the above scenario, this could fail the query as stateful operator is 
expected to be planned for every batches whereas here it is planned 
"selectively".
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, but the behavior can be reverted back with a new config, 
`spark.sql.streaming.optimizeOneRowPlan.enabled`, although I believe there 
should be really rare case where users have to turn the config on.
    
    ### How was this patch tested?
    
    New UT.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #46820 from HeartSaVioR/SPARK-48481.
    
    Authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
    (cherry picked from commit 1cecdc7596e078b4917f456bfbd2435ff9022f2f)
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../catalyst/optimizer/OptimizeOneRowPlan.scala    | 27 ++++++--
 .../org/apache/spark/sql/internal/SQLConf.scala    | 11 ++++
 ...treamingQueryOptimizationCorrectnessSuite.scala | 73 ++++++++++++++++++++++
 3 files changed, 107 insertions(+), 4 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlan.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlan.scala
index 83646611578c..61c08eb8f8b6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlan.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlan.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
 import org.apache.spark.sql.catalyst.trees.TreePattern._
+import org.apache.spark.sql.internal.SQLConf
 
 /**
  * The rule is applied both normal and AQE Optimizer. It optimizes plan using 
max rows:
@@ -31,19 +32,37 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._
  *     it's grouping only(include the rewritten distinct plan), convert 
aggregate to project
  *   - if the max rows of the child of aggregate is less than or equal to 1,
  *     set distinct to false in all aggregate expression
+ *
+ * Note: the rule should not be applied to streaming source, since the number 
of rows it sees is
+ * just for current microbatch. It does not mean the streaming source will 
ever produce max 1
+ * rows during lifetime of the query. Suppose the case: the streaming query 
has a case where
+ * batch 0 runs with empty data in streaming source A which triggers the rule 
with Aggregate,
+ * and batch 1 runs with several data in streaming source A which no longer 
trigger the rule.
+ * In the above scenario, this could fail the query as stateful operator is 
expected to be planned
+ * for every batches whereas here it is planned "selectively".
  */
 object OptimizeOneRowPlan extends Rule[LogicalPlan] {
   override def apply(plan: LogicalPlan): LogicalPlan = {
+    val enableForStreaming = 
conf.getConf(SQLConf.STREAMING_OPTIMIZE_ONE_ROW_PLAN_ENABLED)
+
     plan.transformUpWithPruning(_.containsAnyPattern(SORT, AGGREGATE), ruleId) 
{
-      case Sort(_, _, child) if child.maxRows.exists(_ <= 1L) => child
-      case Sort(_, false, child) if child.maxRowsPerPartition.exists(_ <= 1L) 
=> child
-      case agg @ Aggregate(_, _, child) if agg.groupOnly && 
child.maxRows.exists(_ <= 1L) =>
+      case Sort(_, _, child) if child.maxRows.exists(_ <= 1L) &&
+        isChildEligible(child, enableForStreaming) => child
+      case Sort(_, false, child) if child.maxRowsPerPartition.exists(_ <= 1L) 
&&
+        isChildEligible(child, enableForStreaming) => child
+      case agg @ Aggregate(_, _, child) if agg.groupOnly && 
child.maxRows.exists(_ <= 1L) &&
+        isChildEligible(child, enableForStreaming) =>
         Project(agg.aggregateExpressions, child)
-      case agg: Aggregate if agg.child.maxRows.exists(_ <= 1L) =>
+      case agg: Aggregate if agg.child.maxRows.exists(_ <= 1L) &&
+        isChildEligible(agg.child, enableForStreaming) =>
         agg.transformExpressions {
           case aggExpr: AggregateExpression if aggExpr.isDistinct =>
             aggExpr.copy(isDistinct = false)
         }
     }
   }
+
+  private def isChildEligible(child: LogicalPlan, enableForStreaming: 
Boolean): Boolean = {
+    enableForStreaming || !child.isStreaming
+  }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 3e62f656ac9e..74ff4f09a157 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2142,6 +2142,17 @@ object SQLConf {
       .createWithDefault(true)
 
 
+  val STREAMING_OPTIMIZE_ONE_ROW_PLAN_ENABLED =
+    buildConf("spark.sql.streaming.optimizeOneRowPlan.enabled")
+      .internal()
+      .doc("When true, enable OptimizeOneRowPlan rule for the case where the 
child is a " +
+        "streaming Dataset. This is a fallback flag to revert the 'incorrect' 
behavior, hence " +
+        "this configuration must not be used without understanding in depth. 
Use this only to " +
+        "quickly recover failure in existing query!")
+      .version("4.0.0")
+      .booleanConf
+      .createWithDefault(false)
+
   val VARIABLE_SUBSTITUTE_ENABLED =
     buildConf("spark.sql.variable.substitute")
       .doc("This enables substitution using syntax like `${var}`, 
`${system:var}`, " +
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala
index d17da5d31edd..782badaef924 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala
@@ -22,6 +22,7 @@ import java.sql.Timestamp
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.execution.streaming.MemoryStream
 import org.apache.spark.sql.functions.{expr, lit, window}
+import org.apache.spark.sql.internal.SQLConf
 
 /**
  * This test ensures that any optimizations done by Spark SQL optimizer are
@@ -451,4 +452,76 @@ class StreamingQueryOptimizationCorrectnessSuite extends 
StreamTest {
       )
     }
   }
+
+  test("SPARK-48481: DISTINCT with empty stream source should retain 
AGGREGATE") {
+    def doTest(numExpectedStatefulOperatorsForOneEmptySource: Int): Unit = {
+      withTempView("tv1", "tv2") {
+        val inputStream1 = MemoryStream[Int]
+        val ds1 = inputStream1.toDS()
+        ds1.registerTempTable("tv1")
+
+        val inputStream2 = MemoryStream[Int]
+        val ds2 = inputStream2.toDS()
+        ds2.registerTempTable("tv2")
+
+        // DISTINCT is rewritten to AGGREGATE, hence an AGGREGATEs for each 
source
+        val unioned = spark.sql(
+          """
+            | WITH u AS (
+            |   SELECT DISTINCT value AS value FROM tv1
+            | ), v AS (
+            |   SELECT DISTINCT value AS value FROM tv2
+            | )
+            | SELECT value FROM u UNION ALL SELECT value FROM v
+            |""".stripMargin
+        )
+
+        testStream(unioned, OutputMode.Update())(
+          MultiAddData(inputStream1, 1, 1, 2)(inputStream2, 1, 1, 2),
+          CheckNewAnswer(1, 2, 1, 2),
+          Execute { qe =>
+            val stateOperators = qe.lastProgress.stateOperators
+            // Aggregate should be "stateful" one
+            assert(stateOperators.length === 2)
+            stateOperators.zipWithIndex.foreach { case (op, id) =>
+              assert(op.numRowsUpdated === 2, s"stateful OP ID: $id")
+            }
+          },
+          AddData(inputStream2, 2, 2, 3),
+          // NOTE: this is probably far from expectation to have 2 as output 
given user intends
+          // deduplicate, but the behavior is still correct with rewritten 
node and output mode:
+          // Aggregate & Update mode.
+          // TODO: Probably we should disallow DISTINCT or rewrite to
+          //  dropDuplicates(WithinWatermark) for streaming source?
+          CheckNewAnswer(2, 3),
+          Execute { qe =>
+            val stateOperators = qe.lastProgress.stateOperators
+            // Aggregate should be "stateful" one
+            assert(stateOperators.length === 
numExpectedStatefulOperatorsForOneEmptySource)
+            val opWithUpdatedRows = 
stateOperators.zipWithIndex.filterNot(_._1.numRowsUpdated == 0)
+            assert(opWithUpdatedRows.length === 1)
+            // If this were dropDuplicates, numRowsUpdated should have been 1.
+            assert(opWithUpdatedRows.head._1.numRowsUpdated === 2,
+              s"stateful OP ID: ${opWithUpdatedRows.head._2}")
+          },
+          AddData(inputStream1, 4, 4, 5),
+          CheckNewAnswer(4, 5),
+          Execute { qe =>
+            val stateOperators = qe.lastProgress.stateOperators
+            assert(stateOperators.length === 
numExpectedStatefulOperatorsForOneEmptySource)
+            val opWithUpdatedRows = 
stateOperators.zipWithIndex.filterNot(_._1.numRowsUpdated == 0)
+            assert(opWithUpdatedRows.length === 1)
+            assert(opWithUpdatedRows.head._1.numRowsUpdated === 2,
+              s"stateful OP ID: ${opWithUpdatedRows.head._2}")
+          }
+        )
+      }
+    }
+
+    doTest(numExpectedStatefulOperatorsForOneEmptySource = 2)
+
+    withSQLConf(SQLConf.STREAMING_OPTIMIZE_ONE_ROW_PLAN_ENABLED.key -> "true") 
{
+      doTest(numExpectedStatefulOperatorsForOneEmptySource = 1)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to