This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a00a32a2f96 [SPARK-44216][PYTHON] Make assertSchemaEqual API public a00a32a2f96 is described below commit a00a32a2f967908eb160e8476f6640c7d5d3fee2 Author: Amanda Liu <amanda....@databricks.com> AuthorDate: Mon Jul 17 08:46:11 2023 +0900 [SPARK-44216][PYTHON] Make assertSchemaEqual API public ### What changes were proposed in this pull request? This PR implements and exposes the PySpark util function `assertSchemaEqual` to test for DataFrame schema equality. It uses the Python built-in difflib library to display differences between schemas. SPIP: https://docs.google.com/document/d/1OkyBn3JbEHkkQgSQ45Lq82esXjr9rm2Vj7Ih_4zycRc/edit#heading=h.f5f0u2riv07v ### Why are the changes needed? The `assertSchemaEqual` function compares schema equality, simplifying the testing process for PySpark users. It adds functionality to ignore the nullability flag for schemas and nested types. ### Does this PR introduce _any_ user-facing change? Yes, the PR exposes a user-facing PySpark util function `assertSchemaEqual`. ### How was this patch tested? Added tests to `runtime/python/pyspark/sql/tests/test_utils.py` and `runtime/python/pyspark/sql/tests/connect/test_utils.py` Sample schema inequality error message: ![Screenshot 2023-07-14 at 10 07 22 AM](https://github.com/apache/spark/assets/68875504/e1aa0c87-e2c4-44c7-b84c-8af21d5fad84) Closes #41927 from asl3/assert-schema-equal. Authored-by: Amanda Liu <amanda....@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/docs/source/reference/pyspark.testing.rst | 1 + python/pyspark/errors/error_classes.py | 7 +- python/pyspark/sql/tests/test_utils.py | 166 +++++++++++++++++++---- python/pyspark/testing/__init__.py | 4 +- python/pyspark/testing/utils.py | 166 ++++++++++++++++------- 5 files changed, 268 insertions(+), 76 deletions(-) diff --git a/python/docs/source/reference/pyspark.testing.rst b/python/docs/source/reference/pyspark.testing.rst index df0a334db70..7a6b6cc0d70 100644 --- a/python/docs/source/reference/pyspark.testing.rst +++ b/python/docs/source/reference/pyspark.testing.rst @@ -26,3 +26,4 @@ Testing :toctree: api/ assertDataFrameEqual + assertSchemaEqual diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index 56b166b53c5..2cecee4da44 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -171,9 +171,10 @@ ERROR_CLASSES_JSON = """ }, "DIFFERENT_SCHEMA" : { "message" : [ - "Schemas do not match:", - "df schema: <df_schema>", - "expected schema: <expected_schema>" + "Schemas do not match.", + "--- actual", + "+++ expected", + "<error_msg>" ] }, "DISALLOWED_TYPE_FOR_CONTAINER" : { diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py index 46692450cbd..5b859ad15a5 100644 --- a/python/pyspark/sql/tests/test_utils.py +++ b/python/pyspark/sql/tests/test_utils.py @@ -23,7 +23,7 @@ from pyspark.errors import ( IllegalArgumentException, SparkUpgradeException, ) -from pyspark.testing.utils import assertDataFrameEqual +from pyspark.testing.utils import assertDataFrameEqual, assertSchemaEqual from pyspark.testing.sqlutils import ReusedSQLTestCase import pyspark.sql.functions as F from pyspark.sql.functions import to_date, unix_timestamp, from_unixtime @@ -37,8 +37,11 @@ from pyspark.sql.types import ( DoubleType, StructField, IntegerType, + BooleanType, ) +import difflib + class UtilsTestsMixin: def test_assert_equal_inttype(self): @@ -149,7 +152,7 @@ class UtilsTestsMixin: percent_diff = (1 / 2) * 100 expected_error_message += "( %.5f %% )" % percent_diff diff_msg = ( - "[df]" + "[actual]" + "\n" + str(df1.collect()[1]) + "\n\n" @@ -292,7 +295,7 @@ class UtilsTestsMixin: percent_diff = (1 / 2) * 100 expected_error_message += "( %.5f %% )" % percent_diff diff_msg = ( - "[df]" + "[actual]" + "\n" + str(df1.collect()[1]) + "\n\n" @@ -596,7 +599,7 @@ class UtilsTestsMixin: percent_diff = (1 / 2) * 100 expected_error_message += "( %.5f %% )" % percent_diff diff_msg = ( - "[df]" + "[actual]" + "\n" + str(df1.collect()[1]) + "\n\n" @@ -720,7 +723,7 @@ class UtilsTestsMixin: percent_diff = (2 / 2) * 100 expected_error_message += "( %.5f %% )" % percent_diff diff_msg = ( - "[df]" + "[actual]" + "\n" + str(df1.collect()[0]) + "\n\n" @@ -732,7 +735,7 @@ class UtilsTestsMixin: + "\n\n" ) diff_msg += ( - "[df]" + "[actual]" + "\n" + str(df1.collect()[1]) + "\n\n" @@ -827,7 +830,7 @@ class UtilsTestsMixin: percent_diff = (2 / 3) * 100 expected_error_message += "( %.5f %% )" % percent_diff diff_msg = ( - "[df]" + "[actual]" + "\n" + str(df1.collect()[0]) + "\n\n" @@ -839,7 +842,7 @@ class UtilsTestsMixin: + "\n\n" ) diff_msg += ( - "[df]" + "[actual]" + "\n" + str(df1.collect()[2]) + "\n\n" @@ -876,32 +879,27 @@ class UtilsTestsMixin: (1, 1000), (2, 3000), ], - schema=["id", "amount"], + schema=["id", "number"], ) df2 = self.spark.createDataFrame( data=[ ("1", 1000), - ("2", 3000), + ("2", 5000), ], schema=["id", "amount"], ) - with self.assertRaises(PySparkAssertionError) as pe: - assertDataFrameEqual(df1, df2) + generated_diff = difflib.ndiff(str(df1.schema).splitlines(), str(df2.schema).splitlines()) - self.check_error( - exception=pe.exception, - error_class="DIFFERENT_SCHEMA", - message_parameters={"df_schema": df1.schema, "expected_schema": df2.schema}, - ) + expected_error_msg = "\n".join(generated_diff) with self.assertRaises(PySparkAssertionError) as pe: - assertDataFrameEqual(df1, df2, checkRowOrder=True) + assertDataFrameEqual(df1, df2) self.check_error( exception=pe.exception, error_class="DIFFERENT_SCHEMA", - message_parameters={"df_schema": df1.schema, "expected_schema": df2.schema}, + message_parameters={"error_msg": expected_error_msg}, ) def test_diff_schema_lens(self): @@ -921,22 +919,142 @@ class UtilsTestsMixin: schema=["id", "amount", "letter"], ) + generated_diff = difflib.ndiff(str(df1.schema).splitlines(), str(df2.schema).splitlines()) + + expected_error_msg = "\n".join(generated_diff) + with self.assertRaises(PySparkAssertionError) as pe: assertDataFrameEqual(df1, df2) self.check_error( exception=pe.exception, error_class="DIFFERENT_SCHEMA", - message_parameters={"df_schema": df1.schema, "expected_schema": df2.schema}, + message_parameters={"error_msg": expected_error_msg}, + ) + + def test_schema_ignore_nullable(self): + s1 = StructType( + [StructField("id", IntegerType(), True), StructField("name", StringType(), True)] ) + df1 = self.spark.createDataFrame([(1, "jane"), (2, "john")], s1) + + s2 = StructType( + [StructField("id", IntegerType(), True), StructField("name", StringType(), False)] + ) + + df2 = self.spark.createDataFrame([(1, "jane"), (2, "john")], s2) + + assertDataFrameEqual(df1, df2) + + def test_schema_ignore_nullable_array_equal(self): + s1 = StructType([StructField("names", ArrayType(DoubleType(), True), True)]) + s2 = StructType([StructField("names", ArrayType(DoubleType(), False), False)]) + + assertSchemaEqual(s1, s2) + + def test_schema_ignore_nullable_struct_equal(self): + s1 = StructType( + [StructField("names", StructType([StructField("age", IntegerType(), True)]), True)] + ) + s2 = StructType( + [StructField("names", StructType([StructField("age", IntegerType(), False)]), False)] + ) + assertSchemaEqual(s1, s2) + + def test_schema_array_unequal(self): + s1 = StructType([StructField("names", ArrayType(IntegerType(), True), True)]) + s2 = StructType([StructField("names", ArrayType(DoubleType(), False), False)]) + + generated_diff = difflib.ndiff(str(s1).splitlines(), str(s2).splitlines()) + + expected_error_msg = "\n".join(generated_diff) + with self.assertRaises(PySparkAssertionError) as pe: - assertDataFrameEqual(df1, df2, checkRowOrder=True) + assertSchemaEqual(s1, s2) self.check_error( exception=pe.exception, error_class="DIFFERENT_SCHEMA", - message_parameters={"df_schema": df1.schema, "expected_schema": df2.schema}, + message_parameters={"error_msg": expected_error_msg}, + ) + + def test_schema_struct_unequal(self): + s1 = StructType( + [StructField("names", StructType([StructField("age", DoubleType(), True)]), True)] + ) + s2 = StructType( + [StructField("names", StructType([StructField("age", IntegerType(), True)]), True)] + ) + + generated_diff = difflib.ndiff(str(s1).splitlines(), str(s2).splitlines()) + + expected_error_msg = "\n".join(generated_diff) + + with self.assertRaises(PySparkAssertionError) as pe: + assertSchemaEqual(s1, s2) + + self.check_error( + exception=pe.exception, + error_class="DIFFERENT_SCHEMA", + message_parameters={"error_msg": expected_error_msg}, + ) + + def test_schema_more_nested_struct_unequal(self): + s1 = StructType( + [ + StructField( + "name", + StructType( + [ + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), + ] + ), + ), + ] + ) + + s2 = StructType( + [ + StructField( + "name", + StructType( + [ + StructField("firstname", StringType(), True), + StructField("middlename", BooleanType(), True), + StructField("lastname", StringType(), True), + ] + ), + ), + ] + ) + + generated_diff = difflib.ndiff(str(s1).splitlines(), str(s2).splitlines()) + + expected_error_msg = "\n".join(generated_diff) + + with self.assertRaises(PySparkAssertionError) as pe: + assertSchemaEqual(s1, s2) + + self.check_error( + exception=pe.exception, + error_class="DIFFERENT_SCHEMA", + message_parameters={"error_msg": expected_error_msg}, + ) + + def test_schema_unsupported_type(self): + s1 = "names: int" + s2 = "names: int" + + with self.assertRaises(PySparkAssertionError) as pe: + assertSchemaEqual(s1, s2) + + self.check_error( + exception=pe.exception, + error_class="UNSUPPORTED_DATA_TYPE", + message_parameters={"data_type": type(s1)}, ) def test_spark_sql(self): @@ -1056,7 +1174,7 @@ class UtilsTestsMixin: percent_diff = (2 / 2) * 100 expected_error_message += "( %.5f %% )" % percent_diff diff_msg = ( - "[df]" + "[actual]" + "\n" + str(df1.collect()[0]) + "\n\n" @@ -1068,7 +1186,7 @@ class UtilsTestsMixin: + "\n\n" ) diff_msg += ( - "[df]" + "[actual]" + "\n" + str(df1.collect()[1]) + "\n\n" diff --git a/python/pyspark/testing/__init__.py b/python/pyspark/testing/__init__.py index 1bf70befc42..88853e925f8 100644 --- a/python/pyspark/testing/__init__.py +++ b/python/pyspark/testing/__init__.py @@ -14,6 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from pyspark.testing.utils import assertDataFrameEqual +from pyspark.testing.utils import assertDataFrameEqual, assertSchemaEqual -__all__ = ["assertDataFrameEqual"] +__all__ = ["assertDataFrameEqual", "assertSchemaEqual"] diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index 476470d7490..21c7b7e4dcd 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -20,6 +20,7 @@ import os import struct import sys import unittest +import difflib from time import time, sleep from typing import ( Any, @@ -36,7 +37,7 @@ from pyspark.errors import PySparkAssertionError, PySparkException from pyspark.find_spark_home import _find_spark_home from pyspark.sql.dataframe import DataFrame as DataFrame from pyspark.sql import Row -from pyspark.sql.types import StructType, AtomicType +from pyspark.sql.types import StructType, AtomicType, StructField have_scipy = False have_numpy = False @@ -55,7 +56,7 @@ except ImportError: # No NumPy, but that's okay, we'll skip those tests pass -__all__ = ["assertDataFrameEqual"] +__all__ = ["assertDataFrameEqual", "assertSchemaEqual"] SPARK_HOME = _find_spark_home() @@ -222,22 +223,112 @@ class PySparkErrorTestUtils: ) +def assertSchemaEqual(actual: StructType, expected: StructType): + r""" + A util function to assert equality between DataFrame schemas `actual` and `expected`. + + .. versionadded:: 3.5.0 + + Parameters + ---------- + actual : StructType + The DataFrame schema that is being compared or tested. + expected : StructType + The expected schema, for comparison with the actual schema. + + Notes + ----- + When assertSchemaEqual fails, the error message uses the Python `difflib` library to display + a diff log of the `actual` and `expected` schemas. + + Examples + -------- + >>> from pyspark.sql.types import StructType, StructField, ArrayType, IntegerType, DoubleType + >>> s1 = StructType([StructField("names", ArrayType(DoubleType(), True), True)]) + >>> s2 = StructType([StructField("names", ArrayType(DoubleType(), True), True)]) + >>> assertSchemaEqual(s1, s2) # pass, schemas are identical + >>> df1 = spark.createDataFrame(data=[(1, 1000), (2, 3000)], schema=["id", "number"]) + >>> df2 = spark.createDataFrame(data=[("1", 1000), ("2", 5000)], schema=["id", "amount"]) + >>> assertSchemaEqual(df1.schema, df2.schema) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + PySparkAssertionError: [DIFFERENT_SCHEMA] Schemas do not match. + --- actual + +++ expected + - StructType([StructField('id', LongType(), True), StructField('number', LongType(), True)]) + ? ^^ ^^^^^ + + StructType([StructField('id', StringType(), True), StructField('amount', LongType(), True)]) + ? ^^^^ ++++ ^ + """ + if not isinstance(actual, StructType): + raise PySparkAssertionError( + error_class="UNSUPPORTED_DATA_TYPE", + message_parameters={"data_type": type(actual)}, + ) + if not isinstance(expected, StructType): + raise PySparkAssertionError( + error_class="UNSUPPORTED_DATA_TYPE", + message_parameters={"data_type": type(expected)}, + ) + + def compare_schemas_ignore_nullable(s1: StructType, s2: StructType): + if len(s1) != len(s2): + return False + zipped = zip_longest(s1, s2) + for sf1, sf2 in zipped: + if not compare_structfields_ignore_nullable(sf1, sf2): + return False + return True + + def compare_structfields_ignore_nullable(actualSF: StructField, expectedSF: StructField): + if actualSF is None and expectedSF is None: + return True + elif actualSF is None or expectedSF is None: + return False + if actualSF.name != expectedSF.name: + return False + else: + return compare_datatypes_ignore_nullable(actualSF.dataType, expectedSF.dataType) + + def compare_datatypes_ignore_nullable(dt1: Any, dt2: Any): + # checks datatype equality, using recursion to ignore nullable + if dt1.typeName() == dt2.typeName(): + if dt1.typeName() == "array": + return compare_datatypes_ignore_nullable(dt1.elementType, dt2.elementType) + elif dt1.typeName() == "struct": + return compare_schemas_ignore_nullable(dt1, dt2) + else: + return True + else: + return False + + if not compare_schemas_ignore_nullable(actual, expected): + generated_diff = difflib.ndiff(str(actual).splitlines(), str(expected).splitlines()) + + error_msg = "\n".join(generated_diff) + + raise PySparkAssertionError( + error_class="DIFFERENT_SCHEMA", + message_parameters={"error_msg": error_msg}, + ) + + def assertDataFrameEqual( - df: DataFrame, + actual: DataFrame, expected: Union[DataFrame, List[Row]], checkRowOrder: bool = False, rtol: float = 1e-5, atol: float = 1e-8, ): - """ - A util function to assert equality between `df` (DataFrame) and `expected` + r""" + A util function to assert equality between `actual` (DataFrame) and `expected` (either DataFrame or list of Rows), with optional parameter `checkRowOrder`. .. versionadded:: 3.5.0 Parameters ---------- - df : DataFrame + actual : DataFrame The DataFrame that is being compared or tested. expected : DataFrame or list of Rows The expected result of the operation, for comparison with the actual result. @@ -247,10 +338,10 @@ def assertDataFrameEqual( If set to `True`, the order of rows is important and will be checked during comparison. (See Notes) rtol : float, optional - The relative tolerance, used in asserting approximate equality for float values in df + The relative tolerance, used in asserting approximate equality for float values in actual and expected. Set to 1e-5 by default. (See Notes) atol : float, optional - The absolute tolerance, used in asserting approximate equality for float values in df + The absolute tolerance, used in asserting approximate equality for float values in actual and expected. Set to 1e-8 by default. (See Notes) Notes @@ -267,49 +358,40 @@ def assertDataFrameEqual( -------- >>> df1 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"]) >>> df2 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"]) - >>> assertDataFrameEqual(df1, df2) - - Pass, DataFrames are identical - + >>> assertDataFrameEqual(df1, df2) # pass, DataFrames are identical >>> df1 = spark.createDataFrame(data=[("1", 0.1), ("2", 3.23)], schema=["id", "amount"]) >>> df2 = spark.createDataFrame(data=[("1", 0.109), ("2", 3.23)], schema=["id", "amount"]) - >>> assertDataFrameEqual(df1, df2, rtol=1e-1) - - Pass, DataFrames are approx equal by rtol - - >>> df1 = spark.createDataFrame(data=[("1", 1000.00), ("2", 3000.00), ("3", 2000.00)], - ... schema=["id", "amount"]) - >>> df2 = spark.createDataFrame(data=[("1", 1001.00), ("2", 3000.00), ("3", 2003.00)], - ... schema=["id", "amount"]) + >>> assertDataFrameEqual(df1, df2, rtol=1e-1) # pass, DataFrames are approx equal by rtol + >>> df1 = spark.createDataFrame( + ... data=[("1", 1000.00), ("2", 3000.00), ("3", 2000.00)], schema=["id", "amount"]) + >>> df2 = spark.createDataFrame( + ... data=[("1", 1001.00), ("2", 3000.00), ("3", 2003.00)], schema=["id", "amount"]) >>> assertDataFrameEqual(df1, df2) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... PySparkAssertionError: [DIFFERENT_ROWS] Results do not match: ( 66.667 % ) - [df] + [actual] Row(id='1', amount=1000.0) - [expected] Row(id='1', amount=1001.0) - - [df] + [actual] Row(id='3', amount=2000.0) - [expected] Row(id='3', amount=2003.0) """ - if df is None and expected is None: + if actual is None and expected is None: return True - elif df is None or expected is None: + elif actual is None or expected is None: return False try: # If Spark Connect dependencies are available, allow Spark Connect DataFrame from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame - if not isinstance(df, DataFrame) and not isinstance(df, ConnectDataFrame): + if not isinstance(actual, DataFrame) and not isinstance(actual, ConnectDataFrame): raise PySparkAssertionError( error_class="UNSUPPORTED_DATA_TYPE", - message_parameters={"data_type": type(df)}, + message_parameters={"data_type": type(actual)}, ) elif ( not isinstance(expected, DataFrame) @@ -321,10 +403,10 @@ def assertDataFrameEqual( message_parameters={"data_type": type(expected)}, ) except Exception: - if not isinstance(df, DataFrame): + if not isinstance(actual, DataFrame): raise PySparkAssertionError( error_class="UNSUPPORTED_DATA_TYPE", - message_parameters={"data_type": type(df)}, + message_parameters={"data_type": type(actual)}, ) elif not isinstance(expected, DataFrame) and not isinstance(expected, List): raise PySparkAssertionError( @@ -333,8 +415,8 @@ def assertDataFrameEqual( ) # special cases: empty datasets, datasets with 0 columns - if (df.first() is None and expected.first() is None) or ( - len(df.columns) == 0 and len(expected.columns) == 0 + if (actual.first() is None and expected.first() is None) or ( + len(actual.columns) == 0 and len(expected.columns) == 0 ): return True @@ -367,16 +449,6 @@ def assertDataFrameEqual( return compare_vals(r1, r2) - def assert_schema_equal( - df_schema: StructType, - expected_schema: StructType, - ): - if df_schema != expected_schema: - raise PySparkAssertionError( - error_class="DIFFERENT_SCHEMA", - message_parameters={"df_schema": df_schema, "expected_schema": expected_schema}, - ) - def assert_rows_equal(rows1: List[Row], rows2: List[Row]): zipped = list(zip_longest(rows1, rows2)) rows_equal = True @@ -389,7 +461,7 @@ def assertDataFrameEqual( rows_equal = False diff_rows_cnt += 1 diff_msg += ( - "[df]" + "\n" + str(r1) + "\n\n" + "[expected]" + "\n" + str(r2) + "\n\n" + "[actual]" + "\n" + str(r1) + "\n\n" + "[expected]" + "\n" + str(r2) + "\n\n" ) diff_msg += "********************" + "\n\n" @@ -402,15 +474,15 @@ def assertDataFrameEqual( message_parameters={"error_msg": error_msg}, ) - # convert df and expected to list + # convert actual and expected to list if not isinstance(expected, List): # only compare schema if expected is not a List - assert_schema_equal(df.schema, expected.schema) + assertSchemaEqual(actual.schema, expected.schema) expected_list = expected.collect() else: expected_list = expected - df_list = df.collect() + df_list = actual.collect() if not checkRowOrder: # rename duplicate columns for sorting --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org