This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new de7d9c3 [SYSTEMDS-3185] Handling of multi-threading in federated
workers
de7d9c3 is described below
commit de7d9c3caf451c3c587636023f76d0b6f1fadf6f
Author: ywcb00 <[email protected]>
AuthorDate: Thu Mar 3 21:07:08 2022 +0100
[SYSTEMDS-3185] Handling of multi-threading in federated workers
Closes #1535.
---
.../federated/FederatedWorkerHandler.java | 7 +++++
.../matrix/operators/AggregateBinaryOperator.java | 9 ++-----
.../matrix/operators/AggregateTernaryOperator.java | 9 ++-----
.../matrix/operators/AggregateUnaryOperator.java | 11 +++-----
.../runtime/matrix/operators/BinaryOperator.java | 13 ++-------
.../sysds/runtime/matrix/operators/CMOperator.java | 13 +++------
.../runtime/matrix/operators/COVOperator.java | 9 ++-----
...COVOperator.java => MultiThreadedOperator.java} | 31 ++++++++++------------
.../runtime/matrix/operators/ReorgOperator.java | 11 +++-----
.../runtime/matrix/operators/ScalarOperator.java | 13 ++-------
.../runtime/matrix/operators/TernaryOperator.java | 9 ++-----
.../runtime/matrix/operators/UnaryOperator.java | 9 ++-----
12 files changed, 45 insertions(+), 99 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 5ab9c48..9a758ea 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -64,6 +64,7 @@ import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataAll;
import org.apache.sysds.runtime.meta.MetaDataFormat;
+import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator;
import org.apache.sysds.runtime.privacy.DMLPrivacyException;
import org.apache.sysds.runtime.privacy.PrivacyMonitor;
import org.apache.sysds.utils.Statistics;
@@ -422,6 +423,12 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
ec = ecm.get(request.getTID());
}
+ // set the number of threads according to the number of
processors on the federated worker
+ if(receivedInstruction.getOperator() instanceof
MultiThreadedOperator) {
+ int numProcessors =
Runtime.getRuntime().availableProcessors();
+
((MultiThreadedOperator)receivedInstruction.getOperator()).setNumThreads(numProcessors);
+ }
+
BasicProgramBlock pb = new BasicProgramBlock(null);
pb.getInstructions().clear();
pb.getInstructions().add(receivedInstruction);
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateBinaryOperator.java
b/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateBinaryOperator.java
index 506abea..4e36a28 100644
---
a/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateBinaryOperator.java
+++
b/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateBinaryOperator.java
@@ -23,12 +23,11 @@ import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
-public class AggregateBinaryOperator extends Operator {
+public class AggregateBinaryOperator extends MultiThreadedOperator {
private static final long serialVersionUID = 1666421325090925726L;
public final ValueFunction binaryFn;
public final AggregateOperator aggOp;
- private final int k; // num threads
public AggregateBinaryOperator(ValueFunction inner, AggregateOperator
outer) {
// default degree of parallelism is 1
@@ -41,10 +40,6 @@ public class AggregateBinaryOperator extends Operator {
super(inner instanceof Multiply && outer.increOp.fn instanceof
Plus);
binaryFn = inner;
aggOp = outer;
- k = numThreads;
- }
-
- public int getNumThreads() {
- return k;
+ _numThreads = numThreads;
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateTernaryOperator.java
b/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateTernaryOperator.java
index 90f1a04..ba18801 100644
---
a/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateTernaryOperator.java
+++
b/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateTernaryOperator.java
@@ -24,14 +24,13 @@ import
org.apache.sysds.runtime.functionobjects.IndexFunction;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
-public class AggregateTernaryOperator extends Operator
+public class AggregateTernaryOperator extends MultiThreadedOperator
{
private static final long serialVersionUID = 4251745081160216784L;
public final ValueFunction binaryFn;
public final AggregateOperator aggOp;
public final IndexFunction indexFn;
- private final int k; //num threads
public AggregateTernaryOperator(ValueFunction inner, AggregateOperator
outer, IndexFunction ixfun) {
//default degree of parallelism is 1 (e.g., for distributed
operations)
@@ -44,10 +43,6 @@ public class AggregateTernaryOperator extends Operator
binaryFn = inner;
aggOp = outer;
indexFn = ixfun;
- k = numThreads;
- }
-
- public int getNumThreads() {
- return k;
+ _numThreads = numThreads;
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateUnaryOperator.java
b/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateUnaryOperator.java
index a88b68f..0896cbc 100644
---
a/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateUnaryOperator.java
+++
b/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateUnaryOperator.java
@@ -31,12 +31,11 @@ import org.apache.sysds.runtime.functionobjects.ReduceCol;
import org.apache.sysds.runtime.functionobjects.ReduceRow;
-public class AggregateUnaryOperator extends Operator {
+public class AggregateUnaryOperator extends MultiThreadedOperator {
private static final long serialVersionUID = 6690553323120787735L;
public final AggregateOperator aggOp;
public final IndexFunction indexFn;
- private final int k; //num threads
public AggregateUnaryOperator(AggregateOperator aop, IndexFunction iop)
{
@@ -54,13 +53,9 @@ public class AggregateUnaryOperator extends Operator {
|| aop.increOp.fn instanceof Minus);
aggOp = aop;
indexFn = iop;
- k = numThreads;
+ _numThreads = numThreads;
}
-
- public int getNumThreads(){
- return k;
- }
-
+
public boolean isRowAggregate() {
return indexFn instanceof ReduceCol;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
index b712187..f1b98cd 100644
---
a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
+++
b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
@@ -50,12 +50,11 @@ import org.apache.sysds.runtime.functionobjects.Power;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.functionobjects.Xor;
-public class BinaryOperator extends Operator {
+public class BinaryOperator extends MultiThreadedOperator {
private static final long serialVersionUID = -2547950181558989209L;
public final ValueFunction fn;
public final boolean commutative;
- private int _k = 1; // num threads
public BinaryOperator(ValueFunction p) {
this(p, 1);
@@ -70,17 +69,9 @@ public class BinaryOperator extends Operator {
fn = p;
commutative = p instanceof Plus || p instanceof Multiply || p
instanceof And || p instanceof Or ||
p instanceof Xor || p instanceof Minus1Multiply;
- _k = k;
+ _numThreads = k;
}
- public void setNumThreads(int k) {
- _k = k;
- }
-
- public int getNumThreads() {
- return _k;
- }
-
/**
* Method for getting the hop binary operator type for a given function
object.
* This is used in order to use a common code path for consistency
between
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java
b/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java
index 7e56e0e..f928f04 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java
@@ -23,7 +23,7 @@ package org.apache.sysds.runtime.matrix.operators;
import org.apache.sysds.runtime.functionobjects.CM;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
-public class CMOperator extends Operator
+public class CMOperator extends MultiThreadedOperator
{
private static final long serialVersionUID = 4126894676505115420L;
@@ -43,7 +43,6 @@ public class CMOperator extends Operator
public final ValueFunction fn;
public final AggregateOperationTypes aggOpType;
- public final int k;
public CMOperator(ValueFunction op, AggregateOperationTypes agg) {
this(op, agg, 1);
@@ -53,21 +52,17 @@ public class CMOperator extends Operator
super(true);
fn = op;
aggOpType = agg;
- k = numThreads;
+ _numThreads = numThreads;
}
public AggregateOperationTypes getAggOpType() {
return aggOpType;
}
-
- public int getNumThreads() {
- return k;
- }
-
+
public CMOperator setCMAggOp(int order) {
AggregateOperationTypes agg = getCMAggOpType(order);
ValueFunction fn = CM.getCMFnObject(aggOpType);
- return new CMOperator(fn, agg, k);
+ return new CMOperator(fn, agg, _numThreads);
}
public static AggregateOperationTypes getCMAggOpType ( int order ) {
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/operators/COVOperator.java
b/src/main/java/org/apache/sysds/runtime/matrix/operators/COVOperator.java
index 9d288db..19cd1d4 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/COVOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/COVOperator.java
@@ -22,12 +22,11 @@ package org.apache.sysds.runtime.matrix.operators;
import org.apache.sysds.runtime.functionobjects.COV;
-public class COVOperator extends Operator
+public class COVOperator extends MultiThreadedOperator
{
private static final long serialVersionUID = -8404264552880694469L;
public final COV fn;
- public final int k;
public COVOperator(COV op) {
this(op, 1);
@@ -36,10 +35,6 @@ public class COVOperator extends Operator
public COVOperator(COV op, int numThreads) {
super(true);
fn = op;
- k = numThreads;
- }
-
- public int getNumThreads() {
- return k;
+ _numThreads = numThreads;
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/operators/COVOperator.java
b/src/main/java/org/apache/sysds/runtime/matrix/operators/MultiThreadedOperator.java
similarity index 70%
copy from
src/main/java/org/apache/sysds/runtime/matrix/operators/COVOperator.java
copy to
src/main/java/org/apache/sysds/runtime/matrix/operators/MultiThreadedOperator.java
index 9d288db..cbe5a41 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/COVOperator.java
+++
b/src/main/java/org/apache/sysds/runtime/matrix/operators/MultiThreadedOperator.java
@@ -17,29 +17,26 @@
* under the License.
*/
-
package org.apache.sysds.runtime.matrix.operators;
-import org.apache.sysds.runtime.functionobjects.COV;
+public class MultiThreadedOperator extends Operator {
+ private static final long serialVersionUID = 3528522245925706630L;
-public class COVOperator extends Operator
-{
- private static final long serialVersionUID = -8404264552880694469L;
+ protected int _numThreads = 1;
- public final COV fn;
- public final int k;
-
- public COVOperator(COV op) {
- this(op, 1);
+ public MultiThreadedOperator() {
+ super();
}
-
- public COVOperator(COV op, int numThreads) {
- super(true);
- fn = op;
- k = numThreads;
+
+ public MultiThreadedOperator(boolean sparseSafeFlag) {
+ super(sparseSafeFlag);
}
-
+
public int getNumThreads() {
- return k;
+ return _numThreads;
+ }
+
+ public void setNumThreads(int numThreads) {
+ _numThreads = numThreads;
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/operators/ReorgOperator.java
b/src/main/java/org/apache/sysds/runtime/matrix/operators/ReorgOperator.java
index cc3d2e3..ba45a00 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/ReorgOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/ReorgOperator.java
@@ -22,11 +22,10 @@ package org.apache.sysds.runtime.matrix.operators;
import org.apache.sysds.runtime.functionobjects.IndexFunction;
-public class ReorgOperator extends Operator{
+public class ReorgOperator extends MultiThreadedOperator {
private static final long serialVersionUID = -5322516429026298404L;
public final IndexFunction fn;
- private final int k; //num threads
public ReorgOperator(IndexFunction p) {
//default degree of parallelism is 1
@@ -37,14 +36,10 @@ public class ReorgOperator extends Operator{
public ReorgOperator(IndexFunction p, int numThreads) {
super(true);
fn = p;
- k = numThreads;
- }
-
- public int getNumThreads() {
- return k;
+ _numThreads = numThreads;
}
public ReorgOperator setFn(IndexFunction fn) {
- return new ReorgOperator(fn, k);
+ return new ReorgOperator(fn, _numThreads);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/operators/ScalarOperator.java
b/src/main/java/org/apache/sysds/runtime/matrix/operators/ScalarOperator.java
index d33bbae..a5a24e6 100644
---
a/src/main/java/org/apache/sysds/runtime/matrix/operators/ScalarOperator.java
+++
b/src/main/java/org/apache/sysds/runtime/matrix/operators/ScalarOperator.java
@@ -38,13 +38,12 @@ import
org.apache.sysds.runtime.functionobjects.ValueFunction;
* Base class for all scalar operators.
*
*/
-public abstract class ScalarOperator extends Operator
+public abstract class ScalarOperator extends MultiThreadedOperator
{
private static final long serialVersionUID = 4547253761093455869L;
public final ValueFunction fn;
protected final double _constant;
- private int _k; //num threads
public ScalarOperator(ValueFunction p, double cst) {
this(p, cst, false);
@@ -63,21 +62,13 @@ public abstract class ScalarOperator extends Operator
|| (p instanceof Builtin &&
((Builtin)p).getBuiltinCode()==BuiltinCode.MIN && cst>=0));
fn = p;
_constant = cst;
- _k = numThreads;
+ _numThreads = numThreads;
}
public double getConstant() {
return _constant;
}
- public void setNumThreads(int k) {
- _k = k;
- }
-
- public int getNumThreads() {
- return _k;
- }
-
public abstract ScalarOperator setConstant(double cst);
public abstract ScalarOperator setConstant(double cst, int numThreads);
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/operators/TernaryOperator.java
b/src/main/java/org/apache/sysds/runtime/matrix/operators/TernaryOperator.java
index 92b142d..8fc6dd1 100644
---
a/src/main/java/org/apache/sysds/runtime/matrix/operators/TernaryOperator.java
+++
b/src/main/java/org/apache/sysds/runtime/matrix/operators/TernaryOperator.java
@@ -25,21 +25,16 @@ import
org.apache.sysds.runtime.functionobjects.MinusMultiply;
import org.apache.sysds.runtime.functionobjects.PlusMultiply;
import org.apache.sysds.runtime.functionobjects.TernaryValueFunction;
-public class TernaryOperator extends Operator{
+public class TernaryOperator extends MultiThreadedOperator {
private static final long serialVersionUID = 3456088891054083634L;
public final TernaryValueFunction fn;
- private final int _k; // num threads
public TernaryOperator(TernaryValueFunction p, int numThreads) {
//ternaryop is sparse-safe iff (op 0 0 0) == 0
super (p instanceof PlusMultiply || p instanceof MinusMultiply
|| p instanceof IfElse);
fn = p;
- _k = numThreads;
- }
-
- public int getNumThreads() {
- return _k;
+ _numThreads = numThreads;
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/operators/UnaryOperator.java
b/src/main/java/org/apache/sysds/runtime/matrix/operators/UnaryOperator.java
index fde3f11..d1d2abc 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/UnaryOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/UnaryOperator.java
@@ -24,12 +24,11 @@ import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
-public class UnaryOperator extends Operator
+public class UnaryOperator extends MultiThreadedOperator
{
private static final long serialVersionUID = 2441990876648978637L;
public final ValueFunction fn;
- private final int k; //num threads
private final boolean inplace;
public UnaryOperator(ValueFunction p) {
@@ -45,14 +44,10 @@ public class UnaryOperator extends Operator
|| ((Builtin)p).bFunc==Builtin.BuiltinCode.SQRT ||
((Builtin)p).bFunc==Builtin.BuiltinCode.SPROP
|| ((Builtin)p).bFunc==Builtin.BuiltinCode.LOG_NZ ||
((Builtin)p).bFunc==Builtin.BuiltinCode.SIGN) );
fn = p;
- k = numThreads;
+ _numThreads = numThreads;
inplace = inPlace;
}
- public int getNumThreads() {
- return k;
- }
-
public boolean isInplace() {
return inplace;
}