This is an automated email from the ASF dual-hosted git repository.

feiwang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 612464c69 [CELEBORN-2015] Retry IOException failures for RPC requests
612464c69 is described below

commit 612464c69deba27047d725dbcd25cdb8f8b994e5
Author: Sanskar Modi <sanskarmod...@gmail.com>
AuthorDate: Tue May 27 07:37:40 2025 -0700

    [CELEBORN-2015] Retry IOException failures for RPC requests
    
    ### What changes were proposed in this pull request?
    - Support retry on IOException failures for RpcRequest in addition with 
RpcTimeoutException.
    - Moved duplicate code to Utils
    
    ### Why are the changes needed?
    
    Currently if a request fails with SocketException or IOException it does 
not get retried which leads to stage failures. Celeborn should retry on such 
connection failures.
    
    ### Does this PR introduce _any_ user-facing change?
    NA
    
    ### How was this patch tested?
    NA
    
    Closes #3286 from s0nskar/setup_lifecycle_exception.
    
    Authored-by: Sanskar Modi <sanskarmod...@gmail.com>
    Signed-off-by: Wang, Fei <fwan...@ebay.com>
---
 .../celeborn/common/rpc/RpcEndpointRef.scala       | 34 ++++------------------
 .../org/apache/celeborn/common/rpc/RpcEnv.scala    | 33 ++++-----------------
 .../org/apache/celeborn/common/util/Utils.scala    | 29 ++++++++++++++++++
 3 files changed, 39 insertions(+), 57 deletions(-)

diff --git 
a/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEndpointRef.scala 
b/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEndpointRef.scala
index 8c861cf57..483144bdd 100644
--- a/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEndpointRef.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEndpointRef.scala
@@ -17,14 +17,12 @@
 
 package org.apache.celeborn.common.rpc
 
-import java.util.Random
-import java.util.concurrent.TimeUnit
-
 import scala.concurrent.Future
 import scala.reflect.ClassTag
 
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.internal.Logging
+import org.apache.celeborn.common.util.Utils
 
 /**
  * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe.
@@ -111,7 +109,7 @@ abstract class RpcEndpointRef(conf: CelebornConf)
 
   /**
    * Send a message to the corresponding [[RpcEndpoint.receiveAndReply]] and 
get its result within a
-   * specified timeout, retry if timeout, throw an exception if this still 
fails.
+   * specified timeout, retry if timeout or IOException, throw an exception if 
this still fails.
    *
    * Note: this is a blocking action which may cost a lot of time, so don't 
call it in a message
    * loop of [[RpcEndpoint]].
@@ -128,31 +126,9 @@ abstract class RpcEndpointRef(conf: CelebornConf)
       timeout: RpcTimeout,
       retryCount: Int,
       retryWait: Long): T = {
-    var numRetries = retryCount
-    while (numRetries > 0) {
-      numRetries -= 1
-      try {
-        val future = ask[T](message, timeout)
-        return timeout.awaitResult(future, address)
-      } catch {
-        case e: RpcTimeoutException =>
-          if (numRetries > 0) {
-            val random = new Random
-            val retryWaitMs = random.nextInt(retryWait.toInt)
-            try {
-              TimeUnit.MILLISECONDS.sleep(retryWaitMs)
-            } catch {
-              case _: InterruptedException =>
-                throw e
-            }
-          } else {
-            throw e
-          }
-        case e: Exception =>
-          throw e
-      }
+    Utils.withRetryOnTimeoutOrIOException(retryCount, retryWait) {
+      val future = ask[T](message, timeout)
+      return timeout.awaitResult(future, address)
     }
-    // should never be here
-    null.asInstanceOf[T]
   }
 }
diff --git a/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEnv.scala 
b/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEnv.scala
index 19f522a0a..606006e0e 100644
--- a/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEnv.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEnv.scala
@@ -18,8 +18,6 @@
 package org.apache.celeborn.common.rpc
 
 import java.io.File
-import java.util.Random
-import java.util.concurrent.TimeUnit
 
 import scala.concurrent.Future
 
@@ -151,39 +149,18 @@ abstract class RpcEnv(config: RpcEnvConfig) {
   }
 
   /**
-   * Retrieve the [[RpcEndpointRef]] represented by `address` and 
`endpointName` with timeout retry.
-   * This is a blocking action.
+   * Retrieve the [[RpcEndpointRef]] represented by `address` and 
`endpointName` within a specified
+   * timeout, retry if timeout or IOException, throw an exception if this 
still fails. This is a
+   * blocking action.
    */
   def setupEndpointRef(
       address: RpcAddress,
       endpointName: String,
       retryCount: Int,
       retryWait: Long = defaultRetryWait): RpcEndpointRef = {
-    var numRetries = retryCount
-    while (numRetries > 0) {
-      numRetries -= 1
-      try {
-        return setupEndpointRefByAddr(RpcEndpointAddress(address, 
endpointName))
-      } catch {
-        case e: RpcTimeoutException =>
-          if (numRetries > 0) {
-            val random = new Random
-            val retryWaitMs = random.nextInt(retryWait.toInt)
-            try {
-              TimeUnit.MILLISECONDS.sleep(retryWaitMs)
-            } catch {
-              case _: InterruptedException =>
-                throw e
-            }
-          } else {
-            throw e
-          }
-        case e: Exception =>
-          throw e
-      }
+    Utils.withRetryOnTimeoutOrIOException(retryCount, retryWait) {
+      return setupEndpointRefByAddr(RpcEndpointAddress(address, endpointName))
     }
-    // should never be here
-    null
   }
 
   /**
diff --git a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala 
b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
index 73f070ab9..8e784f790 100644
--- a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala
@@ -52,6 +52,7 @@ import org.apache.celeborn.common.network.util.TransportConf
 import org.apache.celeborn.common.protocol.{PartitionLocation, 
PartitionSplitMode, PartitionType, RpcNameConstants, TransportModuleConstants}
 import org.apache.celeborn.common.protocol.message.{ControlMessages, Message}
 import 
org.apache.celeborn.common.protocol.message.ControlMessages.WorkerResource
+import org.apache.celeborn.common.rpc.RpcTimeoutException
 import org.apache.celeborn.reflect.DynConstructors
 
 object Utils extends Logging {
@@ -1278,4 +1279,32 @@ object Utils extends Logging {
       rpcRequestId: Long): String = {
     s"$shuffleKey-$clientChannelId-$rpcRequestId"
   }
+
+  def withRetryOnTimeoutOrIOException[T](numRetries: Int, retryWait: 
Long)(block: => T): T = {
+    var retriesLeft = numRetries
+    while (retriesLeft >= 0) {
+      retriesLeft -= 1
+      try {
+        return block
+      } catch {
+        case e @ (_: RpcTimeoutException | _: IOException) =>
+          if (retriesLeft > 0) {
+            val random = new Random
+            val retryWaitMs = random.nextInt(retryWait.toInt)
+            try {
+              TimeUnit.MILLISECONDS.sleep(retryWaitMs)
+            } catch {
+              case _: InterruptedException =>
+                throw e
+            }
+          } else {
+            throw e
+          }
+        case e: Exception =>
+          throw e
+      }
+    }
+    // should never be here
+    null.asInstanceOf[T]
+  }
 }

Reply via email to