This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit e2bd5bf4fe90a66f316898208c851f796117a90c
Author: Olga <[email protected]>
AuthorDate: Tue Nov 10 16:05:30 2020 +0100

    [SYSTEMDS-2726] Federated right indexing
---
 .../controlprogram/federated/FederatedRange.java   |  15 ++
 .../controlprogram/federated/FederationMap.java    |   6 +-
 .../runtime/instructions/fed/FEDInstruction.java   |   1 +
 .../instructions/fed/FEDInstructionUtils.java      |  10 ++
 .../instructions/fed/IndexingFEDInstruction.java   | 113 ++++++++++++
 .../fed/MatrixIndexingFEDInstruction.java          | 144 ++++++++++++++++
 .../primitives/FederatedRightIndexTest.java        | 191 +++++++++++++++++++++
 .../federated/FederatedRightIndexFullTest.dml      |  36 ++++
 .../FederatedRightIndexFullTestReference.dml       |  29 ++++
 .../federated/FederatedRightIndexLeftTest.dml      |  36 ++++
 .../FederatedRightIndexLeftTestReference.dml       |  29 ++++
 .../federated/FederatedRightIndexRightTest.dml     |  36 ++++
 .../FederatedRightIndexRightTestReference.dml      |  29 ++++
 13 files changed, 673 insertions(+), 2 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
index 4289cfe..3bd5734 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
@@ -102,6 +102,21 @@ public class FederatedRange implements 
Comparable<FederatedRange> {
                return Arrays.toString(_beginDims) + " - " + 
Arrays.toString(_endDims);
        }
 
+       @Override public boolean equals(Object o) {
+               if(this == o)
+                       return true;
+               if(o == null || getClass() != o.getClass())
+                       return false;
+               FederatedRange range = (FederatedRange) o;
+               return Arrays.equals(_beginDims, range._beginDims) && 
Arrays.equals(_endDims, range._endDims);
+       }
+
+       @Override public int hashCode() {
+               int result = Arrays.hashCode(_beginDims);
+               result = 31 * result + Arrays.hashCode(_endDims);
+               return result;
+       }
+
        public FederatedRange shift(long rshift, long cshift) {
                //row shift
                _beginDims[0] += rshift;
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 04251fc..b647476 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -224,8 +224,10 @@ public class FederationMap
        public FederationMap copyWithNewID(long id) {
                Map<FederatedRange, FederatedData> map = new TreeMap<>();
                //TODO handling of file path, but no danger as never written
-               for( Entry<FederatedRange, FederatedData> e : 
_fedMap.entrySet() )
-                       map.put(new FederatedRange(e.getKey()), 
e.getValue().copyWithNewID(id));
+               for( Entry<FederatedRange, FederatedData> e : 
_fedMap.entrySet() ) {
+                       if(e.getKey().getSize() != 0)
+                               map.put(new FederatedRange(e.getKey()), 
e.getValue().copyWithNewID(id));
+               }
                return new FederationMap(id, map, _type);
        }
        
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index 9301765..8094c96 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -37,6 +37,7 @@ public abstract class FEDInstruction extends Instruction {
                Tsmm,
                MMChain,
                Reorg,
+               MatrixIndexing
        }
        
        protected final FEDType _fedType;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 795db11..2edc5f2 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -32,6 +32,7 @@ import 
org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.MatrixIndexingCPInstruction;
 import 
org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction;
 import 
org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
@@ -127,6 +128,15 @@ public class FEDInstructionUtils {
                        if( mo.isFederated() )
                                fedinst = 
ReorgFEDInstruction.parseInstruction(rinst.getInstructionString());
                }
+               else if(inst instanceof MatrixIndexingCPInstruction && 
inst.getOpcode().equalsIgnoreCase("rightIndex")) {
+                       // matrix indexing
+                       MatrixIndexingCPInstruction minst = 
(MatrixIndexingCPInstruction) inst;
+                       if(minst.input1.isMatrix()) {
+                               CacheableData<?> fo = 
ec.getCacheableData(minst.input1);
+                               if(fo.isFederated())
+                                       fedinst = 
MatrixIndexingFEDInstruction.parseInstruction(minst.getInstructionString());
+                       }
+               }
                else if(inst instanceof VariableCPInstruction ){
                        VariableCPInstruction ins = (VariableCPInstruction) 
inst;
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
new file mode 100644
index 0000000..15fe1ab
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.lops.LeftIndex;
+import org.apache.sysds.lops.RightIndex;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.util.IndexRange;
+
+public abstract class IndexingFEDInstruction extends  UnaryFEDInstruction {
+       protected final CPOperand rowLower, rowUpper, colLower, colUpper;
+
+       protected IndexingFEDInstruction(CPOperand in, CPOperand rl, CPOperand 
ru, CPOperand cl, CPOperand cu,
+               CPOperand out, String opcode, String istr) {
+               super(FEDInstruction.FEDType.MatrixIndexing, null, in, out, 
opcode, istr);
+               rowLower = rl;
+               rowUpper = ru;
+               colLower = cl;
+               colUpper = cu;
+       }
+
+       protected IndexingFEDInstruction(CPOperand lhsInput, CPOperand 
rhsInput, CPOperand rl, CPOperand ru, CPOperand cl,
+               CPOperand cu, CPOperand out, String opcode, String istr) {
+               super(FEDInstruction.FEDType.MatrixIndexing, null, lhsInput, 
rhsInput, out, opcode, istr);
+               rowLower = rl;
+               rowUpper = ru;
+               colLower = cl;
+               colUpper = cu;
+       }
+
+       protected IndexRange getIndexRange(ExecutionContext ec) {
+               return new IndexRange( //rl, ru, cl, ru
+                       (int) (ec.getScalarInput(rowLower).getLongValue() - 1),
+                       (int) (ec.getScalarInput(rowUpper).getLongValue() - 1),
+                       (int) (ec.getScalarInput(colLower).getLongValue() - 1),
+                       (int) (ec.getScalarInput(colUpper).getLongValue() - 1));
+       }
+
+       public static IndexingFEDInstruction parseInstruction(String str) {
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+               String opcode = parts[0];
+
+               if(opcode.equalsIgnoreCase(RightIndex.OPCODE)) {
+                       if(parts.length == 7) {
+                               CPOperand in, rl, ru, cl, cu, out;
+                               in = new CPOperand(parts[1]);
+                               rl = new CPOperand(parts[2]);
+                               ru = new CPOperand(parts[3]);
+                               cl = new CPOperand(parts[4]);
+                               cu = new CPOperand(parts[5]);
+                               out = new CPOperand(parts[6]);
+                               if(in.getDataType() == Types.DataType.MATRIX)
+                                       return new 
MatrixIndexingFEDInstruction(in, rl, ru, cl, cu, out, opcode, str);
+                                       //                              else 
if( in.getDataType() == Types.DataType.FRAME )
+                                       //                                      
return new FrameIndexingCPInstruction(in, rl, ru, cl, cu, out, opcode, str);
+                                       //                              else 
if( in.getDataType() == Types.DataType.LIST )
+                                       //                                      
return new ListIndexingCPInstruction(in, rl, ru, cl, cu, out, opcode, str);
+                               else
+                                       throw new DMLRuntimeException("Can 
index only on matrices, frames, and lists.");
+                       }
+                       else {
+                               throw new DMLRuntimeException("Invalid number 
of operands in instruction: " + str);
+                       }
+               }
+               //              else if ( 
opcode.equalsIgnoreCase(LeftIndex.OPCODE)) {
+               //                      if ( parts.length == 8 ) {
+               //                              CPOperand lhsInput, rhsInput, 
rl, ru, cl, cu, out;
+               //                              lhsInput = new 
CPOperand(parts[1]);
+               //                              rhsInput = new 
CPOperand(parts[2]);
+               //                              rl = new CPOperand(parts[3]);
+               //                              ru = new CPOperand(parts[4]);
+               //                              cl = new CPOperand(parts[5]);
+               //                              cu = new CPOperand(parts[6]);
+               //                              out = new CPOperand(parts[7]);
+               //                              if( lhsInput.getDataType()== 
Types.DataType.MATRIX )
+               //                                      return new 
MatrixIndexingFEDInstruction(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, 
str);
+               //                              else if (lhsInput.getDataType() 
== Types.DataType.FRAME)
+               //                                      return new 
FrameIndexingFEDInstruction(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, 
str);
+               //                              else if( lhsInput.getDataType() 
== Types.DataType.LIST )
+               //                                      return new 
ListIndexingFEDInstruction(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, 
str);
+               //                              else
+               //                                      throw new 
DMLRuntimeException("Can index only on matrices, frames, and lists.");
+               //                      }
+               //                      else {
+               //                              throw new 
DMLRuntimeException("Invalid number of operands in instruction: " + str);
+               //                      }
+               //              }
+               else {
+                       throw new DMLRuntimeException("Unknown opcode while 
parsing a MatrixIndexingFEDInstruction: " + str);
+               }
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java
new file mode 100644
index 0000000..ea2e905
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.IndexRange;
+
+public final class MatrixIndexingFEDInstruction extends IndexingFEDInstruction 
{
+       private static final Log LOG = 
LogFactory.getLog(MatrixIndexingFEDInstruction.class.getName());
+
+       public MatrixIndexingFEDInstruction(CPOperand in, CPOperand rl, 
CPOperand ru, CPOperand cl, CPOperand cu,
+               CPOperand out, String opcode, String istr) {
+               super(in, rl, ru, cl, cu, out, opcode, istr);
+       }
+
+       @Override
+       public void processInstruction(ExecutionContext ec) {
+               rightIndexing(ec);
+       }
+
+
+       private void rightIndexing (ExecutionContext ec) {
+               MatrixObject in = ec.getMatrixObject(input1);
+               FederationMap fedMapping = in.getFedMapping();
+               IndexRange ixrange = getIndexRange(ec);
+               FederationMap.FType fedType;
+               Map <FederatedRange, IndexRange> ixs = new HashMap<>();
+
+               FederatedRange nextDim = new FederatedRange(new long[]{0, 0}, 
new long[]{0, 0});
+
+               for (int i = 0; i < fedMapping.getFederatedRanges().length; 
i++) {
+                       long rs = 
fedMapping.getFederatedRanges()[i].getBeginDims()[0], re = 
fedMapping.getFederatedRanges()[i]
+                               .getEndDims()[0], cs = 
fedMapping.getFederatedRanges()[i].getBeginDims()[1], ce = 
fedMapping.getFederatedRanges()[i].getEndDims()[1];
+
+                       // for OTHER
+                       fedType = ((i + 1) < 
fedMapping.getFederatedRanges().length &&
+                               
fedMapping.getFederatedRanges()[i].getEndDims()[0] == 
fedMapping.getFederatedRanges()[i+1].getBeginDims()[0]) ?
+                               FederationMap.FType.ROW : 
FederationMap.FType.COL;
+
+                       long rsn = 0, ren = 0, csn = 0, cen = 0;
+
+                       rsn = (ixrange.rowStart >= rs && ixrange.rowStart < re) 
? (ixrange.rowStart - rs) : 0;
+                       ren = (ixrange.rowEnd >= rs && ixrange.rowEnd < re) ? 
(ixrange.rowEnd - rs) : (re - rs - 1);
+                       csn = (ixrange.colStart >= cs && ixrange.colStart < ce) 
? (ixrange.colStart - cs) : 0;
+                       cen = (ixrange.colEnd >= cs && ixrange.colEnd < ce) ? 
(ixrange.colEnd - cs) : (ce - cs - 1);
+
+                       fedMapping.getFederatedRanges()[i].setBeginDim(0, i != 
0 ? nextDim.getBeginDims()[0] : 0);
+                       fedMapping.getFederatedRanges()[i].setBeginDim(1, i != 
0 ? nextDim.getBeginDims()[1] : 0);
+                       if((ixrange.colStart < ce) && (ixrange.colEnd >= cs) && 
(ixrange.rowStart < re) && (ixrange.rowEnd >= rs)) {
+                               fedMapping.getFederatedRanges()[i].setEndDim(0, 
ren - rsn + 1 + nextDim.getBeginDims()[0]);
+                               fedMapping.getFederatedRanges()[i].setEndDim(1, 
 cen - csn + 1 + nextDim.getBeginDims()[1]);
+
+                               ixs.put(fedMapping.getFederatedRanges()[i], new 
IndexRange(rsn, ren, csn, cen));
+                       } else {
+                               fedMapping.getFederatedRanges()[i].setEndDim(0, 
 i != 0 ? nextDim.getBeginDims()[0] : 0);
+                               fedMapping.getFederatedRanges()[i].setEndDim(1, 
 i != 0 ? nextDim.getBeginDims()[1] : 0);
+                       }
+
+                       if(fedType == FederationMap.FType.ROW) {
+                               
nextDim.setBeginDim(0,fedMapping.getFederatedRanges()[i].getEndDims()[0]);
+                               nextDim.setBeginDim(1, 
fedMapping.getFederatedRanges()[i].getBeginDims()[1]);
+                       } else if(fedType == FederationMap.FType.COL) {
+                               
nextDim.setBeginDim(1,fedMapping.getFederatedRanges()[i].getEndDims()[1]);
+                               nextDim.setBeginDim(0, 
fedMapping.getFederatedRanges()[i].getBeginDims()[0]);
+                       }
+               }
+
+               long varID = FederationUtils.getNextFedDataID();
+               FederationMap slicedMapping = fedMapping.mapParallel(varID, 
(range, data) -> {
+                       try {
+                               FederatedResponse response = 
data.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+                                       -1, new SliceMatrix(data.getVarID(), 
varID, ixs.getOrDefault(range, new IndexRange(-1, -1, -1, -1))))).get();
+                               if(!response.isSuccessful())
+                                       response.throwExceptionFromResponse();
+                       }
+                       catch(Exception e) {
+                               throw new DMLRuntimeException(e);
+                       }
+                       return null;
+               });
+
+               MatrixObject sliced = ec.getMatrixObject(output);
+               
sliced.getDataCharacteristics().set(fedMapping.getMaxIndexInRange(0), 
fedMapping.getMaxIndexInRange(1), (int) in.getBlocksize());
+               sliced.setFedMapping(slicedMapping);
+       }
+
+       private static class SliceMatrix extends FederatedUDF {
+
+               private static final long serialVersionUID = 
5956832933333848772L;
+               private final long _outputID;
+               private final IndexRange _ixrange;
+
+               private SliceMatrix(long input, long outputID, IndexRange 
ixrange) {
+                       super(new long[] {input});
+                       _outputID = outputID;
+                       _ixrange = ixrange;
+               }
+
+
+               @Override public FederatedResponse execute(ExecutionContext ec, 
Data... data) {
+                       MatrixBlock mb = ((MatrixObject) 
data[0]).acquireReadAndRelease();
+                       MatrixBlock res;
+                       if(_ixrange.rowStart != -1)
+                               res = mb.slice(_ixrange, new MatrixBlock());
+                       else res = new MatrixBlock();
+                       MatrixObject mout = 
ExecutionContext.createMatrixObject(res);
+                       ec.setVariable(String.valueOf(_outputID), mout);
+
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS_EMPTY);
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java
new file mode 100644
index 0000000..a16e4ed
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedRightIndexTest extends AutomatedTestBase {
+       private final static String TEST_NAME1 = "FederatedRightIndexRightTest";
+       private final static String TEST_NAME2 = "FederatedRightIndexLeftTest";
+       private final static String TEST_NAME3 = "FederatedRightIndexFullTest";
+
+       private final static String TEST_DIR = "functions/federated/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedRightIndexTest.class.getSimpleName() + "/";
+
+       private final static int blocksize = 1024;
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+
+       @Parameterized.Parameter(2)
+       public int from;
+
+       @Parameterized.Parameter(3)
+       public int to;
+
+       @Parameterized.Parameter(4)
+       public boolean rowPartitioned;
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               return Arrays.asList(new Object[][] {
+                       {20, 10,  6, 8, true}, {20, 10,  2, 10, true},
+                       {20, 12,  2, 10, false}, {20, 12,  1, 4, false}
+               });
+       }
+
+       private enum IndexType {
+               RIGHT, LEFT, FULL
+       }
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S"}));
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S"}));
+               addTestConfiguration(TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S"}));
+       }
+
+       @Test
+       public void testRightIndexRightDenseMatrixCP() {
+               runAggregateOperationTest(IndexType.RIGHT, 
ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void testRightIndexLeftDenseMatrixCP() {
+               runAggregateOperationTest(IndexType.LEFT, ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void testRightIndexFullDenseMatrixCP() {
+               runAggregateOperationTest(IndexType.FULL, ExecMode.SINGLE_NODE);
+       }
+
+       private void runAggregateOperationTest(IndexType type, ExecMode 
execMode) {
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               ExecMode platformOld = rtplatform;
+
+               if(rtplatform == ExecMode.SPARK)
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+               String TEST_NAME = null;
+               switch(type) {
+                       case RIGHT:
+                               TEST_NAME = TEST_NAME1; break;
+                       case LEFT:
+                               TEST_NAME = TEST_NAME2; break;
+                       case FULL:
+                               TEST_NAME = TEST_NAME3; break;
+               }
+
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               // write input matrices
+               int r = rows;
+               int c = cols / 4;
+               if(rowPartitioned) {
+                       r = rows / 4;
+                       c = cols;
+               }
+
+               double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
+               double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
+               double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
+               double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
+
+               MatrixCharacteristics mc = new MatrixCharacteristics(r, c, 
blocksize, r * c);
+               writeInputMatrixWithMTD("X1", X1, false, mc);
+               writeInputMatrixWithMTD("X2", X2, false, mc);
+               writeInputMatrixWithMTD("X3", X3, false, mc);
+               writeInputMatrixWithMTD("X4", X4, false, mc);
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               int port3 = getRandomAvailablePort();
+               int port4 = getRandomAvailablePort();
+               Thread t1 = startLocalFedWorkerThread(port1);
+               Thread t2 = startLocalFedWorkerThread(port2);
+               Thread t3 = startLocalFedWorkerThread(port3);
+               Thread t4 = startLocalFedWorkerThread(port4);
+
+               rtplatform = execMode;
+               if(rtplatform == ExecMode.SPARK) {
+                       System.out.println(7);
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               }
+               TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+               loadTestConfiguration(config);
+
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+               programArgs = new String[] { "-args", input("X1"), input("X2"), 
input("X3"), input("X4"),
+                       String.valueOf(from), String.valueOf(to),  
Boolean.toString(rowPartitioned).toUpperCase(), expected("S")};
+               runTest(true, false, null, -1);
+
+               // Run actual dml script with federated matrix
+
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-stats", "100", "-nvargs",
+                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                       "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
+                       "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")), "rows=" + rows, "cols=" + cols,
+                       "from=" + from, "to=" + to, "rP=" + 
Boolean.toString(rowPartitioned).toUpperCase(),
+                       "out_S=" + output("S")};
+
+               runTest(true, false, null, -1);
+
+               // compare via files
+               compareResults(1e-9);
+
+               Assert.assertTrue(heavyHittersContainsString("fed_rightIndex"));
+
+               // check that federated input files are still existing
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+               TestUtils.shutdownThreads(t1, t2, t3, t4);
+
+               rtplatform = platformOld;
+               DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+
+       }
+}
diff --git 
a/src/test/scripts/functions/federated/FederatedRightIndexFullTest.dml 
b/src/test/scripts/functions/federated/FederatedRightIndexFullTest.dml
new file mode 100644
index 0000000..46bc064
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRightIndexFullTest.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+from = $from;
+to = $to;
+
+if ($rP) {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+} else {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+            ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
+               list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
+}
+
+s = A[from:to, from:to];
+write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/FederatedRightIndexFullTestReference.dml 
b/src/test/scripts/functions/federated/FederatedRightIndexFullTestReference.dml
new file mode 100644
index 0000000..8261f5e
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/FederatedRightIndexFullTestReference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+from = $5;
+to = $6;
+
+if($7) { A = rbind(read($1), read($2), read($3), read($4)); }
+else { A = cbind(read($1), read($2), read($3), read($4)); }
+
+s = A[from:to, from:to];
+write(s, $8);
diff --git 
a/src/test/scripts/functions/federated/FederatedRightIndexLeftTest.dml 
b/src/test/scripts/functions/federated/FederatedRightIndexLeftTest.dml
new file mode 100644
index 0000000..3f690b1
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRightIndexLeftTest.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+from = $from;
+to = $to;
+
+if ($rP) {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+} else {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+            ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
+               list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
+}
+
+s = A[from:to,];
+write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/FederatedRightIndexLeftTestReference.dml 
b/src/test/scripts/functions/federated/FederatedRightIndexLeftTestReference.dml
new file mode 100644
index 0000000..ef095f3
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/FederatedRightIndexLeftTestReference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+from = $5;
+to = $6;
+
+if($7) { A = rbind(read($1), read($2), read($3), read($4)); }
+else { A = cbind(read($1), read($2), read($3), read($4)); }
+
+s = A[from:to,];
+write(s, $8);
diff --git 
a/src/test/scripts/functions/federated/FederatedRightIndexRightTest.dml 
b/src/test/scripts/functions/federated/FederatedRightIndexRightTest.dml
new file mode 100644
index 0000000..ee80b46
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRightIndexRightTest.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+from = $from;
+to = $to;
+
+if ($rP) {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+} else {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+            ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
+               list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
+}
+
+s = A[, from:to];
+write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/FederatedRightIndexRightTestReference.dml
 
b/src/test/scripts/functions/federated/FederatedRightIndexRightTestReference.dml
new file mode 100644
index 0000000..af83ca0
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/FederatedRightIndexRightTestReference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+from = $5;
+to = $6;
+
+if($7) { A = rbind(read($1), read($2), read($3), read($4)); }
+else { A = cbind(read($1), read($2), read($3), read($4)); }
+
+s = A[, from:to];
+write(s, $8);

Reply via email to