This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ae5da1814573 [SPARK-48075][SS] Add type checking for PySpark avro 
functions
ae5da1814573 is described below

commit ae5da181457381be474ae090509f0245f5385744
Author: fanyue-xia <chloe....@databircks.com>
AuthorDate: Thu May 2 14:38:11 2024 +0900

    [SPARK-48075][SS] Add type checking for PySpark avro functions
    
    ### What changes were proposed in this pull request?
    Add type checking for `to_avro` and `from_avro` for PySpark.
    
    ### Why are the changes needed?
    If we perform type checking for arguments and output sensible errors when 
the type of arguments passed into the functions don’t match, we can give the 
user a better user experience
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Unit test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #46324 from fanyue-xia/avro-typecheck.
    
    Authored-by: fanyue-xia <chloe....@databircks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/avro/functions.py         | 34 +++++++++++++++++++++++++++
 python/pyspark/sql/connect/avro/functions.py | 34 +++++++++++++++++++++++++++
 python/pyspark/sql/tests/test_functions.py   | 35 ++++++++++++++++++++++++++++
 3 files changed, 103 insertions(+)

diff --git a/python/pyspark/sql/avro/functions.py 
b/python/pyspark/sql/avro/functions.py
index 40a4ce337db0..fb3bd5398495 100644
--- a/python/pyspark/sql/avro/functions.py
+++ b/python/pyspark/sql/avro/functions.py
@@ -22,6 +22,7 @@ A collections of builtin avro functions
 
 from typing import Dict, Optional, TYPE_CHECKING, cast
 
+from pyspark.errors import PySparkTypeError
 from pyspark.sql.column import Column
 from pyspark.sql.utils import get_active_spark_context, 
try_remote_avro_functions
 from pyspark.util import _print_missing_jar
@@ -80,6 +81,25 @@ def from_avro(
     from py4j.java_gateway import JVMView
     from pyspark.sql.classic.column import _to_java_column
 
+    if not isinstance(data, (Column, str)):
+        raise PySparkTypeError(
+            error_class="INVALID_TYPE",
+            message_parameters={
+                "arg_name": "data",
+                "arg_type": "pyspark.sql.Column or str",
+            },
+        )
+    if not isinstance(jsonFormatSchema, str):
+        raise PySparkTypeError(
+            error_class="INVALID_TYPE",
+            message_parameters={"arg_name": "jsonFormatSchema", "arg_type": 
"str"},
+        )
+    if options is not None and not isinstance(options, dict):
+        raise PySparkTypeError(
+            error_class="INVALID_TYPE",
+            message_parameters={"arg_name": "options", "arg_type": "dict, 
optional"},
+        )
+
     sc = get_active_spark_context()
     try:
         jc = cast(JVMView, 
sc._jvm).org.apache.spark.sql.avro.functions.from_avro(
@@ -131,6 +151,20 @@ def to_avro(data: "ColumnOrName", jsonFormatSchema: str = 
"") -> Column:
     from py4j.java_gateway import JVMView
     from pyspark.sql.classic.column import _to_java_column
 
+    if not isinstance(data, (Column, str)):
+        raise PySparkTypeError(
+            error_class="INVALID_TYPE",
+            message_parameters={
+                "arg_name": "data",
+                "arg_type": "pyspark.sql.Column or str",
+            },
+        )
+    if not isinstance(jsonFormatSchema, str):
+        raise PySparkTypeError(
+            error_class="INVALID_TYPE",
+            message_parameters={"arg_name": "jsonFormatSchema", "arg_type": 
"str"},
+        )
+
     sc = get_active_spark_context()
     try:
         if jsonFormatSchema == "":
diff --git a/python/pyspark/sql/connect/avro/functions.py 
b/python/pyspark/sql/connect/avro/functions.py
index 7e46bc8d33be..da350f92a531 100644
--- a/python/pyspark/sql/connect/avro/functions.py
+++ b/python/pyspark/sql/connect/avro/functions.py
@@ -19,6 +19,7 @@
 A collections of builtin avro functions
 """
 
+from pyspark.errors import PySparkTypeError
 from pyspark.sql.connect.utils import check_dependencies
 
 check_dependencies(__name__)
@@ -36,6 +37,25 @@ if TYPE_CHECKING:
 def from_avro(
     data: "ColumnOrName", jsonFormatSchema: str, options: Optional[Dict[str, 
str]] = None
 ) -> Column:
+    if not isinstance(data, (Column, str)):
+        raise PySparkTypeError(
+            error_class="INVALID_TYPE",
+            message_parameters={
+                "arg_name": "data",
+                "arg_type": "pyspark.sql.Column or str",
+            },
+        )
+    if not isinstance(jsonFormatSchema, str):
+        raise PySparkTypeError(
+            error_class="INVALID_TYPE",
+            message_parameters={"arg_name": "jsonFormatSchema", "arg_type": 
"str"},
+        )
+    if options is not None and not isinstance(options, dict):
+        raise PySparkTypeError(
+            error_class="INVALID_TYPE",
+            message_parameters={"arg_name": "options", "arg_type": "dict, 
optional"},
+        )
+
     if options is None:
         return _invoke_function("from_avro", _to_col(data), 
lit(jsonFormatSchema))
     else:
@@ -48,6 +68,20 @@ from_avro.__doc__ = PyAvroFunctions.from_avro.__doc__
 
 
 def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column:
+    if not isinstance(data, (Column, str)):
+        raise PySparkTypeError(
+            error_class="INVALID_TYPE",
+            message_parameters={
+                "arg_name": "data",
+                "arg_type": "pyspark.sql.Column or str",
+            },
+        )
+    if not isinstance(jsonFormatSchema, str):
+        raise PySparkTypeError(
+            error_class="INVALID_TYPE",
+            message_parameters={"arg_name": "jsonFormatSchema", "arg_type": 
"str"},
+        )
+
     if jsonFormatSchema == "":
         return _invoke_function("to_avro", _to_col(data))
     else:
diff --git a/python/pyspark/sql/tests/test_functions.py 
b/python/pyspark/sql/tests/test_functions.py
index 1fd23906edae..4e9b61f7d0d9 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -26,6 +26,7 @@ import unittest
 
 from pyspark.errors import PySparkTypeError, PySparkValueError, 
SparkRuntimeException
 from pyspark.sql import Row, Window, functions as F, types
+from pyspark.sql.avro.functions import from_avro, to_avro
 from pyspark.sql.column import Column
 from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils
 from pyspark.testing.utils import have_numpy
@@ -1528,6 +1529,40 @@ class FunctionsTestsMixin:
             lambda: df.select(F.json_tuple(df.jstring)),
         )
 
+    def test_avro_type_check(self):
+        parameters = ["data", "jsonFormatSchema", "options"]
+        expected_type = ["pyspark.sql.Column or str", "str", "dict, optional"]
+        dummyDF = self.spark.createDataFrame([Row(a=i, b=i) for i in range(5)])
+
+        # test from_avro type checks for each parameter
+        wrong_type_value = 1
+        with self.assertRaises(PySparkTypeError) as pe1:
+            dummyDF.select(from_avro(wrong_type_value, "jsonSchema", None))
+        with self.assertRaises(PySparkTypeError) as pe2:
+            dummyDF.select(from_avro("value", wrong_type_value, None))
+        with self.assertRaises(PySparkTypeError) as pe3:
+            dummyDF.select(from_avro("value", "jsonSchema", wrong_type_value))
+        from_avro_pes = [pe1, pe2, pe3]
+        for i in range(3):
+            self.check_error(
+                exception=from_avro_pes[i].exception,
+                error_class="INVALID_TYPE",
+                message_parameters={"arg_name": parameters[i], "arg_type": 
expected_type[i]},
+            )
+
+        # test to_avro type checks for each parameter
+        with self.assertRaises(PySparkTypeError) as pe4:
+            dummyDF.select(to_avro(wrong_type_value, "jsonSchema"))
+        with self.assertRaises(PySparkTypeError) as pe5:
+            dummyDF.select(to_avro("value", wrong_type_value))
+        to_avro_pes = [pe4, pe5]
+        for i in range(2):
+            self.check_error(
+                exception=to_avro_pes[i].exception,
+                error_class="INVALID_TYPE",
+                message_parameters={"arg_name": parameters[i], "arg_type": 
expected_type[i]},
+            )
+
 
 class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin):
     pass


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to