This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
commit da6a209696baf1102e15c65e4968e8106313a6a5 Author: Matthias Boehm <[email protected]> AuthorDate: Sat Feb 13 00:58:16 2021 +0100 [MINOR] Performance local parameter server (parallel updates) --- .../sysds/runtime/controlprogram/paramserv/LocalPSWorker.java | 10 ++++++++-- .../runtime/controlprogram/paramserv/dp/DCLocalScheme.java | 3 +++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java index 1b21853..241bfc6 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java @@ -37,10 +37,15 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { protected static final Log LOG = LogFactory.getLog(LocalPSWorker.class.getName()); private static final long serialVersionUID = 5195390748495357295L; + private boolean _parUpdates = false; + protected LocalPSWorker() {} - public LocalPSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, ExecutionContext ec, ParamServer ps) { + public LocalPSWorker(int workerID, String updFunc, Statement.PSFrequency freq, + int epochs, long batchSize, ExecutionContext ec, ParamServer ps, boolean parUpdates) + { super(workerID, updFunc, freq, epochs, batchSize, ec, ps); + _parUpdates = parUpdates; } @Override @@ -86,7 +91,8 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { boolean localUpdate = j < batchIter - 1; // Accumulate the intermediate gradients - accGradients = ParamservUtils.accrueGradients(accGradients, gradients, !localUpdate); + accGradients = ParamservUtils.accrueGradients( + accGradients, gradients, _parUpdates, !localUpdate); // Update the local model with gradients if(localUpdate) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DCLocalScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DCLocalScheme.java index 9b155c1..5fbe022 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DCLocalScheme.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DCLocalScheme.java @@ -26,6 +26,7 @@ import java.util.stream.Collectors; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.CollectionUtils; /** * Disjoint_Contiguous data partitioner: @@ -37,6 +38,8 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock; public class DCLocalScheme extends DataPartitionLocalScheme { public static List<MatrixBlock> partition(int k, MatrixBlock mb) { + if( k == 1 ) + return CollectionUtils.asArrayList(mb); List<MatrixBlock> list = new ArrayList<>(); long stepSize = (long) Math.ceil((double) mb.getNumRows() / k); long begin = 1;
