xinrong-databricks commented on a change in pull request #32177: URL: https://github.com/apache/spark/pull/32177#discussion_r613723822
########## File path: python/pyspark/testing/utils.py ########## @@ -171,3 +209,393 @@ def search_jar(project_relative_path, sbt_jar_name_prefix, mvn_jar_name_prefix): raise Exception("Found multiple JARs: %s; please remove all but one" % (", ".join(jars))) else: return jars[0] + + +# Utilities below are used mainly in pyspark/pandas +class SQLTestUtils(object): + """ + This util assumes the instance of this to have 'spark' attribute, having a spark session. + It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the + the implementation of this class has 'spark' attribute. + """ + + @contextmanager + def sql_conf(self, pairs): + """ + A convenient context manager to test some configuration specific logic. This sets + `value` to the configuration `key` and then restores it back when it exits. + """ + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." + + with sqlc(pairs, spark=self.spark): + yield + + @contextmanager + def database(self, *databases): + """ + A convenient context manager to test with some specific databases. This drops the given + databases if it exists and sets current database to "default" when it exits. + """ + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." + + try: + yield + finally: + for db in databases: + self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db) + self.spark.catalog.setCurrentDatabase("default") + + @contextmanager + def table(self, *tables): + """ + A convenient context manager to test with some specific tables. This drops the given tables + if it exists. + """ + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." + + try: + yield + finally: + for t in tables: + self.spark.sql("DROP TABLE IF EXISTS %s" % t) + + @contextmanager + def tempView(self, *views): + """ + A convenient context manager to test with some specific views. This drops the given views + if it exists. + """ + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." + + try: + yield + finally: + for v in views: + self.spark.catalog.dropTempView(v) + + @contextmanager + def function(self, *functions): + """ + A convenient context manager to test with some specific functions. This drops the given + functions if it exists. + """ + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." + + try: + yield + finally: + for f in functions: + self.spark.sql("DROP FUNCTION IF EXISTS %s" % f) + + +class ReusedSQLTestCase(unittest.TestCase, SQLTestUtils): + @classmethod + def setUpClass(cls): + cls.spark = default_session() + cls.spark.conf.set(SPARK_CONF_ARROW_ENABLED, True) + + @classmethod + def tearDownClass(cls): + # We don't stop Spark session to reuse across all tests. + # The Spark session will be started and stopped at PyTest session level. + # Please see databricks/koalas/conftest.py. + pass + + def assertPandasEqual(self, left, right, check_exact=True): + if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame): + try: + if LooseVersion(pd.__version__) >= LooseVersion("1.1"): + kwargs = dict(check_freq=False) + else: + kwargs = dict() + + assert_frame_equal( + left, + right, + check_index_type=("equiv" if len(left.index) > 0 else False), + check_column_type=("equiv" if len(left.columns) > 0 else False), + check_exact=check_exact, + **kwargs + ) + except AssertionError as e: + msg = ( + str(e) + + "\n\nLeft:\n%s\n%s" % (left, left.dtypes) + + "\n\nRight:\n%s\n%s" % (right, right.dtypes) + ) + raise AssertionError(msg) from e + elif isinstance(left, pd.Series) and isinstance(right, pd.Series): + try: + if LooseVersion(pd.__version__) >= LooseVersion("1.1"): + kwargs = dict(check_freq=False) + else: + kwargs = dict() + + assert_series_equal( + left, + right, + check_index_type=("equiv" if len(left.index) > 0 else False), + check_exact=check_exact, + **kwargs + ) + except AssertionError as e: + msg = ( + str(e) + + "\n\nLeft:\n%s\n%s" % (left, left.dtype) + + "\n\nRight:\n%s\n%s" % (right, right.dtype) + ) + raise AssertionError(msg) from e + elif isinstance(left, pd.Index) and isinstance(right, pd.Index): + try: + assert_index_equal(left, right, check_exact=check_exact) + except AssertionError as e: + msg = ( + str(e) + + "\n\nLeft:\n%s\n%s" % (left, left.dtype) + + "\n\nRight:\n%s\n%s" % (right, right.dtype) + ) + raise AssertionError(msg) from e + else: + raise ValueError("Unexpected values: (%s, %s)" % (left, right)) + + def assertPandasAlmostEqual(self, left, right): + """ + This function checks if given pandas objects approximately same, + which means the conditions below: + - Both objects are nullable + - Compare floats rounding to the number of decimal places, 7 after + dropping missing values (NaN, NaT, None) + """ + if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame): + msg = ( + "DataFrames are not almost equal: " + + "\n\nLeft:\n%s\n%s" % (left, left.dtypes) + + "\n\nRight:\n%s\n%s" % (right, right.dtypes) + ) + self.assertEqual(left.shape, right.shape, msg=msg) + for lcol, rcol in zip(left.columns, right.columns): + self.assertEqual(lcol, rcol, msg=msg) + for lnull, rnull in zip(left[lcol].isnull(), right[rcol].isnull()): + self.assertEqual(lnull, rnull, msg=msg) + for lval, rval in zip(left[lcol].dropna(), right[rcol].dropna()): + self.assertAlmostEqual(lval, rval, msg=msg) + self.assertEqual(left.columns.names, right.columns.names, msg=msg) + elif isinstance(left, pd.Series) and isinstance(right, pd.Series): + msg = ( + "Series are not almost equal: " + + "\n\nLeft:\n%s\n%s" % (left, left.dtype) + + "\n\nRight:\n%s\n%s" % (right, right.dtype) + ) + self.assertEqual(left.name, right.name, msg=msg) + self.assertEqual(len(left), len(right), msg=msg) + for lnull, rnull in zip(left.isnull(), right.isnull()): + self.assertEqual(lnull, rnull, msg=msg) + for lval, rval in zip(left.dropna(), right.dropna()): + self.assertAlmostEqual(lval, rval, msg=msg) + elif isinstance(left, pd.MultiIndex) and isinstance(right, pd.MultiIndex): + msg = ( + "MultiIndices are not almost equal: " + + "\n\nLeft:\n%s\n%s" % (left, left.dtype) + + "\n\nRight:\n%s\n%s" % (right, right.dtype) + ) + self.assertEqual(len(left), len(right), msg=msg) + for lval, rval in zip(left, right): + self.assertAlmostEqual(lval, rval, msg=msg) + elif isinstance(left, pd.Index) and isinstance(right, pd.Index): + msg = ( + "Indices are not almost equal: " + + "\n\nLeft:\n%s\n%s" % (left, left.dtype) + + "\n\nRight:\n%s\n%s" % (right, right.dtype) + ) + self.assertEqual(len(left), len(right), msg=msg) + for lnull, rnull in zip(left.isnull(), right.isnull()): + self.assertEqual(lnull, rnull, msg=msg) + for lval, rval in zip(left.dropna(), right.dropna()): + self.assertAlmostEqual(lval, rval, msg=msg) + else: + raise ValueError("Unexpected values: (%s, %s)" % (left, right)) + + def assert_eq(self, left, right, check_exact=True, almost=False): + """ + Asserts if two arbitrary objects are equal or not. If given objects are Koalas DataFrame + or Series, they are converted into pandas' and compared. + + :param left: object to compare + :param right: object to compare + :param check_exact: if this is False, the comparison is done less precisely. + :param almost: if this is enabled, the comparison is delegated to `unittest`'s + `assertAlmostEqual`. See its documentation for more details. + """ + lobj = self._to_pandas(left) + robj = self._to_pandas(right) + if isinstance(lobj, (pd.DataFrame, pd.Series, pd.Index)): + if almost: + self.assertPandasAlmostEqual(lobj, robj) + else: + self.assertPandasEqual(lobj, robj, check_exact=check_exact) + elif is_list_like(lobj) and is_list_like(robj): + self.assertTrue(len(left) == len(right)) + for litem, ritem in zip(left, right): + self.assert_eq(litem, ritem, check_exact=check_exact, almost=almost) + elif (lobj is not None and pd.isna(lobj)) and (robj is not None and pd.isna(robj)): + pass + else: + if almost: + self.assertAlmostEqual(lobj, robj) + else: + self.assertEqual(lobj, robj) + + @staticmethod + def _to_pandas(obj): + if isinstance(obj, (DataFrame, Series, Index)): + return obj.to_pandas() + else: + return obj + + +class TestUtils(object): + @contextmanager + def temp_dir(self): + tmp = tempfile.mkdtemp() + try: + yield tmp + finally: + shutil.rmtree(tmp) + + @contextmanager + def temp_file(self): + with self.temp_dir() as tmp: + yield tempfile.mktemp(dir=tmp) + + +class ComparisonTestBase(ReusedSQLTestCase): + @property + def kdf(self): + return ps.from_pandas(self.pdf) + + @property + def pdf(self): + return self.kdf.to_pandas() + + +def compare_both(f=None, almost=True): + + if f is None: + return functools.partial(compare_both, almost=almost) + elif isinstance(f, bool): + return functools.partial(compare_both, almost=f) + + @functools.wraps(f) + def wrapped(self): + if almost: + compare = self.assertPandasAlmostEqual + else: + compare = self.assertPandasEqual + + for result_pandas, result_spark in zip(f(self, self.pdf), f(self, self.kdf)): + compare(result_pandas, result_spark.to_pandas()) + + return wrapped + + +@contextmanager +def assert_produces_warning( + expected_warning=Warning, + filter_level="always", + check_stacklevel=True, + raise_on_extra_warnings=True, Review comment: Cool! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org