This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new de7d9c3  [SYSTEMDS-3185] Handling of multi-threading in federated 
workers
de7d9c3 is described below

commit de7d9c3caf451c3c587636023f76d0b6f1fadf6f
Author: ywcb00 <[email protected]>
AuthorDate: Thu Mar 3 21:07:08 2022 +0100

    [SYSTEMDS-3185] Handling of multi-threading in federated workers
    
    Closes #1535.
---
 .../federated/FederatedWorkerHandler.java          |  7 +++++
 .../matrix/operators/AggregateBinaryOperator.java  |  9 ++-----
 .../matrix/operators/AggregateTernaryOperator.java |  9 ++-----
 .../matrix/operators/AggregateUnaryOperator.java   | 11 +++-----
 .../runtime/matrix/operators/BinaryOperator.java   | 13 ++-------
 .../sysds/runtime/matrix/operators/CMOperator.java | 13 +++------
 .../runtime/matrix/operators/COVOperator.java      |  9 ++-----
 ...COVOperator.java => MultiThreadedOperator.java} | 31 ++++++++++------------
 .../runtime/matrix/operators/ReorgOperator.java    | 11 +++-----
 .../runtime/matrix/operators/ScalarOperator.java   | 13 ++-------
 .../runtime/matrix/operators/TernaryOperator.java  |  9 ++-----
 .../runtime/matrix/operators/UnaryOperator.java    |  9 ++-----
 12 files changed, 45 insertions(+), 99 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 5ab9c48..9a758ea 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -64,6 +64,7 @@ import org.apache.sysds.runtime.lineage.LineageItemUtils;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.meta.MetaDataAll;
 import org.apache.sysds.runtime.meta.MetaDataFormat;
+import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator;
 import org.apache.sysds.runtime.privacy.DMLPrivacyException;
 import org.apache.sysds.runtime.privacy.PrivacyMonitor;
 import org.apache.sysds.utils.Statistics;
@@ -422,6 +423,12 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                        ec = ecm.get(request.getTID());
                }
 
+               // set the number of threads according to the number of 
processors on the federated worker
+               if(receivedInstruction.getOperator() instanceof 
MultiThreadedOperator) {
+                       int numProcessors = 
Runtime.getRuntime().availableProcessors();
+                       
((MultiThreadedOperator)receivedInstruction.getOperator()).setNumThreads(numProcessors);
+               }
+
                BasicProgramBlock pb = new BasicProgramBlock(null);
                pb.getInstructions().clear();
                pb.getInstructions().add(receivedInstruction);
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateBinaryOperator.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateBinaryOperator.java
index 506abea..4e36a28 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateBinaryOperator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateBinaryOperator.java
@@ -23,12 +23,11 @@ import org.apache.sysds.runtime.functionobjects.Multiply;
 import org.apache.sysds.runtime.functionobjects.Plus;
 import org.apache.sysds.runtime.functionobjects.ValueFunction;
 
-public class AggregateBinaryOperator extends Operator {
+public class AggregateBinaryOperator extends MultiThreadedOperator {
        private static final long serialVersionUID = 1666421325090925726L;
 
        public final ValueFunction binaryFn;
        public final AggregateOperator aggOp;
-       private final int k; // num threads
 
        public AggregateBinaryOperator(ValueFunction inner, AggregateOperator 
outer) {
                // default degree of parallelism is 1
@@ -41,10 +40,6 @@ public class AggregateBinaryOperator extends Operator {
                super(inner instanceof Multiply && outer.increOp.fn instanceof 
Plus);
                binaryFn = inner;
                aggOp = outer;
-               k = numThreads;
-       }
-
-       public int getNumThreads() {
-               return k;
+               _numThreads = numThreads;
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateTernaryOperator.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateTernaryOperator.java
index 90f1a04..ba18801 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateTernaryOperator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateTernaryOperator.java
@@ -24,14 +24,13 @@ import 
org.apache.sysds.runtime.functionobjects.IndexFunction;
 import org.apache.sysds.runtime.functionobjects.ValueFunction;
 
 
-public class AggregateTernaryOperator extends Operator
+public class AggregateTernaryOperator extends MultiThreadedOperator
 {
        private static final long serialVersionUID = 4251745081160216784L;
        
        public final ValueFunction binaryFn;
        public final AggregateOperator aggOp;
        public final IndexFunction indexFn;
-       private final int k; //num threads
        
        public AggregateTernaryOperator(ValueFunction inner, AggregateOperator 
outer, IndexFunction ixfun) {
                //default degree of parallelism is 1 (e.g., for distributed 
operations)
@@ -44,10 +43,6 @@ public class AggregateTernaryOperator extends Operator
                binaryFn = inner;
                aggOp = outer;
                indexFn = ixfun;
-               k = numThreads;
-       }
-       
-       public int getNumThreads() {
-               return k;
+               _numThreads = numThreads;
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateUnaryOperator.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateUnaryOperator.java
index a88b68f..0896cbc 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateUnaryOperator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateUnaryOperator.java
@@ -31,12 +31,11 @@ import org.apache.sysds.runtime.functionobjects.ReduceCol;
 import org.apache.sysds.runtime.functionobjects.ReduceRow;
 
 
-public class AggregateUnaryOperator extends Operator {
+public class AggregateUnaryOperator extends MultiThreadedOperator {
        private static final long serialVersionUID = 6690553323120787735L;
 
        public final AggregateOperator aggOp;
        public final IndexFunction indexFn;
-       private final int k; //num threads
 
        public AggregateUnaryOperator(AggregateOperator aop, IndexFunction iop)
        {
@@ -54,13 +53,9 @@ public class AggregateUnaryOperator extends Operator {
                        || aop.increOp.fn instanceof Minus);
                aggOp = aop;
                indexFn = iop;
-               k = numThreads;
+               _numThreads = numThreads;
        }
-       
-       public int getNumThreads(){
-               return k;
-       }
-       
+
        public boolean isRowAggregate() {
                return indexFn instanceof ReduceCol;
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
index b712187..f1b98cd 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
@@ -50,12 +50,11 @@ import org.apache.sysds.runtime.functionobjects.Power;
 import org.apache.sysds.runtime.functionobjects.ValueFunction;
 import org.apache.sysds.runtime.functionobjects.Xor;
 
-public class BinaryOperator extends Operator {
+public class BinaryOperator extends MultiThreadedOperator {
        private static final long serialVersionUID = -2547950181558989209L;
 
        public final ValueFunction fn;
        public final boolean commutative;
-       private int _k = 1; // num threads
        
        public BinaryOperator(ValueFunction p) {
                this(p, 1);
@@ -70,17 +69,9 @@ public class BinaryOperator extends Operator {
                fn = p;
                commutative = p instanceof Plus || p instanceof Multiply || p 
instanceof And || p instanceof Or ||
                        p instanceof Xor || p instanceof Minus1Multiply;
-               _k = k;
+               _numThreads = k;
        }
 
-       public void setNumThreads(int k) {
-               _k = k;
-       }
-       
-       public int getNumThreads() {
-               return _k;
-       }
-       
        /**
         * Method for getting the hop binary operator type for a given function 
object.
         * This is used in order to use a common code path for consistency 
between 
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java
index 7e56e0e..f928f04 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java
@@ -23,7 +23,7 @@ package org.apache.sysds.runtime.matrix.operators;
 import org.apache.sysds.runtime.functionobjects.CM;
 import org.apache.sysds.runtime.functionobjects.ValueFunction;
 
-public class CMOperator extends Operator 
+public class CMOperator extends MultiThreadedOperator
 {
        private static final long serialVersionUID = 4126894676505115420L;
        
@@ -43,7 +43,6 @@ public class CMOperator extends Operator
 
        public final ValueFunction fn;
        public final AggregateOperationTypes aggOpType;
-       public final int k;
 
        public CMOperator(ValueFunction op, AggregateOperationTypes agg) {
                this(op, agg, 1);
@@ -53,21 +52,17 @@ public class CMOperator extends Operator
                super(true);
                fn = op;
                aggOpType = agg;
-               k = numThreads;
+               _numThreads = numThreads;
        }
 
        public AggregateOperationTypes getAggOpType() {
                return aggOpType;
        }
-       
-       public int getNumThreads() {
-               return k;
-       }
-       
+
        public CMOperator setCMAggOp(int order) {
                AggregateOperationTypes agg = getCMAggOpType(order);
                ValueFunction fn = CM.getCMFnObject(aggOpType);
-               return new CMOperator(fn, agg, k);
+               return new CMOperator(fn, agg, _numThreads);
        }
        
        public static AggregateOperationTypes getCMAggOpType ( int order ) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/COVOperator.java 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/COVOperator.java
index 9d288db..19cd1d4 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/COVOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/COVOperator.java
@@ -22,12 +22,11 @@ package org.apache.sysds.runtime.matrix.operators;
 
 import org.apache.sysds.runtime.functionobjects.COV;
 
-public class COVOperator extends Operator 
+public class COVOperator extends MultiThreadedOperator
 {
        private static final long serialVersionUID = -8404264552880694469L;
 
        public final COV fn;
-       public final int k;
        
        public COVOperator(COV op) {
                this(op, 1);
@@ -36,10 +35,6 @@ public class COVOperator extends Operator
        public COVOperator(COV op, int numThreads) {
                super(true);
                fn = op;
-               k = numThreads;
-       }
-       
-       public int getNumThreads() {
-               return k;
+               _numThreads = numThreads;
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/COVOperator.java 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/MultiThreadedOperator.java
similarity index 70%
copy from 
src/main/java/org/apache/sysds/runtime/matrix/operators/COVOperator.java
copy to 
src/main/java/org/apache/sysds/runtime/matrix/operators/MultiThreadedOperator.java
index 9d288db..cbe5a41 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/COVOperator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/MultiThreadedOperator.java
@@ -17,29 +17,26 @@
  * under the License.
  */
 
-
 package org.apache.sysds.runtime.matrix.operators;
 
-import org.apache.sysds.runtime.functionobjects.COV;
+public class MultiThreadedOperator extends Operator {
+       private static final long serialVersionUID = 3528522245925706630L;
 
-public class COVOperator extends Operator 
-{
-       private static final long serialVersionUID = -8404264552880694469L;
+       protected int _numThreads = 1;
 
-       public final COV fn;
-       public final int k;
-       
-       public COVOperator(COV op) {
-               this(op, 1);
+       public MultiThreadedOperator() {
+               super();
        }
-       
-       public COVOperator(COV op, int numThreads) {
-               super(true);
-               fn = op;
-               k = numThreads;
+
+       public MultiThreadedOperator(boolean sparseSafeFlag) {
+               super(sparseSafeFlag);
        }
-       
+
        public int getNumThreads() {
-               return k;
+               return _numThreads;
+       }
+
+       public void setNumThreads(int numThreads) {
+               _numThreads = numThreads;
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/ReorgOperator.java 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/ReorgOperator.java
index cc3d2e3..ba45a00 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/ReorgOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/ReorgOperator.java
@@ -22,11 +22,10 @@ package org.apache.sysds.runtime.matrix.operators;
 
 import org.apache.sysds.runtime.functionobjects.IndexFunction;
 
-public class ReorgOperator extends Operator{
+public class ReorgOperator extends MultiThreadedOperator {
        private static final long serialVersionUID = -5322516429026298404L;
 
        public final IndexFunction fn;
-       private final int k; //num threads
        
        public ReorgOperator(IndexFunction p) {
                //default degree of parallelism is 1 
@@ -37,14 +36,10 @@ public class ReorgOperator extends Operator{
        public ReorgOperator(IndexFunction p, int numThreads) {
                super(true);
                fn = p;
-               k = numThreads;
-       }
-
-       public int getNumThreads() {
-               return k;
+               _numThreads = numThreads;
        }
 
        public ReorgOperator setFn(IndexFunction fn) {
-               return new ReorgOperator(fn, k);
+               return new ReorgOperator(fn, _numThreads);
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/ScalarOperator.java 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/ScalarOperator.java
index d33bbae..a5a24e6 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/ScalarOperator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/ScalarOperator.java
@@ -38,13 +38,12 @@ import 
org.apache.sysds.runtime.functionobjects.ValueFunction;
  * Base class for all scalar operators.
  * 
  */
-public abstract class ScalarOperator extends Operator 
+public abstract class ScalarOperator extends MultiThreadedOperator
 {
        private static final long serialVersionUID = 4547253761093455869L;
 
        public final ValueFunction fn;
        protected final double _constant;
-       private int _k; //num threads
        
        public ScalarOperator(ValueFunction p, double cst) {
                this(p, cst, false);
@@ -63,21 +62,13 @@ public abstract class ScalarOperator extends Operator
                                || (p instanceof Builtin && 
((Builtin)p).getBuiltinCode()==BuiltinCode.MIN && cst>=0));
                fn = p;
                _constant = cst;
-               _k = numThreads;
+               _numThreads = numThreads;
        }
        
        public double getConstant() {
                return _constant;
        }
        
-       public void setNumThreads(int k) {
-               _k = k;
-       }
-       
-       public int getNumThreads() {
-               return _k;
-       }
-       
        public abstract ScalarOperator setConstant(double cst);
        
        public abstract ScalarOperator setConstant(double cst, int numThreads);
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/TernaryOperator.java 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/TernaryOperator.java
index 92b142d..8fc6dd1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/TernaryOperator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/TernaryOperator.java
@@ -25,21 +25,16 @@ import 
org.apache.sysds.runtime.functionobjects.MinusMultiply;
 import org.apache.sysds.runtime.functionobjects.PlusMultiply;
 import org.apache.sysds.runtime.functionobjects.TernaryValueFunction;
 
-public class TernaryOperator extends Operator{
+public class TernaryOperator extends MultiThreadedOperator {
        private static final long serialVersionUID = 3456088891054083634L;
        
        public final TernaryValueFunction fn;
-       private final int _k; // num threads
 
        public TernaryOperator(TernaryValueFunction p, int numThreads) {
                //ternaryop is sparse-safe iff (op 0 0 0) == 0
                super (p instanceof PlusMultiply || p instanceof MinusMultiply 
|| p instanceof IfElse);
                fn = p;
-               _k = numThreads;
-       }
-       
-       public int getNumThreads() {
-               return _k;
+               _numThreads = numThreads;
        }
        
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/UnaryOperator.java 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/UnaryOperator.java
index fde3f11..d1d2abc 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/UnaryOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/UnaryOperator.java
@@ -24,12 +24,11 @@ import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.functionobjects.Builtin;
 import org.apache.sysds.runtime.functionobjects.ValueFunction;
 
-public class UnaryOperator extends Operator 
+public class UnaryOperator extends MultiThreadedOperator
 {
        private static final long serialVersionUID = 2441990876648978637L;
 
        public final ValueFunction fn;
-       private final int k; //num threads
        private final boolean inplace;
 
        public UnaryOperator(ValueFunction p) {
@@ -45,14 +44,10 @@ public class UnaryOperator extends Operator
                        || ((Builtin)p).bFunc==Builtin.BuiltinCode.SQRT || 
((Builtin)p).bFunc==Builtin.BuiltinCode.SPROP
                        || ((Builtin)p).bFunc==Builtin.BuiltinCode.LOG_NZ || 
((Builtin)p).bFunc==Builtin.BuiltinCode.SIGN) );
                fn = p;
-               k = numThreads;
+               _numThreads = numThreads;
                inplace = inPlace;
        }
        
-       public int getNumThreads() {
-               return k;
-       }
-       
        public boolean isInplace() {
                return inplace;
        }

Reply via email to