This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7c1ad5bb60c [SPARK-44548][PYTHON] Add support for pandas-on-Spark 
DataFrame assertDataFrameEqual
7c1ad5bb60c is described below

commit 7c1ad5bb60c88c1c659c131e5119aab1f8212af5
Author: Amanda Liu <amanda....@databricks.com>
AuthorDate: Fri Jul 28 14:41:43 2023 +0900

    [SPARK-44548][PYTHON] Add support for pandas-on-Spark DataFrame 
assertDataFrameEqual
    
    ### What changes were proposed in this pull request?
    This PR adds support for pandas-on-Spark DataFrame for the testing util, 
`assertDataFrameEqual`
    
    ### Why are the changes needed?
    The change allows users to call the same PySpark API for both Spark and 
pandas DataFrames.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, the PR affects the user-facing util `assertDataFrameEqual`
    
    ### How was this patch tested?
    Added tests to `python/pyspark/sql/tests/test_utils.py` and 
`python/pyspark/sql/tests/connect/test_utils.py` and existing pandas util tests.
    
    Closes #42158 from asl3/pandas-or-pyspark-df.
    
    Authored-by: Amanda Liu <amanda....@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 dev/sparktestsupport/modules.py                  |   1 +
 python/docs/source/reference/pyspark.testing.rst |   1 +
 python/pyspark/errors/error_classes.py           |  42 ++
 python/pyspark/pandas/tests/test_utils.py        | 171 +++++++-
 python/pyspark/sql/tests/test_utils.py           |  60 ++-
 python/pyspark/testing/__init__.py               |   4 +-
 python/pyspark/testing/pandasutils.py            | 506 +++++++++++++++++------
 python/pyspark/testing/utils.py                  |  82 +++-
 8 files changed, 689 insertions(+), 178 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 3cfd82c3d31..79c3f8f26b1 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -514,6 +514,7 @@ pyspark_testing = Module(
     python_test_goals=[
         # doctests
         "pyspark.testing.utils",
+        "pyspark.testing.pandasutils",
     ],
 )
 
diff --git a/python/docs/source/reference/pyspark.testing.rst 
b/python/docs/source/reference/pyspark.testing.rst
index 7a6b6cc0d70..96b0c72a7bb 100644
--- a/python/docs/source/reference/pyspark.testing.rst
+++ b/python/docs/source/reference/pyspark.testing.rst
@@ -26,4 +26,5 @@ Testing
     :toctree: api/
 
     assertDataFrameEqual
+    assertPandasOnSparkEqual
     assertSchemaEqual
diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index f4b643f1d32..5ecba294d0c 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -164,6 +164,42 @@ ERROR_CLASSES_JSON = """
       "Remote client cannot create a SparkContext. Create SparkSession 
instead."
     ]
   },
+  "DIFFERENT_PANDAS_DATAFRAME" : {
+    "message" : [
+      "DataFrames are not almost equal:",
+      "Left: <left>",
+      "<left_dtype>",
+      "Right: <right>",
+      "<right_dtype>"
+    ]
+  },
+  "DIFFERENT_PANDAS_INDEX" : {
+    "message" : [
+      "Indices are not almost equal:",
+      "Left: <left>",
+      "<left_dtype>",
+      "Right: <right>",
+      "<right_dtype>"
+    ]
+  },
+  "DIFFERENT_PANDAS_MULTIINDEX" : {
+    "message" : [
+      "MultiIndices are not almost equal:",
+      "Left: <left>",
+      "<left_dtype>",
+      "Right: <right>",
+      "<right_dtype>"
+    ]
+  },
+  "DIFFERENT_PANDAS_SERIES" : {
+    "message" : [
+      "Series are not almost equal:",
+      "Left: <left>",
+      "<left_dtype>",
+      "Right: <right>",
+      "<right_dtype>"
+    ]
+  },
   "DIFFERENT_ROWS" : {
     "message" : [
       "<error_msg>"
@@ -233,6 +269,12 @@ ERROR_CLASSES_JSON = """
       "NumPy array input should be of <dimensions> dimensions."
     ]
   },
+  "INVALID_PANDAS_ON_SPARK_COMPARISON" : {
+    "message" : [
+      "Expected two pandas-on-Spark DataFrames",
+      "but got actual: <actual_type> and expected: <expected_type>"
+    ]
+  },
   "INVALID_PANDAS_UDF" : {
     "message" : [
       "Invalid function: <detail>"
diff --git a/python/pyspark/pandas/tests/test_utils.py 
b/python/pyspark/pandas/tests/test_utils.py
index 35ebcf17a0f..de7b0449dae 100644
--- a/python/pyspark/pandas/tests/test_utils.py
+++ b/python/pyspark/pandas/tests/test_utils.py
@@ -16,6 +16,7 @@
 #
 
 import pandas as pd
+from typing import Union
 
 from pyspark.pandas.indexes.base import Index
 from pyspark.pandas.utils import (
@@ -25,8 +26,14 @@ from pyspark.pandas.utils import (
     validate_index_loc,
     validate_mode,
 )
-from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.pandasutils import (
+    PandasOnSparkTestCase,
+    assertPandasOnSparkEqual,
+    _assert_pandas_equal,
+    _assert_pandas_almost_equal,
+)
 from pyspark.testing.sqlutils import SQLTestUtils
+from pyspark.errors import PySparkAssertionError
 
 some_global_variable = 0
 
@@ -105,6 +112,168 @@ class UtilsTestsMixin:
         with self.assertRaisesRegex(IndexError, err_msg):
             validate_index_loc(psidx, -4)
 
+    def test_assert_df_assertPandasOnSparkEqual(self):
+        import pyspark.pandas as ps
+
+        psdf1 = ps.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]})
+        psdf2 = ps.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]})
+
+        assertPandasOnSparkEqual(psdf1, psdf2, checkRowOrder=False)
+        assertPandasOnSparkEqual(psdf1, psdf2, checkRowOrder=True)
+
+    def test_assertPandasOnSparkEqual_ignoreOrder_default(self):
+        import pyspark.pandas as ps
+
+        psdf1 = ps.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]})
+        psdf2 = ps.DataFrame({"a": [2, 1, 3], "b": [5, 4, 6], "c": [8, 7, 9]})
+
+        assertPandasOnSparkEqual(psdf1, psdf2)
+
+    def test_assert_series_assertPandasOnSparkEqual(self):
+        import pyspark.pandas as ps
+
+        s1 = ps.Series([212.32, 100.0001])
+        s2 = ps.Series([212.32, 100.0001])
+
+        assertPandasOnSparkEqual(s1, s2, checkExact=False)
+
+    def test_assert_index_assertPandasOnSparkEqual(self):
+        import pyspark.pandas as ps
+
+        s1 = ps.Index([212.300001, 100.000])
+        s2 = ps.Index([212.3, 100.0001])
+
+        assertPandasOnSparkEqual(s1, s2, almost=True)
+
+    def test_assert_error_assertPandasOnSparkEqual(self):
+        import pyspark.pandas as ps
+
+        list1 = [10, 20, 30]
+        list2 = [10, 20, 30]
+
+        with self.assertRaises(PySparkAssertionError) as pe:
+            assertPandasOnSparkEqual(list1, list2)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="INVALID_TYPE_DF_EQUALITY_ARG",
+            message_parameters={
+                "expected_type": f"{ps.DataFrame.__name__}, "
+                f"{ps.Series.__name__}, "
+                f"{ps.Index.__name__}",
+                "arg_name": "actual",
+                "actual_type": type(list1),
+            },
+        )
+
+    def test_assert_None_assertPandasOnSparkEqual(self):
+        psdf1 = None
+        psdf2 = None
+
+        assertPandasOnSparkEqual(psdf1, psdf2)
+
+    def test_assert_empty_assertPandasOnSparkEqual(self):
+        import pyspark.pandas as ps
+
+        psdf1 = ps.DataFrame()
+        psdf2 = ps.DataFrame()
+
+        assertPandasOnSparkEqual(psdf1, psdf2)
+
+    def test_dataframe_error_assert_pandas_equal(self):
+        pdf1 = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=[0, 1, 3])
+        pdf2 = pd.DataFrame({"a": [1, 3, 3], "b": [4, 5, 6]}, index=[0, 1, 3])
+
+        with self.assertRaises(PySparkAssertionError) as pe:
+            _assert_pandas_equal(pdf1, pdf2, True)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="DIFFERENT_PANDAS_DATAFRAME",
+            message_parameters={
+                "left": pdf1.to_string(),
+                "left_dtype": str(pdf1.dtypes),
+                "right": pdf2.to_string(),
+                "right_dtype": str(pdf2.dtypes),
+            },
+        )
+
+    def test_series_error_assert_pandas_equal(self):
+        series1 = pd.Series([1, 2, 3])
+        series2 = pd.Series([4, 5, 6])
+
+        with self.assertRaises(PySparkAssertionError) as pe:
+            _assert_pandas_equal(series1, series2, True)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="DIFFERENT_PANDAS_SERIES",
+            message_parameters={
+                "left": series1,
+                "left_dtype": series1.dtype,
+                "right": series2,
+                "right_dtype": series2.dtype,
+            },
+        )
+
+    def test_index_error_assert_pandas_equal(self):
+        index1 = pd.Index([1, 2, 3])
+        index2 = pd.Index([4, 5, 6])
+
+        with self.assertRaises(PySparkAssertionError) as pe:
+            _assert_pandas_equal(index1, index2, True)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="DIFFERENT_PANDAS_INDEX",
+            message_parameters={
+                "left": index1,
+                "left_dtype": index1.dtype,
+                "right": index2,
+                "right_dtype": index2.dtype,
+            },
+        )
+
+    def test_multiindex_error_assert_pandas_almost_equal(self):
+        pdf1 = pd.DataFrame({"a": [1, 2], "b": [4, 10]}, index=[0, 1])
+        pdf2 = pd.DataFrame({"a": [1, 5, 3], "b": [1, 5, 6]}, index=[0, 1, 3])
+        multiindex1 = pd.MultiIndex.from_frame(pdf1)
+        multiindex2 = pd.MultiIndex.from_frame(pdf2)
+
+        with self.assertRaises(PySparkAssertionError) as pe:
+            _assert_pandas_almost_equal(multiindex1, multiindex2)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="DIFFERENT_PANDAS_MULTIINDEX",
+            message_parameters={
+                "left": multiindex1,
+                "left_dtype": multiindex1.dtype,
+                "right": multiindex2,
+                "right_dtype": multiindex2.dtype,
+            },
+        )
+
+    def test_dataframe_error_assert_pandas_on_spark_almost_equal(self):
+        import pyspark.pandas as ps
+
+        psdf1 = ps.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]})
+        psdf2 = ps.DataFrame({"a": [1, 2], "b": [4, 5], "c": [7, 8]})
+
+        with self.assertRaises(PySparkAssertionError) as pe:
+            assertPandasOnSparkEqual(psdf1, psdf2, almost=True)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="DIFFERENT_PANDAS_DATAFRAME",
+            message_parameters={
+                "left": psdf1.to_string(),
+                "left_dtype": str(psdf1.dtypes),
+                "right": psdf2.to_string(),
+                "right_dtype": str(psdf2.dtypes),
+            },
+        )
+
 
 class TestClassForLazyProp:
     def __init__(self):
diff --git a/python/pyspark/sql/tests/test_utils.py 
b/python/pyspark/sql/tests/test_utils.py
index a1cefe7c840..500c314e449 100644
--- a/python/pyspark/sql/tests/test_utils.py
+++ b/python/pyspark/sql/tests/test_utils.py
@@ -623,22 +623,47 @@ class UtilsTestsMixin:
         assertDataFrameEqual(df1, df2, checkRowOrder=False)
         assertDataFrameEqual(df1, df2, checkRowOrder=True)
 
-    def test_assert_error_pandas_df(self):
-        import pandas as pd
+    def test_assert_equal_exact_pandas_df(self):
+        import pyspark.pandas as ps
 
-        df1 = pd.DataFrame(data=[10, 20, 30], columns=["Numbers"])
-        df2 = pd.DataFrame(data=[10, 20, 30], columns=["Numbers"])
+        df1 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"])
+        df2 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"])
+
+        assertDataFrameEqual(df1, df2, checkRowOrder=False)
+        assertDataFrameEqual(df1, df2, checkRowOrder=True)
+
+    def test_assert_equal_exact_pandas_df(self):
+        import pyspark.pandas as ps
+
+        df1 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"])
+        df2 = ps.DataFrame(data=[30, 20, 10], columns=["Numbers"])
+
+        assertDataFrameEqual(df1, df2)
+
+    def test_assert_equal_approx_pandas_df(self):
+        import pyspark.pandas as ps
+
+        df1 = ps.DataFrame(data=[10.0001, 20.32, 30.1], columns=["Numbers"])
+        df2 = ps.DataFrame(data=[10.0, 20.32, 30.1], columns=["Numbers"])
+
+        assertDataFrameEqual(df1, df2, checkRowOrder=False)
+        assertDataFrameEqual(df1, df2, checkRowOrder=True)
+
+    def test_assert_error_pandas_pyspark_df(self):
+        import pyspark.pandas as ps
+
+        df1 = ps.DataFrame(data=[10, 20, 30], columns=["Numbers"])
+        df2 = self.spark.createDataFrame([(10,), (11,), (13,)], ["Numbers"])
 
         with self.assertRaises(PySparkAssertionError) as pe:
-            assertDataFrameEqual(df1, df2)
+            assertDataFrameEqual(df1, df2, checkRowOrder=False)
 
         self.check_error(
             exception=pe.exception,
-            error_class="INVALID_TYPE_DF_EQUALITY_ARG",
+            error_class="INVALID_PANDAS_ON_SPARK_COMPARISON",
             message_parameters={
-                "expected_type": DataFrame,
-                "arg_name": "df",
-                "actual_type": pd.DataFrame,
+                "actual_type": type(df1),
+                "expected_type": type(df2),
             },
         )
 
@@ -647,15 +672,16 @@ class UtilsTestsMixin:
 
         self.check_error(
             exception=pe.exception,
-            error_class="INVALID_TYPE_DF_EQUALITY_ARG",
+            error_class="INVALID_PANDAS_ON_SPARK_COMPARISON",
             message_parameters={
-                "expected_type": DataFrame,
-                "arg_name": "df",
-                "actual_type": pd.DataFrame,
+                "actual_type": type(df1),
+                "expected_type": type(df2),
             },
         )
 
     def test_assert_error_non_pyspark_df(self):
+        import pyspark.pandas as ps
+
         dict1 = {"a": 1, "b": 2}
         dict2 = {"a": 1, "b": 2}
 
@@ -666,8 +692,8 @@ class UtilsTestsMixin:
             exception=pe.exception,
             error_class="INVALID_TYPE_DF_EQUALITY_ARG",
             message_parameters={
-                "expected_type": DataFrame,
-                "arg_name": "df",
+                "expected_type": f"{DataFrame.__name__}, 
{ps.DataFrame.__name__}",
+                "arg_name": "actual",
                 "actual_type": type(dict1),
             },
         )
@@ -679,8 +705,8 @@ class UtilsTestsMixin:
             exception=pe.exception,
             error_class="INVALID_TYPE_DF_EQUALITY_ARG",
             message_parameters={
-                "expected_type": DataFrame,
-                "arg_name": "df",
+                "expected_type": f"{DataFrame.__name__}, 
{ps.DataFrame.__name__}",
+                "arg_name": "actual",
                 "actual_type": type(dict1),
             },
         )
diff --git a/python/pyspark/testing/__init__.py 
b/python/pyspark/testing/__init__.py
index 88853e925f8..57c206629a8 100644
--- a/python/pyspark/testing/__init__.py
+++ b/python/pyspark/testing/__init__.py
@@ -16,4 +16,6 @@
 #
 from pyspark.testing.utils import assertDataFrameEqual, assertSchemaEqual
 
-__all__ = ["assertDataFrameEqual", "assertSchemaEqual"]
+from pyspark.testing.pandasutils import assertPandasOnSparkEqual
+
+__all__ = ["assertDataFrameEqual", "assertSchemaEqual", 
"assertPandasOnSparkEqual"]
diff --git a/python/pyspark/testing/pandasutils.py 
b/python/pyspark/testing/pandasutils.py
index 202603ca5c0..4ffe8858396 100644
--- a/python/pyspark/testing/pandasutils.py
+++ b/python/pyspark/testing/pandasutils.py
@@ -19,15 +19,19 @@ import functools
 import shutil
 import tempfile
 import warnings
+import pandas as pd
 from contextlib import contextmanager
 from distutils.version import LooseVersion
+import decimal
+from typing import Union
 
-from pyspark import pandas as ps
+import pyspark.pandas as ps
 from pyspark.pandas.frame import DataFrame
 from pyspark.pandas.indexes import Index
 from pyspark.pandas.series import Series
 from pyspark.pandas.utils import SPARK_CONF_ARROW_ENABLED
 from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.errors import PySparkAssertionError
 
 tabulate_requirement_message = None
 try:
@@ -54,153 +58,381 @@ except ImportError as e:
 have_plotly = plotly_requirement_message is None
 
 
-class PandasOnSparkTestUtils:
-    def convert_str_to_lambda(self, func):
-        """
-        This function coverts `func` str to lambda call
-        """
-        return lambda x: getattr(x, func)()
+__all__ = ["assertPandasOnSparkEqual"]
 
-    def assertPandasEqual(self, left, right, check_exact=True):
-        import pandas as pd
-        from pandas.core.dtypes.common import is_numeric_dtype
-        from pandas.testing import assert_frame_equal, assert_index_equal, 
assert_series_equal
-
-        if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame):
-            try:
-                if LooseVersion(pd.__version__) >= LooseVersion("1.1"):
-                    kwargs = dict(check_freq=False)
-                else:
-                    kwargs = dict()
-
-                if LooseVersion(pd.__version__) < LooseVersion("1.1.1"):
-                    # Due to https://github.com/pandas-dev/pandas/issues/35446
-                    check_exact = (
-                        check_exact
-                        and all([is_numeric_dtype(dtype) for dtype in 
left.dtypes])
-                        and all([is_numeric_dtype(dtype) for dtype in 
right.dtypes])
-                    )
 
-                assert_frame_equal(
-                    left,
-                    right,
-                    check_index_type=("equiv" if len(left.index) > 0 else 
False),
-                    check_column_type=("equiv" if len(left.columns) > 0 else 
False),
-                    check_exact=check_exact,
-                    **kwargs,
+def _assert_pandas_equal(
+    left: Union[pd.DataFrame, pd.Series, pd.Index],
+    right: Union[pd.DataFrame, pd.Series, pd.Index],
+    checkExact: bool,
+):
+    from pandas.core.dtypes.common import is_numeric_dtype
+    from pandas.testing import assert_frame_equal, assert_index_equal, 
assert_series_equal
+
+    if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame):
+        try:
+            if LooseVersion(pd.__version__) >= LooseVersion("1.1"):
+                kwargs = dict(check_freq=False)
+            else:
+                kwargs = dict()
+
+            if LooseVersion(pd.__version__) < LooseVersion("1.1.1"):
+                # Due to https://github.com/pandas-dev/pandas/issues/35446
+                checkExact = (
+                    checkExact
+                    and all([is_numeric_dtype(dtype) for dtype in left.dtypes])
+                    and all([is_numeric_dtype(dtype) for dtype in 
right.dtypes])
                 )
-            except AssertionError as e:
-                msg = (
-                    str(e)
-                    + "\n\nLeft:\n%s\n%s" % (left, left.dtypes)
-                    + "\n\nRight:\n%s\n%s" % (right, right.dtypes)
+
+            assert_frame_equal(
+                left,
+                right,
+                check_index_type=("equiv" if len(left.index) > 0 else False),
+                check_column_type=("equiv" if len(left.columns) > 0 else 
False),
+                check_exact=checkExact,
+                **kwargs,
+            )
+        except AssertionError:
+            raise PySparkAssertionError(
+                error_class="DIFFERENT_PANDAS_DATAFRAME",
+                message_parameters={
+                    "left": left.to_string(),
+                    "left_dtype": str(left.dtypes),
+                    "right": right.to_string(),
+                    "right_dtype": str(right.dtypes),
+                },
+            )
+    elif isinstance(left, pd.Series) and isinstance(right, pd.Series):
+        try:
+            if LooseVersion(pd.__version__) >= LooseVersion("1.1"):
+                kwargs = dict(check_freq=False)
+            else:
+                kwargs = dict()
+            if LooseVersion(pd.__version__) < LooseVersion("1.1.1"):
+                # Due to https://github.com/pandas-dev/pandas/issues/35446
+                checkExact = (
+                    checkExact and is_numeric_dtype(left.dtype) and 
is_numeric_dtype(right.dtype)
                 )
-                raise AssertionError(msg) from e
-        elif isinstance(left, pd.Series) and isinstance(right, pd.Series):
-            try:
-                if LooseVersion(pd.__version__) >= LooseVersion("1.1"):
-                    kwargs = dict(check_freq=False)
-                else:
-                    kwargs = dict()
-                if LooseVersion(pd.__version__) < LooseVersion("1.1.1"):
-                    # Due to https://github.com/pandas-dev/pandas/issues/35446
-                    check_exact = (
-                        check_exact
-                        and is_numeric_dtype(left.dtype)
-                        and is_numeric_dtype(right.dtype)
-                    )
-                assert_series_equal(
-                    left,
-                    right,
-                    check_index_type=("equiv" if len(left.index) > 0 else 
False),
-                    check_exact=check_exact,
-                    **kwargs,
+            assert_series_equal(
+                left,
+                right,
+                check_index_type=("equiv" if len(left.index) > 0 else False),
+                check_exact=checkExact,
+                **kwargs,
+            )
+        except AssertionError:
+            raise PySparkAssertionError(
+                error_class="DIFFERENT_PANDAS_SERIES",
+                message_parameters={
+                    "left": left,
+                    "left_dtype": left.dtype,
+                    "right": right,
+                    "right_dtype": right.dtype,
+                },
+            )
+    elif isinstance(left, pd.Index) and isinstance(right, pd.Index):
+        try:
+            if LooseVersion(pd.__version__) < LooseVersion("1.1.1"):
+                # Due to https://github.com/pandas-dev/pandas/issues/35446
+                checkExact = (
+                    checkExact and is_numeric_dtype(left.dtype) and 
is_numeric_dtype(right.dtype)
                 )
-            except AssertionError as e:
-                msg = (
-                    str(e)
-                    + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
-                    + "\n\nRight:\n%s\n%s" % (right, right.dtype)
+            assert_index_equal(left, right, check_exact=checkExact)
+        except AssertionError:
+            raise PySparkAssertionError(
+                error_class="DIFFERENT_PANDAS_INDEX",
+                message_parameters={
+                    "left": left,
+                    "left_dtype": left.dtype,
+                    "right": right,
+                    "right_dtype": right.dtype,
+                },
+            )
+    else:
+        raise ValueError("Unexpected values: (%s, %s)" % (left, right))
+
+
+def _assert_pandas_almost_equal(
+    left: Union[pd.DataFrame, pd.Series, pd.Index], right: Union[pd.DataFrame, 
pd.Series, pd.Index]
+):
+    """
+    This function checks if given pandas objects approximately same,
+    which means the conditions below:
+      - Both objects are nullable
+      - Compare floats rounding to the number of decimal places, 7 after
+        dropping missing values (NaN, NaT, None)
+    """
+    # following pandas convention, rtol=1e-5 and atol=1e-8
+    rtol = 1e-5
+    atol = 1e-8
+
+    if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame):
+        if left.shape != right.shape:
+            raise PySparkAssertionError(
+                error_class="DIFFERENT_PANDAS_DATAFRAME",
+                message_parameters={
+                    "left": left.to_string(),
+                    "left_dtype": str(left.dtypes),
+                    "right": right.to_string(),
+                    "right_dtype": str(right.dtypes),
+                },
+            )
+        for lcol, rcol in zip(left.columns, right.columns):
+            if lcol != rcol:
+                raise PySparkAssertionError(
+                    error_class="DIFFERENT_PANDAS_DATAFRAME",
+                    message_parameters={
+                        "left": left.to_string(),
+                        "left_dtype": str(left.dtypes),
+                        "right": right.to_string(),
+                        "right_dtype": str(right.dtypes),
+                    },
                 )
-                raise AssertionError(msg) from e
-        elif isinstance(left, pd.Index) and isinstance(right, pd.Index):
-            try:
-                if LooseVersion(pd.__version__) < LooseVersion("1.1.1"):
-                    # Due to https://github.com/pandas-dev/pandas/issues/35446
-                    check_exact = (
-                        check_exact
-                        and is_numeric_dtype(left.dtype)
-                        and is_numeric_dtype(right.dtype)
+            for lnull, rnull in zip(left[lcol].isnull(), right[rcol].isnull()):
+                if lnull != rnull:
+                    raise PySparkAssertionError(
+                        error_class="DIFFERENT_PANDAS_DATAFRAME",
+                        message_parameters={
+                            "left": left.to_string(),
+                            "left_dtype": str(left.dtypes),
+                            "right": right.to_string(),
+                            "right_dtype": str(right.dtypes),
+                        },
                     )
-                assert_index_equal(left, right, check_exact=check_exact)
-            except AssertionError as e:
-                msg = (
-                    str(e)
-                    + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
-                    + "\n\nRight:\n%s\n%s" % (right, right.dtype)
+            for lval, rval in zip(left[lcol].dropna(), right[rcol].dropna()):
+                if (isinstance(lval, float) or isinstance(lval, 
decimal.Decimal)) and (
+                    isinstance(rval, float) or isinstance(rval, 
decimal.Decimal)
+                ):
+                    if abs(float(lval) - float(rval)) > (atol + rtol * 
abs(float(rval))):
+                        raise PySparkAssertionError(
+                            error_class="DIFFERENT_PANDAS_DATAFRAME",
+                            message_parameters={
+                                "left": left.to_string(),
+                                "left_dtype": str(left.dtypes),
+                                "right": right.to_string(),
+                                "right_dtype": str(right.dtypes),
+                            },
+                        )
+        if left.columns.names != right.columns.names:
+            raise PySparkAssertionError(
+                error_class="DIFFERENT_PANDAS_DATAFRAME",
+                message_parameters={
+                    "left": left.to_string(),
+                    "left_dtype": str(left.dtypes),
+                    "right": right.to_string(),
+                    "right_dtype": str(right.dtypes),
+                },
+            )
+    elif isinstance(left, pd.Series) and isinstance(right, pd.Series):
+        if left.name != right.name or len(left) != len(right):
+            raise PySparkAssertionError(
+                error_class="DIFFERENT_PANDAS_SERIES",
+                message_parameters={
+                    "left": left,
+                    "left_dtype": left.dtype,
+                    "right": right,
+                    "right_dtype": right.dtype,
+                },
+            )
+        for lnull, rnull in zip(left.isnull(), right.isnull()):
+            if lnull != rnull:
+                raise PySparkAssertionError(
+                    error_class="DIFFERENT_PANDAS_SERIES",
+                    message_parameters={
+                        "left": left,
+                        "left_dtype": left.dtype,
+                        "right": right,
+                        "right_dtype": right.dtype,
+                    },
                 )
-                raise AssertionError(msg) from e
+        for lval, rval in zip(left.dropna(), right.dropna()):
+            if (isinstance(lval, float) or isinstance(lval, decimal.Decimal)) 
and (
+                isinstance(rval, float) or isinstance(rval, decimal.Decimal)
+            ):
+                if abs(float(lval) - float(rval)) > (atol + rtol * 
abs(float(rval))):
+                    raise PySparkAssertionError(
+                        error_class="DIFFERENT_PANDAS_SERIES",
+                        message_parameters={
+                            "left": left,
+                            "left_dtype": left.dtype,
+                            "right": right,
+                            "right_dtype": right.dtype,
+                        },
+                    )
+    elif isinstance(left, pd.MultiIndex) and isinstance(right, pd.MultiIndex):
+        if len(left) != len(right):
+            raise PySparkAssertionError(
+                error_class="DIFFERENT_PANDAS_MULTIINDEX",
+                message_parameters={
+                    "left": left,
+                    "left_dtype": left.dtype,
+                    "right": right,
+                    "right_dtype": right.dtype,
+                },
+            )
+        for lval, rval in zip(left, right):
+            if (isinstance(lval, float) or isinstance(lval, decimal.Decimal)) 
and (
+                isinstance(rval, float) or isinstance(rval, decimal.Decimal)
+            ):
+                if abs(float(lval) - float(rval)) > (atol + rtol * 
abs(float(rval))):
+                    raise PySparkAssertionError(
+                        error_class="DIFFERENT_PANDAS_MULTIINDEX",
+                        message_parameters={
+                            "left": left,
+                            "left_dtype": left.dtype,
+                            "right": right,
+                            "right_dtype": right.dtype,
+                        },
+                    )
+    elif isinstance(left, pd.Index) and isinstance(right, pd.Index):
+        if len(left) != len(right):
+            raise PySparkAssertionError(
+                error_class="DIFFERENT_PANDAS_INDEX",
+                message_parameters={
+                    "left": left,
+                    "left_dtype": left.dtype,
+                    "right": right,
+                    "right_dtype": right.dtype,
+                },
+            )
+        for lnull, rnull in zip(left.isnull(), right.isnull()):
+            if lnull != rnull:
+                raise PySparkAssertionError(
+                    error_class="DIFFERENT_PANDAS_INDEX",
+                    message_parameters={
+                        "left": left,
+                        "left_dtype": left.dtype,
+                        "right": right,
+                        "right_dtype": right.dtype,
+                    },
+                )
+        for lval, rval in zip(left.dropna(), right.dropna()):
+            if (isinstance(lval, float) or isinstance(lval, decimal.Decimal)) 
and (
+                isinstance(rval, float) or isinstance(rval, decimal.Decimal)
+            ):
+                if abs(float(lval) - float(rval)) > (atol + rtol * 
abs(float(rval))):
+                    raise PySparkAssertionError(
+                        error_class="DIFFERENT_PANDAS_INDEX",
+                        message_parameters={
+                            "left": left,
+                            "left_dtype": left.dtype,
+                            "right": right,
+                            "right_dtype": right.dtype,
+                        },
+                    )
+    else:
+        raise ValueError("Unexpected values: (%s, %s)" % (left, right))
+
+
+def assertPandasOnSparkEqual(
+    actual: Union[DataFrame, Series, Index],
+    expected: Union[DataFrame, pd.DataFrame, Series, Index],
+    checkExact: bool = True,
+    almost: bool = False,
+    checkRowOrder: bool = False,
+):
+    r"""
+    A util function to assert equality between actual (pandas-on-Spark 
DataFrame) and expected
+    (pandas-on-Spark or pandas DataFrame).
+
+    .. versionadded:: 3.5.0
+
+    Parameters
+    ----------
+    actual: pyspark.pandas.frame.DataFrame
+        The DataFrame that is being compared or tested.
+    expected: pyspark.pandas.frame.DataFrame or pd.DataFrame
+        The expected DataFrame, for comparison with the actual result.
+    checkExact: bool, optional
+        A flag indicating whether to compare exact equality.
+        If set to 'True' (default), the data is compared exactly.
+        If set to 'False', the data is compared less precisely, following 
pandas assert_frame_equal
+        approximate comparison (see documentation for more details).
+    almost: bool, optional
+        A flag indicating whether to use unittest `assertAlmostEqual` or 
`assertEqual`.
+        If set to 'True', the comparison is delegated to `unittest`'s 
`assertAlmostEqual`
+        (see documentation for more details).
+        If set to 'False' (default), the data is compared exactly with 
`unittest`'s
+        `assertEqual`.
+    checkRowOrder : bool, optional
+        A flag indicating whether the order of rows should be considered in 
the comparison.
+        If set to `False` (default), the row order is not taken into account.
+        If set to `True`, the order of rows is important and will be checked 
during comparison.
+        (See Notes)
+
+    Notes
+    -----
+    For `checkRowOrder`, note that pandas-on-Spark DataFrame ordering is 
non-deterministic, unless
+    explicitly sorted.
+
+    Examples
+    --------
+    >>> import pyspark.pandas as ps
+    >>> psdf1 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]})
+    >>> psdf2 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]})
+    >>> assertPandasOnSparkEqual(psdf1, psdf2)  # pass, ps.DataFrames are equal
+    >>> s1 = ps.Series([212.32, 100.0001])
+    >>> s2 = ps.Series([212.32, 100.0])
+    >>> assertPandasOnSparkEqual(s1, s2, checkExact=False)  # pass, ps.Series 
are approx equal
+    >>> s1 = ps.Index([212.300001, 100.000])
+    >>> s2 = ps.Index([212.3, 100.0001])
+    >>> assertPandasOnSparkEqual(s1, s2, almost=True)  # pass, ps.Index obj 
are almost equal
+    """
+    if actual is None and expected is None:
+        return True
+    elif actual is None or expected is None:
+        return False
+
+    if not isinstance(actual, (DataFrame, Series, Index)):
+        raise PySparkAssertionError(
+            error_class="INVALID_TYPE_DF_EQUALITY_ARG",
+            message_parameters={
+                "expected_type": f"{DataFrame.__name__}, {Series.__name__}, 
{Index.__name__}",
+                "arg_name": "actual",
+                "actual_type": type(actual),
+            },
+        )
+    elif not isinstance(expected, (DataFrame, pd.DataFrame, Series, Index)):
+        raise PySparkAssertionError(
+            error_class="INVALID_TYPE_DF_EQUALITY_ARG",
+            message_parameters={
+                "expected_type": f"{DataFrame.__name__}, "
+                f"{pd.DataFrame.__name__}, "
+                f"{Series.__name__}, "
+                f"{Index.__name__}",
+                "arg_name": "expected",
+                "actual_type": type(expected),
+            },
+        )
+    else:
+        actual = actual.to_pandas()
+        if not isinstance(expected, pd.DataFrame):
+            expected = expected.to_pandas()
+
+        if not checkRowOrder:
+            if isinstance(actual, pd.DataFrame) and len(actual.columns) > 0:
+                actual = actual.sort_values(by=actual.columns[0], 
ignore_index=True)
+            if isinstance(expected, pd.DataFrame) and len(expected.columns) > 
0:
+                expected = expected.sort_values(by=expected.columns[0], 
ignore_index=True)
+
+        if almost:
+            _assert_pandas_almost_equal(actual, expected)
         else:
-            raise ValueError("Unexpected values: (%s, %s)" % (left, right))
+            _assert_pandas_equal(actual, expected, checkExact=checkExact)
 
-    def assertPandasAlmostEqual(self, left, right):
+
+class PandasOnSparkTestUtils:
+    def convert_str_to_lambda(self, func):
         """
-        This function checks if given pandas objects approximately same,
-        which means the conditions below:
-          - Both objects are nullable
-          - Compare floats rounding to the number of decimal places, 7 after
-            dropping missing values (NaN, NaT, None)
+        This function converts `func` str to lambda call
         """
-        import pandas as pd
+        return lambda x: getattr(x, func)()
 
-        if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame):
-            msg = (
-                "DataFrames are not almost equal: "
-                + "\n\nLeft:\n%s\n%s" % (left, left.dtypes)
-                + "\n\nRight:\n%s\n%s" % (right, right.dtypes)
-            )
-            self.assertEqual(left.shape, right.shape, msg=msg)
-            for lcol, rcol in zip(left.columns, right.columns):
-                self.assertEqual(lcol, rcol, msg=msg)
-                for lnull, rnull in zip(left[lcol].isnull(), 
right[rcol].isnull()):
-                    self.assertEqual(lnull, rnull, msg=msg)
-                for lval, rval in zip(left[lcol].dropna(), 
right[rcol].dropna()):
-                    self.assertAlmostEqual(lval, rval, msg=msg)
-            self.assertEqual(left.columns.names, right.columns.names, msg=msg)
-        elif isinstance(left, pd.Series) and isinstance(right, pd.Series):
-            msg = (
-                "Series are not almost equal: "
-                + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
-                + "\n\nRight:\n%s\n%s" % (right, right.dtype)
-            )
-            self.assertEqual(left.name, right.name, msg=msg)
-            self.assertEqual(len(left), len(right), msg=msg)
-            for lnull, rnull in zip(left.isnull(), right.isnull()):
-                self.assertEqual(lnull, rnull, msg=msg)
-            for lval, rval in zip(left.dropna(), right.dropna()):
-                self.assertAlmostEqual(lval, rval, msg=msg)
-        elif isinstance(left, pd.MultiIndex) and isinstance(right, 
pd.MultiIndex):
-            msg = (
-                "MultiIndices are not almost equal: "
-                + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
-                + "\n\nRight:\n%s\n%s" % (right, right.dtype)
-            )
-            self.assertEqual(len(left), len(right), msg=msg)
-            for lval, rval in zip(left, right):
-                self.assertAlmostEqual(lval, rval, msg=msg)
-        elif isinstance(left, pd.Index) and isinstance(right, pd.Index):
-            msg = (
-                "Indices are not almost equal: "
-                + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
-                + "\n\nRight:\n%s\n%s" % (right, right.dtype)
-            )
-            self.assertEqual(len(left), len(right), msg=msg)
-            for lnull, rnull in zip(left.isnull(), right.isnull()):
-                self.assertEqual(lnull, rnull, msg=msg)
-            for lval, rval in zip(left.dropna(), right.dropna()):
-                self.assertAlmostEqual(lval, rval, msg=msg)
-        else:
-            raise ValueError("Unexpected values: (%s, %s)" % (left, right))
+    def assertPandasEqual(self, left, right, check_exact=True):
+        _assert_pandas_equal(left, right, check_exact)
+
+    def assertPandasAlmostEqual(self, left, right):
+        _assert_pandas_almost_equal(left, right)
 
     def assert_eq(self, left, right, check_exact=True, almost=False):
         """
@@ -220,9 +452,9 @@ class PandasOnSparkTestUtils:
         robj = self._to_pandas(right)
         if isinstance(lobj, (pd.DataFrame, pd.Series, pd.Index)):
             if almost:
-                self.assertPandasAlmostEqual(lobj, robj)
+                _assert_pandas_almost_equal(lobj, robj)
             else:
-                self.assertPandasEqual(lobj, robj, check_exact=check_exact)
+                _assert_pandas_equal(lobj, robj, checkExact=check_exact)
         elif is_list_like(lobj) and is_list_like(robj):
             self.assertTrue(len(left) == len(right))
             for litem, ritem in zip(left, right):
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index b8977b6fffd..3ba92017fc4 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -38,6 +38,7 @@ from pyspark.find_spark_home import _find_spark_home
 from pyspark.sql.dataframe import DataFrame
 from pyspark.sql import Row
 from pyspark.sql.types import StructType, AtomicType, StructField
+import pyspark.pandas as ps
 
 have_scipy = False
 have_numpy = False
@@ -314,8 +315,8 @@ def assertSchemaEqual(actual: StructType, expected: 
StructType):
 
 
 def assertDataFrameEqual(
-    actual: DataFrame,
-    expected: Union[DataFrame, List[Row]],
+    actual: Union[DataFrame, ps.DataFrame],
+    expected: Union[DataFrame, ps.DataFrame, List[Row]],
     checkRowOrder: bool = False,
     rtol: float = 1e-5,
     atol: float = 1e-8,
@@ -324,13 +325,17 @@ def assertDataFrameEqual(
     A util function to assert equality between `actual` (DataFrame) and 
`expected`
     (DataFrame or list of Rows), with optional parameters `checkRowOrder`, 
`rtol`, and `atol`.
 
+    Supports Spark, Spark Connect, and pandas-on-Spark DataFrames.
+    For more information about pandas-on-Spark DataFrame equality, see the 
docs for
+    `assertPandasOnSparkEqual`.
+
     .. versionadded:: 3.5.0
 
     Parameters
     ----------
-    actual : DataFrame
+    actual : DataFrame (Spark, Spark Connect, or pandas-on-Spark)
         The DataFrame that is being compared or tested.
-    expected : DataFrame or list of Rows
+    expected : DataFrame (Spark, Spark Connect, or pandas-on-Spark) or list of 
Rows
         The expected result of the operation, for comparison with the actual 
result.
     checkRowOrder : bool, optional
         A flag indicating whether the order of rows should be considered in 
the comparison.
@@ -346,10 +351,10 @@ def assertDataFrameEqual(
 
     Notes
     -----
-    When assertDataFrameEqual fails, the error message uses the Python 
`difflib` library to display
-    a diff log of each row that differs in `actual` and `expected`.
+    When `assertDataFrameEqual` fails, the error message uses the Python 
`difflib` library to
+    display a diff log of each row that differs in `actual` and `expected`.
 
-    For checkRowOrder, note that PySpark DataFrame ordering is 
non-deterministic, unless
+    For `checkRowOrder`, note that PySpark DataFrame ordering is 
non-deterministic, unless
     explicitly sorted.
 
     Note that schema equality is checked only when `expected` is a DataFrame 
(not a list of Rows).
@@ -369,7 +374,11 @@ def assertDataFrameEqual(
     >>> assertDataFrameEqual(df1, df2, rtol=1e-1)  # pass, DataFrames are 
approx equal by rtol
     >>> df1 = spark.createDataFrame(data=[(1, 1000), (2, 3000)], schema=["id", 
"amount"])
     >>> list_of_rows = [Row(1, 1000), Row(2, 3000)]
-    >>> assertDataFrameEqual(df1, list_of_rows)  # pass, actual and expected 
are equal
+    >>> assertDataFrameEqual(df1, list_of_rows)  # pass, actual and expected 
data are equal
+    >>> import pyspark.pandas as ps
+    >>> df1 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]})
+    >>> df2 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]})
+    >>> assertDataFrameEqual(df1, df2)  # pass, pandas-on-Spark DataFrames are 
equal
     >>> df1 = spark.createDataFrame(
     ...     data=[("1", 1000.00), ("2", 3000.00), ("3", 2000.00)], 
schema=["id", "amount"])
     >>> df2 = spark.createDataFrame(
@@ -395,47 +404,76 @@ def assertDataFrameEqual(
     elif actual is None or expected is None:
         return False
 
+    import pyspark.pandas as ps
+    from pyspark.testing.pandasutils import assertPandasOnSparkEqual
+
     try:
         # If Spark Connect dependencies are available, allow Spark Connect 
DataFrame
         from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
 
-        if not isinstance(actual, DataFrame) and not isinstance(actual, 
ConnectDataFrame):
+        if isinstance(actual, ps.DataFrame) or isinstance(expected, 
ps.DataFrame):
+            # handle pandas DataFrames
+            if not (isinstance(actual, ps.DataFrame) and isinstance(expected, 
ps.DataFrame)):
+                raise PySparkAssertionError(
+                    error_class="INVALID_PANDAS_ON_SPARK_COMPARISON",
+                    message_parameters={
+                        "actual_type": type(actual),
+                        "expected_type": type(expected),
+                    },
+                )
+            # assert approximate equality for float data
+            return assertPandasOnSparkEqual(
+                actual, expected, checkExact=False, checkRowOrder=checkRowOrder
+            )
+        elif not isinstance(actual, (DataFrame, ps.DataFrame, 
ConnectDataFrame)):
             raise PySparkAssertionError(
                 error_class="INVALID_TYPE_DF_EQUALITY_ARG",
                 message_parameters={
-                    "expected_type": DataFrame,
-                    "arg_name": "df",
+                    "expected_type": DataFrame.__name__,
+                    "arg_name": "actual",
                     "actual_type": type(actual),
                 },
             )
-        elif (
-            not isinstance(expected, DataFrame)
-            and not isinstance(expected, ConnectDataFrame)
-            and not isinstance(expected, List)
-        ):
+        elif not isinstance(expected, (DataFrame, ps.DataFrame, 
ConnectDataFrame, list)):
             raise PySparkAssertionError(
                 error_class="INVALID_TYPE_DF_EQUALITY_ARG",
                 message_parameters={
-                    "expected_type": Union[DataFrame, List[Row]],
+                    "expected_type": f"{DataFrame.__name__}, 
{List[Row].__name__}",
                     "arg_name": "expected",
                     "actual_type": type(expected),
                 },
             )
     except Exception:
-        if not isinstance(actual, DataFrame):
+        if isinstance(actual, ps.DataFrame) or isinstance(expected, 
ps.DataFrame):
+            # handle pandas DataFrames
+            if not (isinstance(actual, ps.DataFrame) and isinstance(expected, 
ps.DataFrame)):
+                raise PySparkAssertionError(
+                    error_class="INVALID_PANDAS_ON_SPARK_COMPARISON",
+                    message_parameters={
+                        "actual_type": type(actual),
+                        "expected_type": type(expected),
+                    },
+                )
+            # assert approximate equality for float data
+            return assertPandasOnSparkEqual(
+                actual, expected, checkExact=False, checkRowOrder=checkRowOrder
+            )
+        elif not isinstance(actual, (DataFrame, ps.DataFrame)):
             raise PySparkAssertionError(
                 error_class="INVALID_TYPE_DF_EQUALITY_ARG",
                 message_parameters={
-                    "expected_type": DataFrame,
-                    "arg_name": "df",
+                    "expected_type": f"{DataFrame.__name__}, 
{ps.DataFrame.__name__}",
+                    "arg_name": "actual",
                     "actual_type": type(actual),
                 },
             )
-        elif not isinstance(expected, DataFrame) and not isinstance(expected, 
List):
+        elif not isinstance(expected, (DataFrame, ps.DataFrame, list)):
             raise PySparkAssertionError(
                 error_class="INVALID_TYPE_DF_EQUALITY_ARG",
                 message_parameters={
-                    "expected_type": Union[DataFrame, List[Row]],
+                    "expected_type": f"{DataFrame.__name__}, "
+                    f"{ps.DataFrame.__name__}, "
+                    f"{List[Row].__name__}",
                     "arg_name": "expected",
                     "actual_type": type(expected),
                 },


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to