This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 6d19dbaffc3e2ead068474c734ad289141025b5a Author: baunsgaard <[email protected]> AuthorDate: Thu Nov 12 15:01:34 2020 +0100 [SYSTEMDS-2723] Cast to frame Federated --- .../controlprogram/federated/FederatedData.java | 7 ++ .../instructions/fed/FEDInstructionUtils.java | 25 ++-- .../instructions/fed/VariableFEDInstruction.java | 57 ++++++++++ .../primitives/FederetedCastToFrameTest.java | 126 +++++++++++++++++++++ .../primitives/FederatedCastToFrameTest.dml | 26 +++++ .../FederatedCastToFrameTestReference.dml | 25 ++++ 6 files changed, 259 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java index f9702c0..d19d132 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java @@ -74,6 +74,13 @@ public class FederatedData { _allFedSites.add(_address); } + public FederatedData(Types.DataType dataType, InetSocketAddress address, String filepath, long varID) { + _dataType = dataType; + _address = address; + _filepath = filepath; + _varID = varID; + } + public InetSocketAddress getAddress() { return _address; } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java index 2edc5f2..68e1cee 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java @@ -36,6 +36,8 @@ import org.apache.sysds.runtime.instructions.cp.MatrixIndexingCPInstruction; import org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction; import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction; import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction; +import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction; +import org.apache.sysds.runtime.instructions.cp.UnaryMatrixCPInstruction; import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction; import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction.VariableOperationCode; import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction; @@ -82,13 +84,15 @@ public class FEDInstructionUtils { if( mo.isFederated() ) fedinst = TsmmFEDInstruction.parseInstruction(linst.getInstructionString()); } - else if (inst instanceof AggregateUnaryCPInstruction) { - AggregateUnaryCPInstruction instruction = (AggregateUnaryCPInstruction) inst; - if( instruction.input1.isMatrix() && ec.containsVariable(instruction.input1) ) { - MatrixObject mo1 = ec.getMatrixObject(instruction.input1); - if (mo1.isFederated() && instruction.getAUType() == AggregateUnaryCPInstruction.AUType.DEFAULT){ - LOG.debug("Federated UnaryAggregate"); - fedinst = AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString()); + else if(inst instanceof UnaryCPInstruction){ + if (inst instanceof AggregateUnaryCPInstruction) { + AggregateUnaryCPInstruction instruction = (AggregateUnaryCPInstruction) inst; + if( instruction.input1.isMatrix() && ec.containsVariable(instruction.input1) ) { + MatrixObject mo1 = ec.getMatrixObject(instruction.input1); + if (mo1.isFederated() && instruction.getAUType() == AggregateUnaryCPInstruction.AUType.DEFAULT){ + LOG.debug("Federated UnaryAggregate"); + fedinst = AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString()); + } } } } @@ -141,12 +145,19 @@ public class FEDInstructionUtils { VariableCPInstruction ins = (VariableCPInstruction) inst; if(ins.getVariableOpcode() == VariableOperationCode.Write + && ins.getInput1().isMatrix() && ins.getInput3().getName().contains("federated")){ fedinst = VariableFEDInstruction.parseInstruction(ins); } + else if(ins.getVariableOpcode() == VariableOperationCode.CastAsFrameVariable + && ins.getInput1().isMatrix() + && ec.getCacheableData(ins.getInput1()).isFederated()){ + fedinst = VariableFEDInstruction.parseInstruction(ins); + } } + //set thread id for federated context management if( fedinst != null ) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java index 91efc2c..7d39e9d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java @@ -19,11 +19,25 @@ package org.apache.sysds.runtime.instructions.fed; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types; +import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.FrameObject; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.federated.FederatedData; +import org.apache.sysds.runtime.controlprogram.federated.FederatedRange; +import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; +import org.apache.sysds.runtime.controlprogram.federated.FederationMap; +import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction; import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction.VariableOperationCode; import org.apache.sysds.runtime.lineage.LineageItem; @@ -52,6 +66,13 @@ public class VariableFEDInstruction extends FEDInstruction implements LineageTra processWriteInstruction(ec); break; + case CastAsMatrixVariable: + processCastAsMatrixVariableInstruction(ec); + break; + case CastAsFrameVariable: + processCastAsFrameVariableInstruction(ec); + break; + default: throw new DMLRuntimeException("Unsupported Opcode for federated Variable Instruction : " + opcode); } @@ -66,6 +87,42 @@ public class VariableFEDInstruction extends FEDInstruction implements LineageTra _in.processInstruction(ec); } + private void processCastAsMatrixVariableInstruction(ExecutionContext ec){ + LOG.error("Not Implemented"); + throw new DMLRuntimeException("Not Implemented Cast as Matrix"); + + } + + private void processCastAsFrameVariableInstruction(ExecutionContext ec){ + + MatrixObject mo1 = ec.getMatrixObject(_in.getInput1()); + + if( !mo1.isFederated() ) + throw new DMLRuntimeException("Federated Reorg: " + + "Federated input expected, but invoked w/ "+mo1.isFederated()); + + //execute transpose at federated site + FederatedRequest fr1 = FederationUtils.callInstruction(_in.getInstructionString(), _in.getOutput(), + new CPOperand[]{_in.getInput1()}, new long[]{mo1.getFedMapping().getID()}); + mo1.getFedMapping().execute(getTID(), true, fr1); + + //drive output federated mapping + FrameObject out = ec.getFrameObject(_in.getOutput()); + out.getDataCharacteristics().set(mo1.getNumColumns(), + mo1.getNumRows(), (int)mo1.getBlocksize(), mo1.getNnz()); + FederationMap outMap = mo1.getFedMapping().copyWithNewID(fr1.getID()); + Map<FederatedRange, FederatedData> newMap = new HashMap<>(); + for(Map.Entry<FederatedRange, FederatedData> pair : outMap.getFedMapping().entrySet()){ + FederatedData om = pair.getValue(); + FederatedData nf = new FederatedData(Types.DataType.FRAME, om.getAddress(),om.getFilepath(),om.getVarID()); + newMap.put(pair.getKey(), nf); + } + ValueType[] schema = new ValueType[(int)mo1.getDataCharacteristics().getCols()]; + Arrays.fill(schema, ValueType.FP64); + out.setSchema(schema); + out.setFedMapping(outMap); + } + @Override public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) { return _in.getLineageItem(ec); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToFrameTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToFrameTest.java new file mode 100644 index 0000000..bbef96e --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToFrameTest.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.federated.primitives; + +import java.util.Arrays; +import java.util.Collection; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(value = Parameterized.class) [email protected] +public class FederetedCastToFrameTest extends AutomatedTestBase { + private static final Log LOG = LogFactory.getLog(FederetedCastToFrameTest.class.getName()); + + private final static String TEST_DIR = "functions/federated/primitives/"; + private final static String TEST_NAME = "FederatedCastToFrameTest"; + private final static String TEST_CLASS_DIR = TEST_DIR + FederetedCastToFrameTest.class.getSimpleName() + "/"; + + private final static int blocksize = 1024; + @Parameterized.Parameter() + public int rows; + @Parameterized.Parameter(1) + public int cols; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); + } + + @Parameterized.Parameters + public static Collection<Object[]> data() { + // rows have to be even and > 1 + return Arrays.asList(new Object[][] {{10, 32}}); + } + + @Test + public void federatedMultiplyCP() { + federatedMultiply(Types.ExecMode.SINGLE_NODE); + } + + @Test + @Ignore + public void federatedMultiplySP() { + // TODO Fix me Spark execution error + federatedMultiply(Types.ExecMode.SPARK); + } + + public void federatedMultiply(Types.ExecMode execMode) { + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + Types.ExecMode platformOld = rtplatform; + rtplatform = execMode; + if(rtplatform == Types.ExecMode.SPARK) { + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + + // write input matrices + int halfRows = rows / 2; + // We have two matrices handled by a single federated worker + double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42); + double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340); + + writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols)); + writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols)); + + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + Thread t1 = startLocalFedWorkerThread(port1); + Thread t2 = startLocalFedWorkerThread(port2); + + TestConfiguration config = availableTestConfigurations.get(TEST_NAME); + loadTestConfiguration(config); + + // Run reference dml script with normal matrix + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; + programArgs = new String[] {"-nvargs", "X1=" + input("X1"), "X2=" + input("X2")}; + String out = runTest(null).toString().split("SystemDS Statistics:")[0]; + + // Run actual dml script with federated matrix + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "100", "-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")), + "X2=" + TestUtils.federatedAddress(port2, input("X2")), "r=" + rows, "c=" + cols}; + String fedOut = runTest(null).toString(); + + LOG.error(fedOut); + fedOut = fedOut.split("SystemDS Statistics:")[0]; + Assert.assertTrue("Equal Printed Output", out.equals(fedOut)); + Assert.assertTrue("Contains federated Cast to frame", heavyHittersContainsString("fed_castdtf")); + TestUtils.shutdownThreads(t1, t2); + + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } +} diff --git a/src/test/scripts/functions/federated/primitives/FederatedCastToFrameTest.dml b/src/test/scripts/functions/federated/primitives/FederatedCastToFrameTest.dml new file mode 100644 index 0000000..6efd3f4 --- /dev/null +++ b/src/test/scripts/functions/federated/primitives/FederatedCastToFrameTest.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = federated(addresses=list($X1, $X2), + ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c))) + +Z = as.frame(X) +print(toString(Z[1])) diff --git a/src/test/scripts/functions/federated/primitives/FederatedCastToFrameTestReference.dml b/src/test/scripts/functions/federated/primitives/FederatedCastToFrameTestReference.dml new file mode 100644 index 0000000..919e309 --- /dev/null +++ b/src/test/scripts/functions/federated/primitives/FederatedCastToFrameTestReference.dml @@ -0,0 +1,25 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = rbind(read($X1), read($X2)) + +Z = as.frame(X) +print(toString(Z[1]))
