This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c8ad616b988 [SPARK-45600][PYTHON] Make Python data source registration 
session level
c8ad616b988 is described below

commit c8ad616b988efdd47d7091f51c1e4563564b4e10
Author: allisonwang-db <allison.w...@databricks.com>
AuthorDate: Thu Nov 23 11:04:09 2023 +0900

    [SPARK-45600][PYTHON] Make Python data source registration session level
    
    ### What changes were proposed in this pull request?
    
    This PR makes dynamic Python data source registration session-scoped. 
Previously, registered data sources were stored in the `sharedState` and can be 
referenced by other sessions, which won't work with Spark Connect.
    
    ### Why are the changes needed?
    
    To make Python data source support Spark Connect in the future.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    New unit test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #43742 from allisonwang-db/spark-45600-session-level.
    
    Authored-by: allisonwang-db <allison.w...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/tests/test_python_datasource.py | 30 +++++++++++++++++
 .../org/apache/spark/sql/DataFrameReader.scala     |  4 +--
 .../scala/org/apache/spark/sql/SparkSession.scala  |  2 +-
 .../execution/datasources/DataSourceManager.scala  | 15 ++++++---
 .../sql/internal/BaseSessionStateBuilder.scala     | 15 +++++++++
 .../apache/spark/sql/internal/SessionState.scala   |  5 +++
 .../apache/spark/sql/internal/SharedState.scala    | 12 -------
 .../execution/python/PythonDataSourceSuite.scala   | 38 +++++++++++++++++-----
 8 files changed, 93 insertions(+), 28 deletions(-)

diff --git a/python/pyspark/sql/tests/test_python_datasource.py 
b/python/pyspark/sql/tests/test_python_datasource.py
index 46b9fa642fd..bab062c4821 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -49,6 +49,36 @@ class BasePythonDataSourceTestsMixin:
         self.assertEqual(list(reader.partitions()), [None])
         self.assertEqual(list(reader.read(None)), [(None,)])
 
+    def test_data_source_register(self):
+        class TestReader(DataSourceReader):
+            def read(self, partition):
+                yield (0, 1)
+
+        class TestDataSource(DataSource):
+            def schema(self):
+                return "a INT, b INT"
+
+            def reader(self, schema):
+                return TestReader()
+
+        self.spark.dataSource.register(TestDataSource)
+        df = self.spark.read.format("TestDataSource").load()
+        assertDataFrameEqual(df, [Row(a=0, b=1)])
+
+        class MyDataSource(TestDataSource):
+            @classmethod
+            def name(cls):
+                return "TestDataSource"
+
+            def schema(self):
+                return "c INT, d INT"
+
+        # Should be able to register the data source with the same name.
+        self.spark.dataSource.register(MyDataSource)
+
+        df = self.spark.read.format("TestDataSource").load()
+        assertDataFrameEqual(df, [Row(c=0, d=1)])
+
     def test_in_memory_data_source(self):
         class InMemDataSourceReader(DataSourceReader):
             DEFAULT_NUM_PARTITIONS: int = 3
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 7fadbbfac68..c29ffb32907 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -210,7 +210,7 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
     }
 
     val isUserDefinedDataSource =
-      sparkSession.sharedState.dataSourceManager.dataSourceExists(source)
+      sparkSession.sessionState.dataSourceManager.dataSourceExists(source)
 
     Try(DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf)) 
match {
       case Success(providerOpt) =>
@@ -243,7 +243,7 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
   }
 
   private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = {
-    val builder = 
sparkSession.sharedState.dataSourceManager.lookupDataSource(source)
+    val builder = 
sparkSession.sessionState.dataSourceManager.lookupDataSource(source)
     // Add `path` and `paths` options to the extra options if specified.
     val optionsWithPath = DataSourceV2Utils.getOptionsWithPaths(extraOptions, 
paths: _*)
     val plan = builder(sparkSession, source, userSpecifiedSchema, 
optionsWithPath)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 5eba9e59c17..24497add04f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -233,7 +233,7 @@ class SparkSession private(
   /**
    * A collection of methods for registering user-defined data sources.
    */
-  private[sql] def dataSource: DataSourceRegistration = 
sharedState.dataSourceRegistration
+  private[sql] def dataSource: DataSourceRegistration = 
sessionState.dataSourceRegistration
 
   /**
    * Returns a `StreamingQueryManager` that allows managing all the
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
index a8c9c892b8b..1cdc3d9cb69 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources
 import java.util.Locale
 import java.util.concurrent.ConcurrentHashMap
 
+import org.apache.spark.internal.Logging
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
@@ -30,7 +31,7 @@ import org.apache.spark.sql.types.StructType
  * A manager for user-defined data sources. It is used to register and lookup 
data sources by
  * their short names or fully qualified names.
  */
-class DataSourceManager {
+class DataSourceManager extends Logging {
 
   private type DataSourceBuilder = (
     SparkSession,  // Spark session
@@ -49,10 +50,10 @@ class DataSourceManager {
    */
   def registerDataSource(name: String, builder: DataSourceBuilder): Unit = {
     val normalizedName = normalize(name)
-    if (dataSourceBuilders.containsKey(normalizedName)) {
-      throw QueryCompilationErrors.dataSourceAlreadyExists(name)
+    val previousValue = dataSourceBuilders.put(normalizedName, builder)
+    if (previousValue != null) {
+      logWarning(f"The data source $name replaced a previously registered data 
source.")
     }
-    dataSourceBuilders.put(normalizedName, builder)
   }
 
   /**
@@ -73,4 +74,10 @@ class DataSourceManager {
   def dataSourceExists(name: String): Boolean = {
     dataSourceBuilders.containsKey(normalize(name))
   }
+
+  override def clone(): DataSourceManager = {
+    val manager = new DataSourceManager
+    dataSourceBuilders.forEach((k, v) => manager.registerDataSource(k, v))
+    manager
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 630e1202f6d..d198e8f5d1f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -120,6 +120,13 @@ abstract class BaseSessionStateBuilder(
       
.getOrElse(extensions.registerTableFunctions(TableFunctionRegistry.builtin.clone()))
   }
 
+  /**
+   * Manages the registration of data sources
+   */
+  protected lazy val dataSourceManager: DataSourceManager = {
+    parentState.map(_.dataSourceManager.clone()).getOrElse(new 
DataSourceManager)
+  }
+
   /**
    * Experimental methods that can be used to define custom optimization rules 
and custom planning
    * strategies.
@@ -178,6 +185,12 @@ abstract class BaseSessionStateBuilder(
 
   protected def udtfRegistration: UDTFRegistration = new 
UDTFRegistration(tableFunctionRegistry)
 
+  /**
+   * A collection of method used for registering user-defined data sources.
+   */
+  protected def dataSourceRegistration: DataSourceRegistration =
+    new DataSourceRegistration(dataSourceManager)
+
   /**
    * Logical query plan analyzer for resolving unresolved attributes and 
relations.
    *
@@ -376,6 +389,8 @@ abstract class BaseSessionStateBuilder(
       tableFunctionRegistry,
       udfRegistration,
       udtfRegistration,
+      dataSourceManager,
+      dataSourceRegistration,
       () => catalog,
       sqlParser,
       () => analyzer,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index adf3e0cb6ca..bc6710e6cbd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.connector.catalog.CatalogManager
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.adaptive.AdaptiveRulesHolder
+import org.apache.spark.sql.execution.datasources.DataSourceManager
 import org.apache.spark.sql.streaming.StreamingQueryManager
 import org.apache.spark.sql.util.ExecutionListenerManager
 import org.apache.spark.util.{DependencyUtils, Utils}
@@ -49,6 +50,8 @@ import org.apache.spark.util.{DependencyUtils, Utils}
  * @param udfRegistration Interface exposed to the user for registering 
user-defined functions.
  * @param udtfRegistration Interface exposed to the user for registering 
user-defined
  *                         table functions.
+ * @param dataSourceManager Internal catalog for managing data sources 
registered by users.
+ * @param dataSourceRegistration Interface exposed to users for registering 
data sources.
  * @param catalogBuilder a function to create an internal catalog for managing 
table and database
  *                       states.
  * @param sqlParser Parser that extracts expressions, plans, table identifiers 
etc. from SQL texts.
@@ -73,6 +76,8 @@ private[sql] class SessionState(
     val tableFunctionRegistry: TableFunctionRegistry,
     val udfRegistration: UDFRegistration,
     val udtfRegistration: UDTFRegistration,
+    val dataSourceManager: DataSourceManager,
+    val dataSourceRegistration: DataSourceRegistration,
     catalogBuilder: () => SessionCatalog,
     val sqlParser: ParserInterface,
     analyzerBuilder: () => Analyzer,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
index 8adc32fcf62..164710cdd88 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
@@ -30,11 +30,9 @@ import org.apache.hadoop.fs.{FsUrlStreamHandlerFactory, Path}
 
 import org.apache.spark.{SparkConf, SparkContext}
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.DataSourceRegistration
 import org.apache.spark.sql.catalyst.catalog._
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.CacheManager
-import org.apache.spark.sql.execution.datasources.DataSourceManager
 import org.apache.spark.sql.execution.streaming.StreamExecution
 import org.apache.spark.sql.execution.ui.{SQLAppStatusListener, 
SQLAppStatusStore, SQLTab, StreamingQueryStatusStore}
 import org.apache.spark.sql.internal.StaticSQLConf._
@@ -107,16 +105,6 @@ private[sql] class SharedState(
   @GuardedBy("activeQueriesLock")
   private[sql] val activeStreamingQueries = new ConcurrentHashMap[UUID, 
StreamExecution]()
 
-  /**
-   * A data source manager shared by all sessions.
-   */
-  lazy val dataSourceManager = new DataSourceManager()
-
-  /**
-   * A collection of method used for registering user-defined data sources.
-   */
-  lazy val dataSourceRegistration = new 
DataSourceRegistration(dataSourceManager)
-
   /**
    * A status store to query SQL status/metrics of this Spark application, 
based on SQL-specific
    * [[org.apache.spark.scheduler.SparkListenerEvent]]s.
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index bd0b08cbec8..33b34b39ab2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python
 
 import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, 
QueryTest, Row}
 import org.apache.spark.sql.catalyst.plans.logical.{BatchEvalPythonUDTF, 
PythonDataSourcePartitions}
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.StructType
 
@@ -143,16 +144,35 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
 
     val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
     spark.dataSource.registerPython(dataSourceName, dataSource)
-    
assert(spark.sharedState.dataSourceManager.dataSourceExists(dataSourceName))
+    
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+    val ds1 = 
spark.sessionState.dataSourceManager.lookupDataSource(dataSourceName)
+    checkAnswer(
+      ds1(spark, dataSourceName, None, CaseInsensitiveMap(Map.empty)),
+      Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1)))
 
-    // Check error when registering a data source with the same name.
-    val err = intercept[AnalysisException] {
-      spark.dataSource.registerPython(dataSourceName, dataSource)
-    }
-    checkError(
-      exception = err,
-      errorClass = "DATA_SOURCE_ALREADY_EXISTS",
-      parameters = Map("provider" -> dataSourceName))
+    // Should be able to override an already registered data source.
+    val newScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource, DataSourceReader
+         |class SimpleDataSourceReader(DataSourceReader):
+         |    def read(self, partition):
+         |        yield (0, )
+         |
+         |class $dataSourceName(DataSource):
+         |    def schema(self) -> str:
+         |        return "id INT"
+         |
+         |    def reader(self, schema):
+         |        return SimpleDataSourceReader()
+         |""".stripMargin
+    val newDataSource = createUserDefinedPythonDataSource(dataSourceName, 
newScript)
+    spark.dataSource.registerPython(dataSourceName, newDataSource)
+    
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+
+    val ds2 = 
spark.sessionState.dataSourceManager.lookupDataSource(dataSourceName)
+    checkAnswer(
+      ds2(spark, dataSourceName, None, CaseInsensitiveMap(Map.empty)),
+      Seq(Row(0)))
   }
 
   test("load data source") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to