This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4f66f091ad5 [SPARK-44218][PYTHON] Customize diff log in 
assertDataFrameEqual error message format
4f66f091ad5 is described below

commit 4f66f091ad5dbdd9177a7550a8da7e266a869147
Author: Amanda Liu <amanda....@databricks.com>
AuthorDate: Wed Aug 2 08:40:54 2023 +0900

    [SPARK-44218][PYTHON] Customize diff log in assertDataFrameEqual error 
message format
    
    ### What changes were proposed in this pull request?
    This PR improves the error message format for `assertDataFrameEqual` by 
creating a new custom diff log function, `_context_diff_ `to print all Rows for 
`actual` and `expected` and highlight their different rows in color.
    
    ### Why are the changes needed?
    The change is needed to clarify the error message for unequal DataFrames.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, the PR affects the error message display for users. See the new error 
message output examples below.
    
    ### How was this patch tested?
    Added tests to `python/pyspark/sql/tests/test_utils.py` and 
`python/pyspark/sql/tests/connect/test_utils.py`.
    
    Example error messages:
    <img width="894" alt="Screenshot 2023-07-31 at 7 41 55 PM" 
src="https://github.com/apache/spark/assets/68875504/04b5b985-4670-4d4b-8032-9704523a3df1";>
    
    <img width="875" alt="errormessage1" 
src="https://github.com/apache/spark/assets/68875504/d7b420c0-80d3-4853-b8e5-48ff6b459dc7";>
    
    Closes #42196 from asl3/update-difflib.
    
    Authored-by: Amanda Liu <amanda....@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/errors/error_classes.py    |  24 ++-
 python/pyspark/pandas/tests/test_utils.py |   1 -
 python/pyspark/sql/tests/test_utils.py    | 269 ++++++++++++++++++++----------
 python/pyspark/testing/utils.py           |  69 +++++++-
 4 files changed, 261 insertions(+), 102 deletions(-)

diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index f1396c49af0..d6f093246da 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -167,36 +167,44 @@ ERROR_CLASSES_JSON = """
   "DIFFERENT_PANDAS_DATAFRAME" : {
     "message" : [
       "DataFrames are not almost equal:",
-      "Left: <left>",
+      "Left:",
+      "<left>",
       "<left_dtype>",
-      "Right: <right>",
+      "Right:",
+      "<right>",
       "<right_dtype>"
     ]
   },
   "DIFFERENT_PANDAS_INDEX" : {
     "message" : [
       "Indices are not almost equal:",
-      "Left: <left>",
+      "Left:",
+      "<left>",
       "<left_dtype>",
-      "Right: <right>",
+      "Right:",
+      "<right>",
       "<right_dtype>"
     ]
   },
   "DIFFERENT_PANDAS_MULTIINDEX" : {
     "message" : [
       "MultiIndices are not almost equal:",
-      "Left: <left>",
+      "Left:",
+      "<left>",
       "<left_dtype>",
-      "Right: <right>",
+      "Right:",
+      "<right>",
       "<right_dtype>"
     ]
   },
   "DIFFERENT_PANDAS_SERIES" : {
     "message" : [
       "Series are not almost equal:",
-      "Left: <left>",
+      "Left:",
+      "<left>",
       "<left_dtype>",
-      "Right: <right>",
+      "Right:",
+      "<right>",
       "<right_dtype>"
     ]
   },
diff --git a/python/pyspark/pandas/tests/test_utils.py 
b/python/pyspark/pandas/tests/test_utils.py
index de7b0449dae..3d658446f27 100644
--- a/python/pyspark/pandas/tests/test_utils.py
+++ b/python/pyspark/pandas/tests/test_utils.py
@@ -16,7 +16,6 @@
 #
 
 import pandas as pd
-from typing import Union
 
 from pyspark.pandas.indexes.base import Index
 from pyspark.pandas.utils import (
diff --git a/python/pyspark/sql/tests/test_utils.py 
b/python/pyspark/sql/tests/test_utils.py
index b7ab596880f..76d397e3ade 100644
--- a/python/pyspark/sql/tests/test_utils.py
+++ b/python/pyspark/sql/tests/test_utils.py
@@ -23,7 +23,7 @@ from pyspark.errors import (
     IllegalArgumentException,
     SparkUpgradeException,
 )
-from pyspark.testing.utils import assertDataFrameEqual, assertSchemaEqual
+from pyspark.testing.utils import assertDataFrameEqual, assertSchemaEqual, 
_context_diff
 from pyspark.testing.sqlutils import ReusedSQLTestCase
 from pyspark.sql import Row
 import pyspark.sql.functions as F
@@ -44,6 +44,7 @@ from pyspark.sql.dataframe import DataFrame
 
 import difflib
 from typing import List, Union
+from itertools import zip_longest
 
 
 class UtilsTestsMixin:
@@ -151,17 +152,22 @@ class UtilsTestsMixin:
             ),
         )
 
-        expected_error_message = "Results do not match: "
-        percent_diff = (1 / 2) * 100
-        expected_error_message += "( %.5f %% )" % percent_diff
+        rows_str1 = ""
+        rows_str2 = ""
+
+        # count different rows
+        for r1, r2 in list(zip_longest(df1.collect(), df2.collect())):
+            rows_str1 += str(r1) + "\n"
+            rows_str2 += str(r2) + "\n"
 
-        generated_diff = difflib.ndiff(
-            str(df1.collect()[1]).splitlines(), 
str(df2.collect()[1]).splitlines()
+        generated_diff = _context_diff(
+            actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=2
         )
-        diff_msg = "\n" + "\n".join(generated_diff) + "\n"
-        diff_msg += "********************" + "\n"
 
-        expected_error_message += "\n" + "--- actual\n+++ expected\n" + 
diff_msg
+        error_msg = "Results do not match: "
+        percent_diff = (1 / 2) * 100
+        error_msg += "( %.5f %% )" % percent_diff
+        error_msg += "\n" + "\n".join(generated_diff)
 
         with self.assertRaises(PySparkAssertionError) as pe:
             assertDataFrameEqual(df1, df2)
@@ -169,7 +175,16 @@ class UtilsTestsMixin:
         self.check_error(
             exception=pe.exception,
             error_class="DIFFERENT_ROWS",
-            message_parameters={"error_msg": expected_error_message},
+            message_parameters={"error_msg": error_msg},
+        )
+
+        with self.assertRaises(PySparkAssertionError) as pe:
+            assertDataFrameEqual(df1, df2, checkRowOrder=True)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="DIFFERENT_ROWS",
+            message_parameters={"error_msg": error_msg},
         )
 
     def test_assert_approx_equal_arraytype_float_custom_rtol_pass(self):
@@ -266,6 +281,7 @@ class UtilsTestsMixin:
     def test_assert_notequal_arraytype(self):
         df1 = self.spark.createDataFrame(
             data=[
+                ("Amy", ["C++", "Rust"]),
                 ("John", ["Python", "Java"]),
                 ("Jane", ["Scala", "SQL", "Java"]),
             ],
@@ -278,6 +294,7 @@ class UtilsTestsMixin:
         )
         df2 = self.spark.createDataFrame(
             data=[
+                ("Amy", ["C++", "Rust"]),
                 ("John", ["Python", "Java"]),
                 ("Jane", ["Scala", "Java"]),
             ],
@@ -289,17 +306,25 @@ class UtilsTestsMixin:
             ),
         )
 
-        expected_error_message = "Results do not match: "
-        percent_diff = (1 / 2) * 100
-        expected_error_message += "( %.5f %% )" % percent_diff
+        rows_str1 = ""
+        rows_str2 = ""
+
+        sorted_list1 = sorted(df1.collect(), key=lambda x: str(x))
+        sorted_list2 = sorted(df2.collect(), key=lambda x: str(x))
+
+        # count different rows
+        for r1, r2 in list(zip_longest(sorted_list1, sorted_list2)):
+            rows_str1 += str(r1) + "\n"
+            rows_str2 += str(r2) + "\n"
 
-        generated_diff = difflib.ndiff(
-            str(df1.collect()[1]).splitlines(), 
str(df2.collect()[1]).splitlines()
+        generated_diff = _context_diff(
+            actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=3
         )
-        diff_msg = "\n" + "\n".join(generated_diff) + "\n"
-        diff_msg += "********************" + "\n"
 
-        expected_error_message += "\n" + "--- actual\n+++ expected\n" + 
diff_msg
+        error_msg = "Results do not match: "
+        percent_diff = (1 / 3) * 100
+        error_msg += "( %.5f %% )" % percent_diff
+        error_msg += "\n" + "\n".join(generated_diff)
 
         with self.assertRaises(PySparkAssertionError) as pe:
             assertDataFrameEqual(df1, df2)
@@ -307,16 +332,33 @@ class UtilsTestsMixin:
         self.check_error(
             exception=pe.exception,
             error_class="DIFFERENT_ROWS",
-            message_parameters={"error_msg": expected_error_message},
+            message_parameters={"error_msg": error_msg},
+        )
+
+        rows_str1 = ""
+        rows_str2 = ""
+
+        # count different rows
+        for r1, r2 in list(zip_longest(df1.collect(), df2.collect())):
+            rows_str1 += str(r1) + "\n"
+            rows_str2 += str(r2) + "\n"
+
+        generated_diff = _context_diff(
+            actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=3
         )
 
+        error_msg = "Results do not match: "
+        percent_diff = (1 / 3) * 100
+        error_msg += "( %.5f %% )" % percent_diff
+        error_msg += "\n" + "\n".join(generated_diff)
+
         with self.assertRaises(PySparkAssertionError) as pe:
             assertDataFrameEqual(df1, df2, checkRowOrder=True)
 
         self.check_error(
             exception=pe.exception,
             error_class="DIFFERENT_ROWS",
-            message_parameters={"error_msg": expected_error_message},
+            message_parameters={"error_msg": error_msg},
         )
 
     def test_assert_equal_maptype(self):
@@ -588,17 +630,22 @@ class UtilsTestsMixin:
             schema=["id", "amount"],
         )
 
-        expected_error_message = "Results do not match: "
-        percent_diff = (1 / 2) * 100
-        expected_error_message += "( %.5f %% )" % percent_diff
+        rows_str1 = ""
+        rows_str2 = ""
+
+        # count different rows
+        for r1, r2 in list(zip_longest(df1.collect(), df2.collect())):
+            rows_str1 += str(r1) + "\n"
+            rows_str2 += str(r2) + "\n"
 
-        generated_diff = difflib.ndiff(
-            str(df1.collect()[1]).splitlines(), 
str(df2.collect()[1]).splitlines()
+        generated_diff = _context_diff(
+            actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=2
         )
-        diff_msg = "\n" + "\n".join(generated_diff) + "\n"
-        diff_msg += "********************" + "\n"
 
-        expected_error_message += "\n" + "--- actual\n+++ expected\n" + 
diff_msg
+        error_msg = "Results do not match: "
+        percent_diff = (1 / 2) * 100
+        error_msg += "( %.5f %% )" % percent_diff
+        error_msg += "\n" + "\n".join(generated_diff)
 
         with self.assertRaises(PySparkAssertionError) as pe:
             assertDataFrameEqual(df1, df2)
@@ -606,7 +653,7 @@ class UtilsTestsMixin:
         self.check_error(
             exception=pe.exception,
             error_class="DIFFERENT_ROWS",
-            message_parameters={"error_msg": expected_error_message},
+            message_parameters={"error_msg": error_msg},
         )
 
         with self.assertRaises(PySparkAssertionError) as pe:
@@ -615,7 +662,7 @@ class UtilsTestsMixin:
         self.check_error(
             exception=pe.exception,
             error_class="DIFFERENT_ROWS",
-            message_parameters={"error_msg": expected_error_message},
+            message_parameters={"error_msg": error_msg},
         )
 
     def test_assert_equal_nulldf(self):
@@ -762,22 +809,22 @@ class UtilsTestsMixin:
             schema=["id", "amount"],
         )
 
-        expected_error_message = "Results do not match: "
-        percent_diff = (2 / 2) * 100
-        expected_error_message += "( %.5f %% )" % percent_diff
+        rows_str1 = ""
+        rows_str2 = ""
 
-        generated_diff = difflib.ndiff(
-            str(df1.collect()[0]).splitlines(), 
str(df2.collect()[0]).splitlines()
-        )
-        diff_msg = "\n" + "\n".join(generated_diff) + "\n"
-        diff_msg += "********************" + "\n"
-        generated_diff = difflib.ndiff(
-            str(df1.collect()[1]).splitlines(), 
str(df2.collect()[1]).splitlines()
+        # count different rows
+        for r1, r2 in list(zip_longest(df1.collect(), df2.collect())):
+            rows_str1 += str(r1) + "\n"
+            rows_str2 += str(r2) + "\n"
+
+        generated_diff = _context_diff(
+            actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=2
         )
-        diff_msg += "\n" + "\n".join(generated_diff) + "\n"
-        diff_msg += "********************" + "\n"
 
-        expected_error_message += "\n" + "--- actual\n+++ expected\n" + 
diff_msg
+        error_msg = "Results do not match: "
+        percent_diff = (2 / 2) * 100
+        error_msg += "( %.5f %% )" % percent_diff
+        error_msg += "\n" + "\n".join(generated_diff)
 
         with self.assertRaises(PySparkAssertionError) as pe:
             assertDataFrameEqual(df1, df2, checkRowOrder=True)
@@ -785,7 +832,7 @@ class UtilsTestsMixin:
         self.check_error(
             exception=pe.exception,
             error_class="DIFFERENT_ROWS",
-            message_parameters={"error_msg": expected_error_message},
+            message_parameters={"error_msg": error_msg},
         )
 
     def test_remove_non_word_characters_long(self):
@@ -857,22 +904,22 @@ class UtilsTestsMixin:
             schema=["id", "amount"],
         )
 
-        expected_error_message = "Results do not match: "
-        percent_diff = (2 / 3) * 100
-        expected_error_message += "( %.5f %% )" % percent_diff
+        rows_str1 = ""
+        rows_str2 = ""
 
-        generated_diff = difflib.ndiff(
-            str(df1.collect()[0]).splitlines(), 
str(df2.collect()[0]).splitlines()
-        )
-        diff_msg = "\n" + "\n".join(generated_diff) + "\n"
-        diff_msg += "********************" + "\n"
-        generated_diff = difflib.ndiff(
-            str(df1.collect()[2]).splitlines(), 
str(df2.collect()[2]).splitlines()
+        # count different rows
+        for r1, r2 in list(zip_longest(df1.collect(), df2.collect())):
+            rows_str1 += str(r1) + "\n"
+            rows_str2 += str(r2) + "\n"
+
+        generated_diff = _context_diff(
+            actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=3
         )
-        diff_msg += "\n" + "\n".join(generated_diff) + "\n"
-        diff_msg += "********************" + "\n"
 
-        expected_error_message += "\n" + "--- actual\n+++ expected\n" + 
diff_msg
+        error_msg = "Results do not match: "
+        percent_diff = (2 / 3) * 100
+        error_msg += "( %.5f %% )" % percent_diff
+        error_msg += "\n" + "\n".join(generated_diff)
 
         with self.assertRaises(PySparkAssertionError) as pe:
             assertDataFrameEqual(df1, df2)
@@ -880,7 +927,7 @@ class UtilsTestsMixin:
         self.check_error(
             exception=pe.exception,
             error_class="DIFFERENT_ROWS",
-            message_parameters={"error_msg": expected_error_message},
+            message_parameters={"error_msg": error_msg},
         )
 
         with self.assertRaises(PySparkAssertionError) as pe:
@@ -889,7 +936,7 @@ class UtilsTestsMixin:
         self.check_error(
             exception=pe.exception,
             error_class="DIFFERENT_ROWS",
-            message_parameters={"error_msg": expected_error_message},
+            message_parameters={"error_msg": error_msg},
         )
 
     def test_assert_notequal_schema(self):
@@ -1209,17 +1256,22 @@ class UtilsTestsMixin:
         list_of_rows1 = [Row(1, "abc", 5000), Row(2, "def", 1000)]
         list_of_rows2 = [Row(1, "abc", 5000), Row(2, "defg", 1000)]
 
-        expected_error_message = "Results do not match: "
-        percent_diff = (1 / 2) * 100
-        expected_error_message += "( %.5f %% )" % percent_diff
+        rows_str1 = ""
+        rows_str2 = ""
 
-        generated_diff = difflib.ndiff(
-            str(list_of_rows1[1]).splitlines(), 
str(list_of_rows2[1]).splitlines()
+        # count different rows
+        for r1, r2 in list(zip_longest(list_of_rows1, list_of_rows2)):
+            rows_str1 += str(r1) + "\n"
+            rows_str2 += str(r2) + "\n"
+
+        generated_diff = _context_diff(
+            actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=2
         )
-        diff_msg = "\n" + "\n".join(generated_diff) + "\n"
-        diff_msg += "********************" + "\n"
 
-        expected_error_message += "\n" + "--- actual\n+++ expected\n" + 
diff_msg
+        error_msg = "Results do not match: "
+        percent_diff = (1 / 2) * 100
+        error_msg += "( %.5f %% )" % percent_diff
+        error_msg += "\n" + "\n".join(generated_diff)
 
         with self.assertRaises(PySparkAssertionError) as pe:
             assertDataFrameEqual(list_of_rows1, list_of_rows2)
@@ -1227,7 +1279,7 @@ class UtilsTestsMixin:
         self.check_error(
             exception=pe.exception,
             error_class="DIFFERENT_ROWS",
-            message_parameters={"error_msg": expected_error_message},
+            message_parameters={"error_msg": error_msg},
         )
 
         with self.assertRaises(PySparkAssertionError) as pe:
@@ -1236,7 +1288,7 @@ class UtilsTestsMixin:
         self.check_error(
             exception=pe.exception,
             error_class="DIFFERENT_ROWS",
-            message_parameters={"error_msg": expected_error_message},
+            message_parameters={"error_msg": error_msg},
         )
 
     def test_list_row_unequal_schema(self):
@@ -1244,28 +1296,77 @@ class UtilsTestsMixin:
             data=[
                 (1, 3000),
                 (2, 1000),
+                (3, 10),
             ],
             schema=["id", "amount"],
         )
 
-        list_of_rows = [Row(1, "3000"), Row(2, "1000")]
+        list_of_rows = [Row(id=1, amount=300), Row(id=2, amount=100), 
Row(id=3, amount=10)]
 
-        expected_error_message = "Results do not match: "
-        percent_diff = (2 / 2) * 100
-        expected_error_message += "( %.5f %% )" % percent_diff
+        rows_str1 = ""
+        rows_str2 = ""
 
-        generated_diff = difflib.ndiff(
-            str(df1.collect()[0]).splitlines(), 
str(list_of_rows[0]).splitlines()
+        # count different rows
+        for r1, r2 in list(zip_longest(df1, list_of_rows)):
+            rows_str1 += str(r1) + "\n"
+            rows_str2 += str(r2) + "\n"
+
+        generated_diff = _context_diff(
+            actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=3
+        )
+
+        error_msg = "Results do not match: "
+        percent_diff = (2 / 3) * 100
+        error_msg += "( %.5f %% )" % percent_diff
+        error_msg += "\n" + "\n".join(generated_diff)
+
+        with self.assertRaises(PySparkAssertionError) as pe:
+            assertDataFrameEqual(df1, list_of_rows)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="DIFFERENT_ROWS",
+            message_parameters={"error_msg": error_msg},
         )
-        diff_msg = "\n" + "\n".join(generated_diff) + "\n"
-        diff_msg += "********************" + "\n"
-        generated_diff = difflib.ndiff(
-            str(df1.collect()[1]).splitlines(), 
str(list_of_rows[1]).splitlines()
+
+        with self.assertRaises(PySparkAssertionError) as pe:
+            assertDataFrameEqual(df1, list_of_rows, checkRowOrder=True)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="DIFFERENT_ROWS",
+            message_parameters={"error_msg": error_msg},
         )
-        diff_msg += "\n" + "\n".join(generated_diff) + "\n"
-        diff_msg += "********************" + "\n"
 
-        expected_error_message += "\n" + "--- actual\n+++ expected\n" + 
diff_msg
+    def test_list_row_unequal_schema(self):
+        from pyspark.sql import Row
+
+        df1 = self.spark.createDataFrame(
+            data=[
+                (1, 3000),
+                (2, 1000),
+            ],
+            schema=["id", "amount"],
+        )
+
+        list_of_rows = [Row(1, "3000"), Row(2, "1000")]
+
+        rows_str1 = ""
+        rows_str2 = ""
+
+        # count different rows
+        for r1, r2 in list(zip_longest(df1.collect(), list_of_rows)):
+            rows_str1 += str(r1) + "\n"
+            rows_str2 += str(r2) + "\n"
+
+        generated_diff = _context_diff(
+            actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=2
+        )
+
+        error_msg = "Results do not match: "
+        percent_diff = (2 / 2) * 100
+        error_msg += "( %.5f %% )" % percent_diff
+        error_msg += "\n" + "\n".join(generated_diff)
 
         with self.assertRaises(PySparkAssertionError) as pe:
             assertDataFrameEqual(df1, list_of_rows)
@@ -1273,7 +1374,7 @@ class UtilsTestsMixin:
         self.check_error(
             exception=pe.exception,
             error_class="DIFFERENT_ROWS",
-            message_parameters={"error_msg": expected_error_message},
+            message_parameters={"error_msg": error_msg},
         )
 
         with self.assertRaises(PySparkAssertionError) as pe:
@@ -1282,7 +1383,7 @@ class UtilsTestsMixin:
         self.check_error(
             exception=pe.exception,
             error_class="DIFFERENT_ROWS",
-            message_parameters={"error_msg": expected_error_message},
+            message_parameters={"error_msg": error_msg},
         )
 
 
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index eb55255863d..2a23476112f 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -29,6 +29,7 @@ from typing import (
     Dict,
     List,
     Tuple,
+    Iterator,
 )
 from itertools import zip_longest
 
@@ -190,6 +191,49 @@ def search_jar(project_relative_path, sbt_jar_name_prefix, 
mvn_jar_name_prefix):
         return jars[0]
 
 
+def _terminal_color_support():
+    try:
+        # determine if environment supports color
+        script = "$(test $(tput colors)) && $(test $(tput colors) -ge 8) && 
echo true || echo false"
+        return os.popen(script).read()
+    except Exception:
+        return False
+
+
+def _context_diff(actual: List[str], expected: List[str], n: int = 3):
+    """
+    Modified from difflib context_diff API,
+    see original code here: 
https://github.com/python/cpython/blob/main/Lib/difflib.py#L1180
+    """
+
+    def red(s: str) -> str:
+        red_color = "\033[31m"
+        no_color = "\033[0m"
+        return red_color + str(s) + no_color
+
+    prefix = dict(insert="+ ", delete="- ", replace="! ", equal="  ")
+    for group in difflib.SequenceMatcher(None, actual, 
expected).get_grouped_opcodes(n):
+        yield "*** actual ***"
+        if any(tag in {"replace", "delete"} for tag, _, _, _, _ in group):
+            for tag, i1, i2, _, _ in group:
+                for line in actual[i1:i2]:
+                    if tag != "equal" and _terminal_color_support():
+                        yield red(prefix[tag] + str(line))
+                    else:
+                        yield prefix[tag] + str(line)
+
+        yield "\n"
+
+        yield "*** expected ***"
+        if any(tag in {"replace", "insert"} for tag, _, _, _, _ in group):
+            for tag, _, _, j1, j2 in group:
+                for line in expected[j1:j2]:
+                    if tag != "equal" and _terminal_color_support():
+                        yield red(prefix[tag] + str(line))
+                    else:
+                        yield prefix[tag] + str(line)
+
+
 class PySparkErrorTestUtils:
     """
     This util provide functions to accurate and consistent error testing
@@ -303,6 +347,7 @@ def assertSchemaEqual(actual: StructType, expected: 
StructType):
         else:
             return False
 
+    # ignore nullable flag by default
     if not compare_schemas_ignore_nullable(actual, expected):
         generated_diff = difflib.ndiff(str(actual).splitlines(), 
str(expected).splitlines())
 
@@ -492,23 +537,29 @@ def assertDataFrameEqual(
 
     def assert_rows_equal(rows1: List[Row], rows2: List[Row]):
         zipped = list(zip_longest(rows1, rows2))
-        rows_equal = True
-        error_msg = "Results do not match: "
-        diff_msg = ""
         diff_rows_cnt = 0
+        diff_rows = False
 
+        rows_str1 = ""
+        rows_str2 = ""
+
+        # count different rows
         for r1, r2 in zipped:
+            rows_str1 += str(r1) + "\n"
+            rows_str2 += str(r2) + "\n"
             if not compare_rows(r1, r2):
-                rows_equal = False
                 diff_rows_cnt += 1
-                generated_diff = difflib.ndiff(str(r1).splitlines(), 
str(r2).splitlines())
-                diff_msg += "\n" + "\n".join(generated_diff) + "\n"
-                diff_msg += "********************" + "\n"
+                diff_rows = True
+
+        generated_diff = _context_diff(
+            actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), 
n=len(zipped)
+        )
 
-        if not rows_equal:
+        if diff_rows:
+            error_msg = "Results do not match: "
             percent_diff = (diff_rows_cnt / len(zipped)) * 100
             error_msg += "( %.5f %% )" % percent_diff
-            error_msg += "\n" + "--- actual\n+++ expected\n" + diff_msg
+            error_msg += "\n" + "\n".join(generated_diff)
             raise PySparkAssertionError(
                 error_class="DIFFERENT_ROWS",
                 message_parameters={"error_msg": error_msg},


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to