This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 29b3c612de [SYSTEMDS-3729] Add missing federated roll reorg operations
29b3c612de is described below

commit 29b3c612deb2312130987cc906d6790e7c7187ca
Author: min-guk <[email protected]>
AuthorDate: Sun Oct 20 18:15:36 2024 +0200

    [SYSTEMDS-3729] Add missing federated roll reorg operations
    
    Closes #2126.
---
 .../controlprogram/federated/FederationMap.java    |  31 +++-
 .../runtime/instructions/FEDInstructionParser.java |   1 +
 .../instructions/cp/ReorgCPInstruction.java        |   2 +-
 .../instructions/fed/ReorgFEDInstruction.java      | 145 +++++++++++++++-
 .../instructions/fed/UnaryFEDInstruction.java      |   6 +-
 .../instructions/spark/ReorgSPInstruction.java     |   2 +-
 .../primitives/part2/FederatedRollTest.java        | 187 +++++++++++++++++++++
 .../functions/federated/FederatedRollTest.dml      |  32 ++++
 .../federated/FederatedRollTestReference.dml       |  26 +++
 9 files changed, 422 insertions(+), 10 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 985fdb056e..91e6c156c4 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -406,14 +406,40 @@ public class FederationMap {
                return ret.toArray(new Future[0]);
        }
 
+       @SuppressWarnings("unchecked")
+       public Future<FederatedResponse>[] executeRoll(long tid, boolean wait,
+               FederatedRequest frEnd, FederatedRequest frStart, long rlen)
+       {
+               // executes step1[] - step 2 - ... step4 (only first step 
federated-data-specific)
+               setThreadID(tid, new FederatedRequest[]{frStart, frEnd});
+               List<Future<FederatedResponse>> ret = new ArrayList<>();
+
+               for(Pair<FederatedRange, FederatedData> e : _fedMap) {
+                       if (e.getKey().getEndDims()[0] == rlen) {
+                               
ret.add(e.getValue().executeFederatedOperation(frEnd));
+                       } else if (e.getKey().getBeginDims()[0] == 0){
+                               
ret.add(e.getValue().executeFederatedOperation(frStart));
+                       }
+               }
+
+               // prepare results (future federated responses), with optional 
wait to ensure the
+               // order of requests without data dependencies (e.g., cleanup 
RPCs)
+               if(wait)
+                       FederationUtils.waitFor(ret);
+               return (Future<FederatedResponse>[])ret.toArray(new Future[0]);
+       }
+
        public List<Pair<FederatedRange, Future<FederatedResponse>>> 
requestFederatedData() {
                if(!isInitialized())
                        throw new DMLRuntimeException("Federated matrix read 
only supported on initialized FederatedData");
 
                List<Pair<FederatedRange, Future<FederatedResponse>>> 
readResponses = new ArrayList<>();
-               FederatedRequest request = new 
FederatedRequest(RequestType.GET_VAR, _ID);
-               for(Pair<FederatedRange, FederatedData> e : _fedMap)
+
+               for(Pair<FederatedRange, FederatedData> e : _fedMap){
+                       FederatedRequest request = new 
FederatedRequest(RequestType.GET_VAR, e.getValue().getVarID());
                        readResponses.add(Pair.of(e.getKey(), 
e.getValue().executeFederatedOperation(request)));
+               }
+
                return readResponses;
        }
 
@@ -692,6 +718,7 @@ public class FederationMap {
                }
        }
 
+
        private static class MappingTask implements Callable<Void> {
                private final FederatedRange _range;
                private final FederatedData _data;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index f61e86e800..820d07031d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -86,6 +86,7 @@ public class FEDInstructionParser extends InstructionParser
                String2FEDInstructionType.put( "r'"     , FEDType.Reorg );
                String2FEDInstructionType.put( "rdiag"  , FEDType.Reorg );
                String2FEDInstructionType.put( "rev"    , FEDType.Reorg );
+               String2FEDInstructionType.put( "roll"    , FEDType.Reorg );
                //String2FEDInstructionType.put( "rshape" , FEDType.Reorg ); 
Not supported by ReorgFEDInstruction parser!
                //String2FEDInstructionType.put( "rsort"  , FEDType.Reorg ); 
Not supported by ReorgFEDInstruction parser!
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java
index e7b3000d52..ab105a9585 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java
@@ -86,7 +86,7 @@ public class ReorgCPInstruction extends UnaryCPInstruction {
         * @param istr   ?
         */
        private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, 
CPOperand shift, String opcode, String istr) {
-               super(CPType.Reorg, op, in, out, opcode, istr);
+               super(CPType.Reorg, op, in, shift, out, opcode, istr);
                _col = null;
                _desc = null;
                _ixret = null;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index c10ca27259..2c8748f783 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -36,6 +36,7 @@ import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
@@ -43,6 +44,7 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.functionobjects.DiagIndex;
 import org.apache.sysds.runtime.functionobjects.RevIndex;
+import org.apache.sysds.runtime.functionobjects.RollIndex;
 import org.apache.sysds.runtime.functionobjects.SwapIndex;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -57,6 +59,8 @@ import 
org.apache.sysds.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 
 public class ReorgFEDInstruction extends UnaryFEDInstruction {
+       // roll-specific attributes
+       private CPOperand _shift = null;
 
        public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, 
String opcode, String istr, FederatedOutput fedOut) {
                super(FEDType.Reorg, op, in1, out, opcode, istr, fedOut);
@@ -66,14 +70,29 @@ public class ReorgFEDInstruction extends 
UnaryFEDInstruction {
                super(FEDType.Reorg, op, in1, out, opcode, istr);
        }
 
+       private ReorgFEDInstruction(Operator op, CPOperand in, CPOperand shift, 
CPOperand out,  String opcode, String istr, FederatedOutput fedOut) {
+               super(FEDType.Reorg, op, in, shift, out, opcode, istr, fedOut);
+               _shift = shift;
+       }
+
        public static ReorgFEDInstruction parseInstruction(ReorgCPInstruction 
rinst) {
-               return new ReorgFEDInstruction(rinst.getOperator(), 
rinst.input1, rinst.output, rinst.getOpcode(),
-                       rinst.getInstructionString(), FederatedOutput.NONE);
+               if (rinst.input2 != null) {
+                       return new ReorgFEDInstruction(rinst.getOperator(), 
rinst.input1, rinst.input2, rinst.output, rinst.getOpcode(),
+                                       rinst.getInstructionString(), 
FederatedOutput.NONE);
+               } else{
+                       return new ReorgFEDInstruction(rinst.getOperator(), 
rinst.input1, rinst.output, rinst.getOpcode(),
+                                       rinst.getInstructionString(), 
FederatedOutput.NONE);
+               }
        }
 
        public static ReorgFEDInstruction parseInstruction(ReorgSPInstruction 
rinst) {
-               return new ReorgFEDInstruction(rinst.getOperator(), 
rinst.input1, rinst.output, rinst.getOpcode(),
-                       rinst.getInstructionString(), FederatedOutput.NONE);
+               if (rinst.input2 != null) {
+                       return new ReorgFEDInstruction(rinst.getOperator(), 
rinst.input1, rinst.input2, rinst.output, rinst.getOpcode(),
+                                       rinst.getInstructionString(), 
FederatedOutput.NONE);
+               } else{
+                       return new ReorgFEDInstruction(rinst.getOperator(), 
rinst.input1, rinst.output, rinst.getOpcode(),
+                                       rinst.getInstructionString(), 
FederatedOutput.NONE);
+               }
        }
 
        public static ReorgFEDInstruction parseInstruction(String str) {
@@ -105,6 +124,15 @@ public class ReorgFEDInstruction extends 
UnaryFEDInstruction {
                        return new ReorgFEDInstruction(new 
ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str,
                                fedOut);
                }
+               else if (opcode.equalsIgnoreCase("roll")) {
+                       InstructionUtils.checkNumFields(str, 3);
+                       in.split(parts[1]);
+                       out.split(parts[3]);
+                       CPOperand shift = new CPOperand(parts[2]);
+                       fedOut = parseFedOutFlag(str, 3);
+                       return new ReorgFEDInstruction(new ReorgOperator(new 
RollIndex(0)),
+                                       in, out, shift, opcode, str, fedOut);
+               }
                else {
                        throw new DMLRuntimeException("ReorgFEDInstruction: 
unsupported opcode: " + opcode);
                }
@@ -167,6 +195,36 @@ public class ReorgFEDInstruction extends 
UnaryFEDInstruction {
                                
.setBlocksize(mo1.getBlocksize()).setNonZeros(nnz);
                        
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
 
+                       optionalForceLocal(out);
+               } else if (instOpcode.equalsIgnoreCase("roll")) {
+                       long rlen = mo1.getNumRows();
+                       long shift = ec.getScalarInput(_shift).getLongValue();
+                       shift %= (rlen != 0 ? rlen : 1); // roll matrix with 
axis=none
+
+                       long inID = mo1.getFedMapping().getID();
+                       long outEndID = FederationUtils.getNextFedDataID();
+                       long outStartID = FederationUtils.getNextFedDataID();
+
+                       List<Pair<FederatedRange, FederatedData>> inMap = 
mo1.getFedMapping().getMap();
+                       Pair<FederationMap, Long> rollResult = rollFedMap(
+                               inMap, inID, outEndID, outStartID, shift, rlen, 
mo1.getFedMapping().getType());
+                       long length = rollResult.getValue();
+                       FederationMap outFedMap = rollResult.getKey();
+
+                       FederatedRequest frEnd = new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, outEndID,
+                                       new 
ReorgFEDInstruction.SliceMatrix(inID, outEndID, length, true));
+                       FederatedRequest frStart = new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, outStartID,
+                                       new 
ReorgFEDInstruction.SliceMatrix(inID, outStartID, length, false));
+                       Future<FederatedResponse>[] ffr = 
outFedMap.executeRoll(getTID(), true, frEnd, frStart, rlen);
+
+                       //derive output federated mapping
+                       MatrixObject out = ec.getMatrixObject(output);
+                       long nnz = (mo1.getNnz() != -1) ? mo1.getNnz() : 
FederationUtils.sumNonZeros(ffr);
+                       out.getDataCharacteristics()
+                               .setDimension(mo1.getNumRows(), 
mo1.getNumColumns())
+                               .setBlocksize(mo1.getBlocksize())
+                               .setNonZeros(nnz);
+                       out.setFedMapping(outFedMap);
                        optionalForceLocal(out);
                }
                else if (instOpcode.equals("rdiag")) {
@@ -189,6 +247,40 @@ public class ReorgFEDInstruction extends 
UnaryFEDInstruction {
                }
        }
 
+
+       public Pair<FederationMap, Long> rollFedMap(List<Pair<FederatedRange, 
FederatedData>> oldMap, long inID,
+                                                                               
                long outEndID, long outStartID, long shift, long rlen, FType 
type) {
+               List<Pair<FederatedRange, FederatedData>> map = new 
ArrayList<>();
+               long length = 0;
+
+               for(Map.Entry<FederatedRange, FederatedData> e : oldMap) {
+                       if(e.getKey().getSize() == 0) continue;
+                       FederatedRange fedRange = new 
FederatedRange(e.getKey());
+                       long beginRow = fedRange.getBeginDims()[0] + shift;
+                       long endRow = fedRange.getEndDims()[0] + shift;
+
+                       beginRow = beginRow > rlen ? beginRow - rlen : beginRow;
+                       endRow = endRow > rlen ? endRow - rlen : endRow;
+
+                       if (beginRow < endRow) {
+                               fedRange.setBeginDim(0, beginRow);
+                               fedRange.setEndDim(0, endRow);
+                               map.add(Pair.of(fedRange, 
e.getValue().copyWithNewID(inID)));
+                       } else {
+                               length = rlen - beginRow;
+                               fedRange.setBeginDim(0, beginRow);
+                               fedRange.setEndDim(0, rlen);
+                               map.add(Pair.of(fedRange, 
e.getValue().copyWithNewID(outEndID)));
+
+                               FederatedRange startRange = new 
FederatedRange(fedRange);
+                               startRange.setBeginDim(0, 0);
+                               startRange.setEndDim(0, endRow);
+                               map.add(Pair.of(startRange, 
e.getValue().copyWithNewID(outStartID)));
+                       }
+               }
+               return Pair.of(new FederationMap(outEndID, map, type), length);
+       }
+
        /**
         * Update the federated ranges of result and return the updated 
federation map.
         * @param result RdiagResult for which the fedmap is updated
@@ -307,6 +399,51 @@ public class ReorgFEDInstruction extends 
UnaryFEDInstruction {
                return new RdiagResult(diagFedMap, dcs);
        }
 
+       public static class SliceMatrix extends FederatedUDF {
+               private static final long serialVersionUID = 
-3466926635958851402L;
+               private final long _outputID;
+               private final int _sliceRow;
+               private final boolean _isRight;
+
+               private SliceMatrix(long input, long outputID, long sliceRow, 
boolean isRight) {
+                       super(new long[] {input});
+                       _outputID = outputID;
+                       _sliceRow = (int) sliceRow;
+                       _isRight = isRight;
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       MatrixBlock oriBlock = ((MatrixObject) 
data[0]).acquireReadAndRelease();
+                       MatrixBlock resBlock;
+
+                       if (_sliceRow != 0){
+                               if (_isRight){
+                                       resBlock = oriBlock.slice(0, 
_sliceRow-1, 0,
+                                                       
oriBlock.getNumColumns()-1, new MatrixBlock());
+                               } else{
+                                       resBlock = oriBlock.slice(_sliceRow, 
oriBlock.getNumRows()-1,
+                                                       0, 
oriBlock.getNumColumns()-1, new MatrixBlock());
+                               }
+                       } else{
+                               resBlock = oriBlock;
+                       }
+                       ec.setMatrixOutput(String.valueOf(_outputID), resBlock);
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, resBlock);
+               }
+
+               @Override
+               public List<Long> getOutputIds() {
+                       return new ArrayList<>(Arrays.asList(_outputID));
+               }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return Pair.of(String.valueOf(_outputID),
+                                       new LineageItem());
+               }
+       }
+
        public static class Rdiag extends FederatedUDF {
 
                private static final long serialVersionUID = 
-3466926635958851402L;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
index f025983e74..2311a1afe2 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
@@ -88,7 +88,8 @@ public abstract class UnaryFEDInstruction extends 
ComputationFEDInstruction {
                        }
                }
                else if(inst instanceof ReorgCPInstruction &&
-                       (inst.getOpcode().equals("r'") || 
inst.getOpcode().equals("rdiag") || inst.getOpcode().equals("rev"))) {
+                               (inst.getOpcode().equals("r'") || 
inst.getOpcode().equals("rdiag")
+                                               || 
inst.getOpcode().equals("rev") || inst.getOpcode().equals("roll"))) {
                        ReorgCPInstruction rinst = (ReorgCPInstruction) inst;
                        CacheableData<?> mo = ec.getCacheableData(rinst.input1);
 
@@ -157,7 +158,8 @@ public abstract class UnaryFEDInstruction extends 
ComputationFEDInstruction {
                                        return 
AggregateUnaryFEDInstruction.parseInstruction(auinstruction);
                }
                else if(inst instanceof ReorgSPInstruction &&
-                       (inst.getOpcode().equals("r'") || 
inst.getOpcode().equals("rdiag") || inst.getOpcode().equals("rev"))) {
+                               (inst.getOpcode().equals("r'") || 
inst.getOpcode().equals("rdiag")
+                                               || 
inst.getOpcode().equals("rev") || inst.getOpcode().equals("roll"))) {
                        ReorgSPInstruction rinst = (ReorgSPInstruction) inst;
                        CacheableData<?> mo = ec.getCacheableData(rinst.input1);
                        if((mo instanceof MatrixObject || mo instanceof 
FrameObject) && mo.isFederated() &&
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
index b096405959..1a4f8fef0d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
@@ -85,7 +85,7 @@ public class ReorgSPInstruction extends UnarySPInstruction {
        }
 
        private ReorgSPInstruction(Operator op, CPOperand in, CPOperand out, 
CPOperand shift, String opcode, String istr) {
-               this(op, in, out, opcode, istr);
+               super(SPType.Reorg, op, in, shift, null, out, opcode, istr);
                _shift = shift;
        }
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java
new file mode 100644
index 0000000000..f242710338
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives.part2;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedRollTest extends AutomatedTestBase {
+       // private static final Log LOG = 
LogFactory.getLog(FederatedRightIndexTest.class.getName());
+
+       private final static String TEST_NAME = "FederatedRollTest";
+
+       private final static String TEST_DIR = "functions/federated/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedRollTest.class.getSimpleName() + "/";
+
+       private final static int blocksize = 1024;
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+
+       @Parameterized.Parameter(2)
+       public boolean rowPartitioned;
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               return Arrays.asList(new Object[][]{{100, 12, true}, {100, 12, 
false}});
+       }
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"S"}));
+       }
+
+       @Test
+       public void testRollCP() {
+               runRollTest(ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       @Ignore
+       public void testRollSP() {
+               runRollTest(ExecMode.SPARK);
+       }
+
+       @Test
+       public void federatedCompilationRollCP() {
+               runRollTest(ExecMode.SINGLE_NODE, true);
+       }
+
+       @Test
+       @Ignore
+       public void federatedCompilationRollSP() {
+               runRollTest(ExecMode.SPARK, true);
+       }
+
+       private void runRollTest(ExecMode execMode) {
+               runRollTest(execMode, false);
+       }
+
+       private void runRollTest(ExecMode execMode, boolean 
activateFedCompilation) {
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               ExecMode platformOld = rtplatform;
+
+               if (rtplatform == ExecMode.SPARK)
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               // write input matrices
+               int r = rows;
+               int c = cols / 4;
+               if (rowPartitioned) {
+                       r = rows / 4;
+                       c = cols;
+               }
+
+               double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
+               double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
+               double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
+               double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
+
+               for (int k : new int[]{1, 2, 3}) {
+                       Arrays.fill(X3[k], 0);
+               }
+
+               MatrixCharacteristics mc = new MatrixCharacteristics(r, c, 
blocksize, r * c);
+               writeInputMatrixWithMTD("X1", X1, false, mc);
+               writeInputMatrixWithMTD("X2", X2, false, mc);
+               writeInputMatrixWithMTD("X3", X3, false, mc);
+               writeInputMatrixWithMTD("X4", X4, false, mc);
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               int port3 = getRandomAvailablePort();
+               int port4 = getRandomAvailablePort();
+               Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
+               Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
+               Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
+               Process t4 = startLocalFedWorker(port4);
+
+
+               try {
+                       if (!isAlive(t1, t2, t3, t4))
+                               throw new RuntimeException("Failed starting 
federated worker");
+                       rtplatform = execMode;
+                       if (rtplatform == ExecMode.SPARK) {
+                               DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+                       }
+                       TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+                       loadTestConfiguration(config);
+
+                       // Run reference dml script with normal matrix
+                       fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+                       programArgs = new String[]{"-stats", "100", "-args", 
input("X1"), input("X2"), input("X3"), input("X4"),
+                                       
Boolean.toString(rowPartitioned).toUpperCase(), expected("S")};
+
+                       runTest(null);
+
+                       OptimizerUtils.FEDERATED_COMPILATION = 
activateFedCompilation;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[]{"-stats", "100", "-nvargs",
+                                       "in_X1=" + 
TestUtils.federatedAddress(port1, input("X1")),
+                                       "in_X2=" + 
TestUtils.federatedAddress(port2, input("X2")),
+                                       "in_X3=" + 
TestUtils.federatedAddress(port3, input("X3")),
+                                       "in_X4=" + 
TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + cols,
+                                       "rP=" + 
Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")};
+
+                       runTest(null);
+
+                       // compare via files
+                       compareResults(0.01, "Stat-DML1", "Stat-DML2");
+
+                       
Assert.assertTrue(heavyHittersContainsString("fed_roll"));
+
+                       // check that federated input files are still existing
+                       
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+                       
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+                       
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+                       
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+               } finally {
+                       TestUtils.shutdownThreads(t1, t2, t3, t4);
+
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+                       OptimizerUtils.FEDERATED_COMPILATION = false;
+               }
+       }
+}
diff --git a/src/test/scripts/functions/federated/FederatedRollTest.dml 
b/src/test/scripts/functions/federated/FederatedRollTest.dml
new file mode 100644
index 0000000000..cb464256ed
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRollTest.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+ #
+ # Licensed to the Apache Software Foundation (ASF) under one
+ # or more contributor license agreements.  See the NOTICE file
+ # distributed with this work for additional information
+ # regarding copyright ownership.  The ASF licenses this file
+ # to you under the Apache License, Version 2.0 (the
+ # "License"); you may not use this file except in compliance
+ # with the License.  You may obtain a copy of the License at
+ #
+ #   http://www.apache.org/licenses/LICENSE-2.0
+ #
+ # Unless required by applicable law or agreed to in writing,
+ # software distributed under the License is distributed on an
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ # KIND, either express or implied.  See the License for the
+ # specific language governing permissions and limitations
+ # under the License.
+ #
+ #-------------------------------------------------------------
+if ($rP) {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+} else {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+            ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
+               list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
+}
+
+s = roll(A, 1);
+write(s, $out_S);
\ No newline at end of file
diff --git 
a/src/test/scripts/functions/federated/FederatedRollTestReference.dml 
b/src/test/scripts/functions/federated/FederatedRollTestReference.dml
new file mode 100644
index 0000000000..694bd5f1d4
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRollTestReference.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+ #
+ # Licensed to the Apache Software Foundation (ASF) under one
+ # or more contributor license agreements.  See the NOTICE file
+ # distributed with this work for additional information
+ # regarding copyright ownership.  The ASF licenses this file
+ # to you under the Apache License, Version 2.0 (the
+ # "License"); you may not use this file except in compliance
+ # with the License.  You may obtain a copy of the License at
+ #
+ #   http://www.apache.org/licenses/LICENSE-2.0
+ #
+ # Unless required by applicable law or agreed to in writing,
+ # software distributed under the License is distributed on an
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ # KIND, either express or implied.  See the License for the
+ # specific language governing permissions and limitations
+ # under the License.
+ #
+ #-------------------------------------------------------------
+
+ if($5) { A = rbind(read($1), read($2), read($3), read($4)); }
+ else { A = cbind(read($1), read($2), read($3), read($4)); }
+
+ s = roll(A, 1);
+ write(s, $6);
\ No newline at end of file

Reply via email to