This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 7528bf1 [SYSTEMDS-2982] Federated instructions w/ multiple aligned
matrices
7528bf1 is described below
commit 7528bf1d2d4bd9255d865ba95b7afd00d64c032f
Author: ywcb00 <[email protected]>
AuthorDate: Mon May 31 19:16:43 2021 +0200
[SYSTEMDS-2982] Federated instructions w/ multiple aligned matrices
Change alignment check to accept partial alignments
(row alignment - row partitioned / col alignment - col partitioned).
Introduce new enum AlignType to address the different possible alignment.
This enum add a new method isAligned(AlignType...) to simplify the check
for different possible alignments
- fix instruction string creation for fed_ba+*
- add support for three aligned federated matrices
- create static method isFederated() to centralize the isFederated methods
of SpoofCPInstruction and SpoofSPInstruction
- change method isFederated to make a simple call to the static
isFederated method of SpoofFEDInstruction
- change alignment checks to use the new method isAligned(AlignType...)
- add getter method getInputs() to get the input CPOperands from the outside
- change calls to isFederated from CP or SP spoof instruction to directly
call the static method SpoofFEDInstruction.isFederated()
- extend test to also test with two aligned federated matrices
- extend test for partitioned federated data
- add test script to test the dot product of two aligned federated matrices
- add test script to test the ternary aggregation "fed_tack+*"
- add test script to test the ternary aggregation "fed_tak+*"
Closes #1314
---
.../controlprogram/federated/FederationMap.java | 43 +++++++++
.../runtime/instructions/InstructionUtils.java | 6 +-
.../instructions/cp/SpoofCPInstruction.java | 40 +-------
.../fed/AggregateBinaryFEDInstruction.java | 3 +-
.../fed/AggregateTernaryFEDInstruction.java | 25 ++++-
.../fed/BinaryMatrixMatrixFEDInstruction.java | 4 +-
.../instructions/fed/CtableFEDInstruction.java | 16 +++-
.../instructions/fed/FEDInstructionUtils.java | 22 ++---
.../instructions/fed/MMChainFEDInstruction.java | 29 ++++--
.../instructions/fed/SpoofFEDInstruction.java | 34 +++++++
.../instructions/spark/SpoofSPInstruction.java | 42 +--------
.../primitives/FederatedColAggregateTest.java | 15 ++-
.../primitives/FederatedFullAggregateTest.java | 37 +++++---
.../federated/primitives/FederatedIfelseTest.java | 8 +-
.../federated/primitives/FederatedLogicalTest.java | 105 +++++++++++++--------
.../federated/primitives/FederatedRCBindTest.java | 54 +++++++----
.../primitives/FederatedRowAggregateTest.java | 15 ++-
.../functions/federated/FederatedRCBindTest.dml | 27 ++++--
.../federated/FederatedRCBindTestReference.dml | 11 ++-
.../FederatedMMTest.dml} | 31 +++---
.../FederatedMMTestReference.dml} | 26 ++---
.../FederatedTernaryColSumTest.dml} | 44 +++------
.../FederatedTernaryColSumTestReference.dml} | 27 +++---
.../FederatedTernarySumTest.dml} | 44 +++------
.../FederatedTernarySumTestReference.dml} | 27 +++---
.../binary/FederatedLogicalMatrixMatrixTest.dml | 8 ++
.../FederatedLogicalMatrixMatrixTestReference.dml | 8 ++
27 files changed, 453 insertions(+), 298 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index c77ff79..64a6cf9 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -98,6 +98,29 @@ public class FederationMap {
}
}
+ // Alignment Check Type
+ public enum AlignType {
+ FULL, // exact matching dimensions of partitions on the same
federated worker
+ ROW, // matching rows of partitions on the same federated worker
+ COL, // matching columns of partitions on the same federated
worker
+ FULL_T, // matching dimensions with transposed dimensions of
partitions on the same federated worker
+ ROW_T, // matching rows with columns of partitions on the same
federated worker
+ COL_T; // matching columns with rows of partitions on the same
federated worker
+
+ public boolean isTransposed() {
+ return (this == FULL_T || this == ROW_T || this ==
COL_T);
+ }
+ public boolean isFullType() {
+ return (this == FULL || this == FULL_T);
+ }
+ public boolean isRowType() {
+ return (this == ROW || this == ROW_T);
+ }
+ public boolean isColType() {
+ return (this == COL || this == COL_T);
+ }
+ }
+
private long _ID = -1;
private final List<Pair<FederatedRange, FederatedData>> _fedMap;
private FType _type;
@@ -231,6 +254,26 @@ public class FederationMap {
return ret;
}
+
+ /**
+ * helper function for checking multiple allowed alignment types
+ * @param that FederationMap to check alignment with
+ * @param alignTypes collection of alignment types which should be
checked
+ * @return true if this and that FederationMap are aligned according to
at least one alignment type
+ */
+ public boolean isAligned(FederationMap that, AlignType... alignTypes) {
+ boolean ret = false;
+ for(AlignType at : alignTypes) {
+ if(at.isFullType())
+ ret |= isAligned(that, at.isTransposed());
+ else
+ ret |= isAligned(that, at.isTransposed(),
at.isRowType(), at.isColType());
+ if(ret) // early stopping - alignment already found
+ break;
+ }
+ return ret;
+ }
+
/**
* Determines if the two federation maps are aligned row/column
partitions
* at the same federated sites (which allows for purely federated
operation)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index 2d8fdbd..076f5fc 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -1085,14 +1085,12 @@ public class InstructionUtils
public static String constructBinaryInstString(String instString,
String opcode, CPOperand op1, CPOperand op2, CPOperand out) {
String[] parts = instString.split(Lop.OPERAND_DELIMITOR);
- parts[1] = opcode;
- return InstructionUtils.concatOperands(parts[0], parts[1],
createOperand(op1), createOperand(op2), createOperand(out));
+ return InstructionUtils.concatOperands(parts[0], opcode,
createOperand(op1), createOperand(op2), createOperand(out));
}
public static String constructUnaryInstString(String instString, String
opcode, CPOperand op1, CPOperand out) {
String[] parts = instString.split(Lop.OPERAND_DELIMITOR);
- parts[1] = opcode;
- return InstructionUtils.concatOperands(parts[0], parts[1],
createOperand(op1), createOperand(out));
+ return InstructionUtils.concatOperands(parts[0], opcode,
createOperand(op1), createOperand(out));
}
/**
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/SpoofCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/SpoofCPInstruction.java
index 38fd8d7..97047a4 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/SpoofCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/SpoofCPInstruction.java
@@ -27,12 +27,8 @@ import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.runtime.codegen.CodegenUtils;
import org.apache.sysds.runtime.codegen.SpoofOperator;
-import org.apache.sysds.runtime.codegen.SpoofOuterProduct;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
-import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.lineage.LineageCodegenItem;
import org.apache.sysds.runtime.lineage.LineageItem;
@@ -133,39 +129,7 @@ public class SpoofCPInstruction extends
ComputationCPInstruction {
return Pair.of(output.getName(), LIroot);
}
- public boolean isFederated(ExecutionContext ec) {
- return isFederated(ec, null);
- }
-
- public boolean isFederated(ExecutionContext ec, FType type) {
- FederationMap fedMap = null;
- boolean retVal = false;
-
- // flags for alignment check
- boolean equalRows = false;
- boolean equalCols = false;
- boolean transposed = false; // flag indicates to check for
transposed alignment
-
- for(CPOperand input : _in) {
- Data data = ec.getVariable(input);
- if(data instanceof MatrixObject && ((MatrixObject)
data).isFederated(type)) {
- MatrixObject mo = ((MatrixObject) data);
- if(fedMap == null) { // first federated matrix
- fedMap = mo.getFedMapping();
- retVal = true;
-
- // setting the constraints for
alignment check on further federated matrices
- equalRows = mo.isFederated(FType.ROW);
- equalCols = mo.isFederated(FType.COL);
- transposed =
(getOperatorClass().getSuperclass() == SpoofOuterProduct.class);
- }
- else if(!fedMap.isAligned(mo.getFedMapping(),
false, equalRows, equalCols)
- && (!transposed ||
!(fedMap.isAligned(mo.getFedMapping(), true, equalRows, equalCols)
- ||
mo.getFedMapping().isAligned(fedMap, true, equalRows, equalCols)))) {
- retVal = false; // multiple federated
matrices must be aligned
- }
- }
- }
- return retVal;
+ public CPOperand[] getInputs() {
+ return _in;
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index c731ce0..10dd7c6 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -28,6 +28,7 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import
org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -72,7 +73,7 @@ public class AggregateBinaryFEDInstruction extends
BinaryFEDInstruction {
//TODO cleanup unnecessary redundancy
//#1 federated matrix-vector multiplication
if(mo1.isFederated(FType.COL) && mo2.isFederated(FType.ROW)
- && mo1.getFedMapping().isAligned(mo2.getFedMapping(),
true) ) {
+ && mo1.getFedMapping().isAligned(mo2.getFedMapping(),
AlignType.COL_T) ) {
FederatedRequest fr1 =
FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2},
new long[]{mo1.getFedMapping().getID(),
mo2.getFedMapping().getID()}, true);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
index 42a6e0e..17fd58a 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
@@ -24,6 +24,8 @@ import java.util.concurrent.Future;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import
org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
@@ -31,7 +33,9 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.cp.AggregateTernaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
public class AggregateTernaryFEDInstruction extends FEDInstruction {
// private static final Log LOG =
LogFactory.getLog(AggregateTernaryFEDInstruction.class.getName());
@@ -52,7 +56,26 @@ public class AggregateTernaryFEDInstruction extends
FEDInstruction {
MatrixObject mo1 = ec.getMatrixObject(_ins.input1);
MatrixObject mo2 = ec.getMatrixObject(_ins.input2);
MatrixObject mo3 = _ins.input3.isLiteral() ? null :
ec.getMatrixObject(_ins.input3);
- if(mo1.isFederated() && mo2.isFederated() &&
mo1.getFedMapping().isAligned(mo2.getFedMapping(), false) &&
+ if(mo3 != null && mo1.isFederated() && mo2.isFederated() &&
mo3.isFederated()
+ &&
mo1.getFedMapping().isAligned(mo2.getFedMapping(), mo1.isFederated(FType.ROW) ?
AlignType.ROW : AlignType.COL)
+ &&
mo2.getFedMapping().isAligned(mo3.getFedMapping(), mo1.isFederated(FType.ROW) ?
AlignType.ROW : AlignType.COL)) {
+ FederatedRequest fr1 =
FederationUtils.callInstruction(_ins.getInstructionString(), _ins.getOutput(),
+ new CPOperand[] {_ins.input1, _ins.input2,
_ins.input3},
+ new long[] {mo1.getFedMapping().getID(),
mo2.getFedMapping().getID(), mo3.getFedMapping().getID()});
+ FederatedRequest fr2 = new
FederatedRequest(RequestType.GET_VAR, fr1.getID());
+ FederatedRequest fr3 =
mo1.getFedMapping().cleanup(getTID(), fr1.getID());
+ Future<FederatedResponse>[] response =
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
+
+ if(_ins.output.getDataType().isScalar()) {
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+ ec.setScalarOutput(_ins.output.getName(),
FederationUtils.aggScalar(aop, response, mo1.getFedMapping()));
+ }
+ else {
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator(_ins.getOpcode().equals("fed_tak+*")
? "uak+" : "uack+");
+ ec.setMatrixOutput(_ins.output.getName(),
FederationUtils.aggMatrix(aop, response, mo1.getFedMapping()));
+ }
+ }
+ else if(mo1.isFederated() && mo2.isFederated() &&
mo1.getFedMapping().isAligned(mo2.getFedMapping(), false) &&
mo3 == null) {
FederatedRequest fr1 =
mo1.getFedMapping().broadcast(ec.getScalarInput(_ins.input3));
FederatedRequest fr2 =
FederationUtils.callInstruction(_ins.getInstructionString(),
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index db82bba..2a4d766 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -24,6 +24,7 @@ import
org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import
org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -53,7 +54,8 @@ public class BinaryMatrixMatrixFEDInstruction extends
BinaryFEDInstruction
//execute federated operation on mo1 or mo2
FederatedRequest fr2 = null;
if( mo2.isFederated() ) {
- if(mo1.isFederated() &&
mo1.getFedMapping().isAligned(mo2.getFedMapping(), false)) {
+ if(mo1.isFederated() &&
mo1.getFedMapping().isAligned(mo2.getFedMapping(),
+ mo1.isFederated(FType.ROW) ?
AlignType.ROW : AlignType.COL)) {
fr2 =
FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2},
new long[]{mo1.getFedMapping().getID(),
mo2.getFedMapping().getID()}, true);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
index 00137e8..e12ed5a 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
@@ -36,6 +36,7 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import
org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.functionobjects.And;
import org.apache.sysds.runtime.instructions.Instruction;
@@ -127,7 +128,20 @@ public class CtableFEDInstruction extends
ComputationFEDInstruction {
FederatedRequest[] fr1 =
mo1.getFedMapping().broadcastSliced(mo2, false);
FederatedRequest fr2, fr3;
- if(mo3 == null) {
+ if(mo3 != null && mo1.isFederated() && mo3.isFederated()
+ && mo1.getFedMapping().isAligned(mo3.getFedMapping(),
AlignType.FULL)) { // mo1 and mo3 federated and aligned
+ if(!reversed)
+ fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1,
input2, input3},
+ new long[]
{mo1.getFedMapping().getID(), fr1[0].getID(), mo3.getFedMapping().getID()});
+ else
+ fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1,
input2, input3},
+ new long[] {fr1[0].getID(),
mo1.getFedMapping().getID(), mo3.getFedMapping().getID()});
+
+ fr3 = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr4 =
mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+ ffr = mo1.getFedMapping().execute(getTID(), true, fr1,
fr2, fr3, fr4);
+ }
+ else if(mo3 == null) {
if(!reversed)
fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1,
input2},
new long[]
{mo1.getFedMapping().getID(), fr1[0].getID()});
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index df39028..bc05449 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -232,12 +232,12 @@ public class FEDInstructionUtils {
fedinst =
QuaternaryFEDInstruction.parseInstruction(instruction.getInstructionString());
}
else if(inst instanceof SpoofCPInstruction) {
- SpoofCPInstruction instruction = (SpoofCPInstruction)
inst;
- Class<?> scla =
instruction.getOperatorClass().getSuperclass();
- if(((scla == SpoofCellwise.class || scla ==
SpoofMultiAggregate.class
- || scla == SpoofOuterProduct.class) &&
instruction.isFederated(ec))
- || (scla == SpoofRowwise.class &&
instruction.isFederated(ec, FType.ROW))) {
- fedinst =
SpoofFEDInstruction.parseInstruction(instruction.getInstructionString());
+ SpoofCPInstruction ins = (SpoofCPInstruction) inst;
+ Class<?> scla = ins.getOperatorClass().getSuperclass();
+ if(((scla == SpoofCellwise.class || scla ==
SpoofMultiAggregate.class || scla == SpoofOuterProduct.class)
+ && SpoofFEDInstruction.isFederated(ec,
ins.getInputs(), scla))
+ || (scla == SpoofRowwise.class &&
SpoofFEDInstruction.isFederated(ec, FType.ROW, ins.getInputs(), scla))) {
+ fedinst =
SpoofFEDInstruction.parseInstruction(ins.getInstructionString());
}
}
else if(inst instanceof CtableCPInstruction) {
@@ -342,11 +342,11 @@ public class FEDInstructionUtils {
fedinst =
QuaternaryFEDInstruction.parseInstruction(instruction.getInstructionString());
}
else if(inst instanceof SpoofSPInstruction) {
- SpoofSPInstruction instruction = (SpoofSPInstruction)
inst;
- Class<?> scla =
instruction.getOperatorClass().getSuperclass();
- if(((scla == SpoofCellwise.class || scla ==
SpoofMultiAggregate.class
- || scla ==
SpoofOuterProduct.class) && instruction.isFederated(ec))
- || (scla == SpoofRowwise.class &&
instruction.isFederated(ec, FType.ROW))) {
+ SpoofSPInstruction ins = (SpoofSPInstruction) inst;
+ Class<?> scla = ins.getOperatorClass().getSuperclass();
+ if(((scla == SpoofCellwise.class || scla ==
SpoofMultiAggregate.class || scla == SpoofOuterProduct.class)
+ && SpoofFEDInstruction.isFederated(ec,
ins.getInputs(), scla))
+ || (scla == SpoofRowwise.class &&
SpoofFEDInstruction.isFederated(ec, FType.ROW, ins.getInputs(), scla))) {
fedinst =
SpoofFEDInstruction.parseInstruction(inst.getInstructionString());
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
index 25df0d3..7aa3ca9 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
@@ -25,6 +25,7 @@ import
org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import
org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -79,10 +80,27 @@ public class MMChainFEDInstruction extends
UnaryFEDInstruction {
if( !mo1.isFederated() )
throw new DMLRuntimeException("Federated MMChain:
Federated main input expected, "
+ "but invoked w/ "+mo1.isFederated()+"
"+mo2.isFederated());
-
- if( !_type.isWeighted() ) { //XtXv
- //construct commands: broadcast vector, execute, get
and aggregate, cleanup
- FederatedRequest fr1 =
mo1.getFedMapping().broadcast(mo2);
+
+ // broadcast vector mo2
+ FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
+
+ if(_type.isWeighted() && mo3.isFederated()
+ && mo1.getFedMapping().isAligned(mo3.getFedMapping(),
AlignType.ROW)) {
+ //construct commands: execute, get and
aggregate, cleanup
+ FederatedRequest fr2 =
FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2, input3},
+ new long[]{mo1.getFedMapping().getID(),
fr1.getID(), mo3.getFedMapping().getID()});
+ FederatedRequest fr3 = new
FederatedRequest(RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr4 = mo1.getFedMapping()
+ .cleanup(getTID(), fr1.getID(),
fr2.getID());
+
+ //execute federated operations and aggregate
+ Future<FederatedResponse>[] tmp =
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
+ MatrixBlock ret = FederationUtils.aggAdd(tmp);
+ ec.setMatrixOutput(output.getName(), ret);
+ }
+ else if( !_type.isWeighted() ) { //XtXv
+ //construct commands: execute, get and aggregate,
cleanup
FederatedRequest fr2 =
FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2}, new
long[]{mo1.getFedMapping().getID(), fr1.getID()});
FederatedRequest fr3 = new
FederatedRequest(RequestType.GET_VAR, fr2.getID());
@@ -95,9 +113,8 @@ public class MMChainFEDInstruction extends
UnaryFEDInstruction {
ec.setMatrixOutput(output.getName(), ret);
}
else { //XtwXv | XtXvy
- //construct commands: broadcast 2 vectors, execute, get
and aggregate, cleanup
+ //construct commands: broadcast vector mo3, execute,
get and aggregate, cleanup
FederatedRequest[] fr0 =
mo1.getFedMapping().broadcastSliced(mo3, false);
- FederatedRequest fr1 =
mo1.getFedMapping().broadcast(mo2);
FederatedRequest fr2 =
FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2, input3},
new long[]{mo1.getFedMapping().getID(),
fr1.getID(), fr0[0].getID()});
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
index 918ec00..d8717c0 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
@@ -36,6 +36,7 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import
org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -47,6 +48,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.concurrent.Future;
public class SpoofFEDInstruction extends FEDInstruction
@@ -427,4 +429,36 @@ public class SpoofFEDInstruction extends FEDInstruction
}
}
}
+
+ public static boolean isFederated(ExecutionContext ec, CPOperand[]
inputs, Class<?> scla) {
+ return isFederated(ec, null, inputs, scla);
+ }
+
+ public static boolean isFederated(ExecutionContext ec, FType type,
CPOperand[] inputs, Class<?> scla) {
+ FederationMap fedMap = null;
+ boolean retVal = false;
+
+ ArrayList<AlignType> alignmentTypes = new ArrayList<>();
+
+ for(CPOperand input : inputs) {
+ Data data = ec.getVariable(input);
+ if(data instanceof MatrixObject && ((MatrixObject)
data).isFederated(type)) {
+ MatrixObject mo = ((MatrixObject) data);
+ if(fedMap == null) { // first federated matrix
+ fedMap = mo.getFedMapping();
+ retVal = true;
+
+ // setting the alignment types for
alignment check on further federated matrices
+
alignmentTypes.add(mo.isFederated(FType.ROW) ? AlignType.ROW : AlignType.COL);
+ if(scla == SpoofOuterProduct.class)
+
Collections.addAll(alignmentTypes, AlignType.ROW_T, AlignType.COL_T);
+ }
+ else if(!fedMap.isAligned(mo.getFedMapping(),
alignmentTypes.toArray(new AlignType[0]))) {
+ retVal = false; // multiple federated
matrices must be aligned
+ }
+ }
+ }
+ return retVal;
+ }
+
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction.java
index cc6d7bc..c0ccf9f 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/SpoofSPInstruction.java
@@ -39,17 +39,13 @@ import
org.apache.sysds.runtime.codegen.SpoofOuterProduct.OutProdType;
import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
-import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
-import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
@@ -680,41 +676,7 @@ public class SpoofSPInstruction extends SPInstruction {
return null;
}
- public boolean isFederated(ExecutionContext ec) {
- return isFederated(ec, null);
- }
-
- public boolean isFederated(ExecutionContext ec, FType type) {
- //FIXME remove redundancy with SpoofCPInstruction
-
- FederationMap fedMap = null;
- boolean retVal = false;
-
- // flags for alignment check
- boolean equalRows = false;
- boolean equalCols = false;
- boolean transposed = false; // flag indicates to check for
transposed alignment
-
- for(CPOperand input : _in) {
- Data data = ec.getVariable(input);
- if(data instanceof MatrixObject && ((MatrixObject)
data).isFederated(type)) {
- MatrixObject mo = ((MatrixObject) data);
- if(fedMap == null) { // first federated matrix
- fedMap = mo.getFedMapping();
- retVal = true;
-
- // setting the constraints for
alignment check on further federated matrices
- equalRows = mo.isFederated(FType.ROW);
- equalCols = mo.isFederated(FType.COL);
- transposed =
(getOperatorClass().getSuperclass() == SpoofOuterProduct.class);
- }
- else if(!fedMap.isAligned(mo.getFedMapping(),
false, equalRows, equalCols)
- && (!transposed ||
!(fedMap.isAligned(mo.getFedMapping(), true, equalRows, equalCols)
- ||
mo.getFedMapping().isAligned(fedMap, true, equalRows, equalCols)))) {
- retVal = false; // multiple federated
matrices must be aligned
- }
- }
- }
- return retVal;
+ public CPOperand[] getInputs() {
+ return _in;
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedColAggregateTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedColAggregateTest.java
index 870e7c2..19e45e5 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedColAggregateTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedColAggregateTest.java
@@ -43,6 +43,7 @@ public class FederatedColAggregateTest extends
AutomatedTestBase {
private final static String TEST_NAME4 = "FederatedColMinTest";
private final static String TEST_NAME5 = "FederatedColProdTest";
private final static String TEST_NAME10 = "FederatedColVarTest";
+ private final static String TEST_NAME11 = "FederatedTernaryColSumTest";
private final static String TEST_DIR = "functions/federated/aggregate/";
private static final String TEST_CLASS_DIR = TEST_DIR +
FederatedColAggregateTest.class.getSimpleName() + "/";
@@ -65,7 +66,7 @@ public class FederatedColAggregateTest extends
AutomatedTestBase {
}
private enum OpType {
- SUM, MEAN, MAX, MIN, VAR, PROD
+ SUM, MEAN, MAX, MIN, VAR, PROD, TERNARY_SUM
}
@Override
@@ -77,6 +78,7 @@ public class FederatedColAggregateTest extends
AutomatedTestBase {
addTestConfiguration(TEST_NAME4, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"S"}));
addTestConfiguration(TEST_NAME10, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME10, new String[] {"S"}));
addTestConfiguration(TEST_NAME5, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {"S"}));
+ addTestConfiguration(TEST_NAME11, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME11, new String[] {"S"}));
}
@Test
@@ -94,7 +96,6 @@ public class FederatedColAggregateTest extends
AutomatedTestBase {
runAggregateOperationTest(OpType.MAX, ExecMode.SINGLE_NODE);
}
-
@Test
public void testColMinDenseMatrixCP() {
runAggregateOperationTest(OpType.MIN, ExecMode.SINGLE_NODE);
@@ -110,6 +111,11 @@ public class FederatedColAggregateTest extends
AutomatedTestBase {
runAggregateOperationTest(OpType.PROD, ExecMode.SINGLE_NODE);
}
+ @Test
+ public void testTernaryColSumDenseMatrixCP() {
+ runAggregateOperationTest(OpType.TERNARY_SUM,
ExecMode.SINGLE_NODE);
+ }
+
private void runAggregateOperationTest(OpType type, ExecMode execMode) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;
@@ -137,6 +143,9 @@ public class FederatedColAggregateTest extends
AutomatedTestBase {
case PROD:
TEST_NAME = TEST_NAME5;
break;
+ case TERNARY_SUM:
+ TEST_NAME = TEST_NAME11;
+ break;
}
getAndLoadTestConfiguration(TEST_NAME);
@@ -222,6 +231,8 @@ public class FederatedColAggregateTest extends
AutomatedTestBase {
case PROD:
Assert.assertTrue(heavyHittersContainsString(fedInst.concat("*")));
break;
+ case TERNARY_SUM:
+
Assert.assertTrue(heavyHittersContainsString("fed_tack+*"));
}
// check that federated input files are still existing
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullAggregateTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullAggregateTest.java
index 7d79ddb..abf7df2 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullAggregateTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullAggregateTest.java
@@ -44,6 +44,7 @@ public class FederatedFullAggregateTest extends
AutomatedTestBase {
private final static String TEST_NAME3 = "FederatedMaxTest";
private final static String TEST_NAME4 = "FederatedMinTest";
private final static String TEST_NAME5 = "FederatedVarTest";
+ private final static String TEST_NAME6 = "FederatedTernarySumTest";
private final static String TEST_DIR = "functions/federated/aggregate/";
private static final String TEST_CLASS_DIR = TEST_DIR +
FederatedFullAggregateTest.class.getSimpleName() + "/";
@@ -69,7 +70,7 @@ public class FederatedFullAggregateTest extends
AutomatedTestBase {
}
private enum OpType {
- SUM, MEAN, MAX, MIN, VAR
+ SUM, MEAN, MAX, MIN, VAR, TERNARY_SUM
}
@Override
@@ -80,64 +81,70 @@ public class FederatedFullAggregateTest extends
AutomatedTestBase {
addTestConfiguration(TEST_NAME3, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S.scalar"}));
addTestConfiguration(TEST_NAME4, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"S.scalar"}));
addTestConfiguration(TEST_NAME5, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {"S.scalar"}));
+ addTestConfiguration(TEST_NAME6, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] {"S.scalar"}));
}
@Test
public void testSumDenseMatrixCP() {
- runColAggregateOperationTest(OpType.SUM, ExecType.CP);
+ runAggregateOperationTest(OpType.SUM, ExecType.CP);
}
@Test
public void testMeanDenseMatrixCP() {
- runColAggregateOperationTest(OpType.MEAN, ExecType.CP);
+ runAggregateOperationTest(OpType.MEAN, ExecType.CP);
}
@Test
public void testMaxDenseMatrixCP() {
- runColAggregateOperationTest(OpType.MAX, ExecType.CP);
+ runAggregateOperationTest(OpType.MAX, ExecType.CP);
}
@Test
public void testMinDenseMatrixCP() {
- runColAggregateOperationTest(OpType.MIN, ExecType.CP);
+ runAggregateOperationTest(OpType.MIN, ExecType.CP);
}
@Test
public void testVarDenseMatrixCP() {
- runColAggregateOperationTest(OpType.VAR, ExecType.CP);
+ runAggregateOperationTest(OpType.VAR, ExecType.CP);
+ }
+
+ @Test
+ public void testTernarySumDenseMatrixCP() {
+ runAggregateOperationTest(OpType.TERNARY_SUM, ExecType.CP);
}
@Test
@Ignore
public void testSumDenseMatrixSP() {
- runColAggregateOperationTest(OpType.SUM, ExecType.SPARK);
+ runAggregateOperationTest(OpType.SUM, ExecType.SPARK);
}
@Test
@Ignore
public void testMeanDenseMatrixSP() {
- runColAggregateOperationTest(OpType.MEAN, ExecType.SPARK);
+ runAggregateOperationTest(OpType.MEAN, ExecType.SPARK);
}
@Test
@Ignore
public void testMaxDenseMatrixSP() {
- runColAggregateOperationTest(OpType.MAX, ExecType.SPARK);
+ runAggregateOperationTest(OpType.MAX, ExecType.SPARK);
}
@Test
@Ignore
public void testMinDenseMatrixSP() {
- runColAggregateOperationTest(OpType.MIN, ExecType.SPARK);
+ runAggregateOperationTest(OpType.MIN, ExecType.SPARK);
}
@Test
@Ignore
public void testVarDenseMatrixSP() {
- runColAggregateOperationTest(OpType.VAR, ExecType.SPARK);
+ runAggregateOperationTest(OpType.VAR, ExecType.SPARK);
}
- private void runColAggregateOperationTest(OpType type, ExecType
instType) {
+ private void runAggregateOperationTest(OpType type, ExecType instType) {
ExecMode platformOld = rtplatform;
switch(instType) {
case SPARK:
@@ -168,6 +175,9 @@ public class FederatedFullAggregateTest extends
AutomatedTestBase {
case VAR:
TEST_NAME = TEST_NAME5;
break;
+ case TERNARY_SUM:
+ TEST_NAME = TEST_NAME6;
+ break;
}
getAndLoadTestConfiguration(TEST_NAME);
@@ -243,6 +253,9 @@ public class FederatedFullAggregateTest extends
AutomatedTestBase {
case VAR:
Assert.assertTrue(heavyHittersContainsString("fed_uavar"));
break;
+ case TERNARY_SUM:
+
Assert.assertTrue(heavyHittersContainsString("fed_tak+*"));
+ break;
}
// check that federated input files are still existing
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedIfelseTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedIfelseTest.java
index c8733ff..11234e3 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedIfelseTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedIfelseTest.java
@@ -79,14 +79,14 @@ public class FederatedIfelseTest extends AutomatedTestBase {
runTernaryTest(ExecMode.SINGLE_NODE, true);
}
- private void runTernaryTest(ExecMode execMode, boolean alligned) {
+ private void runTernaryTest(ExecMode execMode, boolean aligned) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;
if(rtplatform == ExecMode.SPARK)
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
- String TEST_NAME = alligned ? TEST_NAME2 : TEST_NAME1;
+ String TEST_NAME = aligned ? TEST_NAME2 : TEST_NAME1;
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
@@ -138,7 +138,7 @@ public class FederatedIfelseTest extends AutomatedTestBase {
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
- if(alligned)
+ if(aligned)
runAlignedTernary(HOME, TEST_NAME, r, c, port1, port2,
port3, port4);
else
runTernary(HOME, TEST_NAME, port1, port2, port3, port4);
@@ -153,7 +153,7 @@ public class FederatedIfelseTest extends AutomatedTestBase {
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
- if(alligned) {
+ if(aligned) {
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("Y1")));
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("Y2")));
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("Y3")));
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java
index a79e3b7..4ebbd6d 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java
@@ -28,6 +28,7 @@ import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
import org.junit.BeforeClass;
+import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -74,7 +75,10 @@ public class FederatedLogicalTest extends AutomatedTestBase
private enum YType {
MATRIX,
ROW_VEC,
- COL_VEC
+ COL_VEC,
+ FED_MAT, // federated matrix Y
+ FED_RV, // federated row vector Y
+ FED_CV // federated col vector Y
}
@Parameterized.Parameter()
@@ -109,6 +113,7 @@ public class FederatedLogicalTest extends AutomatedTestBase
// {4, 75, 0.9, FederationType.ROW_PARTITIONED,
YType.MATRIX},
// {100, 1, 0.01, FederationType.ROW_PARTITIONED,
YType.MATRIX},
// {100, 1, 0.9, FederationType.ROW_PARTITIONED,
YType.MATRIX},
+ {24, 16, 0.25, FederationType.ROW_PARTITIONED,
YType.FED_MAT},
// row partitioned MV row vector
{100, 75, 0.01, FederationType.ROW_PARTITIONED,
YType.ROW_VEC},
@@ -125,6 +130,7 @@ public class FederatedLogicalTest extends AutomatedTestBase
// {4, 75, 0.9, FederationType.ROW_PARTITIONED,
YType.COL_VEC},
// {100, 1, 0.01, FederationType.ROW_PARTITIONED,
YType.COL_VEC},
// {100, 1, 0.9, FederationType.ROW_PARTITIONED,
YType.COL_VEC},
+ {24, 16, 0.25, FederationType.ROW_PARTITIONED,
YType.FED_CV},
// col partitioned MM
{100, 76, 0.01, FederationType.COL_PARTITIONED,
YType.MATRIX},
@@ -133,6 +139,7 @@ public class FederatedLogicalTest extends AutomatedTestBase
// {1, 76, 0.9, FederationType.COL_PARTITIONED,
YType.MATRIX},
// {100, 4, 0.01, FederationType.COL_PARTITIONED,
YType.MATRIX},
// {100, 4, 0.9, FederationType.COL_PARTITIONED,
YType.MATRIX},
+ {24, 16, 0.25, FederationType.COL_PARTITIONED,
YType.FED_MAT},
// col partitioned MV row vector
{100, 76, 0.01, FederationType.COL_PARTITIONED,
YType.ROW_VEC},
@@ -141,6 +148,7 @@ public class FederatedLogicalTest extends AutomatedTestBase
// {1, 76, 0.9, FederationType.COL_PARTITIONED,
YType.ROW_VEC},
// {100, 4, 0.01, FederationType.COL_PARTITIONED,
YType.ROW_VEC},
// {100, 4, 0.9, FederationType.COL_PARTITIONED,
YType.ROW_VEC},
+ {24, 16, 0.25, FederationType.COL_PARTITIONED,
YType.FED_RV},
// col partitioned MV col vector
{100, 76, 0.01, FederationType.COL_PARTITIONED,
YType.COL_VEC},
@@ -157,6 +165,7 @@ public class FederatedLogicalTest extends AutomatedTestBase
// {1, 75, 0.9, FederationType.SINGLE_FED_WORKER,
YType.MATRIX},
// {100, 1, 0.01, FederationType.SINGLE_FED_WORKER,
YType.MATRIX},
// {100, 1, 0.9, FederationType.SINGLE_FED_WORKER,
YType.MATRIX},
+ {24, 16, 0.25, FederationType.SINGLE_FED_WORKER,
YType.FED_MAT},
// full partitioned (not supported yet)
// {70, 80, 0.01, FederationType.FULL_PARTITIONED,
YType.MATRIX},
@@ -182,17 +191,20 @@ public class FederatedLogicalTest extends
AutomatedTestBase
federatedLogicalTest(SCALAR_TEST_NAME, Type.GREATER,
ExecMode.SPARK);
}
-// @Test
-// public void federatedLogicalScalarLessSingleNode() {
-// federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS,
ExecMode.SINGLE_NODE);
-// }
-//
-// @Test
-// public void federatedLogicalScalarLessSpark() {
-// federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS,
ExecMode.SPARK);
-// }
+ @Test
+ @Ignore
+ public void federatedLogicalScalarLessSingleNode() {
+ federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS,
ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ @Ignore
+ public void federatedLogicalScalarLessSpark() {
+ federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS,
ExecMode.SPARK);
+ }
@Test
+ @Ignore
public void federatedLogicalScalarEqualsSingleNode() {
federatedLogicalTest(SCALAR_TEST_NAME, Type.EQUALS,
ExecMode.SINGLE_NODE);
}
@@ -208,29 +220,34 @@ public class FederatedLogicalTest extends
AutomatedTestBase
}
@Test
+ @Ignore
public void federatedLogicalScalarNotEqualsSpark() {
federatedLogicalTest(SCALAR_TEST_NAME, Type.NOT_EQUALS,
ExecMode.SPARK);
}
@Test
+ @Ignore
public void federatedLogicalScalarGreaterEqualsSingleNode() {
federatedLogicalTest(SCALAR_TEST_NAME, Type.GREATER_EQUALS,
ExecMode.SINGLE_NODE);
}
@Test
+ @Ignore
public void federatedLogicalScalarGreaterEqualsSpark() {
federatedLogicalTest(SCALAR_TEST_NAME, Type.GREATER_EQUALS,
ExecMode.SPARK);
}
-// @Test
-// public void federatedLogicalScalarLessEqualsSingleNode() {
-// federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS_EQUALS,
ExecMode.SINGLE_NODE);
-// }
-//
-// @Test
-// public void federatedLogicalScalarLessEqualsSpark() {
-// federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS_EQUALS,
ExecMode.SPARK);
-// }
+ @Test
+ @Ignore
+ public void federatedLogicalScalarLessEqualsSingleNode() {
+ federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS_EQUALS,
ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ @Ignore
+ public void federatedLogicalScalarLessEqualsSpark() {
+ federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS_EQUALS,
ExecMode.SPARK);
+ }
//---------------------------MATRIX MATRIX--------------------------
@Test
@@ -243,15 +260,17 @@ public class FederatedLogicalTest extends
AutomatedTestBase
federatedLogicalTest(MATRIX_TEST_NAME, Type.GREATER,
ExecMode.SPARK);
}
-// @Test
-// public void federatedLogicalMatrixLessSingleNode() {
-// federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS,
ExecMode.SINGLE_NODE);
-// }
-//
-// @Test
-// public void federatedLogicalMatrixLessSpark() {
-// federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS,
ExecMode.SPARK);
-// }
+ @Test
+ @Ignore
+ public void federatedLogicalMatrixLessSingleNode() {
+ federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS,
ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ @Ignore
+ public void federatedLogicalMatrixLessSpark() {
+ federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS,
ExecMode.SPARK);
+ }
@Test
public void federatedLogicalMatrixEqualsSingleNode() {
@@ -259,11 +278,13 @@ public class FederatedLogicalTest extends
AutomatedTestBase
}
@Test
+ @Ignore
public void federatedLogicalMatrixEqualsSpark() {
federatedLogicalTest(MATRIX_TEST_NAME, Type.EQUALS,
ExecMode.SPARK);
}
@Test
+ @Ignore
public void federatedLogicalMatrixNotEqualsSingleNode() {
federatedLogicalTest(MATRIX_TEST_NAME, Type.NOT_EQUALS,
ExecMode.SINGLE_NODE);
}
@@ -274,24 +295,28 @@ public class FederatedLogicalTest extends
AutomatedTestBase
}
@Test
+ @Ignore
public void federatedLogicalMatrixGreaterEqualsSingleNode() {
federatedLogicalTest(MATRIX_TEST_NAME, Type.GREATER_EQUALS,
ExecMode.SINGLE_NODE);
}
@Test
+ @Ignore
public void federatedLogicalMatrixGreaterEqualsSpark() {
federatedLogicalTest(MATRIX_TEST_NAME, Type.GREATER_EQUALS,
ExecMode.SPARK);
}
-// @Test
-// public void federatedLogicalMatrixLessEqualsSingleNode() {
-// federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS_EQUALS,
ExecMode.SINGLE_NODE);
-// }
-//
-// @Test
-// public void federatedLogicalMatrixLessEqualsSpark() {
-// federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS_EQUALS,
ExecMode.SPARK);
-// }
+ @Test
+ @Ignore
+ public void federatedLogicalMatrixLessEqualsSingleNode() {
+ federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS_EQUALS,
ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ @Ignore
+ public void federatedLogicalMatrixLessEqualsSpark() {
+ federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS_EQUALS,
ExecMode.SPARK);
+ }
//
-----------------------------------------------------------------------------
@@ -346,8 +371,8 @@ public class FederatedLogicalTest extends AutomatedTestBase
double[][] Y_mat = null;
double Y_scal = 0;
if(is_matrix_test) {
- int y_rows = (y_type == YType.ROW_VEC ? 1 : rows);
- int y_cols = (y_type == YType.COL_VEC ? 1 : cols);
+ int y_rows = ((y_type == YType.ROW_VEC || y_type ==
YType.FED_RV) ? 1 : rows);
+ int y_cols = ((y_type == YType.COL_VEC || y_type ==
YType.FED_CV) ? 1 : cols);
Y_mat = getRandomMatrix(y_rows, y_cols, 0, 1, sparsity,
5040);
writeInputMatrixWithMTD("Y", Y_mat, false, new
MatrixCharacteristics(y_rows, y_cols, BLOCKSIZE, y_rows * y_cols));
@@ -375,6 +400,7 @@ public class FederatedLogicalTest extends AutomatedTestBase
"in_X4=" + (!single_fed_worker ? input("X4") :
input("X1")), // not needed in case of a single federated worker
"in_Y=" + (is_matrix_test ? input("Y") :
Double.toString(Y_scal)),
"in_fed_type=" + Integer.toString(fed_type.ordinal()),
+ "in_y_type=" + Integer.toString(y_type.ordinal()),
"in_op_type=" + Integer.toString(op_type.ordinal()),
"out_Z=" + expected(OUTPUT_NAME)};
runTest(true, false, null, -1);
@@ -388,6 +414,7 @@ public class FederatedLogicalTest extends AutomatedTestBase
"in_X4=" + (!single_fed_worker ?
TestUtils.federatedAddress(port4, input("X4")) : null),
"in_Y=" + (is_matrix_test ? input("Y") :
Double.toString(Y_scal)),
"in_fed_type=" + Integer.toString(fed_type.ordinal()),
+ "in_y_type=" + Integer.toString(y_type.ordinal()),
"in_op_type=" + Integer.toString(op_type.ordinal()),
"rows=" + Integer.toString(fed_rows), "cols=" +
Integer.toString(fed_cols),
"out_Z=" + output(OUTPUT_NAME)};
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
index efde5b7..cdcb408 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
@@ -45,17 +45,22 @@ public class FederatedRCBindTest extends AutomatedTestBase {
public int rows;
@Parameterized.Parameter(1)
public int cols;
+ @Parameterized.Parameter(2)
+ public boolean partitioned;
@Parameterized.Parameters
public static Collection<Object[]> data() {
//TODO add tests and support of aligned blocksized (which is
however a special case)
+ // rows must be even if paritioned
return Arrays.asList(new Object[][] {
- // {1, 1001},
- // {10, 100},
- {100, 10},
- // {1001, 1},
- // {10, 2001},
- // {2001, 10}
+ // (rows, cols, paritioned)
+ // {1, 1001, false},
+ {10, 100, false},
+ {100, 10, true},
+ // {1001, 1, false},
+ // {10, 2001, false},
+ // {2000, 10, true},
+ // {100, 100, true},
});
}
@@ -87,10 +92,20 @@ public class FederatedRCBindTest extends AutomatedTestBase {
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
- double[][] A = getRandomMatrix(rows, cols, -10, 10, 1, 1);
- writeInputMatrixWithMTD("A", A, false, new
MatrixCharacteristics(rows, cols, blocksize, rows * cols));
- double[][] B = getRandomMatrix(rows, cols, -10, 10, 1, 2);
- writeInputMatrixWithMTD("B", B, false, new
MatrixCharacteristics(rows, cols, blocksize, rows * cols));
+ if(partitioned)
+ rows = rows / 2;
+
+ double[][] A1 = getRandomMatrix(rows, cols, -10, 10, 1, 1);
+ writeInputMatrixWithMTD("A1", A1, false, new
MatrixCharacteristics(rows, cols, blocksize, rows * cols));
+ double[][] B1 = getRandomMatrix(rows, cols, -10, 10, 1, 2);
+ writeInputMatrixWithMTD("B1", B1, false, new
MatrixCharacteristics(rows, cols, blocksize, rows * cols));
+
+ double[][] A2 = partitioned ? getRandomMatrix(rows, cols, -10,
10, 1, 1) : null;
+ double[][] B2 = partitioned ? getRandomMatrix(rows, cols, -10,
10, 1, 2) : null;
+ if(partitioned) {
+ writeInputMatrixWithMTD("A2", A2, false, new
MatrixCharacteristics(rows, cols, blocksize, rows * cols));
+ writeInputMatrixWithMTD("B2", B2, false, new
MatrixCharacteristics(rows, cols, blocksize, rows * cols));
+ }
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
@@ -101,7 +116,10 @@ public class FederatedRCBindTest extends AutomatedTestBase
{
rtplatform = Types.ExecMode.SINGLE_NODE;
// Run reference dml script with normal matrix for Row/Col sum
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
- programArgs = new String[] {"-nvargs", "in1=" + input("A"),
"in2=" + input("B"), "out_R_FF=" + expected("R_FF"),
+ programArgs = new String[] {"-nvargs", "in_A1=" + input("A1"),
"in_A2=" + input("A2"),
+ "in_B1=" + input("B1"), "in_B2=" + input("B2"),
+ "in_partitioned=" +
Boolean.toString(partitioned).toUpperCase(),
+ "out_R_FF=" + expected("R_FF"),
"out_R_FL=" + expected("R_FL"), "out_R_LF=" +
expected("R_LF"), "out_C_FF=" + expected("C_FF"),
"out_C_FL=" + expected("C_FL"), "out_C_LF=" +
expected("C_LF")};
runTest(true, false, null, -1);
@@ -114,11 +132,15 @@ public class FederatedRCBindTest extends
AutomatedTestBase {
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-nvargs", "in1=" +
TestUtils.federatedAddress(port1, input("A")),
- "in2=" + TestUtils.federatedAddress(port2, input("B")),
"in2_local=" + input("B"), "rows=" + rows,
- "cols=" + cols, "out_R_FF=" + output("R_FF"),
"out_R_FL=" + output("R_FL"),
- "out_R_LF=" + output("R_LF"), "out_C_FF=" +
output("C_FF"), "out_C_FL=" + output("C_FL"),
- "out_C_LF=" + output("C_LF")};
+ programArgs = new String[] {"-nvargs",
+ "in_A1=" + TestUtils.federatedAddress(port1,
input("A1")),
+ "in_A2=" + TestUtils.federatedAddress(port1,
input("A2")),
+ "in_B1=" + TestUtils.federatedAddress(port2,
input("B1")),
+ "in_B2=" + TestUtils.federatedAddress(port2,
input("B2")),
+ "in_partitioned=" +
Boolean.toString(partitioned).toUpperCase(),
+ "in_B1_local=" + input("B1"), "in_B2_local=" +
input("B2"), "rows=" + rows, "cols=" + cols,
+ "out_R_FF=" + output("R_FF"), "out_R_FL=" +
output("R_FL"), "out_R_LF=" + output("R_LF"),
+ "out_C_FF=" + output("C_FF"), "out_C_FL=" +
output("C_FL"), "out_C_LF=" + output("C_LF")};
runTest(true, false, null, -1);
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowAggregateTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowAggregateTest.java
index c140dc8..91452eb 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowAggregateTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowAggregateTest.java
@@ -43,6 +43,7 @@ public class FederatedRowAggregateTest extends
AutomatedTestBase {
private final static String TEST_NAME8 = "FederatedRowMinTest";
private final static String TEST_NAME9 = "FederatedRowVarTest";
private final static String TEST_NAME10 = "FederatedRowProdTest";
+ private final static String TEST_NAME11 = "FederatedMMTest";
private final static String TEST_DIR = "functions/federated/aggregate/";
private static final String TEST_CLASS_DIR = TEST_DIR +
FederatedRowAggregateTest.class.getSimpleName() + "/";
@@ -65,7 +66,7 @@ public class FederatedRowAggregateTest extends
AutomatedTestBase {
}
private enum OpType {
- SUM, MEAN, MAX, MIN, VAR, PROD
+ SUM, MEAN, MAX, MIN, VAR, PROD, MM
}
@Override
@@ -77,6 +78,7 @@ public class FederatedRowAggregateTest extends
AutomatedTestBase {
addTestConfiguration(TEST_NAME8, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME8, new String[] {"S"}));
addTestConfiguration(TEST_NAME9, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME9, new String[] {"S"}));
addTestConfiguration(TEST_NAME10, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME10, new String[] {"S"}));
+ addTestConfiguration(TEST_NAME11, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME11, new String[] {"S"}));
}
@Test
@@ -109,6 +111,11 @@ public class FederatedRowAggregateTest extends
AutomatedTestBase {
runAggregateOperationTest(OpType.PROD, ExecMode.SINGLE_NODE);
}
+ @Test
+ public void testMMDenseMatrixCP() {
+ runAggregateOperationTest(OpType.MM, ExecMode.SINGLE_NODE);
+ }
+
private void runAggregateOperationTest(OpType type, ExecMode execMode) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;
@@ -136,6 +143,9 @@ public class FederatedRowAggregateTest extends
AutomatedTestBase {
case PROD:
TEST_NAME = TEST_NAME10;
break;
+ case MM:
+ TEST_NAME = TEST_NAME11;
+ break;
}
getAndLoadTestConfiguration(TEST_NAME);
@@ -221,6 +231,9 @@ public class FederatedRowAggregateTest extends
AutomatedTestBase {
case PROD:
Assert.assertTrue(heavyHittersContainsString(fedInst.concat("*")));
break;
+ case MM:
+
Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
+ break;
}
// check that federated input files are still existing
diff --git a/src/test/scripts/functions/federated/FederatedRCBindTest.dml
b/src/test/scripts/functions/federated/FederatedRCBindTest.dml
index 4447b95..422b2cb 100644
--- a/src/test/scripts/functions/federated/FederatedRCBindTest.dml
+++ b/src/test/scripts/functions/federated/FederatedRCBindTest.dml
@@ -19,16 +19,25 @@
#
#-------------------------------------------------------------
-A = federated(addresses=list($in1), ranges=list(list(0, 0), list($rows,
$cols)))
-BF = federated(addresses=list($in2), ranges=list(list(0, 0), list($rows,
$cols)))
-B = read($in2_local)
+if($in_partitioned) {
+ AF = federated(addresses=list($in_A1, $in_A2), ranges=list(list(0, 0),
list($rows, $cols), list($rows, 0), list($rows * 2, $cols)));
+ BF = federated(addresses=list($in_B1, $in_B2), ranges=list(list(0, 0),
list($rows, $cols), list($rows, 0), list($rows * 2, $cols)));
+ B = rbind(read($in_B1_local), read($in_B2_local));
+ while(FALSE) { }
+}
+else {
+ AF = federated(addresses=list($in_A1), ranges=list(list(0, 0), list($rows,
$cols)))
+ BF = federated(addresses=list($in_B1), ranges=list(list(0, 0), list($rows,
$cols)))
+ B = read($in_B1_local);
+}
-R_FF = rbind(A, BF)
-C_FF = cbind(A, BF)
-R_FL = rbind(A, B)
-C_FL = cbind(A, B)
-R_LF = rbind(B, A)
-C_LF = cbind(B, A)
+
+R_FF = rbind(AF, BF)
+C_FF = cbind(AF, BF)
+R_FL = rbind(AF, B)
+C_FL = cbind(AF, B)
+R_LF = rbind(B, AF)
+C_LF = cbind(B, AF)
write(R_FF, $out_R_FF)
write(R_FL, $out_R_FL)
diff --git
a/src/test/scripts/functions/federated/FederatedRCBindTestReference.dml
b/src/test/scripts/functions/federated/FederatedRCBindTestReference.dml
index 034a957..30712e0 100644
--- a/src/test/scripts/functions/federated/FederatedRCBindTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedRCBindTestReference.dml
@@ -19,8 +19,15 @@
#
#-------------------------------------------------------------
-A = read($in1)
-B = read($in2)
+if($in_partitioned) {
+ A = rbind(read($in_A1), read($in_A2));
+ B = rbind(read($in_B1), read($in_B2));
+}
+else {
+ A = read($in_A1);
+ B = read($in_B1);
+}
+
R = rbind(A, B)
C = cbind(A, B)
R_LF = rbind(B, A)
diff --git a/src/test/scripts/functions/federated/FederatedRCBindTest.dml
b/src/test/scripts/functions/federated/aggregate/FederatedMMTest.dml
similarity index 55%
copy from src/test/scripts/functions/federated/FederatedRCBindTest.dml
copy to src/test/scripts/functions/federated/aggregate/FederatedMMTest.dml
index 4447b95..9ba4176 100644
--- a/src/test/scripts/functions/federated/FederatedRCBindTest.dml
+++ b/src/test/scripts/functions/federated/aggregate/FederatedMMTest.dml
@@ -19,21 +19,20 @@
#
#-------------------------------------------------------------
-A = federated(addresses=list($in1), ranges=list(list(0, 0), list($rows,
$cols)))
-BF = federated(addresses=list($in2), ranges=list(list(0, 0), list($rows,
$cols)))
-B = read($in2_local)
+if ($rP) {
+ X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+ Y = X * 7; # row partitioned federated Y
+ X = t(X); # col partitioned federated X
+} else {
+ X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4),
list($rows, $cols/2),
+ list(0,$cols/2), list($rows, 3*($cols/4)), list(0,
3*($cols/4)), list($rows, $cols)));
+ Y = t(X) * 7; # row partitioned federated Y
+}
-R_FF = rbind(A, BF)
-C_FF = cbind(A, BF)
-R_FL = rbind(A, B)
-C_FL = cbind(A, B)
-R_LF = rbind(B, A)
-C_LF = cbind(B, A)
+while(FALSE) { }
-write(R_FF, $out_R_FF)
-write(R_FL, $out_R_FL)
-write(R_LF, $out_R_LF)
-
-write(C_FF, $out_C_FF)
-write(C_FL, $out_C_FL)
-write(C_LF, $out_C_LF)
+s = X %*% Y;
+write(s, $out_S);
diff --git
a/src/test/scripts/functions/federated/FederatedRCBindTestReference.dml
b/src/test/scripts/functions/federated/aggregate/FederatedMMTestReference.dml
similarity index 80%
copy from src/test/scripts/functions/federated/FederatedRCBindTestReference.dml
copy to
src/test/scripts/functions/federated/aggregate/FederatedMMTestReference.dml
index 034a957..47f1ce0 100644
--- a/src/test/scripts/functions/federated/FederatedRCBindTestReference.dml
+++
b/src/test/scripts/functions/federated/aggregate/FederatedMMTestReference.dml
@@ -19,15 +19,17 @@
#
#-------------------------------------------------------------
-A = read($in1)
-B = read($in2)
-R = rbind(A, B)
-C = cbind(A, B)
-R_LF = rbind(B, A)
-C_LF = cbind(B, A)
-write(R, $out_R_FF)
-write(R, $out_R_FL)
-write(R_LF, $out_R_LF)
-write(C, $out_C_FF)
-write(C, $out_C_FL)
-write(C_LF, $out_C_LF)
+if($6) {
+ X = rbind(read($1), read($2), read($3), read($4));
+ Y = X * 7;
+ X = t(X);
+}
+else {
+ X = cbind(read($1), read($2), read($3), read($4));
+ Y = t(X) * 7;
+}
+
+while(FALSE) { }
+
+s = X %*% Y;
+write(s, $5);
diff --git
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
b/src/test/scripts/functions/federated/aggregate/FederatedTernaryColSumTest.dml
similarity index 54%
copy from
src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
copy to
src/test/scripts/functions/federated/aggregate/FederatedTernaryColSumTest.dml
index e217ae8..d733a45 100644
---
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
+++
b/src/test/scripts/functions/federated/aggregate/FederatedTernaryColSumTest.dml
@@ -19,35 +19,21 @@
#
#-------------------------------------------------------------
-fed_type = $in_fed_type;
-
-if(fed_type == 0) { # single federated worker
- X = read($in_X1);
-}
-else if(fed_type == 1) { # row partitioned
- X = rbind(read($in_X1), read($in_X2), read($in_X3), read($in_X4));
-}
-else if(fed_type == 2) { # col partitioned
- X = cbind(read($in_X1), read($in_X2), read($in_X3), read($in_X4));
+if ($rP) {
+ X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+ Y = X * 7; # row partitioned federated Y
+ Z = X - 1.5; # row partitioned federated Z
+} else {
+ X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4),
list($rows, $cols/2),
+ list(0,$cols/2), list($rows, 3*($cols/4)), list(0,
3*($cols/4)), list($rows, $cols)));
+ Y = X * 7; # col partitioned federated Y
+ Z = X - 1.5; # col partitioned federated Z
}
-else if(fed_type == 3) { # full partitioned
- X = rbind(cbind(read($in_X1), read($in_X2)), cbind(read($in_X3),
read($in_X4)));
-}
-
-Y = read($in_Y);
-op_type = $in_op_type;
-if(op_type == 0)
- Z = (X > Y)
-else if(op_type == 1)
- Z = (X < Y)
-else if(op_type == 2)
- Z = (X == Y)
-else if(op_type == 3)
- Z = (X != Y)
-else if(op_type == 4)
- Z = (X >= Y)
-else if(op_type == 5)
- Z = (X <= Y)
+while(FALSE) { }
-write(Z, $out_Z);
+s = colSums(X * Y * Z);
+write(s, $out_S);
diff --git
a/src/test/scripts/functions/federated/FederatedRCBindTestReference.dml
b/src/test/scripts/functions/federated/aggregate/FederatedTernaryColSumTestReference.dml
similarity index 79%
copy from src/test/scripts/functions/federated/FederatedRCBindTestReference.dml
copy to
src/test/scripts/functions/federated/aggregate/FederatedTernaryColSumTestReference.dml
index 034a957..58c8181 100644
--- a/src/test/scripts/functions/federated/FederatedRCBindTestReference.dml
+++
b/src/test/scripts/functions/federated/aggregate/FederatedTernaryColSumTestReference.dml
@@ -19,15 +19,18 @@
#
#-------------------------------------------------------------
-A = read($in1)
-B = read($in2)
-R = rbind(A, B)
-C = cbind(A, B)
-R_LF = rbind(B, A)
-C_LF = cbind(B, A)
-write(R, $out_R_FF)
-write(R, $out_R_FL)
-write(R_LF, $out_R_LF)
-write(C, $out_C_FF)
-write(C, $out_C_FL)
-write(C_LF, $out_C_LF)
+if($6) {
+ X = rbind(read($1), read($2), read($3), read($4));
+ Y = X * 7;
+ Z = X - 1.5;
+}
+else {
+ X = cbind(read($1), read($2), read($3), read($4));
+ Y = X * 7;
+ Z = X - 1.5;
+}
+
+while(FALSE) { }
+
+s = colSums(X * Y * Z);
+write(s, $5);
diff --git
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
b/src/test/scripts/functions/federated/aggregate/FederatedTernarySumTest.dml
similarity index 54%
copy from
src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
copy to
src/test/scripts/functions/federated/aggregate/FederatedTernarySumTest.dml
index e217ae8..d2790a0 100644
---
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
+++ b/src/test/scripts/functions/federated/aggregate/FederatedTernarySumTest.dml
@@ -19,35 +19,21 @@
#
#-------------------------------------------------------------
-fed_type = $in_fed_type;
-
-if(fed_type == 0) { # single federated worker
- X = read($in_X1);
-}
-else if(fed_type == 1) { # row partitioned
- X = rbind(read($in_X1), read($in_X2), read($in_X3), read($in_X4));
-}
-else if(fed_type == 2) { # col partitioned
- X = cbind(read($in_X1), read($in_X2), read($in_X3), read($in_X4));
+if ($rP) {
+ X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+ Y = X * 7; # row partitioned federated Y
+ Z = X - 1.5; # row partitioned federated Z
+} else {
+ X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4),
list($rows, $cols/2),
+ list(0,$cols/2), list($rows, 3*($cols/4)), list(0,
3*($cols/4)), list($rows, $cols)));
+ Y = X * 7; # col partitioned federated Y
+ Z = X - 1.5; # col partitioned federated Z
}
-else if(fed_type == 3) { # full partitioned
- X = rbind(cbind(read($in_X1), read($in_X2)), cbind(read($in_X3),
read($in_X4)));
-}
-
-Y = read($in_Y);
-op_type = $in_op_type;
-if(op_type == 0)
- Z = (X > Y)
-else if(op_type == 1)
- Z = (X < Y)
-else if(op_type == 2)
- Z = (X == Y)
-else if(op_type == 3)
- Z = (X != Y)
-else if(op_type == 4)
- Z = (X >= Y)
-else if(op_type == 5)
- Z = (X <= Y)
+while(FALSE) { }
-write(Z, $out_Z);
+s = sum(X * Y * Z);
+write(s, $out_S);
diff --git
a/src/test/scripts/functions/federated/FederatedRCBindTestReference.dml
b/src/test/scripts/functions/federated/aggregate/FederatedTernarySumTestReference.dml
similarity index 79%
copy from src/test/scripts/functions/federated/FederatedRCBindTestReference.dml
copy to
src/test/scripts/functions/federated/aggregate/FederatedTernarySumTestReference.dml
index 034a957..8b634a1 100644
--- a/src/test/scripts/functions/federated/FederatedRCBindTestReference.dml
+++
b/src/test/scripts/functions/federated/aggregate/FederatedTernarySumTestReference.dml
@@ -19,15 +19,18 @@
#
#-------------------------------------------------------------
-A = read($in1)
-B = read($in2)
-R = rbind(A, B)
-C = cbind(A, B)
-R_LF = rbind(B, A)
-C_LF = cbind(B, A)
-write(R, $out_R_FF)
-write(R, $out_R_FL)
-write(R_LF, $out_R_LF)
-write(C, $out_C_FF)
-write(C, $out_C_FL)
-write(C_LF, $out_C_LF)
+if($6) {
+ X = rbind(read($1), read($2), read($3), read($4));
+ Y = X * 7;
+ Z = X - 1.5;
+}
+else {
+ X = cbind(read($1), read($2), read($3), read($4));
+ Y = X * 7;
+ Z = X - 1.5;
+}
+
+while(FALSE) { }
+
+s = sum(X * Y * Z);
+write(s, $5);
diff --git
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTest.dml
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTest.dml
index b84ff6f..e824fd8 100644
---
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTest.dml
+++
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTest.dml
@@ -20,6 +20,7 @@
#-------------------------------------------------------------
fed_type = $in_fed_type;
+y_type = $in_y_type;
if(fed_type == 0) { # single federated worker
X = federated(addresses=list($in_X1),
@@ -42,6 +43,13 @@ else if(fed_type == 3) { # full partitioned
}
Y = read($in_Y);
+if(y_type == 3) # make Y federated
+ Y = X + Y;
+else if(y_type == 4) # make Y federated
+ Y = X[1,] + Y;
+else if(y_type == 5) # make Y federated
+ Y = X[, 1] + Y;
+
op_type = $in_op_type;
if(op_type == 0)
diff --git
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
index e217ae8..53b35af 100644
---
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
+++
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
@@ -20,6 +20,7 @@
#-------------------------------------------------------------
fed_type = $in_fed_type;
+y_type = $in_y_type;
if(fed_type == 0) { # single federated worker
X = read($in_X1);
@@ -35,6 +36,13 @@ else if(fed_type == 3) { # full partitioned
}
Y = read($in_Y);
+if(y_type == 3) # make Y federated
+ Y = X + Y;
+else if(y_type == 4) # make Y federated
+ Y = X[1,] + Y;
+else if(y_type == 5) # make Y federated
+ Y = X[, 1] + Y;
+
op_type = $in_op_type;
if(op_type == 0)