This is an automated email from the ASF dual-hosted git repository.

tgraves pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6a27789ad7d5 [SPARK-47398][SQL] Extract a trait for 
InMemoryTableScanExec to allow for extending functionality
6a27789ad7d5 is described below

commit 6a27789ad7d59cd133653a49be0bb49729542abe
Author: Raza Jafri <rja...@nvidia.com>
AuthorDate: Thu Mar 21 14:45:56 2024 -0500

    [SPARK-47398][SQL] Extract a trait for InMemoryTableScanExec to allow for 
extending functionality
    
    ### What changes were proposed in this pull request?
    We are proposing to allow the users to have a custom 
`InMemoryTableScanExec`. To accomplish this we can follow the same convention 
we followed for `ShuffleExchangeLike` and `BroadcastExchangeLike`
    
    ### Why are the changes needed?
    In the PR added by ulysses-you, we are wrapping `InMemoryTableScanExec` 
inside `TableCacheQueryStageExec`. This could potentially cause problems, 
especially in the RAPIDS Accelerator for Apache Spark, where we replace 
`InMemoryTableScanExec` with a customized version that has optimizations needed 
by us. This situation could lead to the loss of benefits from 
[SPARK-42101](https://issues.apache.org/jira/browse/SPARK-42101) or even result 
in Spark throwing an Exception.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Ran the existing tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #45525 from razajafri/extract-inmem-trait.
    
    Authored-by: Raza Jafri <rja...@nvidia.com>
    Signed-off-by: Thomas Graves <tgra...@apache.org>
---
 .../execution/adaptive/AdaptiveSparkPlanExec.scala | 16 +++++------
 .../sql/execution/adaptive/QueryStageExec.scala    |  8 +++---
 .../execution/columnar/InMemoryTableScanExec.scala | 33 ++++++++++++++++++++--
 .../adaptive/AdaptiveQueryExecSuite.scala          |  6 ++--
 4 files changed, 45 insertions(+), 18 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index 2879aaca7215..a5e681535cb8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -41,7 +41,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._
 import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, 
DisableUnnecessaryBucketedScan}
-import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
+import org.apache.spark.sql.execution.columnar.InMemoryTableScanLike
 import org.apache.spark.sql.execution.exchange._
 import 
org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, 
SparkListenerSQLAdaptiveSQLMetricUpdates, SQLPlanMetric}
 import org.apache.spark.sql.internal.SQLConf
@@ -255,7 +255,7 @@ case class AdaptiveSparkPlanExec(
     //    and display SQL metrics correctly.
     // 2. If the `QueryExecution` does not match the current execution ID, it 
means the execution
     //    ID belongs to another (parent) query, and we should not call update 
UI in this query.
-    //    e.g., a nested `AdaptiveSparkPlanExec` in `InMemoryTableScanExec`.
+    //    e.g., a nested `AdaptiveSparkPlanExec` in `InMemoryTableScanLike`.
     //
     // That means only the root `AdaptiveSparkPlanExec` of the main query that 
triggers this
     // query execution need to do a plan update for the UI.
@@ -558,9 +558,9 @@ case class AdaptiveSparkPlanExec(
           }
       }
 
-    case i: InMemoryTableScanExec =>
-      // There is no reuse for `InMemoryTableScanExec`, which is different 
from `Exchange`. If we
-      // hit it the first time, we should always create a new query stage.
+    case i: InMemoryTableScanLike =>
+      // There is no reuse for `InMemoryTableScanLike`, which is different 
from `Exchange`.
+      // If we hit it the first time, we should always create a new query 
stage.
       val newStage = newQueryStage(i)
       CreateStageResult(
         newPlan = newStage,
@@ -605,12 +605,12 @@ case class AdaptiveSparkPlanExec(
           }
           BroadcastQueryStageExec(currentStageId, newPlan, e.canonicalized)
         }
-      case i: InMemoryTableScanExec =>
+      case i: InMemoryTableScanLike =>
         // Apply `queryStageOptimizerRules` so that we can reuse subquery.
-        // No need to apply `postStageCreationRules` for 
`InMemoryTableScanExec`
+        // No need to apply `postStageCreationRules` for 
`InMemoryTableScanLike`
         // as it's a leaf node.
         val newPlan = optimizeQueryStage(i, isFinalStage = false)
-        if (!newPlan.isInstanceOf[InMemoryTableScanExec]) {
+        if (!newPlan.isInstanceOf[InMemoryTableScanLike]) {
           throw SparkException.internalError(
             "Custom AQE rules cannot transform table scan node to something 
else.")
         }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
index 88954d6f822d..7db9271aee0c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.columnar.CachedBatch
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
+import org.apache.spark.sql.execution.columnar.InMemoryTableScanLike
 import org.apache.spark.sql.execution.exchange._
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
@@ -261,7 +261,7 @@ case class BroadcastQueryStageExec(
 }
 
 /**
- * A table cache query stage whose child is a [[InMemoryTableScanExec]].
+ * A table cache query stage whose child is a [[InMemoryTableScanLike]].
  *
  * @param id the query stage id.
  * @param plan the underlying plan.
@@ -271,7 +271,7 @@ case class TableCacheQueryStageExec(
     override val plan: SparkPlan) extends QueryStageExec {
 
   @transient val inMemoryTableScan = plan match {
-    case i: InMemoryTableScanExec => i
+    case i: InMemoryTableScanLike => i
     case _ =>
       throw SparkException.internalError(s"wrong plan for table cache stage:\n 
${plan.treeString}")
   }
@@ -294,5 +294,5 @@ case class TableCacheQueryStageExec(
 
   override protected def doMaterialize(): Future[Any] = future
 
-  override def getRuntimeStatistics: Statistics = 
inMemoryTableScan.relation.computeStats()
+  override def getRuntimeStatistics: Statistics = 
inMemoryTableScan.runtimeStatistics
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index 064a46369055..cfcfd282e548 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.QueryPlan
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.columnar.CachedBatch
 import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, 
WholeStageCodegenExec}
@@ -28,11 +29,32 @@ import 
org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
+/**
+ * Common trait for all InMemoryTableScans implementations to facilitate 
pattern matching.
+ */
+trait InMemoryTableScanLike extends LeafExecNode {
+
+  /**
+   * Returns whether the cache buffer is loaded
+   */
+  def isMaterialized: Boolean
+
+  /**
+   * Returns the actual cached RDD without filters and serialization of 
row/columnar.
+   */
+  def baseCacheRDD(): RDD[CachedBatch]
+
+  /**
+   * Returns the runtime statistics after materialization.
+   */
+  def runtimeStatistics: Statistics
+}
+
 case class InMemoryTableScanExec(
     attributes: Seq[Attribute],
     predicates: Seq[Expression],
     @transient relation: InMemoryRelation)
-  extends LeafExecNode {
+  extends InMemoryTableScanLike {
 
   override lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"))
@@ -176,13 +198,18 @@ case class InMemoryTableScanExec(
     columnarInputRDD
   }
 
-  def isMaterialized: Boolean = 
relation.cacheBuilder.isCachedColumnBuffersLoaded
+  override def isMaterialized: Boolean = 
relation.cacheBuilder.isCachedColumnBuffersLoaded
 
   /**
    * This method is only used by AQE which executes the actually cached RDD 
that without filter and
    * serialization of row/columnar.
    */
-  def baseCacheRDD(): RDD[CachedBatch] = {
+  override def baseCacheRDD(): RDD[CachedBatch] = {
     relation.cacheBuilder.cachedColumnBuffers
   }
+
+  /**
+   * Returns the runtime statistics after shuffle materialization.
+   */
+  override def runtimeStatistics: Statistics = relation.computeStats()
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index f528c5584fee..39f6aa8505b3 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, 
BuildRight}
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
 import org.apache.spark.sql.execution.{CollectLimitExec, ColumnarToRowExec, 
LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, 
ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, 
UnionExec}
 import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
-import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
+import org.apache.spark.sql.execution.columnar.{InMemoryTableScanExec, 
InMemoryTableScanLike}
 import org.apache.spark.sql.execution.command.DataWritingCommandExec
 import org.apache.spark.sql.execution.datasources.noop.NoopDataSource
 import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec
@@ -2763,7 +2763,7 @@ class AdaptiveQueryExecSuite
           case s: SortExec => s
         }.size == (if (firstAccess) 2 else 0))
         assert(collect(initialExecutedPlan) {
-          case i: InMemoryTableScanExec => i
+          case i: InMemoryTableScanLike => i
         }.head.isMaterialized != firstAccess)
 
         df.collect()
@@ -2775,7 +2775,7 @@ class AdaptiveQueryExecSuite
           case s: SortExec => s
         }.isEmpty)
         assert(collect(initialExecutedPlan) {
-          case i: InMemoryTableScanExec => i
+          case i: InMemoryTableScanLike => i
         }.head.isMaterialized)
       }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to