This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9798244ca64 [SPARK-45135][PYTHON][TESTS] Make `utils.eventually` a parameterized decorator 9798244ca64 is described below commit 9798244ca647ec68d36f4b9b21356a6de5f73f77 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu Sep 14 08:44:53 2023 +0800 [SPARK-45135][PYTHON][TESTS] Make `utils.eventually` a parameterized decorator ### What changes were proposed in this pull request? - Make utils.eventually a parameterized decorator - Retry `test_read_images` if it fails ### Why are the changes needed? previously, we used `utils.eventually` to retry flaky tests, e.g. https://github.com/apache/spark/commit/745ed93fe451b3f9e8148b06356c28b889a4db5a however, it needs to modify the test body, sometime the change maybe large To minimize the changes, I'd like to ~~add a decorator for test retry~~ Make utils.eventually a parameterized decorator. ### Does this PR introduce _any_ user-facing change? no, test-only ### How was this patch tested? CI ### Was this patch authored or co-authored using generative AI tooling? no Closes #42891 from zhengruifeng/test_retry_test_read_images. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/ml/tests/test_image.py | 3 +- python/pyspark/ml/tests/test_wrapper.py | 4 +- python/pyspark/mllib/tests/test_algorithms.py | 37 ++++++------ .../mllib/tests/test_streaming_algorithms.py | 24 ++++---- .../pandas/test_pandas_grouped_map_with_state.py | 2 +- python/pyspark/testing/utils.py | 70 ++++++++++++++-------- python/pyspark/tests/test_taskcontext.py | 3 +- python/pyspark/tests/test_util.py | 23 ++++++- python/pyspark/tests/test_worker.py | 15 ++--- 9 files changed, 112 insertions(+), 69 deletions(-) diff --git a/python/pyspark/ml/tests/test_image.py b/python/pyspark/ml/tests/test_image.py index 86fa46c3248..ee254a41007 100644 --- a/python/pyspark/ml/tests/test_image.py +++ b/python/pyspark/ml/tests/test_image.py @@ -19,10 +19,11 @@ import unittest from pyspark.ml.image import ImageSchema from pyspark.testing.mlutils import SparkSessionTestCase from pyspark.sql import Row -from pyspark.testing.utils import QuietTest +from pyspark.testing.utils import QuietTest, eventually class ImageFileFormatTest(SparkSessionTestCase): + @eventually(timeout=60.0, catch_assertions=True) def test_read_images(self): data_path = "data/mllib/images/origin/kittens" df = ( diff --git a/python/pyspark/ml/tests/test_wrapper.py b/python/pyspark/ml/tests/test_wrapper.py index 33d93c02acd..3efdbabd998 100644 --- a/python/pyspark/ml/tests/test_wrapper.py +++ b/python/pyspark/ml/tests/test_wrapper.py @@ -63,7 +63,7 @@ class JavaWrapperMemoryTests(SparkSessionTestCase): self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) return True - eventually(condition, timeout=10, catch_assertions=True) + eventually(timeout=10, catch_assertions=True)(condition)() try: summary.__del__() @@ -77,7 +77,7 @@ class JavaWrapperMemoryTests(SparkSessionTestCase): summary._java_obj.toString() return True - eventually(condition, timeout=10, catch_assertions=True) + eventually(timeout=10, catch_assertions=True)(condition)() class WrapperTests(MLlibTestCase): diff --git a/python/pyspark/mllib/tests/test_algorithms.py b/python/pyspark/mllib/tests/test_algorithms.py index dc48c2c021d..bcedd65b05b 100644 --- a/python/pyspark/mllib/tests/test_algorithms.py +++ b/python/pyspark/mllib/tests/test_algorithms.py @@ -97,27 +97,28 @@ class ListTests(MLlibTestCase): # TODO: Allow small numeric difference. self.assertTrue(array_equal(c1, c2)) + @eventually(timeout=60, catch_assertions=True) def test_gmm(self): from pyspark.mllib.clustering import GaussianMixture - def condition(): - data = self.sc.parallelize( - [ - [1, 2], - [8, 9], - [-4, -3], - [-6, -7], - ] - ) - clusters = GaussianMixture.train( - data, 2, convergenceTol=0.001, maxIterations=10, seed=1 - ) - labels = clusters.predict(data).collect() - self.assertEqual(labels[0], labels[1]) - self.assertEqual(labels[2], labels[3]) - return True - - eventually(condition, timeout=60, catch_assertions=True) + data = self.sc.parallelize( + [ + [1, 2], + [8, 9], + [-4, -3], + [-6, -7], + ] + ) + clusters = GaussianMixture.train( + data, + 2, + convergenceTol=0.001, + maxIterations=10, + seed=1, + ) + labels = clusters.predict(data).collect() + self.assertEqual(labels[0], labels[1]) + self.assertEqual(labels[2], labels[3]) def test_gmm_deterministic(self): from pyspark.mllib.clustering import GaussianMixture diff --git a/python/pyspark/mllib/tests/test_streaming_algorithms.py b/python/pyspark/mllib/tests/test_streaming_algorithms.py index 5a06742ba75..3ec98b7e735 100644 --- a/python/pyspark/mllib/tests/test_streaming_algorithms.py +++ b/python/pyspark/mllib/tests/test_streaming_algorithms.py @@ -73,7 +73,7 @@ class StreamingKMeansTest(MLLibStreamingTestCase): self.assertEqual(stkm.latestModel().clusterWeights, [25.0]) return True - eventually(condition, catch_assertions=True) + eventually(catch_assertions=True)(condition)() realCenters = array_sum(array(centers), axis=0) for i in range(5): @@ -118,7 +118,7 @@ class StreamingKMeansTest(MLLibStreamingTestCase): self.assertEqual(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) return True - eventually(condition, 90, catch_assertions=True) + eventually(timeout=90, catch_assertions=True)(condition)() def test_predictOn_model(self): """Test that the model predicts correctly on toy data.""" @@ -147,7 +147,7 @@ class StreamingKMeansTest(MLLibStreamingTestCase): self.assertEqual(result, [[0], [1], [2], [3]]) return True - eventually(condition, catch_assertions=True) + eventually(catch_assertions=True)(condition)() @unittest.skip("SPARK-10086: Flaky StreamingKMeans test in PySpark") def test_trainOn_predictOn(self): @@ -180,7 +180,7 @@ class StreamingKMeansTest(MLLibStreamingTestCase): self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]]) return True - eventually(condition, catch_assertions=True) + eventually(catch_assertions=True)(condition)() class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase): @@ -223,7 +223,7 @@ class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase): self.assertAlmostEqual(rel, 0.1, 1) return True - eventually(condition, timeout=120.0, catch_assertions=True) + eventually(timeout=120.0, catch_assertions=True)(condition)() def test_convergence(self): """ @@ -247,7 +247,7 @@ class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase): return True # We want all batches to finish for this test. - eventually(condition, 120, catch_assertions=True) + eventually(timeout=120, catch_assertions=True)(condition)() t_models = array(models) diff = t_models[1:] - t_models[:-1] @@ -278,7 +278,7 @@ class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase): self.assertEqual(len(true_predicted), len(input_batches)) return True - eventually(condition, catch_assertions=True) + eventually(catch_assertions=True)(condition)() # Test that the accuracy error is no more than 0.4 on each batch. for batch in true_predicted: @@ -319,7 +319,7 @@ class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase): return True return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) - eventually(condition, timeout=180.0) + eventually(timeout=180.0)(condition)() class StreamingLinearRegressionWithTests(MLLibStreamingTestCase): @@ -354,7 +354,7 @@ class StreamingLinearRegressionWithTests(MLLibStreamingTestCase): self.assertAlmostEqual(slr.latestModel().intercept, 0.0, 1) return True - eventually(condition, catch_assertions=True) + eventually(catch_assertions=True)(condition)() def test_parameter_convergence(self): """Test that the model parameters improve with streaming data.""" @@ -380,7 +380,7 @@ class StreamingLinearRegressionWithTests(MLLibStreamingTestCase): return True # We want all batches to finish for this test. - eventually(condition, 90, catch_assertions=True) + eventually(timeout=90, catch_assertions=True)(condition)() w = array(model_weights) diff = w[1:] - w[:-1] @@ -412,7 +412,7 @@ class StreamingLinearRegressionWithTests(MLLibStreamingTestCase): return True # We want all batches to finish for this test. - eventually(condition, catch_assertions=True) + eventually(catch_assertions=True)(condition)() # Test that mean absolute error on each batch is less than 0.1 for batch in samples: @@ -456,7 +456,7 @@ class StreamingLinearRegressionWithTests(MLLibStreamingTestCase): return True return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) - eventually(condition, timeout=180.0) + eventually(timeout=180.0)(condition)() if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py index e1ec97928f7..12ee9319d2c 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py @@ -319,7 +319,7 @@ class GroupedApplyInPandasWithStateTestsMixin: return False try: - eventually(assert_test, timeout=120) + eventually(timeout=120)(assert_test)() finally: q.stop() diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index 2a508b8a450..86cfca58a6b 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -21,6 +21,7 @@ import struct import sys import unittest import difflib +import functools from decimal import Decimal from time import time, sleep from typing import ( @@ -29,8 +30,7 @@ from typing import ( Union, Dict, List, - Tuple, - Iterator, + Callable, ) from itertools import zip_longest @@ -71,7 +71,10 @@ def write_int(i): return struct.pack("!i", i) -def eventually(condition, timeout=30.0, catch_assertions=False): +def eventually( + timeout=30.0, + catch_assertions=False, +): """ Wait a given amount of time for a condition to pass, else fail with an error. This is a helper utility for PySpark tests. @@ -80,7 +83,7 @@ def eventually(condition, timeout=30.0, catch_assertions=False): ---------- condition : function Function that checks for termination conditions. condition() can return: - - True: Conditions met. Return without error. + - True or None: Conditions met. Return without error. - other value: Conditions not met yet. Continue. Upon timeout, include last such value in error message. Note that this method may be called at any time during @@ -93,26 +96,45 @@ def eventually(condition, timeout=30.0, catch_assertions=False): If True, catch AssertionErrors; continue, but save error to throw upon timeout. """ - start_time = time() - lastValue = None - while time() - start_time < timeout: - if catch_assertions: - try: - lastValue = condition() - except AssertionError as e: - lastValue = e - else: - lastValue = condition() - if lastValue is True: - return - sleep(0.01) - if isinstance(lastValue, AssertionError): - raise lastValue - else: - raise AssertionError( - "Test failed due to timeout after %g sec, with last condition returning: %s" - % (timeout, lastValue) - ) + assert timeout > 0 + assert isinstance(catch_assertions, bool) + + def decorator(condition: Callable) -> Callable: + assert isinstance(condition, Callable) + + @functools.wraps(condition) + def wrapper(*args: Any, **kwargs: Any) -> Any: + start_time = time() + lastValue = None + numTries = 0 + while time() - start_time < timeout: + numTries += 1 + + if catch_assertions: + try: + lastValue = condition(*args, **kwargs) + except AssertionError as e: + lastValue = e + else: + lastValue = condition(*args, **kwargs) + + if lastValue is True or lastValue is None: + return + + print(f"\nAttempt #{numTries} failed!\n{lastValue}") + sleep(0.01) + + if isinstance(lastValue, AssertionError): + raise lastValue + else: + raise AssertionError( + "Test failed due to timeout after %g sec, with last condition returning: %s" + % (timeout, lastValue) + ) + + return wrapper + + return decorator class QuietTest: diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py index 5d410aa57e7..50acc3bab07 100644 --- a/python/pyspark/tests/test_taskcontext.py +++ b/python/pyspark/tests/test_taskcontext.py @@ -286,11 +286,12 @@ class TaskContextTestsWithWorkerReuse(unittest.TestCase): self.assertTrue(pid in worker_pids) return True + @eventually(catch_assertions=True) def test_task_context_correct_with_python_worker_reuse(self): # Retrying the check as the PIDs from Python workers might be different even # when reusing Python workers is enabled if a Python worker is dead for some reasons # (e.g., socket connection failure) and new Python worker is created. - eventually(self.check_task_context_correct_with_python_worker_reuse, catch_assertions=True) + self.check_task_context_correct_with_python_worker_reuse() def tearDown(self): self.sc.stop() diff --git a/python/pyspark/tests/test_util.py b/python/pyspark/tests/test_util.py index 77f06721b10..26dc3db74de 100644 --- a/python/pyspark/tests/test_util.py +++ b/python/pyspark/tests/test_util.py @@ -20,7 +20,7 @@ import unittest from py4j.protocol import Py4JJavaError from pyspark import keyword_only -from pyspark.testing.utils import PySparkTestCase +from pyspark.testing.utils import PySparkTestCase, eventually from pyspark.find_spark_home import _find_spark_home @@ -84,6 +84,27 @@ class UtilTests(PySparkTestCase): finally: os.environ["SPARK_HOME"] = origin + @eventually(timeout=180, catch_assertions=True) + def test_eventually_decorator(self): + import random + + self.assertTrue(random.random() < 0.1) + + def test_eventually_function(self): + import random + + def condition(): + self.assertTrue(random.random() < 0.1) + + eventually(timeout=180, catch_assertions=True)(condition)() + + def test_eventually_lambda(self): + import random + + eventually(timeout=180, catch_assertions=True)( + lambda: self.assertTrue(random.random() < 0.1) + )() + if __name__ == "__main__": from pyspark.tests.test_util import * # noqa: F401 diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py index 703690bf7f9..2d675811fb9 100644 --- a/python/pyspark/tests/test_worker.py +++ b/python/pyspark/tests/test_worker.py @@ -186,16 +186,13 @@ class WorkerTests(ReusedPySparkTestCase): class WorkerReuseTest(PySparkTestCase): + @eventually(catch_assertions=True) def test_reuse_worker_of_parallelize_range(self): - def check_reuse_worker_of_parallelize_range(): - rdd = self.sc.parallelize(range(20), 8) - previous_pids = rdd.map(lambda x: os.getpid()).collect() - current_pids = rdd.map(lambda x: os.getpid()).collect() - for pid in current_pids: - self.assertTrue(pid in previous_pids) - return True - - eventually(check_reuse_worker_of_parallelize_range, catch_assertions=True) + rdd = self.sc.parallelize(range(20), 8) + previous_pids = rdd.map(lambda x: os.getpid()).collect() + current_pids = rdd.map(lambda x: os.getpid()).collect() + for pid in current_pids: + self.assertTrue(pid in previous_pids) @unittest.skipIf( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org