This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 926ffb9 [SYSTEMDS-2922] Federated codegen row operations, incl tests
926ffb9 is described below
commit 926ffb9da9a1f7481e8bdc3ff7da8f3dbe4c8f33
Author: ywcb00 <[email protected]>
AuthorDate: Fri May 14 14:03:30 2021 +0200
[SYSTEMDS-2922] Federated codegen row operations, incl tests
Closes #1218.
---
.../controlprogram/context/ExecutionContext.java | 15 ++
.../instructions/cp/SpoofCPInstruction.java | 15 +-
.../instructions/fed/FEDInstructionUtils.java | 10 +-
.../instructions/fed/SpoofFEDInstruction.java | 173 ++++++++++++++-------
.../instructions/spark/SpoofSPInstruction.java | 19 ++-
.../codegen/FederatedCellwiseTmplTest.java | 62 ++++----
...TmplTest.java => FederatedRowwiseTmplTest.java} | 103 ++++++------
.../federated/codegen/FederatedRowwiseTmplTest.dml | 125 +++++++++++++++
.../codegen/FederatedRowwiseTmplTestReference.dml | 123 +++++++++++++++
9 files changed, 482 insertions(+), 163 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index deaa680..f0448f1 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -36,6 +36,7 @@ import
org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
import org.apache.sysds.runtime.controlprogram.caching.TensorObject;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -789,6 +790,20 @@ public class ExecutionContext {
throw new DMLRuntimeException(ex);
}
}
+
+ public boolean isFederated(CPOperand input) {
+ Data data = getVariable(input);
+ if(data instanceof CacheableData && ((CacheableData<?>)
data).isFederated())
+ return true;
+ return false;
+ }
+
+ public boolean isFederated(CPOperand input, FType type) {
+ Data data = getVariable(input);
+ if(data instanceof CacheableData && ((CacheableData<?>)
data).isFederated(type))
+ return true;
+ return false;
+ }
public void traceLineage(Instruction inst) {
if( _lineage == null )
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/SpoofCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/SpoofCPInstruction.java
index 0ba12a2..e9bacd3 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/SpoofCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/SpoofCPInstruction.java
@@ -28,8 +28,8 @@ import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.runtime.codegen.CodegenUtils;
import org.apache.sysds.runtime.codegen.SpoofOperator;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
-import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.lineage.LineageCodegenItem;
import org.apache.sysds.runtime.lineage.LineageItem;
@@ -131,11 +131,16 @@ public class SpoofCPInstruction extends
ComputationCPInstruction {
}
public boolean isFederated(ExecutionContext ec) {
- for(CPOperand input : _in) {
- Data data = ec.getVariable(input);
- if(data instanceof MatrixObject && ((MatrixObject)
data).isFederated())
+ for(CPOperand input : _in)
+ if( ec.isFederated(input) )
+ return true;
+ return false;
+ }
+
+ public boolean isFederated(ExecutionContext ec, FType type) {
+ for(CPOperand input : _in)
+ if( ec.isFederated(input, type) )
return true;
- }
return false;
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index ebd69d9..8f22539 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -20,6 +20,7 @@
package org.apache.sysds.runtime.instructions.fed;
import org.apache.sysds.runtime.codegen.SpoofCellwise;
+import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -220,8 +221,11 @@ public class FEDInstructionUtils {
}
else if(inst instanceof SpoofCPInstruction) {
SpoofCPInstruction instruction = (SpoofCPInstruction)
inst;
- if(instruction.getOperatorClass().getSuperclass() ==
SpoofCellwise.class && instruction.isFederated(ec))
+ Class<?> scla =
instruction.getOperatorClass().getSuperclass();
+ if( (scla == SpoofCellwise.class &&
instruction.isFederated(ec))
+ || (scla == SpoofRowwise.class&&
instruction.isFederated(ec, FType.ROW))) {
fedinst =
SpoofFEDInstruction.parseInstruction(instruction.getInstructionString());
+ }
}
else if(inst instanceof CtableCPInstruction) {
CtableCPInstruction cinst = (CtableCPInstruction) inst;
@@ -324,7 +328,9 @@ public class FEDInstructionUtils {
}
else if(inst instanceof SpoofSPInstruction) {
SpoofSPInstruction instruction = (SpoofSPInstruction)
inst;
- if(instruction.getOperatorClass().getSuperclass() ==
SpoofCellwise.class && instruction.isFederated(ec)) {
+ Class<?> scla =
instruction.getOperatorClass().getSuperclass();
+ if( (scla == SpoofCellwise.class &&
instruction.isFederated(ec))
+ || (scla == SpoofRowwise.class &&
instruction.isFederated(ec, FType.ROW))) {
fedinst =
SpoofFEDInstruction.parseInstruction(inst.getInstructionString());
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
index 6e59813..2ceada6 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
@@ -22,6 +22,10 @@ package org.apache.sysds.runtime.instructions.fed;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.sysds.runtime.codegen.CodegenUtils;
import org.apache.sysds.runtime.codegen.SpoofCellwise;
+import org.apache.sysds.runtime.codegen.SpoofCellwise.AggOp;
+import org.apache.sysds.runtime.codegen.SpoofCellwise.CellType;
+import org.apache.sysds.runtime.codegen.SpoofRowwise;
+import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType;
import org.apache.sysds.runtime.codegen.SpoofOperator;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -48,9 +52,9 @@ public class SpoofFEDInstruction extends FEDInstruction
private final CPOperand _output;
private SpoofFEDInstruction(SpoofOperator op, CPOperand[] in,
- CPOperand out, String opcode, String inst_str)
+ CPOperand out, String opcode, String instStr)
{
- super(FEDInstruction.FEDType.SpoofFused, opcode, inst_str);
+ super(FEDInstruction.FEDType.SpoofFused, opcode, instStr);
_op = op;
_inputs = in;
_output = out;
@@ -79,15 +83,13 @@ public class SpoofFEDInstruction extends FEDInstruction
ArrayList<CPOperand> inCpoScal = new ArrayList<>();
ArrayList<MatrixObject> inMo = new ArrayList<>();
ArrayList<ScalarObject> inSo = new ArrayList<>();
- MatrixObject fedMo = null;
FederationMap fedMap = null;
for(CPOperand cpo : _inputs) {
Data tmpData = ec.getVariable(cpo);
if(tmpData instanceof MatrixObject) {
MatrixObject tmp = (MatrixObject) tmpData;
- if(fedMo == null & tmp.isFederated()) { //take
first
+ if(fedMap == null & tmp.isFederated()) { //take
first
inCpoMat.add(0, cpo); // insert
federated CPO at the beginning
- fedMo = tmp;
fedMap = tmp.getFedMapping();
}
else {
@@ -108,10 +110,7 @@ public class SpoofFEDInstruction extends FEDInstruction
int index = 0;
frIds[index++] = fedMap.getID(); // insert federation map id at
the beginning
for(MatrixObject mo : inMo) {
- if((fedMo.isFederated(FType.ROW) && mo.getNumRows() > 1
&& (mo.getNumColumns() == 1 || mo.getNumColumns() == fedMap.getSize()))
- || (fedMo.isFederated(FType.ROW) &&
mo.getNumColumns() > 1 && mo.getNumRows() == fedMap.getSize())
- || (fedMo.isFederated(FType.COL) &&
(mo.getNumRows() == 1 || mo.getNumRows() == fedMap.getSize()) &&
mo.getNumColumns() > 1)
- || (fedMo.isFederated(FType.COL) &&
mo.getNumRows() > 1 && mo.getNumColumns() == fedMap.getSize())) {
+ if(needsBroadcastSliced(fedMap, mo.getNumRows(),
mo.getNumColumns())) {
FederatedRequest[] tmpFr =
fedMap.broadcastSliced(mo, false);
frIds[index++] = tmpFr[0].getID();
frBroadcastSliced.add(tmpFr);
@@ -150,68 +149,85 @@ public class SpoofFEDInstruction extends FEDInstruction
Future<FederatedResponse>[] response =
fedMap.executeMultipleSlices(
getTID(), true, frBroadcastSliced.toArray(new
FederatedRequest[0][]), frAll);
- if(((SpoofCellwise)_op).getCellType() ==
SpoofCellwise.CellType.FULL_AGG) { // full aggregation
- if(((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.SUM
- || ((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.SUM_SQ) {
- //aggregate partial results from federated
responses as sum
- AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
- ec.setVariable(_output.getName(),
FederationUtils.aggScalar(aop, response));
+ if(_op.getClass().getSuperclass() == SpoofCellwise.class)
+ setOutputCellwise(ec, response, fedMap);
+ else if(_op.getClass().getSuperclass() == SpoofRowwise.class)
+ setOutputRowwise(ec, response, fedMap);
+ else
+ throw new DMLRuntimeException("Federated code
generation only supported for cellwise and rowwise templates.");
+ }
+
+ private static boolean needsBroadcastSliced(FederationMap fedMap, long
rowNum, long colNum) {
+ if(fedMap.getType() == FType.ROW) {
+ return (rowNum == fedMap.getMaxIndexInRange(0) &&
(colNum == 1 || colNum == fedMap.getSize()))
+ || (colNum > 1 && rowNum == fedMap.getSize());
+ }
+ else if(fedMap.getType() == FType.COL) {
+ return ((rowNum == 1 || rowNum == fedMap.getSize()) &&
colNum == fedMap.getMaxIndexInRange(1))
+ || (rowNum > 1 && colNum == fedMap.getSize());
+ }
+ throw new DMLRuntimeException("Only row partitioned or column
partitioned federated input supported yet.");
+ }
+
+ private void setOutputCellwise(ExecutionContext ec,
Future<FederatedResponse>[] response, FederationMap fedMap)
+ {
+ FType fedType = fedMap.getType();
+ AggOp aggOp = ((SpoofCellwise)_op).getAggOp();
+ CellType cellType = ((SpoofCellwise)_op).getCellType();
+ if(cellType == CellType.FULL_AGG) { // full aggregation
+ AggregateUnaryOperator aop = null;
+ if(aggOp == AggOp.SUM || aggOp == AggOp.SUM_SQ) {
+ // aggregate partial results from federated
responses as sum
+ aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
}
- else if(((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.MIN) {
- //aggregate partial results from federated
responses as min
- AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uamin");
- ec.setVariable(_output.getName(),
FederationUtils.aggScalar(aop, response));
+ else if(aggOp == AggOp.MIN) {
+ // aggregate partial results from federated
responses as min
+ aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uamin");
}
- else if(((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.MAX) {
- //aggregate partial results from federated
responses as max
- AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uamax");
- ec.setVariable(_output.getName(),
FederationUtils.aggScalar(aop, response));
+ else if(aggOp == AggOp.MAX) {
+ // aggregate partial results from federated
responses as max
+ aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uamax");
}
else {
- throw new DMLRuntimeException("Aggregation type
for federated spoof instructions not supported yet.");
+ throw new DMLRuntimeException("Aggregation
operation not supported yet.");
}
+ ec.setVariable(_output.getName(),
FederationUtils.aggScalar(aop, response));
}
- else if(((SpoofCellwise)_op).getCellType() ==
SpoofCellwise.CellType.ROW_AGG) { // row aggregation
- if(fedMo.isFederated(FType.ROW)) {
+ else if(cellType == CellType.ROW_AGG) { // row aggregation
+ if(fedType == FType.ROW) {
// bind partial results from federated responses
ec.setMatrixOutput(_output.getName(),
FederationUtils.bind(response, false));
}
- else if(fedMo.isFederated(FType.COL)) {
- if(((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.SUM
- || ((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.SUM_SQ) {
- AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uark+");
- ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
- }
- else if(((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.MIN) {
- AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uarmin");
- ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
- }
- else if(((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.MAX) {
- AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uarmax");
- ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
- }
+ else if(fedType == FType.COL) {
+ AggregateUnaryOperator aop = null;
+ if(aggOp == AggOp.SUM || aggOp == AggOp.SUM_SQ)
+ aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uark+");
+ else if(aggOp == AggOp.MIN)
+ aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uarmin");
+ else if(aggOp == AggOp.MAX)
+ aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uarmax");
+ else
+ throw new
DMLRuntimeException("Aggregation operation not supported yet.");
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
}
else {
throw new DMLRuntimeException("Aggregation type
for federated spoof instructions not supported yet.");
}
}
- else if(((SpoofCellwise)_op).getCellType() ==
SpoofCellwise.CellType.COL_AGG) { // col aggregation
- if(fedMo.isFederated(FType.ROW)) {
- if(((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.SUM
- || ((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.SUM_SQ) {
- AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uack+");
- ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
- }
- else if(((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.MIN) {
- AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uacmin");
- ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
- }
- else if(((SpoofCellwise)_op).getAggOp() ==
SpoofCellwise.AggOp.MAX) {
- AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uacmax");
- ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
- }
+ else if(cellType == CellType.COL_AGG) { // col aggregation
+ if(fedType == FType.ROW) {
+ AggregateUnaryOperator aop = null;
+ if(aggOp == AggOp.SUM || aggOp == AggOp.SUM_SQ)
+ aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uack+");
+ else if(aggOp == AggOp.MIN)
+ aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uacmin");
+ else if(aggOp == AggOp.MAX)
+ aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uacmax");
+ else
+ throw new
DMLRuntimeException("Aggregation operation not supported yet.");
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
}
- else if(fedMo.isFederated(FType.COL)) {
+ else if(fedType == FType.COL) {
// bind partial results from federated responses
ec.setMatrixOutput(_output.getName(),
FederationUtils.bind(response, true));
}
@@ -219,12 +235,12 @@ public class SpoofFEDInstruction extends FEDInstruction
throw new DMLRuntimeException("Aggregation type
for federated spoof instructions not supported yet.");
}
}
- else if(((SpoofCellwise)_op).getCellType() ==
SpoofCellwise.CellType.NO_AGG) { // no aggregation
- if(fedMo.isFederated(FType.ROW)) {
+ else if(cellType == CellType.NO_AGG) { // no aggregation
+ if(fedType == FType.ROW) {
// bind partial results from federated responses
ec.setMatrixOutput(_output.getName(),
FederationUtils.bind(response, false));
}
- else if(fedMo.isFederated(FType.COL)) {
+ else if(fedType == FType.COL) {
// bind partial results from federated responses
ec.setMatrixOutput(_output.getName(),
FederationUtils.bind(response, true));
}
@@ -236,4 +252,43 @@ public class SpoofFEDInstruction extends FEDInstruction
throw new DMLRuntimeException("Aggregation type not
supported yet.");
}
}
+
+ private void setOutputRowwise(ExecutionContext ec,
Future<FederatedResponse>[] response, FederationMap fedMap)
+ {
+ RowType rowType = ((SpoofRowwise)_op).getRowType();
+ if(rowType == RowType.FULL_AGG) { // full aggregation
+ // aggregate partial results from federated responses
as sum
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+ ec.setVariable(_output.getName(),
FederationUtils.aggScalar(aop, response));
+ }
+ else if(rowType == RowType.ROW_AGG) { // row aggregation
+ // aggregate partial results from federated responses
as rowSum
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uark+");
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
+ }
+ else if(rowType == RowType.COL_AGG
+ || rowType == RowType.COL_AGG_T
+ || rowType == RowType.COL_AGG_B1
+ || rowType == RowType.COL_AGG_B1_T
+ || rowType == RowType.COL_AGG_B1R
+ || rowType == RowType.COL_AGG_CONST) { // col
aggregation
+ // aggregate partial results from federated responses
as colSum
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uack+");
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
+ }
+ else if(rowType == RowType.NO_AGG
+ || rowType == RowType.NO_AGG_B1
+ || rowType == RowType.NO_AGG_CONST) { // no aggregation
+ if(fedMap.getType() == FType.ROW) {
+ // bind partial results from federated responses
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.bind(response, false));
+ }
+ else {
+ throw new DMLRuntimeException("Only row
partitioned federated matrices supported yet.");
+ }
+ }
+ else {
+ throw new DMLRuntimeException("AggregationType not
supported yet.");
+ }
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction.java
index dff78d8..f76d74f 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction.java
@@ -39,15 +39,14 @@ import
org.apache.sysds.runtime.codegen.SpoofOuterProduct.OutProdType;
import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
-import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
-import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
@@ -75,8 +74,7 @@ public class SpoofSPInstruction extends SPInstruction {
private final CPOperand[] _in;
private final CPOperand _out;
- private SpoofSPInstruction(Class<?> cls, byte[] classBytes, CPOperand[]
in, CPOperand out, String opcode,
- String str) {
+ private SpoofSPInstruction(Class<?> cls, byte[] classBytes, CPOperand[]
in, CPOperand out, String opcode, String str) {
super(SPType.SpoofFused, opcode, str);
_class = cls;
_classBytes = classBytes;
@@ -680,11 +678,16 @@ public class SpoofSPInstruction extends SPInstruction {
}
public boolean isFederated(ExecutionContext ec) {
- for(CPOperand input : _in) {
- Data data = ec.getVariable(input);
- if(data instanceof MatrixObject && ((MatrixObject)
data).isFederated())
+ for(CPOperand input : _in)
+ if( ec.isFederated(input) )
+ return true;
+ return false;
+ }
+
+ public boolean isFederated(ExecutionContext ec, FType type) {
+ for(CPOperand input : _in)
+ if( ec.isFederated(input, type) )
return true;
- }
return false;
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
index 17f7426..5ea32bc 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
@@ -49,7 +49,7 @@ public class FederatedCellwiseTmplTest extends
AutomatedTestBase
private final static String TEST_CONF = "SystemDS-config-codegen.xml";
private final static String OUTPUT_NAME = "Z";
- private final static double TOLERANCE = 1e-8;
+ private final static double TOLERANCE = 1e-11;
private final static int BLOCKSIZE = 1024;
@Parameterized.Parameter()
@@ -59,8 +59,6 @@ public class FederatedCellwiseTmplTest extends
AutomatedTestBase
@Parameterized.Parameter(2)
public int cols;
@Parameterized.Parameter(3)
- public double sparsity;
- @Parameterized.Parameter(4)
public boolean row_partitioned;
@Override
@@ -73,41 +71,41 @@ public class FederatedCellwiseTmplTest extends
AutomatedTestBase
// rows must be even for row partitioned X
// cols must be even for col partitioned X
return Arrays.asList(new Object[][] {
- // {test_num, rows, cols, sparsity, row_partitioned}
+ // {test_num, rows, cols, row_partitioned}
// row partitioned
- {1, 2000, 2000, 1, true},
- {2, 10, 10, 1, true},
- {3, 4, 4, 1, true},
- {4, 4, 4, 1, true},
- {5, 4, 4, 1, true},
- {6, 4, 1, 1, true},
- {9, 500, 2, 1, true},
- {10, 500, 2, 1, true},
- {11, 1100, 2000, 1, true},
- {12, 2, 500, 1, true},
- {13, 2, 4, 1, true},
- {14, 1100, 200, 1, true},
+ {1, 2000, 2000, true},
+ // {2, 10, 10, true},
+ // {3, 4, 4, true},
+ {4, 4, 4, true},
+ // {5, 4, 4, true},
+ {6, 4, 1, true},
+ {9, 500, 2, true},
+ {10, 500, 2, true},
+ // {11, 1100, 2000, true},
+ {12, 2, 500, true},
+ // {13, 2, 4, true},
+ {14, 1100, 200, true},
// column partitioned
- {1, 2000, 2000, 1, false},
- {2, 10, 10, 1, false},
- {3, 4, 4, 1, false},
- {4, 4, 4, 1, false},
- {5, 4, 4, 1, false},
- {9, 500, 2, 1, false},
- {10, 500, 2, 1, false},
- {11, 1100, 2000, 1, false},
- {12, 2, 500, 1, false},
- {14, 1100, 200, 1, false},
+ // {1, 2000, 2000, false},
+ {2, 10, 10, false},
+ {3, 4, 4, false},
+ // {4, 4, 4, false},
+ {5, 4, 4, false},
+ {9, 500, 2, false},
+ {10, 500, 2, false},
+ {11, 1100, 2000, false},
+ // {12, 2, 500, false},
+ {14, 1100, 200, false},
// not working because of fused sequence operation
- // (wrong grix inside genexec call of fed worker)
- // {7, 1000, 1, 1, true},
+ // (wrong grix inside genexec call of fed worker)
+ // {7, 1000, 1, true},
// not creating a FedSpoof instruction
- // {8, 1002, 24, 1, true},
- // {8, 1002, 24, 1, false},
+ // {8, 1002, 24, true},
+ // {8, 1002, 24, false},
});
}
@@ -147,8 +145,8 @@ public class FederatedCellwiseTmplTest extends
AutomatedTestBase
// generate dataset
// matrix handled by two federated workers
- double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 0, 1,
sparsity, 3);
- double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 0, 1,
sparsity, 7);
+ double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 0, 1, 0.1,
3);
+ double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 0, 1, 0.1,
23);
writeInputMatrixWithMTD("X1", X1, false, new
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
writeInputMatrixWithMTD("X2", X2, false, new
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
similarity index 74%
copy from
src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
copy to
src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
index 17f7426..0f63cb3 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
@@ -39,17 +39,17 @@ import java.util.HashMap;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
-public class FederatedCellwiseTmplTest extends AutomatedTestBase
+public class FederatedRowwiseTmplTest extends AutomatedTestBase
{
- private final static String TEST_NAME = "FederatedCellwiseTmplTest";
+ private final static String TEST_NAME = "FederatedRowwiseTmplTest";
private final static String TEST_DIR = "functions/federated/codegen/";
- private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedCellwiseTmplTest.class.getSimpleName() + "/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedRowwiseTmplTest.class.getSimpleName() + "/";
private final static String TEST_CONF = "SystemDS-config-codegen.xml";
private final static String OUTPUT_NAME = "Z";
- private final static double TOLERANCE = 1e-8;
+ private final static double TOLERANCE = 1e-13;
private final static int BLOCKSIZE = 1024;
@Parameterized.Parameter()
@@ -59,8 +59,6 @@ public class FederatedCellwiseTmplTest extends
AutomatedTestBase
@Parameterized.Parameter(2)
public int cols;
@Parameterized.Parameter(3)
- public double sparsity;
- @Parameterized.Parameter(4)
public boolean row_partitioned;
@Override
@@ -70,44 +68,33 @@ public class FederatedCellwiseTmplTest extends
AutomatedTestBase
@Parameterized.Parameters
public static Collection<Object[]> data() {
- // rows must be even for row partitioned X
- // cols must be even for col partitioned X
+ // rows must be even
return Arrays.asList(new Object[][] {
- // {test_num, rows, cols, sparsity, row_partitioned}
+ // {test_num, rows, cols, row_paritioned}
// row partitioned
- {1, 2000, 2000, 1, true},
- {2, 10, 10, 1, true},
- {3, 4, 4, 1, true},
- {4, 4, 4, 1, true},
- {5, 4, 4, 1, true},
- {6, 4, 1, 1, true},
- {9, 500, 2, 1, true},
- {10, 500, 2, 1, true},
- {11, 1100, 2000, 1, true},
- {12, 2, 500, 1, true},
- {13, 2, 4, 1, true},
- {14, 1100, 200, 1, true},
-
- // column partitioned
- {1, 2000, 2000, 1, false},
- {2, 10, 10, 1, false},
- {3, 4, 4, 1, false},
- {4, 4, 4, 1, false},
- {5, 4, 4, 1, false},
- {9, 500, 2, 1, false},
- {10, 500, 2, 1, false},
- {11, 1100, 2000, 1, false},
- {12, 2, 500, 1, false},
- {14, 1100, 200, 1, false},
-
- // not working because of fused sequence operation
- // (wrong grix inside genexec call of fed worker)
- // {7, 1000, 1, 1, true},
-
- // not creating a FedSpoof instruction
- // {8, 1002, 24, 1, true},
- // {8, 1002, 24, 1, false},
+ {1, 6, 4, true},
+ // {2, 6, 2, true},
+ {3, 6, 4, true},
+ {4, 6, 4, true},
+ {10, 150, 10, true},
+ {15, 150, 10, true},
+ // {20, 1500, 8, true},
+ {21, 1500, 8, true},
+ {25, 600, 10, true},
+ {31, 150, 10, true},
+ // {40, 300, 20, true},
+ {45, 1500, 100, true},
+ {50, 376, 4, true},
+
+ // col partitioned (should not create a federated spoof
instruction)
+ // column partitioned federated data is not supported
within federated rowwise templates
+ {1, 6, 4, false},
+ {3, 6, 4, false},
+ {15, 150, 10, false},
+ {25, 600, 10, false},
+ {31, 150, 10, false},
+ {50, 376, 4, false},
});
}
@@ -116,22 +103,22 @@ public class FederatedCellwiseTmplTest extends
AutomatedTestBase
TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
}
-// @Test
-// public void federatedCodegenCellwiseSingleNode() {
-// testFederatedCodegen(ExecMode.SINGLE_NODE);
-// }
-//
-// @Test
-// public void federatedCodegenCellwiseSpark() {
-// testFederatedCodegen(ExecMode.SPARK);
-// }
+ // @Test
+ // public void federatedCodegenRowwiseSingleNode() {
+ // testFederatedCodegenRowwise(ExecMode.SINGLE_NODE);
+ // }
+ //
+ // @Test
+ // public void federatedCodegenRowwiseSpark() {
+ // testFederatedCodegenRowwise(ExecMode.SPARK);
+ // }
@Test
public void federatedCodegenCellwiseHybrid() {
- testFederatedCodegen(ExecMode.HYBRID);
+ testFederatedCodegenRowwise(ExecMode.HYBRID);
}
-
- private void testFederatedCodegen(ExecMode exec_mode) {
+
+ private void testFederatedCodegenRowwise(ExecMode exec_mode) {
// store the previous platform config to restore it after the
test
ExecMode platform_old = setExecMode(exec_mode);
@@ -147,8 +134,8 @@ public class FederatedCellwiseTmplTest extends
AutomatedTestBase
// generate dataset
// matrix handled by two federated workers
- double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 0, 1,
sparsity, 3);
- double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 0, 1,
sparsity, 7);
+ double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 0, 1, 0.1,
3);
+ double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 0, 1, 0.1,
11);
writeInputMatrixWithMTD("X1", X1, false, new
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
writeInputMatrixWithMTD("X2", X2, false, new
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
@@ -186,15 +173,17 @@ public class FederatedCellwiseTmplTest extends
AutomatedTestBase
HashMap<CellIndex, Double> refResults =
readDMLMatrixFromExpectedDir(OUTPUT_NAME);
HashMap<CellIndex, Double> fedResults =
readDMLMatrixFromOutputDir(OUTPUT_NAME);
TestUtils.compareMatrices(fedResults, refResults, TOLERANCE,
"Fed", "Ref");
-
+
TestUtils.shutdownThreads(thread1, thread2);
// check for federated operations
-
Assert.assertTrue(heavyHittersContainsSubString("fed_spoofCell"));
+ if(row_partitioned)
+
Assert.assertTrue(heavyHittersContainsSubString("fed_spoofRA"));
// check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+
resetExecMode(platform_old);
}
diff --git
a/src/test/scripts/functions/federated/codegen/FederatedRowwiseTmplTest.dml
b/src/test/scripts/functions/federated/codegen/FederatedRowwiseTmplTest.dml
new file mode 100644
index 0000000..468c450
--- /dev/null
+++ b/src/test/scripts/functions/federated/codegen/FederatedRowwiseTmplTest.dml
@@ -0,0 +1,125 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+test_num = $in_test_num;
+row_part = $in_rp;
+
+if(row_part) {
+ X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0),
list($rows, $cols)));
+}
+else {
+ X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows, $cols / 2), list(0, $cols / 2),
list($rows, $cols)));
+}
+
+if(test_num == 1) {
+ # X ... 6x4 matrix
+ Y = matrix(1, rows=4, cols=1);
+ lamda = sum(X);
+
+ Z = t(X) %*% (X %*% (lamda * Y));
+}
+else if(test_num == 2) {
+ # X ... 6x2 matrix
+ U = matrix(1, rows=2, cols=1);
+ Y = matrix( "1 1 1 4 5 6 7 8 9 10 11 12", rows=6, cols=2);
+ lambda = sum(Y);
+
+ Z = t(X) %*% (lambda * (X %*% U));
+}
+else if(test_num == 3) {
+ # X ... 6x4 matrix
+ U = matrix(1,rows=4,cols=1);
+ V = matrix( "1 2 3 4 5 6", rows=6, cols=1);
+ W = matrix( "3 3 3 3 3 3", rows=6, cols=1);
+
+ Z = t(X) %*% (W + (2 - (V * (X %*% U))));
+}
+else if(test_num == 4) {
+ # X ... 6x4 matrix
+ Z = colSums(X/rowSums(X));
+}
+else if(test_num == 10) {
+ # X ... 150x10 matrix
+
+ Y = (X <= rowMins(X));
+ U = (Y / rowSums(Y));
+ Z = colSums(U);
+}
+else if(test_num == 15) {
+ # X ... 150x10 matrix
+
+ Y1 = X - rowMaxs(X)
+ Y2 = exp(Y1)
+ Y3 = Y2 / rowSums(Y2)
+ Y4 = Y3 * rowSums(Y3)
+ Z = Y4 - Y3 * rowSums(Y4)
+}
+else if(test_num == 20) {
+ # X ... 1500x8 matrix
+
+ Z = X / rowSums(X);
+ Z = 1 / (1 - Z);
+}
+else if(test_num == 21) {
+ # X ... 1500x8 matrix
+
+ Z = as.matrix(sum(X / rowSums(X)));
+}
+else if(test_num == 25) {
+ # X ... 600x10 matrix
+ C = matrix(seq(1,40), 4, 10);
+ while(FALSE){}
+
+ Z = -2 * (X %*% t(C)) + t(rowSums(C^2))
+}
+else if(test_num == 31) {
+ # X ... 150x10 matrix
+ y = seq(1, ncol(X));
+ Z = cbind((X %*% y), matrix (7, nrow(X), 1));
+ Z = Z - rowMaxs(Z);
+}
+else if(test_num == 40) {
+ # X ... 300x20 matrix
+
+ Y = X / rowSums(X);
+ Z = (X > 0) * Y;
+}
+else if(test_num == 45) {
+ # X ... 1500x100 matrix
+ X = X * t(seq(1,100));
+ while(FALSE){}
+
+ X0 = X - 0.5;
+ X1 = X / rowSums(X0);
+ X2 = abs(X1 * 0.5);
+ X3 = X1 / rowSums(X2);
+
+ while(FALSE){}
+ Z = as.matrix(sum(X3));
+}
+else if(test_num == 50) {
+ # X ... 376x4 matrix
+ Z = colSums(X == rowSums(X));
+}
+
+write(Z, $out_Z);
diff --git
a/src/test/scripts/functions/federated/codegen/FederatedRowwiseTmplTestReference.dml
b/src/test/scripts/functions/federated/codegen/FederatedRowwiseTmplTestReference.dml
new file mode 100644
index 0000000..0f4323f
--- /dev/null
+++
b/src/test/scripts/functions/federated/codegen/FederatedRowwiseTmplTestReference.dml
@@ -0,0 +1,123 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+test_num = $in_test_num;
+row_part = $in_rp;
+
+if(row_part) {
+ X = rbind(read($in_X1), read($in_X2));
+}
+else {
+ X = cbind(read($in_X1), read($in_X2));
+}
+
+if(test_num == 1) {
+ # X ... 6x4 matrix
+ Y = matrix(1, rows=4, cols=1);
+ lamda = sum(X);
+
+ Z = t(X) %*% (X %*% (lamda * Y));
+}
+else if(test_num == 2) {
+ # X ... 6x2 matrix
+ U = matrix(1, rows=2, cols=1);
+ Y = matrix( "1 1 1 4 5 6 7 8 9 10 11 12", rows=6, cols=2);
+ lambda = sum(Y);
+
+ Z = t(X) %*% (lambda * (X %*% U));
+}
+else if(test_num == 3) {
+ # X ... 6x4 matrix
+ U = matrix(1,rows=4,cols=1);
+ V = matrix( "1 2 3 4 5 6", rows=6, cols=1);
+ W = matrix( "3 3 3 3 3 3", rows=6, cols=1);
+
+ Z = t(X) %*% (W + (2 - (V * (X %*% U))));
+}
+else if(test_num == 4) {
+ # X ... 6x4 matrix
+ Z = colSums(X/rowSums(X));
+}
+else if(test_num == 10) {
+ # X ... 150x10 matrix
+
+ Y = (X <= rowMins(X));
+ U = (Y / rowSums(Y));
+ Z = colSums(U);
+}
+else if(test_num == 15) {
+ # X ... 150x10 matrix
+
+ Y1 = X - rowMaxs(X)
+ Y2 = exp(Y1)
+ Y3 = Y2 / rowSums(Y2)
+ Y4 = Y3 * rowSums(Y3)
+ Z = Y4 - Y3 * rowSums(Y4)
+}
+else if(test_num == 20) {
+ # X ... 1500x8 matrix
+
+ Z = X / rowSums(X);
+ Z = 1 / (1 - Z);
+}
+else if(test_num == 21) {
+ # X ... 1500x8 matrix
+
+ Z = as.matrix(sum(X / rowSums(X)));
+}
+else if(test_num == 25) {
+ # X ... 600x10 matrix
+ C = matrix(seq(1,40), 4, 10);
+ while(FALSE){}
+
+ Z = -2 * (X %*% t(C)) + t(rowSums(C^2))
+}
+else if(test_num == 31) {
+ # X ... 150x10 matrix
+ y = seq(1, ncol(X));
+ Z = cbind((X %*% y), matrix (7, nrow(X), 1));
+ Z = Z - rowMaxs(Z);
+}
+else if(test_num == 40) {
+ # X ... 300x20 matrix
+
+ Y = X / rowSums(X);
+ Z = (X > 0) * Y;
+}
+else if(test_num == 45) {
+ # X ... 1500x100 matrix
+ X = X * t(seq(1,100));
+ while(FALSE){}
+
+ X0 = X - 0.5;
+ X1 = X / rowSums(X0);
+ X2 = abs(X1 * 0.5);
+ X3 = X1 / rowSums(X2);
+
+ while(FALSE){}
+ Z = as.matrix(sum(X3));
+}
+else if(test_num == 50) {
+ # X ... 376x4 matrix
+ Z = colSums(X == rowSums(X));
+}
+
+write(Z, $out_Z);