This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5a49af205411 [SPARK-45555][PYTHON] Includes a debuggable object for 
failed assertion
5a49af205411 is described below

commit 5a49af205411feaf0f5aee07f5d6d122e10bfe1f
Author: Haejoon Lee <haejoon....@databricks.com>
AuthorDate: Tue Nov 7 16:30:49 2023 -0800

    [SPARK-45555][PYTHON] Includes a debuggable object for failed assertion
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to enhanced the `assertDataFrameEqual` function to support 
an optional `includeDiffRows` parameter. This parameter, will return the rows 
from both DataFrames that are not equal when set to `True`.
    
    ### Why are the changes needed?
    
    This enhancement provides users with an easier debugging experience by 
directly pointing out the rows that do not match, eliminating the need for 
manual comparison in case of large DataFrames.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. An optional parameter `includeDiffRows` has been introduced in the 
`assertDataFrameEqual` function. When set to `True`, it will return unequal 
rows for further analysis. For example:
    
    ```python
    df1 = spark.createDataFrame(
        data=[("1", 1000.00), ("2", 3000.00), ("3", 2000.00)], schema=["id", 
"amount"])
    df2 = spark.createDataFrame(
        data=[("1", 1001.00), ("2", 3000.00), ("3", 2003.00)], schema=["id", 
"amount"])
    
    try:
        assertDataFrameEqual(df1, df2, includeDiffRows=True)
    except PySparkAssertionError as e:
        spark.createDataFrame(e.data).show()
    ```
    
    The above code will produce the following DataFrame:
    ```
    +-----------+-----------+
    |         _1|         _2|
    +-----------+-----------+
    |{1, 1000.0}|{1, 1001.0}|
    |{3, 2000.0}|{3, 2003.0}|
    +-----------+-----------+
    ```
    
    ### How was this patch tested?
    
    Added usage example into doctest.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #43444 from itholic/SPARK-45555.
    
    Authored-by: Haejoon Lee <haejoon....@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/errors/exceptions/base.py | 15 ++++++++++++-
 python/pyspark/sql/tests/test_utils.py   | 22 +++++++++++++++++++
 python/pyspark/testing/utils.py          | 37 +++++++++++++++++++++++++++-----
 3 files changed, 68 insertions(+), 6 deletions(-)

diff --git a/python/pyspark/errors/exceptions/base.py 
b/python/pyspark/errors/exceptions/base.py
index 518a2d99ce88..5ab73b63d362 100644
--- a/python/pyspark/errors/exceptions/base.py
+++ b/python/pyspark/errors/exceptions/base.py
@@ -15,11 +15,14 @@
 # limitations under the License.
 #
 
-from typing import Dict, Optional, cast
+from typing import Dict, Optional, cast, Iterable, TYPE_CHECKING
 
 from pyspark.errors.utils import ErrorClassesReader
 from pickle import PicklingError
 
+if TYPE_CHECKING:
+    from pyspark.sql.types import Row
+
 
 class PySparkException(Exception):
     """
@@ -222,6 +225,16 @@ class PySparkAssertionError(PySparkException, 
AssertionError):
     Wrapper class for AssertionError to support error classes.
     """
 
+    def __init__(
+        self,
+        message: Optional[str] = None,
+        error_class: Optional[str] = None,
+        message_parameters: Optional[Dict[str, str]] = None,
+        data: Optional[Iterable["Row"]] = None,
+    ):
+        super().__init__(message, error_class, message_parameters)
+        self.data = data
+
 
 class PySparkNotImplementedError(PySparkException, NotImplementedError):
     """
diff --git a/python/pyspark/sql/tests/test_utils.py 
b/python/pyspark/sql/tests/test_utils.py
index 421043a41bb4..ebdab31ec207 100644
--- a/python/pyspark/sql/tests/test_utils.py
+++ b/python/pyspark/sql/tests/test_utils.py
@@ -1614,6 +1614,28 @@ class UtilsTestsMixin:
             message_parameters={"error_msg": error_msg},
         )
 
+    def test_dataframe_include_diff_rows(self):
+        df1 = self.spark.createDataFrame(
+            [("1", 1000.00), ("2", 3000.00), ("3", 2000.00)], ["id", "amount"]
+        )
+        df2 = self.spark.createDataFrame(
+            [("1", 1001.00), ("2", 3000.00), ("3", 2003.00)], ["id", "amount"]
+        )
+
+        with self.assertRaises(PySparkAssertionError) as context:
+            assertDataFrameEqual(df1, df2, includeDiffRows=True)
+
+        # Extracting the differing rows data from the exception
+        error_data = context.exception.data
+
+        # Expected differences
+        expected_diff = [
+            (Row(id="1", amount=1000.0), Row(id="1", amount=1001.0)),
+            (Row(id="3", amount=2000.0), Row(id="3", amount=2003.0)),
+        ]
+
+        self.assertEqual(error_data, expected_diff)
+
     def test_dataframe_ignore_column_order(self):
         df1 = self.spark.createDataFrame([Row(A=1, B=2), Row(A=3, B=4)])
         df2 = self.spark.createDataFrame([Row(B=2, A=1), Row(B=4, A=3)])
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index 8b2332208c19..5d284ffc7956 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -484,6 +484,7 @@ def assertDataFrameEqual(
     ignoreColumnType: bool = False,
     maxErrors: Optional[int] = None,
     showOnlyDiff: bool = False,
+    includeDiffRows=False,
 ):
     r"""
     A util function to assert equality between `actual` and `expected`
@@ -560,6 +561,11 @@ def assertDataFrameEqual(
         If set to `False` (default), the error message will include all rows
         (when there is at least one row that is different).
 
+        .. versionadded:: 4.0.0
+    includeDiffRows: bool, False
+        If set to `True`, the unequal rows are included in 
PySparkAssertionError for further
+        debugging. If set to `False` (default), the unequal rows are not 
returned as a data set.
+
         .. versionadded:: 4.0.0
 
     Notes
@@ -704,6 +710,24 @@ def assertDataFrameEqual(
     *** expected ***
     ! Row(_1=2, _2='X')
     ! Row(_1=3, _2='Y')
+
+    The `includeDiffRows` parameter can be used to include the rows that did 
not match
+    in the PySparkAssertionError. This can be useful for debugging or further 
analysis.
+
+    >>> df1 = spark.createDataFrame(
+    ...     data=[("1", 1000.00), ("2", 3000.00), ("3", 2000.00)], 
schema=["id", "amount"])
+    >>> df2 = spark.createDataFrame(
+    ...     data=[("1", 1001.00), ("2", 3000.00), ("3", 2003.00)], 
schema=["id", "amount"])
+    >>> try:
+    ...     assertDataFrameEqual(df1, df2, includeDiffRows=True)
+    ... except PySparkAssertionError as e:
+    ...     spark.createDataFrame(e.data).show()  # doctest: 
+NORMALIZE_WHITESPACE
+    +-----------+-----------+
+    |         _1|         _2|
+    +-----------+-----------+
+    |{1, 1000.0}|{1, 1001.0}|
+    |{3, 2000.0}|{3, 2003.0}|
+    +-----------+-----------+
     """
     if actual is None and expected is None:
         return True
@@ -843,7 +867,8 @@ def assertDataFrameEqual(
     ):
         zipped = list(zip_longest(rows1, rows2))
         diff_rows_cnt = 0
-        diff_rows = False
+        diff_rows = []
+        has_diff_rows = False
 
         rows_str1 = ""
         rows_str2 = ""
@@ -852,7 +877,9 @@ def assertDataFrameEqual(
         for r1, r2 in zipped:
             if not compare_rows(r1, r2):
                 diff_rows_cnt += 1
-                diff_rows = True
+                has_diff_rows = True
+                if includeDiffRows:
+                    diff_rows.append((r1, r2))
                 rows_str1 += str(r1) + "\n"
                 rows_str2 += str(r2) + "\n"
                 if maxErrors is not None and diff_rows_cnt >= maxErrors:
@@ -865,14 +892,14 @@ def assertDataFrameEqual(
             actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), 
n=len(zipped)
         )
 
-        if diff_rows:
+        if has_diff_rows:
             error_msg = "Results do not match: "
             percent_diff = (diff_rows_cnt / len(zipped)) * 100
             error_msg += "( %.5f %% )" % percent_diff
             error_msg += "\n" + "\n".join(generated_diff)
+            data = diff_rows if includeDiffRows else None
             raise PySparkAssertionError(
-                error_class="DIFFERENT_ROWS",
-                message_parameters={"error_msg": error_msg},
+                error_class="DIFFERENT_ROWS", message_parameters={"error_msg": 
error_msg}, data=data
             )
 
     # only compare schema if expected is not a List


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to