This is an automated email from the ASF dual-hosted git repository. tgraves pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 1fe396213b57 [SPARK-47398][SQL] Extract a trait for InMemoryTableScanExec to allow for extending functionality 1fe396213b57 is described below commit 1fe396213b57e4697145f5fa1b9f0d24a35399df Author: Raza Jafri <rja...@nvidia.com> AuthorDate: Thu Mar 21 14:45:56 2024 -0500 [SPARK-47398][SQL] Extract a trait for InMemoryTableScanExec to allow for extending functionality ### What changes were proposed in this pull request? We are proposing to allow the users to have a custom `InMemoryTableScanExec`. To accomplish this we can follow the same convention we followed for `ShuffleExchangeLike` and `BroadcastExchangeLike` ### Why are the changes needed? In the PR added by ulysses-you, we are wrapping `InMemoryTableScanExec` inside `TableCacheQueryStageExec`. This could potentially cause problems, especially in the RAPIDS Accelerator for Apache Spark, where we replace `InMemoryTableScanExec` with a customized version that has optimizations needed by us. This situation could lead to the loss of benefits from [SPARK-42101](https://issues.apache.org/jira/browse/SPARK-42101) or even result in Spark throwing an Exception. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Ran the existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #45525 from razajafri/extract-inmem-trait. Authored-by: Raza Jafri <rja...@nvidia.com> Signed-off-by: Thomas Graves <tgra...@apache.org> (cherry picked from commit 6a27789ad7d59cd133653a49be0bb49729542abe) Signed-off-by: Thomas Graves <tgra...@apache.org> --- .../execution/adaptive/AdaptiveSparkPlanExec.scala | 16 +++++------ .../sql/execution/adaptive/QueryStageExec.scala | 8 +++--- .../execution/columnar/InMemoryTableScanExec.scala | 33 ++++++++++++++++++++-- .../adaptive/AdaptiveQueryExecSuite.scala | 6 ++-- 4 files changed, 45 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 96b83a91cc73..d2e879e3eddb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._ import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableUnnecessaryBucketedScan} -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.columnar.InMemoryTableScanLike import org.apache.spark.sql.execution.exchange._ import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates, SQLPlanMetric} import org.apache.spark.sql.internal.SQLConf @@ -254,7 +254,7 @@ case class AdaptiveSparkPlanExec( // and display SQL metrics correctly. // 2. If the `QueryExecution` does not match the current execution ID, it means the execution // ID belongs to another (parent) query, and we should not call update UI in this query. - // e.g., a nested `AdaptiveSparkPlanExec` in `InMemoryTableScanExec`. + // e.g., a nested `AdaptiveSparkPlanExec` in `InMemoryTableScanLike`. // // That means only the root `AdaptiveSparkPlanExec` of the main query that triggers this // query execution need to do a plan update for the UI. @@ -557,9 +557,9 @@ case class AdaptiveSparkPlanExec( } } - case i: InMemoryTableScanExec => - // There is no reuse for `InMemoryTableScanExec`, which is different from `Exchange`. If we - // hit it the first time, we should always create a new query stage. + case i: InMemoryTableScanLike => + // There is no reuse for `InMemoryTableScanLike`, which is different from `Exchange`. + // If we hit it the first time, we should always create a new query stage. val newStage = newQueryStage(i) CreateStageResult( newPlan = newStage, @@ -604,12 +604,12 @@ case class AdaptiveSparkPlanExec( } BroadcastQueryStageExec(currentStageId, newPlan, e.canonicalized) } - case i: InMemoryTableScanExec => + case i: InMemoryTableScanLike => // Apply `queryStageOptimizerRules` so that we can reuse subquery. - // No need to apply `postStageCreationRules` for `InMemoryTableScanExec` + // No need to apply `postStageCreationRules` for `InMemoryTableScanLike` // as it's a leaf node. val newPlan = optimizeQueryStage(i, isFinalStage = false) - if (!newPlan.isInstanceOf[InMemoryTableScanExec]) { + if (!newPlan.isInstanceOf[InMemoryTableScanLike]) { throw SparkException.internalError( "Custom AQE rules cannot transform table scan node to something else.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index b941feb12fc0..433315c49321 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.columnar.CachedBatch import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.columnar.InMemoryTableScanLike import org.apache.spark.sql.execution.exchange._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -261,7 +261,7 @@ case class BroadcastQueryStageExec( } /** - * A table cache query stage whose child is a [[InMemoryTableScanExec]]. + * A table cache query stage whose child is a [[InMemoryTableScanLike]]. * * @param id the query stage id. * @param plan the underlying plan. @@ -271,7 +271,7 @@ case class TableCacheQueryStageExec( override val plan: SparkPlan) extends QueryStageExec { @transient val inMemoryTableScan = plan match { - case i: InMemoryTableScanExec => i + case i: InMemoryTableScanLike => i case _ => throw new IllegalStateException(s"wrong plan for table cache stage:\n ${plan.treeString}") } @@ -294,5 +294,5 @@ case class TableCacheQueryStageExec( override protected def doMaterialize(): Future[Any] = future - override def getRuntimeStatistics: Statistics = inMemoryTableScan.relation.computeStats() + override def getRuntimeStatistics: Statistics = inMemoryTableScan.runtimeStatistics } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 08244a4f84fe..5ff8bfd75f8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.columnar.CachedBatch import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, WholeStageCodegenExec} @@ -28,11 +29,32 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.vectorized.ColumnarBatch +/** + * Common trait for all InMemoryTableScans implementations to facilitate pattern matching. + */ +trait InMemoryTableScanLike extends LeafExecNode { + + /** + * Returns whether the cache buffer is loaded + */ + def isMaterialized: Boolean + + /** + * Returns the actual cached RDD without filters and serialization of row/columnar. + */ + def baseCacheRDD(): RDD[CachedBatch] + + /** + * Returns the runtime statistics after materialization. + */ + def runtimeStatistics: Statistics +} + case class InMemoryTableScanExec( attributes: Seq[Attribute], predicates: Seq[Expression], @transient relation: InMemoryRelation) - extends LeafExecNode { + extends InMemoryTableScanLike { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -167,13 +189,18 @@ case class InMemoryTableScanExec( columnarInputRDD } - def isMaterialized: Boolean = relation.cacheBuilder.isCachedColumnBuffersLoaded + override def isMaterialized: Boolean = relation.cacheBuilder.isCachedColumnBuffersLoaded /** * This method is only used by AQE which executes the actually cached RDD that without filter and * serialization of row/columnar. */ - def baseCacheRDD(): RDD[CachedBatch] = { + override def baseCacheRDD(): RDD[CachedBatch] = { relation.cacheBuilder.cachedColumnBuffers } + + /** + * Returns the runtime statistics after shuffle materialization. + */ + override def runtimeStatistics: Statistics = relation.computeStats() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 68bae34790a0..7c280f72ca17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.execution.{CollectLimitExec, ColumnarToRowExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnionExec} import org.apache.spark.sql.execution.aggregate.BaseAggregateExec -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.columnar.{InMemoryTableScanExec, InMemoryTableScanLike} import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.noop.NoopDataSource import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec @@ -2758,7 +2758,7 @@ class AdaptiveQueryExecSuite case s: SortExec => s }.size == (if (firstAccess) 2 else 0)) assert(collect(initialExecutedPlan) { - case i: InMemoryTableScanExec => i + case i: InMemoryTableScanLike => i }.head.isMaterialized != firstAccess) df.collect() @@ -2770,7 +2770,7 @@ class AdaptiveQueryExecSuite case s: SortExec => s }.isEmpty) assert(collect(initialExecutedPlan) { - case i: InMemoryTableScanExec => i + case i: InMemoryTableScanLike => i }.head.isMaterialized) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org