This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
commit e2bd5bf4fe90a66f316898208c851f796117a90c Author: Olga <[email protected]> AuthorDate: Tue Nov 10 16:05:30 2020 +0100 [SYSTEMDS-2726] Federated right indexing --- .../controlprogram/federated/FederatedRange.java | 15 ++ .../controlprogram/federated/FederationMap.java | 6 +- .../runtime/instructions/fed/FEDInstruction.java | 1 + .../instructions/fed/FEDInstructionUtils.java | 10 ++ .../instructions/fed/IndexingFEDInstruction.java | 113 ++++++++++++ .../fed/MatrixIndexingFEDInstruction.java | 144 ++++++++++++++++ .../primitives/FederatedRightIndexTest.java | 191 +++++++++++++++++++++ .../federated/FederatedRightIndexFullTest.dml | 36 ++++ .../FederatedRightIndexFullTestReference.dml | 29 ++++ .../federated/FederatedRightIndexLeftTest.dml | 36 ++++ .../FederatedRightIndexLeftTestReference.dml | 29 ++++ .../federated/FederatedRightIndexRightTest.dml | 36 ++++ .../FederatedRightIndexRightTestReference.dml | 29 ++++ 13 files changed, 673 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java index 4289cfe..3bd5734 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java @@ -102,6 +102,21 @@ public class FederatedRange implements Comparable<FederatedRange> { return Arrays.toString(_beginDims) + " - " + Arrays.toString(_endDims); } + @Override public boolean equals(Object o) { + if(this == o) + return true; + if(o == null || getClass() != o.getClass()) + return false; + FederatedRange range = (FederatedRange) o; + return Arrays.equals(_beginDims, range._beginDims) && Arrays.equals(_endDims, range._endDims); + } + + @Override public int hashCode() { + int result = Arrays.hashCode(_beginDims); + result = 31 * result + Arrays.hashCode(_endDims); + return result; + } + public FederatedRange shift(long rshift, long cshift) { //row shift _beginDims[0] += rshift; diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java index 04251fc..b647476 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java @@ -224,8 +224,10 @@ public class FederationMap public FederationMap copyWithNewID(long id) { Map<FederatedRange, FederatedData> map = new TreeMap<>(); //TODO handling of file path, but no danger as never written - for( Entry<FederatedRange, FederatedData> e : _fedMap.entrySet() ) - map.put(new FederatedRange(e.getKey()), e.getValue().copyWithNewID(id)); + for( Entry<FederatedRange, FederatedData> e : _fedMap.entrySet() ) { + if(e.getKey().getSize() != 0) + map.put(new FederatedRange(e.getKey()), e.getValue().copyWithNewID(id)); + } return new FederationMap(id, map, _type); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java index 9301765..8094c96 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java @@ -37,6 +37,7 @@ public abstract class FEDInstruction extends Instruction { Tsmm, MMChain, Reorg, + MatrixIndexing } protected final FEDType _fedType; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java index 795db11..2edc5f2 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java @@ -32,6 +32,7 @@ import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction; import org.apache.sysds.runtime.instructions.cp.Data; import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction; import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction; +import org.apache.sysds.runtime.instructions.cp.MatrixIndexingCPInstruction; import org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction; import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction; import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction; @@ -127,6 +128,15 @@ public class FEDInstructionUtils { if( mo.isFederated() ) fedinst = ReorgFEDInstruction.parseInstruction(rinst.getInstructionString()); } + else if(inst instanceof MatrixIndexingCPInstruction && inst.getOpcode().equalsIgnoreCase("rightIndex")) { + // matrix indexing + MatrixIndexingCPInstruction minst = (MatrixIndexingCPInstruction) inst; + if(minst.input1.isMatrix()) { + CacheableData<?> fo = ec.getCacheableData(minst.input1); + if(fo.isFederated()) + fedinst = MatrixIndexingFEDInstruction.parseInstruction(minst.getInstructionString()); + } + } else if(inst instanceof VariableCPInstruction ){ VariableCPInstruction ins = (VariableCPInstruction) inst; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java new file mode 100644 index 0000000..15fe1ab --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.fed; + +import org.apache.sysds.common.Types; +import org.apache.sysds.lops.LeftIndex; +import org.apache.sysds.lops.RightIndex; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.util.IndexRange; + +public abstract class IndexingFEDInstruction extends UnaryFEDInstruction { + protected final CPOperand rowLower, rowUpper, colLower, colUpper; + + protected IndexingFEDInstruction(CPOperand in, CPOperand rl, CPOperand ru, CPOperand cl, CPOperand cu, + CPOperand out, String opcode, String istr) { + super(FEDInstruction.FEDType.MatrixIndexing, null, in, out, opcode, istr); + rowLower = rl; + rowUpper = ru; + colLower = cl; + colUpper = cu; + } + + protected IndexingFEDInstruction(CPOperand lhsInput, CPOperand rhsInput, CPOperand rl, CPOperand ru, CPOperand cl, + CPOperand cu, CPOperand out, String opcode, String istr) { + super(FEDInstruction.FEDType.MatrixIndexing, null, lhsInput, rhsInput, out, opcode, istr); + rowLower = rl; + rowUpper = ru; + colLower = cl; + colUpper = cu; + } + + protected IndexRange getIndexRange(ExecutionContext ec) { + return new IndexRange( //rl, ru, cl, ru + (int) (ec.getScalarInput(rowLower).getLongValue() - 1), + (int) (ec.getScalarInput(rowUpper).getLongValue() - 1), + (int) (ec.getScalarInput(colLower).getLongValue() - 1), + (int) (ec.getScalarInput(colUpper).getLongValue() - 1)); + } + + public static IndexingFEDInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + String opcode = parts[0]; + + if(opcode.equalsIgnoreCase(RightIndex.OPCODE)) { + if(parts.length == 7) { + CPOperand in, rl, ru, cl, cu, out; + in = new CPOperand(parts[1]); + rl = new CPOperand(parts[2]); + ru = new CPOperand(parts[3]); + cl = new CPOperand(parts[4]); + cu = new CPOperand(parts[5]); + out = new CPOperand(parts[6]); + if(in.getDataType() == Types.DataType.MATRIX) + return new MatrixIndexingFEDInstruction(in, rl, ru, cl, cu, out, opcode, str); + // else if( in.getDataType() == Types.DataType.FRAME ) + // return new FrameIndexingCPInstruction(in, rl, ru, cl, cu, out, opcode, str); + // else if( in.getDataType() == Types.DataType.LIST ) + // return new ListIndexingCPInstruction(in, rl, ru, cl, cu, out, opcode, str); + else + throw new DMLRuntimeException("Can index only on matrices, frames, and lists."); + } + else { + throw new DMLRuntimeException("Invalid number of operands in instruction: " + str); + } + } + // else if ( opcode.equalsIgnoreCase(LeftIndex.OPCODE)) { + // if ( parts.length == 8 ) { + // CPOperand lhsInput, rhsInput, rl, ru, cl, cu, out; + // lhsInput = new CPOperand(parts[1]); + // rhsInput = new CPOperand(parts[2]); + // rl = new CPOperand(parts[3]); + // ru = new CPOperand(parts[4]); + // cl = new CPOperand(parts[5]); + // cu = new CPOperand(parts[6]); + // out = new CPOperand(parts[7]); + // if( lhsInput.getDataType()== Types.DataType.MATRIX ) + // return new MatrixIndexingFEDInstruction(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, str); + // else if (lhsInput.getDataType() == Types.DataType.FRAME) + // return new FrameIndexingFEDInstruction(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, str); + // else if( lhsInput.getDataType() == Types.DataType.LIST ) + // return new ListIndexingFEDInstruction(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, str); + // else + // throw new DMLRuntimeException("Can index only on matrices, frames, and lists."); + // } + // else { + // throw new DMLRuntimeException("Invalid number of operands in instruction: " + str); + // } + // } + else { + throw new DMLRuntimeException("Unknown opcode while parsing a MatrixIndexingFEDInstruction: " + str); + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java new file mode 100644 index 0000000..ea2e905 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysds.runtime.instructions.fed; + +import java.util.HashMap; +import java.util.Map; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.federated.FederatedRange; +import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; +import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; +import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; +import org.apache.sysds.runtime.controlprogram.federated.FederationMap; +import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.cp.Data; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.IndexRange; + +public final class MatrixIndexingFEDInstruction extends IndexingFEDInstruction { + private static final Log LOG = LogFactory.getLog(MatrixIndexingFEDInstruction.class.getName()); + + public MatrixIndexingFEDInstruction(CPOperand in, CPOperand rl, CPOperand ru, CPOperand cl, CPOperand cu, + CPOperand out, String opcode, String istr) { + super(in, rl, ru, cl, cu, out, opcode, istr); + } + + @Override + public void processInstruction(ExecutionContext ec) { + rightIndexing(ec); + } + + + private void rightIndexing (ExecutionContext ec) { + MatrixObject in = ec.getMatrixObject(input1); + FederationMap fedMapping = in.getFedMapping(); + IndexRange ixrange = getIndexRange(ec); + FederationMap.FType fedType; + Map <FederatedRange, IndexRange> ixs = new HashMap<>(); + + FederatedRange nextDim = new FederatedRange(new long[]{0, 0}, new long[]{0, 0}); + + for (int i = 0; i < fedMapping.getFederatedRanges().length; i++) { + long rs = fedMapping.getFederatedRanges()[i].getBeginDims()[0], re = fedMapping.getFederatedRanges()[i] + .getEndDims()[0], cs = fedMapping.getFederatedRanges()[i].getBeginDims()[1], ce = fedMapping.getFederatedRanges()[i].getEndDims()[1]; + + // for OTHER + fedType = ((i + 1) < fedMapping.getFederatedRanges().length && + fedMapping.getFederatedRanges()[i].getEndDims()[0] == fedMapping.getFederatedRanges()[i+1].getBeginDims()[0]) ? + FederationMap.FType.ROW : FederationMap.FType.COL; + + long rsn = 0, ren = 0, csn = 0, cen = 0; + + rsn = (ixrange.rowStart >= rs && ixrange.rowStart < re) ? (ixrange.rowStart - rs) : 0; + ren = (ixrange.rowEnd >= rs && ixrange.rowEnd < re) ? (ixrange.rowEnd - rs) : (re - rs - 1); + csn = (ixrange.colStart >= cs && ixrange.colStart < ce) ? (ixrange.colStart - cs) : 0; + cen = (ixrange.colEnd >= cs && ixrange.colEnd < ce) ? (ixrange.colEnd - cs) : (ce - cs - 1); + + fedMapping.getFederatedRanges()[i].setBeginDim(0, i != 0 ? nextDim.getBeginDims()[0] : 0); + fedMapping.getFederatedRanges()[i].setBeginDim(1, i != 0 ? nextDim.getBeginDims()[1] : 0); + if((ixrange.colStart < ce) && (ixrange.colEnd >= cs) && (ixrange.rowStart < re) && (ixrange.rowEnd >= rs)) { + fedMapping.getFederatedRanges()[i].setEndDim(0, ren - rsn + 1 + nextDim.getBeginDims()[0]); + fedMapping.getFederatedRanges()[i].setEndDim(1, cen - csn + 1 + nextDim.getBeginDims()[1]); + + ixs.put(fedMapping.getFederatedRanges()[i], new IndexRange(rsn, ren, csn, cen)); + } else { + fedMapping.getFederatedRanges()[i].setEndDim(0, i != 0 ? nextDim.getBeginDims()[0] : 0); + fedMapping.getFederatedRanges()[i].setEndDim(1, i != 0 ? nextDim.getBeginDims()[1] : 0); + } + + if(fedType == FederationMap.FType.ROW) { + nextDim.setBeginDim(0,fedMapping.getFederatedRanges()[i].getEndDims()[0]); + nextDim.setBeginDim(1, fedMapping.getFederatedRanges()[i].getBeginDims()[1]); + } else if(fedType == FederationMap.FType.COL) { + nextDim.setBeginDim(1,fedMapping.getFederatedRanges()[i].getEndDims()[1]); + nextDim.setBeginDim(0, fedMapping.getFederatedRanges()[i].getBeginDims()[0]); + } + } + + long varID = FederationUtils.getNextFedDataID(); + FederationMap slicedMapping = fedMapping.mapParallel(varID, (range, data) -> { + try { + FederatedResponse response = data.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, + -1, new SliceMatrix(data.getVarID(), varID, ixs.getOrDefault(range, new IndexRange(-1, -1, -1, -1))))).get(); + if(!response.isSuccessful()) + response.throwExceptionFromResponse(); + } + catch(Exception e) { + throw new DMLRuntimeException(e); + } + return null; + }); + + MatrixObject sliced = ec.getMatrixObject(output); + sliced.getDataCharacteristics().set(fedMapping.getMaxIndexInRange(0), fedMapping.getMaxIndexInRange(1), (int) in.getBlocksize()); + sliced.setFedMapping(slicedMapping); + } + + private static class SliceMatrix extends FederatedUDF { + + private static final long serialVersionUID = 5956832933333848772L; + private final long _outputID; + private final IndexRange _ixrange; + + private SliceMatrix(long input, long outputID, IndexRange ixrange) { + super(new long[] {input}); + _outputID = outputID; + _ixrange = ixrange; + } + + + @Override public FederatedResponse execute(ExecutionContext ec, Data... data) { + MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease(); + MatrixBlock res; + if(_ixrange.rowStart != -1) + res = mb.slice(_ixrange, new MatrixBlock()); + else res = new MatrixBlock(); + MatrixObject mout = ExecutionContext.createMatrixObject(res); + ec.setVariable(String.valueOf(_outputID), mout); + + return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS_EMPTY); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java new file mode 100644 index 0000000..a16e4ed --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.federated.primitives; + +import java.util.Arrays; +import java.util.Collection; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(value = Parameterized.class) [email protected] +public class FederatedRightIndexTest extends AutomatedTestBase { + private final static String TEST_NAME1 = "FederatedRightIndexRightTest"; + private final static String TEST_NAME2 = "FederatedRightIndexLeftTest"; + private final static String TEST_NAME3 = "FederatedRightIndexFullTest"; + + private final static String TEST_DIR = "functions/federated/"; + private static final String TEST_CLASS_DIR = TEST_DIR + FederatedRightIndexTest.class.getSimpleName() + "/"; + + private final static int blocksize = 1024; + @Parameterized.Parameter() + public int rows; + @Parameterized.Parameter(1) + public int cols; + + @Parameterized.Parameter(2) + public int from; + + @Parameterized.Parameter(3) + public int to; + + @Parameterized.Parameter(4) + public boolean rowPartitioned; + + @Parameterized.Parameters + public static Collection<Object[]> data() { + return Arrays.asList(new Object[][] { + {20, 10, 6, 8, true}, {20, 10, 2, 10, true}, + {20, 12, 2, 10, false}, {20, 12, 1, 4, false} + }); + } + + private enum IndexType { + RIGHT, LEFT, FULL + } + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S"})); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S"})); + addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S"})); + } + + @Test + public void testRightIndexRightDenseMatrixCP() { + runAggregateOperationTest(IndexType.RIGHT, ExecMode.SINGLE_NODE); + } + + @Test + public void testRightIndexLeftDenseMatrixCP() { + runAggregateOperationTest(IndexType.LEFT, ExecMode.SINGLE_NODE); + } + + @Test + public void testRightIndexFullDenseMatrixCP() { + runAggregateOperationTest(IndexType.FULL, ExecMode.SINGLE_NODE); + } + + private void runAggregateOperationTest(IndexType type, ExecMode execMode) { + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + ExecMode platformOld = rtplatform; + + if(rtplatform == ExecMode.SPARK) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + String TEST_NAME = null; + switch(type) { + case RIGHT: + TEST_NAME = TEST_NAME1; break; + case LEFT: + TEST_NAME = TEST_NAME2; break; + case FULL: + TEST_NAME = TEST_NAME3; break; + } + + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + + // write input matrices + int r = rows; + int c = cols / 4; + if(rowPartitioned) { + r = rows / 4; + c = cols; + } + + double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3); + double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7); + double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8); + double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9); + + MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c); + writeInputMatrixWithMTD("X1", X1, false, mc); + writeInputMatrixWithMTD("X2", X2, false, mc); + writeInputMatrixWithMTD("X3", X3, false, mc); + writeInputMatrixWithMTD("X4", X4, false, mc); + + // empty script name because we don't execute any script, just start the worker + fullDMLScriptName = ""; + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + int port3 = getRandomAvailablePort(); + int port4 = getRandomAvailablePort(); + Thread t1 = startLocalFedWorkerThread(port1); + Thread t2 = startLocalFedWorkerThread(port2); + Thread t3 = startLocalFedWorkerThread(port3); + Thread t4 = startLocalFedWorkerThread(port4); + + rtplatform = execMode; + if(rtplatform == ExecMode.SPARK) { + System.out.println(7); + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + TestConfiguration config = availableTestConfigurations.get(TEST_NAME); + loadTestConfiguration(config); + + // Run reference dml script with normal matrix + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; + programArgs = new String[] { "-args", input("X1"), input("X2"), input("X3"), input("X4"), + String.valueOf(from), String.valueOf(to), Boolean.toString(rowPartitioned).toUpperCase(), expected("S")}; + runTest(true, false, null, -1); + + // Run actual dml script with federated matrix + + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "100", "-nvargs", + "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), + "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), + "in_X3=" + TestUtils.federatedAddress(port3, input("X3")), + "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + cols, + "from=" + from, "to=" + to, "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), + "out_S=" + output("S")}; + + runTest(true, false, null, -1); + + // compare via files + compareResults(1e-9); + + Assert.assertTrue(heavyHittersContainsString("fed_rightIndex")); + + // check that federated input files are still existing + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3"))); + Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4"))); + + TestUtils.shutdownThreads(t1, t2, t3, t4); + + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + + } +} diff --git a/src/test/scripts/functions/federated/FederatedRightIndexFullTest.dml b/src/test/scripts/functions/federated/FederatedRightIndexFullTest.dml new file mode 100644 index 0000000..46bc064 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedRightIndexFullTest.dml @@ -0,0 +1,36 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +from = $from; +to = $to; + +if ($rP) { + A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols), + list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols))); +} else { + A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), list($rows, $cols/2), + list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols))); +} + +s = A[from:to, from:to]; +write(s, $out_S); diff --git a/src/test/scripts/functions/federated/FederatedRightIndexFullTestReference.dml b/src/test/scripts/functions/federated/FederatedRightIndexFullTestReference.dml new file mode 100644 index 0000000..8261f5e --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedRightIndexFullTestReference.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +from = $5; +to = $6; + +if($7) { A = rbind(read($1), read($2), read($3), read($4)); } +else { A = cbind(read($1), read($2), read($3), read($4)); } + +s = A[from:to, from:to]; +write(s, $8); diff --git a/src/test/scripts/functions/federated/FederatedRightIndexLeftTest.dml b/src/test/scripts/functions/federated/FederatedRightIndexLeftTest.dml new file mode 100644 index 0000000..3f690b1 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedRightIndexLeftTest.dml @@ -0,0 +1,36 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +from = $from; +to = $to; + +if ($rP) { + A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols), + list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols))); +} else { + A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), list($rows, $cols/2), + list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols))); +} + +s = A[from:to,]; +write(s, $out_S); diff --git a/src/test/scripts/functions/federated/FederatedRightIndexLeftTestReference.dml b/src/test/scripts/functions/federated/FederatedRightIndexLeftTestReference.dml new file mode 100644 index 0000000..ef095f3 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedRightIndexLeftTestReference.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +from = $5; +to = $6; + +if($7) { A = rbind(read($1), read($2), read($3), read($4)); } +else { A = cbind(read($1), read($2), read($3), read($4)); } + +s = A[from:to,]; +write(s, $8); diff --git a/src/test/scripts/functions/federated/FederatedRightIndexRightTest.dml b/src/test/scripts/functions/federated/FederatedRightIndexRightTest.dml new file mode 100644 index 0000000..ee80b46 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedRightIndexRightTest.dml @@ -0,0 +1,36 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +from = $from; +to = $to; + +if ($rP) { + A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols), + list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols))); +} else { + A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4), + ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), list($rows, $cols/2), + list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols))); +} + +s = A[, from:to]; +write(s, $out_S); diff --git a/src/test/scripts/functions/federated/FederatedRightIndexRightTestReference.dml b/src/test/scripts/functions/federated/FederatedRightIndexRightTestReference.dml new file mode 100644 index 0000000..af83ca0 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedRightIndexRightTestReference.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +from = $5; +to = $6; + +if($7) { A = rbind(read($1), read($2), read($3), read($4)); } +else { A = cbind(read($1), read($2), read($3), read($4)); } + +s = A[, from:to]; +write(s, $8);
