This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 649fc8ec20 [MINOR] Fix ThreadPool for Federated
649fc8ec20 is described below
commit 649fc8ec20254a98c05a543e6b706cfe8edd56a7
Author: baunsgaard <[email protected]>
AuthorDate: Tue Aug 8 17:47:58 2023 +0200
[MINOR] Fix ThreadPool for Federated
This commit goes through the instances that call CommonThreadPool, and
fixes the remaining issues. The new double buffering is unfortunately not
one of them so i changed it to use a static single thread extra.
Closes #1877
---
src/main/java/org/apache/sysds/api/DMLScript.java | 4 +-
.../apache/sysds/conf/ConfigurationManager.java | 5 +-
.../sysds/runtime/compress/lib/CLALibStack.java | 2 +-
.../runtime/controlprogram/ParForProgramBlock.java | 31 +++++-----
.../controlprogram/federated/FederationMap.java | 1 -
.../controlprogram/paramserv/LocalPSWorker.java | 3 +-
.../sysds/runtime/frame/data/FrameBlock.java | 10 ++--
.../sysds/runtime/io/FrameReaderJSONLParallel.java | 2 +-
.../sysds/runtime/iogen/FormatIdentifyer.java | 19 ++----
.../sysds/runtime/matrix/data/LibMatrixDNN.java | 2 +-
.../sysds/runtime/matrix/data/MatrixBlock.java | 5 +-
.../sysds/runtime/util/CommonThreadPool.java | 15 ++---
.../runtime/util/DoubleBufferingOutputStream.java | 6 +-
.../java/org/apache/sysds/performance/Main.java | 3 +-
.../federated/multitenant/MultiTenantTestBase.java | 70 +++++++++++++++-------
.../functions/paramserv/ParamservSyntaxTest.java | 2 +
16 files changed, 94 insertions(+), 86 deletions(-)
diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java
b/src/main/java/org/apache/sysds/api/DMLScript.java
index ddc5ee2517..bf638dfcf7 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -326,6 +326,7 @@ public class DMLScript
//reset runtime platform and visualize flag
setGlobalExecMode(oldrtplatform);
EXPLAIN = oldexplain;
+ CommonThreadPool.shutdownAsyncPools();
}
return true;
@@ -572,9 +573,6 @@ public class DMLScript
//0) cleanup federated workers if necessary
FederatedData.clearFederatedWorkers();
- //0) shutdown prefetch/broadcast thread pool if necessary
- CommonThreadPool.shutdownAsyncPools();
-
//1) cleanup scratch space (everything for current uuid)
//(required otherwise export to hdfs would skip assumed
unnecessary writes if same name)
HDFSTool.deleteFileIfExistOnHDFS(
config.getTextValue(DMLConfig.SCRATCH_SPACE) + dirSuffix );
diff --git a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
index 62352bd2a0..088545b8ed 100644
--- a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
+++ b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
@@ -29,7 +29,6 @@ import org.apache.sysds.conf.CompilerConfig.ConfigType;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Compression.CompressConfig;
import org.apache.sysds.lops.compile.linearization.ILinearize;
-import
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.util.CommonThreadPool;
@@ -66,7 +65,7 @@ public class ConfigurationManager{
_dmlconf = new DMLConfig();
_cconf = new CompilerConfig();
- final ExecutorService pool =
CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism());
+ final ExecutorService pool = CommonThreadPool.get();
pool.submit(() ->{
try{
IOUtilFunctions.getFileSystem(_rJob);
@@ -75,7 +74,7 @@ public class ConfigurationManager{
LOG.warn(e.getMessage());
}
});
- pool.shutdown();
+ // pool.shutdown();
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java
index 178c13ad29..ffea0193b7 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java
@@ -218,7 +218,7 @@ public final class CLALibStack {
}
}
- final ExecutorService pool =
CommonThreadPool.get(Math.max(Math.min(clen / 500, k), 1));
+ final ExecutorService pool = CommonThreadPool.get();
try {
List<AColGroup> finalGroups = pool.submit(() -> {
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
index 94bbaf2545..790a92de58 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
@@ -19,11 +19,25 @@
package org.apache.sysds.runtime.controlprogram;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
+
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.log4j.Level;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.CompilerConfig;
@@ -31,7 +45,6 @@ import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.lops.Lop;
-import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.ParForStatementBlock;
@@ -91,24 +104,10 @@ import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.CollectionUtils;
-import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.ProgramConverter;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.stats.ParForStatistics;
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Set;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-import java.util.stream.Stream;
-
/**
@@ -122,7 +121,7 @@ import java.util.stream.Stream;
*
*/
public class ParForProgramBlock extends ForProgramBlock {
- protected static final Log LOG =
LogFactory.getLog(CommonThreadPool.class.getName());
+ protected static final Log LOG =
LogFactory.getLog(ParForProgramBlock.class.getName());
// execution modes
public enum PExecMode {
LOCAL, //local (master) multi-core execution mode
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 4db4f2b8b2..985fdb056e 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -614,7 +614,6 @@ public class FederationMap {
*/
public void forEachParallel(BiFunction<FederatedRange, FederatedData,
Void> forEachFunction) {
ExecutorService pool = CommonThreadPool.get(_fedMap.size());
-
ArrayList<MappingTask> mappingTasks = new ArrayList<>();
for(Pair<FederatedRange, FederatedData> fedMap : _fedMap)
mappingTasks.add(new MappingTask(fedMap.getKey(),
fedMap.getValue(), forEachFunction, _ID));
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
index 5343332eb5..a3be38cafd 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -31,7 +31,6 @@ import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.util.CommonThreadPool;
@@ -91,7 +90,7 @@ public class LocalPSWorker extends PSWorker implements
Callable<Void> {
ListObject params = pullModel();
Future<ListObject> accGradients =
ConcurrentUtils.constantFuture(null);
if(_tpool == null)
- _tpool =
CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism());
+ _tpool = CommonThreadPool.get();
try {
for (int j = 0; j < batchIter; j++) {
diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
index ed5d48d6b3..94ab8f00de 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
@@ -49,7 +49,6 @@ import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.codegen.CodegenUtils;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
-import
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.frame.data.columns.Array;
import org.apache.sysds.runtime.frame.data.columns.ArrayFactory;
@@ -860,26 +859,25 @@ public class FrameBlock implements
CacheBlock<FrameBlock>, Externalizable {
size +=
ArrayFactory.getInMemorySize(_schema[j], rlen, true);
else {// allocated
if(rlen > 1000 && clen > 10 &&
ConfigurationManager.isParallelIOEnabled()) {
- final ExecutorService pool =
CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism());
+ final ExecutorService pool =
CommonThreadPool.get();
try {
size += pool.submit(() -> {
return
Arrays.stream(_coldata).parallel() // parallel columns
.map(x ->
x.getInMemorySize()).reduce(0L, Long::sum);
}).get();
- pool.shutdown();
-
}
catch(InterruptedException | ExecutionException
e) {
- pool.shutdown();
LOG.error(e);
for(Array<?> aa : _coldata)
size += aa.getInMemorySize();
}
+ finally{
+ pool.shutdown();
+ }
}
else {
for(Array<?> aa : _coldata)
size += aa.getInMemorySize();
-
}
}
return size;
diff --git
a/src/main/java/org/apache/sysds/runtime/io/FrameReaderJSONLParallel.java
b/src/main/java/org/apache/sysds/runtime/io/FrameReaderJSONLParallel.java
index 17abd9e3c8..14143e0099 100644
--- a/src/main/java/org/apache/sysds/runtime/io/FrameReaderJSONLParallel.java
+++ b/src/main/java/org/apache/sysds/runtime/io/FrameReaderJSONLParallel.java
@@ -53,7 +53,7 @@ public class FrameReaderJSONLParallel extends FrameReaderJSONL
splits = IOUtilFunctions.sortInputSplits(splits);
try{
- ExecutorService executorPool =
CommonThreadPool.get(Math.min(numThreads, splits.length));
+ ExecutorService executorPool =
CommonThreadPool.get(numThreads);
//compute num rows per split
ArrayList<CountRowsTask> countRowsTasks = new
ArrayList<>();
diff --git a/src/main/java/org/apache/sysds/runtime/iogen/FormatIdentifyer.java
b/src/main/java/org/apache/sysds/runtime/iogen/FormatIdentifyer.java
index aa02ad37fc..3cbc174d64 100644
--- a/src/main/java/org/apache/sysds/runtime/iogen/FormatIdentifyer.java
+++ b/src/main/java/org/apache/sysds/runtime/iogen/FormatIdentifyer.java
@@ -647,19 +647,14 @@ public class FormatIdentifyer {
colIndexes.add(0);
try {
- ExecutorService pool = CommonThreadPool.get(1);
ArrayList<BuildColsKeyPatternSingleRowTask> tasks = new
ArrayList<>();
tasks.add(
new
BuildColsKeyPatternSingleRowTask(prefixesRemovedReverse, prefixesRemoved,
prefixes, suffixes,
prefixesRemovedReverseSort, keys,
colSuffixes, lcs, colIndexes));
- //wait until all tasks have been executed
- List<Future<Object>> rt = pool.invokeAll(tasks);
- pool.shutdown();
-
//check for exceptions
- for(Future<Object> task : rt)
- task.get();
+ for(Callable<Object> task : tasks)
+ task.call();
}
catch(Exception e) {
throw new RuntimeException("Failed
BuildValueKeyPattern.", e);
@@ -770,19 +765,13 @@ public class FormatIdentifyer {
colIndexe.add(0);
try {
- ExecutorService pool = CommonThreadPool.get(1);
ArrayList<BuildColsKeyPatternSingleRowTask> tasks = new
ArrayList<>();
tasks.add(
new
BuildColsKeyPatternSingleRowTask(prefixesRemovedReverse, prefixesRemoved,
prefixes, suffixes,
prefixesRemovedReverseSort, keys,
colSuffixes, lcs, colIndexe));
-
- //wait until all tasks have been executed
- List<Future<Object>> rt = pool.invokeAll(tasks);
- pool.shutdown();
-
//check for exceptions
- for(Future<Object> task : rt)
- task.get();
+ for(Callable<Object> task : tasks)
+ task.call();
}
catch(Exception e) {
throw new RuntimeException("Failed
BuildValueKeyPattern.", e);
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java
index 598fef549d..26a00425a0 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java
@@ -670,7 +670,7 @@ public class LibMatrixDNN {
}
}
else {
- ExecutorService pool = CommonThreadPool.get(
Math.min(k, params.N) );
+ ExecutorService pool = CommonThreadPool.get(k);
List<Future<Long>> taskret =
pool.invokeAll(tasks);
pool.shutdown();
for( Future<Long> task : taskret )
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 5aaf0cd46a..01a5216b4b 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -35,8 +35,8 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.stream.IntStream;
-import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.lang3.concurrent.ConcurrentUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -54,7 +54,6 @@ import
org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.lib.CLALibAggTernaryOp;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
-import
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.DenseBlockFP64;
import org.apache.sysds.runtime.data.DenseBlockFactory;
@@ -373,7 +372,7 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
}
public Future<MatrixBlock> allocateBlockAsync() {
- ExecutorService pool =
CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism());
+ ExecutorService pool = CommonThreadPool.get();
return (pool != null) ? pool.submit(() -> allocateBlock()) :
//async
ConcurrentUtils.constantFuture(allocateBlock());
//fallback sync
}
diff --git a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
index cc6483d258..bc3be9844c 100644
--- a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
+++ b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
@@ -99,14 +99,14 @@ public class CommonThreadPool implements ExecutorService {
* @param k The number of threads wanted
* @return The executor with specified parallelism
*/
- public static ExecutorService get(int k) {
+ public synchronized static ExecutorService get(int k) {
if(size == k)
return shared;
else if(Thread.currentThread().getName().equals("main")) {
if(shared2 != null && shared2K == k)
return shared2;
else if(shared2 == null) {
- shared2 = new
CommonThreadPool(Executors.newFixedThreadPool(k));
+ shared2 = new CommonThreadPool(new
ForkJoinPool(k));
shared2K = k;
return shared2;
}
@@ -141,12 +141,13 @@ public class CommonThreadPool implements ExecutorService {
// check for errors and exceptions
for(Future<T> r : ret)
r.get();
- // shutdown pool
- pool.shutdown();
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
}
+ finally{
+ pool.shutdown();
+ }
}
/**
@@ -155,8 +156,8 @@ public class CommonThreadPool implements ExecutorService {
*
* @return A dynamic thread pool.
*/
- public static ExecutorService getDynamicPool() {
- if(asyncPool != null)
+ public synchronized static ExecutorService getDynamicPool() {
+ if(asyncPool != null && !(asyncPool.isShutdown() ||
asyncPool.isTerminated()) )
return asyncPool;
else {
asyncPool = Executors.newCachedThreadPool();
@@ -167,7 +168,7 @@ public class CommonThreadPool implements ExecutorService {
/**
* Shutdown the cached thread pools.
*/
- public static void shutdownAsyncPools() {
+ public synchronized static void shutdownAsyncPools() {
if(asyncPool != null) {
// shutdown prefetch/broadcast thread pool
asyncPool.shutdown();
diff --git
a/src/main/java/org/apache/sysds/runtime/util/DoubleBufferingOutputStream.java
b/src/main/java/org/apache/sysds/runtime/util/DoubleBufferingOutputStream.java
index f1d2fadd93..16504e64ee 100644
---
a/src/main/java/org/apache/sysds/runtime/util/DoubleBufferingOutputStream.java
+++
b/src/main/java/org/apache/sysds/runtime/util/DoubleBufferingOutputStream.java
@@ -24,13 +24,13 @@ import java.io.IOException;
import java.io.OutputStream;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.commons.lang3.concurrent.ConcurrentUtils;
-public class DoubleBufferingOutputStream extends FilterOutputStream
-{
- protected ExecutorService _pool = CommonThreadPool.get(1);
+public class DoubleBufferingOutputStream extends FilterOutputStream {
+ protected ExecutorService _pool = Executors.newSingleThreadExecutor();
protected Future<?>[] _locks;
protected byte[][] _buff;
private int _pos;
diff --git a/src/test/java/org/apache/sysds/performance/Main.java
b/src/test/java/org/apache/sysds/performance/Main.java
index 1e51a703bf..fa89a62b53 100644
--- a/src/test/java/org/apache/sysds/performance/Main.java
+++ b/src/test/java/org/apache/sysds/performance/Main.java
@@ -123,10 +123,11 @@ public class Main {
public static void main(String[] args) {
try {
exec(Integer.parseInt(args[0]), args);
- CommonThreadPool.get().shutdown();
}
catch(Exception e) {
e.printStackTrace();
+ }finally{
+ CommonThreadPool.get().shutdown();
}
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java
index aa7141bd18..0b4e193ac4 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java
@@ -19,15 +19,20 @@
package org.apache.sysds.test.functions.federated.multitenant;
+import static org.junit.Assert.fail;
+
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
-
-import static org.junit.Assert.fail;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.test.AutomatedTestBase;
@@ -36,6 +41,8 @@ import org.junit.After;
import com.google.crypto.tink.subtle.Random;
public abstract class MultiTenantTestBase extends AutomatedTestBase {
+ protected static final Log LOG =
LogFactory.getLog(MultiTenantTestBase.class.getName());
+
protected ArrayList<Process> workerProcesses = new ArrayList<>();
protected ArrayList<Process> coordinatorProcesses = new ArrayList<>();
@@ -56,8 +63,7 @@ public abstract class MultiTenantTestBase extends
AutomatedTestBase {
}
/**
- * Start numFedWorkers federated worker processes on available ports
and add
- * them to the workerProcesses
+ * Start numFedWorkers federated worker processes on available ports
and add them to the workerProcesses
*
* @param numFedWorkers the number of federated workers to start
* @return int[] the ports of the created federated workers
@@ -67,20 +73,20 @@ public abstract class MultiTenantTestBase extends
AutomatedTestBase {
for(int counter = 0; counter < numFedWorkers; counter++) {
ports[counter] = getRandomAvailablePort();
// start process but only wait long for last one.
- Process tmpProcess =
startLocalFedWorker(ports[counter], addArgs,
- counter == numFedWorkers-1 ? (FED_WORKER_WAIT +
Random.randInt(1000)) * 3 : FED_WORKER_WAIT_S);
+ Process tmpProcess =
startLocalFedWorker(ports[counter], addArgs,
+ counter == numFedWorkers - 1 ? (FED_WORKER_WAIT
+ Random.randInt(1000)) * 3 : FED_WORKER_WAIT_S);
workerProcesses.add(tmpProcess);
}
return ports;
}
/**
- * Start a coordinator process running the specified script with given
arguments
- * and add it to the coordinatorProcesses
+ * Start a coordinator process running the specified script with given
arguments and add it to the
+ * coordinatorProcesses
*
- * @param execMode the execution mode of the coordinator
+ * @param execMode the execution mode of the coordinator
* @param scriptPath the path to the dml script
- * @param args the program arguments for running the dml script
+ * @param args the program arguments for running the dml script
*/
protected void startCoordinator(ExecMode execMode, String scriptPath,
String[] args) {
String separator = System.getProperty("file.separator");
@@ -90,14 +96,14 @@ public abstract class MultiTenantTestBase extends
AutomatedTestBase {
String em = null;
switch(execMode) {
case SINGLE_NODE:
- em = "singlenode";
- break;
+ em = "singlenode";
+ break;
case HYBRID:
- em = "hybrid";
- break;
+ em = "hybrid";
+ break;
case SPARK:
- em = "spark";
- break;
+ em = "spark";
+ break;
}
ArrayList<String> argsList = new ArrayList<>();
@@ -108,13 +114,14 @@ public abstract class MultiTenantTestBase extends
AutomatedTestBase {
argsList.addAll(Arrays.asList(args));
// create the processBuilder and redirect the stderr to its
stdout
- ProcessBuilder processBuilder = new
ProcessBuilder(ArrayUtils.addAll(new String[]{
- path, "-cp", classpath, DMLScript.class.getName()},
argsList.toArray(new String[0])));
+ ProcessBuilder processBuilder = new ProcessBuilder(ArrayUtils
+ .addAll(new String[] {path, "-cp", classpath,
DMLScript.class.getName()}, argsList.toArray(new String[0])));
Process process = null;
try {
process = processBuilder.start();
- } catch(IOException ioe) {
+ }
+ catch(IOException ioe) {
ioe.printStackTrace();
fail("Can't start the coordinator process.");
}
@@ -122,12 +129,28 @@ public abstract class MultiTenantTestBase extends
AutomatedTestBase {
}
/**
- * Wait for all processes of coordinatorProcesses to terminate and
collect
- * their output
+ * Wait for all processes of coordinatorProcesses to terminate and
collect their output
*
* @return String the collected output of the coordinator processes
*/
protected String waitForCoordinators() {
+ return waitForCoordinators(500);
+ }
+
+ protected String waitForCoordinators(int timeout){
+ ExecutorService executor = Executors.newCachedThreadPool();
+ try{
+ return executor.submit(() ->
waitForCoordinatorsActual()).get(timeout, TimeUnit.SECONDS);
+ }
+ catch(Exception e){
+ throw new RuntimeException(e);
+ }
+ finally{
+ executor.shutdown();
+ }
+ }
+
+ private String waitForCoordinatorsActual(){
// wait for the coordinator processes to finish and collect
their output
StringBuilder outputLog = new StringBuilder();
for(int counter = 0; counter < coordinatorProcesses.size();
counter++) {
@@ -139,9 +162,10 @@ public abstract class MultiTenantTestBase extends
AutomatedTestBase {
outputLog.append(IOUtils.toString(coord.getErrorStream(),
Charset.defaultCharset()));
coord.waitFor();
- } catch(Exception ex) {
+ }
+ catch(Exception ex) {
fail(ex.getClass().getSimpleName() + " thrown
while collecting log output of coordinator #"
- + Integer.toString(counter+1) + ".\n");
+ + Integer.toString(counter + 1) +
".\n");
ex.printStackTrace();
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSyntaxTest.java
b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSyntaxTest.java
index f3804066e3..7d4b5f4b25 100644
---
a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSyntaxTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSyntaxTest.java
@@ -93,9 +93,11 @@ public class ParamservSyntaxTest extends AutomatedTestBase {
private void runDMLTest(String testname, boolean exceptionExpected,
Class<?> exceptionClass, String errmsg) {
TestConfiguration config = getTestConfiguration(testname);
+ setOutputBuffering(true);
loadTestConfiguration(config);
programArgs = new String[] { "-explain" };
fullDMLScriptName = HOME + testname + ".dml";
runTest(true, exceptionExpected, exceptionClass, errmsg, -1);
+ setOutputBuffering(false);
}
}