allisonwang-db commented on code in PR #41927:
URL: https://github.com/apache/spark/pull/41927#discussion_r1264721096


##########
python/pyspark/testing/utils.py:
##########
@@ -267,49 +358,40 @@ def assertDataFrameEqual(
     --------
     >>> df1 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], 
schema=["id", "amount"])
     >>> df2 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], 
schema=["id", "amount"])
-    >>> assertDataFrameEqual(df1, df2)
-
-    Pass, DataFrames are identical
-
+    >>> assertDataFrameEqual(df1, df2)  # pass, DataFrames are identical
     >>> df1 = spark.createDataFrame(data=[("1", 0.1), ("2", 3.23)], 
schema=["id", "amount"])
     >>> df2 = spark.createDataFrame(data=[("1", 0.109), ("2", 3.23)], 
schema=["id", "amount"])
-    >>> assertDataFrameEqual(df1, df2, rtol=1e-1)
-
-    Pass, DataFrames are approx equal by rtol
-
-    >>> df1 = spark.createDataFrame(data=[("1", 1000.00), ("2", 3000.00), 
("3", 2000.00)],
-    ... schema=["id", "amount"])
-    >>> df2 = spark.createDataFrame(data=[("1", 1001.00), ("2", 3000.00), 
("3", 2003.00)],
-    ... schema=["id", "amount"])
+    >>> assertDataFrameEqual(df1, df2, rtol=1e-1)  # pass, DataFrames are 
approx equal by rtol
+    >>> df1 = spark.createDataFrame(
+    ...     data=[("1", 1000.00), ("2", 3000.00), ("3", 2000.00)], 
schema=["id", "amount"])
+    >>> df2 = spark.createDataFrame(
+    ...     data=[("1", 1001.00), ("2", 3000.00), ("3", 2003.00)], 
schema=["id", "amount"])

Review Comment:
   We should add another email using list[Row]



##########
python/pyspark/testing/utils.py:
##########
@@ -333,8 +415,8 @@ def assertDataFrameEqual(
             )
 
     # special cases: empty datasets, datasets with 0 columns
-    if (df.first() is None and expected.first() is None) or (
-        len(df.columns) == 0 and len(expected.columns) == 0
+    if (actual.first() is None and expected.first() is None) or (

Review Comment:
   This will fail when expected is a list. We can have a follow up PR to fix 
this.



##########
python/pyspark/testing/utils.py:
##########
@@ -222,22 +223,112 @@ def check_error(
         )
 
 
+def assertSchemaEqual(actual: StructType, expected: StructType):
+    r"""
+    A util function to assert equality between DataFrame schemas `actual` and 
`expected`.
+
+    .. versionadded:: 3.5.0
+
+    Parameters
+    ----------
+    actual : StructType
+        The DataFrame schema that is being compared or tested.
+    expected : StructType
+        The expected schema, for comparison with the actual schema.
+
+    Notes
+    -----
+    When assertSchemaEqual fails, the error message uses the Python `difflib` 
library to display
+    a diff log of the `actual` and `expected` schemas.
+
+    Examples
+    --------
+    >>> from pyspark.sql.types import StructType, StructField, ArrayType, 
IntegerType, DoubleType
+    >>> s1 = StructType([StructField("names", ArrayType(DoubleType(), True), 
True)])
+    >>> s2 = StructType([StructField("names", ArrayType(DoubleType(), True), 
True)])
+    >>> assertSchemaEqual(s1, s2)  # pass, schemas are identical
+    >>> df1 = spark.createDataFrame(data=[(1, 1000), (2, 3000)], schema=["id", 
"number"])
+    >>> df2 = spark.createDataFrame(data=[("1", 1000), ("2", 5000)], 
schema=["id", "amount"])
+    >>> assertSchemaEqual(df1.schema, df2.schema)  # doctest: 
+IGNORE_EXCEPTION_DETAIL
+    Traceback (most recent call last):
+    ...
+    PySparkAssertionError: [DIFFERENT_SCHEMA] Schemas do not match.
+    --- actual
+    +++ expected
+    - StructType([StructField('id', LongType(), True), StructField('number', 
LongType(), True)])
+    ?                               ^^                               ^^^^^
+    + StructType([StructField('id', StringType(), True), StructField('amount', 
LongType(), True)])
+    ?                               ^^^^                              ++++ ^
+    """
+    if not isinstance(actual, StructType):
+        raise PySparkAssertionError(
+            error_class="UNSUPPORTED_DATA_TYPE",
+            message_parameters={"data_type": type(actual)},
+        )
+    if not isinstance(expected, StructType):
+        raise PySparkAssertionError(
+            error_class="UNSUPPORTED_DATA_TYPE",
+            message_parameters={"data_type": type(expected)},
+        )
+
+    def compare_schemas_ignore_nullable(s1: StructType, s2: StructType):
+        if len(s1) != len(s2):
+            return False
+        zipped = zip_longest(s1, s2)
+        for sf1, sf2 in zipped:
+            if not compare_structfields_ignore_nullable(sf1, sf2):
+                return False
+        return True
+
+    def compare_structfields_ignore_nullable(actualSF: StructField, 
expectedSF: StructField):
+        if actualSF is None and expectedSF is None:
+            return True
+        elif actualSF is None or expectedSF is None:
+            return False
+        if actualSF.name != expectedSF.name:
+            return False
+        else:
+            return compare_datatypes_ignore_nullable(actualSF.dataType, 
expectedSF.dataType)
+
+    def compare_datatypes_ignore_nullable(dt1: Any, dt2: Any):
+        # checks datatype equality, using recursion to ignore nullable
+        if dt1.typeName() == dt2.typeName():
+            if dt1.typeName() == "array":
+                return compare_datatypes_ignore_nullable(dt1.elementType, 
dt2.elementType)
+            elif dt1.typeName() == "struct":
+                return compare_schemas_ignore_nullable(dt1, dt2)
+            else:
+                return True
+        else:
+            return False
+
+    if not compare_schemas_ignore_nullable(actual, expected):
+        generated_diff = difflib.ndiff(str(actual).splitlines(), 
str(expected).splitlines())
+
+        error_msg = "\n".join(generated_diff)
+
+        raise PySparkAssertionError(
+            error_class="DIFFERENT_SCHEMA",
+            message_parameters={"error_msg": error_msg},
+        )
+
+
 def assertDataFrameEqual(
-    df: DataFrame,
+    actual: DataFrame,
     expected: Union[DataFrame, List[Row]],
     checkRowOrder: bool = False,

Review Comment:
   We can add a flag to not check the schema. And in the docstring we should 
mention that we only check schema when `expect` is a Dataframe (not a list of 
rows).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to