This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4af4ddea116 [SPARK-45552][PS] Introduce flexible parameters to 
`assertDataFrameEqual`
4af4ddea116 is described below

commit 4af4ddea116d26086550596693ce09674e75bfa3
Author: Haejoon Lee <haejoon....@databricks.com>
AuthorDate: Mon Oct 30 11:07:01 2023 +0900

    [SPARK-45552][PS] Introduce flexible parameters to `assertDataFrameEqual`
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to add six new parameters to the `assertDataFrameEqual`: 
`ignoreNullable`, `ignoreColumnOrder`, `ignoreColumnName`, `ignoreColumnType`, 
`maxErrors`, and `showOnlyDiff` to provide users with more flexibility in 
DataFrame testing.
    
    ### Why are the changes needed?
    
    To enhance the utility of `assertDataFrameEqual` by accommodating various 
common DataFrame comparison scenarios that users might encounter, without 
necessitating manual adjustments or workarounds.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. `assertDataFrameEqual` now have the option to use the six new 
parameters:
    <!DOCTYPE html>
    
    Parameter | Type | Comment
    -- | -- | --
    ignoreNullable | Boolean [optional] | Specifies whether a column’s nullable 
property is included when checking for schema equality.</br></br> When set to 
True (default), the nullable property of the columns being compared is not 
taken into account and the columns will be considered equal even if they have 
different nullable settings.</br></br>When set to False, columns are considered 
equal only if they have the same nullable setting.
    ignoreColumnOrder | Boolean [optional] | Specifies whether to compare 
columns in the order they appear in the DataFrames or by column name.</br></br> 
When set to False (default), columns are compared in the order they appear in 
the DataFrames.</br></br> When set to True, a column in the expected DataFrame 
is compared to the column with the same name in the actual DataFrame. 
</br></br>ignoreColumnOrder cannot be set to True if ignoreColumnNames is also 
set to True.
    ignoreColumnName | Boolean [optional] | Specifies whether to fail the 
initial schema equality check if the column names in the two DataFrames are 
different.</br></br> When set to False (default), column names are checked and 
the function fails if they are different.</br></br> When set to True, the 
function will succeed even if column names are different. Column data types are 
compared for columns in the order they appear in the DataFrames.</br></br> 
ignoreColumnNames cannot be set to  [...]
    ignoreColumnType | Boolean [optional] | Specifies whether to ignore the 
data type of the columns when comparing.</br></br> When set to False (default), 
column data types are checked and the function fails if they are 
different.</br></br> When set to True, the schema equality check will succeed 
even if column data types are different and the function will attempt to 
compare rows.
    maxErrors | Integer [optional] | The maximum number of row comparison 
failures to encounter before returning.</br></br> When this number of row 
comparisons have failed, the function returns independent of how many rows have 
been compared.</br></br> Set to None by default which means compare all rows 
independent of number of failures.
    showOnlyDiff | Boolean [optional] | If set to True, the error message will 
only include rows that are different.</br></br> If set to False (default), the 
error message will include all rows (when there is at least one row that is 
different).
    
    ### How was this patch tested?
    
    Added usage examples into doctest for each parameter.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #43433 from itholic/SPARK-45552.
    
    Authored-by: Haejoon Lee <haejoon....@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/tests/test_utils.py |  68 +++++++++++
 python/pyspark/testing/utils.py        | 215 +++++++++++++++++++++++++++++++--
 2 files changed, 274 insertions(+), 9 deletions(-)

diff --git a/python/pyspark/sql/tests/test_utils.py 
b/python/pyspark/sql/tests/test_utils.py
index a2cad4e83bd..421043a41bb 100644
--- a/python/pyspark/sql/tests/test_utils.py
+++ b/python/pyspark/sql/tests/test_utils.py
@@ -1238,6 +1238,9 @@ class UtilsTestsMixin:
 
         assertDataFrameEqual(df1, df2)
 
+        with self.assertRaises(PySparkAssertionError):
+            assertDataFrameEqual(df1, df2, ignoreNullable=False)
+
     def test_schema_ignore_nullable_array_equal(self):
         s1 = StructType([StructField("names", ArrayType(DoubleType(), True), 
True)])
         s2 = StructType([StructField("names", ArrayType(DoubleType(), False), 
False)])
@@ -1611,6 +1614,71 @@ class UtilsTestsMixin:
             message_parameters={"error_msg": error_msg},
         )
 
+    def test_dataframe_ignore_column_order(self):
+        df1 = self.spark.createDataFrame([Row(A=1, B=2), Row(A=3, B=4)])
+        df2 = self.spark.createDataFrame([Row(B=2, A=1), Row(B=4, A=3)])
+
+        with self.assertRaises(PySparkAssertionError):
+            assertDataFrameEqual(df1, df2, ignoreColumnOrder=False)
+
+        assertDataFrameEqual(df1, df2, ignoreColumnOrder=True)
+
+    def test_dataframe_ignore_column_name(self):
+        df1 = self.spark.createDataFrame([(1, 2), (3, 4)], ["A", "B"])
+        df2 = self.spark.createDataFrame([(1, 2), (3, 4)], ["X", "Y"])
+
+        with self.assertRaises(PySparkAssertionError):
+            assertDataFrameEqual(df1, df2, ignoreColumnName=False)
+
+        assertDataFrameEqual(df1, df2, ignoreColumnName=True)
+
+    def test_dataframe_ignore_column_type(self):
+        df1 = self.spark.createDataFrame([(1, "2"), (3, "4")], ["A", "B"])
+        df2 = self.spark.createDataFrame([(1, 2), (3, 4)], ["A", "B"])
+
+        with self.assertRaises(PySparkAssertionError):
+            assertDataFrameEqual(df1, df2, ignoreColumnType=False)
+
+        assertDataFrameEqual(df1, df2, ignoreColumnType=True)
+
+    def test_dataframe_max_errors(self):
+        df1 = self.spark.createDataFrame([(1, "a"), (2, "b"), (3, "c"), (4, 
"d")], ["id", "value"])
+        df2 = self.spark.createDataFrame([(1, "a"), (2, "z"), (3, "x"), (4, 
"y")], ["id", "value"])
+
+        # We expect differences in rows 2, 3, and 4.
+        # Setting maxErrors to 2 will limit the reported errors.
+        maxErrors = 2
+        with self.assertRaises(PySparkAssertionError) as context:
+            assertDataFrameEqual(df1, df2, maxErrors=maxErrors)
+
+        # Check if the error message contains information about 2 mismatches 
only.
+        error_message = str(context.exception)
+        self.assertTrue("! Row" in error_message and error_message.count("! 
Row") == maxErrors * 2)
+
+    def test_dataframe_show_only_diff(self):
+        df1 = self.spark.createDataFrame(
+            [(1, "apple", "red"), (2, "banana", "yellow"), (3, "cherry", 
"red")],
+            ["id", "fruit", "color"],
+        )
+        df2 = self.spark.createDataFrame(
+            [(1, "apple", "green"), (2, "banana", "yellow"), (3, "cherry", 
"blue")],
+            ["id", "fruit", "color"],
+        )
+
+        with self.assertRaises(PySparkAssertionError) as context:
+            assertDataFrameEqual(df1, df2, showOnlyDiff=False)
+
+        error_message = str(context.exception)
+
+        self.assertTrue("apple" in error_message and "banana" in error_message)
+
+        with self.assertRaises(PySparkAssertionError) as context:
+            assertDataFrameEqual(df1, df2, showOnlyDiff=True)
+
+        error_message = str(context.exception)
+
+        self.assertTrue("apple" in error_message and "banana" not in 
error_message)
+
 
 class UtilsTests(ReusedSQLTestCase, UtilsTestsMixin):
     def test_capture_analysis_exception(self):
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index 5ee27862923..282f4cc1cf5 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -22,6 +22,7 @@ import sys
 import unittest
 import difflib
 import functools
+import math
 from decimal import Decimal
 from time import time, sleep
 from typing import (
@@ -57,6 +58,7 @@ from pyspark.find_spark_home import _find_spark_home
 from pyspark.sql.dataframe import DataFrame
 from pyspark.sql import Row
 from pyspark.sql.types import StructType, StructField
+from pyspark.sql.functions import col, when
 
 
 __all__ = ["assertDataFrameEqual", "assertSchemaEqual"]
@@ -396,6 +398,12 @@ def assertDataFrameEqual(
     checkRowOrder: bool = False,
     rtol: float = 1e-5,
     atol: float = 1e-8,
+    ignoreNullable: bool = True,
+    ignoreColumnOrder: bool = False,
+    ignoreColumnName: bool = False,
+    ignoreColumnType: bool = False,
+    maxErrors: Optional[int] = None,
+    showOnlyDiff: bool = False,
 ):
     r"""
     A util function to assert equality between `actual` and `expected`
@@ -424,6 +432,55 @@ def assertDataFrameEqual(
     atol : float, optional
         The absolute tolerance, used in asserting approximate equality for 
float values in actual
         and expected. Set to 1e-8 by default. (See Notes)
+    ignoreNullable : bool, default True
+        Specifies whether a column’s nullable property is included when 
checking for
+        schema equality.
+        When set to `True` (default), the nullable property of the columns 
being compared
+        is not taken into account and the columns will be considered equal 
even if they have
+        different nullable settings.
+        When set to `False`, columns are considered equal only if they have 
the same nullable
+        setting.
+
+        .. versionadded:: 4.0.0
+    ignoreColumnOrder : bool, default False
+        Specifies whether to compare columns in the order they appear in the 
DataFrame or by
+        column name.
+        If set to `False` (default), columns are compared in the order they 
appear in the
+        DataFrames.
+        When set to `True`, a column in the expected DataFrame is compared to 
the column with the
+        same name in the actual DataFrame.
+
+        .. versionadded:: 4.0.0
+    ignoreColumnName : bool, default False
+        Specifies whether to fail the initial schema equality check if the 
column names in the two
+        DataFrames are different.
+        When set to `False` (default), column names are checked and the 
function fails if they are
+        different.
+        When set to `True`, the function will succeed even if column names are 
different.
+        Column data types are compared for columns in the order they appear in 
the DataFrames.
+
+        .. versionadded:: 4.0.0
+    ignoreColumnType : bool, default False
+        Specifies whether to ignore the data type of the columns when 
comparing.
+        When set to `False` (default), column data types are checked and the 
function fails if they
+        are different.
+        When set to `True`, the schema equality check will succeed even if 
column data types are
+        different and the function will attempt to compare rows.
+
+        .. versionadded:: 4.0.0
+    maxErrors : bool, optional
+        The maximum number of row comparison failures to encounter before 
returning.
+        When this number of row comparisons have failed, the function returns 
independent of how
+        many rows have been compared.
+        Set to None by default which means compare all rows independent of 
number of failures.
+
+        .. versionadded:: 4.0.0
+    showOnlyDiff : bool, default False
+        If set to `True`, the error message will only include rows that are 
different.
+        If set to `False` (default), the error message will include all rows
+        (when there is at least one row that is different).
+
+        .. versionadded:: 4.0.0
 
     Notes
     -----
@@ -440,6 +497,9 @@ def assertDataFrameEqual(
 
     ``absolute(a - b) <= (atol + rtol * absolute(b))``.
 
+    `ignoreColumnOrder` cannot be set to `True` if `ignoreColumnNames` is also 
set to `True`.
+    `ignoreColumnNames` cannot be set to `True` if `ignoreColumnOrder` is also 
set to `True`.
+
     Examples
     --------
     >>> df1 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], 
schema=["id", "amount"])
@@ -469,12 +529,101 @@ def assertDataFrameEqual(
     PySparkAssertionError: [DIFFERENT_ROWS] Results do not match: ( 66.66667 % 
)
     *** actual ***
     ! Row(id='1', amount=1000.0)
-    Row(id='2', amount=3000.0)
+      Row(id='2', amount=3000.0)
     ! Row(id='3', amount=2000.0)
     *** expected ***
     ! Row(id='1', amount=1001.0)
-    Row(id='2', amount=3000.0)
+      Row(id='2', amount=3000.0)
     ! Row(id='3', amount=2003.0)
+
+    Example for ignoreNullable
+
+    >>> from pyspark.sql.types import StructType, StructField, StringType, 
LongType
+    >>> df1_nullable = spark.createDataFrame(
+    ...     data=[(1000, "1"), (5000, "2")],
+    ...     schema=StructType(
+    ...         [StructField("amount", LongType(), True), StructField("id", 
StringType(), True)]
+    ...     )
+    ... )
+    >>> df2_nullable = spark.createDataFrame(
+    ...     data=[(1000, "1"), (5000, "2")],
+    ...     schema=StructType(
+    ...         [StructField("amount", LongType(), True), StructField("id", 
StringType(), False)]
+    ...     )
+    ... )
+    >>> assertDataFrameEqual(df1_nullable, df2_nullable, ignoreNullable=True)  
# pass
+    >>> assertDataFrameEqual(
+    ...     df1_nullable, df2_nullable, ignoreNullable=False
+    ... )  # doctest: +IGNORE_EXCEPTION_DETAIL
+    Traceback (most recent call last):
+    ...
+    PySparkAssertionError: [DIFFERENT_SCHEMA] Schemas do not match.
+    --- actual
+    +++ expected
+    - StructType([StructField('amount', LongType(), True), StructField('id', 
StringType(), True)])
+    ?                                                                          
            ^^^
+    + StructType([StructField('amount', LongType(), True), StructField('id', 
StringType(), False)])
+    ?                                                                          
            ^^^^
+
+    Example for ignoreColumnOrder
+
+    >>> df1_col_order = spark.createDataFrame(
+    ...     data=[(1000, "1"), (5000, "2")], schema=["amount", "id"]
+    ... )
+    >>> df2_col_order = spark.createDataFrame(
+    ...     data=[("1", 1000), ("2", 5000)], schema=["id", "amount"]
+    ... )
+    >>> assertDataFrameEqual(df1_col_order, df2_col_order, 
ignoreColumnOrder=True)
+
+    Example for ignoreColumnName
+
+    >>> df1_col_names = spark.createDataFrame(
+    ...     data=[(1000, "1"), (5000, "2")], schema=["amount", "identity"]
+    ... )
+    >>> df2_col_names = spark.createDataFrame(
+    ...     data=[(1000, "1"), (5000, "2")], schema=["amount", "id"]
+    ... )
+    >>> assertDataFrameEqual(df1_col_names, df2_col_names, 
ignoreColumnName=True)
+
+    Example for ignoreColumnType
+
+    >>> df1_col_types = spark.createDataFrame(
+    ...     data=[(1000, "1"), (5000, "2")], schema=["amount", "id"]
+    ... )
+    >>> df2_col_types = spark.createDataFrame(
+    ...     data=[(1000.0, "1"), (5000.0, "2")], schema=["amount", "id"]
+    ... )
+    >>> assertDataFrameEqual(df1_col_types, df2_col_types, 
ignoreColumnType=True)
+
+    Example for maxErrors (will only report the first mismatching row)
+
+    >>> df1 = spark.createDataFrame([(1, "A"), (2, "B"), (3, "C")])
+    >>> df2 = spark.createDataFrame([(1, "A"), (2, "X"), (3, "Y")])
+    >>> assertDataFrameEqual(df1, df2, maxErrors=1)  # doctest: 
+IGNORE_EXCEPTION_DETAIL
+    Traceback (most recent call last):
+    ...
+    PySparkAssertionError: [DIFFERENT_ROWS] Results do not match: ( 33.33333 % 
)
+    *** actual ***
+      Row(_1=1, _2='A')
+    ! Row(_1=2, _2='B')
+    *** expected ***
+      Row(_1=1, _2='A')
+    ! Row(_1=2, _2='X')
+
+    Example for showOnlyDiff (will only report the mismatching rows)
+
+    >>> df1 = spark.createDataFrame([(1, "A"), (2, "B"), (3, "C")])
+    >>> df2 = spark.createDataFrame([(1, "A"), (2, "X"), (3, "Y")])
+    >>> assertDataFrameEqual(df1, df2, showOnlyDiff=True)  # doctest: 
+IGNORE_EXCEPTION_DETAIL
+    Traceback (most recent call last):
+    ...
+    PySparkAssertionError: [DIFFERENT_ROWS] Results do not match: ( 66.66667 % 
)
+    *** actual ***
+    ! Row(_1=2, _2='B')
+    ! Row(_1=3, _2='C')
+    *** expected ***
+    ! Row(_1=2, _2='X')
+    ! Row(_1=3, _2='Y')
     """
     if actual is None and expected is None:
         return True
@@ -546,6 +695,37 @@ def assertDataFrameEqual(
                 },
             )
 
+    if ignoreColumnOrder:
+        actual = actual.select(*sorted(actual.columns))
+        expected = expected.select(*sorted(expected.columns))
+
+    def rename_dataframe_columns(df: DataFrame) -> DataFrame:
+        """Rename DataFrame columns to sequential numbers for comparison"""
+        renamed_columns = [str(i) for i in range(len(df.columns))]
+        return df.toDF(*renamed_columns)
+
+    if ignoreColumnName:
+        actual = rename_dataframe_columns(actual)
+        expected = rename_dataframe_columns(expected)
+
+    def cast_columns_to_string(df: DataFrame) -> DataFrame:
+        """Cast all DataFrame columns to string for comparison"""
+        for col_name in df.columns:
+            # Add logic to remove trailing .0 for float columns that are whole 
numbers
+            df = df.withColumn(
+                col_name,
+                when(
+                    (col(col_name).cast("float").isNotNull())
+                    & (col(col_name).cast("float") == 
col(col_name).cast("int")),
+                    col(col_name).cast("int").cast("string"),
+                ).otherwise(col(col_name).cast("string")),
+            )
+        return df
+
+    if ignoreColumnType:
+        actual = cast_columns_to_string(actual)
+        expected = cast_columns_to_string(expected)
+
     def compare_rows(r1: Row, r2: Row):
         def compare_vals(val1, val2):
             if isinstance(val1, list) and isinstance(val2, list):
@@ -578,7 +758,9 @@ def assertDataFrameEqual(
 
         return compare_vals(r1, r2)
 
-    def assert_rows_equal(rows1: List[Row], rows2: List[Row]):
+    def assert_rows_equal(
+        rows1: List[Row], rows2: List[Row], maxErrors: int = None, 
showOnlyDiff: bool = False
+    ):
         zipped = list(zip_longest(rows1, rows2))
         diff_rows_cnt = 0
         diff_rows = False
@@ -588,11 +770,16 @@ def assertDataFrameEqual(
 
         # count different rows
         for r1, r2 in zipped:
-            rows_str1 += str(r1) + "\n"
-            rows_str2 += str(r2) + "\n"
             if not compare_rows(r1, r2):
                 diff_rows_cnt += 1
                 diff_rows = True
+                rows_str1 += str(r1) + "\n"
+                rows_str2 += str(r2) + "\n"
+                if maxErrors is not None and diff_rows_cnt >= maxErrors:
+                    break
+            elif not showOnlyDiff:
+                rows_str1 += str(r1) + "\n"
+                rows_str2 += str(r2) + "\n"
 
         generated_diff = _context_diff(
             actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), 
n=len(zipped)
@@ -608,10 +795,20 @@ def assertDataFrameEqual(
                 message_parameters={"error_msg": error_msg},
             )
 
-    # convert actual and expected to list
+    # only compare schema if expected is not a List
     if not isinstance(actual, list) and not isinstance(expected, list):
-        # only compare schema if expected is not a List
-        assertSchemaEqual(actual.schema, expected.schema)
+        if ignoreNullable:
+            assertSchemaEqual(actual.schema, expected.schema)
+        elif actual.schema != expected.schema:
+            generated_diff = difflib.ndiff(
+                str(actual.schema).splitlines(), 
str(expected.schema).splitlines()
+            )
+            error_msg = "\n".join(generated_diff)
+
+            raise PySparkAssertionError(
+                error_class="DIFFERENT_SCHEMA",
+                message_parameters={"error_msg": error_msg},
+            )
 
     if not isinstance(actual, list):
         actual_list = actual.collect()
@@ -628,7 +825,7 @@ def assertDataFrameEqual(
         actual_list = sorted(actual_list, key=lambda x: str(x))
         expected_list = sorted(expected_list, key=lambda x: str(x))
 
-    assert_rows_equal(actual_list, expected_list)
+    assert_rows_equal(actual_list, expected_list, maxErrors=maxErrors, 
showOnlyDiff=showOnlyDiff)
 
 
 def _test() -> None:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to