This is an automated email from the ASF dual-hosted git repository.

yao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8950add773e6 [SPARK-48188][SQL] Consistently use normalized plan for 
cache
8950add773e6 is described below

commit 8950add773e63a910900f796950a6a58e40a8577
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Wed May 8 20:11:24 2024 +0800

    [SPARK-48188][SQL] Consistently use normalized plan for cache
    
    ### What changes were proposed in this pull request?
    
    We must consistently use normalized plans for cache filling and lookup, or 
inconsistency will lead to cache misses.
    
    To guarantee this, this PR makes `CacheManager` the central place to do 
plan normalization, so that callers don't need to care about it. Now most APIs 
in `CacheManager` take either `Dataset` or `LogicalPlan`. For `Dataset`, we get 
the normalized plan directly. For `LogicalPlan`, we normalize it before further 
use.
    
    The caller side should pass `Dataset` when invoking `CacheManager`, if it 
already creates `Dataset`. This is to reduce the impact, as extra creation of 
`Dataset` may have perf issues or introduce unexpected analysis exception.
    
    ### Why are the changes needed?
    
    Avoid unnecessary cache misses for users who add custom normalization rules
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, perf only
    
    ### How was this patch tested?
    
    existing tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #46465 from cloud-fan/cache.
    
    Authored-by: Wenchen Fan <wenc...@databricks.com>
    Signed-off-by: Kent Yao <y...@apache.org>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  |   3 +-
 .../apache/spark/sql/execution/CacheManager.scala  | 160 +++++++++++++--------
 .../spark/sql/execution/QueryExecution.scala       |  37 +++--
 .../execution/command/AnalyzeColumnCommand.scala   |   4 +-
 .../spark/sql/execution/command/CommandUtils.scala |   2 +-
 .../execution/datasources/v2/CacheTableExec.scala  |  30 ++--
 .../datasources/v2/DataSourceV2Strategy.scala      |   2 +-
 .../apache/spark/sql/internal/CatalogImpl.scala    |   5 +-
 .../org/apache/spark/sql/CachedTableSuite.scala    |   2 +-
 .../org/apache/spark/sql/test/SQLTestUtils.scala   |   3 +-
 .../apache/spark/sql/hive/CachedTableSuite.scala   |   9 +-
 11 files changed, 150 insertions(+), 107 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 18c9704afdf8..3e843e64ebbf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -3904,8 +3904,7 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   def unpersist(blocking: Boolean): this.type = {
-    sparkSession.sharedState.cacheManager.uncacheQuery(
-      sparkSession, logicalPlan, cascade = false, blocking)
+    sparkSession.sharedState.cacheManager.uncacheQuery(this, cascade = false, 
blocking)
     this
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index ae99873a9f77..b96f257e6b5b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.{Dataset, SparkSession}
 import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
 import org.apache.spark.sql.catalyst.expressions.{Attribute, 
SubqueryExpression}
 import org.apache.spark.sql.catalyst.optimizer.EliminateResolvedHint
-import org.apache.spark.sql.catalyst.plans.logical.{IgnoreCachedData, 
LogicalPlan, ResolvedHint, SubqueryAlias, View}
+import org.apache.spark.sql.catalyst.plans.logical.{IgnoreCachedData, 
LogicalPlan, ResolvedHint, View}
 import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION
 import org.apache.spark.sql.catalyst.util.sideBySide
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -38,7 +38,10 @@ import org.apache.spark.storage.StorageLevel
 import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK
 
 /** Holds a cached logical plan and its data */
-case class CachedData(plan: LogicalPlan, cachedRepresentation: 
InMemoryRelation) {
+case class CachedData(
+    // A normalized resolved plan (See QueryExecution#normalized).
+    plan: LogicalPlan,
+    cachedRepresentation: InMemoryRelation) {
   override def toString: String =
     s"""
        |CachedData(
@@ -53,7 +56,9 @@ case class CachedData(plan: LogicalPlan, 
cachedRepresentation: InMemoryRelation)
  * InMemoryRelation.  This relation is automatically substituted query plans 
that return the
  * `sameResult` as the originally cached query.
  *
- * Internal to Spark SQL.
+ * Internal to Spark SQL. All its public APIs take analyzed plans and will 
normalize them before
+ * further usage, or take [[Dataset]] and get its normalized plan. See 
`QueryExecution.normalize`
+ * for more details about plan normalization.
  */
 class CacheManager extends Logging with AdaptiveSparkPlanHelper {
 
@@ -77,41 +82,43 @@ class CacheManager extends Logging with 
AdaptiveSparkPlanHelper {
     cachedData.isEmpty
   }
 
+  // Test-only
+  def cacheQuery(query: Dataset[_]): Unit = {
+    cacheQuery(query, tableName = None, storageLevel = MEMORY_AND_DISK)
+  }
+
   /**
    * Caches the data produced by the logical representation of the given 
[[Dataset]].
-   * Unlike `RDD.cache()`, the default storage level is set to be 
`MEMORY_AND_DISK` because
-   * recomputing the in-memory columnar representation of the underlying table 
is expensive.
    */
   def cacheQuery(
       query: Dataset[_],
-      tableName: Option[String] = None,
-      storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = {
-    cacheQuery(query.sparkSession, query.queryExecution.normalized, tableName, 
storageLevel)
+      tableName: Option[String],
+      storageLevel: StorageLevel): Unit = {
+    cacheQueryInternal(query.sparkSession, query.queryExecution.normalized, 
tableName, storageLevel)
   }
 
   /**
-   * Caches the data produced by the given [[LogicalPlan]].
-   * Unlike `RDD.cache()`, the default storage level is set to be 
`MEMORY_AND_DISK` because
-   * recomputing the in-memory columnar representation of the underlying table 
is expensive.
+   * Caches the data produced by the given [[LogicalPlan]]. The given plan 
will be normalized
+   * before being used further.
    */
   def cacheQuery(
       spark: SparkSession,
       planToCache: LogicalPlan,
-      tableName: Option[String]): Unit = {
-    cacheQuery(spark, planToCache, tableName, MEMORY_AND_DISK)
+      tableName: Option[String],
+      storageLevel: StorageLevel): Unit = {
+    val normalized = QueryExecution.normalize(spark, planToCache)
+    cacheQueryInternal(spark, normalized, tableName, storageLevel)
   }
 
-  /**
-   * Caches the data produced by the given [[LogicalPlan]].
-   */
-  def cacheQuery(
+  // The `planToCache` should have been normalized.
+  private def cacheQueryInternal(
       spark: SparkSession,
       planToCache: LogicalPlan,
       tableName: Option[String],
       storageLevel: StorageLevel): Unit = {
     if (storageLevel == StorageLevel.NONE) {
       // Do nothing for StorageLevel.NONE since it will not actually cache any 
data.
-    } else if (lookupCachedData(planToCache).nonEmpty) {
+    } else if (lookupCachedDataInternal(planToCache).nonEmpty) {
       logWarning("Asked to cache already cached data.")
     } else {
       val sessionWithConfigsOff = getOrCloneSessionWithConfigsOff(spark)
@@ -124,7 +131,7 @@ class CacheManager extends Logging with 
AdaptiveSparkPlanHelper {
       }
 
       this.synchronized {
-        if (lookupCachedData(planToCache).nonEmpty) {
+        if (lookupCachedDataInternal(planToCache).nonEmpty) {
           logWarning("Data has already been cached.")
         } else {
           val cd = CachedData(planToCache, inMemoryRelation)
@@ -138,38 +145,64 @@ class CacheManager extends Logging with 
AdaptiveSparkPlanHelper {
 
   /**
    * Un-cache the given plan or all the cache entries that refer to the given 
plan.
-   * @param query     The [[Dataset]] to be un-cached.
-   * @param cascade   If true, un-cache all the cache entries that refer to 
the given
-   *                  [[Dataset]]; otherwise un-cache the given [[Dataset]] 
only.
+   *
+   * @param query    The [[Dataset]] to be un-cached.
+   * @param cascade  If true, un-cache all the cache entries that refer to the 
given
+   *                 [[Dataset]]; otherwise un-cache the given [[Dataset]] 
only.
+   * @param blocking Whether to block until all blocks are deleted.
    */
+  def uncacheQuery(
+      query: Dataset[_],
+      cascade: Boolean,
+      blocking: Boolean): Unit = {
+    uncacheQueryInternal(query.sparkSession, query.queryExecution.normalized, 
cascade, blocking)
+  }
+
+  // An overload to provide default value for the `blocking` parameter.
   def uncacheQuery(
       query: Dataset[_],
       cascade: Boolean): Unit = {
-    uncacheQuery(query.sparkSession, query.queryExecution.normalized, cascade)
+    uncacheQuery(query, cascade, blocking = false)
   }
 
   /**
    * Un-cache the given plan or all the cache entries that refer to the given 
plan.
-   * @param spark     The Spark session.
-   * @param plan      The plan to be un-cached.
-   * @param cascade   If true, un-cache all the cache entries that refer to 
the given
-   *                  plan; otherwise un-cache the given plan only.
-   * @param blocking  Whether to block until all blocks are deleted.
+   *
+   * @param spark    The Spark session.
+   * @param plan     The plan to be un-cached.
+   * @param cascade  If true, un-cache all the cache entries that refer to the 
given
+   *                 plan; otherwise un-cache the given plan only.
+   * @param blocking Whether to block until all blocks are deleted.
    */
   def uncacheQuery(
       spark: SparkSession,
       plan: LogicalPlan,
       cascade: Boolean,
-      blocking: Boolean = false): Unit = {
-    uncacheQuery(spark, _.sameResult(plan), cascade, blocking)
+      blocking: Boolean): Unit = {
+    val normalized = QueryExecution.normalize(spark, plan)
+    uncacheQueryInternal(spark, normalized, cascade, blocking)
+  }
+
+  // An overload to provide default value for the `blocking` parameter.
+  def uncacheQuery(
+      spark: SparkSession,
+      plan: LogicalPlan,
+      cascade: Boolean): Unit = {
+    uncacheQuery(spark, plan, cascade, blocking = false)
+  }
+
+  // The `plan` should have been normalized.
+  private def uncacheQueryInternal(
+      spark: SparkSession,
+      plan: LogicalPlan,
+      cascade: Boolean,
+      blocking: Boolean): Unit = {
+    uncacheByCondition(spark, _.sameResult(plan), cascade, blocking)
   }
 
   def uncacheTableOrView(spark: SparkSession, name: Seq[String], cascade: 
Boolean): Unit = {
-    uncacheQuery(
-      spark,
-      isMatchedTableOrView(_, name, spark.sessionState.conf),
-      cascade,
-      blocking = false)
+    uncacheByCondition(
+      spark, isMatchedTableOrView(_, name, spark.sessionState.conf), cascade, 
blocking = false)
   }
 
   private def isMatchedTableOrView(plan: LogicalPlan, name: Seq[String], conf: 
SQLConf): Boolean = {
@@ -178,28 +211,24 @@ class CacheManager extends Logging with 
AdaptiveSparkPlanHelper {
     }
 
     plan match {
-      case SubqueryAlias(ident, LogicalRelation(_, _, Some(catalogTable), _)) 
=>
-        val v1Ident = catalogTable.identifier
-        isSameName(ident.qualifier :+ ident.name) && 
isSameName(v1Ident.nameParts)
+      case LogicalRelation(_, _, Some(catalogTable), _) =>
+        isSameName(catalogTable.identifier.nameParts)
 
-      case SubqueryAlias(ident, DataSourceV2Relation(_, _, Some(catalog), 
Some(v2Ident), _)) =>
+      case DataSourceV2Relation(_, _, Some(catalog), Some(v2Ident), _) =>
         import 
org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper
-        isSameName(ident.qualifier :+ ident.name) &&
-          isSameName(v2Ident.toQualifiedNameParts(catalog))
+        isSameName(v2Ident.toQualifiedNameParts(catalog))
 
-      case SubqueryAlias(ident, View(catalogTable, _, _)) =>
-        val v1Ident = catalogTable.identifier
-        isSameName(ident.qualifier :+ ident.name) && 
isSameName(v1Ident.nameParts)
+      case View(catalogTable, _, _) =>
+        isSameName(catalogTable.identifier.nameParts)
 
-      case SubqueryAlias(ident, HiveTableRelation(catalogTable, _, _, _, _)) =>
-        val v1Ident = catalogTable.identifier
-        isSameName(ident.qualifier :+ ident.name) && 
isSameName(v1Ident.nameParts)
+      case HiveTableRelation(catalogTable, _, _, _, _) =>
+        isSameName(catalogTable.identifier.nameParts)
 
       case _ => false
     }
   }
 
-  def uncacheQuery(
+  private def uncacheByCondition(
       spark: SparkSession,
       isMatchedPlan: LogicalPlan => Boolean,
       cascade: Boolean,
@@ -252,10 +281,12 @@ class CacheManager extends Logging with 
AdaptiveSparkPlanHelper {
   }
 
   /**
-   * Tries to re-cache all the cache entries that refer to the given plan.
+   * Tries to re-cache all the cache entries that refer to the given plan. The 
given plan will be
+   * normalized before being used further.
    */
   def recacheByPlan(spark: SparkSession, plan: LogicalPlan): Unit = {
-    recacheByCondition(spark, _.plan.exists(_.sameResult(plan)))
+    val normalized = QueryExecution.normalize(spark, plan)
+    recacheByCondition(spark, _.plan.exists(_.sameResult(normalized)))
   }
 
   /**
@@ -278,7 +309,7 @@ class CacheManager extends Logging with 
AdaptiveSparkPlanHelper {
       }
       val recomputedPlan = cd.copy(cachedRepresentation = newCache)
       this.synchronized {
-        if (lookupCachedData(recomputedPlan.plan).nonEmpty) {
+        if (lookupCachedDataInternal(recomputedPlan.plan).nonEmpty) {
           logWarning("While recaching, data was already added to cache.")
         } else {
           cachedData = recomputedPlan +: cachedData
@@ -289,13 +320,23 @@ class CacheManager extends Logging with 
AdaptiveSparkPlanHelper {
     }
   }
 
-  /** Optionally returns cached data for the given [[Dataset]] */
+  /**
+   * Optionally returns cached data for the given [[Dataset]]
+   */
   def lookupCachedData(query: Dataset[_]): Option[CachedData] = {
-    lookupCachedData(query.queryExecution.normalized)
+    lookupCachedDataInternal(query.queryExecution.normalized)
   }
 
-  /** Optionally returns cached data for the given [[LogicalPlan]]. */
-  def lookupCachedData(plan: LogicalPlan): Option[CachedData] = {
+  /**
+   * Optionally returns cached data for the given [[LogicalPlan]]. The given 
plan will be normalized
+   * before being used further.
+   */
+  def lookupCachedData(session: SparkSession, plan: LogicalPlan): 
Option[CachedData] = {
+    val normalized = QueryExecution.normalize(session, plan)
+    lookupCachedDataInternal(normalized)
+  }
+
+  private def lookupCachedDataInternal(plan: LogicalPlan): Option[CachedData] 
= {
     val result = cachedData.find(cd => plan.sameResult(cd.plan))
     if (result.isDefined) {
       CacheManager.logCacheOperation(log"Dataframe cache hit for input plan:" +
@@ -305,13 +346,16 @@ class CacheManager extends Logging with 
AdaptiveSparkPlanHelper {
     result
   }
 
-  /** Replaces segments of the given logical plan with cached versions where 
possible. */
-  def useCachedData(plan: LogicalPlan): LogicalPlan = {
+  /**
+   * Replaces segments of the given logical plan with cached versions where 
possible. The input
+   * plan must be normalized.
+   */
+  private[sql] def useCachedData(plan: LogicalPlan): LogicalPlan = {
     val newPlan = plan transformDown {
       case command: IgnoreCachedData => command
 
       case currentFragment =>
-        lookupCachedData(currentFragment).map { cached =>
+        lookupCachedDataInternal(currentFragment).map { cached =>
           // After cache lookup, we should still keep the hints from the input 
plan.
           val hints = 
EliminateResolvedHint.extractHintsFromPlan(currentFragment)._2
           val cachedPlan = 
cached.cachedRepresentation.withOutput(currentFragment.output)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index d04d8dc2cd7f..357484ca19df 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -134,19 +134,7 @@ class QueryExecution(
 
   // The plan that has been normalized by custom rules, so that it's more 
likely to hit cache.
   lazy val normalized: LogicalPlan = {
-    val normalizationRules = sparkSession.sessionState.planNormalizationRules
-    if (normalizationRules.isEmpty) {
-      commandExecuted
-    } else {
-      val planChangeLogger = new PlanChangeLogger[LogicalPlan]()
-      val normalized = normalizationRules.foldLeft(commandExecuted) { (p, 
rule) =>
-        val result = rule.apply(p)
-        planChangeLogger.logRule(rule.ruleName, p, result)
-        result
-      }
-      planChangeLogger.logBatch("Plan Normalization", commandExecuted, 
normalized)
-      normalized
-    }
+    QueryExecution.normalize(sparkSession, commandExecuted, Some(tracker))
   }
 
   lazy val withCachedData: LogicalPlan = sparkSession.withActive {
@@ -613,4 +601,27 @@ object QueryExecution {
       case e: Throwable => throw toInternalError(msg, e)
     }
   }
+
+  def normalize(
+      session: SparkSession,
+      plan: LogicalPlan,
+      tracker: Option[QueryPlanningTracker] = None): LogicalPlan = {
+    val normalizationRules = session.sessionState.planNormalizationRules
+    if (normalizationRules.isEmpty) {
+      plan
+    } else {
+      val planChangeLogger = new PlanChangeLogger[LogicalPlan]()
+      val normalized = normalizationRules.foldLeft(plan) { (p, rule) =>
+        val startTime = System.nanoTime()
+        val result = rule.apply(p)
+        val runTime = System.nanoTime() - startTime
+        val effective = !result.fastEquals(p)
+        tracker.foreach(_.recordRuleInvocation(rule.ruleName, runTime, 
effective))
+        planChangeLogger.logRule(rule.ruleName, p, result)
+        result
+      }
+      planChangeLogger.logBatch("Plan Normalization", plan, normalized)
+      normalized
+    }
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
index 299f41eb55e1..7b0ce3e59263 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
@@ -61,8 +61,8 @@ case class AnalyzeColumnCommand(
 
   private def analyzeColumnInCachedData(plan: LogicalPlan, sparkSession: 
SparkSession): Boolean = {
     val cacheManager = sparkSession.sharedState.cacheManager
-    val planToLookup = sparkSession.sessionState.executePlan(plan).analyzed
-    cacheManager.lookupCachedData(planToLookup).map { cachedData =>
+    val df = Dataset.ofRows(sparkSession, plan)
+    cacheManager.lookupCachedData(df).map { cachedData =>
       val columnsToAnalyze = getColumnsToAnalyze(
         tableIdent, cachedData.cachedRepresentation, columnNames, allColumns)
       cacheManager.analyzeColumnCacheQuery(sparkSession, cachedData, 
columnsToAnalyze)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala
index d7c5df151bf1..7acd1cb0852b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala
@@ -240,7 +240,7 @@ object CommandUtils extends Logging {
       // Analyzes a catalog view if the view is cached
       val table = sparkSession.table(tableIdent.quotedString)
       val cacheManager = sparkSession.sharedState.cacheManager
-      if (cacheManager.lookupCachedData(table.logicalPlan).isDefined) {
+      if (cacheManager.lookupCachedData(table).isDefined) {
         if (!noScan) {
           // To collect table stats, materializes an underlying columnar RDD
           table.count()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala
index fc8a40f88545..56c44a125681 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala
@@ -21,9 +21,9 @@ import java.util.Locale
 
 import org.apache.spark.internal.LogKeys.OPTIONS
 import org.apache.spark.internal.MDC
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.Dataset
 import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
-import org.apache.spark.sql.catalyst.analysis.LocalTempView
+import org.apache.spark.sql.catalyst.analysis.{LocalTempView, 
UnresolvedRelation}
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
@@ -34,7 +34,6 @@ import org.apache.spark.storage.StorageLevel
 trait BaseCacheTableExec extends LeafV2CommandExec {
   def relationName: String
   def planToCache: LogicalPlan
-  def dataFrameForCachedPlan: DataFrame
   def isLazy: Boolean
   def options: Map[String, String]
 
@@ -49,15 +48,12 @@ trait BaseCacheTableExec extends LeafV2CommandExec {
       logWarning(log"Invalid options: ${MDC(OPTIONS, 
withoutStorageLevel.mkString(", "))}")
     }
 
-    session.sharedState.cacheManager.cacheQuery(
-      session,
-      planToCache,
-      Some(relationName),
-      storageLevel)
+    val df = Dataset.ofRows(session, planToCache)
+    session.sharedState.cacheManager.cacheQuery(df, Some(relationName), 
storageLevel)
 
     if (!isLazy) {
       // Performs eager caching.
-      dataFrameForCachedPlan.count()
+      df.count()
     }
 
     Seq.empty
@@ -74,10 +70,6 @@ case class CacheTableExec(
   override lazy val relationName: String = multipartIdentifier.quoted
 
   override lazy val planToCache: LogicalPlan = relation
-
-  override lazy val dataFrameForCachedPlan: DataFrame = {
-    Dataset.ofRows(session, planToCache)
-  }
 }
 
 case class CacheTableAsSelectExec(
@@ -89,7 +81,10 @@ case class CacheTableAsSelectExec(
     referredTempFunctions: Seq[String]) extends BaseCacheTableExec {
   override lazy val relationName: String = tempViewName
 
-  override lazy val planToCache: LogicalPlan = {
+  override def planToCache: LogicalPlan = UnresolvedRelation(Seq(tempViewName))
+
+  override def run(): Seq[InternalRow] = {
+    // CACHE TABLE AS TABLE creates a temp view and caches the temp view.
     CreateViewCommand(
       name = TableIdentifier(tempViewName),
       userSpecifiedColumns = Nil,
@@ -103,12 +98,7 @@ case class CacheTableAsSelectExec(
       isAnalyzed = true,
       referredTempFunctions = referredTempFunctions
     ).run(session)
-
-    dataFrameForCachedPlan.logicalPlan
-  }
-
-  override lazy val dataFrameForCachedPlan: DataFrame = {
-    session.table(tempViewName)
+    super.run()
   }
 }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index 0d926dcd99c4..7a668b75c3c7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -83,7 +83,7 @@ class DataSourceV2Strategy(session: SparkSession) extends 
Strategy with Predicat
   // given table, the cache's storage level is returned.
   private def invalidateTableCache(r: ResolvedTable)(): Option[StorageLevel] = 
{
     val v2Relation = DataSourceV2Relation.create(r.table, Some(r.catalog), 
Some(r.identifier))
-    val cache = session.sharedState.cacheManager.lookupCachedData(v2Relation)
+    val cache = session.sharedState.cacheManager.lookupCachedData(session, 
v2Relation)
     session.sharedState.cacheManager.uncacheQuery(session, v2Relation, cascade 
= true)
     if (cache.isDefined) {
       val cacheLevel = cache.get.cachedRepresentation.cacheBuilder.storageLevel
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
index df7c4ab1a0c7..3e20a23a0a06 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
@@ -734,9 +734,8 @@ class CatalogImpl(sparkSession: SparkSession) extends 
Catalog {
       // same way as how a permanent view is handled. This also avoids a 
potential issue where a
       // dependent view becomes invalid because of the above while its data is 
still cached.
       val viewText = viewDef.desc.viewText
-      val plan = sparkSession.sessionState.executePlan(viewDef)
-      sparkSession.sharedState.cacheManager.uncacheQuery(
-        sparkSession, plan.analyzed, cascade = viewText.isDefined)
+      val df = Dataset.ofRows(sparkSession, viewDef)
+      sparkSession.sharedState.cacheManager.uncacheQuery(df, cascade = 
viewText.isDefined)
     } catch {
       case NonFatal(_) => // ignore
     }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 0ad9ceefc419..d023fb82185a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -1107,7 +1107,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
     assert(queryStats1.map(_._1.name).isEmpty)
 
     val cacheManager = spark.sharedState.cacheManager
-    val cachedData = cacheManager.lookupCachedData(query().logicalPlan)
+    val cachedData = cacheManager.lookupCachedData(query())
     assert(cachedData.isDefined)
     val queryAttrs = cachedData.get.plan.output
     assert(queryAttrs.size === 3)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index 7da2bb47038e..5fbf379644f6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -339,8 +339,7 @@ private[sql] trait SQLTestUtilsBase
     val tableIdent = 
spark.sessionState.sqlParser.parseTableIdentifier(tableName)
     val cascade = !spark.sessionState.catalog.isTempView(tableIdent)
     spark.sharedState.cacheManager.uncacheQuery(
-      spark,
-      spark.table(tableName).logicalPlan,
+      spark.table(tableName),
       cascade = cascade,
       blocking = true)
   }
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
index 89fe10d5c4bd..d7918f8cbf4f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
@@ -335,9 +335,10 @@ class CachedTableSuite extends QueryTest with SQLTestUtils 
with TestHiveSingleto
         options = Map.empty)(sparkSession = spark)
 
       val plan = LogicalRelation(relation, tableMeta)
-      spark.sharedState.cacheManager.cacheQuery(Dataset.ofRows(spark, plan))
+      val df = Dataset.ofRows(spark, plan)
+      spark.sharedState.cacheManager.cacheQuery(df)
 
-      assert(spark.sharedState.cacheManager.lookupCachedData(plan).isDefined)
+      assert(spark.sharedState.cacheManager.lookupCachedData(df).isDefined)
 
       val sameCatalog = new CatalogFileIndex(spark, tableMeta, 0)
       val sameRelation = HadoopFsRelation(
@@ -347,9 +348,9 @@ class CachedTableSuite extends QueryTest with SQLTestUtils 
with TestHiveSingleto
         bucketSpec = None,
         fileFormat = new ParquetFileFormat(),
         options = Map.empty)(sparkSession = spark)
-      val samePlan = LogicalRelation(sameRelation, tableMeta)
+      val samePlanDf = Dataset.ofRows(spark, LogicalRelation(sameRelation, 
tableMeta))
 
-      
assert(spark.sharedState.cacheManager.lookupCachedData(samePlan).isDefined)
+      
assert(spark.sharedState.cacheManager.lookupCachedData(samePlanDf).isDefined)
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to