This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 2215cef40043 [SPARK-46353][CORE] Refactor to improve `RegisterWorker` unit test coverage 2215cef40043 is described below commit 2215cef40043a3205446f8daecafed8f2360a742 Author: Dongjoon Hyun <dh...@apple.com> AuthorDate: Tue Dec 12 09:57:43 2023 -0800 [SPARK-46353][CORE] Refactor to improve `RegisterWorker` unit test coverage ### What changes were proposed in this pull request? This PR aims to improve the unit test coverage for `RegisterWorker` message handling. - Add `handleRegisterWorker` helper method which is testable easily. - Add new unit tests for three conditional branches. ### Why are the changes needed? It's easily to test and improve. We can add more tests in this way in the future. ### Does this PR introduce _any_ user-facing change? No. This is a refactoring on the main code and only additions to the test methods. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #44284 from dongjoon-hyun/SPARK-46353. Authored-by: Dongjoon Hyun <dh...@apple.com> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- .../org/apache/spark/deploy/master/Master.scala | 75 +++++++++++++--------- .../apache/spark/deploy/master/MasterSuite.scala | 59 ++++++++++++++++- 2 files changed, 102 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index a550f44fc0a4..c8679c185ad7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -37,7 +37,7 @@ import org.apache.spark.internal.config.Deploy._ import org.apache.spark.internal.config.UI._ import org.apache.spark.internal.config.Worker._ import org.apache.spark.metrics.{MetricsSystem, MetricsSystemInstances} -import org.apache.spark.resource.{ResourceProfile, ResourceRequirement, ResourceUtils} +import org.apache.spark.resource.{ResourceInformation, ResourceProfile, ResourceRequirement, ResourceUtils} import org.apache.spark.rpc._ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer} import org.apache.spark.util.{SparkUncaughtExceptionHandler, ThreadUtils, Utils} @@ -75,7 +75,8 @@ private[deploy] class Master( private val waitingApps = new ArrayBuffer[ApplicationInfo] val apps = new HashSet[ApplicationInfo] - private val idToWorker = new HashMap[String, WorkerInfo] + // Visible for testing + private[master] val idToWorker = new HashMap[String, WorkerInfo] private val addressToWorker = new HashMap[RpcAddress, WorkerInfo] private val endpointToApp = new HashMap[RpcEndpointRef, ApplicationInfo] @@ -106,7 +107,7 @@ private[deploy] class Master( private[master] var state = RecoveryState.STANDBY - private var persistenceEngine: PersistenceEngine = _ + private[master] var persistenceEngine: PersistenceEngine = _ private var leaderElectionAgent: LeaderElectionAgent = _ @@ -281,33 +282,8 @@ private[deploy] class Master( case RegisterWorker( id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl, masterAddress, resources) => - logInfo("Registering worker %s:%d with %d cores, %s RAM".format( - workerHost, workerPort, cores, Utils.megabytesToString(memory))) - if (state == RecoveryState.STANDBY) { - workerRef.send(MasterInStandby) - } else if (idToWorker.contains(id)) { - if (idToWorker(id).state == WorkerState.UNKNOWN) { - logInfo("Worker has been re-registered: " + id) - idToWorker(id).state = WorkerState.ALIVE - } - workerRef.send(RegisteredWorker(self, masterWebUiUrl, masterAddress, true)) - } else { - val workerResources = - resources.map(r => r._1 -> WorkerResourceInfo(r._1, r._2.addresses.toImmutableArraySeq)) - val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, - workerRef, workerWebUiUrl, workerResources) - if (registerWorker(worker)) { - persistenceEngine.addWorker(worker) - workerRef.send(RegisteredWorker(self, masterWebUiUrl, masterAddress, false)) - schedule() - } else { - val workerAddress = worker.endpoint.address - logWarning("Worker registration failed. Attempted to re-register worker at same " + - "address: " + workerAddress) - workerRef.send(RegisterWorkerFailed("Attempted to re-register worker at same address: " - + workerAddress)) - } - } + handleRegisterWorker(id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl, + masterAddress, resources) case RegisterApplication(description, driver) => // TODO Prevent repeated registrations from some driver @@ -676,6 +652,45 @@ private[deploy] class Master( logInfo(f"Recovery complete in ${timeTakenNs / 1000000000d}%.3fs - resuming operations!") } + private[master] def handleRegisterWorker( + id: String, + workerHost: String, + workerPort: Int, + workerRef: RpcEndpointRef, + cores: Int, + memory: Int, + workerWebUiUrl: String, + masterAddress: RpcAddress, + resources: Map[String, ResourceInformation]): Unit = { + logInfo("Registering worker %s:%d with %d cores, %s RAM".format( + workerHost, workerPort, cores, Utils.megabytesToString(memory))) + if (state == RecoveryState.STANDBY) { + workerRef.send(MasterInStandby) + } else if (idToWorker.contains(id)) { + if (idToWorker(id).state == WorkerState.UNKNOWN) { + logInfo("Worker has been re-registered: " + id) + idToWorker(id).state = WorkerState.ALIVE + } + workerRef.send(RegisteredWorker(self, masterWebUiUrl, masterAddress, true)) + } else { + val workerResources = + resources.map(r => r._1 -> WorkerResourceInfo(r._1, r._2.addresses.toImmutableArraySeq)) + val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, + workerRef, workerWebUiUrl, workerResources) + if (registerWorker(worker)) { + persistenceEngine.addWorker(worker) + workerRef.send(RegisteredWorker(self, masterWebUiUrl, masterAddress, false)) + schedule() + } else { + val workerAddress = worker.endpoint.address + logWarning("Worker registration failed. Attempted to re-register worker at same " + + "address: " + workerAddress) + workerRef.send(RegisterWorkerFailed("Attempted to re-register worker at same address: " + + workerAddress)) + } + } + } + /** * Schedule executors to be launched on the workers. * Returns an array containing number of cores assigned to each worker. diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 9fd1991dab02..e15a5db770eb 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -30,12 +30,13 @@ import scala.reflect.ClassTag import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.mockito.ArgumentMatchers.any -import org.mockito.Mockito.{doNothing, mock, when} +import org.mockito.ArgumentMatchers.{any, eq => meq} +import org.mockito.Mockito.{doNothing, mock, times, verify, when} import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.scalatest.concurrent.Eventually import org.scalatest.matchers.must.Matchers import org.scalatest.matchers.should.Matchers._ +import org.scalatestplus.mockito.MockitoSugar.{mock => smock} import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} @@ -1373,6 +1374,60 @@ class MasterSuite extends SparkFunSuite eventLogCodec = None) assert(master.invokePrivate(_createApplication(desc, null)).id === "spark-45756") } + + test("SPARK-46353: handleRegisterWorker in STANDBY mode") { + val master = makeMaster() + val masterRpcAddress = smock[RpcAddress] + val worker = smock[RpcEndpointRef] + + assert(master.state === RecoveryState.STANDBY) + master.handleRegisterWorker("worker-0", "localhost", 1024, worker, 10, 4096, + "http://localhost:8081", masterRpcAddress, Map.empty) + verify(worker, times(1)).send(meq(MasterInStandby)) + verify(worker, times(0)) + .send(meq(RegisteredWorker(master.self, null, masterRpcAddress, duplicate = true))) + verify(worker, times(0)) + .send(meq(RegisteredWorker(master.self, null, masterRpcAddress, duplicate = false))) + assert(master.workers.isEmpty) + assert(master.idToWorker.isEmpty) + } + + test("SPARK-46353: handleRegisterWorker in RECOVERING mode without workers") { + val master = makeMaster() + val masterRpcAddress = smock[RpcAddress] + val worker = smock[RpcEndpointRef] + + master.state = RecoveryState.RECOVERING + master.persistenceEngine = new BlackHolePersistenceEngine() + master.handleRegisterWorker("worker-0", "localhost", 1024, worker, 10, 4096, + "http://localhost:8081", masterRpcAddress, Map.empty) + verify(worker, times(0)).send(meq(MasterInStandby)) + verify(worker, times(1)) + .send(meq(RegisteredWorker(master.self, null, masterRpcAddress, duplicate = false))) + assert(master.workers.size === 1) + assert(master.idToWorker.size === 1) + } + + test("SPARK-46353: handleRegisterWorker in RECOVERING mode with a unknown worker") { + val master = makeMaster() + val masterRpcAddress = smock[RpcAddress] + val worker = smock[RpcEndpointRef] + val workerInfo = smock[WorkerInfo] + when(workerInfo.state).thenReturn(WorkerState.UNKNOWN) + + master.state = RecoveryState.RECOVERING + master.workers.add(workerInfo) + master.idToWorker("worker-0") = workerInfo + master.persistenceEngine = new BlackHolePersistenceEngine() + master.handleRegisterWorker("worker-0", "localhost", 1024, worker, 10, 4096, + "http://localhost:8081", masterRpcAddress, Map.empty) + verify(worker, times(0)).send(meq(MasterInStandby)) + verify(worker, times(1)) + .send(meq(RegisteredWorker(master.self, null, masterRpcAddress, duplicate = true))) + assert(master.state === RecoveryState.RECOVERING) + assert(master.workers.nonEmpty) + assert(master.idToWorker.nonEmpty) + } } private class FakeRecoveryModeFactory(conf: SparkConf, ser: serializer.Serializer) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org