This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9798244ca64 [SPARK-45135][PYTHON][TESTS] Make `utils.eventually` a 
parameterized decorator
9798244ca64 is described below

commit 9798244ca647ec68d36f4b9b21356a6de5f73f77
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu Sep 14 08:44:53 2023 +0800

    [SPARK-45135][PYTHON][TESTS] Make `utils.eventually` a parameterized 
decorator
    
    ### What changes were proposed in this pull request?
    
    - Make utils.eventually a parameterized decorator
    - Retry `test_read_images` if it fails
    
    ### Why are the changes needed?
    previously, we used `utils.eventually` to retry flaky tests, e.g. 
https://github.com/apache/spark/commit/745ed93fe451b3f9e8148b06356c28b889a4db5a
    
    however, it needs to modify the test body, sometime the change maybe large
    
    To minimize the changes, I'd like to ~~add a decorator for test retry~~ 
Make utils.eventually a parameterized decorator.
    
    ### Does this PR introduce _any_ user-facing change?
    no, test-only
    
    ### How was this patch tested?
    CI
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #42891 from zhengruifeng/test_retry_test_read_images.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/ml/tests/test_image.py              |  3 +-
 python/pyspark/ml/tests/test_wrapper.py            |  4 +-
 python/pyspark/mllib/tests/test_algorithms.py      | 37 ++++++------
 .../mllib/tests/test_streaming_algorithms.py       | 24 ++++----
 .../pandas/test_pandas_grouped_map_with_state.py   |  2 +-
 python/pyspark/testing/utils.py                    | 70 ++++++++++++++--------
 python/pyspark/tests/test_taskcontext.py           |  3 +-
 python/pyspark/tests/test_util.py                  | 23 ++++++-
 python/pyspark/tests/test_worker.py                | 15 ++---
 9 files changed, 112 insertions(+), 69 deletions(-)

diff --git a/python/pyspark/ml/tests/test_image.py 
b/python/pyspark/ml/tests/test_image.py
index 86fa46c3248..ee254a41007 100644
--- a/python/pyspark/ml/tests/test_image.py
+++ b/python/pyspark/ml/tests/test_image.py
@@ -19,10 +19,11 @@ import unittest
 from pyspark.ml.image import ImageSchema
 from pyspark.testing.mlutils import SparkSessionTestCase
 from pyspark.sql import Row
-from pyspark.testing.utils import QuietTest
+from pyspark.testing.utils import QuietTest, eventually
 
 
 class ImageFileFormatTest(SparkSessionTestCase):
+    @eventually(timeout=60.0, catch_assertions=True)
     def test_read_images(self):
         data_path = "data/mllib/images/origin/kittens"
         df = (
diff --git a/python/pyspark/ml/tests/test_wrapper.py 
b/python/pyspark/ml/tests/test_wrapper.py
index 33d93c02acd..3efdbabd998 100644
--- a/python/pyspark/ml/tests/test_wrapper.py
+++ b/python/pyspark/ml/tests/test_wrapper.py
@@ -63,7 +63,7 @@ class JavaWrapperMemoryTests(SparkSessionTestCase):
             self.assertIn("LinearRegressionTrainingSummary", 
summary._java_obj.toString())
             return True
 
-        eventually(condition, timeout=10, catch_assertions=True)
+        eventually(timeout=10, catch_assertions=True)(condition)()
 
         try:
             summary.__del__()
@@ -77,7 +77,7 @@ class JavaWrapperMemoryTests(SparkSessionTestCase):
                 summary._java_obj.toString()
             return True
 
-        eventually(condition, timeout=10, catch_assertions=True)
+        eventually(timeout=10, catch_assertions=True)(condition)()
 
 
 class WrapperTests(MLlibTestCase):
diff --git a/python/pyspark/mllib/tests/test_algorithms.py 
b/python/pyspark/mllib/tests/test_algorithms.py
index dc48c2c021d..bcedd65b05b 100644
--- a/python/pyspark/mllib/tests/test_algorithms.py
+++ b/python/pyspark/mllib/tests/test_algorithms.py
@@ -97,27 +97,28 @@ class ListTests(MLlibTestCase):
             # TODO: Allow small numeric difference.
             self.assertTrue(array_equal(c1, c2))
 
+    @eventually(timeout=60, catch_assertions=True)
     def test_gmm(self):
         from pyspark.mllib.clustering import GaussianMixture
 
-        def condition():
-            data = self.sc.parallelize(
-                [
-                    [1, 2],
-                    [8, 9],
-                    [-4, -3],
-                    [-6, -7],
-                ]
-            )
-            clusters = GaussianMixture.train(
-                data, 2, convergenceTol=0.001, maxIterations=10, seed=1
-            )
-            labels = clusters.predict(data).collect()
-            self.assertEqual(labels[0], labels[1])
-            self.assertEqual(labels[2], labels[3])
-            return True
-
-        eventually(condition, timeout=60, catch_assertions=True)
+        data = self.sc.parallelize(
+            [
+                [1, 2],
+                [8, 9],
+                [-4, -3],
+                [-6, -7],
+            ]
+        )
+        clusters = GaussianMixture.train(
+            data,
+            2,
+            convergenceTol=0.001,
+            maxIterations=10,
+            seed=1,
+        )
+        labels = clusters.predict(data).collect()
+        self.assertEqual(labels[0], labels[1])
+        self.assertEqual(labels[2], labels[3])
 
     def test_gmm_deterministic(self):
         from pyspark.mllib.clustering import GaussianMixture
diff --git a/python/pyspark/mllib/tests/test_streaming_algorithms.py 
b/python/pyspark/mllib/tests/test_streaming_algorithms.py
index 5a06742ba75..3ec98b7e735 100644
--- a/python/pyspark/mllib/tests/test_streaming_algorithms.py
+++ b/python/pyspark/mllib/tests/test_streaming_algorithms.py
@@ -73,7 +73,7 @@ class StreamingKMeansTest(MLLibStreamingTestCase):
             self.assertEqual(stkm.latestModel().clusterWeights, [25.0])
             return True
 
-        eventually(condition, catch_assertions=True)
+        eventually(catch_assertions=True)(condition)()
 
         realCenters = array_sum(array(centers), axis=0)
         for i in range(5):
@@ -118,7 +118,7 @@ class StreamingKMeansTest(MLLibStreamingTestCase):
             self.assertEqual(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0])
             return True
 
-        eventually(condition, 90, catch_assertions=True)
+        eventually(timeout=90, catch_assertions=True)(condition)()
 
     def test_predictOn_model(self):
         """Test that the model predicts correctly on toy data."""
@@ -147,7 +147,7 @@ class StreamingKMeansTest(MLLibStreamingTestCase):
             self.assertEqual(result, [[0], [1], [2], [3]])
             return True
 
-        eventually(condition, catch_assertions=True)
+        eventually(catch_assertions=True)(condition)()
 
     @unittest.skip("SPARK-10086: Flaky StreamingKMeans test in PySpark")
     def test_trainOn_predictOn(self):
@@ -180,7 +180,7 @@ class StreamingKMeansTest(MLLibStreamingTestCase):
             self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]])
             return True
 
-        eventually(condition, catch_assertions=True)
+        eventually(catch_assertions=True)(condition)()
 
 
 class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase):
@@ -223,7 +223,7 @@ class 
StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase):
             self.assertAlmostEqual(rel, 0.1, 1)
             return True
 
-        eventually(condition, timeout=120.0, catch_assertions=True)
+        eventually(timeout=120.0, catch_assertions=True)(condition)()
 
     def test_convergence(self):
         """
@@ -247,7 +247,7 @@ class 
StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase):
             return True
 
         # We want all batches to finish for this test.
-        eventually(condition, 120, catch_assertions=True)
+        eventually(timeout=120, catch_assertions=True)(condition)()
 
         t_models = array(models)
         diff = t_models[1:] - t_models[:-1]
@@ -278,7 +278,7 @@ class 
StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase):
             self.assertEqual(len(true_predicted), len(input_batches))
             return True
 
-        eventually(condition, catch_assertions=True)
+        eventually(catch_assertions=True)(condition)()
 
         # Test that the accuracy error is no more than 0.4 on each batch.
         for batch in true_predicted:
@@ -319,7 +319,7 @@ class 
StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase):
                 return True
             return "Latest errors: " + ", ".join(map(lambda x: str(x), errors))
 
-        eventually(condition, timeout=180.0)
+        eventually(timeout=180.0)(condition)()
 
 
 class StreamingLinearRegressionWithTests(MLLibStreamingTestCase):
@@ -354,7 +354,7 @@ class 
StreamingLinearRegressionWithTests(MLLibStreamingTestCase):
             self.assertAlmostEqual(slr.latestModel().intercept, 0.0, 1)
             return True
 
-        eventually(condition, catch_assertions=True)
+        eventually(catch_assertions=True)(condition)()
 
     def test_parameter_convergence(self):
         """Test that the model parameters improve with streaming data."""
@@ -380,7 +380,7 @@ class 
StreamingLinearRegressionWithTests(MLLibStreamingTestCase):
             return True
 
         # We want all batches to finish for this test.
-        eventually(condition, 90, catch_assertions=True)
+        eventually(timeout=90, catch_assertions=True)(condition)()
 
         w = array(model_weights)
         diff = w[1:] - w[:-1]
@@ -412,7 +412,7 @@ class 
StreamingLinearRegressionWithTests(MLLibStreamingTestCase):
             return True
 
         # We want all batches to finish for this test.
-        eventually(condition, catch_assertions=True)
+        eventually(catch_assertions=True)(condition)()
 
         # Test that mean absolute error on each batch is less than 0.1
         for batch in samples:
@@ -456,7 +456,7 @@ class 
StreamingLinearRegressionWithTests(MLLibStreamingTestCase):
                 return True
             return "Latest errors: " + ", ".join(map(lambda x: str(x), errors))
 
-        eventually(condition, timeout=180.0)
+        eventually(timeout=180.0)(condition)()
 
 
 if __name__ == "__main__":
diff --git 
a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py 
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
index e1ec97928f7..12ee9319d2c 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
@@ -319,7 +319,7 @@ class GroupedApplyInPandasWithStateTestsMixin:
                 return False
 
         try:
-            eventually(assert_test, timeout=120)
+            eventually(timeout=120)(assert_test)()
         finally:
             q.stop()
 
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index 2a508b8a450..86cfca58a6b 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -21,6 +21,7 @@ import struct
 import sys
 import unittest
 import difflib
+import functools
 from decimal import Decimal
 from time import time, sleep
 from typing import (
@@ -29,8 +30,7 @@ from typing import (
     Union,
     Dict,
     List,
-    Tuple,
-    Iterator,
+    Callable,
 )
 from itertools import zip_longest
 
@@ -71,7 +71,10 @@ def write_int(i):
     return struct.pack("!i", i)
 
 
-def eventually(condition, timeout=30.0, catch_assertions=False):
+def eventually(
+    timeout=30.0,
+    catch_assertions=False,
+):
     """
     Wait a given amount of time for a condition to pass, else fail with an 
error.
     This is a helper utility for PySpark tests.
@@ -80,7 +83,7 @@ def eventually(condition, timeout=30.0, 
catch_assertions=False):
     ----------
     condition : function
         Function that checks for termination conditions. condition() can 
return:
-            - True: Conditions met. Return without error.
+            - True or None: Conditions met. Return without error.
             - other value: Conditions not met yet. Continue. Upon timeout,
               include last such value in error message.
               Note that this method may be called at any time during
@@ -93,26 +96,45 @@ def eventually(condition, timeout=30.0, 
catch_assertions=False):
         If True, catch AssertionErrors; continue, but save
         error to throw upon timeout.
     """
-    start_time = time()
-    lastValue = None
-    while time() - start_time < timeout:
-        if catch_assertions:
-            try:
-                lastValue = condition()
-            except AssertionError as e:
-                lastValue = e
-        else:
-            lastValue = condition()
-        if lastValue is True:
-            return
-        sleep(0.01)
-    if isinstance(lastValue, AssertionError):
-        raise lastValue
-    else:
-        raise AssertionError(
-            "Test failed due to timeout after %g sec, with last condition 
returning: %s"
-            % (timeout, lastValue)
-        )
+    assert timeout > 0
+    assert isinstance(catch_assertions, bool)
+
+    def decorator(condition: Callable) -> Callable:
+        assert isinstance(condition, Callable)
+
+        @functools.wraps(condition)
+        def wrapper(*args: Any, **kwargs: Any) -> Any:
+            start_time = time()
+            lastValue = None
+            numTries = 0
+            while time() - start_time < timeout:
+                numTries += 1
+
+                if catch_assertions:
+                    try:
+                        lastValue = condition(*args, **kwargs)
+                    except AssertionError as e:
+                        lastValue = e
+                else:
+                    lastValue = condition(*args, **kwargs)
+
+                if lastValue is True or lastValue is None:
+                    return
+
+                print(f"\nAttempt #{numTries} failed!\n{lastValue}")
+                sleep(0.01)
+
+            if isinstance(lastValue, AssertionError):
+                raise lastValue
+            else:
+                raise AssertionError(
+                    "Test failed due to timeout after %g sec, with last 
condition returning: %s"
+                    % (timeout, lastValue)
+                )
+
+        return wrapper
+
+    return decorator
 
 
 class QuietTest:
diff --git a/python/pyspark/tests/test_taskcontext.py 
b/python/pyspark/tests/test_taskcontext.py
index 5d410aa57e7..50acc3bab07 100644
--- a/python/pyspark/tests/test_taskcontext.py
+++ b/python/pyspark/tests/test_taskcontext.py
@@ -286,11 +286,12 @@ class TaskContextTestsWithWorkerReuse(unittest.TestCase):
             self.assertTrue(pid in worker_pids)
         return True
 
+    @eventually(catch_assertions=True)
     def test_task_context_correct_with_python_worker_reuse(self):
         # Retrying the check as the PIDs from Python workers might be 
different even
         # when reusing Python workers is enabled if a Python worker is dead 
for some reasons
         # (e.g., socket connection failure) and new Python worker is created.
-        eventually(self.check_task_context_correct_with_python_worker_reuse, 
catch_assertions=True)
+        self.check_task_context_correct_with_python_worker_reuse()
 
     def tearDown(self):
         self.sc.stop()
diff --git a/python/pyspark/tests/test_util.py 
b/python/pyspark/tests/test_util.py
index 77f06721b10..26dc3db74de 100644
--- a/python/pyspark/tests/test_util.py
+++ b/python/pyspark/tests/test_util.py
@@ -20,7 +20,7 @@ import unittest
 from py4j.protocol import Py4JJavaError
 
 from pyspark import keyword_only
-from pyspark.testing.utils import PySparkTestCase
+from pyspark.testing.utils import PySparkTestCase, eventually
 from pyspark.find_spark_home import _find_spark_home
 
 
@@ -84,6 +84,27 @@ class UtilTests(PySparkTestCase):
         finally:
             os.environ["SPARK_HOME"] = origin
 
+    @eventually(timeout=180, catch_assertions=True)
+    def test_eventually_decorator(self):
+        import random
+
+        self.assertTrue(random.random() < 0.1)
+
+    def test_eventually_function(self):
+        import random
+
+        def condition():
+            self.assertTrue(random.random() < 0.1)
+
+        eventually(timeout=180, catch_assertions=True)(condition)()
+
+    def test_eventually_lambda(self):
+        import random
+
+        eventually(timeout=180, catch_assertions=True)(
+            lambda: self.assertTrue(random.random() < 0.1)
+        )()
+
 
 if __name__ == "__main__":
     from pyspark.tests.test_util import *  # noqa: F401
diff --git a/python/pyspark/tests/test_worker.py 
b/python/pyspark/tests/test_worker.py
index 703690bf7f9..2d675811fb9 100644
--- a/python/pyspark/tests/test_worker.py
+++ b/python/pyspark/tests/test_worker.py
@@ -186,16 +186,13 @@ class WorkerTests(ReusedPySparkTestCase):
 
 
 class WorkerReuseTest(PySparkTestCase):
+    @eventually(catch_assertions=True)
     def test_reuse_worker_of_parallelize_range(self):
-        def check_reuse_worker_of_parallelize_range():
-            rdd = self.sc.parallelize(range(20), 8)
-            previous_pids = rdd.map(lambda x: os.getpid()).collect()
-            current_pids = rdd.map(lambda x: os.getpid()).collect()
-            for pid in current_pids:
-                self.assertTrue(pid in previous_pids)
-            return True
-
-        eventually(check_reuse_worker_of_parallelize_range, 
catch_assertions=True)
+        rdd = self.sc.parallelize(range(20), 8)
+        previous_pids = rdd.map(lambda x: os.getpid()).collect()
+        current_pids = rdd.map(lambda x: os.getpid()).collect()
+        for pid in current_pids:
+            self.assertTrue(pid in previous_pids)
 
 
 @unittest.skipIf(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to