This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 61030de02f5 [SPARK-44364] [PYTHON] Add support for List[Row] data type for expected 61030de02f5 is described below commit 61030de02f5735919d6b3e3c2923831c5d2fcc61 Author: Amanda Liu <amanda....@databricks.com> AuthorDate: Fri Jul 14 08:27:41 2023 +0900 [SPARK-44364] [PYTHON] Add support for List[Row] data type for expected ### What changes were proposed in this pull request? This PR adds supported for List[Row] type for the `expected` argument in `assertDataFrameEqual`. ### Why are the changes needed? The change improves flexibility of the `assertDataFrameEqual` function by allowing for comparison between dataframe and List[Row] types. ### Does this PR introduce _any_ user-facing change? Yes, the PR modifies the user-facing API `assertDataFrameEqual`. ### How was this patch tested? Added tests to `runtime/python/pyspark/sql/tests/test_utils.py` and `runtime/python/pyspark/sql/tests/connect/test_utils.py` Closes #41924 from asl3/list-row-support. Authored-by: Amanda Liu <amanda....@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/errors/error_classes.py | 5 - python/pyspark/sql/tests/test_utils.py | 266 ++++++++++++++++++++++----------- python/pyspark/testing/utils.py | 48 +++--- 3 files changed, 209 insertions(+), 110 deletions(-) diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index 57447d56892..8c51024bf06 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -713,11 +713,6 @@ ERROR_CLASSES_JSON = """ "<data_type> is only supported with pyarrow 2.0.0 and above." ] }, - "UNSUPPORTED_DATA_TYPE_FOR_IGNORE_ROW_ORDER" : { - "message" : [ - "Cannot ignore row order because undefined sorting for data type." - ] - }, "UNSUPPORTED_JOIN_TYPE" : { "message" : [ "Unsupported join type: <join_type>. Supported join types include: \\"inner\\", \\"outer\\", \\"full\\", \\"fullouter\\", \\"full_outer\\", \\"leftouter\\", \\"left\\", \\"left_outer\\", \\"rightouter\\", \\"right\\", \\"right_outer\\", \\"leftsemi\\", \\"left_semi\\", \\"semi\\", \\"leftanti\\", \\"left_anti\\", \\"anti\\", \\"cross\\"." diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py index 1757d8dd2e1..ce8d83e6cb9 100644 --- a/python/pyspark/sql/tests/test_utils.py +++ b/python/pyspark/sql/tests/test_utils.py @@ -57,7 +57,8 @@ class UtilsTestsMixin: schema=["id", "amount"], ) - assertDataFrameEqual(df1, df2) + assertDataFrameEqual(df1, df2, checkRowOrder=False) + assertDataFrameEqual(df1, df2, checkRowOrder=True) def test_assert_equal_arraytype(self): df1 = self.spark.createDataFrame( @@ -85,7 +86,8 @@ class UtilsTestsMixin: ), ) - assertDataFrameEqual(df1, df2) + assertDataFrameEqual(df1, df2, checkRowOrder=False) + assertDataFrameEqual(df1, df2, checkRowOrder=True) def test_assert_approx_equal_arraytype_float(self): df1 = self.spark.createDataFrame( @@ -113,7 +115,8 @@ class UtilsTestsMixin: ), ) - assertDataFrameEqual(df1, df2) + assertDataFrameEqual(df1, df2, checkRowOrder=False) + assertDataFrameEqual(df1, df2, checkRowOrder=True) def test_assert_notequal_arraytype(self): df1 = self.spark.createDataFrame( @@ -167,6 +170,15 @@ class UtilsTestsMixin: message_parameters={"error_msg": expected_error_message}, ) + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, df2, checkRowOrder=True) + + self.check_error( + exception=pe.exception, + error_class="DIFFERENT_ROWS", + message_parameters={"error_msg": expected_error_message}, + ) + def test_assert_equal_maptype(self): df1 = self.spark.createDataFrame( data=[ @@ -193,35 +205,8 @@ class UtilsTestsMixin: ), ) - assertDataFrameEqual(df1, df2, check_row_order=True) - - def test_assert_approx_equal_maptype_double(self): - df1 = self.spark.createDataFrame( - data=[ - ("student1", {"math": 76.23, "english": 92.64}), - ("student2", {"math": 87.89, "english": 84.48}), - ], - schema=StructType( - [ - StructField("student", StringType(), True), - StructField("grades", MapType(StringType(), DoubleType()), True), - ] - ), - ) - df2 = self.spark.createDataFrame( - data=[ - ("student1", {"math": 76.23, "english": 92.63999999}), - ("student2", {"math": 87.89, "english": 84.48}), - ], - schema=StructType( - [ - StructField("student", StringType(), True), - StructField("grades", MapType(StringType(), DoubleType()), True), - ] - ), - ) - - assertDataFrameEqual(df1, df2, check_row_order=True) + assertDataFrameEqual(df1, df2, checkRowOrder=False) + assertDataFrameEqual(df1, df2, checkRowOrder=True) def test_assert_approx_equal_maptype_double(self): df1 = self.spark.createDataFrame( @@ -249,7 +234,8 @@ class UtilsTestsMixin: ), ) - assertDataFrameEqual(df1, df2, check_row_order=True) + assertDataFrameEqual(df1, df2, checkRowOrder=False) + assertDataFrameEqual(df1, df2, checkRowOrder=True) def test_assert_approx_equal_nested_struct_double(self): df1 = self.spark.createDataFrame( @@ -296,7 +282,8 @@ class UtilsTestsMixin: ), ) - assertDataFrameEqual(df1, df2) + assertDataFrameEqual(df1, df2, checkRowOrder=False) + assertDataFrameEqual(df1, df2, checkRowOrder=True) def test_assert_equal_nested_struct_str(self): df1 = self.spark.createDataFrame( @@ -343,7 +330,8 @@ class UtilsTestsMixin: ), ) - assertDataFrameEqual(df1, df2) + assertDataFrameEqual(df1, df2, checkRowOrder=False) + assertDataFrameEqual(df1, df2, checkRowOrder=True) def test_assert_equal_nested_struct_str_duplicate(self): df1 = self.spark.createDataFrame( @@ -388,7 +376,8 @@ class UtilsTestsMixin: ), ) - assertDataFrameEqual(df1, df2) + assertDataFrameEqual(df1, df2, checkRowOrder=False) + assertDataFrameEqual(df1, df2, checkRowOrder=True) def test_assert_equal_duplicate_col(self): df1 = self.spark.createDataFrame( @@ -406,7 +395,8 @@ class UtilsTestsMixin: schema=["number", "language", "number", "number"], ) - assertDataFrameEqual(df1, df2) + assertDataFrameEqual(df1, df2, checkRowOrder=False) + assertDataFrameEqual(df1, df2, checkRowOrder=True) def test_assert_equal_timestamp(self): df1 = self.spark.createDataFrame( @@ -420,7 +410,8 @@ class UtilsTestsMixin: df1 = df1.withColumn("timestamp", to_timestamp("timestamp")) df2 = df2.withColumn("timestamp", to_timestamp("timestamp")) - assertDataFrameEqual(df1, df2) + assertDataFrameEqual(df1, df2, checkRowOrder=False) + assertDataFrameEqual(df1, df2, checkRowOrder=True) def test_assert_equal_nullrow(self): df1 = self.spark.createDataFrame( @@ -438,7 +429,8 @@ class UtilsTestsMixin: schema=["id", "amount"], ) - assertDataFrameEqual(df1, df2) + assertDataFrameEqual(df1, df2, checkRowOrder=False) + assertDataFrameEqual(df1, df2, checkRowOrder=True) def test_assert_notequal_nullval(self): df1 = self.spark.createDataFrame( @@ -482,11 +474,21 @@ class UtilsTestsMixin: message_parameters={"error_msg": expected_error_message}, ) + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, df2, checkRowOrder=True) + + self.check_error( + exception=pe.exception, + error_class="DIFFERENT_ROWS", + message_parameters={"error_msg": expected_error_message}, + ) + def test_assert_equal_nulldf(self): df1 = None df2 = None - assertDataFrameEqual(df1, df2) + assertDataFrameEqual(df1, df2, checkRowOrder=False) + assertDataFrameEqual(df1, df2, checkRowOrder=True) def test_assert_error_pandas_df(self): import pandas as pd @@ -503,6 +505,15 @@ class UtilsTestsMixin: message_parameters={"data_type": pd.DataFrame}, ) + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, df2, checkRowOrder=True) + + self.check_error( + exception=pe.exception, + error_class="UNSUPPORTED_DATA_TYPE", + message_parameters={"data_type": pd.DataFrame}, + ) + def test_assert_error_non_pyspark_df(self): dict1 = {"a": 1, "b": 2} dict2 = {"a": 1, "b": 2} @@ -516,6 +527,15 @@ class UtilsTestsMixin: message_parameters={"data_type": type(dict1)}, ) + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(dict1, dict2, checkRowOrder=True) + + self.check_error( + exception=pe.exception, + error_class="UNSUPPORTED_DATA_TYPE", + message_parameters={"data_type": type(dict1)}, + ) + def test_row_order_ignored(self): # test that row order is ignored (not checked) by default df1 = self.spark.createDataFrame( @@ -536,7 +556,7 @@ class UtilsTestsMixin: assertDataFrameEqual(df1, df2) def test_check_row_order_error(self): - # test check_row_order=True + # test checkRowOrder=True df1 = self.spark.createDataFrame( data=[ ("2", 3000.00), @@ -582,7 +602,7 @@ class UtilsTestsMixin: expected_error_message += "\n" + diff_msg with self.assertRaises(PySparkAssertionError) as pe: - assertDataFrameEqual(df1, df2, check_row_order=True) + assertDataFrameEqual(df1, df2, checkRowOrder=True) self.check_error( exception=pe.exception, @@ -620,7 +640,8 @@ class UtilsTestsMixin: schema=["id", "amount"], ) - assertDataFrameEqual(df1, df2) + assertDataFrameEqual(df1, df2, checkRowOrder=False) + assertDataFrameEqual(df1, df2, checkRowOrder=True) def test_assert_pyspark_df_not_equal(self): df1 = self.spark.createDataFrame( @@ -678,6 +699,15 @@ class UtilsTestsMixin: message_parameters={"error_msg": expected_error_message}, ) + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, df2, checkRowOrder=True) + + self.check_error( + exception=pe.exception, + error_class="DIFFERENT_ROWS", + message_parameters={"error_msg": expected_error_message}, + ) + def test_assert_notequal_schema(self): df1 = self.spark.createDataFrame( data=[ @@ -703,6 +733,15 @@ class UtilsTestsMixin: message_parameters={"df_schema": df1.schema, "expected_schema": df2.schema}, ) + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, df2, checkRowOrder=True) + + self.check_error( + exception=pe.exception, + error_class="DIFFERENT_SCHEMA", + message_parameters={"df_schema": df1.schema, "expected_schema": df2.schema}, + ) + def test_diff_schema_lens(self): df1 = self.spark.createDataFrame( data=[ @@ -729,48 +768,22 @@ class UtilsTestsMixin: message_parameters={"df_schema": df1.schema, "expected_schema": df2.schema}, ) - def test_assert_equal_maptype(self): - df1 = self.spark.createDataFrame( - data=[ - ("student1", {"id": 222342203655477580}), - ("student2", {"grad_year": 422322203155477692}), - ], - schema=StructType( - [ - StructField("student", StringType(), True), - StructField("properties", MapType(StringType(), LongType()), True), - ] - ), - ) - df2 = self.spark.createDataFrame( - data=[ - ("student1", {"id": 222342203655477580}), - ("student2", {"id": 422322203155477692}), - ], - schema=StructType( - [ - StructField("student", StringType(), True), - StructField("properties", MapType(StringType(), LongType()), True), - ] - ), - ) - - from pyspark.errors.exceptions.connect import AnalysisException + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, df2, checkRowOrder=True) - try: - assertDataFrameEqual(df1, df2) - except PySparkAssertionError as pe: - self.check_error( - exception=pe, - error_class="UNSUPPORTED_DATA_TYPE_FOR_IGNORE_ROW_ORDER", - message_parameters={}, - ) - except AnalysisException: - # catch AnalysisException for Spark Connect - pass + self.check_error( + exception=pe.exception, + error_class="DIFFERENT_SCHEMA", + message_parameters={"df_schema": df1.schema, "expected_schema": df2.schema}, + ) def test_spark_sql(self): assertDataFrameEqual(self.spark.sql("select 1 + 2 AS x"), self.spark.sql("select 3 AS x")) + assertDataFrameEqual( + self.spark.sql("select 1 + 2 AS x"), + self.spark.sql("select 3 AS x"), + checkRowOrder=True, + ) def test_spark_sql_sort_rows(self): df1 = self.spark.createDataFrame( @@ -796,26 +809,111 @@ class UtilsTestsMixin: self.spark.sql("select * from df1 order by amount"), self.spark.sql("select * from df2") ) + assertDataFrameEqual( + self.spark.sql("select * from df1 order by amount"), + self.spark.sql("select * from df2"), + checkRowOrder=True, + ) + def test_empty_dataset(self): df1 = self.spark.range(0, 10).limit(0) df2 = self.spark.range(0, 10).limit(0) - assertDataFrameEqual(df1, df2) + assertDataFrameEqual(df1, df2, checkRowOrder=False) + assertDataFrameEqual(df1, df2, checkRowOrder=True) def test_no_column(self): df1 = self.spark.range(0, 10).drop("id") df2 = self.spark.range(0, 10).drop("id") - assertDataFrameEqual(df1, df2) + assertDataFrameEqual(df1, df2, checkRowOrder=False) + assertDataFrameEqual(df1, df2, checkRowOrder=True) def test_empty_no_column(self): df1 = self.spark.range(0, 10).drop("id").limit(0) df2 = self.spark.range(0, 10).drop("id").limit(0) - assertDataFrameEqual(df1, df2) + assertDataFrameEqual(df1, df2, checkRowOrder=False) + assertDataFrameEqual(df1, df2, checkRowOrder=True) + + def test_list_row_equal(self): + from pyspark.sql import Row + + df1 = self.spark.createDataFrame( + data=[ + (1, 3000), + (2, 1000), + ], + schema=["id", "amount"], + ) + + list_of_rows = [Row(1, 3000), Row(2, 1000)] + + assertDataFrameEqual(df1, list_of_rows, checkRowOrder=False) + assertDataFrameEqual(df1, list_of_rows, checkRowOrder=True) + + def test_list_row_unequal_schema(self): + from pyspark.sql import Row + + df1 = self.spark.createDataFrame( + data=[ + (1, 3000), + (2, 1000), + ], + schema=["id", "amount"], + ) + + list_of_rows = [Row(1, "3000"), Row(2, "1000")] + + expected_error_message = "Results do not match: " + percent_diff = (2 / 2) * 100 + expected_error_message += "( %.5f %% )" % percent_diff + diff_msg = ( + "[df]" + + "\n" + + str(df1.collect()[0]) + + "\n\n" + + "[expected]" + + "\n" + + str(list_of_rows[0]) + + "\n\n" + + "********************" + + "\n\n" + ) + diff_msg += ( + "[df]" + + "\n" + + str(df1.collect()[1]) + + "\n\n" + + "[expected]" + + "\n" + + str(list_of_rows[1]) + + "\n\n" + + "********************" + + "\n\n" + ) + expected_error_message += "\n" + diff_msg + + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, list_of_rows) + + self.check_error( + exception=pe.exception, + error_class="DIFFERENT_ROWS", + message_parameters={"error_msg": expected_error_message}, + ) + + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, list_of_rows, checkRowOrder=True) + + self.check_error( + exception=pe.exception, + error_class="DIFFERENT_ROWS", + message_parameters={"error_msg": expected_error_message}, + ) class UtilsTests(ReusedSQLTestCase, UtilsTestsMixin): diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index defb6852969..cb3f72b75ed 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -222,10 +222,12 @@ class PySparkErrorTestUtils: ) -def assertDataFrameEqual(df: DataFrame, expected: DataFrame, check_row_order: bool = False): +def assertDataFrameEqual( + df: DataFrame, expected: Union[DataFrame, List[Row]], checkRowOrder: bool = False +): """ - A util function to assert equality between DataFrames `df` and `expected`, with - optional parameter `check_row_order`. + A util function to assert equality between `df` (DataFrame) and `expected` + (either DataFrame or list of Rows), with optional parameter `checkRowOrder`. .. versionadded:: 3.5.0 @@ -236,10 +238,10 @@ def assertDataFrameEqual(df: DataFrame, expected: DataFrame, check_row_order: bo df : DataFrame The DataFrame that is being compared or tested. - expected : DataFrame + expected : DataFrame or list of Rows The expected result of the operation, for comparison with the actual result. - check_row_order : bool, optional + checkRowOrder : bool, optional A flag indicating whether the order of rows should be considered in the comparison. If set to `False` (default), the row order is not taken into account. If set to `True`, the order of rows is important and will be checked during comparison. @@ -292,7 +294,11 @@ def assertDataFrameEqual(df: DataFrame, expected: DataFrame, check_row_order: bo error_class="UNSUPPORTED_DATA_TYPE", message_parameters={"data_type": type(df)}, ) - elif not isinstance(expected, DataFrame) and not isinstance(expected, ConnectDataFrame): + elif ( + not isinstance(expected, DataFrame) + and not isinstance(expected, ConnectDataFrame) + and not isinstance(expected, List) + ): raise PySparkAssertionError( error_class="UNSUPPORTED_DATA_TYPE", message_parameters={"data_type": type(expected)}, @@ -303,7 +309,7 @@ def assertDataFrameEqual(df: DataFrame, expected: DataFrame, check_row_order: bo error_class="UNSUPPORTED_DATA_TYPE", message_parameters={"data_type": type(df)}, ) - elif not isinstance(expected, DataFrame): + elif not isinstance(expected, DataFrame) and not isinstance(expected, List): raise PySparkAssertionError( error_class="UNSUPPORTED_DATA_TYPE", message_parameters={"data_type": type(expected)}, @@ -379,22 +385,22 @@ def assertDataFrameEqual(df: DataFrame, expected: DataFrame, check_row_order: bo message_parameters={"error_msg": error_msg}, ) - if not check_row_order: - try: - # rename duplicate columns for sorting - renamed_df = df.toDF(*[f"_{i}" for i in range(len(df.columns))]) - renamed_expected = expected.toDF(*[f"_{i}" for i in range(len(expected.columns))]) + # convert df and expected to list + if not isinstance(expected, List): + # only compare schema if expected is not a List + assert_schema_equal(df.schema, expected.schema) + expected_list = expected.collect() + else: + expected_list = expected - df = renamed_df.sort(renamed_df.columns).toDF(*df.columns) - expected = renamed_expected.sort(renamed_expected.columns).toDF(*expected.columns) - except Exception: - raise PySparkAssertionError( - error_class="UNSUPPORTED_DATA_TYPE_FOR_IGNORE_ROW_ORDER", - message_parameters={}, - ) + df_list = df.collect() + + if not checkRowOrder: + # rename duplicate columns for sorting + df_list = sorted(df_list, key=lambda x: str(x)) + expected_list = sorted(expected_list, key=lambda x: str(x)) - assert_schema_equal(df.schema, expected.schema) - assert_rows_equal(df.collect(), expected.collect()) + assert_rows_equal(df_list, expected_list) def _test() -> None: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org