This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new ae5da1814573 [SPARK-48075][SS] Add type checking for PySpark avro functions ae5da1814573 is described below commit ae5da181457381be474ae090509f0245f5385744 Author: fanyue-xia <chloe....@databircks.com> AuthorDate: Thu May 2 14:38:11 2024 +0900 [SPARK-48075][SS] Add type checking for PySpark avro functions ### What changes were proposed in this pull request? Add type checking for `to_avro` and `from_avro` for PySpark. ### Why are the changes needed? If we perform type checking for arguments and output sensible errors when the type of arguments passed into the functions don’t match, we can give the user a better user experience ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test ### Was this patch authored or co-authored using generative AI tooling? No Closes #46324 from fanyue-xia/avro-typecheck. Authored-by: fanyue-xia <chloe....@databircks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/avro/functions.py | 34 +++++++++++++++++++++++++++ python/pyspark/sql/connect/avro/functions.py | 34 +++++++++++++++++++++++++++ python/pyspark/sql/tests/test_functions.py | 35 ++++++++++++++++++++++++++++ 3 files changed, 103 insertions(+) diff --git a/python/pyspark/sql/avro/functions.py b/python/pyspark/sql/avro/functions.py index 40a4ce337db0..fb3bd5398495 100644 --- a/python/pyspark/sql/avro/functions.py +++ b/python/pyspark/sql/avro/functions.py @@ -22,6 +22,7 @@ A collections of builtin avro functions from typing import Dict, Optional, TYPE_CHECKING, cast +from pyspark.errors import PySparkTypeError from pyspark.sql.column import Column from pyspark.sql.utils import get_active_spark_context, try_remote_avro_functions from pyspark.util import _print_missing_jar @@ -80,6 +81,25 @@ def from_avro( from py4j.java_gateway import JVMView from pyspark.sql.classic.column import _to_java_column + if not isinstance(data, (Column, str)): + raise PySparkTypeError( + error_class="INVALID_TYPE", + message_parameters={ + "arg_name": "data", + "arg_type": "pyspark.sql.Column or str", + }, + ) + if not isinstance(jsonFormatSchema, str): + raise PySparkTypeError( + error_class="INVALID_TYPE", + message_parameters={"arg_name": "jsonFormatSchema", "arg_type": "str"}, + ) + if options is not None and not isinstance(options, dict): + raise PySparkTypeError( + error_class="INVALID_TYPE", + message_parameters={"arg_name": "options", "arg_type": "dict, optional"}, + ) + sc = get_active_spark_context() try: jc = cast(JVMView, sc._jvm).org.apache.spark.sql.avro.functions.from_avro( @@ -131,6 +151,20 @@ def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column: from py4j.java_gateway import JVMView from pyspark.sql.classic.column import _to_java_column + if not isinstance(data, (Column, str)): + raise PySparkTypeError( + error_class="INVALID_TYPE", + message_parameters={ + "arg_name": "data", + "arg_type": "pyspark.sql.Column or str", + }, + ) + if not isinstance(jsonFormatSchema, str): + raise PySparkTypeError( + error_class="INVALID_TYPE", + message_parameters={"arg_name": "jsonFormatSchema", "arg_type": "str"}, + ) + sc = get_active_spark_context() try: if jsonFormatSchema == "": diff --git a/python/pyspark/sql/connect/avro/functions.py b/python/pyspark/sql/connect/avro/functions.py index 7e46bc8d33be..da350f92a531 100644 --- a/python/pyspark/sql/connect/avro/functions.py +++ b/python/pyspark/sql/connect/avro/functions.py @@ -19,6 +19,7 @@ A collections of builtin avro functions """ +from pyspark.errors import PySparkTypeError from pyspark.sql.connect.utils import check_dependencies check_dependencies(__name__) @@ -36,6 +37,25 @@ if TYPE_CHECKING: def from_avro( data: "ColumnOrName", jsonFormatSchema: str, options: Optional[Dict[str, str]] = None ) -> Column: + if not isinstance(data, (Column, str)): + raise PySparkTypeError( + error_class="INVALID_TYPE", + message_parameters={ + "arg_name": "data", + "arg_type": "pyspark.sql.Column or str", + }, + ) + if not isinstance(jsonFormatSchema, str): + raise PySparkTypeError( + error_class="INVALID_TYPE", + message_parameters={"arg_name": "jsonFormatSchema", "arg_type": "str"}, + ) + if options is not None and not isinstance(options, dict): + raise PySparkTypeError( + error_class="INVALID_TYPE", + message_parameters={"arg_name": "options", "arg_type": "dict, optional"}, + ) + if options is None: return _invoke_function("from_avro", _to_col(data), lit(jsonFormatSchema)) else: @@ -48,6 +68,20 @@ from_avro.__doc__ = PyAvroFunctions.from_avro.__doc__ def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column: + if not isinstance(data, (Column, str)): + raise PySparkTypeError( + error_class="INVALID_TYPE", + message_parameters={ + "arg_name": "data", + "arg_type": "pyspark.sql.Column or str", + }, + ) + if not isinstance(jsonFormatSchema, str): + raise PySparkTypeError( + error_class="INVALID_TYPE", + message_parameters={"arg_name": "jsonFormatSchema", "arg_type": "str"}, + ) + if jsonFormatSchema == "": return _invoke_function("to_avro", _to_col(data)) else: diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 1fd23906edae..4e9b61f7d0d9 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -26,6 +26,7 @@ import unittest from pyspark.errors import PySparkTypeError, PySparkValueError, SparkRuntimeException from pyspark.sql import Row, Window, functions as F, types +from pyspark.sql.avro.functions import from_avro, to_avro from pyspark.sql.column import Column from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils from pyspark.testing.utils import have_numpy @@ -1528,6 +1529,40 @@ class FunctionsTestsMixin: lambda: df.select(F.json_tuple(df.jstring)), ) + def test_avro_type_check(self): + parameters = ["data", "jsonFormatSchema", "options"] + expected_type = ["pyspark.sql.Column or str", "str", "dict, optional"] + dummyDF = self.spark.createDataFrame([Row(a=i, b=i) for i in range(5)]) + + # test from_avro type checks for each parameter + wrong_type_value = 1 + with self.assertRaises(PySparkTypeError) as pe1: + dummyDF.select(from_avro(wrong_type_value, "jsonSchema", None)) + with self.assertRaises(PySparkTypeError) as pe2: + dummyDF.select(from_avro("value", wrong_type_value, None)) + with self.assertRaises(PySparkTypeError) as pe3: + dummyDF.select(from_avro("value", "jsonSchema", wrong_type_value)) + from_avro_pes = [pe1, pe2, pe3] + for i in range(3): + self.check_error( + exception=from_avro_pes[i].exception, + error_class="INVALID_TYPE", + message_parameters={"arg_name": parameters[i], "arg_type": expected_type[i]}, + ) + + # test to_avro type checks for each parameter + with self.assertRaises(PySparkTypeError) as pe4: + dummyDF.select(to_avro(wrong_type_value, "jsonSchema")) + with self.assertRaises(PySparkTypeError) as pe5: + dummyDF.select(to_avro("value", wrong_type_value)) + to_avro_pes = [pe4, pe5] + for i in range(2): + self.check_error( + exception=to_avro_pes[i].exception, + error_class="INVALID_TYPE", + message_parameters={"arg_name": parameters[i], "arg_type": expected_type[i]}, + ) + class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin): pass --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org