This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.1 by this push:
     new 7d9320085999 [SPARK-52767][SQL] Optimize maxRows and 
maxRowsPerPartition for join and union
7d9320085999 is described below

commit 7d932008599927797f7e902ed10abc466675c331
Author: zml1206 <[email protected]>
AuthorDate: Tue Nov 18 22:35:43 2025 +0800

    [SPARK-52767][SQL] Optimize maxRows and maxRowsPerPartition for join and 
union
    
    ### What changes were proposed in this pull request?
    Make the `maxRows` and `maxRowsPerPartition`  only calculated at most once.
    
    ### Why are the changes needed?
    Improve performance, especially when there are dozens of joins and unions.
    Before pr, the number of maxRows executions of join/union increases 
exponentially with the number of joins/unions.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    
    Local test, 28 tables join before pr 36s, after pr  4s, 29 tables join 
before pr 67s, after pr  5s
    ```
        Seq(1).toDF("a").write.mode("overwrite").parquet("tmp/t1")
        spark.read.parquet("tmp/t1").createOrReplaceTempView("t")
        val t1 = System.currentTimeMillis()
        spark.sql(
          """
            |select a,count(1) from (
            |select t1.a from (select distinct a from t) t1
            |join t t2 on t1.a=t2.a
            |join t t3 on t1.a=t3.a
            |join t t4 on t1.a=t4.a
            |join t t5 on t1.a=t5.a
            |join t t6 on t1.a=t6.a
            |join t t7 on t1.a=t7.a
            |join t t8 on t1.a=t8.a
            |join t t9 on t1.a=t9.a
            |join t t10 on t1.a=t10.a
            |join t t11 on t1.a=t11.a
            |join t t12 on t1.a=t12.a
            |join t t13 on t1.a=t13.a
            |join t t14 on t1.a=t14.a
            |join t t15 on t1.a=t15.a
            |join t t16 on t1.a=t16.a
            |join t t17 on t1.a=t17.a
            |join t t18 on t1.a=t18.a
            |join t t19 on t1.a=t19.a
            |join t t20 on t1.a=t20.a
            |join t t21 on t1.a=t21.a
            |join t t22 on t1.a=t22.a
            |join t t23 on t1.a=t23.a
            |join t t24 on t1.a=t24.a
            |join t t25 on t1.a=t25.a
            |join t t26 on t1.a=t26.a
            |join t t27 on t1.a=t27.a
            |join t t28 on t1.a=t28.a
            |) group by a
            |""".stripMargin).show
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #51451 from zml1206/SPARK-52767.
    
    Authored-by: zml1206 <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
    (cherry picked from commit aa387f32158a98260f7b9b16dc87feb64b504ab4)
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../plans/logical/basicLogicalOperators.scala      | 44 +++++++++++-----------
 1 file changed, 21 insertions(+), 23 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 142420ee258a..b87d018f2ab1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -580,19 +580,18 @@ case class Union(
     allowMissingCol: Boolean = false) extends UnionBase {
   assert(!allowMissingCol || byName, "`allowMissingCol` can be true only if 
`byName` is true.")
 
-  override def maxRows: Option[Long] = {
-    var sum = BigInt(0)
-    children.foreach { child =>
-      if (child.maxRows.isDefined) {
-        sum += child.maxRows.get
-        if (!sum.isValidLong) {
-          return None
+  override lazy val maxRows: Option[Long] = {
+    val sum = children.foldLeft(Option(BigInt(0))) {
+      case (Some(acc), child) =>
+        child.maxRows match {
+          case Some(n) =>
+            val newSum = acc + n
+            if (newSum.isValidLong) Some(newSum) else None
+          case None => None
         }
-      } else {
-        return None
-      }
+      case (None, _) => None
     }
-    Some(sum.toLong)
+    sum.map(_.toLong)
   }
 
   final override val nodePatterns: Seq[TreePattern] = Seq(UNION)
@@ -600,19 +599,18 @@ case class Union(
   /**
    * Note the definition has assumption about how union is implemented 
physically.
    */
-  override def maxRowsPerPartition: Option[Long] = {
-    var sum = BigInt(0)
-    children.foreach { child =>
-      if (child.maxRowsPerPartition.isDefined) {
-        sum += child.maxRowsPerPartition.get
-        if (!sum.isValidLong) {
-          return None
+  override lazy val maxRowsPerPartition: Option[Long] = {
+    val sum = children.foldLeft(Option(BigInt(0))) {
+      case (Some(acc), child) =>
+        child.maxRowsPerPartition match {
+          case Some(n) =>
+            val newSum = acc + n
+            if (newSum.isValidLong) Some(newSum) else None
+          case None => None
         }
-      } else {
-        return None
-      }
+      case (None, _) => None
     }
-    Some(sum.toLong)
+    sum.map(_.toLong)
   }
 
   private def duplicatesResolvedPerBranch: Boolean =
@@ -666,7 +664,7 @@ case class Join(
     hint: JoinHint)
   extends BinaryNode with PredicateHelper {
 
-  override def maxRows: Option[Long] = {
+  override lazy val maxRows: Option[Long] = {
     joinType match {
       case Inner | Cross | FullOuter | LeftOuter | RightOuter | LeftSingle
           if left.maxRows.isDefined && right.maxRows.isDefined =>


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to