This is an automated email from the ASF dual-hosted git repository.

allisonwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6cfe303b06da [SPARK-51919][PYTHON] Allow overwriting statically 
registered Python Data Source
6cfe303b06da is described below

commit 6cfe303b06dadcd019eaa731e365fb8c1788a4c8
Author: Haoyu Weng <wengh...@gmail.com>
AuthorDate: Tue Jul 8 11:27:58 2025 -0700

    [SPARK-51919][PYTHON] Allow overwriting statically registered Python Data 
Source
    
    ### What changes were proposed in this pull request?
    - Allow overwriting static Python Data Sources during registration
    - Update documentation to clarify Python Data Source behavior and 
registration options
    
    ### Why are the changes needed?
    Static registration is a bit obscure and doesn't always work as expected 
(e.g. when the module providing DefaultSource is installed after 
`lookup_data_sources` already ran).
    So in practice users (or LLM agents) often want to explicitly register the 
data source even if it is provided as a DefaultSource.
    Raising an error in this case interrupts the workflow, making LLM agents 
spend extra tokens regenerating the same code but without registration.
    
    This change also makes the behavior consistent with user data source 
registration which are already allowed to overwrite previous user registrations.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. Previously, registering a Python Data Source with the same name as a 
statically registered one would throw an error. With this change, it will 
overwrite the static registration.
    
    ### How was this patch tested?
    Added a test in `PythonDataSourceSuite.scala` to verify that static sources 
can be overwritten correctly.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #50716 from wengh/pyds-overwrite-static.
    
    Authored-by: Haoyu Weng <wengh...@gmail.com>
    Signed-off-by: Allison Wang <allison.w...@databricks.com>
---
 .../source/tutorial/sql/python_data_source.rst     |  4 +++-
 .../execution/datasources/DataSourceManager.scala  | 22 +++++++++++-----------
 .../execution/python/PythonDataSourceSuite.scala   | 18 ++++++++++++++++++
 3 files changed, 32 insertions(+), 12 deletions(-)

diff --git a/python/docs/source/tutorial/sql/python_data_source.rst 
b/python/docs/source/tutorial/sql/python_data_source.rst
index 22b2a0b5f3c7..41b76c95d580 100644
--- a/python/docs/source/tutorial/sql/python_data_source.rst
+++ b/python/docs/source/tutorial/sql/python_data_source.rst
@@ -520,4 +520,6 @@ The following example demonstrates how to implement a basic 
Data Source using Ar
 Usage Notes
 -----------
 
-- During Data Source resolution, built-in and Scala/Java Data Sources take 
precedence over Python Data Sources with the same name; to explicitly use a 
Python Data Source, make sure its name does not conflict with the other Data 
Sources.
+- During Data Source resolution, built-in and Scala/Java Data Sources take 
precedence over Python Data Sources with the same name; to explicitly use a 
Python Data Source, make sure its name does not conflict with the other 
non-Python Data Sources.
+- It is allowed to register multiple Python Data Sources with the same name. 
Later registrations will overwrite earlier ones.
+- To automatically register a data source, export it as ``DefaultSource`` in a 
top level module with name prefix ``pyspark_``. See `pyspark_huggingface 
<https://github.com/huggingface/pyspark_huggingface>`_ for an example.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
index 711e096ebd1f..7a8dbab35964 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
@@ -48,14 +48,13 @@ class DataSourceManager extends Logging {
    */
   def registerDataSource(name: String, source: UserDefinedPythonDataSource): 
Unit = {
     val normalizedName = normalize(name)
-    if (staticDataSourceBuilders.contains(normalizedName)) {
-      // Cannot overwrite static Python Data Sources.
-      throw QueryCompilationErrors.dataSourceAlreadyExists(name)
-    }
     val previousValue = runtimeDataSourceBuilders.put(normalizedName, source)
     if (previousValue != null) {
       logWarning(log"The data source ${MDC(DATA_SOURCE, name)} replaced a 
previously " +
         log"registered data source.")
+    } else if (staticDataSourceBuilders.contains(normalizedName)) {
+      logWarning(log"The data source ${MDC(DATA_SOURCE, name)} replaced a 
statically " +
+        log"registered data source.")
     }
   }
 
@@ -64,11 +63,7 @@ class DataSourceManager extends Logging {
    * it does not exist.
    */
   def lookupDataSource(name: String): UserDefinedPythonDataSource = {
-    if (dataSourceExists(name)) {
-      val normalizedName = normalize(name)
-      staticDataSourceBuilders.getOrElse(
-        normalizedName, runtimeDataSourceBuilders.get(normalizedName))
-    } else {
+    getDataSource(name).getOrElse {
       throw QueryCompilationErrors.dataSourceDoesNotExist(name)
     }
   }
@@ -77,9 +72,14 @@ class DataSourceManager extends Logging {
    * Checks if a data source with the specified name exists (case-insensitive).
    */
   def dataSourceExists(name: String): Boolean = {
+    getDataSource(name).isDefined
+  }
+
+  private def getDataSource(name: String): Option[UserDefinedPythonDataSource] 
= {
     val normalizedName = normalize(name)
-    staticDataSourceBuilders.contains(normalizedName) ||
-      runtimeDataSourceBuilders.containsKey(normalizedName)
+    // Runtime registration takes precedence over static.
+    Option(runtimeDataSourceBuilders.get(normalizedName))
+      .orElse(staticDataSourceBuilders.get(normalizedName))
   }
 
   override def clone(): DataSourceManager = {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index f9eb01c10ede..d201f1890dbd 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -126,6 +126,24 @@ class PythonDataSourceSuite extends 
PythonDataSourceSuiteBase {
     assume(shouldTestPandasUDFs)
     val df = spark.read.format(staticSourceName).load()
     checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), 
Row(2, 1)))
+
+    // Overwrite the static source
+    val errorText = "static source overwritten"
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |
+         |class $staticSourceName(DataSource):
+         |    def schema(self) -> str:
+         |        raise Exception("$errorText")
+         |""".stripMargin
+    val dataSource = createUserDefinedPythonDataSource(
+      name = staticSourceName, pythonScript = dataSourceScript)
+    spark.dataSource.registerPython(staticSourceName, dataSource)
+    val err = intercept[AnalysisException] {
+      spark.read.format(staticSourceName).load()
+    }
+    assert(err.getMessage.contains(errorText))
   }
 
   test("simple data source") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to