This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
commit a3e3ea949c6af02914356c430756477b948965ce Author: Matthias Boehm <[email protected]> AuthorDate: Sat Aug 15 14:39:45 2020 +0200 [SYSTEMDS-2620] Federated tsmm operations (e.g., PCA, lmDS, cor) * Federated tsmm: support for federated tsmm left over row-partioned federated matrices. * Performance: aggAdd (e.g., in ba+*, uack+, and tsmm) via nary instead of binary operations. --- .../controlprogram/federated/FederationUtils.java | 16 ++- .../fed/ComputationFEDInstruction.java | 10 +- .../runtime/instructions/fed/FEDInstruction.java | 3 +- .../instructions/fed/FEDInstructionUtils.java | 6 + .../instructions/fed/TsmmFEDInstruction.java | 82 +++++++++++++ .../test/functions/federated/FederatedPCATest.java | 133 +++++++++++++++++++++ .../functions/federated/FederatedPCATest.dml | 25 ++++ .../federated/FederatedPCATestReference.dml | 24 ++++ 8 files changed, 283 insertions(+), 16 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java index ab0b3aa..f2c8227 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java @@ -29,13 +29,13 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType; import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.functionobjects.KahanFunction; -import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.functionobjects.Plus; import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.instructions.cp.DoubleObject; import org.apache.sysds.runtime.instructions.cp.ScalarObject; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; -import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.SimpleOperator; public class FederationUtils { private static final IDSequence _idSeq = new IDSequence(); @@ -58,13 +58,11 @@ public class FederationUtils { public static MatrixBlock aggAdd(Future<FederatedResponse>[] ffr) { try { - BinaryOperator bop = InstructionUtils.parseBinaryOperator("+"); - MatrixBlock ret = (MatrixBlock) (ffr[0].get().getData()[0]); - for (int i=1; i<ffr.length; i++) { - MatrixBlock tmp = (MatrixBlock) (ffr[i].get().getData()[0]); - ret.binaryOperationsInPlace(bop, tmp); - } - return ret; + SimpleOperator op = new SimpleOperator(Plus.getPlusFnObject()); + MatrixBlock[] in = new MatrixBlock[ffr.length]; + for(int i=0; i<ffr.length; i++) + in[i] = (MatrixBlock) ffr[i].get().getData()[0]; + return MatrixBlock.naryOperations(op, in, new ScalarObject[0], new MatrixBlock()); } catch(Exception ex) { throw new DMLRuntimeException(ex); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java index 9d972f4..ccaec24 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java @@ -37,9 +37,8 @@ public abstract class ComputationFEDInstruction extends FEDInstruction implement public final CPOperand output; public final CPOperand input1, input2, input3; - protected ComputationFEDInstruction(FEDType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, - String opcode, - String istr) { + protected ComputationFEDInstruction(FEDType type, Operator op, + CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) { super(type, op, opcode, istr); input1 = in1; input2 = in2; @@ -47,9 +46,8 @@ public abstract class ComputationFEDInstruction extends FEDInstruction implement output = out; } - protected ComputationFEDInstruction(FEDType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3, - CPOperand out, - String opcode, String istr) { + protected ComputationFEDInstruction(FEDType type, Operator op, + CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr) { super(type, op, opcode, istr); input1 = in1; input2 = in2; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java index f2d0791..d6bd388 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java @@ -32,7 +32,8 @@ public abstract class FEDInstruction extends Instruction { Append, Binary, Init, - MultiReturnParameterizedBuiltin + MultiReturnParameterizedBuiltin, + Tsmm, } protected final FEDType _fedType; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java index d639baa..5f97350 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java @@ -76,6 +76,12 @@ public class FEDInstructionUtils { } } } + else if( inst instanceof MMTSJCPInstruction ) { + MMTSJCPInstruction linst = (MMTSJCPInstruction) inst; + MatrixObject mo = ec.getMatrixObject(linst.input1); + if( mo.isFederated() ) + return TsmmFEDInstruction.parseInstruction(linst.toString()); + } return inst; } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java new file mode 100644 index 0000000..a3061ed --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.fed; + +import org.apache.sysds.lops.MMTSJ.MMTSJType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; +import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; +import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; +import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; + +import java.util.concurrent.Future; + +public class TsmmFEDInstruction extends BinaryFEDInstruction { + private final MMTSJType _type; + @SuppressWarnings("unused") + private final int _numThreads; + + public TsmmFEDInstruction(CPOperand in, CPOperand out, MMTSJType type, int k, String opcode, String istr) { + super(FEDType.Tsmm, null, in, null, out, opcode, istr); + _type = type; + _numThreads = k; + } + + public static TsmmFEDInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + String opcode = parts[0]; + if(!opcode.equalsIgnoreCase("tsmm")) + throw new DMLRuntimeException("TsmmFedInstruction.parseInstruction():: Unknown opcode " + opcode); + + InstructionUtils.checkNumFields(parts, 4); + CPOperand in = new CPOperand(parts[1]); + CPOperand out = new CPOperand(parts[2]); + MMTSJType type = MMTSJType.valueOf(parts[3]); + int k = Integer.parseInt(parts[4]); + return new TsmmFEDInstruction(in, out, type, k, opcode, str); + } + + @Override + public void processInstruction(ExecutionContext ec) { + MatrixObject mo1 = ec.getMatrixObject(input1); + + if(mo1.isFederated() && _type.isLeft()) { // left tsmm + //construct commands: fed tsmm, retrieve results + FederatedRequest fr1 = FederationUtils.callInstruction(instString, output, + new CPOperand[]{input1}, new long[]{mo1.getFedMapping().getID()}); + FederatedRequest fr2 = new FederatedRequest(RequestType.GET_VAR, fr1.getID()); + + //execute federated operations and aggregate + Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(fr1, fr2); + MatrixBlock ret = FederationUtils.aggAdd(tmp); + mo1.getFedMapping().cleanup(fr1.getID()); + ec.setMatrixOutput(output.getName(), ret); + } + else { //other combinations + throw new DMLRuntimeException("Federated Tsmm not supported with the " + + "following federated objects: "+mo1.isFederated()+" "+_fedType); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java new file mode 100644 index 0000000..29826f8 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.federated; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import org.apache.sysds.common.Types; +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; + +import java.util.Arrays; +import java.util.Collection; + +@RunWith(value = Parameterized.class) [email protected] +public class FederatedPCATest extends AutomatedTestBase { + + private final static String TEST_DIR = "functions/federated/"; + private final static String TEST_NAME = "FederatedPCATest"; + private final static String TEST_CLASS_DIR = TEST_DIR + FederatedPCATest.class.getSimpleName() + "/"; + + private final static int blocksize = 1024; + @Parameterized.Parameter() + public int rows; + @Parameterized.Parameter(1) + public int cols; + @Parameterized.Parameter(2) + public boolean scaleAndShift; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"})); + } + + @Parameterized.Parameters + public static Collection<Object[]> data() { + // rows have to be even and > 1 + return Arrays.asList(new Object[][] { + {10000, 10, false}, {2000, 50, false}, {1000, 100, false}, + //TODO support for federated uacmean, uacvar + //{10000, 10, true}, {2000, 50, true}, {1000, 100, true} + }); + } + + @Test + public void federatedPCASinglenode() { + federatedL2SVM(Types.ExecMode.SINGLE_NODE); + } + + @Test + public void federatedPCAHybrid() { + federatedL2SVM(Types.ExecMode.HYBRID); + } + + public void federatedL2SVM(Types.ExecMode execMode) { + ExecMode platformOld = setExecMode(execMode); + + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + + // write input matrices + int halfRows = rows / 2; + // We have two matrices handled by a single federated worker + double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 3); + double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 7); + writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols)); + writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols)); + + // empty script name because we don't execute any script, just start the worker + fullDMLScriptName = ""; + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + Thread t1 = startLocalFedWorker(port1); + Thread t2 = startLocalFedWorker(port2); + + TestConfiguration config = availableTestConfigurations.get(TEST_NAME); + loadTestConfiguration(config); + setOutputBuffering(false); + + // Run reference dml script with normal matrix + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; + programArgs = new String[] {"-args", input("X1"), input("X2"), + String.valueOf(scaleAndShift).toUpperCase(), expected("Z")}; + runTest(true, false, null, -1); + + // Run actual dml script with federated matrix + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", + "-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), + "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), "rows=" + rows, "cols=" + cols, + "scaleAndShift=" + String.valueOf(scaleAndShift).toUpperCase(), "out=" + output("Z")}; + runTest(true, false, null, -1); + + // compare via files + compareResults(1e-9); + TestUtils.shutdownThreads(t1, t2); + + // check for federated operations + Assert.assertTrue(heavyHittersContainsString("fed_ba+*")); + Assert.assertTrue(heavyHittersContainsString("fed_uack+")); + Assert.assertTrue(heavyHittersContainsString("fed_tsmm")); + if( scaleAndShift ) { + Assert.assertTrue(heavyHittersContainsString("fed_uacmean")); + Assert.assertTrue(heavyHittersContainsString("fed_uacvar")); + } + + resetExecMode(platformOld); + } +} diff --git a/src/test/scripts/functions/federated/FederatedPCATest.dml b/src/test/scripts/functions/federated/FederatedPCATest.dml new file mode 100644 index 0000000..b235d44 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedPCATest.dml @@ -0,0 +1,25 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = federated(addresses=list($in_X1, $in_X2), + ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), list($rows, $cols))) +[X2,M] = pca(X=X, K=2, scale=$scaleAndShift, center=$scaleAndShift) +write(X2, $out) diff --git a/src/test/scripts/functions/federated/FederatedPCATestReference.dml b/src/test/scripts/functions/federated/FederatedPCATestReference.dml new file mode 100644 index 0000000..0b17fe0 --- /dev/null +++ b/src/test/scripts/functions/federated/FederatedPCATestReference.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = rbind(read($1), read($2)) +[X2,M] = pca(X=X, K=2, scale=$3, center=$3) +write(X2, $4)
