This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a00a32a2f96 [SPARK-44216][PYTHON] Make assertSchemaEqual API public
a00a32a2f96 is described below

commit a00a32a2f967908eb160e8476f6640c7d5d3fee2
Author: Amanda Liu <amanda....@databricks.com>
AuthorDate: Mon Jul 17 08:46:11 2023 +0900

    [SPARK-44216][PYTHON] Make assertSchemaEqual API public
    
    ### What changes were proposed in this pull request?
    This PR implements and exposes the PySpark util function 
`assertSchemaEqual` to test for DataFrame schema equality. It uses the Python 
built-in difflib library to display differences between schemas.
    
    SPIP: 
https://docs.google.com/document/d/1OkyBn3JbEHkkQgSQ45Lq82esXjr9rm2Vj7Ih_4zycRc/edit#heading=h.f5f0u2riv07v
    
    ### Why are the changes needed?
    The `assertSchemaEqual` function compares schema equality, simplifying the 
testing process for PySpark users. It adds functionality to ignore the 
nullability flag for schemas and nested types.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, the PR exposes a user-facing PySpark util function `assertSchemaEqual`.
    
    ### How was this patch tested?
    Added tests to `runtime/python/pyspark/sql/tests/test_utils.py` and 
`runtime/python/pyspark/sql/tests/connect/test_utils.py`
    
    Sample schema inequality error message:
    
    ![Screenshot 2023-07-14 at 10 07 22 
AM](https://github.com/apache/spark/assets/68875504/e1aa0c87-e2c4-44c7-b84c-8af21d5fad84)
    
    Closes #41927 from asl3/assert-schema-equal.
    
    Authored-by: Amanda Liu <amanda....@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/docs/source/reference/pyspark.testing.rst |   1 +
 python/pyspark/errors/error_classes.py           |   7 +-
 python/pyspark/sql/tests/test_utils.py           | 166 +++++++++++++++++++----
 python/pyspark/testing/__init__.py               |   4 +-
 python/pyspark/testing/utils.py                  | 166 ++++++++++++++++-------
 5 files changed, 268 insertions(+), 76 deletions(-)

diff --git a/python/docs/source/reference/pyspark.testing.rst 
b/python/docs/source/reference/pyspark.testing.rst
index df0a334db70..7a6b6cc0d70 100644
--- a/python/docs/source/reference/pyspark.testing.rst
+++ b/python/docs/source/reference/pyspark.testing.rst
@@ -26,3 +26,4 @@ Testing
     :toctree: api/
 
     assertDataFrameEqual
+    assertSchemaEqual
diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index 56b166b53c5..2cecee4da44 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -171,9 +171,10 @@ ERROR_CLASSES_JSON = """
   },
   "DIFFERENT_SCHEMA" : {
     "message" : [
-      "Schemas do not match:",
-      "df schema: <df_schema>",
-      "expected schema: <expected_schema>"
+      "Schemas do not match.",
+      "--- actual",
+      "+++ expected",
+      "<error_msg>"
     ]
   },
   "DISALLOWED_TYPE_FOR_CONTAINER" : {
diff --git a/python/pyspark/sql/tests/test_utils.py 
b/python/pyspark/sql/tests/test_utils.py
index 46692450cbd..5b859ad15a5 100644
--- a/python/pyspark/sql/tests/test_utils.py
+++ b/python/pyspark/sql/tests/test_utils.py
@@ -23,7 +23,7 @@ from pyspark.errors import (
     IllegalArgumentException,
     SparkUpgradeException,
 )
-from pyspark.testing.utils import assertDataFrameEqual
+from pyspark.testing.utils import assertDataFrameEqual, assertSchemaEqual
 from pyspark.testing.sqlutils import ReusedSQLTestCase
 import pyspark.sql.functions as F
 from pyspark.sql.functions import to_date, unix_timestamp, from_unixtime
@@ -37,8 +37,11 @@ from pyspark.sql.types import (
     DoubleType,
     StructField,
     IntegerType,
+    BooleanType,
 )
 
+import difflib
+
 
 class UtilsTestsMixin:
     def test_assert_equal_inttype(self):
@@ -149,7 +152,7 @@ class UtilsTestsMixin:
         percent_diff = (1 / 2) * 100
         expected_error_message += "( %.5f %% )" % percent_diff
         diff_msg = (
-            "[df]"
+            "[actual]"
             + "\n"
             + str(df1.collect()[1])
             + "\n\n"
@@ -292,7 +295,7 @@ class UtilsTestsMixin:
         percent_diff = (1 / 2) * 100
         expected_error_message += "( %.5f %% )" % percent_diff
         diff_msg = (
-            "[df]"
+            "[actual]"
             + "\n"
             + str(df1.collect()[1])
             + "\n\n"
@@ -596,7 +599,7 @@ class UtilsTestsMixin:
         percent_diff = (1 / 2) * 100
         expected_error_message += "( %.5f %% )" % percent_diff
         diff_msg = (
-            "[df]"
+            "[actual]"
             + "\n"
             + str(df1.collect()[1])
             + "\n\n"
@@ -720,7 +723,7 @@ class UtilsTestsMixin:
         percent_diff = (2 / 2) * 100
         expected_error_message += "( %.5f %% )" % percent_diff
         diff_msg = (
-            "[df]"
+            "[actual]"
             + "\n"
             + str(df1.collect()[0])
             + "\n\n"
@@ -732,7 +735,7 @@ class UtilsTestsMixin:
             + "\n\n"
         )
         diff_msg += (
-            "[df]"
+            "[actual]"
             + "\n"
             + str(df1.collect()[1])
             + "\n\n"
@@ -827,7 +830,7 @@ class UtilsTestsMixin:
         percent_diff = (2 / 3) * 100
         expected_error_message += "( %.5f %% )" % percent_diff
         diff_msg = (
-            "[df]"
+            "[actual]"
             + "\n"
             + str(df1.collect()[0])
             + "\n\n"
@@ -839,7 +842,7 @@ class UtilsTestsMixin:
             + "\n\n"
         )
         diff_msg += (
-            "[df]"
+            "[actual]"
             + "\n"
             + str(df1.collect()[2])
             + "\n\n"
@@ -876,32 +879,27 @@ class UtilsTestsMixin:
                 (1, 1000),
                 (2, 3000),
             ],
-            schema=["id", "amount"],
+            schema=["id", "number"],
         )
         df2 = self.spark.createDataFrame(
             data=[
                 ("1", 1000),
-                ("2", 3000),
+                ("2", 5000),
             ],
             schema=["id", "amount"],
         )
 
-        with self.assertRaises(PySparkAssertionError) as pe:
-            assertDataFrameEqual(df1, df2)
+        generated_diff = difflib.ndiff(str(df1.schema).splitlines(), 
str(df2.schema).splitlines())
 
-        self.check_error(
-            exception=pe.exception,
-            error_class="DIFFERENT_SCHEMA",
-            message_parameters={"df_schema": df1.schema, "expected_schema": 
df2.schema},
-        )
+        expected_error_msg = "\n".join(generated_diff)
 
         with self.assertRaises(PySparkAssertionError) as pe:
-            assertDataFrameEqual(df1, df2, checkRowOrder=True)
+            assertDataFrameEqual(df1, df2)
 
         self.check_error(
             exception=pe.exception,
             error_class="DIFFERENT_SCHEMA",
-            message_parameters={"df_schema": df1.schema, "expected_schema": 
df2.schema},
+            message_parameters={"error_msg": expected_error_msg},
         )
 
     def test_diff_schema_lens(self):
@@ -921,22 +919,142 @@ class UtilsTestsMixin:
             schema=["id", "amount", "letter"],
         )
 
+        generated_diff = difflib.ndiff(str(df1.schema).splitlines(), 
str(df2.schema).splitlines())
+
+        expected_error_msg = "\n".join(generated_diff)
+
         with self.assertRaises(PySparkAssertionError) as pe:
             assertDataFrameEqual(df1, df2)
 
         self.check_error(
             exception=pe.exception,
             error_class="DIFFERENT_SCHEMA",
-            message_parameters={"df_schema": df1.schema, "expected_schema": 
df2.schema},
+            message_parameters={"error_msg": expected_error_msg},
+        )
+
+    def test_schema_ignore_nullable(self):
+        s1 = StructType(
+            [StructField("id", IntegerType(), True), StructField("name", 
StringType(), True)]
         )
 
+        df1 = self.spark.createDataFrame([(1, "jane"), (2, "john")], s1)
+
+        s2 = StructType(
+            [StructField("id", IntegerType(), True), StructField("name", 
StringType(), False)]
+        )
+
+        df2 = self.spark.createDataFrame([(1, "jane"), (2, "john")], s2)
+
+        assertDataFrameEqual(df1, df2)
+
+    def test_schema_ignore_nullable_array_equal(self):
+        s1 = StructType([StructField("names", ArrayType(DoubleType(), True), 
True)])
+        s2 = StructType([StructField("names", ArrayType(DoubleType(), False), 
False)])
+
+        assertSchemaEqual(s1, s2)
+
+    def test_schema_ignore_nullable_struct_equal(self):
+        s1 = StructType(
+            [StructField("names", StructType([StructField("age", 
IntegerType(), True)]), True)]
+        )
+        s2 = StructType(
+            [StructField("names", StructType([StructField("age", 
IntegerType(), False)]), False)]
+        )
+        assertSchemaEqual(s1, s2)
+
+    def test_schema_array_unequal(self):
+        s1 = StructType([StructField("names", ArrayType(IntegerType(), True), 
True)])
+        s2 = StructType([StructField("names", ArrayType(DoubleType(), False), 
False)])
+
+        generated_diff = difflib.ndiff(str(s1).splitlines(), 
str(s2).splitlines())
+
+        expected_error_msg = "\n".join(generated_diff)
+
         with self.assertRaises(PySparkAssertionError) as pe:
-            assertDataFrameEqual(df1, df2, checkRowOrder=True)
+            assertSchemaEqual(s1, s2)
 
         self.check_error(
             exception=pe.exception,
             error_class="DIFFERENT_SCHEMA",
-            message_parameters={"df_schema": df1.schema, "expected_schema": 
df2.schema},
+            message_parameters={"error_msg": expected_error_msg},
+        )
+
+    def test_schema_struct_unequal(self):
+        s1 = StructType(
+            [StructField("names", StructType([StructField("age", DoubleType(), 
True)]), True)]
+        )
+        s2 = StructType(
+            [StructField("names", StructType([StructField("age", 
IntegerType(), True)]), True)]
+        )
+
+        generated_diff = difflib.ndiff(str(s1).splitlines(), 
str(s2).splitlines())
+
+        expected_error_msg = "\n".join(generated_diff)
+
+        with self.assertRaises(PySparkAssertionError) as pe:
+            assertSchemaEqual(s1, s2)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="DIFFERENT_SCHEMA",
+            message_parameters={"error_msg": expected_error_msg},
+        )
+
+    def test_schema_more_nested_struct_unequal(self):
+        s1 = StructType(
+            [
+                StructField(
+                    "name",
+                    StructType(
+                        [
+                            StructField("firstname", StringType(), True),
+                            StructField("middlename", StringType(), True),
+                            StructField("lastname", StringType(), True),
+                        ]
+                    ),
+                ),
+            ]
+        )
+
+        s2 = StructType(
+            [
+                StructField(
+                    "name",
+                    StructType(
+                        [
+                            StructField("firstname", StringType(), True),
+                            StructField("middlename", BooleanType(), True),
+                            StructField("lastname", StringType(), True),
+                        ]
+                    ),
+                ),
+            ]
+        )
+
+        generated_diff = difflib.ndiff(str(s1).splitlines(), 
str(s2).splitlines())
+
+        expected_error_msg = "\n".join(generated_diff)
+
+        with self.assertRaises(PySparkAssertionError) as pe:
+            assertSchemaEqual(s1, s2)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="DIFFERENT_SCHEMA",
+            message_parameters={"error_msg": expected_error_msg},
+        )
+
+    def test_schema_unsupported_type(self):
+        s1 = "names: int"
+        s2 = "names: int"
+
+        with self.assertRaises(PySparkAssertionError) as pe:
+            assertSchemaEqual(s1, s2)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="UNSUPPORTED_DATA_TYPE",
+            message_parameters={"data_type": type(s1)},
         )
 
     def test_spark_sql(self):
@@ -1056,7 +1174,7 @@ class UtilsTestsMixin:
         percent_diff = (2 / 2) * 100
         expected_error_message += "( %.5f %% )" % percent_diff
         diff_msg = (
-            "[df]"
+            "[actual]"
             + "\n"
             + str(df1.collect()[0])
             + "\n\n"
@@ -1068,7 +1186,7 @@ class UtilsTestsMixin:
             + "\n\n"
         )
         diff_msg += (
-            "[df]"
+            "[actual]"
             + "\n"
             + str(df1.collect()[1])
             + "\n\n"
diff --git a/python/pyspark/testing/__init__.py 
b/python/pyspark/testing/__init__.py
index 1bf70befc42..88853e925f8 100644
--- a/python/pyspark/testing/__init__.py
+++ b/python/pyspark/testing/__init__.py
@@ -14,6 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-from pyspark.testing.utils import assertDataFrameEqual
+from pyspark.testing.utils import assertDataFrameEqual, assertSchemaEqual
 
-__all__ = ["assertDataFrameEqual"]
+__all__ = ["assertDataFrameEqual", "assertSchemaEqual"]
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index 476470d7490..21c7b7e4dcd 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -20,6 +20,7 @@ import os
 import struct
 import sys
 import unittest
+import difflib
 from time import time, sleep
 from typing import (
     Any,
@@ -36,7 +37,7 @@ from pyspark.errors import PySparkAssertionError, 
PySparkException
 from pyspark.find_spark_home import _find_spark_home
 from pyspark.sql.dataframe import DataFrame as DataFrame
 from pyspark.sql import Row
-from pyspark.sql.types import StructType, AtomicType
+from pyspark.sql.types import StructType, AtomicType, StructField
 
 have_scipy = False
 have_numpy = False
@@ -55,7 +56,7 @@ except ImportError:
     # No NumPy, but that's okay, we'll skip those tests
     pass
 
-__all__ = ["assertDataFrameEqual"]
+__all__ = ["assertDataFrameEqual", "assertSchemaEqual"]
 
 SPARK_HOME = _find_spark_home()
 
@@ -222,22 +223,112 @@ class PySparkErrorTestUtils:
         )
 
 
+def assertSchemaEqual(actual: StructType, expected: StructType):
+    r"""
+    A util function to assert equality between DataFrame schemas `actual` and 
`expected`.
+
+    .. versionadded:: 3.5.0
+
+    Parameters
+    ----------
+    actual : StructType
+        The DataFrame schema that is being compared or tested.
+    expected : StructType
+        The expected schema, for comparison with the actual schema.
+
+    Notes
+    -----
+    When assertSchemaEqual fails, the error message uses the Python `difflib` 
library to display
+    a diff log of the `actual` and `expected` schemas.
+
+    Examples
+    --------
+    >>> from pyspark.sql.types import StructType, StructField, ArrayType, 
IntegerType, DoubleType
+    >>> s1 = StructType([StructField("names", ArrayType(DoubleType(), True), 
True)])
+    >>> s2 = StructType([StructField("names", ArrayType(DoubleType(), True), 
True)])
+    >>> assertSchemaEqual(s1, s2)  # pass, schemas are identical
+    >>> df1 = spark.createDataFrame(data=[(1, 1000), (2, 3000)], schema=["id", 
"number"])
+    >>> df2 = spark.createDataFrame(data=[("1", 1000), ("2", 5000)], 
schema=["id", "amount"])
+    >>> assertSchemaEqual(df1.schema, df2.schema)  # doctest: 
+IGNORE_EXCEPTION_DETAIL
+    Traceback (most recent call last):
+    ...
+    PySparkAssertionError: [DIFFERENT_SCHEMA] Schemas do not match.
+    --- actual
+    +++ expected
+    - StructType([StructField('id', LongType(), True), StructField('number', 
LongType(), True)])
+    ?                               ^^                               ^^^^^
+    + StructType([StructField('id', StringType(), True), StructField('amount', 
LongType(), True)])
+    ?                               ^^^^                              ++++ ^
+    """
+    if not isinstance(actual, StructType):
+        raise PySparkAssertionError(
+            error_class="UNSUPPORTED_DATA_TYPE",
+            message_parameters={"data_type": type(actual)},
+        )
+    if not isinstance(expected, StructType):
+        raise PySparkAssertionError(
+            error_class="UNSUPPORTED_DATA_TYPE",
+            message_parameters={"data_type": type(expected)},
+        )
+
+    def compare_schemas_ignore_nullable(s1: StructType, s2: StructType):
+        if len(s1) != len(s2):
+            return False
+        zipped = zip_longest(s1, s2)
+        for sf1, sf2 in zipped:
+            if not compare_structfields_ignore_nullable(sf1, sf2):
+                return False
+        return True
+
+    def compare_structfields_ignore_nullable(actualSF: StructField, 
expectedSF: StructField):
+        if actualSF is None and expectedSF is None:
+            return True
+        elif actualSF is None or expectedSF is None:
+            return False
+        if actualSF.name != expectedSF.name:
+            return False
+        else:
+            return compare_datatypes_ignore_nullable(actualSF.dataType, 
expectedSF.dataType)
+
+    def compare_datatypes_ignore_nullable(dt1: Any, dt2: Any):
+        # checks datatype equality, using recursion to ignore nullable
+        if dt1.typeName() == dt2.typeName():
+            if dt1.typeName() == "array":
+                return compare_datatypes_ignore_nullable(dt1.elementType, 
dt2.elementType)
+            elif dt1.typeName() == "struct":
+                return compare_schemas_ignore_nullable(dt1, dt2)
+            else:
+                return True
+        else:
+            return False
+
+    if not compare_schemas_ignore_nullable(actual, expected):
+        generated_diff = difflib.ndiff(str(actual).splitlines(), 
str(expected).splitlines())
+
+        error_msg = "\n".join(generated_diff)
+
+        raise PySparkAssertionError(
+            error_class="DIFFERENT_SCHEMA",
+            message_parameters={"error_msg": error_msg},
+        )
+
+
 def assertDataFrameEqual(
-    df: DataFrame,
+    actual: DataFrame,
     expected: Union[DataFrame, List[Row]],
     checkRowOrder: bool = False,
     rtol: float = 1e-5,
     atol: float = 1e-8,
 ):
-    """
-    A util function to assert equality between `df` (DataFrame) and `expected`
+    r"""
+    A util function to assert equality between `actual` (DataFrame) and 
`expected`
     (either DataFrame or list of Rows), with optional parameter 
`checkRowOrder`.
 
     .. versionadded:: 3.5.0
 
     Parameters
     ----------
-    df : DataFrame
+    actual : DataFrame
         The DataFrame that is being compared or tested.
     expected : DataFrame or list of Rows
         The expected result of the operation, for comparison with the actual 
result.
@@ -247,10 +338,10 @@ def assertDataFrameEqual(
         If set to `True`, the order of rows is important and will be checked 
during comparison.
         (See Notes)
     rtol : float, optional
-        The relative tolerance, used in asserting approximate equality for 
float values in df
+        The relative tolerance, used in asserting approximate equality for 
float values in actual
         and expected. Set to 1e-5 by default. (See Notes)
     atol : float, optional
-        The absolute tolerance, used in asserting approximate equality for 
float values in df
+        The absolute tolerance, used in asserting approximate equality for 
float values in actual
         and expected. Set to 1e-8 by default. (See Notes)
 
     Notes
@@ -267,49 +358,40 @@ def assertDataFrameEqual(
     --------
     >>> df1 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], 
schema=["id", "amount"])
     >>> df2 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], 
schema=["id", "amount"])
-    >>> assertDataFrameEqual(df1, df2)
-
-    Pass, DataFrames are identical
-
+    >>> assertDataFrameEqual(df1, df2)  # pass, DataFrames are identical
     >>> df1 = spark.createDataFrame(data=[("1", 0.1), ("2", 3.23)], 
schema=["id", "amount"])
     >>> df2 = spark.createDataFrame(data=[("1", 0.109), ("2", 3.23)], 
schema=["id", "amount"])
-    >>> assertDataFrameEqual(df1, df2, rtol=1e-1)
-
-    Pass, DataFrames are approx equal by rtol
-
-    >>> df1 = spark.createDataFrame(data=[("1", 1000.00), ("2", 3000.00), 
("3", 2000.00)],
-    ... schema=["id", "amount"])
-    >>> df2 = spark.createDataFrame(data=[("1", 1001.00), ("2", 3000.00), 
("3", 2003.00)],
-    ... schema=["id", "amount"])
+    >>> assertDataFrameEqual(df1, df2, rtol=1e-1)  # pass, DataFrames are 
approx equal by rtol
+    >>> df1 = spark.createDataFrame(
+    ...     data=[("1", 1000.00), ("2", 3000.00), ("3", 2000.00)], 
schema=["id", "amount"])
+    >>> df2 = spark.createDataFrame(
+    ...     data=[("1", 1001.00), ("2", 3000.00), ("3", 2003.00)], 
schema=["id", "amount"])
     >>> assertDataFrameEqual(df1, df2)  # doctest: +IGNORE_EXCEPTION_DETAIL
     Traceback (most recent call last):
     ...
     PySparkAssertionError: [DIFFERENT_ROWS] Results do not match: ( 66.667 % )
-    [df]
+    [actual]
     Row(id='1', amount=1000.0)
-
     [expected]
     Row(id='1', amount=1001.0)
-
-    [df]
+    [actual]
     Row(id='3', amount=2000.0)
-
     [expected]
     Row(id='3', amount=2003.0)
     """
-    if df is None and expected is None:
+    if actual is None and expected is None:
         return True
-    elif df is None or expected is None:
+    elif actual is None or expected is None:
         return False
 
     try:
         # If Spark Connect dependencies are available, allow Spark Connect 
DataFrame
         from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
 
-        if not isinstance(df, DataFrame) and not isinstance(df, 
ConnectDataFrame):
+        if not isinstance(actual, DataFrame) and not isinstance(actual, 
ConnectDataFrame):
             raise PySparkAssertionError(
                 error_class="UNSUPPORTED_DATA_TYPE",
-                message_parameters={"data_type": type(df)},
+                message_parameters={"data_type": type(actual)},
             )
         elif (
             not isinstance(expected, DataFrame)
@@ -321,10 +403,10 @@ def assertDataFrameEqual(
                 message_parameters={"data_type": type(expected)},
             )
     except Exception:
-        if not isinstance(df, DataFrame):
+        if not isinstance(actual, DataFrame):
             raise PySparkAssertionError(
                 error_class="UNSUPPORTED_DATA_TYPE",
-                message_parameters={"data_type": type(df)},
+                message_parameters={"data_type": type(actual)},
             )
         elif not isinstance(expected, DataFrame) and not isinstance(expected, 
List):
             raise PySparkAssertionError(
@@ -333,8 +415,8 @@ def assertDataFrameEqual(
             )
 
     # special cases: empty datasets, datasets with 0 columns
-    if (df.first() is None and expected.first() is None) or (
-        len(df.columns) == 0 and len(expected.columns) == 0
+    if (actual.first() is None and expected.first() is None) or (
+        len(actual.columns) == 0 and len(expected.columns) == 0
     ):
         return True
 
@@ -367,16 +449,6 @@ def assertDataFrameEqual(
 
         return compare_vals(r1, r2)
 
-    def assert_schema_equal(
-        df_schema: StructType,
-        expected_schema: StructType,
-    ):
-        if df_schema != expected_schema:
-            raise PySparkAssertionError(
-                error_class="DIFFERENT_SCHEMA",
-                message_parameters={"df_schema": df_schema, "expected_schema": 
expected_schema},
-            )
-
     def assert_rows_equal(rows1: List[Row], rows2: List[Row]):
         zipped = list(zip_longest(rows1, rows2))
         rows_equal = True
@@ -389,7 +461,7 @@ def assertDataFrameEqual(
                 rows_equal = False
                 diff_rows_cnt += 1
                 diff_msg += (
-                    "[df]" + "\n" + str(r1) + "\n\n" + "[expected]" + "\n" + 
str(r2) + "\n\n"
+                    "[actual]" + "\n" + str(r1) + "\n\n" + "[expected]" + "\n" 
+ str(r2) + "\n\n"
                 )
                 diff_msg += "********************" + "\n\n"
 
@@ -402,15 +474,15 @@ def assertDataFrameEqual(
                 message_parameters={"error_msg": error_msg},
             )
 
-    # convert df and expected to list
+    # convert actual and expected to list
     if not isinstance(expected, List):
         # only compare schema if expected is not a List
-        assert_schema_equal(df.schema, expected.schema)
+        assertSchemaEqual(actual.schema, expected.schema)
         expected_list = expected.collect()
     else:
         expected_list = expected
 
-    df_list = df.collect()
+    df_list = actual.collect()
 
     if not checkRowOrder:
         # rename duplicate columns for sorting


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to