This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 53cba126b1 [SYSTEMDS-3800] Improve Code Coverage for Federated
Operations
53cba126b1 is described below
commit 53cba126b1663a2923b34778086e8a3093502a5a
Author: Grigorii <[email protected]>
AuthorDate: Mon Dec 2 17:35:46 2024 +0100
[SYSTEMDS-3800] Improve Code Coverage for Federated Operations
Closes #2148.
---
.../instructions/fed/TernaryFEDInstruction.java | 40 ++++++-
.../part1/FederatedCentralMomentTest.java | 98 ++++++++++++-----
.../part5/FederatedFullCumulativeTest.java | 29 ++++-
.../primitives/part5/FederatedIfelseTest.java | 120 ++++++++++-----------
....dml => FederatedCentralMomentWeightedTest.dml} | 14 ++-
...ederatedCentralMomentWeightedTestReference.dml} | 13 +--
...ml => FederatedIfelseSingleMatrixInputTest.dml} | 12 ++-
...eratedIfelseSingleMatrixInputTestReference.dml} | 9 +-
.../functions/federated/FederatedIfelseTest.dml | 13 ++-
.../federated/FederatedIfelseTestReference.dml | 8 +-
10 files changed, 224 insertions(+), 132 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
index 0883e6fe02..786f08344e 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
@@ -122,7 +122,38 @@ public class TernaryFEDInstruction extends
ComputationFEDInstruction {
if(matrixInputsCount == 3)
processMatrixInput(ec, mo1, mo2, mo3);
else if(matrixInputsCount == 1) {
- CPOperand in = mo1 == null ? mo2 == null ? input3 :
input2 : input1;
+ CPOperand in;
+ // determine the position of a matrix in the input and
whether any of the scalars are not literals
+ if (mo1 == null) {
+ if (mo2 == null) { // sc, sc, mat
+ in = input3;
+ instString =
InstructionUtils.replaceOperand(instString, 2,
+
InstructionUtils.createLiteralOperand(ec.getScalarInput(input1).getStringValue(),
Types.ValueType.FP64));
+ if (!input2.isLiteral()) {
+ instString =
InstructionUtils.replaceOperand(instString, 3,
+
InstructionUtils.createLiteralOperand(ec.getScalarInput(input2).getStringValue(),
Types.ValueType.FP64));
+ }
+ } else { // sc, mat, sc
+ in = input2;
+ instString =
InstructionUtils.replaceOperand(instString, 2,
+
InstructionUtils.createLiteralOperand(ec.getScalarInput(input1).getStringValue(),
Types.ValueType.FP64));
+ if (!input3.isLiteral()) {
+ instString =
InstructionUtils.replaceOperand(instString, 4,
+
InstructionUtils.createLiteralOperand(ec.getScalarInput(input3).getStringValue(),
Types.ValueType.FP64));
+ }
+ }
+ } else { // mat, sc, sc
+ in = input1;
+ if (!input2.isLiteral()) {
+ instString =
InstructionUtils.replaceOperand(instString, 3,
+
InstructionUtils.createLiteralOperand(ec.getScalarInput(input2).getStringValue(),
Types.ValueType.FP64));
+ }
+ if (!input3.isLiteral()) {
+ instString =
InstructionUtils.replaceOperand(instString, 4,
+
InstructionUtils.createLiteralOperand(ec.getScalarInput(input3).getStringValue(),
Types.ValueType.FP64));
+ }
+ }
+
mo1 = mo1 == null ? mo2 == null ? mo3 : mo2 : mo1;
processMatrixScalarInput(ec, mo1, in);
}
@@ -150,11 +181,10 @@ public class TernaryFEDInstruction extends
ComputationFEDInstruction {
private void processMatrixScalarInput(ExecutionContext ec,
MatrixLineagePair mo1, CPOperand in) {
long id = FederationUtils.getNextFedDataID();
- FederatedRequest fr1 = new
FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new
MatrixCharacteristics(-1, -1), mo1.getDataType());
-
- FederatedRequest fr2 =
FederationUtils.callInstruction(instString, output, id, new CPOperand[] {in},
new long[] {mo1.getFedMapping().getID()},
+ FederatedRequest fr =
FederationUtils.callInstruction(instString, output, id, new CPOperand[] {in},
new long[] {mo1.getFedMapping().getID()},
InstructionUtils.getExecType(instString), false);
- sendFederatedRequests(ec, mo1.getMO(), fr1.getID(), fr1, fr2);
+
+ sendFederatedRequests(ec, mo1.getMO(), fr.getID(), fr);
}
private void process2MatrixScalarInput(ExecutionContext ec,
MatrixLineagePair mo1, MatrixLineagePair mo2, CPOperand in1, CPOperand in2) {
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java
index 6b23897bd5..1d9f951e78 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java
@@ -40,7 +40,8 @@ import org.junit.runners.Parameterized;
public class FederatedCentralMomentTest extends AutomatedTestBase {
private final static String TEST_DIR = "functions/federated/";
- private final static String TEST_NAME = "FederatedCentralMomentTest";
+ private final static String TEST_NAME1 = "FederatedCentralMomentTest";
+ private final static String TEST_NAME2 =
"FederatedCentralMomentWeightedTest";
private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedCentralMomentTest.class.getSimpleName() + "/";
private final static int blocksize = 1024;
@@ -48,44 +49,62 @@ public class FederatedCentralMomentTest extends
AutomatedTestBase {
public int rows;
@Parameterized.Parameter(1)
+ public int cols;
+
+ @Parameterized.Parameter(2)
public int k;
@Parameterized.Parameters
public static Collection<Object[]> data() {
- return Arrays.asList(new Object[][] {{1000, 2}, {1000, 3},
{1000, 4}});
+ return Arrays.asList(new Object[][] {{20, 1, 2}});
}
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
- addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S.scalar"}));
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S.scalar"}));
+ addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S.scalar"}));
}
@Test
- @Ignore // infinite runtime online but works locally.
public void federatedCentralMomentCP() {
- federatedCentralMoment(Types.ExecMode.SINGLE_NODE);
+ federatedCentralMoment(Types.ExecMode.SINGLE_NODE, false);
+ }
+
+ @Test
+ public void federatedCentralMomentWeightedCP() {
+ federatedCentralMoment(Types.ExecMode.SINGLE_NODE, true);
}
@Test
- @Ignore
public void federatedCentralMomentSP() {
- federatedCentralMoment(Types.ExecMode.SPARK);
+ federatedCentralMoment(Types.ExecMode.SPARK, false);
+ }
+
+ // The test fails due to an error while executing rmvar instruction
after cm calculation
+ // The CacheStatus of the weights variable is READ hence it can't be
modified
+ // In this test the input matrix is federated and weights are read from
file
+ @Ignore
+ @Test
+ public void federatedCentralMomentWeightedSP() {
+ federatedCentralMoment(Types.ExecMode.SPARK, true);
}
- public void federatedCentralMoment(Types.ExecMode execMode) {
+ public void federatedCentralMoment(Types.ExecMode execMode, boolean
isWeighted) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
Types.ExecMode platformOld = rtplatform;
+ String TEST_NAME = isWeighted ? TEST_NAME2 : TEST_NAME1;
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
int r = rows / 4;
+ int c = cols;
- double[][] X1 = getRandomMatrix(r, 1, 1, 5, 1, 3);
- double[][] X2 = getRandomMatrix(r, 1, 1, 5, 1, 7);
- double[][] X3 = getRandomMatrix(r, 1, 1, 5, 1, 8);
- double[][] X4 = getRandomMatrix(r, 1, 1, 5, 1, 9);
+ double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
+ double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
+ double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
+ double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
MatrixCharacteristics mc = new MatrixCharacteristics(r, 1,
blocksize, r);
writeInputMatrixWithMTD("X1", X1, false, mc);
@@ -114,24 +133,47 @@ public class FederatedCentralMomentTest extends
AutomatedTestBase {
if(rtplatform == Types.ExecMode.SPARK) {
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
}
- // Run reference dml script with normal matrix for
Row/Col
- fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
- programArgs = new String[] {"-stats", "100", "-args",
input("X1"), input("X2"), input("X3"), input("X4"),
- expected("S"), String.valueOf(k)};
- runTest(null);
-
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
-
- fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-stats", "100", "-nvargs",
- "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
- "in_X2=" + TestUtils.federatedAddress(port2,
input("X2")),
- "in_X3=" + TestUtils.federatedAddress(port3,
input("X3")),
- "in_X4=" + TestUtils.federatedAddress(port4,
input("X4")), "rows=" + rows, "cols=" + 1,
- "out_S=" + output("S"), "k=" + k};
- runTest(null);
-
+ if (isWeighted) {
+ double[][] W1 = getRandomMatrix(r, c, 0, 1, 1,
3);
+
+ writeInputMatrixWithMTD("W1", W1, false, mc);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME +
"Reference.dml";
+ programArgs = new String[] {"-stats", "100",
"-args", input("X1"), input("X2"), input("X3"), input("X4"),
+ input("W1"), expected("S"), ""
+ k};
+ runTest(null);
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "100",
"-nvargs",
+ "in_X1=" +
TestUtils.federatedAddress(port1, input("X1")),
+ "in_X2=" +
TestUtils.federatedAddress(port2, input("X2")),
+ "in_X3=" +
TestUtils.federatedAddress(port3, input("X3")),
+ "in_X4=" +
TestUtils.federatedAddress(port4, input("X4")),
+ "in_W1=" + input("W1"),
+ "rows=" + rows, "cols=" + cols,
"k=" + k,
+ "out_S=" + output("S")};
+ runTest(null);
+ }
+ else {
+ // Run reference dml script with normal matrix
for Row/Col
+ fullDMLScriptName = HOME + TEST_NAME +
"Reference.dml";
+ programArgs = new String[]{"-stats", "100",
"-args", input("X1"), input("X2"), input("X3"), input("X4"),
+ expected("S"),
String.valueOf(k)};
+ runTest(null);
+
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[]{"-stats", "100",
"-nvargs",
+ "in_X1=" +
TestUtils.federatedAddress(port1, input("X1")),
+ "in_X2=" +
TestUtils.federatedAddress(port2, input("X2")),
+ "in_X3=" +
TestUtils.federatedAddress(port3, input("X3")),
+ "in_X4=" +
TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + 1,
+ "out_S=" + output("S"), "k=" +
k};
+ runTest(null);
+ }
// compare all sums via files
compareResults(0.01);
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java
index b47b920006..e5ff93e02c 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java
@@ -89,6 +89,16 @@ public class FederatedFullCumulativeTest extends
AutomatedTestBase {
runCumOperationTest(OpType.MIN, ExecType.CP);
}
+ @Test
+ public void testProdDenseMatrixCP() {
+ runCumOperationTest(OpType.PROD, ExecType.CP);
+ }
+
+ @Test
+ public void testSumProdDenseMatrixCP() {
+ runCumOperationTest(OpType.SUMPROD, ExecType.CP);
+ }
+
@Test
@Ignore
public void testSumDenseMatrixSP() {
@@ -189,7 +199,10 @@ public class FederatedFullCumulativeTest extends
AutomatedTestBase {
runTest(true, false, null, -1);
// compare via files
- compareResults(1e-6, "DML1", "DML2");
+ if (type != OpType.SUMPROD && type != OpType.PROD)
+ compareResults(1e-6, "DML1", "DML2");
+ else // we sum over the cumsumprod matrix and get a
very large number, hence the large tolerance
+ compareResults(1e+73, "DML1", "DML2");
switch(type) {
case SUM:
@@ -208,12 +221,20 @@ public class FederatedFullCumulativeTest extends
AutomatedTestBase {
heavyHittersContainsString(instType == ExecType.SPARK ? "fed_bcumoffmin" :
"fed_ucummin"));
break;
case SUMPROD:
-
Assert.assertTrue(heavyHittersContainsString(instType == ExecType.SPARK ?
"fed_bcumoff+*" : "ucumk+*"));
+ // when input is column-partitioned,
ucumk+* is executed instead of fed_ucumk+*
+
Assert.assertTrue(heavyHittersContainsString(instType == ExecType.SPARK ?
"fed_bcumoff+*" :
+ rowPartitioned ? "fed_ucumk+*"
: "ucumk+*"));
break;
}
- if(instType != ExecType.SPARK) // verify output is
federated
-
Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
+ if(instType != ExecType.SPARK) { // verify output is
federated
+ if (type == OpType.SUMPROD && !rowPartitioned) {
+
Assert.assertTrue(heavyHittersContainsString("uak+"));
+ } else {
+
Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
+ }
+ }
+
// check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java
index 9dcf9482bc..abf0c8c228 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java
@@ -39,6 +39,7 @@ import org.junit.runners.Parameterized;
public class FederatedIfelseTest extends AutomatedTestBase {
private final static String TEST_NAME1 = "FederatedIfelseTest";
private final static String TEST_NAME2 = "FederatedIfelseAlignedTest";
+ private final static String TEST_NAME3 =
"FederatedIfelseSingleMatrixInputTest";
private final static String TEST_DIR = "functions/federated/";
private static final String TEST_CLASS_DIR = TEST_DIR +
FederatedIfelseTest.class.getSimpleName() + "/";
@@ -62,36 +63,42 @@ public class FederatedIfelseTest extends AutomatedTestBase {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S"}));
addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S"}));
+ addTestConfiguration(TEST_NAME3, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S"}));
}
@Test
public void testIfelseDiffWorkersCP() {
- runTernaryTest(ExecMode.SINGLE_NODE, false);
+ runTernaryTest(ExecMode.SINGLE_NODE, false, false);
+ }
+
+ @Test
+ public void testIfelseDiffWorkersSingleMatInCP() {
+ runTernaryTest(ExecMode.SINGLE_NODE, false, true);
}
@Test
public void testIfelseAlignedCP() {
- runTernaryTest(ExecMode.SINGLE_NODE, true);
+ runTernaryTest(ExecMode.SINGLE_NODE, true, false);
}
@Test
public void testIfelseDiffWorkersSP() {
- runTernaryTest(ExecMode.SPARK, false);
+ runTernaryTest(ExecMode.SPARK, false, false);
}
@Test
public void testIfelseAlignedSP() {
- runTernaryTest(ExecMode.SPARK, true);
+ runTernaryTest(ExecMode.SPARK, true, false);
}
- private void runTernaryTest(ExecMode execMode, boolean aligned) {
+ private void runTernaryTest(ExecMode execMode, boolean aligned, boolean
singleMatrixInput) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;
if(rtplatform == ExecMode.SPARK)
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
- String TEST_NAME = aligned ? TEST_NAME2 : TEST_NAME1;
+ String TEST_NAME = aligned ? TEST_NAME2 : (!singleMatrixInput ?
TEST_NAME1 : TEST_NAME3);
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
@@ -148,10 +155,51 @@ public class FederatedIfelseTest extends
AutomatedTestBase {
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
- if(aligned)
- runAlignedTernary(HOME, TEST_NAME, r, c, port1,
port2, port3, port4);
- else
- runTernary(HOME, TEST_NAME, port1, port2,
port3, port4);
+ if(aligned) {
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME +
"Reference.dml";
+ programArgs = new String[] {"-stats", "100",
"-args", input("X1"), input("X2"), input("X3"), input("X4"),
+ input("Y1"), input("Y2"), input("Y3"),
input("Y4"), expected("S"),
+
Boolean.toString(rowPartitioned).toUpperCase()};
+ runTest(true, false, null, -1);
+
+ // Run actual dml script with federated matrix
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "100",
"-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "in_X2=" +
TestUtils.federatedAddress(port2, input("X2")),
+ "in_X3=" +
TestUtils.federatedAddress(port3, input("X3")),
+ "in_X4=" +
TestUtils.federatedAddress(port4, input("X4")),
+ "in_Y1=" +
TestUtils.federatedAddress(port1, input("Y1")),
+ "in_Y2=" +
TestUtils.federatedAddress(port2, input("Y2")),
+ "in_Y3=" +
TestUtils.federatedAddress(port3, input("Y3")),
+ "in_Y4=" +
TestUtils.federatedAddress(port4, input("Y4")), "rows=" + rows, "cols=" + cols,
+ "rP=" +
Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")};
+ runTest(true, false, null, -1);
+ } else {
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME +
"Reference.dml";
+ programArgs = new String[] {"-stats", "100",
"-args", input("X1"), input("X2"), input("X3"),
+ input("X4"), expected("S"),
Boolean.toString(rowPartitioned).toUpperCase()};
+ runTest(true, false, null, -1);
+
+ // Run actual dml script with federated matrix
+ double[][] x = getRandomMatrix(1, 1, 3.0, 3.0,
1, 1);
+ double[][] y = getRandomMatrix(1, 1, 4.0, 4.0,
1, 1);
+ MatrixCharacteristics mc1 = new
MatrixCharacteristics(1, 1, blocksize, 1 * 1);
+ writeInputMatrixWithMTD("x", x, false, mc1);
+ writeInputMatrixWithMTD("y", y, false, mc1);
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "100",
"-nvargs",
+ "in_X1=" +
TestUtils.federatedAddress(port1, input("X1")),
+ "in_X2=" +
TestUtils.federatedAddress(port2, input("X2")),
+ "in_X3=" +
TestUtils.federatedAddress(port3, input("X3")),
+ "in_X4=" +
TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + cols,
+ "x=" + input("x"), "y=" + input("y"),
"rP=" + Boolean.toString(rowPartitioned).toUpperCase(),
+ "out_S=" + output("S")};
+ runTest(true, false, null, -1);
+ }
// compare via files
compareResults(1e-9, "DML1", "DML2");
@@ -178,56 +226,4 @@ public class FederatedIfelseTest extends AutomatedTestBase
{
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
}
-
- private void runTernary(String HOME, String TEST_NAME, int port1, int
port2, int port3, int port4) {
- // Run reference dml script with normal matrix
- fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
- programArgs = new String[] {"-stats", "100", "-args",
input("X1"), input("X2"), input("X3"), input("X4"),
- expected("S"),
Boolean.toString(rowPartitioned).toUpperCase()};
- runTest(true, false, null, -1);
-
- // Run actual dml script with federated matrix
-
- fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-stats", "100", "-nvargs",
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
- "in_X2=" + TestUtils.federatedAddress(port2,
input("X2")),
- "in_X3=" + TestUtils.federatedAddress(port3,
input("X3")),
- "in_X4=" + TestUtils.federatedAddress(port4,
input("X4")), "rows=" + rows, "cols=" + cols,
- "rP=" + Boolean.toString(rowPartitioned).toUpperCase(),
"out_S=" + output("S")};
- runTest(true, false, null, -1);
- }
-
- private void runAlignedTernary(String HOME, String TEST_NAME, int r,
int c, int port1, int port2, int port3,
- int port4) {
- double[][] Y1 = getRandomMatrix(r, c, 10, 15, 1, 3);
- double[][] Y2 = getRandomMatrix(r, c, 10, 15, 1, 7);
- double[][] Y3 = getRandomMatrix(r, c, 10, 15, 1, 8);
- double[][] Y4 = getRandomMatrix(r, c, 10, 15, 1, 9);
- MatrixCharacteristics mc2 = new MatrixCharacteristics(r, c,
blocksize, r * c);
- writeInputMatrixWithMTD("Y1", Y1, false, mc2);
- writeInputMatrixWithMTD("Y2", Y2, false, mc2);
- writeInputMatrixWithMTD("Y3", Y3, false, mc2);
- writeInputMatrixWithMTD("Y4", Y4, false, mc2);
-
- // Run reference dml script with normal matrix
- fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
- programArgs = new String[] {"-stats", "100", "-args",
input("X1"), input("X2"), input("X3"), input("X4"),
- input("Y1"), input("Y2"), input("Y3"), input("Y4"),
expected("S"),
- Boolean.toString(rowPartitioned).toUpperCase()};
- runTest(true, false, null, -1);
-
- // Run actual dml script with federated matrix
-
- fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-stats", "100", "-nvargs",
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
- "in_X2=" + TestUtils.federatedAddress(port2,
input("X2")),
- "in_X3=" + TestUtils.federatedAddress(port3,
input("X3")),
- "in_X4=" + TestUtils.federatedAddress(port4,
input("X4")),
- "in_Y1=" + TestUtils.federatedAddress(port1,
input("Y1")),
- "in_Y2=" + TestUtils.federatedAddress(port2,
input("Y2")),
- "in_Y3=" + TestUtils.federatedAddress(port3,
input("Y3")),
- "in_Y4=" + TestUtils.federatedAddress(port4,
input("Y4")), "rows=" + rows, "cols=" + cols,
- "rP=" + Boolean.toString(rowPartitioned).toUpperCase(),
"out_S=" + output("S")};
- runTest(true, false, null, -1);
- }
}
diff --git
a/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
b/src/test/scripts/functions/federated/FederatedCentralMomentWeightedTest.dml
similarity index 73%
copy from src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
copy to
src/test/scripts/functions/federated/FederatedCentralMomentWeightedTest.dml
index e232cdf095..64771c7cd0 100644
--- a/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
+++
b/src/test/scripts/functions/federated/FederatedCentralMomentWeightedTest.dml
@@ -19,13 +19,11 @@
#
#-------------------------------------------------------------
+X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
-if($6) { A = rbind(read($1), read($2), read($3), read($4)); }
-else { A = cbind(read($1), read($2), read($3), read($4)); }
+W = read($in_W1);
-c1 = ifelse(A>0, A + matrix(1, nrow(A), ncol(A)), A*2);
-c2 = ifelse(A-3>0, A + matrix(1, nrow(A), ncol(A)), 3);
-c3 = ifelse(1, matrix(1, nrow(A), ncol(A)), 3);
-s = c2 + c3;
-s = s + 10*c1;
-write(s, $5);
+s = moment(X, W, $k);
+write(s, $out_S);
diff --git
a/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
b/src/test/scripts/functions/federated/FederatedCentralMomentWeightedTestReference.dml
similarity index 73%
copy from src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
copy to
src/test/scripts/functions/federated/FederatedCentralMomentWeightedTestReference.dml
index e232cdf095..83ee778694 100644
--- a/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
+++
b/src/test/scripts/functions/federated/FederatedCentralMomentWeightedTestReference.dml
@@ -19,13 +19,10 @@
#
#-------------------------------------------------------------
+X = rbind(read($1), read($2), read($3), read($4));
-if($6) { A = rbind(read($1), read($2), read($3), read($4)); }
-else { A = cbind(read($1), read($2), read($3), read($4)); }
+w = read($5);
+W = rbind(w, w, w, w)
-c1 = ifelse(A>0, A + matrix(1, nrow(A), ncol(A)), A*2);
-c2 = ifelse(A-3>0, A + matrix(1, nrow(A), ncol(A)), 3);
-c3 = ifelse(1, matrix(1, nrow(A), ncol(A)), 3);
-s = c2 + c3;
-s = s + 10*c1;
-write(s, $5);
+s = moment(X, W, $7);
+write(s, $6);
diff --git a/src/test/scripts/functions/federated/FederatedIfelseTest.dml
b/src/test/scripts/functions/federated/FederatedIfelseSingleMatrixInputTest.dml
similarity index 88%
copy from src/test/scripts/functions/federated/FederatedIfelseTest.dml
copy to
src/test/scripts/functions/federated/FederatedIfelseSingleMatrixInputTest.dml
index da650a3e65..579033134c 100644
--- a/src/test/scripts/functions/federated/FederatedIfelseTest.dml
+++
b/src/test/scripts/functions/federated/FederatedIfelseSingleMatrixInputTest.dml
@@ -28,9 +28,13 @@ if ($rP) {
ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4),
list($rows, $cols/2),
list(0,$cols/2), list($rows, 3*($cols/4)), list(0,
3*($cols/4)), list($rows, $cols)));
}
-c1 = ifelse(A>0, A + matrix(1, $rows, $cols), A*2);
-c2 = ifelse(A-3>0, A + matrix(1, $rows, $cols), 3);
-c3 = ifelse(1, matrix(1, $rows, $cols), 3);
+x = as.integer(as.scalar(read($x)));
+y = as.integer(as.scalar(read($y)));
+
+c1 = ifelse(x, A, y);
+c2 = ifelse(A-3>0, y, x);
+c3 = ifelse(x, y, A + matrix(1, $rows, $cols));
+
s = c2 + c3;
-s = s + 10*c1;
+s = s + x*c1;
write(s, $out_S);
diff --git
a/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
b/src/test/scripts/functions/federated/FederatedIfelseSingleMatrixInputTestReference.dml
similarity index 85%
copy from src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
copy to
src/test/scripts/functions/federated/FederatedIfelseSingleMatrixInputTestReference.dml
index e232cdf095..644f2795a7 100644
--- a/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
+++
b/src/test/scripts/functions/federated/FederatedIfelseSingleMatrixInputTestReference.dml
@@ -23,9 +23,10 @@
if($6) { A = rbind(read($1), read($2), read($3), read($4)); }
else { A = cbind(read($1), read($2), read($3), read($4)); }
-c1 = ifelse(A>0, A + matrix(1, nrow(A), ncol(A)), A*2);
-c2 = ifelse(A-3>0, A + matrix(1, nrow(A), ncol(A)), 3);
-c3 = ifelse(1, matrix(1, nrow(A), ncol(A)), 3);
+c1 = ifelse(3, A, 4);
+c2 = ifelse(A-3>0, 4, 3);
+c3 = ifelse(3, 4, A + matrix(1, nrow(A), ncol(A)));
+
s = c2 + c3;
-s = s + 10*c1;
+s = s + 3*c1;
write(s, $5);
diff --git a/src/test/scripts/functions/federated/FederatedIfelseTest.dml
b/src/test/scripts/functions/federated/FederatedIfelseTest.dml
index da650a3e65..e3b6ee9524 100644
--- a/src/test/scripts/functions/federated/FederatedIfelseTest.dml
+++ b/src/test/scripts/functions/federated/FederatedIfelseTest.dml
@@ -28,9 +28,12 @@ if ($rP) {
ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4),
list($rows, $cols/2),
list(0,$cols/2), list($rows, 3*($cols/4)), list(0,
3*($cols/4)), list($rows, $cols)));
}
-c1 = ifelse(A>0, A + matrix(1, $rows, $cols), A*2);
-c2 = ifelse(A-3>0, A + matrix(1, $rows, $cols), 3);
-c3 = ifelse(1, matrix(1, $rows, $cols), 3);
-s = c2 + c3;
-s = s + 10*c1;
+x = as.integer(as.scalar(read($x)));
+y = as.integer(as.scalar(read($y)));
+
+c1 = ifelse(x, A + matrix(1, $rows, $cols), A*2);
+c2 = ifelse(A-3>0, A + matrix(1, $rows, $cols), x);
+
+s = c2 + 1;
+s = s + x*c1;
write(s, $out_S);
diff --git
a/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
b/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
index e232cdf095..416ff6756c 100644
--- a/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
@@ -23,9 +23,9 @@
if($6) { A = rbind(read($1), read($2), read($3), read($4)); }
else { A = cbind(read($1), read($2), read($3), read($4)); }
-c1 = ifelse(A>0, A + matrix(1, nrow(A), ncol(A)), A*2);
+c1 = ifelse(3, A + matrix(1, nrow(A), ncol(A)), A*2);
c2 = ifelse(A-3>0, A + matrix(1, nrow(A), ncol(A)), 3);
-c3 = ifelse(1, matrix(1, nrow(A), ncol(A)), 3);
-s = c2 + c3;
-s = s + 10*c1;
+
+s = c2 + 1;
+s = s + 3*c1;
write(s, $5);