This is an automated email from the ASF dual-hosted git repository. feiwang pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push: new 612464c69 [CELEBORN-2015] Retry IOException failures for RPC requests 612464c69 is described below commit 612464c69deba27047d725dbcd25cdb8f8b994e5 Author: Sanskar Modi <sanskarmod...@gmail.com> AuthorDate: Tue May 27 07:37:40 2025 -0700 [CELEBORN-2015] Retry IOException failures for RPC requests ### What changes were proposed in this pull request? - Support retry on IOException failures for RpcRequest in addition with RpcTimeoutException. - Moved duplicate code to Utils ### Why are the changes needed? Currently if a request fails with SocketException or IOException it does not get retried which leads to stage failures. Celeborn should retry on such connection failures. ### Does this PR introduce _any_ user-facing change? NA ### How was this patch tested? NA Closes #3286 from s0nskar/setup_lifecycle_exception. Authored-by: Sanskar Modi <sanskarmod...@gmail.com> Signed-off-by: Wang, Fei <fwan...@ebay.com> --- .../celeborn/common/rpc/RpcEndpointRef.scala | 34 ++++------------------ .../org/apache/celeborn/common/rpc/RpcEnv.scala | 33 ++++----------------- .../org/apache/celeborn/common/util/Utils.scala | 29 ++++++++++++++++++ 3 files changed, 39 insertions(+), 57 deletions(-) diff --git a/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEndpointRef.scala b/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEndpointRef.scala index 8c861cf57..483144bdd 100644 --- a/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEndpointRef.scala +++ b/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEndpointRef.scala @@ -17,14 +17,12 @@ package org.apache.celeborn.common.rpc -import java.util.Random -import java.util.concurrent.TimeUnit - import scala.concurrent.Future import scala.reflect.ClassTag import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.internal.Logging +import org.apache.celeborn.common.util.Utils /** * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. @@ -111,7 +109,7 @@ abstract class RpcEndpointRef(conf: CelebornConf) /** * Send a message to the corresponding [[RpcEndpoint.receiveAndReply]] and get its result within a - * specified timeout, retry if timeout, throw an exception if this still fails. + * specified timeout, retry if timeout or IOException, throw an exception if this still fails. * * Note: this is a blocking action which may cost a lot of time, so don't call it in a message * loop of [[RpcEndpoint]]. @@ -128,31 +126,9 @@ abstract class RpcEndpointRef(conf: CelebornConf) timeout: RpcTimeout, retryCount: Int, retryWait: Long): T = { - var numRetries = retryCount - while (numRetries > 0) { - numRetries -= 1 - try { - val future = ask[T](message, timeout) - return timeout.awaitResult(future, address) - } catch { - case e: RpcTimeoutException => - if (numRetries > 0) { - val random = new Random - val retryWaitMs = random.nextInt(retryWait.toInt) - try { - TimeUnit.MILLISECONDS.sleep(retryWaitMs) - } catch { - case _: InterruptedException => - throw e - } - } else { - throw e - } - case e: Exception => - throw e - } + Utils.withRetryOnTimeoutOrIOException(retryCount, retryWait) { + val future = ask[T](message, timeout) + return timeout.awaitResult(future, address) } - // should never be here - null.asInstanceOf[T] } } diff --git a/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEnv.scala b/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEnv.scala index 19f522a0a..606006e0e 100644 --- a/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEnv.scala +++ b/common/src/main/scala/org/apache/celeborn/common/rpc/RpcEnv.scala @@ -18,8 +18,6 @@ package org.apache.celeborn.common.rpc import java.io.File -import java.util.Random -import java.util.concurrent.TimeUnit import scala.concurrent.Future @@ -151,39 +149,18 @@ abstract class RpcEnv(config: RpcEnvConfig) { } /** - * Retrieve the [[RpcEndpointRef]] represented by `address` and `endpointName` with timeout retry. - * This is a blocking action. + * Retrieve the [[RpcEndpointRef]] represented by `address` and `endpointName` within a specified + * timeout, retry if timeout or IOException, throw an exception if this still fails. This is a + * blocking action. */ def setupEndpointRef( address: RpcAddress, endpointName: String, retryCount: Int, retryWait: Long = defaultRetryWait): RpcEndpointRef = { - var numRetries = retryCount - while (numRetries > 0) { - numRetries -= 1 - try { - return setupEndpointRefByAddr(RpcEndpointAddress(address, endpointName)) - } catch { - case e: RpcTimeoutException => - if (numRetries > 0) { - val random = new Random - val retryWaitMs = random.nextInt(retryWait.toInt) - try { - TimeUnit.MILLISECONDS.sleep(retryWaitMs) - } catch { - case _: InterruptedException => - throw e - } - } else { - throw e - } - case e: Exception => - throw e - } + Utils.withRetryOnTimeoutOrIOException(retryCount, retryWait) { + return setupEndpointRefByAddr(RpcEndpointAddress(address, endpointName)) } - // should never be here - null } /** diff --git a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala index 73f070ab9..8e784f790 100644 --- a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala +++ b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala @@ -52,6 +52,7 @@ import org.apache.celeborn.common.network.util.TransportConf import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionSplitMode, PartitionType, RpcNameConstants, TransportModuleConstants} import org.apache.celeborn.common.protocol.message.{ControlMessages, Message} import org.apache.celeborn.common.protocol.message.ControlMessages.WorkerResource +import org.apache.celeborn.common.rpc.RpcTimeoutException import org.apache.celeborn.reflect.DynConstructors object Utils extends Logging { @@ -1278,4 +1279,32 @@ object Utils extends Logging { rpcRequestId: Long): String = { s"$shuffleKey-$clientChannelId-$rpcRequestId" } + + def withRetryOnTimeoutOrIOException[T](numRetries: Int, retryWait: Long)(block: => T): T = { + var retriesLeft = numRetries + while (retriesLeft >= 0) { + retriesLeft -= 1 + try { + return block + } catch { + case e @ (_: RpcTimeoutException | _: IOException) => + if (retriesLeft > 0) { + val random = new Random + val retryWaitMs = random.nextInt(retryWait.toInt) + try { + TimeUnit.MILLISECONDS.sleep(retryWaitMs) + } catch { + case _: InterruptedException => + throw e + } + } else { + throw e + } + case e: Exception => + throw e + } + } + // should never be here + null.asInstanceOf[T] + } }