This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 649fc8ec20 [MINOR] Fix ThreadPool for Federated
649fc8ec20 is described below

commit 649fc8ec20254a98c05a543e6b706cfe8edd56a7
Author: baunsgaard <[email protected]>
AuthorDate: Tue Aug 8 17:47:58 2023 +0200

    [MINOR] Fix ThreadPool for Federated
    
    This commit goes through the instances that call CommonThreadPool, and
    fixes the remaining issues. The new double buffering is unfortunately not
    one of them so i changed it to use a static single thread extra.
    
    Closes #1877
---
 src/main/java/org/apache/sysds/api/DMLScript.java  |  4 +-
 .../apache/sysds/conf/ConfigurationManager.java    |  5 +-
 .../sysds/runtime/compress/lib/CLALibStack.java    |  2 +-
 .../runtime/controlprogram/ParForProgramBlock.java | 31 +++++-----
 .../controlprogram/federated/FederationMap.java    |  1 -
 .../controlprogram/paramserv/LocalPSWorker.java    |  3 +-
 .../sysds/runtime/frame/data/FrameBlock.java       | 10 ++--
 .../sysds/runtime/io/FrameReaderJSONLParallel.java |  2 +-
 .../sysds/runtime/iogen/FormatIdentifyer.java      | 19 ++----
 .../sysds/runtime/matrix/data/LibMatrixDNN.java    |  2 +-
 .../sysds/runtime/matrix/data/MatrixBlock.java     |  5 +-
 .../sysds/runtime/util/CommonThreadPool.java       | 15 ++---
 .../runtime/util/DoubleBufferingOutputStream.java  |  6 +-
 .../java/org/apache/sysds/performance/Main.java    |  3 +-
 .../federated/multitenant/MultiTenantTestBase.java | 70 +++++++++++++++-------
 .../functions/paramserv/ParamservSyntaxTest.java   |  2 +
 16 files changed, 94 insertions(+), 86 deletions(-)

diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java 
b/src/main/java/org/apache/sysds/api/DMLScript.java
index ddc5ee2517..bf638dfcf7 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -326,6 +326,7 @@ public class DMLScript
                        //reset runtime platform and visualize flag
                        setGlobalExecMode(oldrtplatform);
                        EXPLAIN = oldexplain;
+                       CommonThreadPool.shutdownAsyncPools();
                }
                
                return true;
@@ -572,9 +573,6 @@ public class DMLScript
                //0) cleanup federated workers if necessary
                FederatedData.clearFederatedWorkers();
                
-               //0) shutdown prefetch/broadcast thread pool if necessary
-               CommonThreadPool.shutdownAsyncPools();
-
                //1) cleanup scratch space (everything for current uuid)
                //(required otherwise export to hdfs would skip assumed 
unnecessary writes if same name)
                HDFSTool.deleteFileIfExistOnHDFS( 
config.getTextValue(DMLConfig.SCRATCH_SPACE) + dirSuffix );
diff --git a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java 
b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
index 62352bd2a0..088545b8ed 100644
--- a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
+++ b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
@@ -29,7 +29,6 @@ import org.apache.sysds.conf.CompilerConfig.ConfigType;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.lops.Compression.CompressConfig;
 import org.apache.sysds.lops.compile.linearization.ILinearize;
-import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysds.runtime.io.IOUtilFunctions;
 import org.apache.sysds.runtime.util.CommonThreadPool;
 
@@ -66,7 +65,7 @@ public class ConfigurationManager{
                _dmlconf = new DMLConfig();
                _cconf = new CompilerConfig();
 
-               final ExecutorService pool = 
CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism());
+               final ExecutorService pool = CommonThreadPool.get();
                pool.submit(() ->{
                        try{
                                IOUtilFunctions.getFileSystem(_rJob);
@@ -75,7 +74,7 @@ public class ConfigurationManager{
                                LOG.warn(e.getMessage());
                        }
                });
-               pool.shutdown();
+               // pool.shutdown();
        }
        
        
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java
index 178c13ad29..ffea0193b7 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java
@@ -218,7 +218,7 @@ public final class CLALibStack {
                        }
                }
 
-               final ExecutorService pool = 
CommonThreadPool.get(Math.max(Math.min(clen / 500, k), 1));
+               final ExecutorService pool = CommonThreadPool.get();
                try {
 
                        List<AColGroup> finalGroups = pool.submit(() -> {
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
index 94bbaf2545..790a92de58 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
@@ -19,11 +19,25 @@
 
 package org.apache.sysds.runtime.controlprogram;
 
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
+
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.log4j.Level;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.common.Types.FileFormat;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.conf.CompilerConfig;
@@ -31,7 +45,6 @@ import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.hops.recompile.Recompiler;
 import org.apache.sysds.lops.Lop;
-import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.parser.DMLProgram;
 import org.apache.sysds.parser.DataIdentifier;
 import org.apache.sysds.parser.ParForStatementBlock;
@@ -91,24 +104,10 @@ import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.lineage.LineageItemUtils;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.util.CollectionUtils;
-import org.apache.sysds.runtime.util.CommonThreadPool;
 import org.apache.sysds.runtime.util.ProgramConverter;
 import org.apache.sysds.runtime.util.UtilFunctions;
 import org.apache.sysds.utils.stats.ParForStatistics;
 
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Set;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-import java.util.stream.Stream;
-
 
 
 /**
@@ -122,7 +121,7 @@ import java.util.stream.Stream;
  *
  */
 public class ParForProgramBlock extends ForProgramBlock {      
-       protected static final Log LOG = 
LogFactory.getLog(CommonThreadPool.class.getName());
+       protected static final Log LOG = 
LogFactory.getLog(ParForProgramBlock.class.getName());
        // execution modes
        public enum PExecMode {
                LOCAL,          //local (master) multi-core execution mode
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 4db4f2b8b2..985fdb056e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -614,7 +614,6 @@ public class FederationMap {
         */
        public void forEachParallel(BiFunction<FederatedRange, FederatedData, 
Void> forEachFunction) {
                ExecutorService pool = CommonThreadPool.get(_fedMap.size());
-
                ArrayList<MappingTask> mappingTasks = new ArrayList<>();
                for(Pair<FederatedRange, FederatedData> fedMap : _fedMap)
                        mappingTasks.add(new MappingTask(fedMap.getKey(), 
fedMap.getValue(), forEachFunction, _ID));
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
index 5343332eb5..a3be38cafd 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -31,7 +31,6 @@ import org.apache.sysds.parser.Statement;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.runtime.util.CommonThreadPool;
@@ -91,7 +90,7 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
                        ListObject params = pullModel();
                        Future<ListObject> accGradients = 
ConcurrentUtils.constantFuture(null);
                        if(_tpool == null)
-                               _tpool = 
CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism());
+                               _tpool = CommonThreadPool.get();
 
                        try {
                                for (int j = 0; j < batchIter; j++) {
diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java 
b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
index ed5d48d6b3..94ab8f00de 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
@@ -49,7 +49,6 @@ import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.codegen.CodegenUtils;
 import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
-import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
 import org.apache.sysds.runtime.frame.data.columns.Array;
 import org.apache.sysds.runtime.frame.data.columns.ArrayFactory;
@@ -860,26 +859,25 @@ public class FrameBlock implements 
CacheBlock<FrameBlock>, Externalizable {
                                size += 
ArrayFactory.getInMemorySize(_schema[j], rlen, true);
                else {// allocated
                        if(rlen > 1000 && clen > 10 && 
ConfigurationManager.isParallelIOEnabled()) {
-                               final ExecutorService pool = 
CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism());
+                               final ExecutorService pool = 
CommonThreadPool.get();
                                try {
                                        size += pool.submit(() -> {
                                                return 
Arrays.stream(_coldata).parallel() // parallel columns
                                                        .map(x -> 
x.getInMemorySize()).reduce(0L, Long::sum);
                                        }).get();
-                                       pool.shutdown();
-
                                }
                                catch(InterruptedException | ExecutionException 
e) {
-                                       pool.shutdown();
                                        LOG.error(e);
                                        for(Array<?> aa : _coldata)
                                                size += aa.getInMemorySize();
                                }
+                               finally{
+                                       pool.shutdown();
+                               }
                        }
                        else {
                                for(Array<?> aa : _coldata)
                                        size += aa.getInMemorySize();
-
                        }
                }
                return size;
diff --git 
a/src/main/java/org/apache/sysds/runtime/io/FrameReaderJSONLParallel.java 
b/src/main/java/org/apache/sysds/runtime/io/FrameReaderJSONLParallel.java
index 17abd9e3c8..14143e0099 100644
--- a/src/main/java/org/apache/sysds/runtime/io/FrameReaderJSONLParallel.java
+++ b/src/main/java/org/apache/sysds/runtime/io/FrameReaderJSONLParallel.java
@@ -53,7 +53,7 @@ public class FrameReaderJSONLParallel extends FrameReaderJSONL
                splits = IOUtilFunctions.sortInputSplits(splits);
 
                try{
-                       ExecutorService executorPool = 
CommonThreadPool.get(Math.min(numThreads, splits.length));
+                       ExecutorService executorPool = 
CommonThreadPool.get(numThreads);
 
                        //compute num rows per split
                        ArrayList<CountRowsTask> countRowsTasks = new 
ArrayList<>();
diff --git a/src/main/java/org/apache/sysds/runtime/iogen/FormatIdentifyer.java 
b/src/main/java/org/apache/sysds/runtime/iogen/FormatIdentifyer.java
index aa02ad37fc..3cbc174d64 100644
--- a/src/main/java/org/apache/sysds/runtime/iogen/FormatIdentifyer.java
+++ b/src/main/java/org/apache/sysds/runtime/iogen/FormatIdentifyer.java
@@ -647,19 +647,14 @@ public class FormatIdentifyer {
                colIndexes.add(0);
 
                try {
-                       ExecutorService pool = CommonThreadPool.get(1);
                        ArrayList<BuildColsKeyPatternSingleRowTask> tasks = new 
ArrayList<>();
                        tasks.add(
                                new 
BuildColsKeyPatternSingleRowTask(prefixesRemovedReverse, prefixesRemoved, 
prefixes, suffixes,
                                        prefixesRemovedReverseSort, keys, 
colSuffixes, lcs, colIndexes));
 
-                       //wait until all tasks have been executed
-                       List<Future<Object>> rt = pool.invokeAll(tasks);
-                       pool.shutdown();
-
                        //check for exceptions
-                       for(Future<Object> task : rt)
-                               task.get();
+                       for(Callable<Object> task : tasks)
+                               task.call();
                }
                catch(Exception e) {
                        throw new RuntimeException("Failed 
BuildValueKeyPattern.", e);
@@ -770,19 +765,13 @@ public class FormatIdentifyer {
                colIndexe.add(0);
 
                try {
-                       ExecutorService pool = CommonThreadPool.get(1);
                        ArrayList<BuildColsKeyPatternSingleRowTask> tasks = new 
ArrayList<>();
                        tasks.add(
                                new 
BuildColsKeyPatternSingleRowTask(prefixesRemovedReverse, prefixesRemoved, 
prefixes, suffixes,
                                        prefixesRemovedReverseSort, keys, 
colSuffixes, lcs, colIndexe));
-
-                       //wait until all tasks have been executed
-                       List<Future<Object>> rt = pool.invokeAll(tasks);
-                       pool.shutdown();
-
                        //check for exceptions
-                       for(Future<Object> task : rt)
-                               task.get();
+                       for(Callable<Object> task : tasks)
+                               task.call();
                }
                catch(Exception e) {
                        throw new RuntimeException("Failed 
BuildValueKeyPattern.", e);
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java
index 598fef549d..26a00425a0 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java
@@ -670,7 +670,7 @@ public class LibMatrixDNN {
                                }
                        }
                        else {
-                               ExecutorService pool = CommonThreadPool.get( 
Math.min(k, params.N) );
+                               ExecutorService pool = CommonThreadPool.get(k);
                                List<Future<Long>> taskret = 
pool.invokeAll(tasks);
                                pool.shutdown();
                                for( Future<Long> task : taskret )
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 5aaf0cd46a..01a5216b4b 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -35,8 +35,8 @@ import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
 import java.util.stream.IntStream;
 
-import org.apache.commons.lang3.NotImplementedException;
 import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.lang3.NotImplementedException;
 import org.apache.commons.lang3.concurrent.ConcurrentUtils;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
@@ -54,7 +54,6 @@ import 
org.apache.sysds.runtime.compress.DMLCompressionException;
 import org.apache.sysds.runtime.compress.lib.CLALibAggTernaryOp;
 import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
-import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysds.runtime.data.DenseBlock;
 import org.apache.sysds.runtime.data.DenseBlockFP64;
 import org.apache.sysds.runtime.data.DenseBlockFactory;
@@ -373,7 +372,7 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
        }
        
        public Future<MatrixBlock> allocateBlockAsync() {
-               ExecutorService pool = 
CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism());
+               ExecutorService pool = CommonThreadPool.get();
                return (pool != null) ? pool.submit(() -> allocateBlock()) : 
//async
                        ConcurrentUtils.constantFuture(allocateBlock()); 
//fallback sync
        }
diff --git a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java 
b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
index cc6483d258..bc3be9844c 100644
--- a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
+++ b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
@@ -99,14 +99,14 @@ public class CommonThreadPool implements ExecutorService {
         * @param k The number of threads wanted
         * @return The executor with specified parallelism
         */
-       public static ExecutorService get(int k) {
+       public synchronized static ExecutorService get(int k) {
                if(size == k)
                        return shared;
                else if(Thread.currentThread().getName().equals("main")) {
                        if(shared2 != null && shared2K == k)
                                return shared2;
                        else if(shared2 == null) {
-                               shared2 = new 
CommonThreadPool(Executors.newFixedThreadPool(k));
+                               shared2 = new CommonThreadPool(new 
ForkJoinPool(k));
                                shared2K = k;
                                return shared2;
                        }
@@ -141,12 +141,13 @@ public class CommonThreadPool implements ExecutorService {
                        // check for errors and exceptions
                        for(Future<T> r : ret)
                                r.get();
-                       // shutdown pool
-                       pool.shutdown();
                }
                catch(Exception ex) {
                        throw new DMLRuntimeException(ex);
                }
+               finally{
+                       pool.shutdown();
+               }
        }
 
        /**
@@ -155,8 +156,8 @@ public class CommonThreadPool implements ExecutorService {
         * 
         * @return A dynamic thread pool.
         */
-       public static ExecutorService getDynamicPool() {
-               if(asyncPool != null)
+       public synchronized static ExecutorService getDynamicPool() {
+               if(asyncPool != null && !(asyncPool.isShutdown() || 
asyncPool.isTerminated()) )
                        return asyncPool;
                else {
                        asyncPool = Executors.newCachedThreadPool();
@@ -167,7 +168,7 @@ public class CommonThreadPool implements ExecutorService {
        /**
         * Shutdown the cached thread pools.
         */
-       public static void shutdownAsyncPools() {
+       public synchronized static void shutdownAsyncPools() {
                if(asyncPool != null) {
                        // shutdown prefetch/broadcast thread pool
                        asyncPool.shutdown();
diff --git 
a/src/main/java/org/apache/sysds/runtime/util/DoubleBufferingOutputStream.java 
b/src/main/java/org/apache/sysds/runtime/util/DoubleBufferingOutputStream.java
index f1d2fadd93..16504e64ee 100644
--- 
a/src/main/java/org/apache/sysds/runtime/util/DoubleBufferingOutputStream.java
+++ 
b/src/main/java/org/apache/sysds/runtime/util/DoubleBufferingOutputStream.java
@@ -24,13 +24,13 @@ import java.io.IOException;
 import java.io.OutputStream;
 import java.util.concurrent.Callable;
 import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 
 import org.apache.commons.lang3.concurrent.ConcurrentUtils;
 
-public class DoubleBufferingOutputStream extends FilterOutputStream
-{
-       protected ExecutorService _pool = CommonThreadPool.get(1);
+public class DoubleBufferingOutputStream extends FilterOutputStream {
+       protected ExecutorService _pool = Executors.newSingleThreadExecutor();
        protected Future<?>[] _locks;
        protected byte[][] _buff;
        private int _pos;
diff --git a/src/test/java/org/apache/sysds/performance/Main.java 
b/src/test/java/org/apache/sysds/performance/Main.java
index 1e51a703bf..fa89a62b53 100644
--- a/src/test/java/org/apache/sysds/performance/Main.java
+++ b/src/test/java/org/apache/sysds/performance/Main.java
@@ -123,10 +123,11 @@ public class Main {
        public static void main(String[] args) {
                try {
                        exec(Integer.parseInt(args[0]), args);
-                       CommonThreadPool.get().shutdown();
                }
                catch(Exception e) {
                        e.printStackTrace();
+               }finally{
+                       CommonThreadPool.get().shutdown();
                }
        }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java
index aa7141bd18..0b4e193ac4 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java
@@ -19,15 +19,20 @@
 
 package org.apache.sysds.test.functions.federated.multitenant;
 
+import static org.junit.Assert.fail;
+
 import java.io.IOException;
 import java.nio.charset.Charset;
 import java.util.ArrayList;
 import java.util.Arrays;
-
-import static org.junit.Assert.fail;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
 
 import org.apache.commons.io.IOUtils;
 import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.test.AutomatedTestBase;
@@ -36,6 +41,8 @@ import org.junit.After;
 import com.google.crypto.tink.subtle.Random;
 
 public abstract class MultiTenantTestBase extends AutomatedTestBase {
+       protected static final Log LOG = 
LogFactory.getLog(MultiTenantTestBase.class.getName());
+
        protected ArrayList<Process> workerProcesses = new ArrayList<>();
        protected ArrayList<Process> coordinatorProcesses = new ArrayList<>();
 
@@ -56,8 +63,7 @@ public abstract class MultiTenantTestBase extends 
AutomatedTestBase {
        }
 
        /**
-        * Start numFedWorkers federated worker processes on available ports 
and add
-        * them to the workerProcesses
+        * Start numFedWorkers federated worker processes on available ports 
and add them to the workerProcesses
         *
         * @param numFedWorkers the number of federated workers to start
         * @return int[] the ports of the created federated workers
@@ -67,20 +73,20 @@ public abstract class MultiTenantTestBase extends 
AutomatedTestBase {
                for(int counter = 0; counter < numFedWorkers; counter++) {
                        ports[counter] = getRandomAvailablePort();
                        // start process but only wait long for last one.
-                       Process tmpProcess = 
startLocalFedWorker(ports[counter], addArgs, 
-                               counter == numFedWorkers-1 ? (FED_WORKER_WAIT + 
Random.randInt(1000)) * 3 : FED_WORKER_WAIT_S);
+                       Process tmpProcess = 
startLocalFedWorker(ports[counter], addArgs,
+                               counter == numFedWorkers - 1 ? (FED_WORKER_WAIT 
+ Random.randInt(1000)) * 3 : FED_WORKER_WAIT_S);
                        workerProcesses.add(tmpProcess);
                }
                return ports;
        }
 
        /**
-        * Start a coordinator process running the specified script with given 
arguments
-        * and add it to the coordinatorProcesses
+        * Start a coordinator process running the specified script with given 
arguments and add it to the
+        * coordinatorProcesses
         *
-        * @param execMode the execution mode of the coordinator
+        * @param execMode   the execution mode of the coordinator
         * @param scriptPath the path to the dml script
-        * @param args the program arguments for running the dml script
+        * @param args       the program arguments for running the dml script
         */
        protected void startCoordinator(ExecMode execMode, String scriptPath, 
String[] args) {
                String separator = System.getProperty("file.separator");
@@ -90,14 +96,14 @@ public abstract class MultiTenantTestBase extends 
AutomatedTestBase {
                String em = null;
                switch(execMode) {
                        case SINGLE_NODE:
-                       em = "singlenode";
-                       break;
+                               em = "singlenode";
+                               break;
                        case HYBRID:
-                       em = "hybrid";
-                       break;
+                               em = "hybrid";
+                               break;
                        case SPARK:
-                       em = "spark";
-                       break;
+                               em = "spark";
+                               break;
                }
 
                ArrayList<String> argsList = new ArrayList<>();
@@ -108,13 +114,14 @@ public abstract class MultiTenantTestBase extends 
AutomatedTestBase {
                argsList.addAll(Arrays.asList(args));
 
                // create the processBuilder and redirect the stderr to its 
stdout
-               ProcessBuilder processBuilder = new 
ProcessBuilder(ArrayUtils.addAll(new String[]{
-                       path, "-cp", classpath, DMLScript.class.getName()}, 
argsList.toArray(new String[0])));
+               ProcessBuilder processBuilder = new ProcessBuilder(ArrayUtils
+                       .addAll(new String[] {path, "-cp", classpath, 
DMLScript.class.getName()}, argsList.toArray(new String[0])));
 
                Process process = null;
                try {
                        process = processBuilder.start();
-               } catch(IOException ioe) {
+               }
+               catch(IOException ioe) {
                        ioe.printStackTrace();
                        fail("Can't start the coordinator process.");
                }
@@ -122,12 +129,28 @@ public abstract class MultiTenantTestBase extends 
AutomatedTestBase {
        }
 
        /**
-        * Wait for all processes of coordinatorProcesses to terminate and 
collect
-        * their output
+        * Wait for all processes of coordinatorProcesses to terminate and 
collect their output
         *
         * @return String the collected output of the coordinator processes
         */
        protected String waitForCoordinators() {
+               return waitForCoordinators(500);
+       }
+
+       protected String waitForCoordinators(int timeout){
+               ExecutorService executor = Executors.newCachedThreadPool();
+               try{
+                       return executor.submit(() -> 
waitForCoordinatorsActual()).get(timeout, TimeUnit.SECONDS);
+               }
+               catch(Exception e){
+                       throw new RuntimeException(e);
+               }
+               finally{
+                       executor.shutdown();
+               }
+       }
+
+       private String waitForCoordinatorsActual(){
                // wait for the coordinator processes to finish and collect 
their output
                StringBuilder outputLog = new StringBuilder();
                for(int counter = 0; counter < coordinatorProcesses.size(); 
counter++) {
@@ -139,9 +162,10 @@ public abstract class MultiTenantTestBase extends 
AutomatedTestBase {
                                
outputLog.append(IOUtils.toString(coord.getErrorStream(), 
Charset.defaultCharset()));
 
                                coord.waitFor();
-                       } catch(Exception ex) {
+                       }
+                       catch(Exception ex) {
                                fail(ex.getClass().getSimpleName() + " thrown 
while collecting log output of coordinator #"
-                                       + Integer.toString(counter+1) + ".\n");
+                                       + Integer.toString(counter + 1) + 
".\n");
                                ex.printStackTrace();
                        }
                }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSyntaxTest.java
 
b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSyntaxTest.java
index f3804066e3..7d4b5f4b25 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSyntaxTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSyntaxTest.java
@@ -93,9 +93,11 @@ public class ParamservSyntaxTest extends AutomatedTestBase {
 
        private void runDMLTest(String testname, boolean exceptionExpected, 
Class<?> exceptionClass, String errmsg) {
                TestConfiguration config = getTestConfiguration(testname);
+               setOutputBuffering(true);
                loadTestConfiguration(config);
                programArgs = new String[] { "-explain" };
                fullDMLScriptName = HOME + testname + ".dml";
                runTest(true, exceptionExpected, exceptionClass, errmsg, -1);
+               setOutputBuffering(false);
        }
 }

Reply via email to