This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 250c734 [SYSTEMDS-2763] Federated rowIndexMax and rowIndexMin
250c734 is described below
commit 250c7345980f69f4b7918c252af5f2ab2213c1aa
Author: Olga <[email protected]>
AuthorDate: Wed Dec 16 20:09:55 2020 +0100
[SYSTEMDS-2763] Federated rowIndexMax and rowIndexMin
This commit adds the functions prod and cov for federated execution.
(also included tests).
Closes #1130
---
.../org/apache/sysds/lops/PartialAggregate.java | 11 +-
.../controlprogram/federated/FederationUtils.java | 62 +++++++-
.../runtime/instructions/InstructionUtils.java | 19 ++-
.../cp/AggregateUnaryCPInstruction.java | 6 +
.../fed/AggregateUnaryFEDInstruction.java | 14 +-
.../instructions/fed/BinaryFEDInstruction.java | 2 +-
.../sysds/runtime/matrix/data/LibMatrixAgg.java | 3 +
.../federated/primitives/FederatedProdTest.java | 153 ++++++++++++++++++++
.../primitives/FederatedRowIndexTest.java | 156 +++++++++++++++++++++
.../functions/federated/FederatedProdTest.dml | 33 +++++
.../federated/FederatedProdTestReference.dml | 26 ++++
.../functions/federated/FederatedRowIndexTest.dml | 33 +++++
.../federated/FederatedRowIndexTestReference.dml | 26 ++++
13 files changed, 530 insertions(+), 14 deletions(-)
diff --git a/src/main/java/org/apache/sysds/lops/PartialAggregate.java
b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
index bfec9ff..c28a9d5 100644
--- a/src/main/java/org/apache/sysds/lops/PartialAggregate.java
+++ b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
@@ -233,8 +233,15 @@ public class PartialAggregate extends Lop
sb.append( OPERAND_DELIMITOR );
if( getExecType() == ExecType.SPARK )
sb.append( _aggtype );
- else if( getExecType() == ExecType.CP )
- sb.append( _numThreads );
+ else if( getExecType() == ExecType.CP ) {
+ sb.append(_numThreads);
+
+ //number of outputs, valid for fed instruction
+ if(getOpcode().equalsIgnoreCase("uarimin") ||
getOpcode().equalsIgnoreCase("uarimax")) {
+ sb.append(OPERAND_DELIMITOR);
+ sb.append("1");
+ }
+ }
return sb.toString();
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index 881991a..31a7136 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -38,12 +38,16 @@ import
org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
import org.apache.sysds.runtime.functionobjects.CM;
import org.apache.sysds.runtime.functionobjects.KahanFunction;
import org.apache.sysds.runtime.functionobjects.Mean;
+import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.functionobjects.ReduceAll;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.matrix.data.LibMatrixAgg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
@@ -189,6 +193,30 @@ public class FederationUtils {
}
}
+ public static MatrixBlock aggMinMaxIndex(Future<FederatedResponse>[]
ffr, boolean isMin, FederationMap map) {
+ try {
+ MatrixBlock prev = (MatrixBlock)
ffr[0].get().getData()[0];
+ int size = 0;
+ for(int i = 1; i < ffr.length; i++) {
+ MatrixBlock next = (MatrixBlock)
ffr[i].get().getData()[0];
+ size =
map.getFederatedRanges()[i-1].getEndDimsInt()[1];
+ for(int j = 0; j < prev.getNumRows(); j++) {
+ next.setValue(j, 0, next.getValue(j, 0)
+ size);
+ if((prev.getValue(j, 1) >
next.getValue(j, 1) && !isMin) ||
+ (prev.getValue(j, 1) <
next.getValue(j, 1) && isMin)) {
+ next.setValue(j, 0,
prev.getValue(j, 0));
+ next.setValue(j, 1,
prev.getValue(j, 1));
+ }
+ }
+ prev = next;
+ }
+ return prev.slice(0, prev.getNumRows()-1, 0,0, true,
new MatrixBlock());
+ }
+ catch (Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+ }
+
public static MatrixBlock aggVar(Future<FederatedResponse>[] ffr,
Future<FederatedResponse>[] meanFfr, FederationMap map, boolean isRowAggregate,
boolean isScalar) {
try {
FederatedRange[] ranges = map.getFederatedRanges();
@@ -325,13 +353,24 @@ public class FederationUtils {
if(!(aop.aggOp.increOp.fn instanceof KahanFunction ||
(aop.aggOp.increOp.fn instanceof Builtin &&
(((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() ==
BuiltinCode.MIN
|| ((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() ==
BuiltinCode.MAX)
- || aop.aggOp.increOp.fn instanceof Mean ))) {
+ || aop.aggOp.increOp.fn instanceof Mean
+ || aop.aggOp.increOp.fn instanceof Multiply))) {
throw new DMLRuntimeException("Unsupported aggregation
operator: "
+ aop.aggOp.increOp.getClass().getSimpleName());
}
try {
- if(aop.aggOp.increOp.fn instanceof Builtin){
+ if(aop.aggOp.increOp.fn instanceof Multiply){
+ MatrixBlock ret = new MatrixBlock(ffr.length,
1, false);
+ MatrixBlock res = new MatrixBlock(0);
+ for(int i = 0; i < ffr.length; i++)
+ ret.setValue(i, 0,
((ScalarObject)ffr[i].get().getData()[0]).getDoubleValue());
+ LibMatrixAgg.aggregateUnaryMatrix(ret, res,
+ new AggregateUnaryOperator(new
AggregateOperator(1, Multiply.getMultiplyFnObject()),
+
ReduceAll.getReduceAllFnObject()));
+ return new DoubleObject(res.quickGetValue(0,
0));
+ }
+ else if(aop.aggOp.increOp.fn instanceof Builtin){
// then we know it is a Min or Max based on the
previous check.
boolean isMin = ((Builtin)
aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN;
return new DoubleObject(aggMinMax(ffr, isMin,
true, Optional.empty()).getValue(0,0));
@@ -361,12 +400,21 @@ public class FederationUtils {
return aggAdd(ffr);
else if( aop.aggOp.increOp.fn instanceof Mean )
return aggMean(ffr, map);
- else if (aop.aggOp.increOp.fn instanceof Builtin &&
- (((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() ==
BuiltinCode.MIN ||
+ else if (aop.aggOp.increOp.fn instanceof Builtin) {
+ if ((((Builtin) aop.aggOp.increOp.fn).getBuiltinCode()
== BuiltinCode.MIN ||
((Builtin)
aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAX)) {
- boolean isMin = ((Builtin)
aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN;
- return aggMinMax(ffr,isMin,false,
Optional.of(map.getType()));
- } else
+ boolean isMin = ((Builtin)
aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN;
+ return aggMinMax(ffr,isMin,false,
Optional.of(map.getType()));
+ }
+ else if((((Builtin)
aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MININDEX)
+ || (((Builtin)
aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAXINDEX)) {
+ boolean isMin = ((Builtin)
aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MININDEX;
+ return aggMinMaxIndex(ffr,isMin, map);
+ }
+ else throw new DMLRuntimeException("Unsupported
aggregation operator: "
+ +
aop.aggOp.increOp.fn.getClass().getSimpleName());
+ }
+ else
throw new DMLRuntimeException("Unsupported aggregation
operator: "
+
aop.aggOp.increOp.fn.getClass().getSimpleName());
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index fa4ea24..49c3452 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -376,7 +376,7 @@ public class InstructionUtils
else if ( opcode.equalsIgnoreCase("uarmax") ) {
AggregateOperator agg = new
AggregateOperator(Double.NEGATIVE_INFINITY, Builtin.getBuiltinFnObject("max"));
aggun = new AggregateUnaryOperator(agg,
ReduceCol.getReduceColFnObject(), numThreads);
- }
+ }
else if (opcode.equalsIgnoreCase("uarimax") ) {
AggregateOperator agg = new
AggregateOperator(Double.NEGATIVE_INFINITY,
Builtin.getBuiltinFnObject("maxindex"), CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg,
ReduceCol.getReduceColFnObject(), numThreads);
@@ -384,7 +384,7 @@ public class InstructionUtils
else if ( opcode.equalsIgnoreCase("uarmin") ) {
AggregateOperator agg = new
AggregateOperator(Double.POSITIVE_INFINITY, Builtin.getBuiltinFnObject("min"));
aggun = new AggregateUnaryOperator(agg,
ReduceCol.getReduceColFnObject(), numThreads);
- }
+ }
else if (opcode.equalsIgnoreCase("uarimin") ) {
AggregateOperator agg = new
AggregateOperator(Double.POSITIVE_INFINITY,
Builtin.getBuiltinFnObject("minindex"), CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg,
ReduceCol.getReduceColFnObject(), numThreads);
@@ -401,6 +401,21 @@ public class InstructionUtils
return aggun;
}
+ public static AggregateUnaryOperator
parseAggregateUnaryRowIndexOperator(String opcode, int numOutputs, int
numThreads) {
+ AggregateUnaryOperator aggun = null;
+ AggregateOperator agg = null;
+ if (opcode.equalsIgnoreCase("uarimax") )
+ agg = new AggregateOperator(Double.NEGATIVE_INFINITY,
Builtin.getBuiltinFnObject("maxindex"),
+ numOutputs == 1 ?
CorrectionLocationType.LASTCOLUMN : CorrectionLocationType.NONE);
+
+ else if (opcode.equalsIgnoreCase("uarimin") )
+ agg = new AggregateOperator(Double.POSITIVE_INFINITY,
Builtin.getBuiltinFnObject("minindex"),
+ numOutputs == 1 ?
CorrectionLocationType.LASTCOLUMN : CorrectionLocationType.NONE);
+
+ aggun = new AggregateUnaryOperator(agg,
ReduceCol.getReduceColFnObject(), numThreads);
+ return aggun;
+ }
+
public static AggregateTernaryOperator
parseAggregateTernaryOperator(String opcode) {
return parseAggregateTernaryOperator(opcode, 1);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
index e6e74fb..ef1ff08 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
@@ -85,6 +85,12 @@ public class AggregateUnaryCPInstruction extends
UnaryCPInstruction {
return new AggregateUnaryCPInstruction(new
SimpleOperator(null),
in1, out, AUType.COUNT_DISTINCT_APPROX, opcode, str);
}
+ else if(opcode.equalsIgnoreCase("uarimax") ||
opcode.equalsIgnoreCase("uarimin")){
+ // parse with number of outputs
+ AggregateUnaryOperator aggun = InstructionUtils
+ .parseAggregateUnaryRowIndexOperator(opcode,
Integer.parseInt(parts[4]), Integer.parseInt(parts[3]));
+ return new AggregateUnaryCPInstruction(aggun, in1, out,
AUType.DEFAULT, opcode, str);
+ }
else { //DEFAULT BEHAVIOR
AggregateUnaryOperator aggun = InstructionUtils
.parseBasicAggregateUnaryOperator(opcode,
Integer.parseInt(parts[3]));
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index b9f220b..4fbe4e6 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.instructions.fed;
import java.util.concurrent.Future;
+import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -56,7 +57,13 @@ public class AggregateUnaryFEDInstruction extends
UnaryFEDInstruction {
String opcode = parts[0];
CPOperand in1 = new CPOperand(parts[1]);
CPOperand out = new CPOperand(parts[2]);
- AggregateUnaryOperator aggun =
InstructionUtils.parseBasicAggregateUnaryOperator(opcode);
+
+ AggregateUnaryOperator aggun = null;
+ if(opcode.equalsIgnoreCase("uarimax") ||
opcode.equalsIgnoreCase("uarimin"))
+ aggun =
InstructionUtils.parseAggregateUnaryRowIndexOperator(opcode,
Integer.parseInt(parts[4]), 1);
+ else
+ aggun =
InstructionUtils.parseBasicAggregateUnaryOperator(opcode);
+
if(InstructionUtils.getExecType(str) == ExecType.SPARK)
str = InstructionUtils.replaceOperand(str, 4, "-1");
return new AggregateUnaryFEDInstruction(aggun, in1, out,
opcode, str);
@@ -76,7 +83,10 @@ public class AggregateUnaryFEDInstruction extends
UnaryFEDInstruction {
AggregateUnaryOperator aop = (AggregateUnaryOperator) _optr;
MatrixObject in = ec.getMatrixObject(input1);
FederationMap map = in.getFedMapping();
-
+
+ if((instOpcode.equalsIgnoreCase("uarimax") ||
instOpcode.equalsIgnoreCase("uarimin")) &&
in.isFederated(FederationMap.FType.COL))
+ instString =
InstructionUtils.replaceOperand(instString, 5, "2");
+
//create federated commands for aggregation
FederatedRequest fr1 =
FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1}, new
long[]{in.getFedMapping().getID()});
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
index f1f8f38..ffaf2af 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
@@ -58,7 +58,7 @@ public abstract class BinaryFEDInstruction extends
ComputationFEDInstruction {
else
throw new DMLRuntimeException("Federated binary
operations not yet supported:" + opcode);
}
-
+
protected static void checkOutputDataType(CPOperand in1, CPOperand in2,
CPOperand out) {
// check for valid data type of output
if( (in1.getDataType() == DataType.MATRIX || in2.getDataType()
== DataType.MATRIX) && out.getDataType() != DataType.MATRIX )
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
index ce16369..743dce9 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
@@ -231,6 +231,9 @@ public class LibMatrixAgg
public static void aggregateUnaryMatrix(MatrixBlock in, MatrixBlock
out, AggregateUnaryOperator uaop, int k) {
//fall back to sequential version if necessary
if( !satisfiesMultiThreadingConstraints(in, out, uaop, k) ) {
+ if(uaop.aggOp.increOp.fn instanceof Builtin &&
(((((Builtin) uaop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MININDEX)
+ || (((Builtin)
uaop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAXINDEX)) &&
uaop.aggOp.correction.getNumRemovedRowsColumns()==0))
+ out.clen = 2;
aggregateUnaryMatrix(in, out, uaop);
return;
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedProdTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedProdTest.java
new file mode 100644
index 0000000..70859d4
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedProdTest.java
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedProdTest extends AutomatedTestBase {
+
+ private final static String TEST_NAME = "FederatedProdTest";
+
+ private final static String TEST_DIR = "functions/federated/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
FederatedProdTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+
+ @Parameterized.Parameter(2)
+ public boolean rowPartitioned;
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ {100, 12, true},
+ {100, 12, false}
+ });
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S.scalar"}));
+ }
+
+ @Test
+ public void testProdCP() { runProdTest(ExecMode.SINGLE_NODE); }
+
+ private void runProdTest(ExecMode execMode) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ ExecMode platformOld = rtplatform;
+
+ if(rtplatform == ExecMode.SPARK)
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // write input matrices
+ int r = rows;
+ int c = cols / 4;
+ if(rowPartitioned) {
+ r = rows / 4;
+ c = cols;
+ }
+
+ double[][] X1 = getRandomMatrix(r, c, 0, 2, 1, 3);
+ double[][] X2 = getRandomMatrix(r, c, 0, 2, 1, 7);
+ double[][] X3 = getRandomMatrix(r, c, 0, 2, 1, 8);
+ double[][] X4 = getRandomMatrix(r, c, 0, 2, 1, 9);
+
+ MatrixCharacteristics mc = new MatrixCharacteristics(r, c,
blocksize, r * c);
+ writeInputMatrixWithMTD("X1", X1, false, mc);
+ writeInputMatrixWithMTD("X2", X2, false, mc);
+ writeInputMatrixWithMTD("X3", X3, false, mc);
+ writeInputMatrixWithMTD("X4", X4, false, mc);
+
+ // empty script name because we don't execute any script, just
start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ int port3 = getRandomAvailablePort();
+ int port4 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+ Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
+ Thread t4 = startLocalFedWorkerThread(port4);
+
+ rtplatform = execMode;
+ if(rtplatform == ExecMode.SPARK) {
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+ TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-stats", "100", "-args",
input("X1"), input("X2"), input("X3"), input("X4"),
+ Boolean.toString(rowPartitioned).toUpperCase(),
expected("S")};
+
+ runTest(null);
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "100", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "in_X3=" + TestUtils.federatedAddress(port3,
input("X3")),
+ "in_X4=" + TestUtils.federatedAddress(port4,
input("X4")), "rows=" + rows, "cols=" + cols,
+ "rP=" + Boolean.toString(rowPartitioned).toUpperCase(),
"out_S=" + output("S")};
+
+ runTest(null);
+
+ // compare via files
+ compareResults(1e-9);
+
+ Assert.assertTrue(heavyHittersContainsString("fed_ua*"));
+
+ // check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+ TestUtils.shutdownThreads(t1, t2, t3, t4);
+
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowIndexTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowIndexTest.java
new file mode 100644
index 0000000..e4c7534
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowIndexTest.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedRowIndexTest extends AutomatedTestBase {
+
+ private final static String TEST_NAME = "FederatedRowIndexTest";
+
+ private final static String TEST_DIR = "functions/federated/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
FederatedRowIndexTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+
+ @Parameterized.Parameter(2)
+ public boolean rowPartitioned;
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ {1000, 12, true},
+ {1000, 12, false}
+ });
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"}));
+ }
+
+ @Test
+ public void testRowIndexCP() {
+ runRowIndexTest(ExecMode.SINGLE_NODE);
+ }
+
+ private void runRowIndexTest(ExecMode execMode) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ ExecMode platformOld = rtplatform;
+
+ if(rtplatform == ExecMode.SPARK)
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // write input matrices
+ int r = rows;
+ int c = cols / 4;
+ if(rowPartitioned) {
+ r = rows / 4;
+ c = cols;
+ }
+
+ double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
+ double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
+ double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
+ double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
+
+ MatrixCharacteristics mc = new MatrixCharacteristics(r, c,
blocksize, r * c);
+ writeInputMatrixWithMTD("X1", X1, false, mc);
+ writeInputMatrixWithMTD("X2", X2, false, mc);
+ writeInputMatrixWithMTD("X3", X3, false, mc);
+ writeInputMatrixWithMTD("X4", X4, false, mc);
+
+ // empty script name because we don't execute any script, just
start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ int port3 = getRandomAvailablePort();
+ int port4 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+ Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
+ Thread t4 = startLocalFedWorkerThread(port4);
+
+ rtplatform = execMode;
+ if(rtplatform == ExecMode.SPARK) {
+ System.out.println(7);
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+ TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-stats", "100", "-args",
input("X1"), input("X2"), input("X3"), input("X4"),
+ Boolean.toString(rowPartitioned).toUpperCase(),
expected("S")};
+
+ runTest(null);
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "100", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "in_X3=" + TestUtils.federatedAddress(port3,
input("X3")),
+ "in_X4=" + TestUtils.federatedAddress(port4,
input("X4")), "rows=" + rows, "cols=" + cols,
+ "rP=" + Boolean.toString(rowPartitioned).toUpperCase(),
"out_S=" + output("S")};
+
+ runTest(null);
+
+// compare via files
+ compareResults(1e-9);
+
+ Assert.assertTrue(heavyHittersContainsString("fed_uarimax"));
+
+ // check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+ TestUtils.shutdownThreads(t1, t2, t3, t4);
+
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+
+ }
+}
diff --git a/src/test/scripts/functions/federated/FederatedProdTest.dml
b/src/test/scripts/functions/federated/FederatedProdTest.dml
new file mode 100644
index 0000000..ead8936
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedProdTest.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if ($rP) {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+} else {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4),
list($rows, $cols/2),
+ list(0,$cols/2), list($rows, 3*($cols/4)), list(0,
3*($cols/4)), list($rows, $cols)));
+}
+
+s = prod(A);
+write(s, $out_S);
diff --git
a/src/test/scripts/functions/federated/FederatedProdTestReference.dml
b/src/test/scripts/functions/federated/FederatedProdTestReference.dml
new file mode 100644
index 0000000..32293cf
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedProdTestReference.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if($5) { A = rbind(read($1), read($2), read($3), read($4)); }
+else { A = cbind(read($1), read($2), read($3), read($4)); }
+
+s = prod(A);
+write(s, $6);
diff --git a/src/test/scripts/functions/federated/FederatedRowIndexTest.dml
b/src/test/scripts/functions/federated/FederatedRowIndexTest.dml
new file mode 100644
index 0000000..bc80bfc
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRowIndexTest.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if ($rP) {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+} else {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4),
list($rows, $cols/2),
+ list(0,$cols/2), list($rows, 3*($cols/4)), list(0,
3*($cols/4)), list($rows, $cols)));
+}
+
+s = rowIndexMax(A);
+write(s, $out_S);
diff --git
a/src/test/scripts/functions/federated/FederatedRowIndexTestReference.dml
b/src/test/scripts/functions/federated/FederatedRowIndexTestReference.dml
new file mode 100644
index 0000000..a430bf1
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRowIndexTestReference.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if($5) { A = rbind(read($1), read($2), read($3), read($4)); }
+else { A = cbind(read($1), read($2), read($3), read($4)); }
+
+s = rowIndexMax(A);
+write(s, $6);