This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 61030de02f5 [SPARK-44364] [PYTHON] Add support for List[Row] data type 
for expected
61030de02f5 is described below

commit 61030de02f5735919d6b3e3c2923831c5d2fcc61
Author: Amanda Liu <amanda....@databricks.com>
AuthorDate: Fri Jul 14 08:27:41 2023 +0900

    [SPARK-44364] [PYTHON] Add support for List[Row] data type for expected
    
    ### What changes were proposed in this pull request?
    This PR adds supported for List[Row] type for the `expected` argument in 
`assertDataFrameEqual`.
    
    ### Why are the changes needed?
    The change improves flexibility of the `assertDataFrameEqual` function by 
allowing for comparison between dataframe and List[Row] types.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, the PR modifies the user-facing API `assertDataFrameEqual`.
    
    ### How was this patch tested?
    Added tests to `runtime/python/pyspark/sql/tests/test_utils.py` and 
`runtime/python/pyspark/sql/tests/connect/test_utils.py`
    
    Closes #41924 from asl3/list-row-support.
    
    Authored-by: Amanda Liu <amanda....@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/errors/error_classes.py |   5 -
 python/pyspark/sql/tests/test_utils.py | 266 ++++++++++++++++++++++-----------
 python/pyspark/testing/utils.py        |  48 +++---
 3 files changed, 209 insertions(+), 110 deletions(-)

diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index 57447d56892..8c51024bf06 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -713,11 +713,6 @@ ERROR_CLASSES_JSON = """
       "<data_type> is only supported with pyarrow 2.0.0 and above."
     ]
   },
-  "UNSUPPORTED_DATA_TYPE_FOR_IGNORE_ROW_ORDER" : {
-    "message" : [
-      "Cannot ignore row order because undefined sorting for data type."
-    ]
-  },
   "UNSUPPORTED_JOIN_TYPE" : {
     "message" : [
       "Unsupported join type: <join_type>. Supported join types include: 
\\"inner\\", \\"outer\\", \\"full\\", \\"fullouter\\", \\"full_outer\\", 
\\"leftouter\\", \\"left\\", \\"left_outer\\", \\"rightouter\\", \\"right\\", 
\\"right_outer\\", \\"leftsemi\\", \\"left_semi\\", \\"semi\\", \\"leftanti\\", 
\\"left_anti\\", \\"anti\\", \\"cross\\"."
diff --git a/python/pyspark/sql/tests/test_utils.py 
b/python/pyspark/sql/tests/test_utils.py
index 1757d8dd2e1..ce8d83e6cb9 100644
--- a/python/pyspark/sql/tests/test_utils.py
+++ b/python/pyspark/sql/tests/test_utils.py
@@ -57,7 +57,8 @@ class UtilsTestsMixin:
             schema=["id", "amount"],
         )
 
-        assertDataFrameEqual(df1, df2)
+        assertDataFrameEqual(df1, df2, checkRowOrder=False)
+        assertDataFrameEqual(df1, df2, checkRowOrder=True)
 
     def test_assert_equal_arraytype(self):
         df1 = self.spark.createDataFrame(
@@ -85,7 +86,8 @@ class UtilsTestsMixin:
             ),
         )
 
-        assertDataFrameEqual(df1, df2)
+        assertDataFrameEqual(df1, df2, checkRowOrder=False)
+        assertDataFrameEqual(df1, df2, checkRowOrder=True)
 
     def test_assert_approx_equal_arraytype_float(self):
         df1 = self.spark.createDataFrame(
@@ -113,7 +115,8 @@ class UtilsTestsMixin:
             ),
         )
 
-        assertDataFrameEqual(df1, df2)
+        assertDataFrameEqual(df1, df2, checkRowOrder=False)
+        assertDataFrameEqual(df1, df2, checkRowOrder=True)
 
     def test_assert_notequal_arraytype(self):
         df1 = self.spark.createDataFrame(
@@ -167,6 +170,15 @@ class UtilsTestsMixin:
             message_parameters={"error_msg": expected_error_message},
         )
 
+        with self.assertRaises(PySparkAssertionError) as pe:
+            assertDataFrameEqual(df1, df2, checkRowOrder=True)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="DIFFERENT_ROWS",
+            message_parameters={"error_msg": expected_error_message},
+        )
+
     def test_assert_equal_maptype(self):
         df1 = self.spark.createDataFrame(
             data=[
@@ -193,35 +205,8 @@ class UtilsTestsMixin:
             ),
         )
 
-        assertDataFrameEqual(df1, df2, check_row_order=True)
-
-    def test_assert_approx_equal_maptype_double(self):
-        df1 = self.spark.createDataFrame(
-            data=[
-                ("student1", {"math": 76.23, "english": 92.64}),
-                ("student2", {"math": 87.89, "english": 84.48}),
-            ],
-            schema=StructType(
-                [
-                    StructField("student", StringType(), True),
-                    StructField("grades", MapType(StringType(), DoubleType()), 
True),
-                ]
-            ),
-        )
-        df2 = self.spark.createDataFrame(
-            data=[
-                ("student1", {"math": 76.23, "english": 92.63999999}),
-                ("student2", {"math": 87.89, "english": 84.48}),
-            ],
-            schema=StructType(
-                [
-                    StructField("student", StringType(), True),
-                    StructField("grades", MapType(StringType(), DoubleType()), 
True),
-                ]
-            ),
-        )
-
-        assertDataFrameEqual(df1, df2, check_row_order=True)
+        assertDataFrameEqual(df1, df2, checkRowOrder=False)
+        assertDataFrameEqual(df1, df2, checkRowOrder=True)
 
     def test_assert_approx_equal_maptype_double(self):
         df1 = self.spark.createDataFrame(
@@ -249,7 +234,8 @@ class UtilsTestsMixin:
             ),
         )
 
-        assertDataFrameEqual(df1, df2, check_row_order=True)
+        assertDataFrameEqual(df1, df2, checkRowOrder=False)
+        assertDataFrameEqual(df1, df2, checkRowOrder=True)
 
     def test_assert_approx_equal_nested_struct_double(self):
         df1 = self.spark.createDataFrame(
@@ -296,7 +282,8 @@ class UtilsTestsMixin:
             ),
         )
 
-        assertDataFrameEqual(df1, df2)
+        assertDataFrameEqual(df1, df2, checkRowOrder=False)
+        assertDataFrameEqual(df1, df2, checkRowOrder=True)
 
     def test_assert_equal_nested_struct_str(self):
         df1 = self.spark.createDataFrame(
@@ -343,7 +330,8 @@ class UtilsTestsMixin:
             ),
         )
 
-        assertDataFrameEqual(df1, df2)
+        assertDataFrameEqual(df1, df2, checkRowOrder=False)
+        assertDataFrameEqual(df1, df2, checkRowOrder=True)
 
     def test_assert_equal_nested_struct_str_duplicate(self):
         df1 = self.spark.createDataFrame(
@@ -388,7 +376,8 @@ class UtilsTestsMixin:
             ),
         )
 
-        assertDataFrameEqual(df1, df2)
+        assertDataFrameEqual(df1, df2, checkRowOrder=False)
+        assertDataFrameEqual(df1, df2, checkRowOrder=True)
 
     def test_assert_equal_duplicate_col(self):
         df1 = self.spark.createDataFrame(
@@ -406,7 +395,8 @@ class UtilsTestsMixin:
             schema=["number", "language", "number", "number"],
         )
 
-        assertDataFrameEqual(df1, df2)
+        assertDataFrameEqual(df1, df2, checkRowOrder=False)
+        assertDataFrameEqual(df1, df2, checkRowOrder=True)
 
     def test_assert_equal_timestamp(self):
         df1 = self.spark.createDataFrame(
@@ -420,7 +410,8 @@ class UtilsTestsMixin:
         df1 = df1.withColumn("timestamp", to_timestamp("timestamp"))
         df2 = df2.withColumn("timestamp", to_timestamp("timestamp"))
 
-        assertDataFrameEqual(df1, df2)
+        assertDataFrameEqual(df1, df2, checkRowOrder=False)
+        assertDataFrameEqual(df1, df2, checkRowOrder=True)
 
     def test_assert_equal_nullrow(self):
         df1 = self.spark.createDataFrame(
@@ -438,7 +429,8 @@ class UtilsTestsMixin:
             schema=["id", "amount"],
         )
 
-        assertDataFrameEqual(df1, df2)
+        assertDataFrameEqual(df1, df2, checkRowOrder=False)
+        assertDataFrameEqual(df1, df2, checkRowOrder=True)
 
     def test_assert_notequal_nullval(self):
         df1 = self.spark.createDataFrame(
@@ -482,11 +474,21 @@ class UtilsTestsMixin:
             message_parameters={"error_msg": expected_error_message},
         )
 
+        with self.assertRaises(PySparkAssertionError) as pe:
+            assertDataFrameEqual(df1, df2, checkRowOrder=True)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="DIFFERENT_ROWS",
+            message_parameters={"error_msg": expected_error_message},
+        )
+
     def test_assert_equal_nulldf(self):
         df1 = None
         df2 = None
 
-        assertDataFrameEqual(df1, df2)
+        assertDataFrameEqual(df1, df2, checkRowOrder=False)
+        assertDataFrameEqual(df1, df2, checkRowOrder=True)
 
     def test_assert_error_pandas_df(self):
         import pandas as pd
@@ -503,6 +505,15 @@ class UtilsTestsMixin:
             message_parameters={"data_type": pd.DataFrame},
         )
 
+        with self.assertRaises(PySparkAssertionError) as pe:
+            assertDataFrameEqual(df1, df2, checkRowOrder=True)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="UNSUPPORTED_DATA_TYPE",
+            message_parameters={"data_type": pd.DataFrame},
+        )
+
     def test_assert_error_non_pyspark_df(self):
         dict1 = {"a": 1, "b": 2}
         dict2 = {"a": 1, "b": 2}
@@ -516,6 +527,15 @@ class UtilsTestsMixin:
             message_parameters={"data_type": type(dict1)},
         )
 
+        with self.assertRaises(PySparkAssertionError) as pe:
+            assertDataFrameEqual(dict1, dict2, checkRowOrder=True)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="UNSUPPORTED_DATA_TYPE",
+            message_parameters={"data_type": type(dict1)},
+        )
+
     def test_row_order_ignored(self):
         # test that row order is ignored (not checked) by default
         df1 = self.spark.createDataFrame(
@@ -536,7 +556,7 @@ class UtilsTestsMixin:
         assertDataFrameEqual(df1, df2)
 
     def test_check_row_order_error(self):
-        # test check_row_order=True
+        # test checkRowOrder=True
         df1 = self.spark.createDataFrame(
             data=[
                 ("2", 3000.00),
@@ -582,7 +602,7 @@ class UtilsTestsMixin:
         expected_error_message += "\n" + diff_msg
 
         with self.assertRaises(PySparkAssertionError) as pe:
-            assertDataFrameEqual(df1, df2, check_row_order=True)
+            assertDataFrameEqual(df1, df2, checkRowOrder=True)
 
         self.check_error(
             exception=pe.exception,
@@ -620,7 +640,8 @@ class UtilsTestsMixin:
             schema=["id", "amount"],
         )
 
-        assertDataFrameEqual(df1, df2)
+        assertDataFrameEqual(df1, df2, checkRowOrder=False)
+        assertDataFrameEqual(df1, df2, checkRowOrder=True)
 
     def test_assert_pyspark_df_not_equal(self):
         df1 = self.spark.createDataFrame(
@@ -678,6 +699,15 @@ class UtilsTestsMixin:
             message_parameters={"error_msg": expected_error_message},
         )
 
+        with self.assertRaises(PySparkAssertionError) as pe:
+            assertDataFrameEqual(df1, df2, checkRowOrder=True)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="DIFFERENT_ROWS",
+            message_parameters={"error_msg": expected_error_message},
+        )
+
     def test_assert_notequal_schema(self):
         df1 = self.spark.createDataFrame(
             data=[
@@ -703,6 +733,15 @@ class UtilsTestsMixin:
             message_parameters={"df_schema": df1.schema, "expected_schema": 
df2.schema},
         )
 
+        with self.assertRaises(PySparkAssertionError) as pe:
+            assertDataFrameEqual(df1, df2, checkRowOrder=True)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="DIFFERENT_SCHEMA",
+            message_parameters={"df_schema": df1.schema, "expected_schema": 
df2.schema},
+        )
+
     def test_diff_schema_lens(self):
         df1 = self.spark.createDataFrame(
             data=[
@@ -729,48 +768,22 @@ class UtilsTestsMixin:
             message_parameters={"df_schema": df1.schema, "expected_schema": 
df2.schema},
         )
 
-    def test_assert_equal_maptype(self):
-        df1 = self.spark.createDataFrame(
-            data=[
-                ("student1", {"id": 222342203655477580}),
-                ("student2", {"grad_year": 422322203155477692}),
-            ],
-            schema=StructType(
-                [
-                    StructField("student", StringType(), True),
-                    StructField("properties", MapType(StringType(), 
LongType()), True),
-                ]
-            ),
-        )
-        df2 = self.spark.createDataFrame(
-            data=[
-                ("student1", {"id": 222342203655477580}),
-                ("student2", {"id": 422322203155477692}),
-            ],
-            schema=StructType(
-                [
-                    StructField("student", StringType(), True),
-                    StructField("properties", MapType(StringType(), 
LongType()), True),
-                ]
-            ),
-        )
-
-        from pyspark.errors.exceptions.connect import AnalysisException
+        with self.assertRaises(PySparkAssertionError) as pe:
+            assertDataFrameEqual(df1, df2, checkRowOrder=True)
 
-        try:
-            assertDataFrameEqual(df1, df2)
-        except PySparkAssertionError as pe:
-            self.check_error(
-                exception=pe,
-                error_class="UNSUPPORTED_DATA_TYPE_FOR_IGNORE_ROW_ORDER",
-                message_parameters={},
-            )
-        except AnalysisException:
-            # catch AnalysisException for Spark Connect
-            pass
+        self.check_error(
+            exception=pe.exception,
+            error_class="DIFFERENT_SCHEMA",
+            message_parameters={"df_schema": df1.schema, "expected_schema": 
df2.schema},
+        )
 
     def test_spark_sql(self):
         assertDataFrameEqual(self.spark.sql("select 1 + 2 AS x"), 
self.spark.sql("select 3 AS x"))
+        assertDataFrameEqual(
+            self.spark.sql("select 1 + 2 AS x"),
+            self.spark.sql("select 3 AS x"),
+            checkRowOrder=True,
+        )
 
     def test_spark_sql_sort_rows(self):
         df1 = self.spark.createDataFrame(
@@ -796,26 +809,111 @@ class UtilsTestsMixin:
             self.spark.sql("select * from df1 order by amount"), 
self.spark.sql("select * from df2")
         )
 
+        assertDataFrameEqual(
+            self.spark.sql("select * from df1 order by amount"),
+            self.spark.sql("select * from df2"),
+            checkRowOrder=True,
+        )
+
     def test_empty_dataset(self):
         df1 = self.spark.range(0, 10).limit(0)
 
         df2 = self.spark.range(0, 10).limit(0)
 
-        assertDataFrameEqual(df1, df2)
+        assertDataFrameEqual(df1, df2, checkRowOrder=False)
+        assertDataFrameEqual(df1, df2, checkRowOrder=True)
 
     def test_no_column(self):
         df1 = self.spark.range(0, 10).drop("id")
 
         df2 = self.spark.range(0, 10).drop("id")
 
-        assertDataFrameEqual(df1, df2)
+        assertDataFrameEqual(df1, df2, checkRowOrder=False)
+        assertDataFrameEqual(df1, df2, checkRowOrder=True)
 
     def test_empty_no_column(self):
         df1 = self.spark.range(0, 10).drop("id").limit(0)
 
         df2 = self.spark.range(0, 10).drop("id").limit(0)
 
-        assertDataFrameEqual(df1, df2)
+        assertDataFrameEqual(df1, df2, checkRowOrder=False)
+        assertDataFrameEqual(df1, df2, checkRowOrder=True)
+
+    def test_list_row_equal(self):
+        from pyspark.sql import Row
+
+        df1 = self.spark.createDataFrame(
+            data=[
+                (1, 3000),
+                (2, 1000),
+            ],
+            schema=["id", "amount"],
+        )
+
+        list_of_rows = [Row(1, 3000), Row(2, 1000)]
+
+        assertDataFrameEqual(df1, list_of_rows, checkRowOrder=False)
+        assertDataFrameEqual(df1, list_of_rows, checkRowOrder=True)
+
+    def test_list_row_unequal_schema(self):
+        from pyspark.sql import Row
+
+        df1 = self.spark.createDataFrame(
+            data=[
+                (1, 3000),
+                (2, 1000),
+            ],
+            schema=["id", "amount"],
+        )
+
+        list_of_rows = [Row(1, "3000"), Row(2, "1000")]
+
+        expected_error_message = "Results do not match: "
+        percent_diff = (2 / 2) * 100
+        expected_error_message += "( %.5f %% )" % percent_diff
+        diff_msg = (
+            "[df]"
+            + "\n"
+            + str(df1.collect()[0])
+            + "\n\n"
+            + "[expected]"
+            + "\n"
+            + str(list_of_rows[0])
+            + "\n\n"
+            + "********************"
+            + "\n\n"
+        )
+        diff_msg += (
+            "[df]"
+            + "\n"
+            + str(df1.collect()[1])
+            + "\n\n"
+            + "[expected]"
+            + "\n"
+            + str(list_of_rows[1])
+            + "\n\n"
+            + "********************"
+            + "\n\n"
+        )
+        expected_error_message += "\n" + diff_msg
+
+        with self.assertRaises(PySparkAssertionError) as pe:
+            assertDataFrameEqual(df1, list_of_rows)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="DIFFERENT_ROWS",
+            message_parameters={"error_msg": expected_error_message},
+        )
+
+        with self.assertRaises(PySparkAssertionError) as pe:
+            assertDataFrameEqual(df1, list_of_rows, checkRowOrder=True)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="DIFFERENT_ROWS",
+            message_parameters={"error_msg": expected_error_message},
+        )
 
 
 class UtilsTests(ReusedSQLTestCase, UtilsTestsMixin):
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index defb6852969..cb3f72b75ed 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -222,10 +222,12 @@ class PySparkErrorTestUtils:
         )
 
 
-def assertDataFrameEqual(df: DataFrame, expected: DataFrame, check_row_order: 
bool = False):
+def assertDataFrameEqual(
+    df: DataFrame, expected: Union[DataFrame, List[Row]], checkRowOrder: bool 
= False
+):
     """
-    A util function to assert equality between DataFrames `df` and `expected`, 
with
-    optional parameter `check_row_order`.
+    A util function to assert equality between `df` (DataFrame) and `expected`
+    (either DataFrame or list of Rows), with optional parameter 
`checkRowOrder`.
 
     .. versionadded:: 3.5.0
 
@@ -236,10 +238,10 @@ def assertDataFrameEqual(df: DataFrame, expected: 
DataFrame, check_row_order: bo
     df : DataFrame
         The DataFrame that is being compared or tested.
 
-    expected : DataFrame
+    expected : DataFrame or list of Rows
         The expected result of the operation, for comparison with the actual 
result.
 
-    check_row_order : bool, optional
+    checkRowOrder : bool, optional
         A flag indicating whether the order of rows should be considered in 
the comparison.
         If set to `False` (default), the row order is not taken into account.
         If set to `True`, the order of rows is important and will be checked 
during comparison.
@@ -292,7 +294,11 @@ def assertDataFrameEqual(df: DataFrame, expected: 
DataFrame, check_row_order: bo
                 error_class="UNSUPPORTED_DATA_TYPE",
                 message_parameters={"data_type": type(df)},
             )
-        elif not isinstance(expected, DataFrame) and not isinstance(expected, 
ConnectDataFrame):
+        elif (
+            not isinstance(expected, DataFrame)
+            and not isinstance(expected, ConnectDataFrame)
+            and not isinstance(expected, List)
+        ):
             raise PySparkAssertionError(
                 error_class="UNSUPPORTED_DATA_TYPE",
                 message_parameters={"data_type": type(expected)},
@@ -303,7 +309,7 @@ def assertDataFrameEqual(df: DataFrame, expected: 
DataFrame, check_row_order: bo
                 error_class="UNSUPPORTED_DATA_TYPE",
                 message_parameters={"data_type": type(df)},
             )
-        elif not isinstance(expected, DataFrame):
+        elif not isinstance(expected, DataFrame) and not isinstance(expected, 
List):
             raise PySparkAssertionError(
                 error_class="UNSUPPORTED_DATA_TYPE",
                 message_parameters={"data_type": type(expected)},
@@ -379,22 +385,22 @@ def assertDataFrameEqual(df: DataFrame, expected: 
DataFrame, check_row_order: bo
                 message_parameters={"error_msg": error_msg},
             )
 
-    if not check_row_order:
-        try:
-            # rename duplicate columns for sorting
-            renamed_df = df.toDF(*[f"_{i}" for i in range(len(df.columns))])
-            renamed_expected = expected.toDF(*[f"_{i}" for i in 
range(len(expected.columns))])
+    # convert df and expected to list
+    if not isinstance(expected, List):
+        # only compare schema if expected is not a List
+        assert_schema_equal(df.schema, expected.schema)
+        expected_list = expected.collect()
+    else:
+        expected_list = expected
 
-            df = renamed_df.sort(renamed_df.columns).toDF(*df.columns)
-            expected = 
renamed_expected.sort(renamed_expected.columns).toDF(*expected.columns)
-        except Exception:
-            raise PySparkAssertionError(
-                error_class="UNSUPPORTED_DATA_TYPE_FOR_IGNORE_ROW_ORDER",
-                message_parameters={},
-            )
+    df_list = df.collect()
+
+    if not checkRowOrder:
+        # rename duplicate columns for sorting
+        df_list = sorted(df_list, key=lambda x: str(x))
+        expected_list = sorted(expected_list, key=lambda x: str(x))
 
-    assert_schema_equal(df.schema, expected.schema)
-    assert_rows_equal(df.collect(), expected.collect())
+    assert_rows_equal(df_list, expected_list)
 
 
 def _test() -> None:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to