This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a98c885da09d [SPARK-46568][PYTHON] Make Python data source options a case-insensitive dictionary a98c885da09d is described below commit a98c885da09d45a19568f5d853086f747e0ecd95 Author: allisonwang-db <allison.w...@databricks.com> AuthorDate: Thu Jan 4 22:28:49 2024 -0800 [SPARK-46568][PYTHON] Make Python data source options a case-insensitive dictionary ### What changes were proposed in this pull request? This PR updates the `options` field to use a case-insensitive dictionary to keep the behavior consistent with the Scala side (which uses `CaseInsensitiveStringMap`). Currently, `options` are stored in a normal Python dictionary which can be confusing to users. For instance: ```python class MyDataSource(DataSource): def __init__(self, options): self.api_key = options.get("API_KEY") # <- This is None spark.read.format(..).option("API_KEY", my_key).load(...) ``` Here, `options` will not have this "API_KEY" as everything is converted to lowercase on the Scala side. ### Why are the changes needed? To improve usability. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? New unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #44564 from allisonwang-db/spark-46568-ds-options. Authored-by: allisonwang-db <allison.w...@databricks.com> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- python/pyspark/sql/datasource.py | 51 +++++++++++++++++++--- python/pyspark/sql/tests/test_python_datasource.py | 21 +++++++++ python/pyspark/sql/worker/create_data_source.py | 6 +-- .../pyspark/sql/worker/write_into_data_source.py | 6 +-- .../execution/python/PythonDataSourceSuite.scala | 43 ++++++++++++++++++ 5 files changed, 115 insertions(+), 12 deletions(-) diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py index e20d44039a69..bdedbac3544e 100644 --- a/python/pyspark/sql/datasource.py +++ b/python/pyspark/sql/datasource.py @@ -15,18 +15,25 @@ # limitations under the License. # from abc import ABC, abstractmethod -from typing import final, Any, Dict, Iterator, List, Sequence, Tuple, Type, Union, TYPE_CHECKING +from collections import UserDict +from typing import Any, Dict, Iterator, List, Sequence, Tuple, Type, Union, TYPE_CHECKING from pyspark.sql import Row from pyspark.sql.types import StructType from pyspark.errors import PySparkNotImplementedError if TYPE_CHECKING: - from pyspark.sql._typing import OptionalPrimitiveType from pyspark.sql.session import SparkSession -__all__ = ["DataSource", "DataSourceReader", "DataSourceWriter", "DataSourceRegistration"] +__all__ = [ + "DataSource", + "DataSourceReader", + "DataSourceWriter", + "DataSourceRegistration", + "InputPartition", + "WriterCommitMessage", +] class DataSource(ABC): @@ -45,15 +52,14 @@ class DataSource(ABC): .. versionadded: 4.0.0 """ - @final - def __init__(self, options: Dict[str, "OptionalPrimitiveType"]) -> None: + def __init__(self, options: Dict[str, str]) -> None: """ Initializes the data source with user-provided options. Parameters ---------- options : dict - A dictionary representing the options for this data source. + A case-insensitive dictionary representing the options for this data source. Notes ----- @@ -403,3 +409,36 @@ class DataSourceRegistration: assert sc._jvm is not None ds = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonDataSource(wrapped) self.sparkSession._jsparkSession.dataSource().registerPython(name, ds) + + +class CaseInsensitiveDict(UserDict): + """ + A case-insensitive map of string keys to values. + + This is used by Python data source options to ensure consistent case insensitivity. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.update(*args, **kwargs) + + def __setitem__(self, key: str, value: Any) -> None: + super().__setitem__(key.lower(), value) + + def __getitem__(self, key: str) -> Any: + return super().__getitem__(key.lower()) + + def __delitem__(self, key: str) -> None: + super().__delitem__(key.lower()) + + def __contains__(self, key: object) -> bool: + if isinstance(key, str): + return super().__contains__(key.lower()) + return False + + def update(self, *args: Any, **kwargs: Any) -> None: + for k, v in dict(*args, **kwargs).items(): + self[k] = v + + def copy(self) -> "CaseInsensitiveDict": + return type(self)(self) diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py index ce629b2718e2..79414cb7ed69 100644 --- a/python/pyspark/sql/tests/test_python_datasource.py +++ b/python/pyspark/sql/tests/test_python_datasource.py @@ -26,6 +26,7 @@ from pyspark.sql.datasource import ( InputPartition, DataSourceWriter, WriterCommitMessage, + CaseInsensitiveDict, ) from pyspark.sql.types import Row, StructType from pyspark.testing import assertDataFrameEqual @@ -346,6 +347,26 @@ class BasePythonDataSourceTestsMixin: text = file.read() assert text == "failed" + def test_case_insensitive_dict(self): + d = CaseInsensitiveDict({"foo": 1, "Bar": 2}) + self.assertEqual(d["foo"], d["FOO"]) + self.assertEqual(d["bar"], d["BAR"]) + self.assertTrue("baR" in d) + d["BAR"] = 3 + self.assertEqual(d["BAR"], 3) + # Test update + d.update({"BaZ": 3}) + self.assertEqual(d["BAZ"], 3) + d.update({"FOO": 4}) + self.assertEqual(d["foo"], 4) + # Test delete + del d["FoO"] + self.assertFalse("FOO" in d) + # Test copy + d2 = d.copy() + self.assertEqual(d2["BaR"], 3) + self.assertEqual(d2["baz"], 3) + class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase): ... diff --git a/python/pyspark/sql/worker/create_data_source.py b/python/pyspark/sql/worker/create_data_source.py index 1ba4dc9e8a3c..a377911c6e9b 100644 --- a/python/pyspark/sql/worker/create_data_source.py +++ b/python/pyspark/sql/worker/create_data_source.py @@ -29,7 +29,7 @@ from pyspark.serializers import ( write_with_length, SpecialLengths, ) -from pyspark.sql.datasource import DataSource +from pyspark.sql.datasource import DataSource, CaseInsensitiveDict from pyspark.sql.types import _parse_datatype_json_string, StructType from pyspark.util import handle_worker_exception from pyspark.worker_util import ( @@ -120,7 +120,7 @@ def main(infile: IO, outfile: IO) -> None: ) # Receive the options. - options = dict() + options = CaseInsensitiveDict() num_options = read_int(infile) for _ in range(num_options): key = utf8_deserializer.loads(infile) @@ -129,7 +129,7 @@ def main(infile: IO, outfile: IO) -> None: # Instantiate a data source. try: - data_source = data_source_cls(options=options) + data_source = data_source_cls(options=options) # type: ignore except Exception as e: raise PySparkRuntimeError( error_class="PYTHON_DATA_SOURCE_CREATE_ERROR", diff --git a/python/pyspark/sql/worker/write_into_data_source.py b/python/pyspark/sql/worker/write_into_data_source.py index 36b3c23b3379..0ba6fc6eb17f 100644 --- a/python/pyspark/sql/worker/write_into_data_source.py +++ b/python/pyspark/sql/worker/write_into_data_source.py @@ -30,7 +30,7 @@ from pyspark.serializers import ( SpecialLengths, ) from pyspark.sql import Row -from pyspark.sql.datasource import DataSource, WriterCommitMessage +from pyspark.sql.datasource import DataSource, WriterCommitMessage, CaseInsensitiveDict from pyspark.sql.types import ( _parse_datatype_json_string, StructType, @@ -142,7 +142,7 @@ def main(infile: IO, outfile: IO) -> None: return_col_name = return_type[0].name # Receive the options. - options = dict() + options = CaseInsensitiveDict() num_options = read_int(infile) for _ in range(num_options): key = utf8_deserializer.loads(infile) @@ -153,7 +153,7 @@ def main(infile: IO, outfile: IO) -> None: overwrite = read_bool(infile) # Instantiate a data source. - data_source = data_source_cls(options=options) + data_source = data_source_cls(options=options) # type: ignore # Instantiate the data source writer. writer = data_source.writer(schema, overwrite) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index a4614c2b4bad..1cd8fb6819cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -790,4 +790,47 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { } } } + + test("SPARK-46568: case insensitive options") { + assume(shouldTestPandasUDFs) + val dataSourceScript = + s""" + |from pyspark.sql.datasource import ( + | DataSource, DataSourceReader, DataSourceWriter, WriterCommitMessage) + |class SimpleDataSourceReader(DataSourceReader): + | def __init__(self, options): + | self.options = options + | + | def read(self, partition): + | foo = self.options.get("Foo") + | bar = self.options.get("BAR") + | baz = "BaZ" in self.options + | yield (foo, bar, baz) + | + |class SimpleDataSourceWriter(DataSourceWriter): + | def __init__(self, options): + | self.options = options + | + | def write(self, row): + | if "FOO" not in self.options or "BAR" not in self.options: + | raise Exception("FOO or BAR not found") + | return WriterCommitMessage() + | + |class $dataSourceName(DataSource): + | def schema(self) -> str: + | return "a string, b string, c string" + | + | def reader(self, schema): + | return SimpleDataSourceReader(self.options) + | + | def writer(self, schema, overwrite): + | return SimpleDataSourceWriter(self.options) + |""".stripMargin + val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) + spark.dataSource.registerPython(dataSourceName, dataSource) + val df = spark.read.option("foo", 1).option("bar", 2).option("BAZ", 3) + .format(dataSourceName).load() + checkAnswer(df, Row("1", "2", "true")) + df.write.option("foo", 1).option("bar", 2).format(dataSourceName).mode("append").save() + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org