xinrong-databricks commented on a change in pull request #32177:
URL: https://github.com/apache/spark/pull/32177#discussion_r613723822



##########
File path: python/pyspark/testing/utils.py
##########
@@ -171,3 +209,393 @@ def search_jar(project_relative_path, 
sbt_jar_name_prefix, mvn_jar_name_prefix):
         raise Exception("Found multiple JARs: %s; please remove all but one" % 
(", ".join(jars)))
     else:
         return jars[0]
+
+
+# Utilities below are used mainly in pyspark/pandas
+class SQLTestUtils(object):
+    """
+    This util assumes the instance of this to have 'spark' attribute, having a 
spark session.
+    It is usually used with 'ReusedSQLTestCase' class but can be used if you 
feel sure the
+    the implementation of this class has 'spark' attribute.
+    """
+
+    @contextmanager
+    def sql_conf(self, pairs):
+        """
+        A convenient context manager to test some configuration specific 
logic. This sets
+        `value` to the configuration `key` and then restores it back when it 
exits.
+        """
+        assert hasattr(self, "spark"), "it should have 'spark' attribute, 
having a spark session."
+
+        with sqlc(pairs, spark=self.spark):
+            yield
+
+    @contextmanager
+    def database(self, *databases):
+        """
+        A convenient context manager to test with some specific databases. 
This drops the given
+        databases if it exists and sets current database to "default" when it 
exits.
+        """
+        assert hasattr(self, "spark"), "it should have 'spark' attribute, 
having a spark session."
+
+        try:
+            yield
+        finally:
+            for db in databases:
+                self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db)
+            self.spark.catalog.setCurrentDatabase("default")
+
+    @contextmanager
+    def table(self, *tables):
+        """
+        A convenient context manager to test with some specific tables. This 
drops the given tables
+        if it exists.
+        """
+        assert hasattr(self, "spark"), "it should have 'spark' attribute, 
having a spark session."
+
+        try:
+            yield
+        finally:
+            for t in tables:
+                self.spark.sql("DROP TABLE IF EXISTS %s" % t)
+
+    @contextmanager
+    def tempView(self, *views):
+        """
+        A convenient context manager to test with some specific views. This 
drops the given views
+        if it exists.
+        """
+        assert hasattr(self, "spark"), "it should have 'spark' attribute, 
having a spark session."
+
+        try:
+            yield
+        finally:
+            for v in views:
+                self.spark.catalog.dropTempView(v)
+
+    @contextmanager
+    def function(self, *functions):
+        """
+        A convenient context manager to test with some specific functions. 
This drops the given
+        functions if it exists.
+        """
+        assert hasattr(self, "spark"), "it should have 'spark' attribute, 
having a spark session."
+
+        try:
+            yield
+        finally:
+            for f in functions:
+                self.spark.sql("DROP FUNCTION IF EXISTS %s" % f)
+
+
+class ReusedSQLTestCase(unittest.TestCase, SQLTestUtils):
+    @classmethod
+    def setUpClass(cls):
+        cls.spark = default_session()
+        cls.spark.conf.set(SPARK_CONF_ARROW_ENABLED, True)
+
+    @classmethod
+    def tearDownClass(cls):
+        # We don't stop Spark session to reuse across all tests.
+        # The Spark session will be started and stopped at PyTest session 
level.
+        # Please see databricks/koalas/conftest.py.
+        pass
+
+    def assertPandasEqual(self, left, right, check_exact=True):
+        if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame):
+            try:
+                if LooseVersion(pd.__version__) >= LooseVersion("1.1"):
+                    kwargs = dict(check_freq=False)
+                else:
+                    kwargs = dict()
+
+                assert_frame_equal(
+                    left,
+                    right,
+                    check_index_type=("equiv" if len(left.index) > 0 else 
False),
+                    check_column_type=("equiv" if len(left.columns) > 0 else 
False),
+                    check_exact=check_exact,
+                    **kwargs
+                )
+            except AssertionError as e:
+                msg = (
+                    str(e)
+                    + "\n\nLeft:\n%s\n%s" % (left, left.dtypes)
+                    + "\n\nRight:\n%s\n%s" % (right, right.dtypes)
+                )
+                raise AssertionError(msg) from e
+        elif isinstance(left, pd.Series) and isinstance(right, pd.Series):
+            try:
+                if LooseVersion(pd.__version__) >= LooseVersion("1.1"):
+                    kwargs = dict(check_freq=False)
+                else:
+                    kwargs = dict()
+
+                assert_series_equal(
+                    left,
+                    right,
+                    check_index_type=("equiv" if len(left.index) > 0 else 
False),
+                    check_exact=check_exact,
+                    **kwargs
+                )
+            except AssertionError as e:
+                msg = (
+                    str(e)
+                    + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
+                    + "\n\nRight:\n%s\n%s" % (right, right.dtype)
+                )
+                raise AssertionError(msg) from e
+        elif isinstance(left, pd.Index) and isinstance(right, pd.Index):
+            try:
+                assert_index_equal(left, right, check_exact=check_exact)
+            except AssertionError as e:
+                msg = (
+                    str(e)
+                    + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
+                    + "\n\nRight:\n%s\n%s" % (right, right.dtype)
+                )
+                raise AssertionError(msg) from e
+        else:
+            raise ValueError("Unexpected values: (%s, %s)" % (left, right))
+
+    def assertPandasAlmostEqual(self, left, right):
+        """
+        This function checks if given pandas objects approximately same,
+        which means the conditions below:
+          - Both objects are nullable
+          - Compare floats rounding to the number of decimal places, 7 after
+            dropping missing values (NaN, NaT, None)
+        """
+        if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame):
+            msg = (
+                "DataFrames are not almost equal: "
+                + "\n\nLeft:\n%s\n%s" % (left, left.dtypes)
+                + "\n\nRight:\n%s\n%s" % (right, right.dtypes)
+            )
+            self.assertEqual(left.shape, right.shape, msg=msg)
+            for lcol, rcol in zip(left.columns, right.columns):
+                self.assertEqual(lcol, rcol, msg=msg)
+                for lnull, rnull in zip(left[lcol].isnull(), 
right[rcol].isnull()):
+                    self.assertEqual(lnull, rnull, msg=msg)
+                for lval, rval in zip(left[lcol].dropna(), 
right[rcol].dropna()):
+                    self.assertAlmostEqual(lval, rval, msg=msg)
+            self.assertEqual(left.columns.names, right.columns.names, msg=msg)
+        elif isinstance(left, pd.Series) and isinstance(right, pd.Series):
+            msg = (
+                "Series are not almost equal: "
+                + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
+                + "\n\nRight:\n%s\n%s" % (right, right.dtype)
+            )
+            self.assertEqual(left.name, right.name, msg=msg)
+            self.assertEqual(len(left), len(right), msg=msg)
+            for lnull, rnull in zip(left.isnull(), right.isnull()):
+                self.assertEqual(lnull, rnull, msg=msg)
+            for lval, rval in zip(left.dropna(), right.dropna()):
+                self.assertAlmostEqual(lval, rval, msg=msg)
+        elif isinstance(left, pd.MultiIndex) and isinstance(right, 
pd.MultiIndex):
+            msg = (
+                "MultiIndices are not almost equal: "
+                + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
+                + "\n\nRight:\n%s\n%s" % (right, right.dtype)
+            )
+            self.assertEqual(len(left), len(right), msg=msg)
+            for lval, rval in zip(left, right):
+                self.assertAlmostEqual(lval, rval, msg=msg)
+        elif isinstance(left, pd.Index) and isinstance(right, pd.Index):
+            msg = (
+                "Indices are not almost equal: "
+                + "\n\nLeft:\n%s\n%s" % (left, left.dtype)
+                + "\n\nRight:\n%s\n%s" % (right, right.dtype)
+            )
+            self.assertEqual(len(left), len(right), msg=msg)
+            for lnull, rnull in zip(left.isnull(), right.isnull()):
+                self.assertEqual(lnull, rnull, msg=msg)
+            for lval, rval in zip(left.dropna(), right.dropna()):
+                self.assertAlmostEqual(lval, rval, msg=msg)
+        else:
+            raise ValueError("Unexpected values: (%s, %s)" % (left, right))
+
+    def assert_eq(self, left, right, check_exact=True, almost=False):
+        """
+        Asserts if two arbitrary objects are equal or not. If given objects 
are Koalas DataFrame
+        or Series, they are converted into pandas' and compared.
+
+        :param left: object to compare
+        :param right: object to compare
+        :param check_exact: if this is False, the comparison is done less 
precisely.
+        :param almost: if this is enabled, the comparison is delegated to 
`unittest`'s
+                       `assertAlmostEqual`. See its documentation for more 
details.
+        """
+        lobj = self._to_pandas(left)
+        robj = self._to_pandas(right)
+        if isinstance(lobj, (pd.DataFrame, pd.Series, pd.Index)):
+            if almost:
+                self.assertPandasAlmostEqual(lobj, robj)
+            else:
+                self.assertPandasEqual(lobj, robj, check_exact=check_exact)
+        elif is_list_like(lobj) and is_list_like(robj):
+            self.assertTrue(len(left) == len(right))
+            for litem, ritem in zip(left, right):
+                self.assert_eq(litem, ritem, check_exact=check_exact, 
almost=almost)
+        elif (lobj is not None and pd.isna(lobj)) and (robj is not None and 
pd.isna(robj)):
+            pass
+        else:
+            if almost:
+                self.assertAlmostEqual(lobj, robj)
+            else:
+                self.assertEqual(lobj, robj)
+
+    @staticmethod
+    def _to_pandas(obj):
+        if isinstance(obj, (DataFrame, Series, Index)):
+            return obj.to_pandas()
+        else:
+            return obj
+
+
+class TestUtils(object):
+    @contextmanager
+    def temp_dir(self):
+        tmp = tempfile.mkdtemp()
+        try:
+            yield tmp
+        finally:
+            shutil.rmtree(tmp)
+
+    @contextmanager
+    def temp_file(self):
+        with self.temp_dir() as tmp:
+            yield tempfile.mktemp(dir=tmp)
+
+
+class ComparisonTestBase(ReusedSQLTestCase):
+    @property
+    def kdf(self):
+        return ps.from_pandas(self.pdf)
+
+    @property
+    def pdf(self):
+        return self.kdf.to_pandas()
+
+
+def compare_both(f=None, almost=True):
+
+    if f is None:
+        return functools.partial(compare_both, almost=almost)
+    elif isinstance(f, bool):
+        return functools.partial(compare_both, almost=f)
+
+    @functools.wraps(f)
+    def wrapped(self):
+        if almost:
+            compare = self.assertPandasAlmostEqual
+        else:
+            compare = self.assertPandasEqual
+
+        for result_pandas, result_spark in zip(f(self, self.pdf), f(self, 
self.kdf)):
+            compare(result_pandas, result_spark.to_pandas())
+
+    return wrapped
+
+
+@contextmanager
+def assert_produces_warning(
+        expected_warning=Warning,
+        filter_level="always",
+        check_stacklevel=True,
+        raise_on_extra_warnings=True,

Review comment:
       Cool!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to