This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 182e2d236c5 [SPARK-45851][CONNECT][SCALA] Support multiple policies in 
scala client
182e2d236c5 is described below

commit 182e2d236c5c39f3c4dba248d6df77eb9c363dfd
Author: Alice Sayutina <alice.sayut...@databricks.com>
AuthorDate: Thu Nov 16 19:41:10 2023 +0900

    [SPARK-45851][CONNECT][SCALA] Support multiple policies in scala client
    
    ### What changes were proposed in this pull request?
    
    Support multiple retry policies defined at the same time. Each policy 
determines which error types it can retry and how exactly those should be 
spread out.
    
    Scala parity for https://github.com/apache/spark/pull/43591
    
    ### Why are the changes needed?
    
    Different error types should be treated differently For instance, 
networking connectivity errors and remote resources being initialized should be 
treated separately.
    
    ### Does this PR introduce _any_ user-facing change?
    No (as long as user doesn't poke within client internals).
    
    ### How was this patch tested?
    Unit tests, some hand testing.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #43757 from cdkrot/SPARK-45851-scala-multiple-policies.
    
    Authored-by: Alice Sayutina <alice.sayut...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../spark/sql/connect/client/ArtifactSuite.scala   |   4 +-
 .../connect/client/SparkConnectClientSuite.scala   |  71 +++++++--
 .../apache/spark/sql/test/RemoteSparkSession.scala |   7 +-
 .../client/CustomSparkConnectBlockingStub.scala    |   4 +-
 .../ExecutePlanResponseReattachableIterator.scala  |  16 +-
 .../sql/connect/client/GrpcRetryHandler.scala      | 166 ++++++++-------------
 .../spark/sql/connect/client/RetriesExceeded.scala |  25 ++++
 .../spark/sql/connect/client/RetryPolicy.scala     | 134 +++++++++++++++++
 .../sql/connect/client/SparkConnectClient.scala    |  12 +-
 .../sql/connect/client/SparkConnectStubState.scala |  10 +-
 .../spark/sql/connect/SparkConnectServerTest.scala |   6 +-
 11 files changed, 311 insertions(+), 144 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
index 79aba053ea0..f945313d242 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
@@ -42,7 +42,6 @@ class ArtifactSuite extends ConnectFunSuite with 
BeforeAndAfterEach {
   private var server: Server = _
   private var artifactManager: ArtifactManager = _
   private var channel: ManagedChannel = _
-  private var retryPolicy: GrpcRetryHandler.RetryPolicy = _
   private var bstub: CustomSparkConnectBlockingStub = _
   private var stub: CustomSparkConnectStub = _
   private var state: SparkConnectStubState = _
@@ -58,8 +57,7 @@ class ArtifactSuite extends ConnectFunSuite with 
BeforeAndAfterEach {
 
   private def createArtifactManager(): Unit = {
     channel = 
InProcessChannelBuilder.forName(getClass.getName).directExecutor().build()
-    retryPolicy = GrpcRetryHandler.RetryPolicy()
-    state = new SparkConnectStubState(channel, retryPolicy)
+    state = new SparkConnectStubState(channel, RetryPolicy.defaultPolicies())
     bstub = new CustomSparkConnectBlockingStub(channel, state)
     stub = new CustomSparkConnectStub(channel, state)
     artifactManager = new ArtifactManager(Configuration(), "", bstub, stub)
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index b93713383b2..e226484d87a 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -119,7 +119,7 @@ class SparkConnectClientSuite extends ConnectFunSuite with 
BeforeAndAfterEach {
     client = SparkConnectClient
       .builder()
       .connectionString(s"sc://localhost:${server.getPort}/;use_ssl=true")
-      .retryPolicy(GrpcRetryHandler.RetryPolicy(maxRetries = 0))
+      .retryPolicy(RetryPolicy(maxRetries = Some(0), canRetry = _ => false, 
name = "TestPolicy"))
       .build()
 
     val request = 
AnalyzePlanRequest.newBuilder().setSessionId("abc123").build()
@@ -311,7 +311,7 @@ class SparkConnectClientSuite extends ConnectFunSuite with 
BeforeAndAfterEach {
     }
   }
 
-  private class DummyFn(val e: Throwable, numFails: Int = 3) {
+  private class DummyFn(e: => Throwable, numFails: Int = 3) {
     var counter = 0
     def fn(): Int = {
       if (counter < numFails) {
@@ -333,9 +333,9 @@ class SparkConnectClientSuite extends ConnectFunSuite with 
BeforeAndAfterEach {
       }
 
       val dummyFn = new DummyFn(new 
StatusRuntimeException(Status.UNAVAILABLE), numFails = 100)
-      val retryHandler = new GrpcRetryHandler(GrpcRetryHandler.RetryPolicy(), 
sleep)
+      val retryHandler = new GrpcRetryHandler(RetryPolicy.defaultPolicies(), 
sleep)
 
-      assertThrows[StatusRuntimeException] {
+      assertThrows[RetriesExceeded] {
         retryHandler.retry {
           dummyFn.fn()
         }
@@ -347,8 +347,8 @@ class SparkConnectClientSuite extends ConnectFunSuite with 
BeforeAndAfterEach {
 
   test("SPARK-44275: retry actually retries") {
     val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE))
-    val retryPolicy = GrpcRetryHandler.RetryPolicy()
-    val retryHandler = new GrpcRetryHandler(retryPolicy)
+    val retryPolicies = RetryPolicy.defaultPolicies()
+    val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = _ => {})
     val result = retryHandler.retry { dummyFn.fn() }
 
     assert(result == 42)
@@ -357,8 +357,8 @@ class SparkConnectClientSuite extends ConnectFunSuite with 
BeforeAndAfterEach {
 
   test("SPARK-44275: default retryException retries only on UNAVAILABLE") {
     val dummyFn = new DummyFn(new StatusRuntimeException(Status.ABORTED))
-    val retryPolicy = GrpcRetryHandler.RetryPolicy()
-    val retryHandler = new GrpcRetryHandler(retryPolicy)
+    val retryPolicies = RetryPolicy.defaultPolicies()
+    val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = _ => {})
 
     assertThrows[StatusRuntimeException] {
       retryHandler.retry { dummyFn.fn() }
@@ -368,7 +368,7 @@ class SparkConnectClientSuite extends ConnectFunSuite with 
BeforeAndAfterEach {
 
   test("SPARK-44275: retry uses canRetry to filter exceptions") {
     val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE))
-    val retryPolicy = GrpcRetryHandler.RetryPolicy(canRetry = _ => false)
+    val retryPolicy = RetryPolicy(canRetry = _ => false, name = "TestPolicy")
     val retryHandler = new GrpcRetryHandler(retryPolicy)
 
     assertThrows[StatusRuntimeException] {
@@ -379,15 +379,62 @@ class SparkConnectClientSuite extends ConnectFunSuite 
with BeforeAndAfterEach {
 
   test("SPARK-44275: retry does not exceed maxRetries") {
     val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE))
-    val retryPolicy = GrpcRetryHandler.RetryPolicy(canRetry = _ => true, 
maxRetries = 1)
-    val retryHandler = new GrpcRetryHandler(retryPolicy)
+    val retryPolicy = RetryPolicy(canRetry = _ => true, maxRetries = Some(1), 
name = "TestPolicy")
+    val retryHandler = new GrpcRetryHandler(retryPolicy, sleep = _ => {})
 
-    assertThrows[StatusRuntimeException] {
+    assertThrows[RetriesExceeded] {
       retryHandler.retry { dummyFn.fn() }
     }
     assert(dummyFn.counter == 2)
   }
 
+  def testPolicySpecificError(maxRetries: Int, status: Status): RetryPolicy = {
+    RetryPolicy(
+      maxRetries = Some(maxRetries),
+      name = s"Policy for ${status.getCode}",
+      canRetry = {
+        case e: StatusRuntimeException => e.getStatus.getCode == status.getCode
+        case _ => false
+      })
+  }
+
+  test("Test multiple policies") {
+    val policy1 = testPolicySpecificError(maxRetries = 2, status = 
Status.UNAVAILABLE)
+    val policy2 = testPolicySpecificError(maxRetries = 4, status = 
Status.INTERNAL)
+
+    // Tolerate 2 UNAVAILABLE errors and 4 INTERNAL errors
+
+    val errors = (List.fill(2)(Status.UNAVAILABLE) ++ 
List.fill(4)(Status.INTERNAL)).iterator
+
+    new GrpcRetryHandler(List(policy1, policy2), sleep = _ => {}).retry({
+      val e = errors.nextOption()
+      if (e.isDefined) {
+        throw e.get.asRuntimeException()
+      }
+    })
+
+    assert(!errors.hasNext)
+  }
+
+  test("Test multiple policies exceed") {
+    val policy1 = testPolicySpecificError(maxRetries = 2, status = 
Status.INTERNAL)
+    val policy2 = testPolicySpecificError(maxRetries = 4, status = 
Status.INTERNAL)
+
+    val errors = List.fill(10)(Status.INTERNAL).iterator
+    var countAttempted = 0
+
+    assertThrows[RetriesExceeded](
+      new GrpcRetryHandler(List(policy1, policy2), sleep = _ => {}).retry({
+        countAttempted += 1
+        val e = errors.nextOption()
+        if (e.isDefined) {
+          throw e.get.asRuntimeException()
+        }
+      }))
+
+    assert(countAttempted == 7)
+  }
+
   test("SPARK-45871: Client execute iterator.toSeq consumes the reattachable 
iterator") {
     startDummyServer(0)
     client = SparkConnectClient
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala
index 172efb7db7c..3444e2d7544 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala
@@ -17,7 +17,6 @@
 package org.apache.spark.sql.test
 
 import java.io.{File, IOException, OutputStream}
-import java.lang.ProcessBuilder
 import java.lang.ProcessBuilder.Redirect
 import java.nio.file.Paths
 import java.util.concurrent.TimeUnit
@@ -28,7 +27,7 @@ import org.scalatest.BeforeAndAfterAll
 
 import org.apache.spark.SparkBuildInfo
 import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.connect.client.GrpcRetryHandler.RetryPolicy
+import org.apache.spark.sql.connect.client.RetryPolicy
 import org.apache.spark.sql.connect.client.SparkConnectClient
 import org.apache.spark.sql.connect.common.config.ConnectCommon
 import org.apache.spark.sql.test.IntegrationTestUtils._
@@ -189,7 +188,9 @@ object SparkConnectServerUtils {
           .builder()
           .userId("test")
           .port(port)
-          .retryPolicy(RetryPolicy(maxRetries = 7, maxBackoff = 
FiniteDuration(10, "s")))
+          .retryPolicy(RetryPolicy
+            .defaultPolicy()
+            .copy(maxRetries = Some(7), maxBackoff = Some(FiniteDuration(10, 
"s"))))
           .build())
       .create()
 
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
index f8df2fa3f65..d7867229248 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
@@ -61,9 +61,9 @@ private[connect] class CustomSparkConnectBlockingStub(
         request.getSessionId,
         request.getUserContext,
         request.getClientType,
-        // Don't use retryHandler - own retry handling is inside.
         stubState.responseValidator.wrapIterator(
-          new ExecutePlanResponseReattachableIterator(request, channel, 
stubState.retryPolicy)))
+          // ExecutePlanResponseReattachableIterator does all retries by 
itself, don't wrap it here
+          new ExecutePlanResponseReattachableIterator(request, channel, 
stubState.retryHandler)))
     }
   }
 
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
index cfa492ef063..5854a9225db 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
@@ -27,6 +27,7 @@ import io.grpc.stub.StreamObserver
 
 import org.apache.spark.connect.proto
 import org.apache.spark.internal.Logging
+import org.apache.spark.sql.connect.client.GrpcRetryHandler.RetryException
 
 /**
  * Retryable iterator of ExecutePlanResponses to an ExecutePlan call.
@@ -50,10 +51,15 @@ import org.apache.spark.internal.Logging
 class ExecutePlanResponseReattachableIterator(
     request: proto.ExecutePlanRequest,
     channel: ManagedChannel,
-    retryPolicy: GrpcRetryHandler.RetryPolicy)
+    retryHandler: GrpcRetryHandler)
     extends WrappedCloseableIterator[proto.ExecutePlanResponse]
     with Logging {
 
+  /**
+   * Retries the given function with exponential backoff according to the 
client's retryPolicy.
+   */
+  private def retry[T](fn: => T): T = retryHandler.retry(fn)
+
   val operationId = if (request.hasOperationId) {
     request.getOperationId
   } else {
@@ -236,7 +242,7 @@ class ExecutePlanResponseReattachableIterator(
         }
         // Try a new ExecutePlan, and throw upstream for retry.
         iter = Some(rawBlockingStub.executePlan(initialRequest))
-        val error = new GrpcRetryHandler.RetryException()
+        val error = new RetryException()
         error.addSuppressed(ex)
         throw error
       case NonFatal(e) =>
@@ -319,12 +325,6 @@ class ExecutePlanResponseReattachableIterator(
 
     release.build()
   }
-
-  /**
-   * Retries the given function with exponential backoff according to the 
client's retryPolicy.
-   */
-  private def retry[T](fn: => T): T =
-    GrpcRetryHandler.retry(retryPolicy)(fn)
 }
 
 private[connect] object ExecutePlanResponseReattachableIterator {
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
index 0f8178cdb5a..2418dfa0350 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
@@ -17,24 +17,23 @@
 
 package org.apache.spark.sql.connect.client
 
-import scala.concurrent.duration.{Duration, FiniteDuration}
-import scala.util.Random
 import scala.util.control.NonFatal
 
-import io.grpc.{Status, StatusRuntimeException}
 import io.grpc.stub.StreamObserver
 
 import org.apache.spark.internal.Logging
 
 private[sql] class GrpcRetryHandler(
-    private val retryPolicy: GrpcRetryHandler.RetryPolicy,
+    private val policies: Seq[RetryPolicy],
     private val sleep: Long => Unit = Thread.sleep) {
 
+  def this(policy: RetryPolicy, sleep: Long => Unit) = this(List(policy), 
sleep)
+  def this(policy: RetryPolicy) = this(policy, Thread.sleep)
+
   /**
    * Retries the given function with exponential backoff according to the 
client's retryPolicy.
    */
-  def retry[T](fn: => T): T =
-    GrpcRetryHandler.retry(retryPolicy, sleep)(fn)
+  def retry[T](fn: => T): T = new GrpcRetryHandler.Retrying(policies, sleep, 
fn).retry()
 
   /**
    * Generalizes the retry logic for RPC calls that return an iterator.
@@ -151,125 +150,86 @@ private[sql] class GrpcRetryHandler(
 private[sql] object GrpcRetryHandler extends Logging {
 
   /**
-   * Retries the given function with exponential backoff according to the 
client's retryPolicy.
-   *
-   * @param retryPolicy
-   *   The retry policy
+   * Class managing the state of the retrying logic during a single retryable 
block.
+   * @param retryPolicies
+   *   list of policies to apply (in order)
    * @param sleep
-   *   The function which sleeps (takes number of milliseconds to sleep)
+   *   typically Thread.sleep
    * @param fn
-   *   The function to retry.
+   *   the function to compute
    * @tparam T
-   *   The return type of the function.
-   * @return
-   *   The result of the function.
+   *   result of function fn
    */
-  final def retry[T](retryPolicy: RetryPolicy, sleep: Long => Unit = 
Thread.sleep)(
-      fn: => T): T = {
-    var currentRetryNum = 0
-    var exceptionList: Seq[Throwable] = Seq.empty
-    var nextBackoff: Duration = retryPolicy.initialBackoff
-
-    if (retryPolicy.maxRetries < 0) {
-      throw new IllegalArgumentException("Can't have negative number of 
retries")
-    }
-
-    while (currentRetryNum <= retryPolicy.maxRetries) {
-      if (currentRetryNum != 0) {
-        var currentBackoff = nextBackoff
-        nextBackoff = nextBackoff * retryPolicy.backoffMultiplier min 
retryPolicy.maxBackoff
+  class Retrying[T](retryPolicies: Seq[RetryPolicy], sleep: Long => Unit, fn: 
=> T) {
+    private var currentRetryNum: Int = 0
+    private var exceptionList: Seq[Throwable] = Seq.empty
+    private val policies: Seq[RetryPolicy.RetryPolicyState] = 
retryPolicies.map(_.toState)
 
-        if (currentBackoff >= retryPolicy.minJitterThreshold) {
-          currentBackoff += Random.nextDouble() * retryPolicy.jitter
-        }
-
-        sleep(currentBackoff.toMillis)
-      }
+    def canRetry(throwable: Throwable): Boolean = {
+      throwable.isInstanceOf[RetryException] || policies.exists(p => 
p.canRetry(throwable))
+    }
 
+    def makeAttempt(): Option[T] = {
       try {
-        return fn
+        Some(fn)
       } catch {
-        case NonFatal(e) if retryPolicy.canRetry(e) && currentRetryNum < 
retryPolicy.maxRetries =>
+        case NonFatal(e) if canRetry(e) =>
           currentRetryNum += 1
           exceptionList = e +: exceptionList
-
-          if (currentRetryNum <= retryPolicy.maxRetries) {
-            logWarning(
-              s"Non-Fatal error during RPC execution: $e, " +
-                s"retrying (currentRetryNum=$currentRetryNum)")
-          } else {
-            logWarning(
-              s"Non-Fatal error during RPC execution: $e, " +
-                s"exceeded retries (currentRetryNum=$currentRetryNum)")
-          }
+          None
       }
     }
 
-    val exception = exceptionList.head
-    exceptionList.tail.foreach(exception.addSuppressed(_))
-    throw exception
-  }
+    def waitAfterAttempt(): Unit = {
+      // find policy which will accept this exception
+      val lastException = exceptionList.head
 
-  /**
-   * Default canRetry in [[RetryPolicy]].
-   * @param e
-   *   The exception to check.
-   * @return
-   *   true if the exception is a [[StatusRuntimeException]] with code 
UNAVAILABLE.
-   */
-  private[client] def retryException(e: Throwable): Boolean = {
-    e match {
-      case e: StatusRuntimeException =>
-        val statusCode: Status.Code = e.getStatus.getCode
+      if (lastException.isInstanceOf[RetryException]) {
+        // retry exception is considered immediately retriable without any 
policies.
+        logWarning(
+          s"Non-Fatal error during RPC execution: $lastException, retrying " +
+            s"(currentRetryNum=$currentRetryNum)")
+        return
+      }
 
-        if (statusCode == Status.Code.INTERNAL) {
-          val msg: String = e.toString
+      for (policy <- policies if policy.canRetry(lastException)) {
+        val time = policy.nextAttempt()
 
-          // This error happens if another RPC preempts this RPC.
-          if (msg.contains("INVALID_CURSOR.DISCONNECTED")) {
-            return true
-          }
-        }
+        if (time.isDefined) {
+          logWarning(
+            s"Non-Fatal error during RPC execution: $lastException, retrying " 
+
+              s"(wait=${time.get.toMillis}, currentRetryNum=$currentRetryNum, 
" +
+              s"policy: ${policy.getName})")
 
-        if (statusCode == Status.Code.UNAVAILABLE) {
-          return true
+          sleep(time.get.toMillis)
+          return
         }
-        false
-      case _ => false
+      }
+
+      logWarning(
+        s"Non-Fatal error during RPC execution: $lastException, exceeded 
retries " +
+          s"(currentRetryNum=$currentRetryNum)")
+
+      val error = new RetriesExceeded()
+      exceptionList.foreach(error.addSuppressed)
+      throw error
     }
-  }
 
-  /**
-   * [[RetryPolicy]] configure the retry mechanism in [[GrpcRetryHandler]]
-   *
-   * @param maxRetries
-   *   Maximum number of retries.
-   * @param initialBackoff
-   *   Start value of the exponential backoff (ms).
-   * @param maxBackoff
-   *   Maximal value of the exponential backoff (ms).
-   * @param backoffMultiplier
-   *   Multiplicative base of the exponential backoff.
-   * @param canRetry
-   *   Function that determines whether a retry is to be performed in the 
event of an error.
-   */
-  case class RetryPolicy(
-      // Please synchronize changes here with Python side:
-      // pyspark/sql/connect/client/core.py
-      //
-      // Note: these constants are selected so that the maximum tolerated wait 
is guaranteed
-      // to be at least 10 minutes
-      maxRetries: Int = 15,
-      initialBackoff: FiniteDuration = FiniteDuration(50, "ms"),
-      maxBackoff: FiniteDuration = FiniteDuration(1, "min"),
-      backoffMultiplier: Double = 4.0,
-      jitter: FiniteDuration = FiniteDuration(500, "ms"),
-      minJitterThreshold: FiniteDuration = FiniteDuration(2, "s"),
-      canRetry: Throwable => Boolean = retryException) {}
+    def retry(): T = {
+      var result = makeAttempt()
+
+      while (result.isEmpty) {
+        waitAfterAttempt()
+        result = makeAttempt()
+      }
+
+      result.get
+    }
+  }
 
   /**
-   * An exception that can be thrown upstream when inside retry and which will 
be retryable
-   * regardless of policy.
+   * An exception that can be thrown upstream when inside retry and which will 
be always retryable
+   * without any policies.
    */
   class RetryException extends Throwable
 }
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetriesExceeded.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetriesExceeded.scala
new file mode 100644
index 00000000000..77e1c0deab2
--- /dev/null
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetriesExceeded.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.client
+
+/**
+ * Represents an exception which was considered retriable but has exceeded 
retry limits.
+ *
+ * The actual exceptions incurred can be retrieved with getSuppressed()
+ */
+class RetriesExceeded extends Throwable
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala
new file mode 100644
index 00000000000..cb5b97f2e4a
--- /dev/null
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.client
+
+import scala.concurrent.duration.{Duration, FiniteDuration}
+import scala.util.Random
+
+import io.grpc.{Status, StatusRuntimeException}
+
+/**
+ * [[RetryPolicy]] configure the retry mechanism in [[GrpcRetryHandler]]
+ *
+ * @param maxRetries
+ *   Maximum number of retries.
+ * @param initialBackoff
+ *   Start value of the exponential backoff (ms).
+ * @param maxBackoff
+ *   Maximal value of the exponential backoff (ms).
+ * @param backoffMultiplier
+ *   Multiplicative base of the exponential backoff.
+ * @param canRetry
+ *   Function that determines whether a retry is to be performed in the event 
of an error.
+ */
+case class RetryPolicy(
+    maxRetries: Option[Int] = None,
+    initialBackoff: FiniteDuration = FiniteDuration(1000, "ms"),
+    maxBackoff: Option[FiniteDuration] = None,
+    backoffMultiplier: Double = 1.0,
+    jitter: FiniteDuration = FiniteDuration(0, "s"),
+    minJitterThreshold: FiniteDuration = FiniteDuration(0, "s"),
+    canRetry: Throwable => Boolean,
+    name: String) {
+
+  def getName: String = name
+
+  def toState: RetryPolicy.RetryPolicyState = new 
RetryPolicy.RetryPolicyState(this)
+}
+
+object RetryPolicy {
+  def defaultPolicy(): RetryPolicy = RetryPolicy(
+    name = "DefaultPolicy",
+    // Please synchronize changes here with Python side:
+    // pyspark/sql/connect/client/core.py
+    //
+    // Note: these constants are selected so that the maximum tolerated wait 
is guaranteed
+    // to be at least 10 minutes
+    maxRetries = Some(15),
+    initialBackoff = FiniteDuration(50, "ms"),
+    maxBackoff = Some(FiniteDuration(1, "min")),
+    backoffMultiplier = 4.0,
+    jitter = FiniteDuration(500, "ms"),
+    minJitterThreshold = FiniteDuration(2, "s"),
+    canRetry = defaultPolicyRetryException)
+
+  // list of policies to be used by this client
+  def defaultPolicies(): Seq[RetryPolicy] = List(defaultPolicy())
+
+  // represents a state of the specific policy
+  // (how many retries have happened and how much to wait until next one)
+  private[client] class RetryPolicyState(val policy: RetryPolicy) {
+    private var numberAttempts = 0
+    private var nextWait: Duration = policy.initialBackoff
+
+    // return waiting time until next attempt, or None if has exceeded max 
retries
+    def nextAttempt(): Option[Duration] = {
+      if (policy.maxRetries.isDefined && numberAttempts >= 
policy.maxRetries.get) {
+        return None
+      }
+
+      numberAttempts += 1
+
+      var currentWait = nextWait
+      nextWait = nextWait * policy.backoffMultiplier
+      if (policy.maxBackoff.isDefined) {
+        nextWait = nextWait min policy.maxBackoff.get
+      }
+
+      if (currentWait >= policy.minJitterThreshold) {
+        currentWait += Random.nextDouble() * policy.jitter
+      }
+
+      Some(currentWait)
+    }
+
+    def canRetry(throwable: Throwable): Boolean = policy.canRetry(throwable)
+
+    def getName: String = policy.getName
+  }
+
+  /**
+   * Default canRetry in [[RetryPolicy]].
+   *
+   * @param e
+   *   The exception to check.
+   * @return
+   *   true if the exception is a [[StatusRuntimeException]] with code 
UNAVAILABLE.
+   */
+  private[client] def defaultPolicyRetryException(e: Throwable): Boolean = {
+    e match {
+      case e: StatusRuntimeException =>
+        val statusCode: Status.Code = e.getStatus.getCode
+
+        if (statusCode == Status.Code.INTERNAL) {
+          val msg: String = e.toString
+
+          // This error happens if another RPC preempts this RPC.
+          if (msg.contains("INVALID_CURSOR.DISCONNECTED")) {
+            return true
+          }
+        }
+
+        if (statusCode == Status.Code.UNAVAILABLE) {
+          return true
+        }
+        false
+      case _ => false
+    }
+  }
+}
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index 9fc74f1af2c..c2776e65392 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -43,7 +43,7 @@ private[sql] class SparkConnectClient(
 
   private val userContext: UserContext = configuration.userContext
 
-  private[this] val stubState = new SparkConnectStubState(channel, 
configuration.retryPolicy)
+  private[this] val stubState = new SparkConnectStubState(channel, 
configuration.retryPolicies)
   private[this] val bstub =
     new CustomSparkConnectBlockingStub(channel, stubState)
   private[this] val stub =
@@ -443,11 +443,15 @@ object SparkConnectClient {
 
     def sslEnabled: Boolean = _configuration.isSslEnabled.contains(true)
 
-    def retryPolicy(policy: GrpcRetryHandler.RetryPolicy): Builder = {
-      _configuration = _configuration.copy(retryPolicy = policy)
+    def retryPolicy(policies: Seq[RetryPolicy]): Builder = {
+      _configuration = _configuration.copy(retryPolicies = policies)
       this
     }
 
+    def retryPolicy(policy: RetryPolicy): Builder = {
+      retryPolicy(List(policy))
+    }
+
     private object URIParams {
       val PARAM_USER_ID = "user_id"
       val PARAM_USE_SSL = "use_ssl"
@@ -634,7 +638,7 @@ object SparkConnectClient {
       metadata: Map[String, String] = Map.empty,
       userAgent: String = genUserAgent(
         sys.env.getOrElse("SPARK_CONNECT_USER_AGENT", DEFAULT_USER_AGENT)),
-      retryPolicy: GrpcRetryHandler.RetryPolicy = 
GrpcRetryHandler.RetryPolicy(),
+      retryPolicies: Seq[RetryPolicy] = RetryPolicy.defaultPolicies(),
       useReattachableExecute: Boolean = true,
       interceptors: List[ClientInterceptor] = List.empty,
       sessionId: Option[String] = None) {
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectStubState.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectStubState.scala
index e6c7ebf9211..2ec9ecad903 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectStubState.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectStubState.scala
@@ -26,17 +26,15 @@ import org.apache.spark.internal.Logging
 // that the same stub instance is used for all requests from the same client. 
In addition,
 // this class provides access to the commonly configured retry policy and 
exception conversion
 // logic.
-class SparkConnectStubState(
-    channel: ManagedChannel,
-    val retryPolicy: GrpcRetryHandler.RetryPolicy)
+class SparkConnectStubState(channel: ManagedChannel, retryPolicies: 
Seq[RetryPolicy])
     extends Logging {
 
+  // Manages the retry handler logic used by the stubs.
+  lazy val retryHandler = new GrpcRetryHandler(retryPolicies)
+
   // Responsible to convert the GRPC Status exceptions into Spark exceptions.
   lazy val exceptionConverter: GrpcExceptionConverter = new 
GrpcExceptionConverter(channel)
 
-  // Manages the retry handler logic used by the stubs.
-  lazy val retryHandler = new GrpcRetryHandler(retryPolicy)
-
   // Provides a helper for validating the responses processed by the stub.
   lazy val responseValidator = new ResponseValidator()
 
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
index 1c0d9a68ab6..dbb06437c4d 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
@@ -27,7 +27,7 @@ import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.connect.proto
 import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.connect.client.{CloseableIterator, 
CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator, 
GrpcRetryHandler, SparkConnectClient, SparkConnectStubState}
+import org.apache.spark.sql.connect.client.{CloseableIterator, 
CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator, 
RetryPolicy, SparkConnectClient, SparkConnectStubState}
 import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
 import org.apache.spark.sql.connect.common.config.ConnectCommon
 import org.apache.spark.sql.connect.config.Connect
@@ -244,11 +244,11 @@ trait SparkConnectServerTest extends SharedSparkSession {
   }
 
   protected def withCustomBlockingStub(
-      retryPolicy: GrpcRetryHandler.RetryPolicy = 
GrpcRetryHandler.RetryPolicy())(
+      retryPolicies: Seq[RetryPolicy] = RetryPolicy.defaultPolicies())(
       f: CustomSparkConnectBlockingStub => Unit): Unit = {
     val conf = SparkConnectClient.Configuration(port = serverPort)
     val channel = conf.createChannel()
-    val stubState = new SparkConnectStubState(channel, retryPolicy)
+    val stubState = new SparkConnectStubState(channel, retryPolicies)
     val bstub = new CustomSparkConnectBlockingStub(channel, stubState)
     try f(bstub)
     finally {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to