This is an automated email from the ASF dual-hosted git repository.

tgraves pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 1fe396213b57 [SPARK-47398][SQL] Extract a trait for 
InMemoryTableScanExec to allow for extending functionality
1fe396213b57 is described below

commit 1fe396213b57e4697145f5fa1b9f0d24a35399df
Author: Raza Jafri <rja...@nvidia.com>
AuthorDate: Thu Mar 21 14:45:56 2024 -0500

    [SPARK-47398][SQL] Extract a trait for InMemoryTableScanExec to allow for 
extending functionality
    
    ### What changes were proposed in this pull request?
    We are proposing to allow the users to have a custom 
`InMemoryTableScanExec`. To accomplish this we can follow the same convention 
we followed for `ShuffleExchangeLike` and `BroadcastExchangeLike`
    
    ### Why are the changes needed?
    In the PR added by ulysses-you, we are wrapping `InMemoryTableScanExec` 
inside `TableCacheQueryStageExec`. This could potentially cause problems, 
especially in the RAPIDS Accelerator for Apache Spark, where we replace 
`InMemoryTableScanExec` with a customized version that has optimizations needed 
by us. This situation could lead to the loss of benefits from 
[SPARK-42101](https://issues.apache.org/jira/browse/SPARK-42101) or even result 
in Spark throwing an Exception.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Ran the existing tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #45525 from razajafri/extract-inmem-trait.
    
    Authored-by: Raza Jafri <rja...@nvidia.com>
    Signed-off-by: Thomas Graves <tgra...@apache.org>
    (cherry picked from commit 6a27789ad7d59cd133653a49be0bb49729542abe)
    Signed-off-by: Thomas Graves <tgra...@apache.org>
---
 .../execution/adaptive/AdaptiveSparkPlanExec.scala | 16 +++++------
 .../sql/execution/adaptive/QueryStageExec.scala    |  8 +++---
 .../execution/columnar/InMemoryTableScanExec.scala | 33 ++++++++++++++++++++--
 .../adaptive/AdaptiveQueryExecSuite.scala          |  6 ++--
 4 files changed, 45 insertions(+), 18 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index 96b83a91cc73..d2e879e3eddb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -41,7 +41,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._
 import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, 
DisableUnnecessaryBucketedScan}
-import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
+import org.apache.spark.sql.execution.columnar.InMemoryTableScanLike
 import org.apache.spark.sql.execution.exchange._
 import 
org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, 
SparkListenerSQLAdaptiveSQLMetricUpdates, SQLPlanMetric}
 import org.apache.spark.sql.internal.SQLConf
@@ -254,7 +254,7 @@ case class AdaptiveSparkPlanExec(
     //    and display SQL metrics correctly.
     // 2. If the `QueryExecution` does not match the current execution ID, it 
means the execution
     //    ID belongs to another (parent) query, and we should not call update 
UI in this query.
-    //    e.g., a nested `AdaptiveSparkPlanExec` in `InMemoryTableScanExec`.
+    //    e.g., a nested `AdaptiveSparkPlanExec` in `InMemoryTableScanLike`.
     //
     // That means only the root `AdaptiveSparkPlanExec` of the main query that 
triggers this
     // query execution need to do a plan update for the UI.
@@ -557,9 +557,9 @@ case class AdaptiveSparkPlanExec(
           }
       }
 
-    case i: InMemoryTableScanExec =>
-      // There is no reuse for `InMemoryTableScanExec`, which is different 
from `Exchange`. If we
-      // hit it the first time, we should always create a new query stage.
+    case i: InMemoryTableScanLike =>
+      // There is no reuse for `InMemoryTableScanLike`, which is different 
from `Exchange`.
+      // If we hit it the first time, we should always create a new query 
stage.
       val newStage = newQueryStage(i)
       CreateStageResult(
         newPlan = newStage,
@@ -604,12 +604,12 @@ case class AdaptiveSparkPlanExec(
           }
           BroadcastQueryStageExec(currentStageId, newPlan, e.canonicalized)
         }
-      case i: InMemoryTableScanExec =>
+      case i: InMemoryTableScanLike =>
         // Apply `queryStageOptimizerRules` so that we can reuse subquery.
-        // No need to apply `postStageCreationRules` for 
`InMemoryTableScanExec`
+        // No need to apply `postStageCreationRules` for 
`InMemoryTableScanLike`
         // as it's a leaf node.
         val newPlan = optimizeQueryStage(i, isFinalStage = false)
-        if (!newPlan.isInstanceOf[InMemoryTableScanExec]) {
+        if (!newPlan.isInstanceOf[InMemoryTableScanLike]) {
           throw SparkException.internalError(
             "Custom AQE rules cannot transform table scan node to something 
else.")
         }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
index b941feb12fc0..433315c49321 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.columnar.CachedBatch
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
+import org.apache.spark.sql.execution.columnar.InMemoryTableScanLike
 import org.apache.spark.sql.execution.exchange._
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
@@ -261,7 +261,7 @@ case class BroadcastQueryStageExec(
 }
 
 /**
- * A table cache query stage whose child is a [[InMemoryTableScanExec]].
+ * A table cache query stage whose child is a [[InMemoryTableScanLike]].
  *
  * @param id the query stage id.
  * @param plan the underlying plan.
@@ -271,7 +271,7 @@ case class TableCacheQueryStageExec(
     override val plan: SparkPlan) extends QueryStageExec {
 
   @transient val inMemoryTableScan = plan match {
-    case i: InMemoryTableScanExec => i
+    case i: InMemoryTableScanLike => i
     case _ =>
       throw new IllegalStateException(s"wrong plan for table cache stage:\n 
${plan.treeString}")
   }
@@ -294,5 +294,5 @@ case class TableCacheQueryStageExec(
 
   override protected def doMaterialize(): Future[Any] = future
 
-  override def getRuntimeStatistics: Statistics = 
inMemoryTableScan.relation.computeStats()
+  override def getRuntimeStatistics: Statistics = 
inMemoryTableScan.runtimeStatistics
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index 08244a4f84fe..5ff8bfd75f8a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.QueryPlan
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.columnar.CachedBatch
 import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, 
WholeStageCodegenExec}
@@ -28,11 +29,32 @@ import 
org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
+/**
+ * Common trait for all InMemoryTableScans implementations to facilitate 
pattern matching.
+ */
+trait InMemoryTableScanLike extends LeafExecNode {
+
+  /**
+   * Returns whether the cache buffer is loaded
+   */
+  def isMaterialized: Boolean
+
+  /**
+   * Returns the actual cached RDD without filters and serialization of 
row/columnar.
+   */
+  def baseCacheRDD(): RDD[CachedBatch]
+
+  /**
+   * Returns the runtime statistics after materialization.
+   */
+  def runtimeStatistics: Statistics
+}
+
 case class InMemoryTableScanExec(
     attributes: Seq[Attribute],
     predicates: Seq[Expression],
     @transient relation: InMemoryRelation)
-  extends LeafExecNode {
+  extends InMemoryTableScanLike {
 
   override lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"))
@@ -167,13 +189,18 @@ case class InMemoryTableScanExec(
     columnarInputRDD
   }
 
-  def isMaterialized: Boolean = 
relation.cacheBuilder.isCachedColumnBuffersLoaded
+  override def isMaterialized: Boolean = 
relation.cacheBuilder.isCachedColumnBuffersLoaded
 
   /**
    * This method is only used by AQE which executes the actually cached RDD 
that without filter and
    * serialization of row/columnar.
    */
-  def baseCacheRDD(): RDD[CachedBatch] = {
+  override def baseCacheRDD(): RDD[CachedBatch] = {
     relation.cacheBuilder.cachedColumnBuffers
   }
+
+  /**
+   * Returns the runtime statistics after shuffle materialization.
+   */
+  override def runtimeStatistics: Statistics = relation.computeStats()
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index 68bae34790a0..7c280f72ca17 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, 
BuildRight}
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
 import org.apache.spark.sql.execution.{CollectLimitExec, ColumnarToRowExec, 
LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, 
ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, 
UnionExec}
 import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
-import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
+import org.apache.spark.sql.execution.columnar.{InMemoryTableScanExec, 
InMemoryTableScanLike}
 import org.apache.spark.sql.execution.command.DataWritingCommandExec
 import org.apache.spark.sql.execution.datasources.noop.NoopDataSource
 import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec
@@ -2758,7 +2758,7 @@ class AdaptiveQueryExecSuite
           case s: SortExec => s
         }.size == (if (firstAccess) 2 else 0))
         assert(collect(initialExecutedPlan) {
-          case i: InMemoryTableScanExec => i
+          case i: InMemoryTableScanLike => i
         }.head.isMaterialized != firstAccess)
 
         df.collect()
@@ -2770,7 +2770,7 @@ class AdaptiveQueryExecSuite
           case s: SortExec => s
         }.isEmpty)
         assert(collect(initialExecutedPlan) {
-          case i: InMemoryTableScanExec => i
+          case i: InMemoryTableScanLike => i
         }.head.isMaterialized)
       }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to