This is an automated email from the ASF dual-hosted git repository. allisonwang pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 6cfe303b06da [SPARK-51919][PYTHON] Allow overwriting statically registered Python Data Source 6cfe303b06da is described below commit 6cfe303b06dadcd019eaa731e365fb8c1788a4c8 Author: Haoyu Weng <wengh...@gmail.com> AuthorDate: Tue Jul 8 11:27:58 2025 -0700 [SPARK-51919][PYTHON] Allow overwriting statically registered Python Data Source ### What changes were proposed in this pull request? - Allow overwriting static Python Data Sources during registration - Update documentation to clarify Python Data Source behavior and registration options ### Why are the changes needed? Static registration is a bit obscure and doesn't always work as expected (e.g. when the module providing DefaultSource is installed after `lookup_data_sources` already ran). So in practice users (or LLM agents) often want to explicitly register the data source even if it is provided as a DefaultSource. Raising an error in this case interrupts the workflow, making LLM agents spend extra tokens regenerating the same code but without registration. This change also makes the behavior consistent with user data source registration which are already allowed to overwrite previous user registrations. ### Does this PR introduce _any_ user-facing change? Yes. Previously, registering a Python Data Source with the same name as a statically registered one would throw an error. With this change, it will overwrite the static registration. ### How was this patch tested? Added a test in `PythonDataSourceSuite.scala` to verify that static sources can be overwritten correctly. ### Was this patch authored or co-authored using generative AI tooling? No Closes #50716 from wengh/pyds-overwrite-static. Authored-by: Haoyu Weng <wengh...@gmail.com> Signed-off-by: Allison Wang <allison.w...@databricks.com> --- .../source/tutorial/sql/python_data_source.rst | 4 +++- .../execution/datasources/DataSourceManager.scala | 22 +++++++++++----------- .../execution/python/PythonDataSourceSuite.scala | 18 ++++++++++++++++++ 3 files changed, 32 insertions(+), 12 deletions(-) diff --git a/python/docs/source/tutorial/sql/python_data_source.rst b/python/docs/source/tutorial/sql/python_data_source.rst index 22b2a0b5f3c7..41b76c95d580 100644 --- a/python/docs/source/tutorial/sql/python_data_source.rst +++ b/python/docs/source/tutorial/sql/python_data_source.rst @@ -520,4 +520,6 @@ The following example demonstrates how to implement a basic Data Source using Ar Usage Notes ----------- -- During Data Source resolution, built-in and Scala/Java Data Sources take precedence over Python Data Sources with the same name; to explicitly use a Python Data Source, make sure its name does not conflict with the other Data Sources. +- During Data Source resolution, built-in and Scala/Java Data Sources take precedence over Python Data Sources with the same name; to explicitly use a Python Data Source, make sure its name does not conflict with the other non-Python Data Sources. +- It is allowed to register multiple Python Data Sources with the same name. Later registrations will overwrite earlier ones. +- To automatically register a data source, export it as ``DefaultSource`` in a top level module with name prefix ``pyspark_``. See `pyspark_huggingface <https://github.com/huggingface/pyspark_huggingface>`_ for an example. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala index 711e096ebd1f..7a8dbab35964 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala @@ -48,14 +48,13 @@ class DataSourceManager extends Logging { */ def registerDataSource(name: String, source: UserDefinedPythonDataSource): Unit = { val normalizedName = normalize(name) - if (staticDataSourceBuilders.contains(normalizedName)) { - // Cannot overwrite static Python Data Sources. - throw QueryCompilationErrors.dataSourceAlreadyExists(name) - } val previousValue = runtimeDataSourceBuilders.put(normalizedName, source) if (previousValue != null) { logWarning(log"The data source ${MDC(DATA_SOURCE, name)} replaced a previously " + log"registered data source.") + } else if (staticDataSourceBuilders.contains(normalizedName)) { + logWarning(log"The data source ${MDC(DATA_SOURCE, name)} replaced a statically " + + log"registered data source.") } } @@ -64,11 +63,7 @@ class DataSourceManager extends Logging { * it does not exist. */ def lookupDataSource(name: String): UserDefinedPythonDataSource = { - if (dataSourceExists(name)) { - val normalizedName = normalize(name) - staticDataSourceBuilders.getOrElse( - normalizedName, runtimeDataSourceBuilders.get(normalizedName)) - } else { + getDataSource(name).getOrElse { throw QueryCompilationErrors.dataSourceDoesNotExist(name) } } @@ -77,9 +72,14 @@ class DataSourceManager extends Logging { * Checks if a data source with the specified name exists (case-insensitive). */ def dataSourceExists(name: String): Boolean = { + getDataSource(name).isDefined + } + + private def getDataSource(name: String): Option[UserDefinedPythonDataSource] = { val normalizedName = normalize(name) - staticDataSourceBuilders.contains(normalizedName) || - runtimeDataSourceBuilders.containsKey(normalizedName) + // Runtime registration takes precedence over static. + Option(runtimeDataSourceBuilders.get(normalizedName)) + .orElse(staticDataSourceBuilders.get(normalizedName)) } override def clone(): DataSourceManager = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index f9eb01c10ede..d201f1890dbd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -126,6 +126,24 @@ class PythonDataSourceSuite extends PythonDataSourceSuiteBase { assume(shouldTestPandasUDFs) val df = spark.read.format(staticSourceName).load() checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1))) + + // Overwrite the static source + val errorText = "static source overwritten" + val dataSourceScript = + s""" + |from pyspark.sql.datasource import DataSource + | + |class $staticSourceName(DataSource): + | def schema(self) -> str: + | raise Exception("$errorText") + |""".stripMargin + val dataSource = createUserDefinedPythonDataSource( + name = staticSourceName, pythonScript = dataSourceScript) + spark.dataSource.registerPython(staticSourceName, dataSource) + val err = intercept[AnalysisException] { + spark.read.format(staticSourceName).load() + } + assert(err.getMessage.contains(errorText)) } test("simple data source") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org