This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit da6a209696baf1102e15c65e4968e8106313a6a5
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Feb 13 00:58:16 2021 +0100

    [MINOR] Performance local parameter server (parallel updates)
---
 .../sysds/runtime/controlprogram/paramserv/LocalPSWorker.java  | 10 ++++++++--
 .../runtime/controlprogram/paramserv/dp/DCLocalScheme.java     |  3 +++
 2 files changed, 11 insertions(+), 2 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
index 1b21853..241bfc6 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -37,10 +37,15 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
        protected static final Log LOG = 
LogFactory.getLog(LocalPSWorker.class.getName());
        private static final long serialVersionUID = 5195390748495357295L;
 
+       private boolean _parUpdates = false;
+       
        protected LocalPSWorker() {}
 
-       public LocalPSWorker(int workerID, String updFunc, 
Statement.PSFrequency freq, int epochs, long batchSize, ExecutionContext ec, 
ParamServer ps) {
+       public LocalPSWorker(int workerID, String updFunc, 
Statement.PSFrequency freq,
+               int epochs, long batchSize, ExecutionContext ec, ParamServer 
ps, boolean parUpdates)
+       {
                super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
+               _parUpdates = parUpdates;
        }
 
        @Override
@@ -86,7 +91,8 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
 
                                boolean localUpdate = j < batchIter - 1;
                                // Accumulate the intermediate gradients
-                               accGradients = 
ParamservUtils.accrueGradients(accGradients, gradients, !localUpdate);
+                               accGradients = ParamservUtils.accrueGradients(
+                                       accGradients, gradients, _parUpdates, 
!localUpdate);
 
                                // Update the local model with gradients
                                if(localUpdate)
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DCLocalScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DCLocalScheme.java
index 9b155c1..5fbe022 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DCLocalScheme.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DCLocalScheme.java
@@ -26,6 +26,7 @@ import java.util.stream.Collectors;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.CollectionUtils;
 
 /**
  * Disjoint_Contiguous data partitioner:
@@ -37,6 +38,8 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 public class DCLocalScheme extends DataPartitionLocalScheme {
 
        public static List<MatrixBlock> partition(int k, MatrixBlock mb) {
+               if( k == 1 )
+                       return CollectionUtils.asArrayList(mb);
                List<MatrixBlock> list = new ArrayList<>();
                long stepSize = (long) Math.ceil((double) mb.getNumRows() / k);
                long begin = 1;

Reply via email to