This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 0a5230a  [MINOR] Fix federated ternary-aggregate instructions (tak+*)
0a5230a is described below

commit 0a5230ab5bb72bb0b61656f74ed4a41c24fa69ce
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun Mar 13 01:33:18 2022 +0100

    [MINOR] Fix federated ternary-aggregate instructions (tak+*)
    
    After the improved ternary aggregate rewrites, remaining bugs of
    federated ternary-aggregate instructions surfaced. This patch fixes the
    immediate bug but leaves a fixme for another branch that requires new
    abstractions for array-based federated requests.
---
 .../instructions/fed/AggregateTernaryFEDInstruction.java | 16 ++++++++++------
 .../federated/algorithms/FederatedYL2SVMTest.java        |  3 ++-
 .../functions/privacy/algorithms/FederatedL2SVMTest.java | 10 ++++++++--
 3 files changed, 20 insertions(+), 9 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
index a9efb89..b253f69 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
@@ -96,14 +96,18 @@ public class AggregateTernaryFEDInstruction extends 
ComputationFEDInstruction {
                        }
                }
                else if(mo1.isFederated() && mo2.isFederated()
-                       && mo1.getFedMapping().isAligned(mo2.getFedMapping(), 
false) && mo3 == null) {
-                       FederatedRequest fr1 = 
mo1.getFedMapping().broadcast(ec.getScalarInput(input3));
+                       && mo1.getFedMapping().isAligned(mo2.getFedMapping(), 
false)) {
+                       FederatedRequest[] fr1 = (mo3 == null) ?
+                               new FederatedRequest[] 
{mo1.getFedMapping().broadcast(ec.getScalarInput(input3))} :
+                               mo1.getFedMapping().broadcastSliced(mo3, false);
                        FederatedRequest fr2 = 
FederationUtils.callInstruction(instString, output,
                                new CPOperand[] {input1, input2, input3},
-                               new long[] {mo1.getFedMapping().getID(), 
mo2.getFedMapping().getID(), fr1.getID()}, true);
+                               new long[] {mo1.getFedMapping().getID(), 
mo2.getFedMapping().getID(), fr1[0].getID()}, true);
                        FederatedRequest fr3 = new 
FederatedRequest(RequestType.GET_VAR, fr2.getID());
-                       FederatedRequest fr4 = 
mo2.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
-                       Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
+                       FederatedRequest fr4 = 
mo2.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr2.getID());
+                       Future<FederatedResponse>[] tmp = (mo3 == null) ?
+                               mo1.getFedMapping().execute(getTID(), fr1[0], 
fr2, fr3, fr4) :
+                               mo1.getFedMapping().execute(getTID(), fr1, fr2, 
fr3, fr4);
 
                        if(output.getDataType().isScalar()) {
                                double sum = 0;
@@ -121,6 +125,7 @@ public class AggregateTernaryFEDInstruction extends 
ComputationFEDInstruction {
                                throw new DMLRuntimeException("Not Implemented 
Federated Ternary Variation");
                        }
                } else if(mo1.isFederatedExcept(FType.BROADCAST) && 
input3.isMatrix() && mo3 != null) {
+                       //FIXME cleanup fr2[0] below for result correctness, 
requires new primitives
                        FederatedRequest[] fr1 = 
mo1.getFedMapping().broadcastSliced(mo3, false);
                        FederatedRequest[] fr2 = 
mo1.getFedMapping().broadcastSliced(mo2, false);
                        FederatedRequest fr3 = 
FederationUtils.callInstruction(getInstructionString(), output,
@@ -138,7 +143,6 @@ public class AggregateTernaryFEDInstruction extends 
ComputationFEDInstruction {
                                        catch(Exception e) {
                                                throw new 
DMLRuntimeException("Federated Get data failed with exception on 
TernaryFedInstruction", e);
                                        }
-
                                ec.setScalarOutput(output.getName(), new 
DoubleObject(sum));
                        }
                        else {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
index 42fee13..b8eef26 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
@@ -130,7 +130,8 @@ public class FederatedYL2SVMTest extends AutomatedTestBase {
 
                // Run actual dml script with federated matrixz
                fullDMLScriptName = HOME + testName + ".dml";
-               programArgs = new String[] {"-stats", "-nvargs", "in_X1=" + 
TestUtils.federatedAddress(port1, input("X1")),
+               programArgs = new String[] {"-stats", "-nvargs",
+                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
                        "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")), "rows=" + rows, "cols=" + cols,
                        "in_Y1=" + TestUtils.federatedAddress(port1, 
input("Y1")),
                        "in_Y2=" + TestUtils.federatedAddress(port2, 
input("Y2")), "out=" + output("Z")};
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java
 
b/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java
index 67b790f..2b7eef3 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java
@@ -36,6 +36,7 @@ import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 import org.junit.Assert;
+import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -275,7 +276,9 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
 
        // Require Federated Workers to return matrix
 
-       @Test public void federatedL2SVMCPPrivateAggregationX1Exception()  {
+       @Test
+       @Ignore //Invalid with new plan
+       public void federatedL2SVMCPPrivateAggregationX1Exception()  {
                rows = 1000;
                cols = 1;
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
@@ -284,7 +287,10 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
                        PrivacyLevel.PrivateAggregation);
        }
 
-       @Test public void federatedL2SVMCPPrivateAggregationX2Exception()  {
+       
+       @Test
+       @Ignore //Invalid with new plan
+       public void federatedL2SVMCPPrivateAggregationX2Exception()  {
                rows = 1000;
                cols = 1;
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();

Reply via email to