This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 9a09c4503ab34942fac7f856be7745e103a62361
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Mon Mar 31 13:15:31 2025 +0200

    [SYSTEMDS-3845] Federated Threading Bug
    
    The federated back end spawn threads for parallel execution instead of 
using the threadpool.
    This commit fixes the issue by naming the worker threads to enable the 
threadpool usage.
    
    The performance on a local experiment using the FederatedKMeans test 
improve from 4.3 sec average to 3.2 sec average. To reproduce the results set 
the federated k-means test to repeat the federated call 20 times.
    
    Closes #2245
---
 .../runtime/controlprogram/federated/FederatedWorkerHandler.java  | 4 ++++
 src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java | 2 +-
 .../test/functions/federated/algorithms/FederatedKmeansTest.java  | 8 ++++++--
 3 files changed, 11 insertions(+), 3 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index ce21c79825..2cd9e8abf4 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -616,6 +616,10 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                
                try {
                        // execute single instruction
+                       // TODO move this thread naming to Netty thread 
creation!
+                       Thread curThread = Thread.currentThread();
+                       long id = curThread.getId();
+                       Thread.currentThread().setName("FedExec_"+ id);
                        pb.execute(ec);
                }
                catch(Exception ex) {
diff --git a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java 
b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
index 7eb0c2bd72..3ee08da0de 100644
--- a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
+++ b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
@@ -115,7 +115,7 @@ public class CommonThreadPool implements ExecutorService {
                final boolean mainThread = threadName.contains("main");
                if(size == k && mainThread)
                        return shared; // use the default thread pool if main 
thread and max parallelism.
-               else if(mainThread || threadName.contains("PARFOR")) {
+               else if(mainThread || threadName.contains("PARFOR") || 
threadName.contains("FedExec")) {
                        CommonThreadPool pool;
                        if(shared2 == null) // If there is no current shared 
pool allocate one.
                                shared2 = new ConcurrentHashMap<>();
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
index 02bfb960d5..c8605ac3d8 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
@@ -22,6 +22,8 @@ package org.apache.sysds.test.functions.federated.algorithms;
 import java.util.Arrays;
 import java.util.Collection;
 
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.runtime.controlprogram.ParForProgramBlock;
@@ -31,6 +33,7 @@ import org.apache.sysds.runtime.util.HDFSTool;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.stats.Timing;
 import org.junit.Assert;
 import org.junit.Ignore;
 import org.junit.Test;
@@ -40,6 +43,7 @@ import org.junit.runners.Parameterized;
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
 public class FederatedKmeansTest extends AutomatedTestBase {
+       protected static final Log LOG = 
LogFactory.getLog(FederatedKmeansTest.class.getName());
 
        private final static String TEST_DIR = "functions/federated/";
        private final static String TEST_NAME = "FederatedKmeansTest";
@@ -120,7 +124,6 @@ public class FederatedKmeansTest extends AutomatedTestBase {
                programArgs = new String[] {"-args", input("X1"), input("X2"),
                        String.valueOf(singleWorker).toUpperCase(), 
String.valueOf(runs), expected("Z")};
                runTest(true, false, null, -1);
-
                // Run actual dml script with federated matrix
                fullDMLScriptName = HOME + TEST_NAME + ".dml";
                programArgs = new String[] {"-stats","20", "-nvargs", "in_X1=" 
+ TestUtils.federatedAddress(port1, input("X1")),
@@ -130,8 +133,9 @@ public class FederatedKmeansTest extends AutomatedTestBase {
                for(int i = 0; i < rep; i++) {
                        ParForProgramBlock.resetWorkerIDs();
                        FederationUtils.resetFedDataID();
+                       Timing t = new Timing();
                        runTest(true, false, null, -1);
-
+                       LOG.debug("Federated kmeans runtime: " + t);
                        // check for federated operations
                        
Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
                        // 
Assert.assertTrue(heavyHittersContainsString("fed_uasqk+"));

Reply via email to