This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c8ad616b988 [SPARK-45600][PYTHON] Make Python data source registration session level c8ad616b988 is described below commit c8ad616b988efdd47d7091f51c1e4563564b4e10 Author: allisonwang-db <allison.w...@databricks.com> AuthorDate: Thu Nov 23 11:04:09 2023 +0900 [SPARK-45600][PYTHON] Make Python data source registration session level ### What changes were proposed in this pull request? This PR makes dynamic Python data source registration session-scoped. Previously, registered data sources were stored in the `sharedState` and can be referenced by other sessions, which won't work with Spark Connect. ### Why are the changes needed? To make Python data source support Spark Connect in the future. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? New unit test ### Was this patch authored or co-authored using generative AI tooling? No Closes #43742 from allisonwang-db/spark-45600-session-level. Authored-by: allisonwang-db <allison.w...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/tests/test_python_datasource.py | 30 +++++++++++++++++ .../org/apache/spark/sql/DataFrameReader.scala | 4 +-- .../scala/org/apache/spark/sql/SparkSession.scala | 2 +- .../execution/datasources/DataSourceManager.scala | 15 ++++++--- .../sql/internal/BaseSessionStateBuilder.scala | 15 +++++++++ .../apache/spark/sql/internal/SessionState.scala | 5 +++ .../apache/spark/sql/internal/SharedState.scala | 12 ------- .../execution/python/PythonDataSourceSuite.scala | 38 +++++++++++++++++----- 8 files changed, 93 insertions(+), 28 deletions(-) diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py index 46b9fa642fd..bab062c4821 100644 --- a/python/pyspark/sql/tests/test_python_datasource.py +++ b/python/pyspark/sql/tests/test_python_datasource.py @@ -49,6 +49,36 @@ class BasePythonDataSourceTestsMixin: self.assertEqual(list(reader.partitions()), [None]) self.assertEqual(list(reader.read(None)), [(None,)]) + def test_data_source_register(self): + class TestReader(DataSourceReader): + def read(self, partition): + yield (0, 1) + + class TestDataSource(DataSource): + def schema(self): + return "a INT, b INT" + + def reader(self, schema): + return TestReader() + + self.spark.dataSource.register(TestDataSource) + df = self.spark.read.format("TestDataSource").load() + assertDataFrameEqual(df, [Row(a=0, b=1)]) + + class MyDataSource(TestDataSource): + @classmethod + def name(cls): + return "TestDataSource" + + def schema(self): + return "c INT, d INT" + + # Should be able to register the data source with the same name. + self.spark.dataSource.register(MyDataSource) + + df = self.spark.read.format("TestDataSource").load() + assertDataFrameEqual(df, [Row(c=0, d=1)]) + def test_in_memory_data_source(self): class InMemDataSourceReader(DataSourceReader): DEFAULT_NUM_PARTITIONS: int = 3 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 7fadbbfac68..c29ffb32907 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -210,7 +210,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } val isUserDefinedDataSource = - sparkSession.sharedState.dataSourceManager.dataSourceExists(source) + sparkSession.sessionState.dataSourceManager.dataSourceExists(source) Try(DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf)) match { case Success(providerOpt) => @@ -243,7 +243,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = { - val builder = sparkSession.sharedState.dataSourceManager.lookupDataSource(source) + val builder = sparkSession.sessionState.dataSourceManager.lookupDataSource(source) // Add `path` and `paths` options to the extra options if specified. val optionsWithPath = DataSourceV2Utils.getOptionsWithPaths(extraOptions, paths: _*) val plan = builder(sparkSession, source, userSpecifiedSchema, optionsWithPath) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 5eba9e59c17..24497add04f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -233,7 +233,7 @@ class SparkSession private( /** * A collection of methods for registering user-defined data sources. */ - private[sql] def dataSource: DataSourceRegistration = sharedState.dataSourceRegistration + private[sql] def dataSource: DataSourceRegistration = sessionState.dataSourceRegistration /** * Returns a `StreamingQueryManager` that allows managing all the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala index a8c9c892b8b..1cdc3d9cb69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources import java.util.Locale import java.util.concurrent.ConcurrentHashMap +import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -30,7 +31,7 @@ import org.apache.spark.sql.types.StructType * A manager for user-defined data sources. It is used to register and lookup data sources by * their short names or fully qualified names. */ -class DataSourceManager { +class DataSourceManager extends Logging { private type DataSourceBuilder = ( SparkSession, // Spark session @@ -49,10 +50,10 @@ class DataSourceManager { */ def registerDataSource(name: String, builder: DataSourceBuilder): Unit = { val normalizedName = normalize(name) - if (dataSourceBuilders.containsKey(normalizedName)) { - throw QueryCompilationErrors.dataSourceAlreadyExists(name) + val previousValue = dataSourceBuilders.put(normalizedName, builder) + if (previousValue != null) { + logWarning(f"The data source $name replaced a previously registered data source.") } - dataSourceBuilders.put(normalizedName, builder) } /** @@ -73,4 +74,10 @@ class DataSourceManager { def dataSourceExists(name: String): Boolean = { dataSourceBuilders.containsKey(normalize(name)) } + + override def clone(): DataSourceManager = { + val manager = new DataSourceManager + dataSourceBuilders.forEach((k, v) => manager.registerDataSource(k, v)) + manager + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 630e1202f6d..d198e8f5d1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -120,6 +120,13 @@ abstract class BaseSessionStateBuilder( .getOrElse(extensions.registerTableFunctions(TableFunctionRegistry.builtin.clone())) } + /** + * Manages the registration of data sources + */ + protected lazy val dataSourceManager: DataSourceManager = { + parentState.map(_.dataSourceManager.clone()).getOrElse(new DataSourceManager) + } + /** * Experimental methods that can be used to define custom optimization rules and custom planning * strategies. @@ -178,6 +185,12 @@ abstract class BaseSessionStateBuilder( protected def udtfRegistration: UDTFRegistration = new UDTFRegistration(tableFunctionRegistry) + /** + * A collection of method used for registering user-defined data sources. + */ + protected def dataSourceRegistration: DataSourceRegistration = + new DataSourceRegistration(dataSourceManager) + /** * Logical query plan analyzer for resolving unresolved attributes and relations. * @@ -376,6 +389,8 @@ abstract class BaseSessionStateBuilder( tableFunctionRegistry, udfRegistration, udtfRegistration, + dataSourceManager, + dataSourceRegistration, () => catalog, sqlParser, () => analyzer, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index adf3e0cb6ca..bc6710e6cbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveRulesHolder +import org.apache.spark.sql.execution.datasources.DataSourceManager import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager import org.apache.spark.util.{DependencyUtils, Utils} @@ -49,6 +50,8 @@ import org.apache.spark.util.{DependencyUtils, Utils} * @param udfRegistration Interface exposed to the user for registering user-defined functions. * @param udtfRegistration Interface exposed to the user for registering user-defined * table functions. + * @param dataSourceManager Internal catalog for managing data sources registered by users. + * @param dataSourceRegistration Interface exposed to users for registering data sources. * @param catalogBuilder a function to create an internal catalog for managing table and database * states. * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts. @@ -73,6 +76,8 @@ private[sql] class SessionState( val tableFunctionRegistry: TableFunctionRegistry, val udfRegistration: UDFRegistration, val udtfRegistration: UDTFRegistration, + val dataSourceManager: DataSourceManager, + val dataSourceRegistration: DataSourceRegistration, catalogBuilder: () => SessionCatalog, val sqlParser: ParserInterface, analyzerBuilder: () => Analyzer, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 8adc32fcf62..164710cdd88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -30,11 +30,9 @@ import org.apache.hadoop.fs.{FsUrlStreamHandlerFactory, Path} import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging -import org.apache.spark.sql.DataSourceRegistration import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.CacheManager -import org.apache.spark.sql.execution.datasources.DataSourceManager import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.ui.{SQLAppStatusListener, SQLAppStatusStore, SQLTab, StreamingQueryStatusStore} import org.apache.spark.sql.internal.StaticSQLConf._ @@ -107,16 +105,6 @@ private[sql] class SharedState( @GuardedBy("activeQueriesLock") private[sql] val activeStreamingQueries = new ConcurrentHashMap[UUID, StreamExecution]() - /** - * A data source manager shared by all sessions. - */ - lazy val dataSourceManager = new DataSourceManager() - - /** - * A collection of method used for registering user-defined data sources. - */ - lazy val dataSourceRegistration = new DataSourceRegistration(dataSourceManager) - /** * A status store to query SQL status/metrics of this Spark application, based on SQL-specific * [[org.apache.spark.scheduler.SparkListenerEvent]]s. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index bd0b08cbec8..33b34b39ab2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest, Row} import org.apache.spark.sql.catalyst.plans.logical.{BatchEvalPythonUDTF, PythonDataSourcePartitions} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType @@ -143,16 +144,35 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) spark.dataSource.registerPython(dataSourceName, dataSource) - assert(spark.sharedState.dataSourceManager.dataSourceExists(dataSourceName)) + assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName)) + val ds1 = spark.sessionState.dataSourceManager.lookupDataSource(dataSourceName) + checkAnswer( + ds1(spark, dataSourceName, None, CaseInsensitiveMap(Map.empty)), + Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1))) - // Check error when registering a data source with the same name. - val err = intercept[AnalysisException] { - spark.dataSource.registerPython(dataSourceName, dataSource) - } - checkError( - exception = err, - errorClass = "DATA_SOURCE_ALREADY_EXISTS", - parameters = Map("provider" -> dataSourceName)) + // Should be able to override an already registered data source. + val newScript = + s""" + |from pyspark.sql.datasource import DataSource, DataSourceReader + |class SimpleDataSourceReader(DataSourceReader): + | def read(self, partition): + | yield (0, ) + | + |class $dataSourceName(DataSource): + | def schema(self) -> str: + | return "id INT" + | + | def reader(self, schema): + | return SimpleDataSourceReader() + |""".stripMargin + val newDataSource = createUserDefinedPythonDataSource(dataSourceName, newScript) + spark.dataSource.registerPython(dataSourceName, newDataSource) + assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName)) + + val ds2 = spark.sessionState.dataSourceManager.lookupDataSource(dataSourceName) + checkAnswer( + ds2(spark, dataSourceName, None, CaseInsensitiveMap(Map.empty)), + Seq(Row(0))) } test("load data source") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org