This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 9a09c4503ab34942fac7f856be7745e103a62361 Author: Sebastian Baunsgaard <[email protected]> AuthorDate: Mon Mar 31 13:15:31 2025 +0200 [SYSTEMDS-3845] Federated Threading Bug The federated back end spawn threads for parallel execution instead of using the threadpool. This commit fixes the issue by naming the worker threads to enable the threadpool usage. The performance on a local experiment using the FederatedKMeans test improve from 4.3 sec average to 3.2 sec average. To reproduce the results set the federated k-means test to repeat the federated call 20 times. Closes #2245 --- .../runtime/controlprogram/federated/FederatedWorkerHandler.java | 4 ++++ src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java | 2 +- .../test/functions/federated/algorithms/FederatedKmeansTest.java | 8 ++++++-- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java index ce21c79825..2cd9e8abf4 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java @@ -616,6 +616,10 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { try { // execute single instruction + // TODO move this thread naming to Netty thread creation! + Thread curThread = Thread.currentThread(); + long id = curThread.getId(); + Thread.currentThread().setName("FedExec_"+ id); pb.execute(ec); } catch(Exception ex) { diff --git a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java index 7eb0c2bd72..3ee08da0de 100644 --- a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java +++ b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java @@ -115,7 +115,7 @@ public class CommonThreadPool implements ExecutorService { final boolean mainThread = threadName.contains("main"); if(size == k && mainThread) return shared; // use the default thread pool if main thread and max parallelism. - else if(mainThread || threadName.contains("PARFOR")) { + else if(mainThread || threadName.contains("PARFOR") || threadName.contains("FedExec")) { CommonThreadPool pool; if(shared2 == null) // If there is no current shared pool allocate one. shared2 = new ConcurrentHashMap<>(); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java index 02bfb960d5..c8605ac3d8 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java @@ -22,6 +22,8 @@ package org.apache.sysds.test.functions.federated.algorithms; import java.util.Arrays; import java.util.Collection; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types; import org.apache.sysds.common.Types.ExecMode; import org.apache.sysds.runtime.controlprogram.ParForProgramBlock; @@ -31,6 +33,7 @@ import org.apache.sysds.runtime.util.HDFSTool; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.stats.Timing; import org.junit.Assert; import org.junit.Ignore; import org.junit.Test; @@ -40,6 +43,7 @@ import org.junit.runners.Parameterized; @RunWith(value = Parameterized.class) @net.jcip.annotations.NotThreadSafe public class FederatedKmeansTest extends AutomatedTestBase { + protected static final Log LOG = LogFactory.getLog(FederatedKmeansTest.class.getName()); private final static String TEST_DIR = "functions/federated/"; private final static String TEST_NAME = "FederatedKmeansTest"; @@ -120,7 +124,6 @@ public class FederatedKmeansTest extends AutomatedTestBase { programArgs = new String[] {"-args", input("X1"), input("X2"), String.valueOf(singleWorker).toUpperCase(), String.valueOf(runs), expected("Z")}; runTest(true, false, null, -1); - // Run actual dml script with federated matrix fullDMLScriptName = HOME + TEST_NAME + ".dml"; programArgs = new String[] {"-stats","20", "-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), @@ -130,8 +133,9 @@ public class FederatedKmeansTest extends AutomatedTestBase { for(int i = 0; i < rep; i++) { ParForProgramBlock.resetWorkerIDs(); FederationUtils.resetFedDataID(); + Timing t = new Timing(); runTest(true, false, null, -1); - + LOG.debug("Federated kmeans runtime: " + t); // check for federated operations Assert.assertTrue(heavyHittersContainsString("fed_ba+*")); // Assert.assertTrue(heavyHittersContainsString("fed_uasqk+"));
