This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 0a11235 [SYSTEMDS-2855] Fix missing federated col-partitioned matrix
multiply
0a11235 is described below
commit 0a112356e059c20baf609cd6c1f06a232ddd2f4c
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Feb 9 17:30:47 2021 +0100
[SYSTEMDS-2855] Fix missing federated col-partitioned matrix multiply
This patch adds the missing support for federated matrix multiplication
for column partitioned federated matrices. In addition, we changed the
log level of federated request command from info to debug, for reduced
default output in local tests.
---
.../federated/FederatedWorkerHandler.java | 8 ++++----
.../controlprogram/federated/FederationMap.java | 19 +++++++++++--------
.../fed/AggregateBinaryFEDInstruction.java | 13 +++++++++++++
.../runtime/instructions/fed/FEDInstructionUtils.java | 2 +-
4 files changed, 29 insertions(+), 13 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index a75c97a..57d5ba3 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -96,10 +96,10 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
for(int i = 0; i < requests.length; i++) {
FederatedRequest request = requests[i];
- if(log.isInfoEnabled()) {
- log.info("Executing command " + (i + 1) + "/" +
requests.length + ": " + request.getType().name());
- if(log.isDebugEnabled()) {
- log.debug("full command: " +
request.toString());
+ if(log.isDebugEnabled()) {
+ log.debug("Executing command " + (i + 1) + "/"
+ requests.length + ": " + request.getType().name());
+ if(log.isTraceEnabled()) {
+ log.trace("full command: " +
request.toString());
}
}
PrivacyMonitor.setCheckPrivacy(request.checkPrivacy());
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index e933979..4f70dd0 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -143,18 +143,21 @@ public class FederationMap {
// prepare broadcast id and pin input
long id = FederationUtils.getNextFedDataID();
CacheBlock cb = data.acquireReadAndRelease();
-
+
// prepare indexing ranges
int[][] ix = new int[_fedMap.size()][];
int pos = 0;
for(Entry<FederatedRange, FederatedData> e :
_fedMap.entrySet()) {
- int rl, ru, cl, cu;
- // TODO Handle different cases than ROW aligned
Matrices.
- rl = transposed ? 0 : e.getKey().getBeginDimsInt()[0];
- ru = transposed ? cb.getNumRows() - 1 :
e.getKey().getEndDimsInt()[0] - 1;
- cl = transposed ? e.getKey().getBeginDimsInt()[0] : 0;
- cu = transposed ? e.getKey().getEndDimsInt()[0] - 1 :
cb.getNumColumns() - 1;
- ix[pos++] = new int[] {rl, ru, cl, cu};
+ int beg = e.getKey().getBeginDimsInt()[(_type ==
FType.ROW ? 0 : 1)];
+ int end = e.getKey().getEndDimsInt()[(_type ==
FType.ROW ? 0 : 1)];
+ int nr = _type == FType.ROW ? cb.getNumRows() :
cb.getNumColumns();
+ int nc = _type == FType.ROW ? cb.getNumColumns() :
cb.getNumRows();
+ int rl = transposed ? 0 : beg;
+ int ru = transposed ? nr - 1 : end - 1;
+ int cl = transposed ? beg : 0;
+ int cu = transposed ? end - 1 : nc - 1;
+ ix[pos++] = _type == FType.ROW ?
+ new int[] {rl, ru, cl, cu} : new int[] {cl, cu,
rl, ru};
}
// multi-threaded block slicing and federation request creation
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 6ed642e..12616ed 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -110,6 +110,19 @@ public class AggregateBinaryFEDInstruction extends
BinaryFEDInstruction {
MatrixBlock ret = FederationUtils.aggAdd(tmp);
ec.setMatrixOutput(output.getName(), ret);
}
+ //#3 col-federated matrix vector multiplication
+ else if (mo1.isFederated(FType.COL)) {// VM + MM
+ //construct commands: broadcast rhs, fed mv, retrieve
results
+ FederatedRequest[] fr1 =
mo1.getFedMapping().broadcastSliced(mo2, true);
+ FederatedRequest fr2 =
FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2}, new
long[]{mo1.getFedMapping().getID(), fr1[0].getID()});
+ FederatedRequest fr3 = new
FederatedRequest(RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr4 =
mo1.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr2.getID());
+ //execute federated operations and aggregate
+ Future<FederatedResponse>[] tmp =
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
+ MatrixBlock ret = FederationUtils.aggAdd(tmp);
+ ec.setMatrixOutput(output.getName(), ret);
+ }
else { //other combinations
throw new DMLRuntimeException("Federated
AggregateBinary not supported with the "
+ "following federated objects:
"+mo1.isFederated()+":"+mo1.getFedMapping()
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 845f8a4..6c0e3ba 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -78,7 +78,7 @@ public class FEDInstructionUtils {
if( instruction.input1.isMatrix() &&
instruction.input2.isMatrix() ) {
MatrixObject mo1 =
ec.getMatrixObject(instruction.input1);
MatrixObject mo2 =
ec.getMatrixObject(instruction.input2);
- if (mo1.isFederated(FType.ROW) ||
mo2.isFederated(FType.ROW)) {
+ if (mo1.isFederated(FType.ROW) ||
mo2.isFederated(FType.ROW) || mo1.isFederated(FType.COL)) {
fedinst =
AggregateBinaryFEDInstruction.parseInstruction(inst.getInstructionString());
}
}