This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 0a5230a [MINOR] Fix federated ternary-aggregate instructions (tak+*)
0a5230a is described below
commit 0a5230ab5bb72bb0b61656f74ed4a41c24fa69ce
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun Mar 13 01:33:18 2022 +0100
[MINOR] Fix federated ternary-aggregate instructions (tak+*)
After the improved ternary aggregate rewrites, remaining bugs of
federated ternary-aggregate instructions surfaced. This patch fixes the
immediate bug but leaves a fixme for another branch that requires new
abstractions for array-based federated requests.
---
.../instructions/fed/AggregateTernaryFEDInstruction.java | 16 ++++++++++------
.../federated/algorithms/FederatedYL2SVMTest.java | 3 ++-
.../functions/privacy/algorithms/FederatedL2SVMTest.java | 10 ++++++++--
3 files changed, 20 insertions(+), 9 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
index a9efb89..b253f69 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
@@ -96,14 +96,18 @@ public class AggregateTernaryFEDInstruction extends
ComputationFEDInstruction {
}
}
else if(mo1.isFederated() && mo2.isFederated()
- && mo1.getFedMapping().isAligned(mo2.getFedMapping(),
false) && mo3 == null) {
- FederatedRequest fr1 =
mo1.getFedMapping().broadcast(ec.getScalarInput(input3));
+ && mo1.getFedMapping().isAligned(mo2.getFedMapping(),
false)) {
+ FederatedRequest[] fr1 = (mo3 == null) ?
+ new FederatedRequest[]
{mo1.getFedMapping().broadcast(ec.getScalarInput(input3))} :
+ mo1.getFedMapping().broadcastSliced(mo3, false);
FederatedRequest fr2 =
FederationUtils.callInstruction(instString, output,
new CPOperand[] {input1, input2, input3},
- new long[] {mo1.getFedMapping().getID(),
mo2.getFedMapping().getID(), fr1.getID()}, true);
+ new long[] {mo1.getFedMapping().getID(),
mo2.getFedMapping().getID(), fr1[0].getID()}, true);
FederatedRequest fr3 = new
FederatedRequest(RequestType.GET_VAR, fr2.getID());
- FederatedRequest fr4 =
mo2.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
- Future<FederatedResponse>[] tmp =
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
+ FederatedRequest fr4 =
mo2.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr2.getID());
+ Future<FederatedResponse>[] tmp = (mo3 == null) ?
+ mo1.getFedMapping().execute(getTID(), fr1[0],
fr2, fr3, fr4) :
+ mo1.getFedMapping().execute(getTID(), fr1, fr2,
fr3, fr4);
if(output.getDataType().isScalar()) {
double sum = 0;
@@ -121,6 +125,7 @@ public class AggregateTernaryFEDInstruction extends
ComputationFEDInstruction {
throw new DMLRuntimeException("Not Implemented
Federated Ternary Variation");
}
} else if(mo1.isFederatedExcept(FType.BROADCAST) &&
input3.isMatrix() && mo3 != null) {
+ //FIXME cleanup fr2[0] below for result correctness,
requires new primitives
FederatedRequest[] fr1 =
mo1.getFedMapping().broadcastSliced(mo3, false);
FederatedRequest[] fr2 =
mo1.getFedMapping().broadcastSliced(mo2, false);
FederatedRequest fr3 =
FederationUtils.callInstruction(getInstructionString(), output,
@@ -138,7 +143,6 @@ public class AggregateTernaryFEDInstruction extends
ComputationFEDInstruction {
catch(Exception e) {
throw new
DMLRuntimeException("Federated Get data failed with exception on
TernaryFedInstruction", e);
}
-
ec.setScalarOutput(output.getName(), new
DoubleObject(sum));
}
else {
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
index 42fee13..b8eef26 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
@@ -130,7 +130,8 @@ public class FederatedYL2SVMTest extends AutomatedTestBase {
// Run actual dml script with federated matrixz
fullDMLScriptName = HOME + testName + ".dml";
- programArgs = new String[] {"-stats", "-nvargs", "in_X1=" +
TestUtils.federatedAddress(port1, input("X1")),
+ programArgs = new String[] {"-stats", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2,
input("X2")), "rows=" + rows, "cols=" + cols,
"in_Y1=" + TestUtils.federatedAddress(port1,
input("Y1")),
"in_Y2=" + TestUtils.federatedAddress(port2,
input("Y2")), "out=" + output("Z")};
diff --git
a/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java
b/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java
index 67b790f..2b7eef3 100644
---
a/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java
@@ -36,6 +36,7 @@ import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
+import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -275,7 +276,9 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
// Require Federated Workers to return matrix
- @Test public void federatedL2SVMCPPrivateAggregationX1Exception() {
+ @Test
+ @Ignore //Invalid with new plan
+ public void federatedL2SVMCPPrivateAggregationX1Exception() {
rows = 1000;
cols = 1;
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
@@ -284,7 +287,10 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
PrivacyLevel.PrivateAggregation);
}
- @Test public void federatedL2SVMCPPrivateAggregationX2Exception() {
+
+ @Test
+ @Ignore //Invalid with new plan
+ public void federatedL2SVMCPPrivateAggregationX2Exception() {
rows = 1000;
cols = 1;
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();