This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 2576c2e [SYSTEMDS-2867] Cleanup federated binary operations, incl
tests
2576c2e is described below
commit 2576c2e9df350f549e6fd9c3463466c9630d923f
Author: ywcb00 <[email protected]>
AuthorDate: Sat Feb 20 18:04:33 2021 +0100
[SYSTEMDS-2867] Cleanup federated binary operations, incl tests
Closes #1182.
---
.../instructions/fed/BinaryFEDInstruction.java | 19 ++
.../fed/BinaryMatrixMatrixFEDInstruction.java | 60 ++---
.../instructions/fed/FEDInstructionUtils.java | 2 +
.../federated/primitives/FederatedLogicalTest.java | 257 ++++++++++++++++-----
.../binary/FederatedLogicalMatrixMatrixTest.dml | 23 +-
.../FederatedLogicalMatrixMatrixTestReference.dml | 15 +-
.../binary/FederatedLogicalMatrixScalarTest.dml | 23 +-
.../FederatedLogicalMatrixScalarTestReference.dml | 15 +-
8 files changed, 318 insertions(+), 96 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
index bfe0c27..9f0c91a 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
@@ -20,6 +20,9 @@
package org.apache.sysds.runtime.instructions.fed;
import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.lops.BinaryM.VectorType;
+import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -33,6 +36,11 @@ public abstract class BinaryFEDInstruction extends
ComputationFEDInstruction {
}
public static BinaryFEDInstruction parseInstruction(String str) {
+ if(str.startsWith(ExecType.SPARK.name())) {
+ // rewrite the spark instruction to a cp instruction
+ str = rewriteSparkInstructionToCP(str);
+ }
+
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
InstructionUtils.checkNumFields(parts, 3, 4);
String opcode = parts[0];
@@ -65,4 +73,15 @@ public abstract class BinaryFEDInstruction extends
ComputationFEDInstruction {
throw new DMLRuntimeException("Element-wise matrix
operations between variables " + in1.getName() +
" and " + in2.getName() + " must produce a
matrix, which " + out.getName() + " is not");
}
+
+ private static String rewriteSparkInstructionToCP(String inst_str) {
+ // rewrite the spark instruction to a cp instruction
+ inst_str = inst_str.replace(ExecType.SPARK.name(),
ExecType.CP.name());
+ inst_str = inst_str.replace(Lop.OPERAND_DELIMITOR + "map",
Lop.OPERAND_DELIMITOR);
+ inst_str = inst_str.replace(Lop.OPERAND_DELIMITOR + "RIGHT",
"");
+ inst_str = inst_str.replace(Lop.OPERAND_DELIMITOR +
VectorType.ROW_VECTOR.name(), "");
+ inst_str = inst_str.replace(Lop.OPERAND_DELIMITOR +
VectorType.COL_VECTOR.name(), "");
+
+ return inst_str;
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index 0ba1935..6f7dcc9 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -29,6 +29,7 @@ import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
+
public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
{
protected BinaryMatrixMatrixFEDInstruction(Operator op,
@@ -62,46 +63,49 @@ public class BinaryMatrixMatrixFEDInstruction extends
BinaryFEDInstruction
+ "federated right input are only
supported for special cases yet.");
}
}
- else {
- //matrix-matrix binary operations -> lhs fed input ->
fed output
- if(mo2.getNumRows() > 1 && mo2.getNumColumns() == 1 ) {
//MV col vector
- FederatedRequest[] fr1 =
mo1.getFedMapping().broadcastSliced(mo2, false);
+ else { // matrix-matrix binary operations -> lhs fed input ->
fed output
+ if(mo1.isFederated(FType.FULL)) {
+ // full federated (row and col)
+ if(mo1.getFedMapping().getSize() == 1) {
+ // only one partition (MM on a single
fed worker)
+ FederatedRequest fr1 =
mo1.getFedMapping().broadcast(mo2);
+ fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1,
input2},
+ new long[]{mo1.getFedMapping().getID(),
fr1.getID()});
+ FederatedRequest fr3 =
mo1.getFedMapping().cleanup(getTID(), fr1.getID());
+ //execute federated instruction and
cleanup intermediates
+ mo1.getFedMapping().execute(getTID(),
true, fr1, fr2, fr3);
+ }
+ else {
+ throw new
DMLRuntimeException("Matrix-matrix binary operations with a full partitioned
federated input with multiple partitions are not supported yet.");
+ }
+ }
+ else if((mo1.isFederated(FType.ROW) && mo2.getNumRows()
== 1 && mo2.getNumColumns() > 1)
+ || (mo1.isFederated(FType.COL) &&
mo2.getNumRows() > 1 && mo2.getNumColumns() == 1)) {
+ // MV row partitioned row vector, MV col
partitioned col vector
+ FederatedRequest fr1 =
mo1.getFedMapping().broadcast(mo2);
fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1,
input2},
- new long[]{mo1.getFedMapping().getID(),
fr1[0].getID()});
- FederatedRequest fr3 =
mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+ new long[]{mo1.getFedMapping().getID(),
fr1.getID()});
+ FederatedRequest fr3 =
mo1.getFedMapping().cleanup(getTID(), fr1.getID());
//execute federated instruction and cleanup
intermediates
mo1.getFedMapping().execute(getTID(), true,
fr1, fr2, fr3);
}
- else if(mo2.getNumRows() == 1 && mo2.getNumColumns() >
1) { //MV row vector
- FederatedRequest fr1 =
mo1.getFedMapping().broadcast(mo2);
+ else if(mo1.isFederated(FType.ROW) ^
mo1.isFederated(FType.COL)) {
+ // row partitioned MM or col partitioned MM
+ FederatedRequest[] fr1 =
mo1.getFedMapping().broadcastSliced(mo2, false);
fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1,
input2},
- new long[]{mo1.getFedMapping().getID(),
fr1.getID()});
- FederatedRequest fr3 =
mo1.getFedMapping().cleanup(getTID(), fr1.getID());
+ new long[]{mo1.getFedMapping().getID(),
fr1[0].getID()});
+ FederatedRequest fr3 =
mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
//execute federated instruction and cleanup
intermediates
mo1.getFedMapping().execute(getTID(), true,
fr1, fr2, fr3);
}
- else { //MM
- if(mo1.isFederated(FType.ROW)) {
- FederatedRequest[] fr1 =
mo1.getFedMapping().broadcastSliced(mo2, false);
- fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1,
input2},
- new
long[]{mo1.getFedMapping().getID(), fr1[0].getID()});
- FederatedRequest fr3 =
mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
- //execute federated instruction and
cleanup intermediates
- mo1.getFedMapping().execute(getTID(),
true, fr1, fr2, fr3);
- }
- else {
- FederatedRequest fr1 =
mo1.getFedMapping().broadcast(mo2);
- fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1,
input2},
- new
long[]{mo1.getFedMapping().getID(), fr1.getID()});
- FederatedRequest fr3 =
mo1.getFedMapping().cleanup(getTID(), fr1.getID());
- //execute federated instruction and
cleanup intermediates
- mo1.getFedMapping().execute(getTID(),
true, fr1, fr2, fr3);
- }
+ else {
+ throw new DMLRuntimeException("Matrix-matrix
binary operations are only supported with a row partitioned or column
partitioned federated input yet.");
}
}
- //derive new fed mapping for output
+ // derive new fed mapping for output
MatrixObject out = ec.getMatrixObject(output);
+
out.getDataCharacteristics().set(mo1.getDataCharacteristics());
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID()));
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 2a608f3..613ff31 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -45,6 +45,7 @@ import
org.apache.sysds.runtime.instructions.cp.VariableCPInstruction.VariableOp
import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendGAlignedSPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendGSPInstruction;
+import
org.apache.sysds.runtime.instructions.spark.BinaryMatrixBVectorSPInstruction;
import
org.apache.sysds.runtime.instructions.spark.BinaryMatrixMatrixSPInstruction;
import
org.apache.sysds.runtime.instructions.spark.BinaryMatrixScalarSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinarySPInstruction;
@@ -266,6 +267,7 @@ public class FEDInstructionUtils {
}
else if (inst instanceof BinaryMatrixScalarSPInstruction
|| inst instanceof
BinaryMatrixMatrixSPInstruction
+ || inst instanceof
BinaryMatrixBVectorSPInstruction
|| inst instanceof
BinaryTensorTensorSPInstruction
|| inst instanceof
BinaryTensorTensorBroadcastSPInstruction) {
BinarySPInstruction instruction =
(BinarySPInstruction) inst;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java
index 53dfb2e..a79e3b7 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java
@@ -36,6 +36,12 @@ import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
+/*
+ * Testing following logical operations:
+ * >, <, ==, !=, >=, <=
+ * with a row/col partitioned federated matrix X and a scalar/vector/matrix Y
+*/
+
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
public class FederatedLogicalTest extends AutomatedTestBase
@@ -47,9 +53,9 @@ public class FederatedLogicalTest extends AutomatedTestBase
private final static String OUTPUT_NAME = "Z";
private final static double TOLERANCE = 0;
- private final static int blocksize = 1024;
+ private final static int BLOCKSIZE = 1024;
- public enum Type{
+ private enum Type {
GREATER,
LESS,
EQUALS,
@@ -58,12 +64,29 @@ public class FederatedLogicalTest extends AutomatedTestBase
LESS_EQUALS
}
+ private enum FederationType {
+ SINGLE_FED_WORKER,
+ ROW_PARTITIONED,
+ COL_PARTITIONED,
+ FULL_PARTITIONED
+ }
+
+ private enum YType {
+ MATRIX,
+ ROW_VEC,
+ COL_VEC
+ }
+
@Parameterized.Parameter()
public int rows;
@Parameterized.Parameter(1)
public int cols;
@Parameterized.Parameter(2)
public double sparsity;
+ @Parameterized.Parameter(3)
+ public FederationType fed_type;
+ @Parameterized.Parameter(4)
+ public YType y_type;
@Override
public void setUp() {
@@ -73,13 +96,73 @@ public class FederatedLogicalTest extends AutomatedTestBase
@Parameterized.Parameters
public static Collection<Object[]> data() {
- // rows must be even
+ // rows must be divisable by 4 for row partitioned data
+ // cols must be divisable by 4 for col partitioned data
+ // rows and cols must be divisable by 2 for full partitioned
data
return Arrays.asList(new Object[][] {
- // {rows, cols, sparsity}
- {100, 75, 0.01},
- {100, 75, 0.9},
- {2, 75, 0.01},
- {2, 75, 0.9}
+ // {rows, cols, sparsity, fed_type, y_type}
+
+ // row partitioned MM
+ {100, 75, 0.01, FederationType.ROW_PARTITIONED,
YType.MATRIX},
+ {100, 75, 0.9, FederationType.ROW_PARTITIONED,
YType.MATRIX},
+ // {4, 75, 0.01, FederationType.ROW_PARTITIONED,
YType.MATRIX},
+ // {4, 75, 0.9, FederationType.ROW_PARTITIONED,
YType.MATRIX},
+ // {100, 1, 0.01, FederationType.ROW_PARTITIONED,
YType.MATRIX},
+ // {100, 1, 0.9, FederationType.ROW_PARTITIONED,
YType.MATRIX},
+
+ // row partitioned MV row vector
+ {100, 75, 0.01, FederationType.ROW_PARTITIONED,
YType.ROW_VEC},
+ {100, 75, 0.9, FederationType.ROW_PARTITIONED,
YType.ROW_VEC},
+ // {4, 75, 0.01, FederationType.ROW_PARTITIONED,
YType.ROW_VEC},
+ // {4, 75, 0.9, FederationType.ROW_PARTITIONED,
YType.ROW_VEC},
+ // {100, 1, 0.01, FederationType.ROW_PARTITIONED,
YType.ROW_VEC},
+ // {100, 1, 0.9, FederationType.ROW_PARTITIONED,
YType.ROW_VEC},
+
+ // row partitioned MV col vector
+ {100, 75, 0.01, FederationType.ROW_PARTITIONED,
YType.COL_VEC},
+ {100, 75, 0.9, FederationType.ROW_PARTITIONED,
YType.COL_VEC},
+ // {4, 75, 0.01, FederationType.ROW_PARTITIONED,
YType.COL_VEC},
+ // {4, 75, 0.9, FederationType.ROW_PARTITIONED,
YType.COL_VEC},
+ // {100, 1, 0.01, FederationType.ROW_PARTITIONED,
YType.COL_VEC},
+ // {100, 1, 0.9, FederationType.ROW_PARTITIONED,
YType.COL_VEC},
+
+ // col partitioned MM
+ {100, 76, 0.01, FederationType.COL_PARTITIONED,
YType.MATRIX},
+ {100, 76, 0.9, FederationType.COL_PARTITIONED,
YType.MATRIX},
+ // {1, 76, 0.01, FederationType.COL_PARTITIONED,
YType.MATRIX},
+ // {1, 76, 0.9, FederationType.COL_PARTITIONED,
YType.MATRIX},
+ // {100, 4, 0.01, FederationType.COL_PARTITIONED,
YType.MATRIX},
+ // {100, 4, 0.9, FederationType.COL_PARTITIONED,
YType.MATRIX},
+
+ // col partitioned MV row vector
+ {100, 76, 0.01, FederationType.COL_PARTITIONED,
YType.ROW_VEC},
+ {100, 76, 0.9, FederationType.COL_PARTITIONED,
YType.ROW_VEC},
+ // {1, 76, 0.01, FederationType.COL_PARTITIONED,
YType.ROW_VEC},
+ // {1, 76, 0.9, FederationType.COL_PARTITIONED,
YType.ROW_VEC},
+ // {100, 4, 0.01, FederationType.COL_PARTITIONED,
YType.ROW_VEC},
+ // {100, 4, 0.9, FederationType.COL_PARTITIONED,
YType.ROW_VEC},
+
+ // col partitioned MV col vector
+ {100, 76, 0.01, FederationType.COL_PARTITIONED,
YType.COL_VEC},
+ {100, 76, 0.9, FederationType.COL_PARTITIONED,
YType.COL_VEC},
+ // {1, 76, 0.01, FederationType.COL_PARTITIONED,
YType.COL_VEC},
+ // {1, 76, 0.9, FederationType.COL_PARTITIONED,
YType.COL_VEC},
+ // {100, 4, 0.01, FederationType.COL_PARTITIONED,
YType.COL_VEC},
+ // {100, 4, 0.9, FederationType.COL_PARTITIONED,
YType.COL_VEC},
+
+ // single federated worker MM
+ {100, 75, 0.01, FederationType.SINGLE_FED_WORKER,
YType.MATRIX},
+ {100, 75, 0.9, FederationType.SINGLE_FED_WORKER,
YType.MATRIX},
+ // {1, 75, 0.01, FederationType.SINGLE_FED_WORKER,
YType.MATRIX},
+ // {1, 75, 0.9, FederationType.SINGLE_FED_WORKER,
YType.MATRIX},
+ // {100, 1, 0.01, FederationType.SINGLE_FED_WORKER,
YType.MATRIX},
+ // {100, 1, 0.9, FederationType.SINGLE_FED_WORKER,
YType.MATRIX},
+
+ // full partitioned (not supported yet)
+ // {70, 80, 0.01, FederationType.FULL_PARTITIONED,
YType.MATRIX},
+ // {70, 80, 0.9, FederationType.FULL_PARTITIONED,
YType.MATRIX},
+ // {2, 2, 0.01, FederationType.FULL_PARTITIONED,
YType.MATRIX},
+ // {2, 2, 0.9, FederationType.FULL_PARTITIONED,
YType.MATRIX}
});
}
@@ -99,15 +182,15 @@ public class FederatedLogicalTest extends AutomatedTestBase
federatedLogicalTest(SCALAR_TEST_NAME, Type.GREATER,
ExecMode.SPARK);
}
- @Test
- public void federatedLogicalScalarLessSingleNode() {
- federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS,
ExecMode.SINGLE_NODE);
- }
-
- @Test
- public void federatedLogicalScalarLessSpark() {
- federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS,
ExecMode.SPARK);
- }
+// @Test
+// public void federatedLogicalScalarLessSingleNode() {
+// federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS,
ExecMode.SINGLE_NODE);
+// }
+//
+// @Test
+// public void federatedLogicalScalarLessSpark() {
+// federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS,
ExecMode.SPARK);
+// }
@Test
public void federatedLogicalScalarEqualsSingleNode() {
@@ -139,15 +222,15 @@ public class FederatedLogicalTest extends
AutomatedTestBase
federatedLogicalTest(SCALAR_TEST_NAME, Type.GREATER_EQUALS,
ExecMode.SPARK);
}
- @Test
- public void federatedLogicalScalarLessEqualsSingleNode() {
- federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS_EQUALS,
ExecMode.SINGLE_NODE);
- }
-
- @Test
- public void federatedLogicalScalarLessEqualsSpark() {
- federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS_EQUALS,
ExecMode.SPARK);
- }
+// @Test
+// public void federatedLogicalScalarLessEqualsSingleNode() {
+// federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS_EQUALS,
ExecMode.SINGLE_NODE);
+// }
+//
+// @Test
+// public void federatedLogicalScalarLessEqualsSpark() {
+// federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS_EQUALS,
ExecMode.SPARK);
+// }
//---------------------------MATRIX MATRIX--------------------------
@Test
@@ -160,15 +243,15 @@ public class FederatedLogicalTest extends
AutomatedTestBase
federatedLogicalTest(MATRIX_TEST_NAME, Type.GREATER,
ExecMode.SPARK);
}
- @Test
- public void federatedLogicalMatrixLessSingleNode() {
- federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS,
ExecMode.SINGLE_NODE);
- }
-
- @Test
- public void federatedLogicalMatrixLessSpark() {
- federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS,
ExecMode.SPARK);
- }
+// @Test
+// public void federatedLogicalMatrixLessSingleNode() {
+// federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS,
ExecMode.SINGLE_NODE);
+// }
+//
+// @Test
+// public void federatedLogicalMatrixLessSpark() {
+// federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS,
ExecMode.SPARK);
+// }
@Test
public void federatedLogicalMatrixEqualsSingleNode() {
@@ -200,15 +283,15 @@ public class FederatedLogicalTest extends
AutomatedTestBase
federatedLogicalTest(MATRIX_TEST_NAME, Type.GREATER_EQUALS,
ExecMode.SPARK);
}
- @Test
- public void federatedLogicalMatrixLessEqualsSingleNode() {
- federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS_EQUALS,
ExecMode.SINGLE_NODE);
- }
-
- @Test
- public void federatedLogicalMatrixLessEqualsSpark() {
- federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS_EQUALS,
ExecMode.SPARK);
- }
+// @Test
+// public void federatedLogicalMatrixLessEqualsSingleNode() {
+// federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS_EQUALS,
ExecMode.SINGLE_NODE);
+// }
+//
+// @Test
+// public void federatedLogicalMatrixLessEqualsSpark() {
+// federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS_EQUALS,
ExecMode.SPARK);
+// }
//
-----------------------------------------------------------------------------
@@ -220,39 +303,78 @@ public class FederatedLogicalTest extends
AutomatedTestBase
getAndLoadTestConfiguration(testname);
String HOME = SCRIPT_DIR + TEST_DIR;
- int fed_rows = rows / 2;
- int fed_cols = cols;
+ int fed_rows = 0;
+ int fed_cols = 0;
+ switch(fed_type) {
+ case SINGLE_FED_WORKER:
+ fed_rows = rows;
+ fed_cols = cols;
+ break;
+ case ROW_PARTITIONED:
+ fed_rows = rows / 4;
+ fed_cols = cols;
+ break;
+ case COL_PARTITIONED:
+ fed_rows = rows;
+ fed_cols = cols / 4;
+ break;
+ case FULL_PARTITIONED:
+ fed_rows = rows / 2;
+ fed_cols = cols / 2;
+ break;
+ }
+
+ boolean single_fed_worker = (fed_type ==
FederationType.SINGLE_FED_WORKER);
// generate dataset
- // matrix handled by two federated workers
- double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 1, 2, 1,
13);
- double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 1, 2, 1, 2);
-
- writeInputMatrixWithMTD("X1", X1, false, new
MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
- writeInputMatrixWithMTD("X2", X2, false, new
MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+ // matrix handled by four federated workers
+ // X2, X3, X4 not used if single_fed_worker == true
+ double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 0, 1,
sparsity, 13);
+ double[][] X2 = (!single_fed_worker ? getRandomMatrix(fed_rows,
fed_cols, 0, 1, sparsity, 2) : null);
+ double[][] X3 = (!single_fed_worker ? getRandomMatrix(fed_rows,
fed_cols, 0, 1, sparsity, 211) : null);
+ double[][] X4 = (!single_fed_worker ? getRandomMatrix(fed_rows,
fed_cols, 0, 1, sparsity, 65) : null);
+
+ writeInputMatrixWithMTD("X1", X1, false, new
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
+ if(!single_fed_worker) {
+ writeInputMatrixWithMTD("X2", X2, false, new
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
+ writeInputMatrixWithMTD("X3", X3, false, new
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
+ writeInputMatrixWithMTD("X4", X4, false, new
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
+ }
boolean is_matrix_test = testname.equals(MATRIX_TEST_NAME);
double[][] Y_mat = null;
double Y_scal = 0;
if(is_matrix_test) {
- Y_mat = getRandomMatrix(rows, cols, 0, 1, sparsity,
5040);
- writeInputMatrixWithMTD("Y", Y_mat, true);
+ int y_rows = (y_type == YType.ROW_VEC ? 1 : rows);
+ int y_cols = (y_type == YType.COL_VEC ? 1 : cols);
+
+ Y_mat = getRandomMatrix(y_rows, y_cols, 0, 1, sparsity,
5040);
+ writeInputMatrixWithMTD("Y", Y_mat, false, new
MatrixCharacteristics(y_rows, y_cols, BLOCKSIZE, y_rows * y_cols));
}
// empty script name because we don't execute any script, just
start the worker
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
- int port2 = getRandomAvailablePort();
- Thread thread1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
- Thread thread2 = startLocalFedWorkerThread(port2);
+ int port2 = (!single_fed_worker ? getRandomAvailablePort() : 0);
+ int port3 = (!single_fed_worker ? getRandomAvailablePort() : 0);
+ int port4 = (!single_fed_worker ? getRandomAvailablePort() : 0);
+ Thread thread1 = startLocalFedWorkerThread(port1,
(!single_fed_worker ? FED_WORKER_WAIT_S : FED_WORKER_WAIT));
+ Thread thread2 = (!single_fed_worker ?
startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S) : null);
+ Thread thread3 = (!single_fed_worker ?
startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S) : null);
+ Thread thread4 = (!single_fed_worker ?
startLocalFedWorkerThread(port4) : null);
getAndLoadTestConfiguration(testname);
// Run reference dml script with normal matrix
fullDMLScriptName = HOME + testname + "Reference.dml";
- programArgs = new String[] {"-nvargs", "in_X1=" + input("X1"),
"in_X2=" + input("X2"),
+ programArgs = new String[] {"-nvargs",
+ "in_X1=" + input("X1"),
+ "in_X2=" + (!single_fed_worker ? input("X2") :
input("X1")), // not needed in case of a single federated worker
+ "in_X3=" + (!single_fed_worker ? input("X3") :
input("X1")), // not needed in case of a single federated worker
+ "in_X4=" + (!single_fed_worker ? input("X4") :
input("X1")), // not needed in case of a single federated worker
"in_Y=" + (is_matrix_test ? input("Y") :
Double.toString(Y_scal)),
+ "in_fed_type=" + Integer.toString(fed_type.ordinal()),
"in_op_type=" + Integer.toString(op_type.ordinal()),
"out_Z=" + expected(OUTPUT_NAME)};
runTest(true, false, null, -1);
@@ -260,10 +382,15 @@ public class FederatedLogicalTest extends
AutomatedTestBase
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + testname + ".dml";
programArgs = new String[] {"-stats", "-nvargs",
- "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")), "in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
+ "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "in_X2=" + (!single_fed_worker ?
TestUtils.federatedAddress(port2, input("X2")) : null),
+ "in_X3=" + (!single_fed_worker ?
TestUtils.federatedAddress(port3, input("X3")) : null),
+ "in_X4=" + (!single_fed_worker ?
TestUtils.federatedAddress(port4, input("X4")) : null),
"in_Y=" + (is_matrix_test ? input("Y") :
Double.toString(Y_scal)),
+ "in_fed_type=" + Integer.toString(fed_type.ordinal()),
"in_op_type=" + Integer.toString(op_type.ordinal()),
- "rows=" + fed_rows, "cols=" + fed_cols, "out_Z=" +
output(OUTPUT_NAME)};
+ "rows=" + Integer.toString(fed_rows), "cols=" +
Integer.toString(fed_cols),
+ "out_Z=" + output(OUTPUT_NAME)};
runTest(true, false, null, -1);
// compare the results via files
@@ -271,7 +398,9 @@ public class FederatedLogicalTest extends AutomatedTestBase
HashMap<CellIndex, Double> fedResults =
readDMLMatrixFromOutputDir(OUTPUT_NAME);
TestUtils.compareMatrices(fedResults, refResults, TOLERANCE,
"Fed", "Ref");
- TestUtils.shutdownThreads(thread1, thread2);
+ TestUtils.shutdownThreads(thread1);
+ if(!single_fed_worker)
+ TestUtils.shutdownThreads(thread2, thread3, thread4);
// check for federated operations
switch(op_type)
@@ -298,7 +427,11 @@ public class FederatedLogicalTest extends AutomatedTestBase
// check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
- Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+ if(!single_fed_worker) {
+
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+ }
resetExecMode(platform_old);
}
diff --git
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTest.dml
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTest.dml
index 7fe350c..b84ff6f 100644
---
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTest.dml
+++
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTest.dml
@@ -19,8 +19,27 @@
#
#-------------------------------------------------------------
-X = federated(addresses=list($in_X1, $in_X2),
- ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2,
$cols)));
+fed_type = $in_fed_type;
+
+if(fed_type == 0) { # single federated worker
+ X = federated(addresses=list($in_X1),
+ ranges=list(list(0, 0), list($rows, $cols)));
+}
+else if(fed_type == 1) { # row partitioned
+ X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows *
2, $cols),
+ list($rows * 2, 0), list($rows * 3, $cols), list($rows * 3, 0),
list($rows * 4, $cols)));
+}
+else if(fed_type == 2) { # col partitioned
+ X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols), list(0, $cols), list($rows,
$cols * 2),
+ list(0, $cols * 2), list($rows, $cols * 3), list(0, $cols * 3),
list($rows, $cols * 4)));
+}
+else if(fed_type == 3) { # full partitioned
+ X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols), list(0, $cols), list($rows,
$cols * 2),
+ list($rows, 0), list($rows * 2, $cols), list($rows, $cols), list($rows *
2, $cols * 2)));
+}
Y = read($in_Y);
op_type = $in_op_type;
diff --git
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
index 1285eb0..e217ae8 100644
---
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
+++
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
@@ -19,7 +19,20 @@
#
#-------------------------------------------------------------
-X = rbind(read($in_X1), read($in_X2));
+fed_type = $in_fed_type;
+
+if(fed_type == 0) { # single federated worker
+ X = read($in_X1);
+}
+else if(fed_type == 1) { # row partitioned
+ X = rbind(read($in_X1), read($in_X2), read($in_X3), read($in_X4));
+}
+else if(fed_type == 2) { # col partitioned
+ X = cbind(read($in_X1), read($in_X2), read($in_X3), read($in_X4));
+}
+else if(fed_type == 3) { # full partitioned
+ X = rbind(cbind(read($in_X1), read($in_X2)), cbind(read($in_X3),
read($in_X4)));
+}
Y = read($in_Y);
op_type = $in_op_type;
diff --git
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTest.dml
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTest.dml
index 1dfb762..b4a520c 100644
---
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTest.dml
+++
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTest.dml
@@ -19,8 +19,27 @@
#
#-------------------------------------------------------------
-X = federated(addresses=list($in_X1, $in_X2),
- ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2,
$cols)));
+fed_type = $in_fed_type;
+
+if(fed_type == 0) { # single federated worker
+ X = federated(addresses=list($in_X1),
+ ranges=list(list(0, 0), list($rows, $cols)));
+}
+else if(fed_type == 1) { # row partitioned
+ X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows *
2, $cols),
+ list($rows * 2, 0), list($rows * 3, $cols), list($rows * 3, 0),
list($rows * 4, $cols)));
+}
+else if(fed_type == 2) { # col partitioned
+ X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols), list(0, $cols), list($rows,
$cols * 2),
+ list(0, $cols * 2), list($rows, $cols * 3), list(0, $cols * 3),
list($rows, $cols * 4)));
+}
+else if(fed_type == 3) { # full partitioned
+ X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols), list(0, $cols), list($rows,
$cols * 2),
+ list($rows, 0), list($rows * 2, $cols), list($rows, $cols), list($rows *
2, $cols * 2)));
+}
y = $in_Y;
op_type = $in_op_type;
diff --git
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTestReference.dml
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTestReference.dml
index 4682aea..40bb906 100644
---
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTestReference.dml
+++
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTestReference.dml
@@ -19,7 +19,20 @@
#
#-------------------------------------------------------------
-X = rbind(read($in_X1), read($in_X2));
+fed_type = $in_fed_type;
+
+if(fed_type == 0) { # single federated worker
+ X = read($in_X1);
+}
+else if(fed_type == 1) { # row partitioned
+ X = rbind(read($in_X1), read($in_X2), read($in_X3), read($in_X4));
+}
+else if(fed_type == 2) { # col partitioned
+ X = cbind(read($in_X1), read($in_X2), read($in_X3), read($in_X4));
+}
+else if(fed_type == 3) { # full partitioned
+ X = rbind(cbind(read($in_X1), read($in_X2)), cbind(read($in_X3),
read($in_X4)));
+}
y = $in_Y;
op_type = $in_op_type;