This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new 31ede7330a3 [SPARK-42920][CONNECT][PYTHON] Enable tests for UDF with UDT 31ede7330a3 is described below commit 31ede7330a314b18faa591a9313ed31c5c8b63c1 Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Mon Mar 27 09:35:33 2023 +0900 [SPARK-42920][CONNECT][PYTHON] Enable tests for UDF with UDT ### What changes were proposed in this pull request? Enables tests for UDF with UDT. ### Why are the changes needed? Now that UDF with UDT should work, the related tests should be enabled to see if it works. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Enabled/modified the related tests. Closes #40549 from ueshin/issues/SPARK-42920/udf_with_udt. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> (cherry picked from commit 80f8664e8278335788d8fa1dd00654f3eaec8ed6) Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../pyspark/sql/tests/connect/test_parity_types.py | 4 +-- python/pyspark/sql/tests/test_types.py | 38 +++++++++++----------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py b/python/pyspark/sql/tests/connect/test_parity_types.py index a2f81fbf25e..aacf5793b2b 100644 --- a/python/pyspark/sql/tests/connect/test_parity_types.py +++ b/python/pyspark/sql/tests/connect/test_parity_types.py @@ -84,8 +84,8 @@ class TypesParityTests(TypesTestsMixin, ReusedConnectTestCase): super().test_infer_schema_upcast_int_to_string() @unittest.skip("Spark Connect does not support RDD but the tests depend on them.") - def test_udf_with_udt(self): - super().test_udf_with_udt() + def test_rdd_with_udt(self): + super().test_rdd_with_udt() @unittest.skip("Requires JVM access.") def test_udt(self): diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index bee899e928e..5d6476b47f4 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -25,8 +25,7 @@ import sys import unittest from pyspark.sql import Row -from pyspark.sql.functions import col -from pyspark.sql.udf import UserDefinedFunction +from pyspark.sql import functions as F from pyspark.errors import AnalysisException from pyspark.sql.types import ( ByteType, @@ -381,7 +380,7 @@ class TypesTestsMixin: try: self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=true") df = self.spark.createDataFrame([(1,), (11,)], ["value"]) - ret = df.select(col("value").cast(DecimalType(1, -1))).collect() + ret = df.select(F.col("value").cast(DecimalType(1, -1))).collect() actual = list(map(lambda r: int(r.value), ret)) self.assertEqual(actual, [0, 10]) finally: @@ -548,8 +547,6 @@ class TypesTestsMixin: df.collect() def test_complex_nested_udt_in_df(self): - from pyspark.sql.functions import udf - schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) df = self.spark.createDataFrame( [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], schema=schema @@ -558,7 +555,7 @@ class TypesTestsMixin: gd = df.groupby("key").agg({"val": "collect_list"}) gd.collect() - udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema)) + udf = F.udf(lambda k, v: [(k, v[0])], ArrayType(df.schema)) gd.select(udf(*gd)).collect() def test_udt_with_none(self): @@ -667,20 +664,27 @@ class TypesTestsMixin: def test_udf_with_udt(self): row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) df = self.spark.createDataFrame([row]) - self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) - udf = UserDefinedFunction(lambda p: p.y, DoubleType()) + udf = F.udf(lambda p: p.y, DoubleType()) self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) - udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT()) + udf2 = F.udf(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT()) self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) df = self.spark.createDataFrame([row]) - self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) - udf = UserDefinedFunction(lambda p: p.y, DoubleType()) + udf = F.udf(lambda p: p.y, DoubleType()) self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) - udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT()) + udf2 = F.udf(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT()) self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) + def test_rdd_with_udt(self): + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + df = self.spark.createDataFrame([row]) + self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) + + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df = self.spark.createDataFrame([row]) + self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) + def test_parquet_with_udt(self): row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) df0 = self.spark.createDataFrame([row]) @@ -719,8 +723,6 @@ class TypesTestsMixin: ) def test_cast_to_string_with_udt(self): - from pyspark.sql.functions import col - row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0)) schema = StructType( [ @@ -730,18 +732,16 @@ class TypesTestsMixin: ) df = self.spark.createDataFrame([row], schema) - result = df.select(col("point").cast("string"), col("pypoint").cast("string")).head() + result = df.select(F.col("point").cast("string"), F.col("pypoint").cast("string")).head() self.assertEqual(result, Row(point="(1.0, 2.0)", pypoint="[3.0, 4.0]")) def test_cast_to_udt_with_udt(self): - from pyspark.sql.functions import col - row = Row(point=ExamplePoint(1.0, 2.0), python_only_point=PythonOnlyPoint(1.0, 2.0)) df = self.spark.createDataFrame([row]) with self.assertRaises(AnalysisException): - df.select(col("point").cast(PythonOnlyUDT())).collect() + df.select(F.col("point").cast(PythonOnlyUDT())).collect() with self.assertRaises(AnalysisException): - df.select(col("python_only_point").cast(ExamplePointUDT())).collect() + df.select(F.col("python_only_point").cast(ExamplePointUDT())).collect() def test_struct_type(self): struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org