This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 250c734  [SYSTEMDS-2763] Federated rowIndexMax and rowIndexMin
250c734 is described below

commit 250c7345980f69f4b7918c252af5f2ab2213c1aa
Author: Olga <[email protected]>
AuthorDate: Wed Dec 16 20:09:55 2020 +0100

    [SYSTEMDS-2763] Federated rowIndexMax and rowIndexMin
    
    This commit adds the functions prod and cov for federated execution.
    (also included tests).
    
    Closes #1130
---
 .../org/apache/sysds/lops/PartialAggregate.java    |  11 +-
 .../controlprogram/federated/FederationUtils.java  |  62 +++++++-
 .../runtime/instructions/InstructionUtils.java     |  19 ++-
 .../cp/AggregateUnaryCPInstruction.java            |   6 +
 .../fed/AggregateUnaryFEDInstruction.java          |  14 +-
 .../instructions/fed/BinaryFEDInstruction.java     |   2 +-
 .../sysds/runtime/matrix/data/LibMatrixAgg.java    |   3 +
 .../federated/primitives/FederatedProdTest.java    | 153 ++++++++++++++++++++
 .../primitives/FederatedRowIndexTest.java          | 156 +++++++++++++++++++++
 .../functions/federated/FederatedProdTest.dml      |  33 +++++
 .../federated/FederatedProdTestReference.dml       |  26 ++++
 .../functions/federated/FederatedRowIndexTest.dml  |  33 +++++
 .../federated/FederatedRowIndexTestReference.dml   |  26 ++++
 13 files changed, 530 insertions(+), 14 deletions(-)

diff --git a/src/main/java/org/apache/sysds/lops/PartialAggregate.java 
b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
index bfec9ff..c28a9d5 100644
--- a/src/main/java/org/apache/sysds/lops/PartialAggregate.java
+++ b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
@@ -233,8 +233,15 @@ public class PartialAggregate extends Lop
                sb.append( OPERAND_DELIMITOR );
                if( getExecType() == ExecType.SPARK )
                        sb.append( _aggtype );
-               else if( getExecType() == ExecType.CP )
-                       sb.append( _numThreads );       
+               else if( getExecType() == ExecType.CP ) {
+                       sb.append(_numThreads);
+
+                       //number of outputs, valid for fed instruction
+                       if(getOpcode().equalsIgnoreCase("uarimin") || 
getOpcode().equalsIgnoreCase("uarimax")) {
+                               sb.append(OPERAND_DELIMITOR);
+                               sb.append("1");
+                       }
+               }
                
                return sb.toString();
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index 881991a..31a7136 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -38,12 +38,16 @@ import 
org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
 import org.apache.sysds.runtime.functionobjects.CM;
 import org.apache.sysds.runtime.functionobjects.KahanFunction;
 import org.apache.sysds.runtime.functionobjects.Mean;
+import org.apache.sysds.runtime.functionobjects.Multiply;
 import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.functionobjects.ReduceAll;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.DoubleObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.matrix.data.LibMatrixAgg;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
 import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
 import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
@@ -189,6 +193,30 @@ public class FederationUtils {
                }
        }
 
+       public static MatrixBlock aggMinMaxIndex(Future<FederatedResponse>[] 
ffr, boolean isMin, FederationMap map) {
+               try {
+                       MatrixBlock prev = (MatrixBlock) 
ffr[0].get().getData()[0];
+                       int size = 0;
+                       for(int i = 1; i < ffr.length; i++) {
+                               MatrixBlock next = (MatrixBlock) 
ffr[i].get().getData()[0];
+                               size = 
map.getFederatedRanges()[i-1].getEndDimsInt()[1];
+                               for(int j = 0; j < prev.getNumRows(); j++) {
+                                       next.setValue(j, 0, next.getValue(j, 0) 
+ size);
+                                       if((prev.getValue(j, 1) > 
next.getValue(j, 1) && !isMin) ||
+                                               (prev.getValue(j, 1) < 
next.getValue(j, 1) && isMin)) {
+                                               next.setValue(j, 0, 
prev.getValue(j, 0));
+                                               next.setValue(j, 1, 
prev.getValue(j, 1));
+                                       }
+                               }
+                               prev = next;
+                       }
+                       return prev.slice(0, prev.getNumRows()-1, 0,0, true, 
new MatrixBlock());
+               }
+               catch (Exception ex) {
+                       throw new DMLRuntimeException(ex);
+               }
+       }
+
        public static MatrixBlock aggVar(Future<FederatedResponse>[] ffr, 
Future<FederatedResponse>[] meanFfr, FederationMap map, boolean isRowAggregate, 
boolean isScalar) {
                try {
                        FederatedRange[] ranges = map.getFederatedRanges();
@@ -325,13 +353,24 @@ public class FederationUtils {
                if(!(aop.aggOp.increOp.fn instanceof KahanFunction || 
(aop.aggOp.increOp.fn instanceof Builtin &&
                        (((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == 
BuiltinCode.MIN
                        || ((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == 
BuiltinCode.MAX)
-                       || aop.aggOp.increOp.fn instanceof Mean ))) {
+                       || aop.aggOp.increOp.fn instanceof Mean
+                       || aop.aggOp.increOp.fn instanceof Multiply))) {
                        throw new DMLRuntimeException("Unsupported aggregation 
operator: "
                                + aop.aggOp.increOp.getClass().getSimpleName());
                }
 
                try {
-                       if(aop.aggOp.increOp.fn instanceof Builtin){
+                       if(aop.aggOp.increOp.fn instanceof Multiply){
+                               MatrixBlock ret = new MatrixBlock(ffr.length, 
1, false);
+                               MatrixBlock res = new MatrixBlock(0);
+                               for(int i = 0; i < ffr.length; i++)
+                                       ret.setValue(i, 0, 
((ScalarObject)ffr[i].get().getData()[0]).getDoubleValue());
+                               LibMatrixAgg.aggregateUnaryMatrix(ret, res,
+                                       new AggregateUnaryOperator(new 
AggregateOperator(1, Multiply.getMultiplyFnObject()),
+                                               
ReduceAll.getReduceAllFnObject()));
+                               return new DoubleObject(res.quickGetValue(0, 
0));
+                       }
+                       else if(aop.aggOp.increOp.fn instanceof Builtin){
                                // then we know it is a Min or Max based on the 
previous check.
                                boolean isMin = ((Builtin) 
aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN;
                                return new DoubleObject(aggMinMax(ffr, isMin, 
true,  Optional.empty()).getValue(0,0));
@@ -361,12 +400,21 @@ public class FederationUtils {
                        return aggAdd(ffr);
                else if( aop.aggOp.increOp.fn instanceof Mean )
                        return aggMean(ffr, map);
-               else if (aop.aggOp.increOp.fn instanceof Builtin &&
-                       (((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == 
BuiltinCode.MIN ||
+               else if (aop.aggOp.increOp.fn instanceof Builtin) {
+                       if ((((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() 
== BuiltinCode.MIN ||
                                ((Builtin) 
aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAX)) {
-                       boolean isMin = ((Builtin) 
aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN;
-                       return aggMinMax(ffr,isMin,false, 
Optional.of(map.getType()));
-               } else
+                               boolean isMin = ((Builtin) 
aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN;
+                               return aggMinMax(ffr,isMin,false, 
Optional.of(map.getType()));
+                       }
+                       else if((((Builtin) 
aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MININDEX)
+                               || (((Builtin) 
aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAXINDEX)) {
+                               boolean isMin = ((Builtin) 
aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MININDEX;
+                               return aggMinMaxIndex(ffr,isMin, map);
+                       }
+                       else throw new DMLRuntimeException("Unsupported 
aggregation operator: "
+                                       + 
aop.aggOp.increOp.fn.getClass().getSimpleName());
+               }
+               else
                        throw new DMLRuntimeException("Unsupported aggregation 
operator: "
                                + 
aop.aggOp.increOp.fn.getClass().getSimpleName());
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java 
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index fa4ea24..49c3452 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -376,7 +376,7 @@ public class InstructionUtils
                else if ( opcode.equalsIgnoreCase("uarmax") ) {
                        AggregateOperator agg = new 
AggregateOperator(Double.NEGATIVE_INFINITY, Builtin.getBuiltinFnObject("max"));
                        aggun = new AggregateUnaryOperator(agg, 
ReduceCol.getReduceColFnObject(), numThreads);
-               } 
+               }
                else if (opcode.equalsIgnoreCase("uarimax") ) {
                        AggregateOperator agg = new 
AggregateOperator(Double.NEGATIVE_INFINITY, 
Builtin.getBuiltinFnObject("maxindex"), CorrectionLocationType.LASTCOLUMN);
                        aggun = new AggregateUnaryOperator(agg, 
ReduceCol.getReduceColFnObject(), numThreads);
@@ -384,7 +384,7 @@ public class InstructionUtils
                else if ( opcode.equalsIgnoreCase("uarmin") ) {
                        AggregateOperator agg = new 
AggregateOperator(Double.POSITIVE_INFINITY, Builtin.getBuiltinFnObject("min"));
                        aggun = new AggregateUnaryOperator(agg, 
ReduceCol.getReduceColFnObject(), numThreads);
-               } 
+               }
                else if (opcode.equalsIgnoreCase("uarimin") ) {
                        AggregateOperator agg = new 
AggregateOperator(Double.POSITIVE_INFINITY, 
Builtin.getBuiltinFnObject("minindex"), CorrectionLocationType.LASTCOLUMN);
                        aggun = new AggregateUnaryOperator(agg, 
ReduceCol.getReduceColFnObject(), numThreads);
@@ -401,6 +401,21 @@ public class InstructionUtils
                return aggun;
        }
 
+       public static AggregateUnaryOperator 
parseAggregateUnaryRowIndexOperator(String opcode, int numOutputs, int 
numThreads) {
+               AggregateUnaryOperator aggun = null;
+               AggregateOperator agg = null;
+               if (opcode.equalsIgnoreCase("uarimax") )
+                       agg = new AggregateOperator(Double.NEGATIVE_INFINITY, 
Builtin.getBuiltinFnObject("maxindex"),
+                               numOutputs == 1 ? 
CorrectionLocationType.LASTCOLUMN : CorrectionLocationType.NONE);
+
+               else if (opcode.equalsIgnoreCase("uarimin") )
+                       agg = new AggregateOperator(Double.POSITIVE_INFINITY, 
Builtin.getBuiltinFnObject("minindex"),
+                               numOutputs == 1 ? 
CorrectionLocationType.LASTCOLUMN : CorrectionLocationType.NONE);
+
+               aggun = new AggregateUnaryOperator(agg, 
ReduceCol.getReduceColFnObject(), numThreads);
+               return aggun;
+       }
+
        public static AggregateTernaryOperator 
parseAggregateTernaryOperator(String opcode) {
                return parseAggregateTernaryOperator(opcode, 1);
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
index e6e74fb..ef1ff08 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
@@ -85,6 +85,12 @@ public class AggregateUnaryCPInstruction extends 
UnaryCPInstruction {
                        return new AggregateUnaryCPInstruction(new 
SimpleOperator(null),
                        in1, out, AUType.COUNT_DISTINCT_APPROX, opcode, str);
                }
+               else if(opcode.equalsIgnoreCase("uarimax") || 
opcode.equalsIgnoreCase("uarimin")){
+                       // parse with number of outputs
+                       AggregateUnaryOperator aggun = InstructionUtils
+                               .parseAggregateUnaryRowIndexOperator(opcode, 
Integer.parseInt(parts[4]), Integer.parseInt(parts[3]));
+                       return new AggregateUnaryCPInstruction(aggun, in1, out, 
AUType.DEFAULT, opcode, str);
+               }
                else { //DEFAULT BEHAVIOR
                        AggregateUnaryOperator aggun = InstructionUtils
                                .parseBasicAggregateUnaryOperator(opcode, 
Integer.parseInt(parts[3]));
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index b9f220b..4fbe4e6 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.instructions.fed;
 
 import java.util.concurrent.Future;
 
+import org.apache.sysds.lops.Lop;
 import org.apache.sysds.lops.LopProperties.ExecType;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -56,7 +57,13 @@ public class AggregateUnaryFEDInstruction extends 
UnaryFEDInstruction {
                String opcode = parts[0];
                CPOperand in1 = new CPOperand(parts[1]);
                CPOperand out = new CPOperand(parts[2]);
-               AggregateUnaryOperator aggun = 
InstructionUtils.parseBasicAggregateUnaryOperator(opcode);
+
+               AggregateUnaryOperator aggun = null;
+               if(opcode.equalsIgnoreCase("uarimax") || 
opcode.equalsIgnoreCase("uarimin"))
+                       aggun = 
InstructionUtils.parseAggregateUnaryRowIndexOperator(opcode, 
Integer.parseInt(parts[4]), 1);
+               else
+                       aggun = 
InstructionUtils.parseBasicAggregateUnaryOperator(opcode);
+
                if(InstructionUtils.getExecType(str) == ExecType.SPARK)
                        str = InstructionUtils.replaceOperand(str, 4, "-1");
                return new AggregateUnaryFEDInstruction(aggun, in1, out, 
opcode, str);
@@ -76,7 +83,10 @@ public class AggregateUnaryFEDInstruction extends 
UnaryFEDInstruction {
                AggregateUnaryOperator aop = (AggregateUnaryOperator) _optr;
                MatrixObject in = ec.getMatrixObject(input1);
                FederationMap map = in.getFedMapping();
-               
+
+               if((instOpcode.equalsIgnoreCase("uarimax") || 
instOpcode.equalsIgnoreCase("uarimin")) && 
in.isFederated(FederationMap.FType.COL))
+                       instString = 
InstructionUtils.replaceOperand(instString, 5, "2");
+
                //create federated commands for aggregation
                FederatedRequest fr1 = 
FederationUtils.callInstruction(instString, output,
                        new CPOperand[]{input1}, new 
long[]{in.getFedMapping().getID()});
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
index f1f8f38..ffaf2af 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
@@ -58,7 +58,7 @@ public abstract class BinaryFEDInstruction extends 
ComputationFEDInstruction {
                else
                        throw new DMLRuntimeException("Federated binary 
operations not yet supported:" + opcode);
        }
-       
+
        protected static void checkOutputDataType(CPOperand in1, CPOperand in2, 
CPOperand out) {
                // check for valid data type of output
                if( (in1.getDataType() == DataType.MATRIX || in2.getDataType() 
== DataType.MATRIX) && out.getDataType() != DataType.MATRIX )
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
index ce16369..743dce9 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
@@ -231,6 +231,9 @@ public class LibMatrixAgg
        public static void aggregateUnaryMatrix(MatrixBlock in, MatrixBlock 
out, AggregateUnaryOperator uaop, int k) {
                //fall back to sequential version if necessary
                if( !satisfiesMultiThreadingConstraints(in, out, uaop, k) ) {
+                       if(uaop.aggOp.increOp.fn instanceof Builtin && 
(((((Builtin) uaop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MININDEX)
+                               || (((Builtin) 
uaop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAXINDEX)) && 
uaop.aggOp.correction.getNumRemovedRowsColumns()==0))
+                                       out.clen = 2;
                        aggregateUnaryMatrix(in, out, uaop);
                        return;
                }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedProdTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedProdTest.java
new file mode 100644
index 0000000..70859d4
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedProdTest.java
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedProdTest extends AutomatedTestBase {
+
+       private final static String TEST_NAME = "FederatedProdTest";
+
+       private final static String TEST_DIR = "functions/federated/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedProdTest.class.getSimpleName() + "/";
+
+       private final static int blocksize = 1024;
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+
+       @Parameterized.Parameter(2)
+       public boolean rowPartitioned;
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               return Arrays.asList(new Object[][] {
+                       {100, 12, true},
+                       {100, 12, false}
+               });
+       }
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S.scalar"}));
+       }
+
+       @Test
+       public void testProdCP() { runProdTest(ExecMode.SINGLE_NODE); }
+
+       private void runProdTest(ExecMode execMode) {
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               ExecMode platformOld = rtplatform;
+
+               if(rtplatform == ExecMode.SPARK)
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               // write input matrices
+               int r = rows;
+               int c = cols / 4;
+               if(rowPartitioned) {
+                       r = rows / 4;
+                       c = cols;
+               }
+
+               double[][] X1 = getRandomMatrix(r, c, 0, 2, 1, 3);
+               double[][] X2 = getRandomMatrix(r, c, 0, 2, 1, 7);
+               double[][] X3 = getRandomMatrix(r, c, 0, 2, 1, 8);
+               double[][] X4 = getRandomMatrix(r, c, 0, 2, 1, 9);
+
+               MatrixCharacteristics mc = new MatrixCharacteristics(r, c, 
blocksize, r * c);
+               writeInputMatrixWithMTD("X1", X1, false, mc);
+               writeInputMatrixWithMTD("X2", X2, false, mc);
+               writeInputMatrixWithMTD("X3", X3, false, mc);
+               writeInputMatrixWithMTD("X4", X4, false, mc);
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               int port3 = getRandomAvailablePort();
+               int port4 = getRandomAvailablePort();
+               Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+               Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+               Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
+               Thread t4 = startLocalFedWorkerThread(port4);
+
+               rtplatform = execMode;
+               if(rtplatform == ExecMode.SPARK) {
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               }
+               TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+               loadTestConfiguration(config);
+
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+               programArgs = new String[] {"-stats", "100", "-args", 
input("X1"), input("X2"), input("X3"), input("X4"),
+                       Boolean.toString(rowPartitioned).toUpperCase(), 
expected("S")};
+
+               runTest(null);
+
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-stats", "100", "-nvargs",
+                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                       "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
+                       "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")), "rows=" + rows, "cols=" + cols,
+                       "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), 
"out_S=" + output("S")};
+
+               runTest(null);
+
+               // compare via files
+               compareResults(1e-9);
+
+               Assert.assertTrue(heavyHittersContainsString("fed_ua*"));
+
+               // check that federated input files are still existing
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+               TestUtils.shutdownThreads(t1, t2, t3, t4);
+
+               rtplatform = platformOld;
+               DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowIndexTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowIndexTest.java
new file mode 100644
index 0000000..e4c7534
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowIndexTest.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedRowIndexTest extends AutomatedTestBase {
+
+       private final static String TEST_NAME = "FederatedRowIndexTest";
+
+       private final static String TEST_DIR = "functions/federated/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedRowIndexTest.class.getSimpleName() + "/";
+
+       private final static int blocksize = 1024;
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+
+       @Parameterized.Parameter(2)
+       public boolean rowPartitioned;
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               return Arrays.asList(new Object[][] {
+                       {1000, 12, true},
+                       {1000, 12, false}
+               });
+       }
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"}));
+       }
+
+       @Test
+       public void testRowIndexCP() {
+               runRowIndexTest(ExecMode.SINGLE_NODE);
+       }
+
+       private void runRowIndexTest(ExecMode execMode) {
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               ExecMode platformOld = rtplatform;
+
+               if(rtplatform == ExecMode.SPARK)
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               // write input matrices
+               int r = rows;
+               int c = cols / 4;
+               if(rowPartitioned) {
+                       r = rows / 4;
+                       c = cols;
+               }
+
+               double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
+               double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
+               double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
+               double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
+
+               MatrixCharacteristics mc = new MatrixCharacteristics(r, c, 
blocksize, r * c);
+               writeInputMatrixWithMTD("X1", X1, false, mc);
+               writeInputMatrixWithMTD("X2", X2, false, mc);
+               writeInputMatrixWithMTD("X3", X3, false, mc);
+               writeInputMatrixWithMTD("X4", X4, false, mc);
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               int port3 = getRandomAvailablePort();
+               int port4 = getRandomAvailablePort();
+               Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+               Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+               Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
+               Thread t4 = startLocalFedWorkerThread(port4);
+
+               rtplatform = execMode;
+               if(rtplatform == ExecMode.SPARK) {
+                       System.out.println(7);
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               }
+               TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+               loadTestConfiguration(config);
+
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+               programArgs = new String[] {"-stats", "100", "-args", 
input("X1"), input("X2"), input("X3"), input("X4"),
+                       Boolean.toString(rowPartitioned).toUpperCase(), 
expected("S")};
+
+               runTest(null);
+
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-stats", "100", "-nvargs",
+                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                       "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
+                       "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")), "rows=" + rows, "cols=" + cols,
+                       "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), 
"out_S=" + output("S")};
+
+               runTest(null);
+
+//              compare via files
+               compareResults(1e-9);
+
+               Assert.assertTrue(heavyHittersContainsString("fed_uarimax"));
+
+               // check that federated input files are still existing
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+               TestUtils.shutdownThreads(t1, t2, t3, t4);
+
+               rtplatform = platformOld;
+               DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+
+       }
+}
diff --git a/src/test/scripts/functions/federated/FederatedProdTest.dml 
b/src/test/scripts/functions/federated/FederatedProdTest.dml
new file mode 100644
index 0000000..ead8936
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedProdTest.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if ($rP) {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+} else {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+            ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
+               list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
+}
+
+s = prod(A);
+write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/FederatedProdTestReference.dml 
b/src/test/scripts/functions/federated/FederatedProdTestReference.dml
new file mode 100644
index 0000000..32293cf
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedProdTestReference.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if($5) { A = rbind(read($1), read($2), read($3), read($4)); }
+else { A = cbind(read($1), read($2), read($3), read($4)); }
+
+s = prod(A);
+write(s, $6);
diff --git a/src/test/scripts/functions/federated/FederatedRowIndexTest.dml 
b/src/test/scripts/functions/federated/FederatedRowIndexTest.dml
new file mode 100644
index 0000000..bc80bfc
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRowIndexTest.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if ($rP) {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+} else {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+            ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
+               list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
+}
+
+s = rowIndexMax(A);
+write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/FederatedRowIndexTestReference.dml 
b/src/test/scripts/functions/federated/FederatedRowIndexTestReference.dml
new file mode 100644
index 0000000..a430bf1
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRowIndexTestReference.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if($5) { A = rbind(read($1), read($2), read($3), read($4)); }
+else { A = cbind(read($1), read($2), read($3), read($4)); }
+
+s = rowIndexMax(A);
+write(s, $6);

Reply via email to