This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new a083501 [SYSTEMDS-2543,2549,2623] Additional federated instructions
(for pca)
a083501 is described below
commit a0835010840346a260029eaebe3813d8f7a05a0f
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Aug 15 19:09:06 2020 +0200
[SYSTEMDS-2543,2549,2623] Additional federated instructions (for pca)
This patch adds all remaining federated instruction to run pca with and
without scale and shift. In details this includes:
* Federated matrix-matrix operations (matrix-matrix and matrix-vector w/
one federated input)
* Federated column aggregates for uacmean, incl local compensation
* Federated replace parameterized builtin function
* Cleanup federated statistics (alignment, spaces)
---
.../compress/AbstractCompressedMatrixBlock.java | 3 +-
.../controlprogram/federated/FederatedRange.java | 9 +-
.../controlprogram/federated/FederationMap.java | 4 +
.../controlprogram/federated/FederationUtils.java | 53 +++++++--
.../cp/ParameterizedBuiltinCPInstruction.java | 4 +
.../fed/AggregateUnaryFEDInstruction.java | 8 +-
.../instructions/fed/BinaryFEDInstruction.java | 2 +-
.../fed/BinaryMatrixMatrixFEDInstruction.java | 61 +++++++++++
.../runtime/instructions/fed/FEDInstruction.java | 1 +
.../instructions/fed/FEDInstructionUtils.java | 27 +++--
.../fed/ParameterizedBuiltinFEDInstruction.java | 121 +++++++++++++++++++++
.../sysds/runtime/matrix/data/CM_N_COVCell.java | 2 +-
.../sysds/runtime/matrix/data/MatrixBlock.java | 3 +-
.../sysds/runtime/matrix/data/MatrixCell.java | 3 +-
.../sysds/runtime/matrix/data/MatrixValue.java | 2 +-
.../java/org/apache/sysds/utils/Statistics.java | 14 +--
.../test/functions/federated/FederatedPCATest.java | 9 +-
17 files changed, 280 insertions(+), 46 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java
index bf86ede..913563c 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/AbstractCompressedMatrixBlock.java
@@ -148,11 +148,12 @@ public abstract class AbstractCompressedMatrixBlock
extends MatrixBlock {
}
@Override
- public void binaryOperationsInPlace(BinaryOperator op, MatrixValue
thatValue) {
+ public MatrixBlock binaryOperationsInPlace(BinaryOperator op,
MatrixValue thatValue) {
printDecompressWarning("binaryOperationsInPlace", (MatrixBlock)
thatValue);
MatrixBlock left = isCompressed() ? decompress() : this;
MatrixBlock right = getUncompressed(thatValue);
left.binaryOperationsInPlace(op, right);
+ return this;
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
index b4f69ad..6571666 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
@@ -71,12 +71,15 @@ public class FederatedRange implements
Comparable<FederatedRange> {
public long getSize() {
long size = 1;
- for (int i = 0; i < _beginDims.length; i++) {
- size *= _endDims[i] - _beginDims[i];
- }
+ for (int i = 0; i < _beginDims.length; i++)
+ size *= getSize(i);
return size;
}
+ public long getSize(int dim) {
+ return _endDims[dim] - _beginDims[dim];
+ }
+
@Override
public int compareTo(FederatedRange o) {
for (int i = 0; i < _beginDims.length; i++) {
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index f224da2..04532fd 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -62,6 +62,10 @@ public class FederationMap
return _ID >= 0;
}
+ public FederatedRange[] getFederatedRanges() {
+ return _fedMap.keySet().toArray(new FederatedRange[0]);
+ }
+
public FederatedRequest broadcast(CacheableData<?> data) {
//prepare single request for all federated data
long id = FederationUtils.getNextFedDataID();
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index f2c8227..c34fa62 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -29,12 +29,16 @@ import org.apache.sysds.runtime.DMLRuntimeException;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.functionobjects.KahanFunction;
+import org.apache.sysds.runtime.functionobjects.Mean;
import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
public class FederationUtils {
@@ -50,9 +54,10 @@ public class FederationUtils {
String linst = inst.replace(ExecType.SPARK.name(),
ExecType.CP.name());
linst =
linst.replace(Lop.OPERAND_DELIMITOR+varOldOut.getName(),
Lop.OPERAND_DELIMITOR+String.valueOf(id));
for(int i=0; i<varOldIn.length; i++)
- if( varOldIn[i] != null )
- linst =
linst.replace(Lop.OPERAND_DELIMITOR+varOldIn[i].getName(),
-
Lop.OPERAND_DELIMITOR+String.valueOf(varNewIn[i]));
+ if( varOldIn[i] != null ) {
+ linst =
linst.replace(Lop.OPERAND_DELIMITOR+varOldIn[i].getName(),
Lop.OPERAND_DELIMITOR+String.valueOf(varNewIn[i]));
+ linst =
linst.replace("="+varOldIn[i].getName(), "="+String.valueOf(varNewIn[i]));
//parameterized
+ }
return new FederatedRequest(RequestType.EXEC_INST, id, linst);
}
@@ -69,6 +74,29 @@ public class FederationUtils {
}
}
+ public static MatrixBlock aggMean(Future<FederatedResponse>[] ffr,
FederationMap map) {
+ try {
+ FederatedRange[] ranges = map.getFederatedRanges();
+ BinaryOperator bop =
InstructionUtils.parseBinaryOperator("+");
+ ScalarOperator sop1 =
InstructionUtils.parseScalarBinaryOperator("*", false);
+ MatrixBlock ret = null;
+ long size = 0;
+ for(int i=0; i<ffr.length; i++) {
+ MatrixBlock tmp =
(MatrixBlock)ffr[i].get().getData()[0];
+ size += ranges[i].getSize(0);
+ sop1 = sop1.setConstant(ranges[i].getSize(0));
+ tmp = tmp.scalarOperations(sop1, new
MatrixBlock());
+ ret = (ret==null) ? tmp :
ret.binaryOperationsInPlace(bop, tmp);
+ }
+ ScalarOperator sop2 =
InstructionUtils.parseScalarBinaryOperator("/", false);
+ sop2 = sop2.setConstant(size);
+ return ret.scalarOperations(sop2, new MatrixBlock());
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+ }
+
public static MatrixBlock[] getResults(Future<FederatedResponse>[] ffr)
{
try {
MatrixBlock[] ret = new MatrixBlock[ffr.length];
@@ -111,13 +139,20 @@ public class FederationUtils {
}
}
- public static MatrixBlock aggMatrix(AggregateUnaryOperator aop,
Future<FederatedResponse>[] ffr) {
- if( !(aop.aggOp.increOp.fn instanceof KahanFunction) ) {
- throw new DMLRuntimeException("Unsupported aggregation
operator: "
- + aop.aggOp.increOp.getClass().getSimpleName());
+ public static MatrixBlock aggMatrix(AggregateUnaryOperator aop,
Future<FederatedResponse>[] ffr, FederationMap map) {
+ // handle row aggregate
+ if( aop.isRowAggregate() ) {
+ //independent of aggregation function for
row-partitioned federated matrices
+ return rbind(ffr);
}
- //assumes full row partitions for row and col aggregates
- return aop.isRowAggregate() ? rbind(ffr) : aggAdd(ffr);
+ // handle col aggregate
+ if( aop.aggOp.increOp.fn instanceof KahanFunction )
+ return aggAdd(ffr);
+ else if( aop.aggOp.increOp.fn instanceof Mean )
+ return aggMean(ffr, map);
+ else
+ throw new DMLRuntimeException("Unsupported aggregation
operator: "
+ +
aop.aggOp.increOp.fn.getClass().getSimpleName());
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index f330793..4ced46f 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -450,6 +450,10 @@ public class ParameterizedBuiltinCPInstruction extends
ComputationCPInstruction
}
}
+ public MatrixObject getTarget(ExecutionContext ec) {
+ return ec.getMatrixObject(params.get("target"));
+ }
+
private CPOperand getTargetOperand() {
return new CPOperand(params.get("target"), ValueType.FP64,
DataType.MATRIX);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index e5dd81e..a9b655b 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -26,6 +26,7 @@ import
org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -61,11 +62,12 @@ public class AggregateUnaryFEDInstruction extends
UnaryFEDInstruction {
FederatedRequest fr2 = new
FederatedRequest(RequestType.GET_VAR, fr1.getID());
//execute federated commands and cleanups
- Future<FederatedResponse>[] tmp =
in.getFedMapping().execute(fr1, fr2);
- in.getFedMapping().cleanup(fr1.getID());
+ FederationMap map = in.getFedMapping();
+ Future<FederatedResponse>[] tmp = map.execute(fr1, fr2);
+ map.cleanup(fr1.getID());
if( output.isScalar() )
ec.setVariable(output.getName(),
FederationUtils.aggScalar(aop, tmp));
else
- ec.setMatrixOutput(output.getName(),
FederationUtils.aggMatrix(aop, tmp));
+ ec.setMatrixOutput(output.getName(),
FederationUtils.aggMatrix(aop, tmp, map));
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
index 9782558..f1f8f38 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
@@ -48,7 +48,7 @@ public abstract class BinaryFEDInstruction extends
ComputationFEDInstruction {
if( in1.getDataType() == DataType.SCALAR && in2.getDataType()
== DataType.SCALAR )
throw new DMLRuntimeException("Federated binary scalar
scalar operations not yet supported");
else if( in1.getDataType() == DataType.MATRIX &&
in2.getDataType() == DataType.MATRIX )
- throw new DMLRuntimeException("Federated binary matrix
matrix operations not yet supported");
+ return new BinaryMatrixMatrixFEDInstruction(operator,
in1, in2, out, opcode, str);
else if( in1.getDataType() == DataType.TENSOR &&
in2.getDataType() == DataType.TENSOR )
throw new DMLRuntimeException("Federated binary tensor
tensor operations not yet supported");
else if( in1.isMatrix() && in2.isScalar() )
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
new file mode 100644
index 0000000..d124c76
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
+{
+ protected BinaryMatrixMatrixFEDInstruction(Operator op,
+ CPOperand in1, CPOperand in2, CPOperand out, String opcode,
String istr) {
+ super(FEDType.Binary, op, in1, in2, out, opcode, istr);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ MatrixObject mo1 = ec.getMatrixObject(input1);
+ MatrixObject mo2 = ec.getMatrixObject(input2);
+
+ if( mo2.isFederated() ) {
+ throw new DMLRuntimeException("Matrix-matrix binary
operations "
+ + " with a federated right input are not
supported yet.");
+ }
+
+ //matrix-matrix binary operations -> lhs fed input -> fed output
+ FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
+ FederatedRequest fr2 =
FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2}, new
long[]{mo1.getFedMapping().getID(), fr1.getID()});
+
+ //execute federated instruction and cleanup intermediates
+ mo1.getFedMapping().execute(fr1, fr2);
+ mo1.getFedMapping().cleanup(fr1.getID());
+
+ //derive new fed mapping for output
+ MatrixObject out = ec.getMatrixObject(output);
+ out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID()));
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index d6bd388..9e58e52 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -33,6 +33,7 @@ public abstract class FEDInstruction extends Instruction {
Binary,
Init,
MultiReturnParameterizedBuiltin,
+ ParameterizedBuiltin,
Tsmm,
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 5f97350..00f3b04 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -54,25 +54,24 @@ public class FEDInstructionUtils {
}
else if (inst instanceof BinaryCPInstruction) {
BinaryCPInstruction instruction = (BinaryCPInstruction)
inst;
- if( instruction.input1.isMatrix() &&
instruction.input2.isScalar() ){
- MatrixObject mo =
ec.getMatrixObject(instruction.input1);
- if(mo.isFederated())
- return
BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
+ if( instruction.input1.isMatrix() &&
ec.getMatrixObject(instruction.input1).isFederated()
+ || instruction.input2.isMatrix() &&
ec.getMatrixObject(instruction.input2).isFederated() ) {
+ return
BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
}
- if( instruction.input2.isMatrix() &&
instruction.input1.isScalar() ){
- MatrixObject mo =
ec.getMatrixObject(instruction.input2);
- if(mo.isFederated())
- return
BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
+ }
+ else if( inst instanceof ParameterizedBuiltinCPInstruction ) {
+ ParameterizedBuiltinCPInstruction pinst =
(ParameterizedBuiltinCPInstruction)inst;
+ if(pinst.getOpcode().equals("replace") &&
pinst.getTarget(ec).isFederated()) {
+ return
ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString());
}
}
else if (inst instanceof
MultiReturnParameterizedBuiltinCPInstruction) {
- MultiReturnParameterizedBuiltinCPInstruction
instruction = (MultiReturnParameterizedBuiltinCPInstruction) inst;
- String opcode = instruction.getOpcode();
- if(opcode.equals("transformencode") &&
instruction.input1.isFrame()) {
- CacheableData<?> fo =
ec.getCacheableData(instruction.input1);
+ MultiReturnParameterizedBuiltinCPInstruction minst =
(MultiReturnParameterizedBuiltinCPInstruction) inst;
+ if(minst.getOpcode().equals("transformencode") &&
minst.input1.isFrame()) {
+ CacheableData<?> fo =
ec.getCacheableData(minst.input1);
if(fo.isFederated()) {
return
MultiReturnParameterizedBuiltinFEDInstruction
-
.parseInstruction(instruction.getInstructionString());
+
.parseInstruction(minst.getInstructionString());
}
}
}
@@ -80,7 +79,7 @@ public class FEDInstructionUtils {
MMTSJCPInstruction linst = (MMTSJCPInstruction) inst;
MatrixObject mo = ec.getMatrixObject(linst.input1);
if( mo.isFederated() )
- return
TsmmFEDInstruction.parseInstruction(linst.toString());
+ return
TsmmFEDInstruction.parseInstruction(linst.getInstructionString());
}
return inst;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
new file mode 100644
index 0000000..3a5ff8a
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.functionobjects.ParameterizedBuiltin;
+import org.apache.sysds.runtime.functionobjects.ValueFunction;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
+
+public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstruction {
+
+ protected final LinkedHashMap<String, String> params;
+
+ protected ParameterizedBuiltinFEDInstruction(Operator op,
+ LinkedHashMap<String, String> paramsMap, CPOperand out, String
opcode, String istr)
+ {
+ super(FEDType.ParameterizedBuiltin, op, null, null, out,
opcode, istr);
+ params = paramsMap;
+ }
+
+ public HashMap<String,String> getParameterMap() {
+ return params;
+ }
+
+ public String getParam(String key) {
+ return getParameterMap().get(key);
+ }
+
+ public static LinkedHashMap<String, String>
constructParameterMap(String[] params) {
+ // process all elements in "params" except first(opcode) and
last(output)
+ LinkedHashMap<String,String> paramMap = new LinkedHashMap<>();
+
+ // all parameters are of form <name=value>
+ String[] parts;
+ for ( int i=1; i <= params.length-2; i++ ) {
+ parts = params[i].split(Lop.NAME_VALUE_SEPARATOR);
+ paramMap.put(parts[0], parts[1]);
+ }
+
+ return paramMap;
+ }
+
+ public static ParameterizedBuiltinFEDInstruction parseInstruction (
String str ) {
+ String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
+ // first part is always the opcode
+ String opcode = parts[0];
+ // last part is always the output
+ CPOperand out = new CPOperand( parts[parts.length-1] );
+
+ // process remaining parts and build a hash map
+ LinkedHashMap<String,String> paramsMap =
constructParameterMap(parts);
+
+ // determine the appropriate value function
+ ValueFunction func = null;
+ if( opcode.equalsIgnoreCase("replace") ) {
+ func =
ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
+ return new ParameterizedBuiltinFEDInstruction(new
SimpleOperator(func), paramsMap, out, opcode, str);
+ }
+ else {
+ throw new DMLRuntimeException("Unsupported opcode (" +
opcode + ") for ParameterizedBuiltinFEDInstruction.");
+ }
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ String opcode = getOpcode();
+ if ( opcode.equalsIgnoreCase("replace") ) {
+ //similar to unary federated instructions, get
federated input
+ //execute instruction, and derive federated output
matrix
+ MatrixObject mo = getTarget(ec);
+ FederatedRequest fr1 =
FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{getTargetOperand()}, new
long[]{mo.getFedMapping().getID()});
+ mo.getFedMapping().execute(fr1);
+
+ //derive new fed mapping for output
+ MatrixObject out = ec.getMatrixObject(output);
+
out.getDataCharacteristics().set(mo.getDataCharacteristics());
+
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
+ }
+ else {
+ throw new DMLRuntimeException("Unknown opcode : " +
opcode);
+ }
+ }
+
+ public MatrixObject getTarget(ExecutionContext ec) {
+ return ec.getMatrixObject(params.get("target"));
+ }
+
+ private CPOperand getTargetOperand() {
+ return new CPOperand(params.get("target"), ValueType.FP64,
DataType.MATRIX);
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java
index a79677b..063ff77 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java
@@ -61,7 +61,7 @@ public class CM_N_COVCell extends MatrixValue implements
WritableComparable
}
@Override
- public void binaryOperationsInPlace(BinaryOperator op, MatrixValue
thatValue) {
+ public MatrixValue binaryOperationsInPlace(BinaryOperator op,
MatrixValue thatValue) {
throw new DMLRuntimeException("operation not supported for
CM_N_COVCell");
}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 45a0965..ed52481 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -2837,7 +2837,7 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock, Externalizab
}
@Override
- public void binaryOperationsInPlace(BinaryOperator op, MatrixValue
thatValue) {
+ public MatrixBlock binaryOperationsInPlace(BinaryOperator op,
MatrixValue thatValue) {
MatrixBlock that=checkType(thatValue);
if( !LibMatrixBincell.isValidDimensionsBinary(this, that) ) {
throw new RuntimeException("block sizes are not matched
for binary " +
@@ -2853,6 +2853,7 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock, Externalizab
//core binary cell operation
LibMatrixBincell.bincellOpInPlace(this, that, op);
+ return this;
}
public MatrixBlock ternaryOperations(TernaryOperator op, MatrixBlock
m2, MatrixBlock m3, MatrixBlock ret) {
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java
index 42338d9..10d7e61 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java
@@ -198,10 +198,11 @@ public class MatrixCell extends MatrixValue implements
WritableComparable, Seria
}
@Override
- public void binaryOperationsInPlace(BinaryOperator op,
+ public MatrixValue binaryOperationsInPlace(BinaryOperator op,
MatrixValue thatValue) {
MatrixCell c2=checkType(thatValue);
setValue(op.fn.execute(this.getValue(), c2.getValue()));
+ return this;
}
public void denseScalarOperationsInPlace(ScalarOperator op) {
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java
index 102e433..9b213ec 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java
@@ -103,7 +103,7 @@ public abstract class MatrixValue implements
WritableComparable
public abstract MatrixValue binaryOperations(BinaryOperator op,
MatrixValue thatValue, MatrixValue result);
- public abstract void binaryOperationsInPlace(BinaryOperator op,
MatrixValue thatValue);
+ public abstract MatrixValue binaryOperationsInPlace(BinaryOperator op,
MatrixValue thatValue);
public abstract MatrixValue reorgOperations(ReorgOperator op,
MatrixValue result,
int startRow, int startColumn, int length);
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java
b/src/main/java/org/apache/sysds/utils/Statistics.java
index 0a6ec38..b498b0e 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -1020,13 +1020,13 @@ public class Statistics
sb.append("ParFor total update in-place:\t" +
lTotalUIPVar + "/" + lTotalLixUIP + "/" + lTotalLix + "\n");
}
if( federatedReadCount.longValue() > 0){
- sb.append("Federated (Reads,Puts,Gets) :\t(" +
- federatedReadCount.longValue() + "," +
- federatedPutCount.longValue() + "," +
- federatedGetCount.longValue() + ")\n");
- sb.append("Federated Execute (In,UDF) :\t(" +
-
federatedExecuteInstructionCount.longValue() + "," +
- federatedExecuteUDFCount.longValue() +
")\n");
+ sb.append("Federated I/O (Read, Put, Get):\t" +
+ federatedReadCount.longValue() + "/" +
+ federatedPutCount.longValue() + "/" +
+ federatedGetCount.longValue() + ".\n");
+ sb.append("Federated Execute (Inst, UDF):\t" +
+
federatedExecuteInstructionCount.longValue() + "/" +
+ federatedExecuteUDFCount.longValue() +
".\n");
}
sb.append("Total JIT compile time:\t\t" +
((double)getJITCompileTime())/1000 + " sec.\n");
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
index 29826f8..bf674a8 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
@@ -61,8 +61,7 @@ public class FederatedPCATest extends AutomatedTestBase {
// rows have to be even and > 1
return Arrays.asList(new Object[][] {
{10000, 10, false}, {2000, 50, false}, {1000, 100,
false},
- //TODO support for federated uacmean, uacvar
- //{10000, 10, true}, {2000, 50, true}, {1000, 100, true}
+ {10000, 10, true}, {2000, 50, true}, {1000, 100, true}
});
}
@@ -99,7 +98,6 @@ public class FederatedPCATest extends AutomatedTestBase {
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
- setOutputBuffering(false);
// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
@@ -124,8 +122,11 @@ public class FederatedPCATest extends AutomatedTestBase {
Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
Assert.assertTrue(heavyHittersContainsString("fed_tsmm"));
if( scaleAndShift ) {
+
Assert.assertTrue(heavyHittersContainsString("fed_uacsqk+"));
Assert.assertTrue(heavyHittersContainsString("fed_uacmean"));
-
Assert.assertTrue(heavyHittersContainsString("fed_uacvar"));
+ Assert.assertTrue(heavyHittersContainsString("fed_-"));
+ Assert.assertTrue(heavyHittersContainsString("fed_/"));
+
Assert.assertTrue(heavyHittersContainsString("fed_replace"));
}
resetExecMode(platformOld);