This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new dbee20a426e [SPARK-44218][PYTHON] Customize diff log in assertDataFrameEqual error message format dbee20a426e is described below commit dbee20a426e8a290e83131dddf23c35b5b249959 Author: Amanda Liu <amanda....@databricks.com> AuthorDate: Wed Aug 2 08:40:54 2023 +0900 [SPARK-44218][PYTHON] Customize diff log in assertDataFrameEqual error message format ### What changes were proposed in this pull request? This PR improves the error message format for `assertDataFrameEqual` by creating a new custom diff log function, `_context_diff_ `to print all Rows for `actual` and `expected` and highlight their different rows in color. ### Why are the changes needed? The change is needed to clarify the error message for unequal DataFrames. ### Does this PR introduce _any_ user-facing change? Yes, the PR affects the error message display for users. See the new error message output examples below. ### How was this patch tested? Added tests to `python/pyspark/sql/tests/test_utils.py` and `python/pyspark/sql/tests/connect/test_utils.py`. Example error messages: <img width="894" alt="Screenshot 2023-07-31 at 7 41 55 PM" src="https://github.com/apache/spark/assets/68875504/04b5b985-4670-4d4b-8032-9704523a3df1"> <img width="875" alt="errormessage1" src="https://github.com/apache/spark/assets/68875504/d7b420c0-80d3-4853-b8e5-48ff6b459dc7"> Closes #42196 from asl3/update-difflib. Authored-by: Amanda Liu <amanda....@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> (cherry picked from commit 4f66f091ad5dbdd9177a7550a8da7e266a869147) Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/errors/error_classes.py | 24 ++- python/pyspark/pandas/tests/test_utils.py | 1 - python/pyspark/sql/tests/test_utils.py | 269 ++++++++++++++++++++---------- python/pyspark/testing/utils.py | 69 +++++++- 4 files changed, 261 insertions(+), 102 deletions(-) diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index acecc48f0a8..554a25952b9 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -167,36 +167,44 @@ ERROR_CLASSES_JSON = """ "DIFFERENT_PANDAS_DATAFRAME" : { "message" : [ "DataFrames are not almost equal:", - "Left: <left>", + "Left:", + "<left>", "<left_dtype>", - "Right: <right>", + "Right:", + "<right>", "<right_dtype>" ] }, "DIFFERENT_PANDAS_INDEX" : { "message" : [ "Indices are not almost equal:", - "Left: <left>", + "Left:", + "<left>", "<left_dtype>", - "Right: <right>", + "Right:", + "<right>", "<right_dtype>" ] }, "DIFFERENT_PANDAS_MULTIINDEX" : { "message" : [ "MultiIndices are not almost equal:", - "Left: <left>", + "Left:", + "<left>", "<left_dtype>", - "Right: <right>", + "Right:", + "<right>", "<right_dtype>" ] }, "DIFFERENT_PANDAS_SERIES" : { "message" : [ "Series are not almost equal:", - "Left: <left>", + "Left:", + "<left>", "<left_dtype>", - "Right: <right>", + "Right:", + "<right>", "<right_dtype>" ] }, diff --git a/python/pyspark/pandas/tests/test_utils.py b/python/pyspark/pandas/tests/test_utils.py index de7b0449dae..3d658446f27 100644 --- a/python/pyspark/pandas/tests/test_utils.py +++ b/python/pyspark/pandas/tests/test_utils.py @@ -16,7 +16,6 @@ # import pandas as pd -from typing import Union from pyspark.pandas.indexes.base import Index from pyspark.pandas.utils import ( diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py index b7ab596880f..76d397e3ade 100644 --- a/python/pyspark/sql/tests/test_utils.py +++ b/python/pyspark/sql/tests/test_utils.py @@ -23,7 +23,7 @@ from pyspark.errors import ( IllegalArgumentException, SparkUpgradeException, ) -from pyspark.testing.utils import assertDataFrameEqual, assertSchemaEqual +from pyspark.testing.utils import assertDataFrameEqual, assertSchemaEqual, _context_diff from pyspark.testing.sqlutils import ReusedSQLTestCase from pyspark.sql import Row import pyspark.sql.functions as F @@ -44,6 +44,7 @@ from pyspark.sql.dataframe import DataFrame import difflib from typing import List, Union +from itertools import zip_longest class UtilsTestsMixin: @@ -151,17 +152,22 @@ class UtilsTestsMixin: ), ) - expected_error_message = "Results do not match: " - percent_diff = (1 / 2) * 100 - expected_error_message += "( %.5f %% )" % percent_diff + rows_str1 = "" + rows_str2 = "" + + # count different rows + for r1, r2 in list(zip_longest(df1.collect(), df2.collect())): + rows_str1 += str(r1) + "\n" + rows_str2 += str(r2) + "\n" - generated_diff = difflib.ndiff( - str(df1.collect()[1]).splitlines(), str(df2.collect()[1]).splitlines() + generated_diff = _context_diff( + actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=2 ) - diff_msg = "\n" + "\n".join(generated_diff) + "\n" - diff_msg += "********************" + "\n" - expected_error_message += "\n" + "--- actual\n+++ expected\n" + diff_msg + error_msg = "Results do not match: " + percent_diff = (1 / 2) * 100 + error_msg += "( %.5f %% )" % percent_diff + error_msg += "\n" + "\n".join(generated_diff) with self.assertRaises(PySparkAssertionError) as pe: assertDataFrameEqual(df1, df2) @@ -169,7 +175,16 @@ class UtilsTestsMixin: self.check_error( exception=pe.exception, error_class="DIFFERENT_ROWS", - message_parameters={"error_msg": expected_error_message}, + message_parameters={"error_msg": error_msg}, + ) + + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, df2, checkRowOrder=True) + + self.check_error( + exception=pe.exception, + error_class="DIFFERENT_ROWS", + message_parameters={"error_msg": error_msg}, ) def test_assert_approx_equal_arraytype_float_custom_rtol_pass(self): @@ -266,6 +281,7 @@ class UtilsTestsMixin: def test_assert_notequal_arraytype(self): df1 = self.spark.createDataFrame( data=[ + ("Amy", ["C++", "Rust"]), ("John", ["Python", "Java"]), ("Jane", ["Scala", "SQL", "Java"]), ], @@ -278,6 +294,7 @@ class UtilsTestsMixin: ) df2 = self.spark.createDataFrame( data=[ + ("Amy", ["C++", "Rust"]), ("John", ["Python", "Java"]), ("Jane", ["Scala", "Java"]), ], @@ -289,17 +306,25 @@ class UtilsTestsMixin: ), ) - expected_error_message = "Results do not match: " - percent_diff = (1 / 2) * 100 - expected_error_message += "( %.5f %% )" % percent_diff + rows_str1 = "" + rows_str2 = "" + + sorted_list1 = sorted(df1.collect(), key=lambda x: str(x)) + sorted_list2 = sorted(df2.collect(), key=lambda x: str(x)) + + # count different rows + for r1, r2 in list(zip_longest(sorted_list1, sorted_list2)): + rows_str1 += str(r1) + "\n" + rows_str2 += str(r2) + "\n" - generated_diff = difflib.ndiff( - str(df1.collect()[1]).splitlines(), str(df2.collect()[1]).splitlines() + generated_diff = _context_diff( + actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=3 ) - diff_msg = "\n" + "\n".join(generated_diff) + "\n" - diff_msg += "********************" + "\n" - expected_error_message += "\n" + "--- actual\n+++ expected\n" + diff_msg + error_msg = "Results do not match: " + percent_diff = (1 / 3) * 100 + error_msg += "( %.5f %% )" % percent_diff + error_msg += "\n" + "\n".join(generated_diff) with self.assertRaises(PySparkAssertionError) as pe: assertDataFrameEqual(df1, df2) @@ -307,16 +332,33 @@ class UtilsTestsMixin: self.check_error( exception=pe.exception, error_class="DIFFERENT_ROWS", - message_parameters={"error_msg": expected_error_message}, + message_parameters={"error_msg": error_msg}, + ) + + rows_str1 = "" + rows_str2 = "" + + # count different rows + for r1, r2 in list(zip_longest(df1.collect(), df2.collect())): + rows_str1 += str(r1) + "\n" + rows_str2 += str(r2) + "\n" + + generated_diff = _context_diff( + actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=3 ) + error_msg = "Results do not match: " + percent_diff = (1 / 3) * 100 + error_msg += "( %.5f %% )" % percent_diff + error_msg += "\n" + "\n".join(generated_diff) + with self.assertRaises(PySparkAssertionError) as pe: assertDataFrameEqual(df1, df2, checkRowOrder=True) self.check_error( exception=pe.exception, error_class="DIFFERENT_ROWS", - message_parameters={"error_msg": expected_error_message}, + message_parameters={"error_msg": error_msg}, ) def test_assert_equal_maptype(self): @@ -588,17 +630,22 @@ class UtilsTestsMixin: schema=["id", "amount"], ) - expected_error_message = "Results do not match: " - percent_diff = (1 / 2) * 100 - expected_error_message += "( %.5f %% )" % percent_diff + rows_str1 = "" + rows_str2 = "" + + # count different rows + for r1, r2 in list(zip_longest(df1.collect(), df2.collect())): + rows_str1 += str(r1) + "\n" + rows_str2 += str(r2) + "\n" - generated_diff = difflib.ndiff( - str(df1.collect()[1]).splitlines(), str(df2.collect()[1]).splitlines() + generated_diff = _context_diff( + actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=2 ) - diff_msg = "\n" + "\n".join(generated_diff) + "\n" - diff_msg += "********************" + "\n" - expected_error_message += "\n" + "--- actual\n+++ expected\n" + diff_msg + error_msg = "Results do not match: " + percent_diff = (1 / 2) * 100 + error_msg += "( %.5f %% )" % percent_diff + error_msg += "\n" + "\n".join(generated_diff) with self.assertRaises(PySparkAssertionError) as pe: assertDataFrameEqual(df1, df2) @@ -606,7 +653,7 @@ class UtilsTestsMixin: self.check_error( exception=pe.exception, error_class="DIFFERENT_ROWS", - message_parameters={"error_msg": expected_error_message}, + message_parameters={"error_msg": error_msg}, ) with self.assertRaises(PySparkAssertionError) as pe: @@ -615,7 +662,7 @@ class UtilsTestsMixin: self.check_error( exception=pe.exception, error_class="DIFFERENT_ROWS", - message_parameters={"error_msg": expected_error_message}, + message_parameters={"error_msg": error_msg}, ) def test_assert_equal_nulldf(self): @@ -762,22 +809,22 @@ class UtilsTestsMixin: schema=["id", "amount"], ) - expected_error_message = "Results do not match: " - percent_diff = (2 / 2) * 100 - expected_error_message += "( %.5f %% )" % percent_diff + rows_str1 = "" + rows_str2 = "" - generated_diff = difflib.ndiff( - str(df1.collect()[0]).splitlines(), str(df2.collect()[0]).splitlines() - ) - diff_msg = "\n" + "\n".join(generated_diff) + "\n" - diff_msg += "********************" + "\n" - generated_diff = difflib.ndiff( - str(df1.collect()[1]).splitlines(), str(df2.collect()[1]).splitlines() + # count different rows + for r1, r2 in list(zip_longest(df1.collect(), df2.collect())): + rows_str1 += str(r1) + "\n" + rows_str2 += str(r2) + "\n" + + generated_diff = _context_diff( + actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=2 ) - diff_msg += "\n" + "\n".join(generated_diff) + "\n" - diff_msg += "********************" + "\n" - expected_error_message += "\n" + "--- actual\n+++ expected\n" + diff_msg + error_msg = "Results do not match: " + percent_diff = (2 / 2) * 100 + error_msg += "( %.5f %% )" % percent_diff + error_msg += "\n" + "\n".join(generated_diff) with self.assertRaises(PySparkAssertionError) as pe: assertDataFrameEqual(df1, df2, checkRowOrder=True) @@ -785,7 +832,7 @@ class UtilsTestsMixin: self.check_error( exception=pe.exception, error_class="DIFFERENT_ROWS", - message_parameters={"error_msg": expected_error_message}, + message_parameters={"error_msg": error_msg}, ) def test_remove_non_word_characters_long(self): @@ -857,22 +904,22 @@ class UtilsTestsMixin: schema=["id", "amount"], ) - expected_error_message = "Results do not match: " - percent_diff = (2 / 3) * 100 - expected_error_message += "( %.5f %% )" % percent_diff + rows_str1 = "" + rows_str2 = "" - generated_diff = difflib.ndiff( - str(df1.collect()[0]).splitlines(), str(df2.collect()[0]).splitlines() - ) - diff_msg = "\n" + "\n".join(generated_diff) + "\n" - diff_msg += "********************" + "\n" - generated_diff = difflib.ndiff( - str(df1.collect()[2]).splitlines(), str(df2.collect()[2]).splitlines() + # count different rows + for r1, r2 in list(zip_longest(df1.collect(), df2.collect())): + rows_str1 += str(r1) + "\n" + rows_str2 += str(r2) + "\n" + + generated_diff = _context_diff( + actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=3 ) - diff_msg += "\n" + "\n".join(generated_diff) + "\n" - diff_msg += "********************" + "\n" - expected_error_message += "\n" + "--- actual\n+++ expected\n" + diff_msg + error_msg = "Results do not match: " + percent_diff = (2 / 3) * 100 + error_msg += "( %.5f %% )" % percent_diff + error_msg += "\n" + "\n".join(generated_diff) with self.assertRaises(PySparkAssertionError) as pe: assertDataFrameEqual(df1, df2) @@ -880,7 +927,7 @@ class UtilsTestsMixin: self.check_error( exception=pe.exception, error_class="DIFFERENT_ROWS", - message_parameters={"error_msg": expected_error_message}, + message_parameters={"error_msg": error_msg}, ) with self.assertRaises(PySparkAssertionError) as pe: @@ -889,7 +936,7 @@ class UtilsTestsMixin: self.check_error( exception=pe.exception, error_class="DIFFERENT_ROWS", - message_parameters={"error_msg": expected_error_message}, + message_parameters={"error_msg": error_msg}, ) def test_assert_notequal_schema(self): @@ -1209,17 +1256,22 @@ class UtilsTestsMixin: list_of_rows1 = [Row(1, "abc", 5000), Row(2, "def", 1000)] list_of_rows2 = [Row(1, "abc", 5000), Row(2, "defg", 1000)] - expected_error_message = "Results do not match: " - percent_diff = (1 / 2) * 100 - expected_error_message += "( %.5f %% )" % percent_diff + rows_str1 = "" + rows_str2 = "" - generated_diff = difflib.ndiff( - str(list_of_rows1[1]).splitlines(), str(list_of_rows2[1]).splitlines() + # count different rows + for r1, r2 in list(zip_longest(list_of_rows1, list_of_rows2)): + rows_str1 += str(r1) + "\n" + rows_str2 += str(r2) + "\n" + + generated_diff = _context_diff( + actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=2 ) - diff_msg = "\n" + "\n".join(generated_diff) + "\n" - diff_msg += "********************" + "\n" - expected_error_message += "\n" + "--- actual\n+++ expected\n" + diff_msg + error_msg = "Results do not match: " + percent_diff = (1 / 2) * 100 + error_msg += "( %.5f %% )" % percent_diff + error_msg += "\n" + "\n".join(generated_diff) with self.assertRaises(PySparkAssertionError) as pe: assertDataFrameEqual(list_of_rows1, list_of_rows2) @@ -1227,7 +1279,7 @@ class UtilsTestsMixin: self.check_error( exception=pe.exception, error_class="DIFFERENT_ROWS", - message_parameters={"error_msg": expected_error_message}, + message_parameters={"error_msg": error_msg}, ) with self.assertRaises(PySparkAssertionError) as pe: @@ -1236,7 +1288,7 @@ class UtilsTestsMixin: self.check_error( exception=pe.exception, error_class="DIFFERENT_ROWS", - message_parameters={"error_msg": expected_error_message}, + message_parameters={"error_msg": error_msg}, ) def test_list_row_unequal_schema(self): @@ -1244,28 +1296,77 @@ class UtilsTestsMixin: data=[ (1, 3000), (2, 1000), + (3, 10), ], schema=["id", "amount"], ) - list_of_rows = [Row(1, "3000"), Row(2, "1000")] + list_of_rows = [Row(id=1, amount=300), Row(id=2, amount=100), Row(id=3, amount=10)] - expected_error_message = "Results do not match: " - percent_diff = (2 / 2) * 100 - expected_error_message += "( %.5f %% )" % percent_diff + rows_str1 = "" + rows_str2 = "" - generated_diff = difflib.ndiff( - str(df1.collect()[0]).splitlines(), str(list_of_rows[0]).splitlines() + # count different rows + for r1, r2 in list(zip_longest(df1, list_of_rows)): + rows_str1 += str(r1) + "\n" + rows_str2 += str(r2) + "\n" + + generated_diff = _context_diff( + actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=3 + ) + + error_msg = "Results do not match: " + percent_diff = (2 / 3) * 100 + error_msg += "( %.5f %% )" % percent_diff + error_msg += "\n" + "\n".join(generated_diff) + + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, list_of_rows) + + self.check_error( + exception=pe.exception, + error_class="DIFFERENT_ROWS", + message_parameters={"error_msg": error_msg}, ) - diff_msg = "\n" + "\n".join(generated_diff) + "\n" - diff_msg += "********************" + "\n" - generated_diff = difflib.ndiff( - str(df1.collect()[1]).splitlines(), str(list_of_rows[1]).splitlines() + + with self.assertRaises(PySparkAssertionError) as pe: + assertDataFrameEqual(df1, list_of_rows, checkRowOrder=True) + + self.check_error( + exception=pe.exception, + error_class="DIFFERENT_ROWS", + message_parameters={"error_msg": error_msg}, ) - diff_msg += "\n" + "\n".join(generated_diff) + "\n" - diff_msg += "********************" + "\n" - expected_error_message += "\n" + "--- actual\n+++ expected\n" + diff_msg + def test_list_row_unequal_schema(self): + from pyspark.sql import Row + + df1 = self.spark.createDataFrame( + data=[ + (1, 3000), + (2, 1000), + ], + schema=["id", "amount"], + ) + + list_of_rows = [Row(1, "3000"), Row(2, "1000")] + + rows_str1 = "" + rows_str2 = "" + + # count different rows + for r1, r2 in list(zip_longest(df1.collect(), list_of_rows)): + rows_str1 += str(r1) + "\n" + rows_str2 += str(r2) + "\n" + + generated_diff = _context_diff( + actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=2 + ) + + error_msg = "Results do not match: " + percent_diff = (2 / 2) * 100 + error_msg += "( %.5f %% )" % percent_diff + error_msg += "\n" + "\n".join(generated_diff) with self.assertRaises(PySparkAssertionError) as pe: assertDataFrameEqual(df1, list_of_rows) @@ -1273,7 +1374,7 @@ class UtilsTestsMixin: self.check_error( exception=pe.exception, error_class="DIFFERENT_ROWS", - message_parameters={"error_msg": expected_error_message}, + message_parameters={"error_msg": error_msg}, ) with self.assertRaises(PySparkAssertionError) as pe: @@ -1282,7 +1383,7 @@ class UtilsTestsMixin: self.check_error( exception=pe.exception, error_class="DIFFERENT_ROWS", - message_parameters={"error_msg": expected_error_message}, + message_parameters={"error_msg": error_msg}, ) diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index eb55255863d..2a23476112f 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -29,6 +29,7 @@ from typing import ( Dict, List, Tuple, + Iterator, ) from itertools import zip_longest @@ -190,6 +191,49 @@ def search_jar(project_relative_path, sbt_jar_name_prefix, mvn_jar_name_prefix): return jars[0] +def _terminal_color_support(): + try: + # determine if environment supports color + script = "$(test $(tput colors)) && $(test $(tput colors) -ge 8) && echo true || echo false" + return os.popen(script).read() + except Exception: + return False + + +def _context_diff(actual: List[str], expected: List[str], n: int = 3): + """ + Modified from difflib context_diff API, + see original code here: https://github.com/python/cpython/blob/main/Lib/difflib.py#L1180 + """ + + def red(s: str) -> str: + red_color = "\033[31m" + no_color = "\033[0m" + return red_color + str(s) + no_color + + prefix = dict(insert="+ ", delete="- ", replace="! ", equal=" ") + for group in difflib.SequenceMatcher(None, actual, expected).get_grouped_opcodes(n): + yield "*** actual ***" + if any(tag in {"replace", "delete"} for tag, _, _, _, _ in group): + for tag, i1, i2, _, _ in group: + for line in actual[i1:i2]: + if tag != "equal" and _terminal_color_support(): + yield red(prefix[tag] + str(line)) + else: + yield prefix[tag] + str(line) + + yield "\n" + + yield "*** expected ***" + if any(tag in {"replace", "insert"} for tag, _, _, _, _ in group): + for tag, _, _, j1, j2 in group: + for line in expected[j1:j2]: + if tag != "equal" and _terminal_color_support(): + yield red(prefix[tag] + str(line)) + else: + yield prefix[tag] + str(line) + + class PySparkErrorTestUtils: """ This util provide functions to accurate and consistent error testing @@ -303,6 +347,7 @@ def assertSchemaEqual(actual: StructType, expected: StructType): else: return False + # ignore nullable flag by default if not compare_schemas_ignore_nullable(actual, expected): generated_diff = difflib.ndiff(str(actual).splitlines(), str(expected).splitlines()) @@ -492,23 +537,29 @@ def assertDataFrameEqual( def assert_rows_equal(rows1: List[Row], rows2: List[Row]): zipped = list(zip_longest(rows1, rows2)) - rows_equal = True - error_msg = "Results do not match: " - diff_msg = "" diff_rows_cnt = 0 + diff_rows = False + rows_str1 = "" + rows_str2 = "" + + # count different rows for r1, r2 in zipped: + rows_str1 += str(r1) + "\n" + rows_str2 += str(r2) + "\n" if not compare_rows(r1, r2): - rows_equal = False diff_rows_cnt += 1 - generated_diff = difflib.ndiff(str(r1).splitlines(), str(r2).splitlines()) - diff_msg += "\n" + "\n".join(generated_diff) + "\n" - diff_msg += "********************" + "\n" + diff_rows = True + + generated_diff = _context_diff( + actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=len(zipped) + ) - if not rows_equal: + if diff_rows: + error_msg = "Results do not match: " percent_diff = (diff_rows_cnt / len(zipped)) * 100 error_msg += "( %.5f %% )" % percent_diff - error_msg += "\n" + "--- actual\n+++ expected\n" + diff_msg + error_msg += "\n" + "\n".join(generated_diff) raise PySparkAssertionError( error_class="DIFFERENT_ROWS", message_parameters={"error_msg": error_msg}, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org