This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new d9f9723  [SYSTEMDS-2859] Fix federated broadcasting w/ single 
federated worker
d9f9723 is described below

commit d9f9723198e01f3470c1f98492f647dada312a1e
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Feb 11 19:04:08 2021 +0100

    [SYSTEMDS-2859] Fix federated broadcasting w/ single federated worker
    
    The newly introduced federated partition type FULL caused various
    problems on sliced broadcasting because this primitive only expected ROW
    and COL partitioning. We now handle this by simply using a normal
    broadcast without slicing for both correctness and performance.
    
    Furthermore, this patch also adds dedicated tests with a single worker
    for both KMeans and L2SVM.
---
 .../controlprogram/federated/FederationMap.java      |  3 +++
 .../federated/algorithms/FederatedKmeansTest.java    | 20 +++++++++++++-------
 .../federated/algorithms/FederatedL2SVMTest.java     | 16 +++++++++++-----
 .../functions/federated/FederatedKmeansTest.dml      |  7 ++++++-
 .../federated/FederatedKmeansTestReference.dml       | 11 ++++++++---
 .../functions/federated/FederatedL2SVMTest.dml       | 13 +++++++++++--
 .../federated/FederatedL2SVMTestReference.dml        | 11 +++++++++--
 7 files changed, 61 insertions(+), 20 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 4f70dd0..7278123 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -140,6 +140,9 @@ public class FederationMap {
         * @return array of federated requests corresponding to federated data
         */
        public FederatedRequest[] broadcastSliced(CacheableData<?> data, 
boolean transposed) {
+               if( _type == FType.FULL )
+                       return new FederatedRequest[]{broadcast(data)};
+               
                // prepare broadcast id and pin input
                long id = FederationUtils.getNextFedDataID();
                CacheBlock cb = data.acquireReadAndRelease();
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
index e352b5a..2f8ce8e 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
@@ -76,17 +76,22 @@ public class FederatedKmeansTest extends AutomatedTestBase {
        }
 
        @Test
-       public void federatedKmeansSinglenode() {
-               federatedKmeans(Types.ExecMode.SINGLE_NODE);
+       public void federatedKmeans2Singlenode() {
+               federatedKmeans(Types.ExecMode.SINGLE_NODE, false);
        }
 
        @Test
+       public void federatedKmeans1Singlenode() {
+               federatedKmeans(Types.ExecMode.SINGLE_NODE, true);
+       }
+       
+       @Test
        @Ignore
-       public void federatedKmeansHybrid() {
-               federatedKmeans(Types.ExecMode.HYBRID);
+       public void federatedKmeans2Hybrid() {
+               federatedKmeans(Types.ExecMode.HYBRID, false);
        }
 
-       public void federatedKmeans(Types.ExecMode execMode) {
+       public void federatedKmeans(Types.ExecMode execMode, boolean 
singleWorker) {
                ExecMode platformOld = setExecMode(execMode);
 
                getAndLoadTestConfiguration(TEST_NAME);
@@ -112,14 +117,15 @@ public class FederatedKmeansTest extends 
AutomatedTestBase {
 
                // Run reference dml script with normal matrix
                fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
-               programArgs = new String[] {"-args", input("X1"), input("X2"), 
String.valueOf(runs), expected("Z")};
+               programArgs = new String[] {"-args", input("X1"), input("X2"),
+                       String.valueOf(singleWorker).toUpperCase(), 
String.valueOf(runs), expected("Z")};
                runTest(true, false, null, -1);
 
                // Run actual dml script with federated matrix
                fullDMLScriptName = HOME + TEST_NAME + ".dml";
                programArgs = new String[] {"-stats", "-nvargs", "in_X1=" + 
TestUtils.federatedAddress(port1, input("X1")),
                        "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")), "rows=" + rows, "cols=" + cols,
-                       "runs=" + String.valueOf(runs), "out=" + output("Z")};
+                       "single=" + String.valueOf(singleWorker).toUpperCase(), 
"runs=" + String.valueOf(runs), "out=" + output("Z")};
 
                for(int i = 0; i < rep; i++) {
                        ParForProgramBlock.resetWorkerIDs();
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java
index 95e5ba4..f7040e6 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java
@@ -61,8 +61,13 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
        }
 
        @Test
-       public void federatedL2SVMCP() {
-               federatedL2SVM(Types.ExecMode.SINGLE_NODE);
+       public void federatedL2SVM2CP() {
+               federatedL2SVM(Types.ExecMode.SINGLE_NODE, false);
+       }
+       
+       @Test
+       public void federatedL2SVM1CP() {
+               federatedL2SVM(Types.ExecMode.SINGLE_NODE, true);
        }
 
        /*
@@ -71,7 +76,7 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
         * @Test public void federatedL2SVMSP() { 
federatedL2SVM(Types.ExecMode.SPARK); }
         */
 
-       public void federatedL2SVM(Types.ExecMode execMode) {
+       public void federatedL2SVM(Types.ExecMode execMode, boolean 
singleWorker) {
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
                Types.ExecMode platformOld = rtplatform;
                rtplatform = execMode;
@@ -107,14 +112,15 @@ public class FederatedL2SVMTest extends AutomatedTestBase 
{
 
                // Run reference dml script with normal matrix
                fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
-               programArgs = new String[] {"-args", input("X1"), input("X2"), 
input("Y"), expected("Z")};
+               programArgs = new String[] {"-args", input("X1"), input("X2"), 
input("Y"),
+                       String.valueOf(singleWorker).toUpperCase(), 
expected("Z")};
                runTest(true, false, null, -1);
 
                // Run actual dml script with federated matrix
                fullDMLScriptName = HOME + TEST_NAME + ".dml";
                programArgs = new String[] {"-nvargs", "in_X1=" + 
TestUtils.federatedAddress(port1, input("X1")),
                        "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")), "rows=" + rows, "cols=" + cols,
-                       "in_Y=" + input("Y"), "out=" + output("Z")};
+                       "in_Y=" + input("Y"), "single=" + 
String.valueOf(singleWorker).toUpperCase(), "out=" + output("Z")};
                runTest(true, false, null, -1);
 
                // compare via files
diff --git a/src/test/scripts/functions/federated/FederatedKmeansTest.dml 
b/src/test/scripts/functions/federated/FederatedKmeansTest.dml
index 13e89ea..017ac51 100644
--- a/src/test/scripts/functions/federated/FederatedKmeansTest.dml
+++ b/src/test/scripts/functions/federated/FederatedKmeansTest.dml
@@ -19,7 +19,12 @@
 #
 #-------------------------------------------------------------
 
-X = federated(addresses=list($in_X1, $in_X2),
+if( $single )
+  X = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows/2, 
$cols)))
+else
+  X = federated(addresses=list($in_X1, $in_X2),
     ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), 
list($rows, $cols)))
+
 [C,Y] = kmeans(X=X, k=4, runs=$runs, max_iter=150)
+
 write(C, $out)
diff --git 
a/src/test/scripts/functions/federated/FederatedKmeansTestReference.dml 
b/src/test/scripts/functions/federated/FederatedKmeansTestReference.dml
index e72c9b5..3046eae 100644
--- a/src/test/scripts/functions/federated/FederatedKmeansTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedKmeansTestReference.dml
@@ -19,6 +19,11 @@
 #
 #-------------------------------------------------------------
 
-X = rbind(read($1), read($2))
-[C,Y] = kmeans(X=X, k=4, runs=$3, max_iter=150)
-write(C, $4)
+if( $3 )
+  X = read($1)
+else
+  X = rbind(read($1), read($2))
+
+[C,Y] = kmeans(X=X, k=4, runs=$4, max_iter=150)
+
+write(C, $5)
diff --git a/src/test/scripts/functions/federated/FederatedL2SVMTest.dml 
b/src/test/scripts/functions/federated/FederatedL2SVMTest.dml
index b5a4228..7ae1a57 100644
--- a/src/test/scripts/functions/federated/FederatedL2SVMTest.dml
+++ b/src/test/scripts/functions/federated/FederatedL2SVMTest.dml
@@ -19,8 +19,17 @@
 #
 #-------------------------------------------------------------
 
-X = federated(addresses=list($in_X1, $in_X2),
-    ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), 
list($rows, $cols)))
 Y = read($in_Y)
+
+if( $single ) {
+  X = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows/2, 
$cols)))
+  Y = Y[1:nrow(X),]
+}
+else {
+  X = federated(addresses=list($in_X1, $in_X2),
+    ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), 
list($rows, $cols)))
+}
+
 model = l2svm(X=X,  Y=Y, intercept = FALSE, epsilon = 1e-12, lambda = 1, 
maxIterations = 100)
+
 write(model, $out)
diff --git 
a/src/test/scripts/functions/federated/FederatedL2SVMTestReference.dml 
b/src/test/scripts/functions/federated/FederatedL2SVMTestReference.dml
index 0b028d8..b5439d4 100644
--- a/src/test/scripts/functions/federated/FederatedL2SVMTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedL2SVMTestReference.dml
@@ -19,7 +19,14 @@
 #
 #-------------------------------------------------------------
 
-X = rbind(read($1), read($2))
 Y = read($3)
+if( $4 ) {
+  X = read($1)
+  Y = Y[1:nrow(X),]
+}
+else
+  X = rbind(read($1), read($2))
+
 model = l2svm(X=X,  Y=Y, intercept = FALSE, epsilon = 1e-12, lambda = 1, 
maxIterations = 100)
-write(model, $4)
+
+write(model, $5)

Reply via email to