This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new fa67b5b  [SPARK-37442][SQL] InMemoryRelation statistics bug causing 
broadcast join failures with AQE enabled
fa67b5b is described below

commit fa67b5b6aafbfa2e39fd8c382993e8469c3340b3
Author: Michael Chen <mike.c...@workday.com>
AuthorDate: Thu Dec 2 22:33:02 2021 +0800

    [SPARK-37442][SQL] InMemoryRelation statistics bug causing broadcast join 
failures with AQE enabled
    
    Immediately materialize underlying rdd cache (using .count) for an 
InMemoryRelation when `buildBuffers` is called.
    
    Currently, when `CachedRDDBuilder.buildBuffers` is called, 
`InMemoryRelation.computeStats` will try to read the accumulators to determine 
what the relation size is. However, the accumulators are not actually accurate 
until the cachedRDD is executed and finishes. While this has not happened, the 
accumulators will report a range from 0 bytes to the accumulator value when the 
cachedRDD finishes. In AQE, join planning can happen during this time and, if 
it reads the size as 0 bytes, will  [...]
    
    Yes. Before, cache materialization doesn't happen until the job starts to 
run. Now, it happens when trying to get the rdd representing an 
InMemoryRelation.
    
    Tests added
    
    Closes #34684 from 
ChenMichael/SPARK-37442-InMemoryRelation-statistics-inaccurate-during-join-planning.
    
    Authored-by: Michael Chen <mike.c...@workday.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit c37b726bd09d34e1115a8af1969485e60dc02592)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/execution/columnar/InMemoryRelation.scala  | 21 +++++++--
 .../adaptive/AdaptiveQueryExecSuite.scala          | 53 ++++++++++++++++++++++
 .../sql/execution/joins/BroadcastJoinSuite.scala   | 26 +++++++++++
 3 files changed, 97 insertions(+), 3 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index c2ec9ff..525653c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.columnar
 
 import org.apache.commons.lang3.StringUtils
 
-import org.apache.spark.TaskContext
+import org.apache.spark.{SparkEnv, TaskContext}
 import org.apache.spark.network.util.JavaUtils
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
@@ -34,7 +34,7 @@ import 
org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapCol
 import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
 import org.apache.spark.sql.types.{BooleanType, ByteType, DoubleType, 
FloatType, IntegerType, LongType, ShortType, StructType, UserDefinedType}
 import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.{RDDBlockId, StorageLevel}
 import org.apache.spark.util.{LongAccumulator, Utils}
 
 /**
@@ -207,6 +207,7 @@ case class CachedRDDBuilder(
     tableName: Option[String]) {
 
   @transient @volatile private var _cachedColumnBuffers: RDD[CachedBatch] = 
null
+  @transient @volatile private var _cachedColumnBuffersAreLoaded: Boolean = 
false
 
   val sizeInBytesStats: LongAccumulator = 
cachedPlan.session.sparkContext.longAccumulator
   val rowCountStats: LongAccumulator = 
cachedPlan.session.sparkContext.longAccumulator
@@ -237,7 +238,21 @@ case class CachedRDDBuilder(
   }
 
   def isCachedColumnBuffersLoaded: Boolean = {
-    _cachedColumnBuffers != null
+    _cachedColumnBuffers != null && isCachedRDDLoaded
+  }
+
+  def isCachedRDDLoaded: Boolean = {
+      _cachedColumnBuffersAreLoaded || {
+        val bmMaster = SparkEnv.get.blockManager.master
+        val rddLoaded = _cachedColumnBuffers.partitions.forall { partition =>
+          bmMaster.getBlockStatus(RDDBlockId(_cachedColumnBuffers.id, 
partition.index), false)
+            .exists { case(_, blockStatus) => blockStatus.isCached }
+        }
+        if (rddLoaded) {
+          _cachedColumnBuffersAreLoaded = rddLoaded
+        }
+        rddLoaded
+    }
   }
 
   private def buildBuffers(): RDD[CachedBatch] = {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index 9837138..20772cb 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -2009,6 +2009,59 @@ class AdaptiveQueryExecSuite
       }
     }
   }
+
+  test("SPARK-37742: AQE reads invalid InMemoryRelation stats and mistakenly 
plans BHJ") {
+    withSQLConf(
+      SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1048584") {
+      // Spark estimates a string column as 20 bytes so with 60k rows, these 
relations should be
+      // estimated at ~120m bytes which is greater than the broadcast join 
threshold.
+      val joinKeyOne = "00112233445566778899"
+      val joinKeyTwo = "11223344556677889900"
+      Seq.fill(60000)(joinKeyOne).toDF("key")
+        .createOrReplaceTempView("temp")
+      Seq.fill(60000)(joinKeyTwo).toDF("key")
+        .createOrReplaceTempView("temp2")
+
+      Seq(joinKeyOne).toDF("key").createOrReplaceTempView("smallTemp")
+      spark.sql("SELECT key as newKey FROM temp").persist()
+
+      // This query is trying to set up a situation where there are three 
joins.
+      // The first join will join the cached relation with a smaller relation.
+      // The first join is expected to be a broadcast join since the smaller 
relation will
+      // fit under the broadcast join threshold.
+      // The second join will join the first join with another relation and is 
expected
+      // to remain as a sort-merge join.
+      // The third join will join the cached relation with another relation 
and is expected
+      // to remain as a sort-merge join.
+      val query =
+      s"""
+         |SELECT t3.newKey
+         |FROM
+         |  (SELECT t1.newKey
+         |  FROM (SELECT key as newKey FROM temp) as t1
+         |        JOIN
+         |        (SELECT key FROM smallTemp) as t2
+         |        ON t1.newKey = t2.key
+         |  ) as t3
+         |  JOIN
+         |  (SELECT key FROM temp2) as t4
+         |  ON t3.newKey = t4.key
+         |UNION
+         |SELECT t1.newKey
+         |FROM
+         |    (SELECT key as newKey FROM temp) as t1
+         |    JOIN
+         |    (SELECT key FROM temp2) as t2
+         |    ON t1.newKey = t2.key
+         |""".stripMargin
+      val df = spark.sql(query)
+      df.collect()
+      val adaptivePlan = df.queryExecution.executedPlan
+      val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+      assert(bhj.length == 1)
+    }
+  }
 }
 
 /**
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index dd6a412..a8b4856 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -637,6 +637,32 @@ abstract class BroadcastJoinSuiteBase extends QueryTest 
with SQLTestUtils
     }
   }
 
+  test("SPARK-37742: join planning shouldn't read invalid InMemoryRelation 
stats") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10") {
+      try {
+        val df1 = Seq(1).toDF("key")
+        val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")
+        df2.persist()
+        df2.queryExecution.toRdd
+
+        val df3 = df1.join(df2, Seq("key"), "inner")
+        val numCachedPlan = collect(df3.queryExecution.executedPlan) {
+          case i: InMemoryTableScanExec => i
+        }.size
+        // df2 should be cached.
+        assert(numCachedPlan === 1)
+
+        val numBroadCastHashJoin = collect(df3.queryExecution.executedPlan) {
+          case b: BroadcastHashJoinExec => b
+        }.size
+        // df2 should not be broadcasted.
+        assert(numBroadCastHashJoin === 0)
+      } finally {
+        spark.catalog.clearCache()
+      }
+    }
+  }
+
   private def expressionsEqual(l: Seq[Expression], r: Seq[Expression]): 
Boolean = {
     l.length == r.length && l.zip(r).forall { case (e1, e2) => 
e1.semanticEquals(e2) }
   }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to