This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 122d982143f [SPARK-44721][CONNECT] Revamp retry logic and make retries run for 10 minutes 122d982143f is described below commit 122d982143f1ae1f2447701c6a7877cbf8bef4f0 Author: Alice Sayutina <alice.sayut...@databricks.com> AuthorDate: Thu Aug 17 09:56:57 2023 +0200 [SPARK-44721][CONNECT] Revamp retry logic and make retries run for 10 minutes ### What changes were proposed in this pull request? Change retry logic. For existing retry logic the maximum allowed wait time can be extremely low and even zero with small probability. This happens, because it waits random(0, T) for T in exponentialBackoff(). Revamp the logic to guarantee the minimum wait time of 10 minutes. Also synchronize retry behavior among python and scala. ### Why are the changes needed? This avoids certain class of client errors where client simply doesn't wait long enough. ### Does this PR introduce _any_ user-facing change? Changes are small from user perspective. The retries are running longer and smoother. ### How was this patch tested? UT Closes #42399 from cdkrot/revamp_retry_logic. Lead-authored-by: Alice Sayutina <alice.sayut...@databricks.com> Co-authored-by: Alice Sayutina <cdkr...@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> (cherry picked from commit 4952a03fdc22b36c1fb5bead09c5e2cc8b4602b8) Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../ExecutePlanResponseReattachableIterator.scala | 4 +- .../sql/connect/client/GrpcRetryHandler.scala | 84 ++++++++++++----- .../connect/client/SparkConnectClientSuite.scala | 26 +++++- python/pyspark/errors/error_classes.py | 5 - python/pyspark/sql/connect/client/core.py | 102 ++++++++++----------- .../sql/tests/connect/client/test_client.py | 24 +++++ .../sql/tests/connect/test_connect_basic.py | 12 +++ 7 files changed, 170 insertions(+), 87 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala index 5ef1151682b..aeb452faecf 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala @@ -301,6 +301,6 @@ class ExecutePlanResponseReattachableIterator( /** * Retries the given function with exponential backoff according to the client's retryPolicy. */ - private def retry[T](fn: => T, currentRetryNum: Int = 0): T = - GrpcRetryHandler.retry(retryPolicy)(fn, currentRetryNum) + private def retry[T](fn: => T): T = + GrpcRetryHandler.retry(retryPolicy)(fn) } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala index 6dad5b4b3a9..8b6f070b8f5 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.connect.client -import scala.annotation.tailrec -import scala.concurrent.duration.FiniteDuration +import scala.concurrent.duration.{Duration, FiniteDuration} +import scala.util.Random import scala.util.control.NonFatal import io.grpc.{Status, StatusRuntimeException} @@ -26,13 +26,15 @@ import io.grpc.stub.StreamObserver import org.apache.spark.internal.Logging -private[client] class GrpcRetryHandler(private val retryPolicy: GrpcRetryHandler.RetryPolicy) { +private[client] class GrpcRetryHandler( + private val retryPolicy: GrpcRetryHandler.RetryPolicy, + private val sleep: Long => Unit = Thread.sleep) { /** * Retries the given function with exponential backoff according to the client's retryPolicy. */ - def retry[T](fn: => T, currentRetryNum: Int = 0): T = - GrpcRetryHandler.retry(retryPolicy)(fn, currentRetryNum) + def retry[T](fn: => T): T = + GrpcRetryHandler.retry(retryPolicy, sleep)(fn) /** * Generalizes the retry logic for RPC calls that return an iterator. @@ -148,37 +150,62 @@ private[client] object GrpcRetryHandler extends Logging { /** * Retries the given function with exponential backoff according to the client's retryPolicy. + * * @param retryPolicy * The retry policy + * @param sleep + * The function which sleeps (takes number of milliseconds to sleep) * @param fn * The function to retry. - * @param currentRetryNum - * Current number of retries. * @tparam T * The return type of the function. * @return * The result of the function. */ - @tailrec final def retry[T](retryPolicy: RetryPolicy)(fn: => T, currentRetryNum: Int = 0): T = { - if (currentRetryNum > retryPolicy.maxRetries) { - throw new IllegalArgumentException( - s"The number of retries ($currentRetryNum) must not exceed " + - s"the maximum number of retires (${retryPolicy.maxRetries}).") + final def retry[T](retryPolicy: RetryPolicy, sleep: Long => Unit = Thread.sleep)( + fn: => T): T = { + var currentRetryNum = 0 + var exceptionList: Seq[Throwable] = Seq.empty + var nextBackoff: Duration = retryPolicy.initialBackoff + + if (retryPolicy.maxRetries < 0) { + throw new IllegalArgumentException("Can't have negative number of retries") } - try { - return fn - } catch { - case NonFatal(e) - if (retryPolicy.canRetry(e) || e.isInstanceOf[RetryException]) - && currentRetryNum < retryPolicy.maxRetries => - logWarning( - s"Non fatal error during RPC execution: $e, " + - s"retrying (currentRetryNum=$currentRetryNum)") - Thread.sleep( - (retryPolicy.maxBackoff min retryPolicy.initialBackoff * Math - .pow(retryPolicy.backoffMultiplier, currentRetryNum)).toMillis) + + while (currentRetryNum <= retryPolicy.maxRetries) { + if (currentRetryNum != 0) { + var currentBackoff = nextBackoff + nextBackoff = nextBackoff * retryPolicy.backoffMultiplier min retryPolicy.maxBackoff + + if (currentBackoff >= retryPolicy.minJitterThreshold) { + currentBackoff += Random.nextDouble() * retryPolicy.jitter + } + + sleep(currentBackoff.toMillis) + } + + try { + return fn + } catch { + case NonFatal(e) if retryPolicy.canRetry(e) && currentRetryNum < retryPolicy.maxRetries => + currentRetryNum += 1 + exceptionList = e +: exceptionList + + if (currentRetryNum <= retryPolicy.maxRetries) { + logWarning( + s"Non-Fatal error during RPC execution: $e, " + + s"retrying (currentRetryNum=$currentRetryNum)") + } else { + logWarning( + s"Non-Fatal error during RPC execution: $e, " + + s"exceeded retries (currentRetryNum=$currentRetryNum)") + } + } } - retry(retryPolicy)(fn, currentRetryNum + 1) + + val exception = exceptionList.head + exceptionList.tail.foreach(exception.addSuppressed(_)) + throw exception } /** @@ -210,10 +237,17 @@ private[client] object GrpcRetryHandler extends Logging { * Function that determines whether a retry is to be performed in the event of an error. */ case class RetryPolicy( + // Please synchronize changes here with Python side: + // pyspark/sql/connect/client/core.py + // + // Note: these constants are selected so that the maximum tolerated wait is guaranteed + // to be at least 10 minutes maxRetries: Int = 15, initialBackoff: FiniteDuration = FiniteDuration(50, "ms"), maxBackoff: FiniteDuration = FiniteDuration(1, "min"), backoffMultiplier: Double = 4.0, + jitter: FiniteDuration = FiniteDuration(500, "ms"), + minJitterThreshold: FiniteDuration = FiniteDuration(2, "s"), canRetry: Throwable => Boolean = retryException) {} /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index e483e0a7291..6348e0e49ca 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -220,10 +220,10 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { } } - private class DummyFn(val e: Throwable) { + private class DummyFn(val e: Throwable, numFails: Int = 3) { var counter = 0 def fn(): Int = { - if (counter < 3) { + if (counter < numFails) { counter += 1 throw e } else { @@ -232,6 +232,28 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { } } + test("SPARK-44721: Retries run for a minimum period") { + // repeat test few times to avoid random flakes + for (_ <- 1 to 10) { + var totalSleepMs: Long = 0 + + def sleep(t: Long): Unit = { + totalSleepMs += t + } + + val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE), numFails = 100) + val retryHandler = new GrpcRetryHandler(GrpcRetryHandler.RetryPolicy(), sleep) + + assertThrows[StatusRuntimeException] { + retryHandler.retry { + dummyFn.fn() + } + } + + assert(totalSleepMs >= 10 * 60 * 1000) // waited at least 10 minutes + } + } + test("SPARK-44275: retry actually retries") { val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE)) val retryPolicy = GrpcRetryHandler.RetryPolicy() diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index 279812ebae1..4709f01ba06 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -231,11 +231,6 @@ ERROR_CLASSES_JSON = """ "Duplicated field names in Arrow Struct are not allowed, got <field_names>" ] }, - "EXCEED_RETRY" : { - "message" : [ - "Retries exceeded but no exception caught." - ] - }, "HIGHER_ORDER_FUNCTION_SHOULD_RETURN_COLUMN" : { "message" : [ "Function `<func_name>` should return Column, got <return_type>." diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 3e0a35ba926..c2889c10e41 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -92,7 +92,7 @@ from pyspark.sql.pandas.types import _create_converter_to_pandas, from_arrow_sch from pyspark.sql.types import DataType, StructType, TimestampType, _has_type from pyspark.rdd import PythonEvalType from pyspark.storagelevel import StorageLevel -from pyspark.errors import PySparkValueError, PySparkRuntimeError +from pyspark.errors import PySparkValueError if TYPE_CHECKING: @@ -638,10 +638,17 @@ class SparkConnectClient(object): ) self._user_id = None self._retry_policy = { + # Please synchronize changes here with Scala side + # GrpcRetryHandler.scala + # + # Note: the number of retries is selected so that the maximum tolerated wait + # is guaranteed to be at least 10 minutes "max_retries": 15, - "backoff_multiplier": 4, + "backoff_multiplier": 4.0, "initial_backoff": 50, "max_backoff": 60000, + "jitter": 500, + "min_jitter_threshold": 2000, } if retry_policy: self._retry_policy.update(retry_policy) @@ -669,6 +676,11 @@ class SparkConnectClient(object): self._use_reattachable_execute = use_reattachable_execute # Configure logging for the SparkConnect client. + def _retrying(self) -> "Retrying": + return Retrying( + can_retry=SparkConnectClient.retry_exception, **self._retry_policy # type: ignore + ) + def disable_reattachable_execute(self) -> "SparkConnectClient": self._use_reattachable_execute = False return self @@ -1090,9 +1102,7 @@ class SparkConnectClient(object): ) try: - for attempt in Retrying( - can_retry=SparkConnectClient.retry_exception, **self._retry_policy - ): + for attempt in self._retrying(): with attempt: resp = self._stub.AnalyzePlan(req, metadata=self._builder.metadata()) if resp.session_id != self._session_id: @@ -1133,9 +1143,7 @@ class SparkConnectClient(object): for b in generator: handle_response(b) else: - for attempt in Retrying( - can_retry=SparkConnectClient.retry_exception, **self._retry_policy - ): + for attempt in self._retrying(): with attempt: for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()): handle_response(b) @@ -1220,9 +1228,7 @@ class SparkConnectClient(object): for b in generator: yield from handle_response(b) else: - for attempt in Retrying( - can_retry=SparkConnectClient.retry_exception, **self._retry_policy - ): + for attempt in self._retrying(): with attempt: for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()): yield from handle_response(b) @@ -1331,9 +1337,7 @@ class SparkConnectClient(object): req = self._config_request_with_metadata() req.operation.CopyFrom(operation) try: - for attempt in Retrying( - can_retry=SparkConnectClient.retry_exception, **self._retry_policy - ): + for attempt in self._retrying(): with attempt: resp = self._stub.Config(req, metadata=self._builder.metadata()) if resp.session_id != self._session_id: @@ -1376,9 +1380,7 @@ class SparkConnectClient(object): def interrupt_all(self) -> Optional[List[str]]: req = self._interrupt_request("all") try: - for attempt in Retrying( - can_retry=SparkConnectClient.retry_exception, **self._retry_policy - ): + for attempt in self._retrying(): with attempt: resp = self._stub.Interrupt(req, metadata=self._builder.metadata()) if resp.session_id != self._session_id: @@ -1394,9 +1396,7 @@ class SparkConnectClient(object): def interrupt_tag(self, tag: str) -> Optional[List[str]]: req = self._interrupt_request("tag", tag) try: - for attempt in Retrying( - can_retry=SparkConnectClient.retry_exception, **self._retry_policy - ): + for attempt in self._retrying(): with attempt: resp = self._stub.Interrupt(req, metadata=self._builder.metadata()) if resp.session_id != self._session_id: @@ -1412,9 +1412,7 @@ class SparkConnectClient(object): def interrupt_operation(self, op_id: str) -> Optional[List[str]]: req = self._interrupt_request("operation", op_id) try: - for attempt in Retrying( - can_retry=SparkConnectClient.retry_exception, **self._retry_policy - ): + for attempt in self._retrying(): with attempt: resp = self._stub.Interrupt(req, metadata=self._builder.metadata()) if resp.session_id != self._session_id: @@ -1538,12 +1536,14 @@ class RetryState: self._done = False self._count = 0 - def set_exception(self, exc: Optional[BaseException]) -> None: + def set_exception(self, exc: BaseException) -> None: self._exception = exc self._count += 1 - def exception(self) -> Optional[BaseException]: - return self._exception + def throw(self) -> None: + if self._exception is None: + raise RuntimeError("No exception is set") + raise self._exception def set_done(self) -> None: self._done = True @@ -1614,13 +1614,19 @@ class Retrying: initial_backoff: int, max_backoff: int, backoff_multiplier: float, + jitter: int, + min_jitter_threshold: int, can_retry: Callable[..., bool] = lambda x: True, + sleep: Callable[[float], None] = time.sleep, ) -> None: self._can_retry = can_retry self._max_retries = max_retries self._initial_backoff = initial_backoff self._max_backoff = max_backoff self._backoff_multiplier = backoff_multiplier + self._jitter = jitter + self._min_jitter_threshold = min_jitter_threshold + self._sleep = sleep def __iter__(self) -> Generator[AttemptManager, None, None]: """ @@ -1631,35 +1637,25 @@ class Retrying: A generator that yields the current attempt. """ retry_state = RetryState() - while True: - # Check if the operation was completed successfully. - if retry_state.done(): - break - - # If the number of retries have exceeded the maximum allowed retries. - if retry_state.count() > self._max_retries: - e = retry_state.exception() - if e is not None: - raise e - else: - raise PySparkRuntimeError( - error_class="EXCEED_RETRY", - message_parameters={}, - ) + next_backoff: float = self._initial_backoff + + if self._max_retries < 0: + raise ValueError("Can't have negative number of retries") + while not retry_state.done() and retry_state.count() <= self._max_retries: # Do backoff if retry_state.count() > 0: - backoff = random.randrange( - 0, - int( - min( - self._initial_backoff * self._backoff_multiplier ** retry_state.count(), - self._max_backoff, - ) - ), - ) - logger.debug(f"Retrying call after {backoff} ms sleep") - # Pythons sleep takes seconds as arguments. - time.sleep(backoff / 1000.0) + # Randomize backoff for this iteration + backoff = next_backoff + next_backoff = min(self._max_backoff, next_backoff * self._backoff_multiplier) + if backoff >= self._min_jitter_threshold: + backoff += random.uniform(0, self._jitter) + + logger.debug(f"Retrying call after {backoff} ms sleep") + self._sleep(backoff / 1000.0) yield AttemptManager(self._can_retry, retry_state) + + if not retry_state.done(): + # Exceeded number of retries, throw last exception we had + retry_state.throw() diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 9782add92f4..98f68767b8b 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -23,6 +23,9 @@ from pyspark.sql.connect.client import SparkConnectClient, ChannelBuilder import pyspark.sql.connect.proto as proto from pyspark.testing.connectutils import should_test_connect, connect_requirement_message +from pyspark.sql.connect.client.core import Retrying +from pyspark.sql.connect.client.reattach import RetryException + if should_test_connect: import pandas as pd import pyarrow as pa @@ -89,6 +92,27 @@ class SparkConnectClientTestCase(unittest.TestCase): client.close() self.assertTrue(client.is_closed) + def test_retry(self): + client = SparkConnectClient("sc://foo/;token=bar") + + total_sleep = 0 + + def sleep(t): + nonlocal total_sleep + total_sleep += t + + try: + for attempt in Retrying( + can_retry=SparkConnectClient.retry_exception, sleep=sleep, **client._retry_policy + ): + with attempt: + raise RetryException() + except RetryException: + pass + + # tolerated at least 10 mins of fails + self.assertGreaterEqual(total_sleep, 600) + def test_channel_builder_with_session(self): dummy = str(uuid.uuid4()) chan = ChannelBuilder(f"sc://foo/;session_id={dummy}") diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index c1527365fce..9e8f5623971 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -3372,6 +3372,8 @@ class ClientTests(unittest.TestCase): backoff_multiplier=1, initial_backoff=1, max_backoff=10, + jitter=0, + min_jitter_threshold=0, ): with attempt: stub(2, call_wrap, grpc.StatusCode.INTERNAL) @@ -3387,6 +3389,8 @@ class ClientTests(unittest.TestCase): backoff_multiplier=1, initial_backoff=1, max_backoff=10, + jitter=0, + min_jitter_threshold=0, ): with attempt: stub(2, call_wrap, grpc.StatusCode.INTERNAL) @@ -3403,6 +3407,8 @@ class ClientTests(unittest.TestCase): max_backoff=50, backoff_multiplier=1, initial_backoff=50, + jitter=0, + min_jitter_threshold=0, ): with attempt: stub(5, call_wrap, grpc.StatusCode.INTERNAL) @@ -3419,6 +3425,8 @@ class ClientTests(unittest.TestCase): backoff_multiplier=1, initial_backoff=1, max_backoff=10, + jitter=0, + min_jitter_threshold=0, ): with attempt: stub(2, call_wrap, grpc.StatusCode.UNAVAILABLE) @@ -3435,6 +3443,8 @@ class ClientTests(unittest.TestCase): max_backoff=50, backoff_multiplier=1, initial_backoff=50, + jitter=0, + min_jitter_threshold=0, ): with attempt: stub(5, call_wrap, grpc.StatusCode.UNAVAILABLE) @@ -3451,6 +3461,8 @@ class ClientTests(unittest.TestCase): backoff_multiplier=1, initial_backoff=1, max_backoff=10, + jitter=0, + min_jitter_threshold=0, ): with attempt: stub(5, call_wrap, grpc.StatusCode.INTERNAL) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org