This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9d93b7112a31 [SPARK-45639][SQL][PYTHON] Support loading Python data 
sources in DataFrameReader
9d93b7112a31 is described below

commit 9d93b7112a31965447a34301889f90d14578e628
Author: allisonwang-db <allison.w...@databricks.com>
AuthorDate: Wed Nov 8 09:23:12 2023 -0800

    [SPARK-45639][SQL][PYTHON] Support loading Python data sources in 
DataFrameReader
    
    ### What changes were proposed in this pull request?
    
    This PR supports `spark.read.format(...).load()` for Python data sources.
    
    After this PR, users can use a Python data source directly like this:
    ```python
    from pyspark.sql.datasource import DataSource, DataSourceReader
    
    class MyReader(DataSourceReader):
        def read(self, partition):
            yield (0, 1)
    
    class MyDataSource(DataSource):
        classmethod
        def name(cls):
            return "my-source"
    
        def schema(self):
            return "id INT, value INT"
    
        def reader(self, schema):
            return MyReader()
    
    spark.dataSource.register(MyDataSource)
    
    df = spark.read.format("my-source").load()
    df.show()
    +---+-----+
    | id|value|
    +---+-----+
    |  0|    1|
    +---+-----+
    ```
    
    ### Why are the changes needed?
    
    To support Python data sources.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. After this PR, users can load a custom Python data source using 
`spark.read.format(...).load()`.
    
    ### How was this patch tested?
    
    New unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #43630 from allisonwang-db/spark-45639-ds-lookup.
    
    Authored-by: allisonwang-db <allison.w...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../src/main/resources/error/error-classes.json    | 12 +++
 dev/sparktestsupport/modules.py                    |  1 +
 docs/sql-error-conditions.md                       | 12 +++
 python/pyspark/sql/session.py                      |  4 +
 python/pyspark/sql/tests/test_python_datasource.py | 97 ++++++++++++++++++++--
 python/pyspark/sql/worker/create_data_source.py    | 16 +++-
 .../spark/sql/errors/QueryCompilationErrors.scala  | 12 +++
 .../org/apache/spark/sql/DataFrameReader.scala     | 48 +++++++++--
 .../execution/datasources/DataSourceManager.scala  | 31 ++++++-
 .../python/UserDefinedPythonDataSource.scala       | 15 ++--
 .../execution/python/PythonDataSourceSuite.scala   | 35 ++++++++
 11 files changed, 255 insertions(+), 28 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-classes.json 
b/common/utils/src/main/resources/error/error-classes.json
index db46ee8ca208..c38171c3d9e6 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -850,6 +850,12 @@
     ],
     "sqlState" : "42710"
   },
+  "DATA_SOURCE_NOT_EXIST" : {
+    "message" : [
+      "Data source '<provider>' not found. Please make sure the data source is 
registered."
+    ],
+    "sqlState" : "42704"
+  },
   "DATA_SOURCE_NOT_FOUND" : {
     "message" : [
       "Failed to find the data source: <provider>. Please find packages at 
`https://spark.apache.org/third-party-projects.html`.";
@@ -1095,6 +1101,12 @@
     ],
     "sqlState" : "42809"
   },
+  "FOUND_MULTIPLE_DATA_SOURCES" : {
+    "message" : [
+      "Detected multiple data sources with the name '<provider>'. Please check 
the data source isn't simultaneously registered and located in the classpath."
+    ],
+    "sqlState" : "42710"
+  },
   "GENERATED_COLUMN_WITH_DEFAULT_VALUE" : {
     "message" : [
       "A column cannot have both a default value and a generation expression 
but column <colName> has default value: (<defaultValue>) and generation 
expression: (<genExpr>)."
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 95c9069a8313..01757ba28dd2 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -511,6 +511,7 @@ pyspark_sql = Module(
         "pyspark.sql.tests.pandas.test_pandas_udf_window",
         "pyspark.sql.tests.pandas.test_converter",
         "pyspark.sql.tests.test_pandas_sqlmetrics",
+        "pyspark.sql.tests.test_python_datasource",
         "pyspark.sql.tests.test_readwriter",
         "pyspark.sql.tests.test_serde",
         "pyspark.sql.tests.test_session",
diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md
index 7b0bc8ceb2b5..8a5faa15dc9c 100644
--- a/docs/sql-error-conditions.md
+++ b/docs/sql-error-conditions.md
@@ -454,6 +454,12 @@ DataType `<type>` requires a length parameter, for example 
`<type>`(10). Please
 
 Data source '`<provider>`' already exists in the registry. Please use a 
different name for the new data source.
 
+### DATA_SOURCE_NOT_EXIST
+
+[SQLSTATE: 
42704](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
+
+Data source '`<provider>`' not found. Please make sure the data source is 
registered.
+
 ### DATA_SOURCE_NOT_FOUND
 
 [SQLSTATE: 
42K02](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
@@ -669,6 +675,12 @@ No such struct field `<fieldName>` in `<fields>`.
 
 The operation `<statement>` is not allowed on the `<objectType>`: 
`<objectName>`.
 
+### FOUND_MULTIPLE_DATA_SOURCES
+
+[SQLSTATE: 
42710](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
+
+Detected multiple data sources with the name '`<provider>`'. Please check the 
data source isn't simultaneously registered and located in the classpath.
+
 ### GENERATED_COLUMN_WITH_DEFAULT_VALUE
 
 [SQLSTATE: 
42623](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 4ab7281d7ac8..85aff09aa3df 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -884,6 +884,10 @@ class SparkSession(SparkConversionMixin):
         Returns
         -------
         :class:`DataSourceRegistration`
+
+        Notes
+        -----
+        This feature is experimental and unstable.
         """
         from pyspark.sql.datasource import DataSourceRegistration
 
diff --git a/python/pyspark/sql/tests/test_python_datasource.py 
b/python/pyspark/sql/tests/test_python_datasource.py
index b429d73fb7d7..fe6a84175274 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -14,10 +14,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+import os
 import unittest
 
 from pyspark.sql.datasource import DataSource, DataSourceReader
+from pyspark.sql.types import Row
+from pyspark.testing import assertDataFrameEqual
 from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.utils import SPARK_HOME
 
 
 class BasePythonDataSourceTestsMixin:
@@ -45,16 +49,93 @@ class BasePythonDataSourceTestsMixin:
         self.assertEqual(list(reader.partitions()), [None])
         self.assertEqual(list(reader.read(None)), [(None,)])
 
-    def test_register_data_source(self):
-        class MyDataSource(DataSource):
-            ...
+    def test_in_memory_data_source(self):
+        class InMemDataSourceReader(DataSourceReader):
+            DEFAULT_NUM_PARTITIONS: int = 3
+
+            def __init__(self, paths, options):
+                self.paths = paths
+                self.options = options
+
+            def partitions(self):
+                if "num_partitions" in self.options:
+                    num_partitions = int(self.options["num_partitions"])
+                else:
+                    num_partitions = self.DEFAULT_NUM_PARTITIONS
+                return range(num_partitions)
+
+            def read(self, partition):
+                yield partition, str(partition)
+
+        class InMemoryDataSource(DataSource):
+            @classmethod
+            def name(cls):
+                return "memory"
+
+            def schema(self):
+                return "x INT, y STRING"
+
+            def reader(self, schema) -> "DataSourceReader":
+                return InMemDataSourceReader(self.paths, self.options)
+
+        self.spark.dataSource.register(InMemoryDataSource)
+        df = self.spark.read.format("memory").load()
+        self.assertEqual(df.rdd.getNumPartitions(), 3)
+        assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1"), Row(x=2, 
y="2")])
 
-        self.spark.dataSource.register(MyDataSource)
+        df = self.spark.read.format("memory").option("num_partitions", 
2).load()
+        assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1")])
+        self.assertEqual(df.rdd.getNumPartitions(), 2)
+
+    def test_custom_json_data_source(self):
+        import json
+
+        class JsonDataSourceReader(DataSourceReader):
+            def __init__(self, paths, options):
+                self.paths = paths
+                self.options = options
+
+            def partitions(self):
+                return iter(self.paths)
+
+            def read(self, path):
+                with open(path, "r") as file:
+                    for line in file.readlines():
+                        if line.strip():
+                            data = json.loads(line)
+                            yield data.get("name"), data.get("age")
+
+        class JsonDataSource(DataSource):
+            @classmethod
+            def name(cls):
+                return "my-json"
+
+            def schema(self):
+                return "name STRING, age INT"
+
+            def reader(self, schema) -> "DataSourceReader":
+                return JsonDataSourceReader(self.paths, self.options)
+
+        self.spark.dataSource.register(JsonDataSource)
+        path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json")
+        path2 = os.path.join(SPARK_HOME, 
"python/test_support/sql/people1.json")
+        df1 = self.spark.read.format("my-json").load(path1)
+        self.assertEqual(df1.rdd.getNumPartitions(), 1)
+        assertDataFrameEqual(
+            df1,
+            [Row(name="Michael", age=None), Row(name="Andy", age=30), 
Row(name="Justin", age=19)],
+        )
 
-        self.assertTrue(
-            self.spark._jsparkSession.sharedState()
-            .dataSourceRegistry()
-            .dataSourceExists("MyDataSource")
+        df2 = self.spark.read.format("my-json").load([path1, path2])
+        self.assertEqual(df2.rdd.getNumPartitions(), 2)
+        assertDataFrameEqual(
+            df2,
+            [
+                Row(name="Michael", age=None),
+                Row(name="Andy", age=30),
+                Row(name="Justin", age=19),
+                Row(name="Jonathan", age=None),
+            ],
         )
 
 
diff --git a/python/pyspark/sql/worker/create_data_source.py 
b/python/pyspark/sql/worker/create_data_source.py
index ea56d2cc7522..6a9ef79b7c18 100644
--- a/python/pyspark/sql/worker/create_data_source.py
+++ b/python/pyspark/sql/worker/create_data_source.py
@@ -14,13 +14,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
+import inspect
 import os
 import sys
 from typing import IO, List
 
 from pyspark.accumulators import _accumulatorRegistry
-from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
+from pyspark.errors import PySparkAssertionError, PySparkRuntimeError, 
PySparkTypeError
 from pyspark.java_gateway import local_connect_and_auth
 from pyspark.serializers import (
     read_bool,
@@ -84,8 +84,20 @@ def main(infile: IO, outfile: IO) -> None:
                 },
             )
 
+        # Check the name method is a class method.
+        if not inspect.ismethod(data_source_cls.name):
+            raise PySparkTypeError(
+                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                message_parameters={
+                    "expected": "'name()' method to be a classmethod",
+                    "actual": f"'{type(data_source_cls.name).__name__}'",
+                },
+            )
+
         # Receive the provider name.
         provider = utf8_deserializer.loads(infile)
+
+        # Check if the provider name matches the data source's name.
         if provider.lower() != data_source_cls.name().lower():
             raise PySparkAssertionError(
                 error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 1925eddd2ce2..0c5dcb1ead01 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -3805,4 +3805,16 @@ private[sql] object QueryCompilationErrors extends 
QueryErrorsBase with Compilat
       errorClass = "DATA_SOURCE_ALREADY_EXISTS",
       messageParameters = Map("provider" -> name))
   }
+
+  def dataSourceDoesNotExist(name: String): Throwable = {
+    new AnalysisException(
+      errorClass = "DATA_SOURCE_NOT_EXIST",
+      messageParameters = Map("provider" -> name))
+  }
+
+  def foundMultipleDataSources(provider: String): Throwable = {
+    new AnalysisException(
+      errorClass = "FOUND_MULTIPLE_DATA_SOURCES",
+      messageParameters = Map("provider" -> provider))
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 9992d8cbba07..ef447e8a8010 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -17,11 +17,12 @@
 
 package org.apache.spark.sql
 
-import java.util.{Locale, Properties}
+import java.util.{Locale, Properties, ServiceConfigurationError}
 
 import scala.jdk.CollectionConverters._
+import scala.util.{Failure, Success, Try}
 
-import org.apache.spark.Partition
+import org.apache.spark.{Partition, SparkClassNotFoundException, 
SparkThrowable}
 import org.apache.spark.annotation.Stable
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.internal.Logging
@@ -208,10 +209,45 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
       throw QueryCompilationErrors.pathOptionNotSetCorrectlyWhenReadingError()
     }
 
-    DataSource.lookupDataSourceV2(source, 
sparkSession.sessionState.conf).flatMap { provider =>
-      DataSourceV2Utils.loadV2Source(sparkSession, provider, 
userSpecifiedSchema, extraOptions,
-        source, paths: _*)
-    }.getOrElse(loadV1Source(paths: _*))
+    val isUserDefinedDataSource =
+      sparkSession.sharedState.dataSourceManager.dataSourceExists(source)
+
+    Try(DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf)) 
match {
+      case Success(providerOpt) =>
+        // The source can be successfully loaded as either a V1 or a V2 data 
source.
+        // Check if it is also a user-defined data source.
+        if (isUserDefinedDataSource) {
+          throw QueryCompilationErrors.foundMultipleDataSources(source)
+        }
+        providerOpt.flatMap { provider =>
+          DataSourceV2Utils.loadV2Source(
+            sparkSession, provider, userSpecifiedSchema, extraOptions, source, 
paths: _*)
+        }.getOrElse(loadV1Source(paths: _*))
+      case Failure(exception) =>
+        // Exceptions are thrown while trying to load the data source as a V1 
or V2 data source.
+        // For the following not found exceptions, if the user-defined data 
source is defined,
+        // we can instead return the user-defined data source.
+        val isNotFoundError = exception match {
+          case _: NoClassDefFoundError | _: SparkClassNotFoundException => true
+          case e: SparkThrowable => e.getErrorClass == "DATA_SOURCE_NOT_FOUND"
+          case e: ServiceConfigurationError => 
e.getCause.isInstanceOf[NoClassDefFoundError]
+          case _ => false
+        }
+        if (isNotFoundError && isUserDefinedDataSource) {
+          loadUserDefinedDataSource(paths)
+        } else {
+          // Throw the original exception.
+          throw exception
+        }
+    }
+  }
+
+  private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = {
+    val builder = 
sparkSession.sharedState.dataSourceManager.lookupDataSource(source)
+    // Unless the legacy path option behavior is enabled, the extraOptions here
+    // should not include "path" or "paths" as keys.
+    val plan = builder(sparkSession, source, paths, userSpecifiedSchema, 
extraOptions)
+    Dataset.ofRows(sparkSession, plan)
   }
 
   private def loadV1Source(paths: String*) = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
index 283ca2ac62ed..72a9e6497aca 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
@@ -22,10 +22,14 @@ import java.util.concurrent.ConcurrentHashMap
 
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.util.CaseInsensitiveStringMap
 
+/**
+ * A manager for user-defined data sources. It is used to register and lookup 
data sources by
+ * their short names or fully qualified names.
+ */
 class DataSourceManager {
 
   private type DataSourceBuilder = (
@@ -33,22 +37,41 @@ class DataSourceManager {
     String,  // provider name
     Seq[String],  // paths
     Option[StructType],  // user specified schema
-    CaseInsensitiveStringMap  // options
+    CaseInsensitiveMap[String]  // options
   ) => LogicalPlan
 
   private val dataSourceBuilders = new ConcurrentHashMap[String, 
DataSourceBuilder]()
 
   private def normalize(name: String): String = name.toLowerCase(Locale.ROOT)
 
+  /**
+   * Register a data source builder for the given provider.
+   * Note that the provider name is case-insensitive.
+   */
   def registerDataSource(name: String, builder: DataSourceBuilder): Unit = {
     val normalizedName = normalize(name)
     if (dataSourceBuilders.containsKey(normalizedName)) {
       throw QueryCompilationErrors.dataSourceAlreadyExists(name)
     }
-    // TODO(SPARK-45639): check if the data source is a DSv1 or DSv2 using 
loadDataSource.
     dataSourceBuilders.put(normalizedName, builder)
   }
 
-  def dataSourceExists(name: String): Boolean =
+  /**
+   * Returns a data source builder for the given provider and throw an 
exception if
+   * it does not exist.
+   */
+  def lookupDataSource(name: String): DataSourceBuilder = {
+    if (dataSourceExists(name)) {
+      dataSourceBuilders.get(normalize(name))
+    } else {
+      throw QueryCompilationErrors.dataSourceDoesNotExist(name)
+    }
+  }
+
+  /**
+   * Checks if a data source with the specified name exists (case-insensitive).
+   */
+  def dataSourceExists(name: String): Boolean = {
     dataSourceBuilders.containsKey(normalize(name))
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
index dbff8eefcd5f..703c1e10ce26 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.python
 import java.io.{DataInputStream, DataOutputStream}
 
 import scala.collection.mutable.ArrayBuffer
-import scala.jdk.CollectionConverters._
 
 import net.razorvine.pickle.Pickler
 
@@ -28,9 +27,9 @@ import org.apache.spark.api.python.{PythonFunction, 
PythonWorkerUtils, SimplePyt
 import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, 
PythonDataSource}
 import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.types.{DataType, StructType}
-import org.apache.spark.sql.util.CaseInsensitiveStringMap
 
 /**
  * A user-defined Python data source. This is used by the Python API.
@@ -44,7 +43,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: 
PythonFunction) {
       provider: String,
       paths: Seq[String],
       userSpecifiedSchema: Option[StructType],
-      options: CaseInsensitiveStringMap): LogicalPlan = {
+      options: CaseInsensitiveMap[String]): LogicalPlan = {
 
     val runner = new UserDefinedPythonDataSourceRunner(
       dataSourceCls, provider, paths, userSpecifiedSchema, options)
@@ -70,7 +69,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: 
PythonFunction) {
       provider: String,
       paths: Seq[String] = Seq.empty,
       userSpecifiedSchema: Option[StructType] = None,
-      options: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty): 
DataFrame = {
+      options: CaseInsensitiveMap[String] = CaseInsensitiveMap(Map.empty)): 
DataFrame = {
     val plan = builder(sparkSession, provider, paths, userSpecifiedSchema, 
options)
     Dataset.ofRows(sparkSession, plan)
   }
@@ -91,7 +90,7 @@ class UserDefinedPythonDataSourceRunner(
     provider: String,
     paths: Seq[String],
     userSpecifiedSchema: Option[StructType],
-    options: CaseInsensitiveStringMap)
+    options: CaseInsensitiveMap[String])
   extends PythonPlannerRunner[PythonDataSourceCreationResult](dataSourceCls) {
 
   override val workerModule = "pyspark.sql.worker.create_data_source"
@@ -113,9 +112,9 @@ class UserDefinedPythonDataSourceRunner(
 
     // Send the options
     dataOut.writeInt(options.size)
-    options.entrySet.asScala.foreach { e =>
-      PythonWorkerUtils.writeUTF(e.getKey, dataOut)
-      PythonWorkerUtils.writeUTF(e.getValue, dataOut)
+    options.iterator.foreach { case (key, value) =>
+      PythonWorkerUtils.writeUTF(key, dataOut)
+      PythonWorkerUtils.writeUTF(value, dataOut)
     }
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index 6c749c2c9b67..22a1e5250cd9 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -155,6 +155,41 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
       parameters = Map("provider" -> dataSourceName))
   }
 
+  test("load data source") {
+    assume(shouldTestPythonUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource, DataSourceReader
+         |class SimpleDataSourceReader(DataSourceReader):
+         |    def __init__(self, paths, options):
+         |        self.paths = paths
+         |        self.options = options
+         |
+         |    def partitions(self):
+         |        return iter(self.paths)
+         |
+         |    def read(self, path):
+         |        yield (path, 1)
+         |
+         |class $dataSourceName(DataSource):
+         |    @classmethod
+         |    def name(cls) -> str:
+         |        return "test"
+         |
+         |    def schema(self) -> str:
+         |        return "id STRING, value INT"
+         |
+         |    def reader(self, schema):
+         |        return SimpleDataSourceReader(self.paths, self.options)
+         |""".stripMargin
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython("test", dataSource)
+
+    checkAnswer(spark.read.format("test").load(), Seq(Row(null, 1)))
+    checkAnswer(spark.read.format("test").load("1"), Seq(Row("1", 1)))
+    checkAnswer(spark.read.format("test").load("1", "2"), Seq(Row("1", 1), 
Row("2", 1)))
+  }
+
   test("reader not implemented") {
     assume(shouldTestPythonUDFs)
     val dataSourceScript =


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to