This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new e964be2cca [MINOR] Add Matrix Multiplication Chain Test and Fix 
Runtime Bug
e964be2cca is described below

commit e964be2cca10a11e357f777a1009dd772f99a5a5
Author: sebwrede <[email protected]>
AuthorDate: Fri Aug 26 16:24:16 2022 +0200

    [MINOR] Add Matrix Multiplication Chain Test and Fix Runtime Bug
    
    Closes #1690.
---
 .../fed/AggregateBinaryFEDInstruction.java         | 24 ++++++++++++++++++-
 .../fedplanning/FederatedMultiplyPlanningTest.java | 12 +++++++++-
 .../FederatedMultiplyPlanningTest12.dml            | 27 ++++++++++++++++++++++
 .../FederatedMultiplyPlanningTest12Reference.dml   | 26 +++++++++++++++++++++
 4 files changed, 87 insertions(+), 2 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 9340e9fb12..1a8115ee94 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -124,7 +124,13 @@ public class AggregateBinaryFEDInstruction extends 
BinaryFEDInstruction {
                                setOutputFedMapping(mo1.getFedMapping(), mo1, 
mo2, fr2.getID(), ec);
                        }
                        else {
-                               aggregateLocally(mo1.getFedMapping(), 
mo1.isFederated(FType.PART), ec, fr1, fr2);
+                               boolean isDoubleBroadcast = 
(mo1.isFederated(FType.BROADCAST) && mo2.isFederated(FType.BROADCAST));
+                               if (isDoubleBroadcast){
+                                       
aggregateLocallySingleWorker(mo1.getFedMapping(), ec, fr1, fr2);
+                               }
+                               else{
+                                       aggregateLocally(mo1.getFedMapping(), 
false, ec, fr1, fr2);
+                               }
                        }
                }
                //#2 vector - federated matrix multiplication
@@ -231,4 +237,20 @@ public class AggregateBinaryFEDInstruction extends 
BinaryFEDInstruction {
                        ret = FederationUtils.bind(ffr, false);
                ec.setMatrixOutput(output.getName(), ret);
        }
+
+       private void aggregateLocallySingleWorker(FederationMap fedMap, 
ExecutionContext ec, FederatedRequest... fr) {
+               //create GET calls on output
+               long callInstID = fr[fr.length - 1].getID();
+               FederatedRequest frG = new 
FederatedRequest(RequestType.GET_VAR, callInstID);
+               FederatedRequest frC = fedMap.cleanup(getTID(), callInstID);
+               //execute federated operations
+               Future<FederatedResponse>[] ffr = fedMap.execute(getTID(), 
ArrayUtils.addAll(fr, frG, frC));
+               try {
+                       //use only one response (all responses contain the same 
result)
+                       MatrixBlock ret = (MatrixBlock) 
ffr[0].get().getData()[0];
+                       ec.setMatrixOutput(output.getName(), ret);
+               } catch(Exception ex){
+                       throw new DMLRuntimeException(ex);
+               }
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index 2477bdef85..415cd21178 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -56,6 +56,7 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
        private final static String TEST_NAME_9 = 
"FederatedMultiplyPlanningTest9";
        private final static String TEST_NAME_10 = 
"FederatedMultiplyPlanningTest10";
        private final static String TEST_NAME_11 = 
"FederatedMultiplyPlanningTest11";
+       private final static String TEST_NAME_12 = 
"FederatedMultiplyPlanningTest12";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedMultiplyPlanningTest.class.getSimpleName() + "/";
        private static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, 
"SystemDS-config-cost-based.xml");
 
@@ -79,6 +80,7 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
                addTestConfiguration(TEST_NAME_9, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_9, new String[] {"Z.scalar"}));
                addTestConfiguration(TEST_NAME_10, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_10, new String[] {"Z"}));
                addTestConfiguration(TEST_NAME_11, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_11, new String[] {"Z"}));
+               addTestConfiguration(TEST_NAME_12, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_12, new String[] {"Z"}));
        }
 
        @Parameterized.Parameters
@@ -161,6 +163,14 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
                federatedTwoMatricesSingleNodeTest(TEST_NAME_11, 
expectedHeavyHitters);
        }
 
+       @Test
+       public void federatedMultiplyPlanningTest12(){
+               String[] expectedHeavyHitters = new String[]{"fed_fedinit"};
+               rows = 30;
+               cols = 30;
+               federatedTwoMatricesSingleNodeTest(TEST_NAME_12, 
expectedHeavyHitters);
+       }
+
        private void writeStandardMatrix(String matrixName, long seed){
                writeStandardMatrix(matrixName, seed, new 
PrivacyConstraint(PrivacyConstraint.PrivacyLevel.PrivateAggregation));
        }
@@ -215,7 +225,7 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
                        writeColStandardMatrix("W1", 76, null);
                        writeColStandardMatrix("W2", 11, null);
                }
-               else if ( testName.equals(TEST_NAME_10) ){
+               else if ( testName.equals(TEST_NAME_10) || 
testName.equals(TEST_NAME_12) ){
                        writeStandardMatrix("X1", 42, null);
                        writeStandardMatrix("X2", 1340, null);
                }
diff --git 
a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest12.dml
 
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest12.dml
new file mode 100644
index 0000000000..3ef9909e68
--- /dev/null
+++ 
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest12.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+z0 = federated(addresses=list($X1, $X2),
+              ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), 
list($r, $c)))
+z1 = z0 %*% z0
+z2 = z1 %*% z1
+print(toString(z2))
+write(z2, $Z)
diff --git 
a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest12Reference.dml
 
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest12Reference.dml
new file mode 100644
index 0000000000..652172c2a8
--- /dev/null
+++ 
b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest12Reference.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+z0 = rbind(read($X1), read($X2))
+z1 = z0 %*% z0
+z2 = z1 %*% z1
+print(toString(z2))
+write(z2, $Z)

Reply via email to