This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 122d982143f [SPARK-44721][CONNECT] Revamp retry logic and make retries 
run for 10 minutes
122d982143f is described below

commit 122d982143f1ae1f2447701c6a7877cbf8bef4f0
Author: Alice Sayutina <alice.sayut...@databricks.com>
AuthorDate: Thu Aug 17 09:56:57 2023 +0200

    [SPARK-44721][CONNECT] Revamp retry logic and make retries run for 10 
minutes
    
    ### What changes were proposed in this pull request?
    
    Change retry logic. For existing retry logic the maximum allowed wait time 
can be extremely low and even zero with small probability.
    
    This happens, because it waits random(0, T) for T in exponentialBackoff(). 
Revamp the logic to guarantee the minimum wait time of 10 minutes. Also 
synchronize retry behavior among python and scala.
    
    ### Why are the changes needed?
    
    This avoids certain class of client errors where client simply doesn't wait 
long enough.
    
    ### Does this PR introduce _any_ user-facing change?
    Changes are small from user perspective. The retries are running longer and 
smoother.
    
    ### How was this patch tested?
    UT
    
    Closes #42399 from cdkrot/revamp_retry_logic.
    
    Lead-authored-by: Alice Sayutina <alice.sayut...@databricks.com>
    Co-authored-by: Alice Sayutina <cdkr...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
    (cherry picked from commit 4952a03fdc22b36c1fb5bead09c5e2cc8b4602b8)
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../ExecutePlanResponseReattachableIterator.scala  |   4 +-
 .../sql/connect/client/GrpcRetryHandler.scala      |  84 ++++++++++++-----
 .../connect/client/SparkConnectClientSuite.scala   |  26 +++++-
 python/pyspark/errors/error_classes.py             |   5 -
 python/pyspark/sql/connect/client/core.py          | 102 ++++++++++-----------
 .../sql/tests/connect/client/test_client.py        |  24 +++++
 .../sql/tests/connect/test_connect_basic.py        |  12 +++
 7 files changed, 170 insertions(+), 87 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
index 5ef1151682b..aeb452faecf 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
@@ -301,6 +301,6 @@ class ExecutePlanResponseReattachableIterator(
   /**
    * Retries the given function with exponential backoff according to the 
client's retryPolicy.
    */
-  private def retry[T](fn: => T, currentRetryNum: Int = 0): T =
-    GrpcRetryHandler.retry(retryPolicy)(fn, currentRetryNum)
+  private def retry[T](fn: => T): T =
+    GrpcRetryHandler.retry(retryPolicy)(fn)
 }
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
index 6dad5b4b3a9..8b6f070b8f5 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
@@ -17,8 +17,8 @@
 
 package org.apache.spark.sql.connect.client
 
-import scala.annotation.tailrec
-import scala.concurrent.duration.FiniteDuration
+import scala.concurrent.duration.{Duration, FiniteDuration}
+import scala.util.Random
 import scala.util.control.NonFatal
 
 import io.grpc.{Status, StatusRuntimeException}
@@ -26,13 +26,15 @@ import io.grpc.stub.StreamObserver
 
 import org.apache.spark.internal.Logging
 
-private[client] class GrpcRetryHandler(private val retryPolicy: 
GrpcRetryHandler.RetryPolicy) {
+private[client] class GrpcRetryHandler(
+    private val retryPolicy: GrpcRetryHandler.RetryPolicy,
+    private val sleep: Long => Unit = Thread.sleep) {
 
   /**
    * Retries the given function with exponential backoff according to the 
client's retryPolicy.
    */
-  def retry[T](fn: => T, currentRetryNum: Int = 0): T =
-    GrpcRetryHandler.retry(retryPolicy)(fn, currentRetryNum)
+  def retry[T](fn: => T): T =
+    GrpcRetryHandler.retry(retryPolicy, sleep)(fn)
 
   /**
    * Generalizes the retry logic for RPC calls that return an iterator.
@@ -148,37 +150,62 @@ private[client] object GrpcRetryHandler extends Logging {
 
   /**
    * Retries the given function with exponential backoff according to the 
client's retryPolicy.
+   *
    * @param retryPolicy
    *   The retry policy
+   * @param sleep
+   *   The function which sleeps (takes number of milliseconds to sleep)
    * @param fn
    *   The function to retry.
-   * @param currentRetryNum
-   *   Current number of retries.
    * @tparam T
    *   The return type of the function.
    * @return
    *   The result of the function.
    */
-  @tailrec final def retry[T](retryPolicy: RetryPolicy)(fn: => T, 
currentRetryNum: Int = 0): T = {
-    if (currentRetryNum > retryPolicy.maxRetries) {
-      throw new IllegalArgumentException(
-        s"The number of retries ($currentRetryNum) must not exceed " +
-          s"the maximum number of retires (${retryPolicy.maxRetries}).")
+  final def retry[T](retryPolicy: RetryPolicy, sleep: Long => Unit = 
Thread.sleep)(
+      fn: => T): T = {
+    var currentRetryNum = 0
+    var exceptionList: Seq[Throwable] = Seq.empty
+    var nextBackoff: Duration = retryPolicy.initialBackoff
+
+    if (retryPolicy.maxRetries < 0) {
+      throw new IllegalArgumentException("Can't have negative number of 
retries")
     }
-    try {
-      return fn
-    } catch {
-      case NonFatal(e)
-          if (retryPolicy.canRetry(e) || e.isInstanceOf[RetryException])
-            && currentRetryNum < retryPolicy.maxRetries =>
-        logWarning(
-          s"Non fatal error during RPC execution: $e, " +
-            s"retrying (currentRetryNum=$currentRetryNum)")
-        Thread.sleep(
-          (retryPolicy.maxBackoff min retryPolicy.initialBackoff * Math
-            .pow(retryPolicy.backoffMultiplier, currentRetryNum)).toMillis)
+
+    while (currentRetryNum <= retryPolicy.maxRetries) {
+      if (currentRetryNum != 0) {
+        var currentBackoff = nextBackoff
+        nextBackoff = nextBackoff * retryPolicy.backoffMultiplier min 
retryPolicy.maxBackoff
+
+        if (currentBackoff >= retryPolicy.minJitterThreshold) {
+          currentBackoff += Random.nextDouble() * retryPolicy.jitter
+        }
+
+        sleep(currentBackoff.toMillis)
+      }
+
+      try {
+        return fn
+      } catch {
+        case NonFatal(e) if retryPolicy.canRetry(e) && currentRetryNum < 
retryPolicy.maxRetries =>
+          currentRetryNum += 1
+          exceptionList = e +: exceptionList
+
+          if (currentRetryNum <= retryPolicy.maxRetries) {
+            logWarning(
+              s"Non-Fatal error during RPC execution: $e, " +
+                s"retrying (currentRetryNum=$currentRetryNum)")
+          } else {
+            logWarning(
+              s"Non-Fatal error during RPC execution: $e, " +
+                s"exceeded retries (currentRetryNum=$currentRetryNum)")
+          }
+      }
     }
-    retry(retryPolicy)(fn, currentRetryNum + 1)
+
+    val exception = exceptionList.head
+    exceptionList.tail.foreach(exception.addSuppressed(_))
+    throw exception
   }
 
   /**
@@ -210,10 +237,17 @@ private[client] object GrpcRetryHandler extends Logging {
    *   Function that determines whether a retry is to be performed in the 
event of an error.
    */
   case class RetryPolicy(
+      // Please synchronize changes here with Python side:
+      // pyspark/sql/connect/client/core.py
+      //
+      // Note: these constants are selected so that the maximum tolerated wait 
is guaranteed
+      // to be at least 10 minutes
       maxRetries: Int = 15,
       initialBackoff: FiniteDuration = FiniteDuration(50, "ms"),
       maxBackoff: FiniteDuration = FiniteDuration(1, "min"),
       backoffMultiplier: Double = 4.0,
+      jitter: FiniteDuration = FiniteDuration(500, "ms"),
+      minJitterThreshold: FiniteDuration = FiniteDuration(2, "s"),
       canRetry: Throwable => Boolean = retryException) {}
 
   /**
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index e483e0a7291..6348e0e49ca 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -220,10 +220,10 @@ class SparkConnectClientSuite extends ConnectFunSuite 
with BeforeAndAfterEach {
     }
   }
 
-  private class DummyFn(val e: Throwable) {
+  private class DummyFn(val e: Throwable, numFails: Int = 3) {
     var counter = 0
     def fn(): Int = {
-      if (counter < 3) {
+      if (counter < numFails) {
         counter += 1
         throw e
       } else {
@@ -232,6 +232,28 @@ class SparkConnectClientSuite extends ConnectFunSuite with 
BeforeAndAfterEach {
     }
   }
 
+  test("SPARK-44721: Retries run for a minimum period") {
+    // repeat test few times to avoid random flakes
+    for (_ <- 1 to 10) {
+      var totalSleepMs: Long = 0
+
+      def sleep(t: Long): Unit = {
+        totalSleepMs += t
+      }
+
+      val dummyFn = new DummyFn(new 
StatusRuntimeException(Status.UNAVAILABLE), numFails = 100)
+      val retryHandler = new GrpcRetryHandler(GrpcRetryHandler.RetryPolicy(), 
sleep)
+
+      assertThrows[StatusRuntimeException] {
+        retryHandler.retry {
+          dummyFn.fn()
+        }
+      }
+
+      assert(totalSleepMs >= 10 * 60 * 1000) // waited at least 10 minutes
+    }
+  }
+
   test("SPARK-44275: retry actually retries") {
     val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE))
     val retryPolicy = GrpcRetryHandler.RetryPolicy()
diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index 279812ebae1..4709f01ba06 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -231,11 +231,6 @@ ERROR_CLASSES_JSON = """
       "Duplicated field names in Arrow Struct are not allowed, got 
<field_names>"
     ]
   },
-  "EXCEED_RETRY" : {
-    "message" : [
-      "Retries exceeded but no exception caught."
-    ]
-  },
   "HIGHER_ORDER_FUNCTION_SHOULD_RETURN_COLUMN" : {
     "message" : [
       "Function `<func_name>` should return Column, got <return_type>."
diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 3e0a35ba926..c2889c10e41 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -92,7 +92,7 @@ from pyspark.sql.pandas.types import 
_create_converter_to_pandas, from_arrow_sch
 from pyspark.sql.types import DataType, StructType, TimestampType, _has_type
 from pyspark.rdd import PythonEvalType
 from pyspark.storagelevel import StorageLevel
-from pyspark.errors import PySparkValueError, PySparkRuntimeError
+from pyspark.errors import PySparkValueError
 
 
 if TYPE_CHECKING:
@@ -638,10 +638,17 @@ class SparkConnectClient(object):
         )
         self._user_id = None
         self._retry_policy = {
+            # Please synchronize changes here with Scala side
+            # GrpcRetryHandler.scala
+            #
+            # Note: the number of retries is selected so that the maximum 
tolerated wait
+            # is guaranteed to be at least 10 minutes
             "max_retries": 15,
-            "backoff_multiplier": 4,
+            "backoff_multiplier": 4.0,
             "initial_backoff": 50,
             "max_backoff": 60000,
+            "jitter": 500,
+            "min_jitter_threshold": 2000,
         }
         if retry_policy:
             self._retry_policy.update(retry_policy)
@@ -669,6 +676,11 @@ class SparkConnectClient(object):
         self._use_reattachable_execute = use_reattachable_execute
         # Configure logging for the SparkConnect client.
 
+    def _retrying(self) -> "Retrying":
+        return Retrying(
+            can_retry=SparkConnectClient.retry_exception, **self._retry_policy 
 # type: ignore
+        )
+
     def disable_reattachable_execute(self) -> "SparkConnectClient":
         self._use_reattachable_execute = False
         return self
@@ -1090,9 +1102,7 @@ class SparkConnectClient(object):
             )
 
         try:
-            for attempt in Retrying(
-                can_retry=SparkConnectClient.retry_exception, 
**self._retry_policy
-            ):
+            for attempt in self._retrying():
                 with attempt:
                     resp = self._stub.AnalyzePlan(req, 
metadata=self._builder.metadata())
                     if resp.session_id != self._session_id:
@@ -1133,9 +1143,7 @@ class SparkConnectClient(object):
                 for b in generator:
                     handle_response(b)
             else:
-                for attempt in Retrying(
-                    can_retry=SparkConnectClient.retry_exception, 
**self._retry_policy
-                ):
+                for attempt in self._retrying():
                     with attempt:
                         for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
                             handle_response(b)
@@ -1220,9 +1228,7 @@ class SparkConnectClient(object):
                 for b in generator:
                     yield from handle_response(b)
             else:
-                for attempt in Retrying(
-                    can_retry=SparkConnectClient.retry_exception, 
**self._retry_policy
-                ):
+                for attempt in self._retrying():
                     with attempt:
                         for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
                             yield from handle_response(b)
@@ -1331,9 +1337,7 @@ class SparkConnectClient(object):
         req = self._config_request_with_metadata()
         req.operation.CopyFrom(operation)
         try:
-            for attempt in Retrying(
-                can_retry=SparkConnectClient.retry_exception, 
**self._retry_policy
-            ):
+            for attempt in self._retrying():
                 with attempt:
                     resp = self._stub.Config(req, 
metadata=self._builder.metadata())
                     if resp.session_id != self._session_id:
@@ -1376,9 +1380,7 @@ class SparkConnectClient(object):
     def interrupt_all(self) -> Optional[List[str]]:
         req = self._interrupt_request("all")
         try:
-            for attempt in Retrying(
-                can_retry=SparkConnectClient.retry_exception, 
**self._retry_policy
-            ):
+            for attempt in self._retrying():
                 with attempt:
                     resp = self._stub.Interrupt(req, 
metadata=self._builder.metadata())
                     if resp.session_id != self._session_id:
@@ -1394,9 +1396,7 @@ class SparkConnectClient(object):
     def interrupt_tag(self, tag: str) -> Optional[List[str]]:
         req = self._interrupt_request("tag", tag)
         try:
-            for attempt in Retrying(
-                can_retry=SparkConnectClient.retry_exception, 
**self._retry_policy
-            ):
+            for attempt in self._retrying():
                 with attempt:
                     resp = self._stub.Interrupt(req, 
metadata=self._builder.metadata())
                     if resp.session_id != self._session_id:
@@ -1412,9 +1412,7 @@ class SparkConnectClient(object):
     def interrupt_operation(self, op_id: str) -> Optional[List[str]]:
         req = self._interrupt_request("operation", op_id)
         try:
-            for attempt in Retrying(
-                can_retry=SparkConnectClient.retry_exception, 
**self._retry_policy
-            ):
+            for attempt in self._retrying():
                 with attempt:
                     resp = self._stub.Interrupt(req, 
metadata=self._builder.metadata())
                     if resp.session_id != self._session_id:
@@ -1538,12 +1536,14 @@ class RetryState:
         self._done = False
         self._count = 0
 
-    def set_exception(self, exc: Optional[BaseException]) -> None:
+    def set_exception(self, exc: BaseException) -> None:
         self._exception = exc
         self._count += 1
 
-    def exception(self) -> Optional[BaseException]:
-        return self._exception
+    def throw(self) -> None:
+        if self._exception is None:
+            raise RuntimeError("No exception is set")
+        raise self._exception
 
     def set_done(self) -> None:
         self._done = True
@@ -1614,13 +1614,19 @@ class Retrying:
         initial_backoff: int,
         max_backoff: int,
         backoff_multiplier: float,
+        jitter: int,
+        min_jitter_threshold: int,
         can_retry: Callable[..., bool] = lambda x: True,
+        sleep: Callable[[float], None] = time.sleep,
     ) -> None:
         self._can_retry = can_retry
         self._max_retries = max_retries
         self._initial_backoff = initial_backoff
         self._max_backoff = max_backoff
         self._backoff_multiplier = backoff_multiplier
+        self._jitter = jitter
+        self._min_jitter_threshold = min_jitter_threshold
+        self._sleep = sleep
 
     def __iter__(self) -> Generator[AttemptManager, None, None]:
         """
@@ -1631,35 +1637,25 @@ class Retrying:
         A generator that yields the current attempt.
         """
         retry_state = RetryState()
-        while True:
-            # Check if the operation was completed successfully.
-            if retry_state.done():
-                break
-
-            # If the number of retries have exceeded the maximum allowed 
retries.
-            if retry_state.count() > self._max_retries:
-                e = retry_state.exception()
-                if e is not None:
-                    raise e
-                else:
-                    raise PySparkRuntimeError(
-                        error_class="EXCEED_RETRY",
-                        message_parameters={},
-                    )
+        next_backoff: float = self._initial_backoff
+
+        if self._max_retries < 0:
+            raise ValueError("Can't have negative number of retries")
 
+        while not retry_state.done() and retry_state.count() <= 
self._max_retries:
             # Do backoff
             if retry_state.count() > 0:
-                backoff = random.randrange(
-                    0,
-                    int(
-                        min(
-                            self._initial_backoff * self._backoff_multiplier 
** retry_state.count(),
-                            self._max_backoff,
-                        )
-                    ),
-                )
-                logger.debug(f"Retrying call after {backoff} ms sleep")
-                # Pythons sleep takes seconds as arguments.
-                time.sleep(backoff / 1000.0)
+                # Randomize backoff for this iteration
+                backoff = next_backoff
+                next_backoff = min(self._max_backoff, next_backoff * 
self._backoff_multiplier)
 
+                if backoff >= self._min_jitter_threshold:
+                    backoff += random.uniform(0, self._jitter)
+
+                logger.debug(f"Retrying call after {backoff} ms sleep")
+                self._sleep(backoff / 1000.0)
             yield AttemptManager(self._can_retry, retry_state)
+
+        if not retry_state.done():
+            # Exceeded number of retries, throw last exception we had
+            retry_state.throw()
diff --git a/python/pyspark/sql/tests/connect/client/test_client.py 
b/python/pyspark/sql/tests/connect/client/test_client.py
index 9782add92f4..98f68767b8b 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -23,6 +23,9 @@ from pyspark.sql.connect.client import SparkConnectClient, 
ChannelBuilder
 import pyspark.sql.connect.proto as proto
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
+from pyspark.sql.connect.client.core import Retrying
+from pyspark.sql.connect.client.reattach import RetryException
+
 if should_test_connect:
     import pandas as pd
     import pyarrow as pa
@@ -89,6 +92,27 @@ class SparkConnectClientTestCase(unittest.TestCase):
         client.close()
         self.assertTrue(client.is_closed)
 
+    def test_retry(self):
+        client = SparkConnectClient("sc://foo/;token=bar")
+
+        total_sleep = 0
+
+        def sleep(t):
+            nonlocal total_sleep
+            total_sleep += t
+
+        try:
+            for attempt in Retrying(
+                can_retry=SparkConnectClient.retry_exception, sleep=sleep, 
**client._retry_policy
+            ):
+                with attempt:
+                    raise RetryException()
+        except RetryException:
+            pass
+
+        # tolerated at least 10 mins of fails
+        self.assertGreaterEqual(total_sleep, 600)
+
     def test_channel_builder_with_session(self):
         dummy = str(uuid.uuid4())
         chan = ChannelBuilder(f"sc://foo/;session_id={dummy}")
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index c1527365fce..9e8f5623971 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -3372,6 +3372,8 @@ class ClientTests(unittest.TestCase):
             backoff_multiplier=1,
             initial_backoff=1,
             max_backoff=10,
+            jitter=0,
+            min_jitter_threshold=0,
         ):
             with attempt:
                 stub(2, call_wrap, grpc.StatusCode.INTERNAL)
@@ -3387,6 +3389,8 @@ class ClientTests(unittest.TestCase):
             backoff_multiplier=1,
             initial_backoff=1,
             max_backoff=10,
+            jitter=0,
+            min_jitter_threshold=0,
         ):
             with attempt:
                 stub(2, call_wrap, grpc.StatusCode.INTERNAL)
@@ -3403,6 +3407,8 @@ class ClientTests(unittest.TestCase):
                 max_backoff=50,
                 backoff_multiplier=1,
                 initial_backoff=50,
+                jitter=0,
+                min_jitter_threshold=0,
             ):
                 with attempt:
                     stub(5, call_wrap, grpc.StatusCode.INTERNAL)
@@ -3419,6 +3425,8 @@ class ClientTests(unittest.TestCase):
             backoff_multiplier=1,
             initial_backoff=1,
             max_backoff=10,
+            jitter=0,
+            min_jitter_threshold=0,
         ):
             with attempt:
                 stub(2, call_wrap, grpc.StatusCode.UNAVAILABLE)
@@ -3435,6 +3443,8 @@ class ClientTests(unittest.TestCase):
                 max_backoff=50,
                 backoff_multiplier=1,
                 initial_backoff=50,
+                jitter=0,
+                min_jitter_threshold=0,
             ):
                 with attempt:
                     stub(5, call_wrap, grpc.StatusCode.UNAVAILABLE)
@@ -3451,6 +3461,8 @@ class ClientTests(unittest.TestCase):
                 backoff_multiplier=1,
                 initial_backoff=1,
                 max_backoff=10,
+                jitter=0,
+                min_jitter_threshold=0,
             ):
                 with attempt:
                     stub(5, call_wrap, grpc.StatusCode.INTERNAL)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to