This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit a3e3ea949c6af02914356c430756477b948965ce
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Aug 15 14:39:45 2020 +0200

    [SYSTEMDS-2620] Federated tsmm operations (e.g., PCA, lmDS, cor)
    
    * Federated tsmm: support for federated tsmm left over row-partioned
    federated matrices.
    * Performance: aggAdd (e.g., in ba+*, uack+, and tsmm) via nary instead
    of binary operations.
---
 .../controlprogram/federated/FederationUtils.java  |  16 ++-
 .../fed/ComputationFEDInstruction.java             |  10 +-
 .../runtime/instructions/fed/FEDInstruction.java   |   3 +-
 .../instructions/fed/FEDInstructionUtils.java      |   6 +
 .../instructions/fed/TsmmFEDInstruction.java       |  82 +++++++++++++
 .../test/functions/federated/FederatedPCATest.java | 133 +++++++++++++++++++++
 .../functions/federated/FederatedPCATest.dml       |  25 ++++
 .../federated/FederatedPCATestReference.dml        |  24 ++++
 8 files changed, 283 insertions(+), 16 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index ab0b3aa..f2c8227 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -29,13 +29,13 @@ import org.apache.sysds.runtime.DMLRuntimeException;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
 import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
 import org.apache.sysds.runtime.functionobjects.KahanFunction;
-import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.functionobjects.Plus;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.DoubleObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
-import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
 
 public class FederationUtils {
        private static final IDSequence _idSeq = new IDSequence();
@@ -58,13 +58,11 @@ public class FederationUtils {
 
        public static MatrixBlock aggAdd(Future<FederatedResponse>[] ffr) {
                try {
-                       BinaryOperator bop = 
InstructionUtils.parseBinaryOperator("+");
-                       MatrixBlock ret = (MatrixBlock) 
(ffr[0].get().getData()[0]);
-                       for (int i=1; i<ffr.length; i++) {
-                               MatrixBlock tmp = (MatrixBlock) 
(ffr[i].get().getData()[0]);
-                               ret.binaryOperationsInPlace(bop, tmp);
-                       }
-                       return ret;
+                       SimpleOperator op = new 
SimpleOperator(Plus.getPlusFnObject());
+                       MatrixBlock[] in = new MatrixBlock[ffr.length];
+                       for(int i=0; i<ffr.length; i++)
+                               in[i] = (MatrixBlock) ffr[i].get().getData()[0];
+                       return MatrixBlock.naryOperations(op, in, new 
ScalarObject[0], new MatrixBlock());
                }
                catch(Exception ex) {
                        throw new DMLRuntimeException(ex);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java
index 9d972f4..ccaec24 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java
@@ -37,9 +37,8 @@ public abstract class ComputationFEDInstruction extends 
FEDInstruction implement
        public final CPOperand output;
        public final CPOperand input1, input2, input3;
        
-       protected ComputationFEDInstruction(FEDType type, Operator op, 
CPOperand in1, CPOperand in2, CPOperand out,
-                       String opcode,
-                       String istr) {
+       protected ComputationFEDInstruction(FEDType type, Operator op,
+               CPOperand in1, CPOperand in2, CPOperand out, String opcode, 
String istr) {
                super(type, op, opcode, istr);
                input1 = in1;
                input2 = in2;
@@ -47,9 +46,8 @@ public abstract class ComputationFEDInstruction extends 
FEDInstruction implement
                output = out;
        }
        
-       protected ComputationFEDInstruction(FEDType type, Operator op, 
CPOperand in1, CPOperand in2, CPOperand in3,
-                       CPOperand out,
-                       String opcode, String istr) {
+       protected ComputationFEDInstruction(FEDType type, Operator op,
+               CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, 
String opcode, String istr) {
                super(type, op, opcode, istr);
                input1 = in1;
                input2 = in2;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index f2d0791..d6bd388 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -32,7 +32,8 @@ public abstract class FEDInstruction extends Instruction {
                Append,
                Binary,
                Init,
-               MultiReturnParameterizedBuiltin
+               MultiReturnParameterizedBuiltin,
+               Tsmm,
        }
        
        protected final FEDType _fedType;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index d639baa..5f97350 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -76,6 +76,12 @@ public class FEDInstructionUtils {
                                }
                        }
                }
+               else if( inst instanceof MMTSJCPInstruction ) {
+                       MMTSJCPInstruction linst = (MMTSJCPInstruction) inst;
+                       MatrixObject mo = ec.getMatrixObject(linst.input1);
+                       if( mo.isFederated() )
+                               return 
TsmmFEDInstruction.parseInstruction(linst.toString());
+               }
                return inst;
        }
        
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
new file mode 100644
index 0000000..a3061ed
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.lops.MMTSJ.MMTSJType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+import java.util.concurrent.Future;
+
+public class TsmmFEDInstruction extends BinaryFEDInstruction {
+       private final MMTSJType _type;
+       @SuppressWarnings("unused")
+       private final int _numThreads;
+       
+       public TsmmFEDInstruction(CPOperand in, CPOperand out, MMTSJType type, 
int k, String opcode, String istr) {
+               super(FEDType.Tsmm, null, in, null, out, opcode, istr);
+               _type = type;
+               _numThreads = k;
+       }
+       
+       public static TsmmFEDInstruction parseInstruction(String str) {
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+               String opcode = parts[0];
+               if(!opcode.equalsIgnoreCase("tsmm"))
+                       throw new 
DMLRuntimeException("TsmmFedInstruction.parseInstruction():: Unknown opcode " + 
opcode);
+               
+               InstructionUtils.checkNumFields(parts, 4);
+               CPOperand in = new CPOperand(parts[1]);
+               CPOperand out = new CPOperand(parts[2]);
+               MMTSJType type = MMTSJType.valueOf(parts[3]);
+               int k = Integer.parseInt(parts[4]);
+               return new TsmmFEDInstruction(in, out, type, k, opcode, str);
+       }
+       
+       @Override
+       public void processInstruction(ExecutionContext ec) {
+               MatrixObject mo1 = ec.getMatrixObject(input1);
+               
+               if(mo1.isFederated() && _type.isLeft()) { // left tsmm
+                       //construct commands: fed tsmm, retrieve results
+                       FederatedRequest fr1 = 
FederationUtils.callInstruction(instString, output,
+                               new CPOperand[]{input1}, new 
long[]{mo1.getFedMapping().getID()});
+                       FederatedRequest fr2 = new 
FederatedRequest(RequestType.GET_VAR, fr1.getID());
+                       
+                       //execute federated operations and aggregate
+                       Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(fr1, fr2);
+                       MatrixBlock ret = FederationUtils.aggAdd(tmp);
+                       mo1.getFedMapping().cleanup(fr1.getID());
+                       ec.setMatrixOutput(output.getName(), ret);
+               }
+               else { //other combinations
+                       throw new DMLRuntimeException("Federated Tsmm not 
supported with the "
+                               + "following federated objects: 
"+mo1.isFederated()+" "+_fedType);
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
new file mode 100644
index 0000000..29826f8
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedPCATest extends AutomatedTestBase {
+
+       private final static String TEST_DIR = "functions/federated/";
+       private final static String TEST_NAME = "FederatedPCATest";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedPCATest.class.getSimpleName() + "/";
+
+       private final static int blocksize = 1024;
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+       @Parameterized.Parameter(2)
+       public boolean scaleAndShift;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+       }
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               // rows have to be even and > 1
+               return Arrays.asList(new Object[][] {
+                       {10000, 10, false}, {2000, 50, false}, {1000, 100, 
false},
+                       //TODO support for federated uacmean, uacvar
+                       //{10000, 10, true}, {2000, 50, true}, {1000, 100, true}
+               });
+       }
+
+       @Test
+       public void federatedPCASinglenode() {
+               federatedL2SVM(Types.ExecMode.SINGLE_NODE);
+       }
+       
+       @Test
+       public void federatedPCAHybrid() {
+               federatedL2SVM(Types.ExecMode.HYBRID);
+       }
+
+       public void federatedL2SVM(Types.ExecMode execMode) {
+               ExecMode platformOld = setExecMode(execMode);
+
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               // write input matrices
+               int halfRows = rows / 2;
+               // We have two matrices handled by a single federated worker
+               double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 3);
+               double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 7);
+               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               Thread t1 = startLocalFedWorker(port1);
+               Thread t2 = startLocalFedWorker(port2);
+
+               TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+               loadTestConfiguration(config);
+               setOutputBuffering(false);
+               
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+               programArgs = new String[] {"-args", input("X1"), input("X2"),
+                       String.valueOf(scaleAndShift).toUpperCase(), 
expected("Z")};
+               runTest(true, false, null, -1);
+
+               // Run actual dml script with federated matrix
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-stats",
+                       "-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")), "rows=" + rows, "cols=" + cols,
+                       "scaleAndShift=" + 
String.valueOf(scaleAndShift).toUpperCase(), "out=" + output("Z")};
+               runTest(true, false, null, -1);
+
+               // compare via files
+               compareResults(1e-9);
+               TestUtils.shutdownThreads(t1, t2);
+               
+               // check for federated operations
+               Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
+               Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
+               Assert.assertTrue(heavyHittersContainsString("fed_tsmm"));
+               if( scaleAndShift ) {
+                       
Assert.assertTrue(heavyHittersContainsString("fed_uacmean"));
+                       
Assert.assertTrue(heavyHittersContainsString("fed_uacvar"));
+               }
+               
+               resetExecMode(platformOld);
+       }
+}
diff --git a/src/test/scripts/functions/federated/FederatedPCATest.dml 
b/src/test/scripts/functions/federated/FederatedPCATest.dml
new file mode 100644
index 0000000..b235d44
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedPCATest.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+    ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), 
list($rows, $cols)))
+[X2,M] = pca(X=X,  K=2, scale=$scaleAndShift, center=$scaleAndShift)
+write(X2, $out)
diff --git a/src/test/scripts/functions/federated/FederatedPCATestReference.dml 
b/src/test/scripts/functions/federated/FederatedPCATestReference.dml
new file mode 100644
index 0000000..0b17fe0
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedPCATestReference.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($1), read($2))
+[X2,M] = pca(X=X,  K=2, scale=$3, center=$3)
+write(X2, $4)

Reply via email to