This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 29b3c612de [SYSTEMDS-3729] Add missing federated roll reorg operations
29b3c612de is described below
commit 29b3c612deb2312130987cc906d6790e7c7187ca
Author: min-guk <[email protected]>
AuthorDate: Sun Oct 20 18:15:36 2024 +0200
[SYSTEMDS-3729] Add missing federated roll reorg operations
Closes #2126.
---
.../controlprogram/federated/FederationMap.java | 31 +++-
.../runtime/instructions/FEDInstructionParser.java | 1 +
.../instructions/cp/ReorgCPInstruction.java | 2 +-
.../instructions/fed/ReorgFEDInstruction.java | 145 +++++++++++++++-
.../instructions/fed/UnaryFEDInstruction.java | 6 +-
.../instructions/spark/ReorgSPInstruction.java | 2 +-
.../primitives/part2/FederatedRollTest.java | 187 +++++++++++++++++++++
.../functions/federated/FederatedRollTest.dml | 32 ++++
.../federated/FederatedRollTestReference.dml | 26 +++
9 files changed, 422 insertions(+), 10 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 985fdb056e..91e6c156c4 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -406,14 +406,40 @@ public class FederationMap {
return ret.toArray(new Future[0]);
}
+ @SuppressWarnings("unchecked")
+ public Future<FederatedResponse>[] executeRoll(long tid, boolean wait,
+ FederatedRequest frEnd, FederatedRequest frStart, long rlen)
+ {
+ // executes step1[] - step 2 - ... step4 (only first step
federated-data-specific)
+ setThreadID(tid, new FederatedRequest[]{frStart, frEnd});
+ List<Future<FederatedResponse>> ret = new ArrayList<>();
+
+ for(Pair<FederatedRange, FederatedData> e : _fedMap) {
+ if (e.getKey().getEndDims()[0] == rlen) {
+
ret.add(e.getValue().executeFederatedOperation(frEnd));
+ } else if (e.getKey().getBeginDims()[0] == 0){
+
ret.add(e.getValue().executeFederatedOperation(frStart));
+ }
+ }
+
+ // prepare results (future federated responses), with optional
wait to ensure the
+ // order of requests without data dependencies (e.g., cleanup
RPCs)
+ if(wait)
+ FederationUtils.waitFor(ret);
+ return (Future<FederatedResponse>[])ret.toArray(new Future[0]);
+ }
+
public List<Pair<FederatedRange, Future<FederatedResponse>>>
requestFederatedData() {
if(!isInitialized())
throw new DMLRuntimeException("Federated matrix read
only supported on initialized FederatedData");
List<Pair<FederatedRange, Future<FederatedResponse>>>
readResponses = new ArrayList<>();
- FederatedRequest request = new
FederatedRequest(RequestType.GET_VAR, _ID);
- for(Pair<FederatedRange, FederatedData> e : _fedMap)
+
+ for(Pair<FederatedRange, FederatedData> e : _fedMap){
+ FederatedRequest request = new
FederatedRequest(RequestType.GET_VAR, e.getValue().getVarID());
readResponses.add(Pair.of(e.getKey(),
e.getValue().executeFederatedOperation(request)));
+ }
+
return readResponses;
}
@@ -692,6 +718,7 @@ public class FederationMap {
}
}
+
private static class MappingTask implements Callable<Void> {
private final FederatedRange _range;
private final FederatedData _data;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index f61e86e800..820d07031d 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -86,6 +86,7 @@ public class FEDInstructionParser extends InstructionParser
String2FEDInstructionType.put( "r'" , FEDType.Reorg );
String2FEDInstructionType.put( "rdiag" , FEDType.Reorg );
String2FEDInstructionType.put( "rev" , FEDType.Reorg );
+ String2FEDInstructionType.put( "roll" , FEDType.Reorg );
//String2FEDInstructionType.put( "rshape" , FEDType.Reorg );
Not supported by ReorgFEDInstruction parser!
//String2FEDInstructionType.put( "rsort" , FEDType.Reorg );
Not supported by ReorgFEDInstruction parser!
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java
index e7b3000d52..ab105a9585 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java
@@ -86,7 +86,7 @@ public class ReorgCPInstruction extends UnaryCPInstruction {
* @param istr ?
*/
private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out,
CPOperand shift, String opcode, String istr) {
- super(CPType.Reorg, op, in, out, opcode, istr);
+ super(CPType.Reorg, op, in, shift, out, opcode, istr);
_col = null;
_desc = null;
_ixret = null;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index c10ca27259..2c8748f783 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -36,6 +36,7 @@ import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
@@ -43,6 +44,7 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.functionobjects.DiagIndex;
import org.apache.sysds.runtime.functionobjects.RevIndex;
+import org.apache.sysds.runtime.functionobjects.RollIndex;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -57,6 +59,8 @@ import
org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
public class ReorgFEDInstruction extends UnaryFEDInstruction {
+ // roll-specific attributes
+ private CPOperand _shift = null;
public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out,
String opcode, String istr, FederatedOutput fedOut) {
super(FEDType.Reorg, op, in1, out, opcode, istr, fedOut);
@@ -66,14 +70,29 @@ public class ReorgFEDInstruction extends
UnaryFEDInstruction {
super(FEDType.Reorg, op, in1, out, opcode, istr);
}
+ private ReorgFEDInstruction(Operator op, CPOperand in, CPOperand shift,
CPOperand out, String opcode, String istr, FederatedOutput fedOut) {
+ super(FEDType.Reorg, op, in, shift, out, opcode, istr, fedOut);
+ _shift = shift;
+ }
+
public static ReorgFEDInstruction parseInstruction(ReorgCPInstruction
rinst) {
- return new ReorgFEDInstruction(rinst.getOperator(),
rinst.input1, rinst.output, rinst.getOpcode(),
- rinst.getInstructionString(), FederatedOutput.NONE);
+ if (rinst.input2 != null) {
+ return new ReorgFEDInstruction(rinst.getOperator(),
rinst.input1, rinst.input2, rinst.output, rinst.getOpcode(),
+ rinst.getInstructionString(),
FederatedOutput.NONE);
+ } else{
+ return new ReorgFEDInstruction(rinst.getOperator(),
rinst.input1, rinst.output, rinst.getOpcode(),
+ rinst.getInstructionString(),
FederatedOutput.NONE);
+ }
}
public static ReorgFEDInstruction parseInstruction(ReorgSPInstruction
rinst) {
- return new ReorgFEDInstruction(rinst.getOperator(),
rinst.input1, rinst.output, rinst.getOpcode(),
- rinst.getInstructionString(), FederatedOutput.NONE);
+ if (rinst.input2 != null) {
+ return new ReorgFEDInstruction(rinst.getOperator(),
rinst.input1, rinst.input2, rinst.output, rinst.getOpcode(),
+ rinst.getInstructionString(),
FederatedOutput.NONE);
+ } else{
+ return new ReorgFEDInstruction(rinst.getOperator(),
rinst.input1, rinst.output, rinst.getOpcode(),
+ rinst.getInstructionString(),
FederatedOutput.NONE);
+ }
}
public static ReorgFEDInstruction parseInstruction(String str) {
@@ -105,6 +124,15 @@ public class ReorgFEDInstruction extends
UnaryFEDInstruction {
return new ReorgFEDInstruction(new
ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str,
fedOut);
}
+ else if (opcode.equalsIgnoreCase("roll")) {
+ InstructionUtils.checkNumFields(str, 3);
+ in.split(parts[1]);
+ out.split(parts[3]);
+ CPOperand shift = new CPOperand(parts[2]);
+ fedOut = parseFedOutFlag(str, 3);
+ return new ReorgFEDInstruction(new ReorgOperator(new
RollIndex(0)),
+ in, out, shift, opcode, str, fedOut);
+ }
else {
throw new DMLRuntimeException("ReorgFEDInstruction:
unsupported opcode: " + opcode);
}
@@ -167,6 +195,36 @@ public class ReorgFEDInstruction extends
UnaryFEDInstruction {
.setBlocksize(mo1.getBlocksize()).setNonZeros(nnz);
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+ optionalForceLocal(out);
+ } else if (instOpcode.equalsIgnoreCase("roll")) {
+ long rlen = mo1.getNumRows();
+ long shift = ec.getScalarInput(_shift).getLongValue();
+ shift %= (rlen != 0 ? rlen : 1); // roll matrix with
axis=none
+
+ long inID = mo1.getFedMapping().getID();
+ long outEndID = FederationUtils.getNextFedDataID();
+ long outStartID = FederationUtils.getNextFedDataID();
+
+ List<Pair<FederatedRange, FederatedData>> inMap =
mo1.getFedMapping().getMap();
+ Pair<FederationMap, Long> rollResult = rollFedMap(
+ inMap, inID, outEndID, outStartID, shift, rlen,
mo1.getFedMapping().getType());
+ long length = rollResult.getValue();
+ FederationMap outFedMap = rollResult.getKey();
+
+ FederatedRequest frEnd = new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, outEndID,
+ new
ReorgFEDInstruction.SliceMatrix(inID, outEndID, length, true));
+ FederatedRequest frStart = new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, outStartID,
+ new
ReorgFEDInstruction.SliceMatrix(inID, outStartID, length, false));
+ Future<FederatedResponse>[] ffr =
outFedMap.executeRoll(getTID(), true, frEnd, frStart, rlen);
+
+ //derive output federated mapping
+ MatrixObject out = ec.getMatrixObject(output);
+ long nnz = (mo1.getNnz() != -1) ? mo1.getNnz() :
FederationUtils.sumNonZeros(ffr);
+ out.getDataCharacteristics()
+ .setDimension(mo1.getNumRows(),
mo1.getNumColumns())
+ .setBlocksize(mo1.getBlocksize())
+ .setNonZeros(nnz);
+ out.setFedMapping(outFedMap);
optionalForceLocal(out);
}
else if (instOpcode.equals("rdiag")) {
@@ -189,6 +247,40 @@ public class ReorgFEDInstruction extends
UnaryFEDInstruction {
}
}
+
+ public Pair<FederationMap, Long> rollFedMap(List<Pair<FederatedRange,
FederatedData>> oldMap, long inID,
+
long outEndID, long outStartID, long shift, long rlen, FType
type) {
+ List<Pair<FederatedRange, FederatedData>> map = new
ArrayList<>();
+ long length = 0;
+
+ for(Map.Entry<FederatedRange, FederatedData> e : oldMap) {
+ if(e.getKey().getSize() == 0) continue;
+ FederatedRange fedRange = new
FederatedRange(e.getKey());
+ long beginRow = fedRange.getBeginDims()[0] + shift;
+ long endRow = fedRange.getEndDims()[0] + shift;
+
+ beginRow = beginRow > rlen ? beginRow - rlen : beginRow;
+ endRow = endRow > rlen ? endRow - rlen : endRow;
+
+ if (beginRow < endRow) {
+ fedRange.setBeginDim(0, beginRow);
+ fedRange.setEndDim(0, endRow);
+ map.add(Pair.of(fedRange,
e.getValue().copyWithNewID(inID)));
+ } else {
+ length = rlen - beginRow;
+ fedRange.setBeginDim(0, beginRow);
+ fedRange.setEndDim(0, rlen);
+ map.add(Pair.of(fedRange,
e.getValue().copyWithNewID(outEndID)));
+
+ FederatedRange startRange = new
FederatedRange(fedRange);
+ startRange.setBeginDim(0, 0);
+ startRange.setEndDim(0, endRow);
+ map.add(Pair.of(startRange,
e.getValue().copyWithNewID(outStartID)));
+ }
+ }
+ return Pair.of(new FederationMap(outEndID, map, type), length);
+ }
+
/**
* Update the federated ranges of result and return the updated
federation map.
* @param result RdiagResult for which the fedmap is updated
@@ -307,6 +399,51 @@ public class ReorgFEDInstruction extends
UnaryFEDInstruction {
return new RdiagResult(diagFedMap, dcs);
}
+ public static class SliceMatrix extends FederatedUDF {
+ private static final long serialVersionUID =
-3466926635958851402L;
+ private final long _outputID;
+ private final int _sliceRow;
+ private final boolean _isRight;
+
+ private SliceMatrix(long input, long outputID, long sliceRow,
boolean isRight) {
+ super(new long[] {input});
+ _outputID = outputID;
+ _sliceRow = (int) sliceRow;
+ _isRight = isRight;
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data...
data) {
+ MatrixBlock oriBlock = ((MatrixObject)
data[0]).acquireReadAndRelease();
+ MatrixBlock resBlock;
+
+ if (_sliceRow != 0){
+ if (_isRight){
+ resBlock = oriBlock.slice(0,
_sliceRow-1, 0,
+
oriBlock.getNumColumns()-1, new MatrixBlock());
+ } else{
+ resBlock = oriBlock.slice(_sliceRow,
oriBlock.getNumRows()-1,
+ 0,
oriBlock.getNumColumns()-1, new MatrixBlock());
+ }
+ } else{
+ resBlock = oriBlock;
+ }
+ ec.setMatrixOutput(String.valueOf(_outputID), resBlock);
+ return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, resBlock);
+ }
+
+ @Override
+ public List<Long> getOutputIds() {
+ return new ArrayList<>(Arrays.asList(_outputID));
+ }
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return Pair.of(String.valueOf(_outputID),
+ new LineageItem());
+ }
+ }
+
public static class Rdiag extends FederatedUDF {
private static final long serialVersionUID =
-3466926635958851402L;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
index f025983e74..2311a1afe2 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
@@ -88,7 +88,8 @@ public abstract class UnaryFEDInstruction extends
ComputationFEDInstruction {
}
}
else if(inst instanceof ReorgCPInstruction &&
- (inst.getOpcode().equals("r'") ||
inst.getOpcode().equals("rdiag") || inst.getOpcode().equals("rev"))) {
+ (inst.getOpcode().equals("r'") ||
inst.getOpcode().equals("rdiag")
+ ||
inst.getOpcode().equals("rev") || inst.getOpcode().equals("roll"))) {
ReorgCPInstruction rinst = (ReorgCPInstruction) inst;
CacheableData<?> mo = ec.getCacheableData(rinst.input1);
@@ -157,7 +158,8 @@ public abstract class UnaryFEDInstruction extends
ComputationFEDInstruction {
return
AggregateUnaryFEDInstruction.parseInstruction(auinstruction);
}
else if(inst instanceof ReorgSPInstruction &&
- (inst.getOpcode().equals("r'") ||
inst.getOpcode().equals("rdiag") || inst.getOpcode().equals("rev"))) {
+ (inst.getOpcode().equals("r'") ||
inst.getOpcode().equals("rdiag")
+ ||
inst.getOpcode().equals("rev") || inst.getOpcode().equals("roll"))) {
ReorgSPInstruction rinst = (ReorgSPInstruction) inst;
CacheableData<?> mo = ec.getCacheableData(rinst.input1);
if((mo instanceof MatrixObject || mo instanceof
FrameObject) && mo.isFederated() &&
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
index b096405959..1a4f8fef0d 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
@@ -85,7 +85,7 @@ public class ReorgSPInstruction extends UnarySPInstruction {
}
private ReorgSPInstruction(Operator op, CPOperand in, CPOperand out,
CPOperand shift, String opcode, String istr) {
- this(op, in, out, opcode, istr);
+ super(SPType.Reorg, op, in, shift, null, out, opcode, istr);
_shift = shift;
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java
new file mode 100644
index 0000000000..f242710338
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives.part2;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedRollTest extends AutomatedTestBase {
+ // private static final Log LOG =
LogFactory.getLog(FederatedRightIndexTest.class.getName());
+
+ private final static String TEST_NAME = "FederatedRollTest";
+
+ private final static String TEST_DIR = "functions/federated/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
FederatedRollTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+
+ @Parameterized.Parameter(2)
+ public boolean rowPartitioned;
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][]{{100, 12, true}, {100, 12,
false}});
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"S"}));
+ }
+
+ @Test
+ public void testRollCP() {
+ runRollTest(ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ @Ignore
+ public void testRollSP() {
+ runRollTest(ExecMode.SPARK);
+ }
+
+ @Test
+ public void federatedCompilationRollCP() {
+ runRollTest(ExecMode.SINGLE_NODE, true);
+ }
+
+ @Test
+ @Ignore
+ public void federatedCompilationRollSP() {
+ runRollTest(ExecMode.SPARK, true);
+ }
+
+ private void runRollTest(ExecMode execMode) {
+ runRollTest(execMode, false);
+ }
+
+ private void runRollTest(ExecMode execMode, boolean
activateFedCompilation) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ ExecMode platformOld = rtplatform;
+
+ if (rtplatform == ExecMode.SPARK)
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // write input matrices
+ int r = rows;
+ int c = cols / 4;
+ if (rowPartitioned) {
+ r = rows / 4;
+ c = cols;
+ }
+
+ double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
+ double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
+ double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
+ double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
+
+ for (int k : new int[]{1, 2, 3}) {
+ Arrays.fill(X3[k], 0);
+ }
+
+ MatrixCharacteristics mc = new MatrixCharacteristics(r, c,
blocksize, r * c);
+ writeInputMatrixWithMTD("X1", X1, false, mc);
+ writeInputMatrixWithMTD("X2", X2, false, mc);
+ writeInputMatrixWithMTD("X3", X3, false, mc);
+ writeInputMatrixWithMTD("X4", X4, false, mc);
+
+ // empty script name because we don't execute any script, just
start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ int port3 = getRandomAvailablePort();
+ int port4 = getRandomAvailablePort();
+ Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
+ Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
+ Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
+ Process t4 = startLocalFedWorker(port4);
+
+
+ try {
+ if (!isAlive(t1, t2, t3, t4))
+ throw new RuntimeException("Failed starting
federated worker");
+ rtplatform = execMode;
+ if (rtplatform == ExecMode.SPARK) {
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+ TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[]{"-stats", "100", "-args",
input("X1"), input("X2"), input("X3"), input("X4"),
+
Boolean.toString(rowPartitioned).toUpperCase(), expected("S")};
+
+ runTest(null);
+
+ OptimizerUtils.FEDERATED_COMPILATION =
activateFedCompilation;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[]{"-stats", "100", "-nvargs",
+ "in_X1=" +
TestUtils.federatedAddress(port1, input("X1")),
+ "in_X2=" +
TestUtils.federatedAddress(port2, input("X2")),
+ "in_X3=" +
TestUtils.federatedAddress(port3, input("X3")),
+ "in_X4=" +
TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + cols,
+ "rP=" +
Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")};
+
+ runTest(null);
+
+ // compare via files
+ compareResults(0.01, "Stat-DML1", "Stat-DML2");
+
+
Assert.assertTrue(heavyHittersContainsString("fed_roll"));
+
+ // check that federated input files are still existing
+
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+ } finally {
+ TestUtils.shutdownThreads(t1, t2, t3, t4);
+
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ OptimizerUtils.FEDERATED_COMPILATION = false;
+ }
+ }
+}
diff --git a/src/test/scripts/functions/federated/FederatedRollTest.dml
b/src/test/scripts/functions/federated/FederatedRollTest.dml
new file mode 100644
index 0000000000..cb464256ed
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRollTest.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+ #
+ # Licensed to the Apache Software Foundation (ASF) under one
+ # or more contributor license agreements. See the NOTICE file
+ # distributed with this work for additional information
+ # regarding copyright ownership. The ASF licenses this file
+ # to you under the Apache License, Version 2.0 (the
+ # "License"); you may not use this file except in compliance
+ # with the License. You may obtain a copy of the License at
+ #
+ # http://www.apache.org/licenses/LICENSE-2.0
+ #
+ # Unless required by applicable law or agreed to in writing,
+ # software distributed under the License is distributed on an
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ # KIND, either express or implied. See the License for the
+ # specific language governing permissions and limitations
+ # under the License.
+ #
+ #-------------------------------------------------------------
+if ($rP) {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+} else {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4),
list($rows, $cols/2),
+ list(0,$cols/2), list($rows, 3*($cols/4)), list(0,
3*($cols/4)), list($rows, $cols)));
+}
+
+s = roll(A, 1);
+write(s, $out_S);
\ No newline at end of file
diff --git
a/src/test/scripts/functions/federated/FederatedRollTestReference.dml
b/src/test/scripts/functions/federated/FederatedRollTestReference.dml
new file mode 100644
index 0000000000..694bd5f1d4
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRollTestReference.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+ #
+ # Licensed to the Apache Software Foundation (ASF) under one
+ # or more contributor license agreements. See the NOTICE file
+ # distributed with this work for additional information
+ # regarding copyright ownership. The ASF licenses this file
+ # to you under the Apache License, Version 2.0 (the
+ # "License"); you may not use this file except in compliance
+ # with the License. You may obtain a copy of the License at
+ #
+ # http://www.apache.org/licenses/LICENSE-2.0
+ #
+ # Unless required by applicable law or agreed to in writing,
+ # software distributed under the License is distributed on an
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ # KIND, either express or implied. See the License for the
+ # specific language governing permissions and limitations
+ # under the License.
+ #
+ #-------------------------------------------------------------
+
+ if($5) { A = rbind(read($1), read($2), read($3), read($4)); }
+ else { A = cbind(read($1), read($2), read($3), read($4)); }
+
+ s = roll(A, 1);
+ write(s, $6);
\ No newline at end of file