This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new a083501  [SYSTEMDS-2543,2549,2623] Additional federated instructions 
(for pca)
a083501 is described below

commit a0835010840346a260029eaebe3813d8f7a05a0f
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Aug 15 19:09:06 2020 +0200

    [SYSTEMDS-2543,2549,2623] Additional federated instructions (for pca)
    
    This patch adds all remaining federated instruction to run pca with and
    without scale and shift. In details this includes:
    
    * Federated matrix-matrix operations (matrix-matrix and matrix-vector w/
    one federated input)
    * Federated column aggregates for uacmean, incl local compensation
    * Federated replace parameterized builtin function
    * Cleanup federated statistics (alignment, spaces)
---
 .../compress/AbstractCompressedMatrixBlock.java    |   3 +-
 .../controlprogram/federated/FederatedRange.java   |   9 +-
 .../controlprogram/federated/FederationMap.java    |   4 +
 .../controlprogram/federated/FederationUtils.java  |  53 +++++++--
 .../cp/ParameterizedBuiltinCPInstruction.java      |   4 +
 .../fed/AggregateUnaryFEDInstruction.java          |   8 +-
 .../instructions/fed/BinaryFEDInstruction.java     |   2 +-
 .../fed/BinaryMatrixMatrixFEDInstruction.java      |  61 +++++++++++
 .../runtime/instructions/fed/FEDInstruction.java   |   1 +
 .../instructions/fed/FEDInstructionUtils.java      |  27 +++--
 .../fed/ParameterizedBuiltinFEDInstruction.java    | 121 +++++++++++++++++++++
 .../sysds/runtime/matrix/data/CM_N_COVCell.java    |   2 +-
 .../sysds/runtime/matrix/data/MatrixBlock.java     |   3 +-
 .../sysds/runtime/matrix/data/MatrixCell.java      |   3 +-
 .../sysds/runtime/matrix/data/MatrixValue.java     |   2 +-
 .../java/org/apache/sysds/utils/Statistics.java    |  14 +--
 .../test/functions/federated/FederatedPCATest.java |   9 +-
 17 files changed, 280 insertions(+), 46 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java
 
b/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java
index bf86ede..913563c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java
@@ -148,11 +148,12 @@ public abstract class AbstractCompressedMatrixBlock 
extends MatrixBlock {
        }
 
        @Override
-       public void binaryOperationsInPlace(BinaryOperator op, MatrixValue 
thatValue) {
+       public MatrixBlock binaryOperationsInPlace(BinaryOperator op, 
MatrixValue thatValue) {
                printDecompressWarning("binaryOperationsInPlace", (MatrixBlock) 
thatValue);
                MatrixBlock left = isCompressed() ? decompress() : this;
                MatrixBlock right = getUncompressed(thatValue);
                left.binaryOperationsInPlace(op, right);
+               return this;
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
index b4f69ad..6571666 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
@@ -71,12 +71,15 @@ public class FederatedRange implements 
Comparable<FederatedRange> {
        
        public long getSize() {
                long size = 1;
-               for (int i = 0; i < _beginDims.length; i++) {
-                       size *= _endDims[i] - _beginDims[i];
-               }
+               for (int i = 0; i < _beginDims.length; i++)
+                       size *= getSize(i);
                return size;
        }
        
+       public long getSize(int dim) {
+               return _endDims[dim] - _beginDims[dim];
+       }
+       
        @Override
        public int compareTo(FederatedRange o) {
                for (int i = 0; i < _beginDims.length; i++) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index f224da2..04532fd 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -62,6 +62,10 @@ public class FederationMap
                return _ID >= 0;
        }
        
+       public FederatedRange[] getFederatedRanges() {
+               return _fedMap.keySet().toArray(new FederatedRange[0]);
+       }
+       
        public FederatedRequest broadcast(CacheableData<?> data) {
                //prepare single request for all federated data
                long id = FederationUtils.getNextFedDataID();
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index f2c8227..c34fa62 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -29,12 +29,16 @@ import org.apache.sysds.runtime.DMLRuntimeException;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
 import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
 import org.apache.sysds.runtime.functionobjects.KahanFunction;
+import org.apache.sysds.runtime.functionobjects.Mean;
 import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.DoubleObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
 import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
 
 public class FederationUtils {
@@ -50,9 +54,10 @@ public class FederationUtils {
                String linst = inst.replace(ExecType.SPARK.name(), 
ExecType.CP.name());
                linst = 
linst.replace(Lop.OPERAND_DELIMITOR+varOldOut.getName(), 
Lop.OPERAND_DELIMITOR+String.valueOf(id));
                for(int i=0; i<varOldIn.length; i++)
-                       if( varOldIn[i] != null )
-                               linst = 
linst.replace(Lop.OPERAND_DELIMITOR+varOldIn[i].getName(),
-                                       
Lop.OPERAND_DELIMITOR+String.valueOf(varNewIn[i]));
+                       if( varOldIn[i] != null ) {
+                               linst = 
linst.replace(Lop.OPERAND_DELIMITOR+varOldIn[i].getName(), 
Lop.OPERAND_DELIMITOR+String.valueOf(varNewIn[i]));
+                               linst = 
linst.replace("="+varOldIn[i].getName(), "="+String.valueOf(varNewIn[i])); 
//parameterized
+                       }
                return new FederatedRequest(RequestType.EXEC_INST, id, linst);
        }
 
@@ -69,6 +74,29 @@ public class FederationUtils {
                }
        }
        
+       public static MatrixBlock aggMean(Future<FederatedResponse>[] ffr, 
FederationMap map) {
+               try {
+                       FederatedRange[] ranges = map.getFederatedRanges();
+                       BinaryOperator bop = 
InstructionUtils.parseBinaryOperator("+");
+                       ScalarOperator sop1 = 
InstructionUtils.parseScalarBinaryOperator("*", false);
+                       MatrixBlock ret = null;
+                       long size = 0;
+                       for(int i=0; i<ffr.length; i++) {
+                               MatrixBlock tmp = 
(MatrixBlock)ffr[i].get().getData()[0];
+                               size += ranges[i].getSize(0);
+                               sop1 = sop1.setConstant(ranges[i].getSize(0));
+                               tmp = tmp.scalarOperations(sop1, new 
MatrixBlock());
+                               ret = (ret==null) ? tmp : 
ret.binaryOperationsInPlace(bop, tmp);
+                       }
+                       ScalarOperator sop2 = 
InstructionUtils.parseScalarBinaryOperator("/", false);
+                       sop2 = sop2.setConstant(size);
+                       return ret.scalarOperations(sop2, new MatrixBlock());
+               }
+               catch(Exception ex) {
+                       throw new DMLRuntimeException(ex);
+               }
+       }
+       
        public static MatrixBlock[] getResults(Future<FederatedResponse>[] ffr) 
{
                try {
                        MatrixBlock[] ret = new MatrixBlock[ffr.length];
@@ -111,13 +139,20 @@ public class FederationUtils {
                }
        }
 
-       public static MatrixBlock aggMatrix(AggregateUnaryOperator aop, 
Future<FederatedResponse>[] ffr) {
-               if( !(aop.aggOp.increOp.fn instanceof KahanFunction) ) {
-                       throw new DMLRuntimeException("Unsupported aggregation 
operator: "
-                               + aop.aggOp.increOp.getClass().getSimpleName());
+       public static MatrixBlock aggMatrix(AggregateUnaryOperator aop, 
Future<FederatedResponse>[] ffr, FederationMap map) {
+               // handle row aggregate
+               if( aop.isRowAggregate() ) {
+                       //independent of aggregation function for 
row-partitioned federated matrices
+                       return rbind(ffr);
                }
                
-               //assumes full row partitions for row and col aggregates
-               return aop.isRowAggregate() ?  rbind(ffr) : aggAdd(ffr);
+               // handle col aggregate
+               if( aop.aggOp.increOp.fn instanceof KahanFunction )
+                       return aggAdd(ffr);
+               else if( aop.aggOp.increOp.fn instanceof Mean )
+                       return aggMean(ffr, map);
+               else
+                       throw new DMLRuntimeException("Unsupported aggregation 
operator: "
+                               + 
aop.aggOp.increOp.fn.getClass().getSimpleName());
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index f330793..4ced46f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -450,6 +450,10 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
                }
        }
        
+       public MatrixObject getTarget(ExecutionContext ec) {
+               return ec.getMatrixObject(params.get("target"));
+       }
+       
        private CPOperand getTargetOperand() {
                return new CPOperand(params.get("target"), ValueType.FP64, 
DataType.MATRIX);
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index e5dd81e..a9b655b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -26,6 +26,7 @@ import 
org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -61,11 +62,12 @@ public class AggregateUnaryFEDInstruction extends 
UnaryFEDInstruction {
                FederatedRequest fr2 = new 
FederatedRequest(RequestType.GET_VAR, fr1.getID());
                
                //execute federated commands and cleanups
-               Future<FederatedResponse>[] tmp = 
in.getFedMapping().execute(fr1, fr2);
-               in.getFedMapping().cleanup(fr1.getID());
+               FederationMap map = in.getFedMapping();
+               Future<FederatedResponse>[] tmp = map.execute(fr1, fr2);
+               map.cleanup(fr1.getID());
                if( output.isScalar() )
                        ec.setVariable(output.getName(), 
FederationUtils.aggScalar(aop, tmp));
                else
-                       ec.setMatrixOutput(output.getName(), 
FederationUtils.aggMatrix(aop, tmp));
+                       ec.setMatrixOutput(output.getName(), 
FederationUtils.aggMatrix(aop, tmp, map));
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
index 9782558..f1f8f38 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
@@ -48,7 +48,7 @@ public abstract class BinaryFEDInstruction extends 
ComputationFEDInstruction {
                if( in1.getDataType() == DataType.SCALAR && in2.getDataType() 
== DataType.SCALAR )
                        throw new DMLRuntimeException("Federated binary scalar 
scalar operations not yet supported");
                else if( in1.getDataType() == DataType.MATRIX && 
in2.getDataType() == DataType.MATRIX )
-                       throw new DMLRuntimeException("Federated binary matrix 
matrix operations not yet supported");
+                       return new BinaryMatrixMatrixFEDInstruction(operator, 
in1, in2, out, opcode, str);
                else if( in1.getDataType() == DataType.TENSOR && 
in2.getDataType() == DataType.TENSOR )
                        throw new DMLRuntimeException("Federated binary tensor 
tensor operations not yet supported");
                else if( in1.isMatrix() && in2.isScalar() )
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
new file mode 100644
index 0000000..d124c76
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
+{
+       protected BinaryMatrixMatrixFEDInstruction(Operator op,
+               CPOperand in1, CPOperand in2, CPOperand out, String opcode, 
String istr) {
+               super(FEDType.Binary, op, in1, in2, out, opcode, istr);
+       }
+
+       @Override
+       public void processInstruction(ExecutionContext ec) {
+               MatrixObject mo1 = ec.getMatrixObject(input1);
+               MatrixObject mo2 = ec.getMatrixObject(input2);
+               
+               if( mo2.isFederated() ) {
+                       throw new DMLRuntimeException("Matrix-matrix binary 
operations "
+                               + " with a federated right input are not 
supported yet.");
+               }
+               
+               //matrix-matrix binary operations -> lhs fed input -> fed output
+               FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
+               FederatedRequest fr2 = 
FederationUtils.callInstruction(instString, output,
+                       new CPOperand[]{input1, input2}, new 
long[]{mo1.getFedMapping().getID(), fr1.getID()});
+               
+               //execute federated instruction and cleanup intermediates
+               mo1.getFedMapping().execute(fr1, fr2);
+               mo1.getFedMapping().cleanup(fr1.getID());
+               
+               //derive new fed mapping for output
+               MatrixObject out = ec.getMatrixObject(output);
+               out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+               
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID()));
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index d6bd388..9e58e52 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -33,6 +33,7 @@ public abstract class FEDInstruction extends Instruction {
                Binary,
                Init,
                MultiReturnParameterizedBuiltin,
+               ParameterizedBuiltin,
                Tsmm,
        }
        
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 5f97350..00f3b04 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -54,25 +54,24 @@ public class FEDInstructionUtils {
                }
                else if (inst instanceof BinaryCPInstruction) {
                        BinaryCPInstruction instruction = (BinaryCPInstruction) 
inst;
-                       if( instruction.input1.isMatrix() && 
instruction.input2.isScalar() ){
-                               MatrixObject mo = 
ec.getMatrixObject(instruction.input1);
-                               if(mo.isFederated())
-                                       return 
BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
+                       if( instruction.input1.isMatrix() && 
ec.getMatrixObject(instruction.input1).isFederated()
+                               || instruction.input2.isMatrix() && 
ec.getMatrixObject(instruction.input2).isFederated() ) {
+                               return 
BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
                        }
-                       if( instruction.input2.isMatrix() && 
instruction.input1.isScalar() ){
-                               MatrixObject mo = 
ec.getMatrixObject(instruction.input2);
-                               if(mo.isFederated())
-                                       return 
BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
+               }
+               else if( inst instanceof ParameterizedBuiltinCPInstruction ) {
+                       ParameterizedBuiltinCPInstruction pinst = 
(ParameterizedBuiltinCPInstruction)inst;
+                       if(pinst.getOpcode().equals("replace") && 
pinst.getTarget(ec).isFederated()) {
+                               return 
ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString());
                        }
                }
                else if (inst instanceof 
MultiReturnParameterizedBuiltinCPInstruction) {
-                       MultiReturnParameterizedBuiltinCPInstruction 
instruction = (MultiReturnParameterizedBuiltinCPInstruction) inst;
-                       String opcode = instruction.getOpcode();
-                       if(opcode.equals("transformencode") && 
instruction.input1.isFrame()) {
-                               CacheableData<?> fo = 
ec.getCacheableData(instruction.input1);
+                       MultiReturnParameterizedBuiltinCPInstruction minst = 
(MultiReturnParameterizedBuiltinCPInstruction) inst;
+                       if(minst.getOpcode().equals("transformencode") && 
minst.input1.isFrame()) {
+                               CacheableData<?> fo = 
ec.getCacheableData(minst.input1);
                                if(fo.isFederated()) {
                                        return 
MultiReturnParameterizedBuiltinFEDInstruction
-                                               
.parseInstruction(instruction.getInstructionString());
+                                               
.parseInstruction(minst.getInstructionString());
                                }
                        }
                }
@@ -80,7 +79,7 @@ public class FEDInstructionUtils {
                        MMTSJCPInstruction linst = (MMTSJCPInstruction) inst;
                        MatrixObject mo = ec.getMatrixObject(linst.input1);
                        if( mo.isFederated() )
-                               return 
TsmmFEDInstruction.parseInstruction(linst.toString());
+                               return 
TsmmFEDInstruction.parseInstruction(linst.getInstructionString());
                }
                return inst;
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
new file mode 100644
index 0000000..3a5ff8a
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.functionobjects.ParameterizedBuiltin;
+import org.apache.sysds.runtime.functionobjects.ValueFunction;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
+
+public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstruction {
+
+       protected final LinkedHashMap<String, String> params;
+       
+       protected ParameterizedBuiltinFEDInstruction(Operator op,
+               LinkedHashMap<String, String> paramsMap, CPOperand out, String 
opcode, String istr)
+       {
+               super(FEDType.ParameterizedBuiltin, op, null, null, out, 
opcode, istr);
+               params = paramsMap;
+       }
+       
+       public HashMap<String,String> getParameterMap() { 
+               return params; 
+       }
+       
+       public String getParam(String key) {
+               return getParameterMap().get(key);
+       }
+       
+       public static LinkedHashMap<String, String> 
constructParameterMap(String[] params) {
+               // process all elements in "params" except first(opcode) and 
last(output)
+               LinkedHashMap<String,String> paramMap = new LinkedHashMap<>();
+               
+               // all parameters are of form <name=value>
+               String[] parts;
+               for ( int i=1; i <= params.length-2; i++ ) {
+                       parts = params[i].split(Lop.NAME_VALUE_SEPARATOR);
+                       paramMap.put(parts[0], parts[1]);
+               }
+               
+               return paramMap;
+       }
+       
+       public static ParameterizedBuiltinFEDInstruction parseInstruction ( 
String str ) {
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+               // first part is always the opcode
+               String opcode = parts[0];
+               // last part is always the output
+               CPOperand out = new CPOperand( parts[parts.length-1] ); 
+       
+               // process remaining parts and build a hash map
+               LinkedHashMap<String,String> paramsMap = 
constructParameterMap(parts);
+       
+               // determine the appropriate value function
+               ValueFunction func = null;
+               if( opcode.equalsIgnoreCase("replace") ) {
+                       func = 
ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
+                       return new ParameterizedBuiltinFEDInstruction(new 
SimpleOperator(func), paramsMap, out, opcode, str);
+               }
+               else {
+                       throw new DMLRuntimeException("Unsupported opcode (" + 
opcode + ") for ParameterizedBuiltinFEDInstruction.");
+               }
+       }
+       
+       @Override 
+       public void processInstruction(ExecutionContext ec) {
+               String opcode = getOpcode();
+               if ( opcode.equalsIgnoreCase("replace") ) {
+                       //similar to unary federated instructions, get 
federated input
+                       //execute instruction, and derive federated output 
matrix
+                       MatrixObject mo = getTarget(ec);
+                       FederatedRequest fr1 = 
FederationUtils.callInstruction(instString, output,
+                               new CPOperand[]{getTargetOperand()}, new 
long[]{mo.getFedMapping().getID()});
+                       mo.getFedMapping().execute(fr1);
+                       
+                       //derive new fed mapping for output
+                       MatrixObject out = ec.getMatrixObject(output);
+                       
out.getDataCharacteristics().set(mo.getDataCharacteristics());
+                       
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
+               }
+               else {
+                       throw new DMLRuntimeException("Unknown opcode : " + 
opcode);
+               }
+       }
+       
+       public MatrixObject getTarget(ExecutionContext ec) {
+               return ec.getMatrixObject(params.get("target"));
+       }
+       
+       private CPOperand getTargetOperand() {
+               return new CPOperand(params.get("target"), ValueType.FP64, 
DataType.MATRIX);
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java
index a79677b..063ff77 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java
@@ -61,7 +61,7 @@ public class CM_N_COVCell extends MatrixValue implements 
WritableComparable
        }
 
        @Override
-       public void binaryOperationsInPlace(BinaryOperator op, MatrixValue 
thatValue) {
+       public MatrixValue binaryOperationsInPlace(BinaryOperator op, 
MatrixValue thatValue) {
                throw new DMLRuntimeException("operation not supported for 
CM_N_COVCell");
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 45a0965..ed52481 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -2837,7 +2837,7 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
        }
 
        @Override
-       public void binaryOperationsInPlace(BinaryOperator op, MatrixValue 
thatValue) {
+       public MatrixBlock binaryOperationsInPlace(BinaryOperator op, 
MatrixValue thatValue) {
                MatrixBlock that=checkType(thatValue);
                if( !LibMatrixBincell.isValidDimensionsBinary(this, that) ) {
                        throw new RuntimeException("block sizes are not matched 
for binary " +
@@ -2853,6 +2853,7 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                
                //core binary cell operation
                LibMatrixBincell.bincellOpInPlace(this, that, op);
+               return this;
        }
        
        public MatrixBlock ternaryOperations(TernaryOperator op, MatrixBlock 
m2, MatrixBlock m3, MatrixBlock ret) {
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java
index 42338d9..10d7e61 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java
@@ -198,10 +198,11 @@ public class MatrixCell extends MatrixValue implements 
WritableComparable, Seria
        }
 
        @Override
-       public void binaryOperationsInPlace(BinaryOperator op,
+       public MatrixValue binaryOperationsInPlace(BinaryOperator op,
                        MatrixValue thatValue) {
                MatrixCell c2=checkType(thatValue);
                setValue(op.fn.execute(this.getValue(), c2.getValue()));
+               return this;
        }
 
        public void denseScalarOperationsInPlace(ScalarOperator op) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java
index 102e433..9b213ec 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java
@@ -103,7 +103,7 @@ public abstract class MatrixValue implements 
WritableComparable
        
        public abstract MatrixValue binaryOperations(BinaryOperator op, 
MatrixValue thatValue, MatrixValue result);
        
-       public abstract void binaryOperationsInPlace(BinaryOperator op, 
MatrixValue thatValue);
+       public abstract MatrixValue binaryOperationsInPlace(BinaryOperator op, 
MatrixValue thatValue);
        
        public abstract MatrixValue reorgOperations(ReorgOperator op, 
MatrixValue result,
                        int startRow, int startColumn, int length);
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java 
b/src/main/java/org/apache/sysds/utils/Statistics.java
index 0a6ec38..b498b0e 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -1020,13 +1020,13 @@ public class Statistics
                                sb.append("ParFor total update in-place:\t" + 
lTotalUIPVar + "/" + lTotalLixUIP + "/" + lTotalLix + "\n");
                        }
                        if( federatedReadCount.longValue() > 0){
-                               sb.append("Federated (Reads,Puts,Gets) :\t(" + 
-                                       federatedReadCount.longValue() + "," +
-                                       federatedPutCount.longValue() + "," +
-                                       federatedGetCount.longValue() + ")\n");
-                               sb.append("Federated Execute (In,UDF)  :\t(" +
-                                       
federatedExecuteInstructionCount.longValue() + "," +
-                                       federatedExecuteUDFCount.longValue() + 
")\n");
+                               sb.append("Federated I/O (Read, Put, Get):\t" + 
+                                       federatedReadCount.longValue() + "/" +
+                                       federatedPutCount.longValue() + "/" +
+                                       federatedGetCount.longValue() + ".\n");
+                               sb.append("Federated Execute (Inst, UDF):\t" +
+                                       
federatedExecuteInstructionCount.longValue() + "/" +
+                                       federatedExecuteUDFCount.longValue() + 
".\n");
                        }
 
                        sb.append("Total JIT compile time:\t\t" + 
((double)getJITCompileTime())/1000 + " sec.\n");
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
index 29826f8..bf674a8 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
@@ -61,8 +61,7 @@ public class FederatedPCATest extends AutomatedTestBase {
                // rows have to be even and > 1
                return Arrays.asList(new Object[][] {
                        {10000, 10, false}, {2000, 50, false}, {1000, 100, 
false},
-                       //TODO support for federated uacmean, uacvar
-                       //{10000, 10, true}, {2000, 50, true}, {1000, 100, true}
+                       {10000, 10, true}, {2000, 50, true}, {1000, 100, true}
                });
        }
 
@@ -99,7 +98,6 @@ public class FederatedPCATest extends AutomatedTestBase {
 
                TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
                loadTestConfiguration(config);
-               setOutputBuffering(false);
                
                // Run reference dml script with normal matrix
                fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
@@ -124,8 +122,11 @@ public class FederatedPCATest extends AutomatedTestBase {
                Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
                Assert.assertTrue(heavyHittersContainsString("fed_tsmm"));
                if( scaleAndShift ) {
+                       
Assert.assertTrue(heavyHittersContainsString("fed_uacsqk+"));
                        
Assert.assertTrue(heavyHittersContainsString("fed_uacmean"));
-                       
Assert.assertTrue(heavyHittersContainsString("fed_uacvar"));
+                       Assert.assertTrue(heavyHittersContainsString("fed_-"));
+                       Assert.assertTrue(heavyHittersContainsString("fed_/"));
+                       
Assert.assertTrue(heavyHittersContainsString("fed_replace"));
                }
                
                resetExecMode(platformOld);

Reply via email to