This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new c14198b  [SYSTEMDS-2542] Federated rowProds, colProds
c14198b is described below

commit c14198b75dcee7c77b4f68ca88807cb00d7ffe97
Author: Olga <[email protected]>
AuthorDate: Mon Apr 12 18:45:54 2021 +0200

    [SYSTEMDS-2542] Federated rowProds, colProds
    
    - Added fed col and row prod
    - newlines and formatting
    
    Closes #1248
---
 .../controlprogram/federated/FederationUtils.java  | 30 ++++++++++++++++++-
 .../primitives/FederatedColAggregateTest.java      | 35 +++++++++++++++-------
 .../primitives/FederatedRowAggregateTest.java      | 23 ++++++++++----
 ...FederatedSumTest.dml => FederatedColProdTest.R} | 24 +++++++--------
 ...atedRowVarTest.dml => FederatedColProdTest.dml} |  4 +--
 ...mTest.dml => FederatedColProdTestReference.dml} | 16 +++-------
 .../federated/aggregate/FederatedColVarTest.dml    |  2 +-
 ...atedRowVarTest.dml => FederatedRowProdTest.dml} |  4 +--
 ...mTest.dml => FederatedRowProdTestReference.dml} | 16 +++-------
 .../federated/aggregate/FederatedRowVarTest.dml    |  2 +-
 .../federated/aggregate/FederatedSumTest.dml       |  2 +-
 11 files changed, 97 insertions(+), 61 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index f569364..94fe0bd 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -203,6 +203,32 @@ public class FederationUtils {
                }
        }
 
+       public static MatrixBlock aggProd(Future<FederatedResponse>[] ffr, 
FederationMap fedMap, AggregateUnaryOperator aop) {
+               try {
+                       boolean rowFed = fedMap.getType() == 
FederationMap.FType.ROW;
+                       MatrixBlock ret = rowFed ?
+                               new MatrixBlock(ffr.length, (int) 
fedMap.getFederatedRanges()[0].getEndDims()[1], 1.0) :
+                               new MatrixBlock((int) 
fedMap.getFederatedRanges()[0].getEndDims()[0], ffr.length, 1.0);
+                       MatrixBlock res = rowFed ?
+                               new MatrixBlock(1, (int) 
fedMap.getFederatedRanges()[0].getEndDims()[1], 1.0) :
+                               new MatrixBlock((int) 
fedMap.getFederatedRanges()[0].getEndDims()[0], 1, 1.0);
+
+                       for(int i = 0; i < ffr.length; i++) {
+                               MatrixBlock tmp = (MatrixBlock) 
ffr[i].get().getData()[0];
+                               if(rowFed)
+                                       ret.copy(i, i, 0, 
ret.getNumColumns()-1, tmp, true);
+                               else
+                                       ret.copy(0, ret.getNumRows()-1, i, i, 
tmp, true);
+                       }
+
+                       LibMatrixAgg.aggregateUnaryMatrix(ret, res, aop);
+                       return res;
+               }
+               catch (Exception ex) {
+                       throw new DMLRuntimeException(ex);
+               }
+       }
+
        public static MatrixBlock aggMinMaxIndex(Future<FederatedResponse>[] 
ffr, boolean isMin, FederationMap map) {
                try {
                        MatrixBlock prev = (MatrixBlock) 
ffr[0].get().getData()[0];
@@ -410,6 +436,8 @@ public class FederationUtils {
                        return aggAdd(ffr);
                else if( aop.aggOp.increOp.fn instanceof Mean )
                        return aggMean(ffr, map);
+               else if(aop.aggOp.increOp.fn instanceof Multiply)
+                       return aggProd(ffr, map, aop);
                else if (aop.aggOp.increOp.fn instanceof Builtin) {
                        if ((((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() 
== BuiltinCode.MIN ||
                                ((Builtin) 
aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAX)) {
@@ -419,7 +447,7 @@ public class FederationUtils {
                        else if((((Builtin) 
aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MININDEX)
                                || (((Builtin) 
aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAXINDEX)) {
                                boolean isMin = ((Builtin) 
aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MININDEX;
-                               return aggMinMaxIndex(ffr,isMin, map);
+                               return aggMinMaxIndex(ffr, isMin, map);
                        }
                        else throw new DMLRuntimeException("Unsupported 
aggregation operator: "
                                        + 
aop.aggOp.increOp.fn.getClass().getSimpleName());
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedColAggregateTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedColAggregateTest.java
index 1bdcb5b..870e7c2 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedColAggregateTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedColAggregateTest.java
@@ -41,6 +41,7 @@ public class FederatedColAggregateTest extends 
AutomatedTestBase {
        private final static String TEST_NAME2 = "FederatedColMeanTest";
        private final static String TEST_NAME3 = "FederatedColMaxTest";
        private final static String TEST_NAME4 = "FederatedColMinTest";
+       private final static String TEST_NAME5 = "FederatedColProdTest";
        private final static String TEST_NAME10 = "FederatedColVarTest";
 
        private final static String TEST_DIR = "functions/federated/aggregate/";
@@ -58,13 +59,13 @@ public class FederatedColAggregateTest extends 
AutomatedTestBase {
        public static Collection<Object[]> data() {
                return Arrays.asList(
                        new Object[][] {
-                               {10, 1000, false},
+//                             {10, 1000, false},
                                {1000, 40, true},
                });
        }
 
        private enum OpType {
-               SUM, MEAN, MAX, MIN, VAR
+               SUM, MEAN, MAX, MIN, VAR, PROD
        }
 
        @Override
@@ -75,6 +76,7 @@ public class FederatedColAggregateTest extends 
AutomatedTestBase {
                addTestConfiguration(TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S"}));
                addTestConfiguration(TEST_NAME4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"S"}));
                addTestConfiguration(TEST_NAME10, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME10, new String[] {"S"}));
+               addTestConfiguration(TEST_NAME5, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {"S"}));
        }
 
        @Test
@@ -103,6 +105,11 @@ public class FederatedColAggregateTest extends 
AutomatedTestBase {
                runAggregateOperationTest(OpType.VAR, ExecMode.SINGLE_NODE);
        }
 
+       @Test
+       public void testColProdDenseMatrixCP() {
+               runAggregateOperationTest(OpType.PROD, ExecMode.SINGLE_NODE);
+       }
+
        private void runAggregateOperationTest(OpType type, ExecMode execMode) {
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
                ExecMode platformOld = rtplatform;
@@ -127,6 +134,9 @@ public class FederatedColAggregateTest extends 
AutomatedTestBase {
                        case VAR:
                                TEST_NAME = TEST_NAME10;
                                break;
+                       case PROD:
+                               TEST_NAME = TEST_NAME5;
+                               break;
                }
 
                getAndLoadTestConfiguration(TEST_NAME);
@@ -140,16 +150,16 @@ public class FederatedColAggregateTest extends 
AutomatedTestBase {
                        c = cols;
                }
 
-               double[][] X1 = getRandomMatrix(r, c, 1, 3, 1, 3);
-               double[][] X2 = getRandomMatrix(r, c, 1, 3, 1, 7);
-               double[][] X3 = getRandomMatrix(r, c, 1, 3, 1, 8);
-               double[][] X4 = getRandomMatrix(r, c, 1, 3, 1, 9);
+               double[][] X1 = getRandomMatrix(r, c, 3, 3, 1, 3);
+               double[][] X2 = getRandomMatrix(r, c, 3, 3, 1, 7);
+               double[][] X3 = getRandomMatrix(r, c, 3, 3, 1, 8);
+               double[][] X4 = getRandomMatrix(r, c, 3, 3, 1, 9);
 
                MatrixCharacteristics mc = new MatrixCharacteristics(r, c, 
blocksize, r * c);
-               writeInputMatrixWithMTD("X1", X1, false, mc);
-               writeInputMatrixWithMTD("X2", X2, false, mc);
-               writeInputMatrixWithMTD("X3", X3, false, mc);
-               writeInputMatrixWithMTD("X4", X4, false, mc);
+               writeInputMatrixWithMTD("X1", X1, true, mc);
+               writeInputMatrixWithMTD("X2", X2, true, mc);
+               writeInputMatrixWithMTD("X3", X3, true, mc);
+               writeInputMatrixWithMTD("X4", X4, true, mc);
 
                // empty script name because we don't execute any script, just 
start the worker
                fullDMLScriptName = "";
@@ -189,7 +199,7 @@ public class FederatedColAggregateTest extends 
AutomatedTestBase {
                runTest(true, false, null, -1);
 
                // compare via files
-               compareResults(type == FederatedColAggregateTest.OpType.VAR ? 
1e-2 : 1e-9);
+               compareResults((type == FederatedColAggregateTest.OpType.VAR) 
|| (type == OpType.PROD) ? 1e-2 : 1e-9);
 
                String fedInst = "fed_uac";
 
@@ -209,6 +219,9 @@ public class FederatedColAggregateTest extends 
AutomatedTestBase {
                        case VAR:
                                
Assert.assertTrue(heavyHittersContainsString(fedInst.concat("var")));
                                break;
+                       case PROD:
+                               
Assert.assertTrue(heavyHittersContainsString(fedInst.concat("*")));
+                               break;
                }
 
                // check that federated input files are still existing
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowAggregateTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowAggregateTest.java
index e0a3632..c140dc8 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowAggregateTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowAggregateTest.java
@@ -42,6 +42,7 @@ public class FederatedRowAggregateTest extends 
AutomatedTestBase {
        private final static String TEST_NAME7 = "FederatedRowMaxTest";
        private final static String TEST_NAME8 = "FederatedRowMinTest";
        private final static String TEST_NAME9 = "FederatedRowVarTest";
+       private final static String TEST_NAME10 = "FederatedRowProdTest";
 
        private final static String TEST_DIR = "functions/federated/aggregate/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedRowAggregateTest.class.getSimpleName() + "/";
@@ -64,7 +65,7 @@ public class FederatedRowAggregateTest extends 
AutomatedTestBase {
        }
 
        private enum OpType {
-               SUM, MEAN, MAX, MIN, VAR
+               SUM, MEAN, MAX, MIN, VAR, PROD
        }
 
        @Override
@@ -75,6 +76,7 @@ public class FederatedRowAggregateTest extends 
AutomatedTestBase {
                addTestConfiguration(TEST_NAME7, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] {"S"}));
                addTestConfiguration(TEST_NAME8, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME8, new String[] {"S"}));
                addTestConfiguration(TEST_NAME9, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME9, new String[] {"S"}));
+               addTestConfiguration(TEST_NAME10, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME10, new String[] {"S"}));
        }
 
        @Test
@@ -102,6 +104,11 @@ public class FederatedRowAggregateTest extends 
AutomatedTestBase {
                runAggregateOperationTest(OpType.VAR, ExecMode.SINGLE_NODE);
        }
 
+       @Test
+       public void testRowProdDenseMatrixCP() {
+               runAggregateOperationTest(OpType.PROD, ExecMode.SINGLE_NODE);
+       }
+
        private void runAggregateOperationTest(OpType type, ExecMode execMode) {
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
                ExecMode platformOld = rtplatform;
@@ -126,6 +133,9 @@ public class FederatedRowAggregateTest extends 
AutomatedTestBase {
                        case VAR:
                                TEST_NAME = TEST_NAME9;
                                break;
+                       case PROD:
+                               TEST_NAME = TEST_NAME10;
+                               break;
                }
 
                getAndLoadTestConfiguration(TEST_NAME);
@@ -139,10 +149,10 @@ public class FederatedRowAggregateTest extends 
AutomatedTestBase {
                        c = cols;
                }
 
-               double[][] X1 = getRandomMatrix(r, c, 1, 3, 1, 3);
-               double[][] X2 = getRandomMatrix(r, c, 1, 3, 1, 7);
-               double[][] X3 = getRandomMatrix(r, c, 1, 3, 1, 8);
-               double[][] X4 = getRandomMatrix(r, c, 1, 3, 1, 9);
+               double[][] X1 = getRandomMatrix(r, c, 3, 3, 1, 3);
+               double[][] X2 = getRandomMatrix(r, c, 3, 3, 1, 7);
+               double[][] X3 = getRandomMatrix(r, c, 3, 3, 1, 8);
+               double[][] X4 = getRandomMatrix(r, c, 3, 3, 1, 9);
 
                MatrixCharacteristics mc = new MatrixCharacteristics(r, c, 
blocksize, r * c);
                writeInputMatrixWithMTD("X1", X1, false, mc);
@@ -208,6 +218,9 @@ public class FederatedRowAggregateTest extends 
AutomatedTestBase {
                        case VAR:
                                
Assert.assertTrue(heavyHittersContainsString(fedInst.concat("var")));
                                break;
+                       case PROD:
+                               
Assert.assertTrue(heavyHittersContainsString(fedInst.concat("*")));
+                               break;
                }
 
                // check that federated input files are still existing
diff --git 
a/src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml 
b/src/test/scripts/functions/federated/aggregate/FederatedColProdTest.R
similarity index 61%
copy from src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml
copy to src/test/scripts/functions/federated/aggregate/FederatedColProdTest.R
index 9de439e..95ef9e2 100644
--- a/src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml
+++ b/src/test/scripts/functions/federated/aggregate/FederatedColProdTest.R
@@ -18,17 +18,15 @@
 # under the License.
 #
 #-------------------------------------------------------------
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
 
-if ($rP) {
-    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
-        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
-               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
-} else {
-    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
-            ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
-               list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
-}
-
-s = sum(A);
-
-write(s, $out_S);
\ No newline at end of file
+X1 = as.matrix(readMM(paste(args[1], "X1.mtx", sep="")));
+X2 = as.matrix(readMM(paste(args[1], "X2.mtx", sep="")));
+X3 = as.matrix(readMM(paste(args[1], "X3.mtx", sep="")));
+X4 = as.matrix(readMM(paste(args[1], "X4.mtx", sep="")));
+X = rbind(X1, X2, X3, X4)
+R = colProds(X)
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep=""));
diff --git 
a/src/test/scripts/functions/federated/aggregate/FederatedRowVarTest.dml 
b/src/test/scripts/functions/federated/aggregate/FederatedColProdTest.dml
similarity index 97%
copy from src/test/scripts/functions/federated/aggregate/FederatedRowVarTest.dml
copy to src/test/scripts/functions/federated/aggregate/FederatedColProdTest.dml
index 8b4a57d..ae90cd7 100644
--- a/src/test/scripts/functions/federated/aggregate/FederatedRowVarTest.dml
+++ b/src/test/scripts/functions/federated/aggregate/FederatedColProdTest.dml
@@ -30,5 +30,5 @@ if ($rP) {
                list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
 }
 
-s = rowVars(A);
-write(s, $out_S);
\ No newline at end of file
+s = colProds(A);
+write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml 
b/src/test/scripts/functions/federated/aggregate/FederatedColProdTestReference.dml
similarity index 61%
copy from src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml
copy to 
src/test/scripts/functions/federated/aggregate/FederatedColProdTestReference.dml
index 9de439e..6fa8d53 100644
--- a/src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml
+++ 
b/src/test/scripts/functions/federated/aggregate/FederatedColProdTestReference.dml
@@ -19,16 +19,8 @@
 #
 #-------------------------------------------------------------
 
-if ($rP) {
-    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
-        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
-               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
-} else {
-    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
-            ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
-               list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
-}
+if($6) { A = rbind(read($1), read($2), read($3), read($4)); }
+else { A = cbind(read($1), read($2), read($3), read($4)); }
 
-s = sum(A);
-
-write(s, $out_S);
\ No newline at end of file
+s = colProds(A);
+write(s, $5);
diff --git 
a/src/test/scripts/functions/federated/aggregate/FederatedColVarTest.dml 
b/src/test/scripts/functions/federated/aggregate/FederatedColVarTest.dml
index 186dc1d..a7ac87f 100644
--- a/src/test/scripts/functions/federated/aggregate/FederatedColVarTest.dml
+++ b/src/test/scripts/functions/federated/aggregate/FederatedColVarTest.dml
@@ -31,4 +31,4 @@ if ($rP) {
 }
 
 s = colVars(A);
-write(s, $out_S);
\ No newline at end of file
+write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/aggregate/FederatedRowVarTest.dml 
b/src/test/scripts/functions/federated/aggregate/FederatedRowProdTest.dml
similarity index 97%
copy from src/test/scripts/functions/federated/aggregate/FederatedRowVarTest.dml
copy to src/test/scripts/functions/federated/aggregate/FederatedRowProdTest.dml
index 8b4a57d..9d5f11d 100644
--- a/src/test/scripts/functions/federated/aggregate/FederatedRowVarTest.dml
+++ b/src/test/scripts/functions/federated/aggregate/FederatedRowProdTest.dml
@@ -30,5 +30,5 @@ if ($rP) {
                list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
 }
 
-s = rowVars(A);
-write(s, $out_S);
\ No newline at end of file
+s = rowProds(A);
+write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml 
b/src/test/scripts/functions/federated/aggregate/FederatedRowProdTestReference.dml
similarity index 61%
copy from src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml
copy to 
src/test/scripts/functions/federated/aggregate/FederatedRowProdTestReference.dml
index 9de439e..b917d13 100644
--- a/src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml
+++ 
b/src/test/scripts/functions/federated/aggregate/FederatedRowProdTestReference.dml
@@ -19,16 +19,8 @@
 #
 #-------------------------------------------------------------
 
-if ($rP) {
-    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
-        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
-               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
-} else {
-    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
-            ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
-               list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
-}
+if($6) { A = rbind(read($1), read($2), read($3), read($4)); }
+else { A = cbind(read($1), read($2), read($3), read($4)); }
 
-s = sum(A);
-
-write(s, $out_S);
\ No newline at end of file
+s = rowProds(A);
+write(s, $5);
diff --git 
a/src/test/scripts/functions/federated/aggregate/FederatedRowVarTest.dml 
b/src/test/scripts/functions/federated/aggregate/FederatedRowVarTest.dml
index 8b4a57d..bec43a2 100644
--- a/src/test/scripts/functions/federated/aggregate/FederatedRowVarTest.dml
+++ b/src/test/scripts/functions/federated/aggregate/FederatedRowVarTest.dml
@@ -31,4 +31,4 @@ if ($rP) {
 }
 
 s = rowVars(A);
-write(s, $out_S);
\ No newline at end of file
+write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml 
b/src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml
index 9de439e..72a7cd6 100644
--- a/src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml
+++ b/src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml
@@ -31,4 +31,4 @@ if ($rP) {
 
 s = sum(A);
 
-write(s, $out_S);
\ No newline at end of file
+write(s, $out_S);

Reply via email to