This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 11af786b35c [SPARK-45451][SQL] Make the default storage level of 
dataset cache configurable
11af786b35c is described below

commit 11af786b35cabe6d139dd9763ccf1af9ceb7eb9f
Author: ulysses-you <ulyssesyo...@gmail.com>
AuthorDate: Wed Oct 11 20:51:22 2023 +0800

    [SPARK-45451][SQL] Make the default storage level of dataset cache 
configurable
    
    ### What changes were proposed in this pull request?
    
    This pr adds a new config `spark.sql.defaultCacheStorageLevel` , so that 
people can use `set spark.sql.defaultCacheStorageLevel=xxx` to change the 
default storage level of `dataset.cache`.
    
    ### Why are the changes needed?
    
    Most people use the default storage level, so this pr makes it easy to 
change the storage level without touching code.
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    add test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #43259 from ulysses-you/cache.
    
    Authored-by: ulysses-you <ulyssesyo...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../org/apache/spark/sql/internal/SQLConf.scala    | 13 ++++++++
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  5 +--
 .../execution/datasources/v2/CacheTableExec.scala  | 22 +++++--------
 .../apache/spark/sql/internal/CatalogImpl.scala    |  4 +--
 .../org/apache/spark/sql/CachedTableSuite.scala    | 37 ++++++++++++++++++++++
 5 files changed, 61 insertions(+), 20 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 65d2e6136e9..12ec9e911d3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -44,6 +44,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import 
org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
 import org.apache.spark.sql.types.{AtomicType, TimestampNTZType, TimestampType}
+import org.apache.spark.storage.{StorageLevel, StorageLevelMapper}
 import org.apache.spark.unsafe.array.ByteArrayMethods
 import org.apache.spark.util.Utils
 
@@ -1563,6 +1564,15 @@ object SQLConf {
       .booleanConf
       .createWithDefault(true)
 
+  val DEFAULT_CACHE_STORAGE_LEVEL = 
buildConf("spark.sql.defaultCacheStorageLevel")
+    .doc("The default storage level of `dataset.cache()`, 
`catalog.cacheTable()` and " +
+      "sql query `CACHE TABLE t`.")
+    .version("4.0.0")
+    .stringConf
+    .transform(_.toUpperCase(Locale.ROOT))
+    .checkValues(StorageLevelMapper.values.map(_.name()).toSet)
+    .createWithDefault(StorageLevelMapper.MEMORY_AND_DISK.name())
+
   val CROSS_JOINS_ENABLED = buildConf("spark.sql.crossJoin.enabled")
     .internal()
     .doc("When false, we will throw an error if a query contains a cartesian 
product without " +
@@ -5027,6 +5037,9 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
 
   def groupByAliases: Boolean = getConf(GROUP_BY_ALIASES)
 
+  def defaultCacheStorageLevel: StorageLevel =
+    StorageLevel.fromString(getConf(DEFAULT_CACHE_STORAGE_LEVEL))
+
   def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED)
 
   override def sessionLocalTimeZone: String = 
getConf(SQLConf.SESSION_LOCAL_TIMEZONE)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 0cc037b157e..5079cfcca9d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -3798,10 +3798,7 @@ class Dataset[T] private[sql](
    * @group basic
    * @since 1.6.0
    */
-  def persist(): this.type = {
-    sparkSession.sharedState.cacheManager.cacheQuery(this)
-    this
-  }
+  def persist(): this.type = 
persist(sparkSession.sessionState.conf.defaultCacheStorageLevel)
 
   /**
    * Persist this Dataset with the default storage level (`MEMORY_AND_DISK`).
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala
index 8c14b5e3707..1744df83033 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CacheTableExec.scala
@@ -38,25 +38,19 @@ trait BaseCacheTableExec extends LeafV2CommandExec {
 
   override def run(): Seq[InternalRow] = {
     val storageLevelKey = "storagelevel"
-    val storageLevelValue =
-      
CaseInsensitiveMap(options).get(storageLevelKey).map(_.toUpperCase(Locale.ROOT))
+    val storageLevel = CaseInsensitiveMap(options).get(storageLevelKey)
+      .map(s => StorageLevel.fromString(s.toUpperCase(Locale.ROOT)))
+      .getOrElse(conf.defaultCacheStorageLevel)
     val withoutStorageLevel = options.filterKeys(_.toLowerCase(Locale.ROOT) != 
storageLevelKey)
     if (withoutStorageLevel.nonEmpty) {
       logWarning(s"Invalid options: ${withoutStorageLevel.mkString(", ")}")
     }
 
-    if (storageLevelValue.nonEmpty) {
-      session.sharedState.cacheManager.cacheQuery(
-        session,
-        planToCache,
-        Some(relationName),
-        StorageLevel.fromString(storageLevelValue.get))
-    } else {
-      session.sharedState.cacheManager.cacheQuery(
-        session,
-        planToCache,
-        Some(relationName))
-    }
+    session.sharedState.cacheManager.cacheQuery(
+      session,
+      planToCache,
+      Some(relationName),
+      storageLevel)
 
     if (!isLazy) {
       // Performs eager caching.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
index acc2055d779..74a4f1c9d4c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
@@ -760,13 +760,13 @@ class CatalogImpl(sparkSession: SparkSession) extends 
Catalog {
   }
 
   /**
-   * Caches the specified table or view in-memory.
+   * Persist the specified table or view with the default storage level,
    *
    * @group cachemgmt
    * @since 2.0.0
    */
   override def cacheTable(tableName: String): Unit = {
-    
sparkSession.sharedState.cacheManager.cacheQuery(sparkSession.table(tableName), 
Some(tableName))
+    cacheTable(tableName, 
sparkSession.sessionState.conf.defaultCacheStorageLevel)
   }
 
   /**
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 51fe3ffd34d..7d411331b8e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -1694,4 +1694,41 @@ class CachedTableSuite extends QueryTest with 
SQLTestUtils
       }
     }
   }
+
+  test("SPARK-45451: Make the default storage level of dataset cache 
configurable") {
+    def validateStorageLevel(expected: StorageLevel): Unit = {
+      withTempView("t") {
+        spark.range(10).createOrReplaceTempView("t")
+
+        Seq(() => spark.table("t").cache(),
+          () => spark.catalog.cacheTable("t"),
+          () => spark.sql("CACHE TABLE t")).foreach { f =>
+          withCache("t") {
+            f()
+            val cached = spark.table("t")
+            val tableCache = collect(cached.queryExecution.executedPlan) {
+              case i: InMemoryTableScanExec => i
+            }
+            if (expected == StorageLevel.NONE) {
+              assert(tableCache.isEmpty)
+            } else {
+              assert(tableCache.size == 1)
+              assert(tableCache.head.relation.cacheBuilder.storageLevel == 
expected)
+            }
+          }
+        }
+      }
+    }
+
+    validateStorageLevel(StorageLevel.MEMORY_AND_DISK)
+    withSQLConf(SQLConf.DEFAULT_CACHE_STORAGE_LEVEL.key -> "NONE") {
+      validateStorageLevel(StorageLevel.NONE)
+    }
+    withSQLConf(SQLConf.DEFAULT_CACHE_STORAGE_LEVEL.key -> 
"MEMORY_AND_DISK_2") {
+      validateStorageLevel(StorageLevel.MEMORY_AND_DISK_2)
+    }
+    intercept[IllegalArgumentException] {
+      withSQLConf(SQLConf.DEFAULT_CACHE_STORAGE_LEVEL.key -> "DISK") {}
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to