This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 53cba126b1 [SYSTEMDS-3800] Improve Code Coverage for Federated 
Operations
53cba126b1 is described below

commit 53cba126b1663a2923b34778086e8a3093502a5a
Author: Grigorii <[email protected]>
AuthorDate: Mon Dec 2 17:35:46 2024 +0100

    [SYSTEMDS-3800] Improve Code Coverage for Federated Operations
    
    Closes #2148.
---
 .../instructions/fed/TernaryFEDInstruction.java    |  40 ++++++-
 .../part1/FederatedCentralMomentTest.java          |  98 ++++++++++++-----
 .../part5/FederatedFullCumulativeTest.java         |  29 ++++-
 .../primitives/part5/FederatedIfelseTest.java      | 120 ++++++++++-----------
 ....dml => FederatedCentralMomentWeightedTest.dml} |  14 ++-
 ...ederatedCentralMomentWeightedTestReference.dml} |  13 +--
 ...ml => FederatedIfelseSingleMatrixInputTest.dml} |  12 ++-
 ...eratedIfelseSingleMatrixInputTestReference.dml} |   9 +-
 .../functions/federated/FederatedIfelseTest.dml    |  13 ++-
 .../federated/FederatedIfelseTestReference.dml     |   8 +-
 10 files changed, 224 insertions(+), 132 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
index 0883e6fe02..786f08344e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
@@ -122,7 +122,38 @@ public class TernaryFEDInstruction extends 
ComputationFEDInstruction {
                if(matrixInputsCount == 3)
                        processMatrixInput(ec, mo1, mo2, mo3);
                else if(matrixInputsCount == 1) {
-                       CPOperand in = mo1 == null ? mo2 == null ? input3 : 
input2 : input1;
+                       CPOperand in;
+                       // determine the position of a matrix in the input and 
whether any of the scalars are not literals
+                       if (mo1 == null) {
+                               if (mo2 == null) { // sc, sc, mat
+                                       in = input3;
+                                       instString = 
InstructionUtils.replaceOperand(instString, 2,
+                                               
InstructionUtils.createLiteralOperand(ec.getScalarInput(input1).getStringValue(),
 Types.ValueType.FP64));
+                                       if (!input2.isLiteral()) {
+                                               instString = 
InstructionUtils.replaceOperand(instString, 3,
+                                                       
InstructionUtils.createLiteralOperand(ec.getScalarInput(input2).getStringValue(),
 Types.ValueType.FP64));
+                                       }
+                               } else { // sc, mat, sc
+                                       in = input2;
+                                       instString = 
InstructionUtils.replaceOperand(instString, 2,
+                                               
InstructionUtils.createLiteralOperand(ec.getScalarInput(input1).getStringValue(),
 Types.ValueType.FP64));
+                                       if (!input3.isLiteral()) {
+                                               instString = 
InstructionUtils.replaceOperand(instString, 4,
+                                                       
InstructionUtils.createLiteralOperand(ec.getScalarInput(input3).getStringValue(),
 Types.ValueType.FP64));
+                                       }
+                               }
+                       } else { // mat, sc, sc
+                               in = input1;
+                               if (!input2.isLiteral()) {
+                                       instString = 
InstructionUtils.replaceOperand(instString, 3,
+                                               
InstructionUtils.createLiteralOperand(ec.getScalarInput(input2).getStringValue(),
 Types.ValueType.FP64));
+                               }
+                               if (!input3.isLiteral()) {
+                                       instString = 
InstructionUtils.replaceOperand(instString, 4,
+                                               
InstructionUtils.createLiteralOperand(ec.getScalarInput(input3).getStringValue(),
 Types.ValueType.FP64));
+                               }
+                       }
+
                        mo1 = mo1 == null ? mo2 == null ? mo3 : mo2 : mo1;
                        processMatrixScalarInput(ec, mo1, in);
                }
@@ -150,11 +181,10 @@ public class TernaryFEDInstruction extends 
ComputationFEDInstruction {
 
        private void processMatrixScalarInput(ExecutionContext ec, 
MatrixLineagePair mo1, CPOperand in) {
                long id = FederationUtils.getNextFedDataID();
-               FederatedRequest fr1 = new 
FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new 
MatrixCharacteristics(-1, -1), mo1.getDataType());
-
-               FederatedRequest fr2 = 
FederationUtils.callInstruction(instString, output, id, new CPOperand[] {in}, 
new long[] {mo1.getFedMapping().getID()},
+               FederatedRequest fr = 
FederationUtils.callInstruction(instString, output, id, new CPOperand[] {in}, 
new long[] {mo1.getFedMapping().getID()},
                        InstructionUtils.getExecType(instString), false);
-               sendFederatedRequests(ec, mo1.getMO(), fr1.getID(), fr1, fr2);
+
+               sendFederatedRequests(ec, mo1.getMO(), fr.getID(), fr);
        }
 
        private void process2MatrixScalarInput(ExecutionContext ec, 
MatrixLineagePair mo1, MatrixLineagePair mo2, CPOperand in1, CPOperand in2) {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java
index 6b23897bd5..1d9f951e78 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java
@@ -40,7 +40,8 @@ import org.junit.runners.Parameterized;
 public class FederatedCentralMomentTest extends AutomatedTestBase {
 
        private final static String TEST_DIR = "functions/federated/";
-       private final static String TEST_NAME = "FederatedCentralMomentTest";
+       private final static String TEST_NAME1 = "FederatedCentralMomentTest";
+       private final static String TEST_NAME2 = 
"FederatedCentralMomentWeightedTest";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedCentralMomentTest.class.getSimpleName() + "/";
 
        private final static int blocksize = 1024;
@@ -48,44 +49,62 @@ public class FederatedCentralMomentTest extends 
AutomatedTestBase {
        public int rows;
 
        @Parameterized.Parameter(1)
+       public int cols;
+
+       @Parameterized.Parameter(2)
        public int k;
 
        @Parameterized.Parameters
        public static Collection<Object[]> data() {
-               return Arrays.asList(new Object[][] {{1000, 2}, {1000, 3}, 
{1000, 4}});
+               return Arrays.asList(new Object[][] {{20, 1, 2}});
        }
 
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S.scalar"}));
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S.scalar"}));
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S.scalar"}));
        }
 
        @Test
-       @Ignore // infinite runtime online but works locally.
        public void federatedCentralMomentCP() {
-               federatedCentralMoment(Types.ExecMode.SINGLE_NODE);
+               federatedCentralMoment(Types.ExecMode.SINGLE_NODE, false);
+       }
+
+       @Test
+       public void federatedCentralMomentWeightedCP() {
+               federatedCentralMoment(Types.ExecMode.SINGLE_NODE, true);
        }
 
        @Test
-       @Ignore
        public void federatedCentralMomentSP() {
-               federatedCentralMoment(Types.ExecMode.SPARK);
+               federatedCentralMoment(Types.ExecMode.SPARK, false);
+       }
+
+       // The test fails due to an error while executing rmvar instruction 
after cm calculation
+       // The CacheStatus of the weights variable is READ hence it can't be 
modified
+       // In this test the input matrix is federated and weights are read from 
file
+       @Ignore
+       @Test
+       public void federatedCentralMomentWeightedSP() {
+               federatedCentralMoment(Types.ExecMode.SPARK, true);
        }
 
-       public void federatedCentralMoment(Types.ExecMode execMode) {
+       public void federatedCentralMoment(Types.ExecMode execMode, boolean 
isWeighted) {
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
                Types.ExecMode platformOld = rtplatform;
 
+               String TEST_NAME = isWeighted ? TEST_NAME2 : TEST_NAME1;
                getAndLoadTestConfiguration(TEST_NAME);
                String HOME = SCRIPT_DIR + TEST_DIR;
 
                int r = rows / 4;
+               int c = cols;
 
-               double[][] X1 = getRandomMatrix(r, 1, 1, 5, 1, 3);
-               double[][] X2 = getRandomMatrix(r, 1, 1, 5, 1, 7);
-               double[][] X3 = getRandomMatrix(r, 1, 1, 5, 1, 8);
-               double[][] X4 = getRandomMatrix(r, 1, 1, 5, 1, 9);
+               double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
+               double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
+               double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
+               double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
 
                MatrixCharacteristics mc = new MatrixCharacteristics(r, 1, 
blocksize, r);
                writeInputMatrixWithMTD("X1", X1, false, mc);
@@ -114,24 +133,47 @@ public class FederatedCentralMomentTest extends 
AutomatedTestBase {
                        if(rtplatform == Types.ExecMode.SPARK) {
                                DMLScript.USE_LOCAL_SPARK_CONFIG = true;
                        }
-                       // Run reference dml script with normal matrix for 
Row/Col
-                       fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
-                       programArgs = new String[] {"-stats", "100", "-args", 
input("X1"), input("X2"), input("X3"), input("X4"),
-                               expected("S"), String.valueOf(k)};
-                       runTest(null);
-
                        TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
                        loadTestConfiguration(config);
-
-                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
-                       programArgs = new String[] {"-stats", "100", "-nvargs",
-                               "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
-                               "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
-                               "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
-                               "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")), "rows=" + rows, "cols=" + 1,
-                               "out_S=" + output("S"), "k=" + k};
-                       runTest(null);
-
+                       if (isWeighted) {
+                               double[][] W1 = getRandomMatrix(r, c, 0, 1, 1, 
3);
+
+                               writeInputMatrixWithMTD("W1", W1, false, mc);
+
+                               // Run reference dml script with normal matrix
+                               fullDMLScriptName = HOME + TEST_NAME + 
"Reference.dml";
+                               programArgs = new String[] {"-stats", "100", 
"-args", input("X1"), input("X2"), input("X3"), input("X4"),
+                                               input("W1"), expected("S"), "" 
+ k};
+                               runTest(null);
+
+                               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                               programArgs = new String[] {"-stats", "100", 
"-nvargs",
+                                               "in_X1=" + 
TestUtils.federatedAddress(port1, input("X1")),
+                                               "in_X2=" + 
TestUtils.federatedAddress(port2, input("X2")),
+                                               "in_X3=" + 
TestUtils.federatedAddress(port3, input("X3")),
+                                               "in_X4=" + 
TestUtils.federatedAddress(port4, input("X4")),
+                                               "in_W1=" + input("W1"),
+                                               "rows=" + rows, "cols=" + cols, 
"k=" + k,
+                                               "out_S=" + output("S")};
+                               runTest(null);
+                       }
+                       else {
+                               // Run reference dml script with normal matrix 
for Row/Col
+                               fullDMLScriptName = HOME + TEST_NAME + 
"Reference.dml";
+                               programArgs = new String[]{"-stats", "100", 
"-args", input("X1"), input("X2"), input("X3"), input("X4"),
+                                               expected("S"), 
String.valueOf(k)};
+                               runTest(null);
+
+
+                               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                               programArgs = new String[]{"-stats", "100", 
"-nvargs",
+                                               "in_X1=" + 
TestUtils.federatedAddress(port1, input("X1")),
+                                               "in_X2=" + 
TestUtils.federatedAddress(port2, input("X2")),
+                                               "in_X3=" + 
TestUtils.federatedAddress(port3, input("X3")),
+                                               "in_X4=" + 
TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + 1,
+                                               "out_S=" + output("S"), "k=" + 
k};
+                               runTest(null);
+                       }
                        // compare all sums via files
                        compareResults(0.01);
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java
index b47b920006..e5ff93e02c 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java
@@ -89,6 +89,16 @@ public class FederatedFullCumulativeTest extends 
AutomatedTestBase {
                runCumOperationTest(OpType.MIN, ExecType.CP);
        }
 
+       @Test
+       public void testProdDenseMatrixCP() {
+               runCumOperationTest(OpType.PROD, ExecType.CP);
+       }
+
+       @Test
+       public void testSumProdDenseMatrixCP() {
+               runCumOperationTest(OpType.SUMPROD, ExecType.CP);
+       }
+
        @Test
        @Ignore
        public void testSumDenseMatrixSP() {
@@ -189,7 +199,10 @@ public class FederatedFullCumulativeTest extends 
AutomatedTestBase {
                        runTest(true, false, null, -1);
 
                        // compare via files
-                       compareResults(1e-6, "DML1", "DML2");
+                       if (type != OpType.SUMPROD && type != OpType.PROD)
+                               compareResults(1e-6, "DML1", "DML2");
+                       else // we sum over the cumsumprod matrix and get a 
very large number, hence the large tolerance
+                               compareResults(1e+73, "DML1", "DML2");
 
                        switch(type) {
                                case SUM:
@@ -208,12 +221,20 @@ public class FederatedFullCumulativeTest extends 
AutomatedTestBase {
                                                
heavyHittersContainsString(instType == ExecType.SPARK ? "fed_bcumoffmin" : 
"fed_ucummin"));
                                        break;
                                case SUMPROD:
-                                       
Assert.assertTrue(heavyHittersContainsString(instType == ExecType.SPARK ? 
"fed_bcumoff+*" : "ucumk+*"));
+                                       // when input is column-partitioned, 
ucumk+* is executed instead of fed_ucumk+*
+                                       
Assert.assertTrue(heavyHittersContainsString(instType == ExecType.SPARK ? 
"fed_bcumoff+*" :
+                                               rowPartitioned ? "fed_ucumk+*" 
: "ucumk+*"));
                                        break;
                        }
 
-                       if(instType != ExecType.SPARK) // verify output is 
federated
-                               
Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
+                       if(instType != ExecType.SPARK) { // verify output is 
federated
+                               if (type == OpType.SUMPROD && !rowPartitioned) {
+                                       
Assert.assertTrue(heavyHittersContainsString("uak+"));
+                               } else {
+                                       
Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
+                               }
+                       }
+
 
                        // check that federated input files are still existing
                        
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java
index 9dcf9482bc..abf0c8c228 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java
@@ -39,6 +39,7 @@ import org.junit.runners.Parameterized;
 public class FederatedIfelseTest extends AutomatedTestBase {
        private final static String TEST_NAME1 = "FederatedIfelseTest";
        private final static String TEST_NAME2 = "FederatedIfelseAlignedTest";
+       private final static String TEST_NAME3 = 
"FederatedIfelseSingleMatrixInputTest";
 
        private final static String TEST_DIR = "functions/federated/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedIfelseTest.class.getSimpleName() + "/";
@@ -62,36 +63,42 @@ public class FederatedIfelseTest extends AutomatedTestBase {
                TestUtils.clearAssertionInformation();
                addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S"}));
                addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S"}));
+               addTestConfiguration(TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S"}));
        }
 
        @Test
        public void testIfelseDiffWorkersCP() {
-               runTernaryTest(ExecMode.SINGLE_NODE, false);
+               runTernaryTest(ExecMode.SINGLE_NODE, false, false);
+       }
+
+       @Test
+       public void testIfelseDiffWorkersSingleMatInCP() {
+               runTernaryTest(ExecMode.SINGLE_NODE, false, true);
        }
 
        @Test
        public void testIfelseAlignedCP() {
-               runTernaryTest(ExecMode.SINGLE_NODE, true);
+               runTernaryTest(ExecMode.SINGLE_NODE, true, false);
        }
 
        @Test
        public void testIfelseDiffWorkersSP() {
-               runTernaryTest(ExecMode.SPARK, false);
+               runTernaryTest(ExecMode.SPARK, false, false);
        }
 
        @Test
        public void testIfelseAlignedSP() {
-               runTernaryTest(ExecMode.SPARK, true);
+               runTernaryTest(ExecMode.SPARK, true, false);
        }
 
-       private void runTernaryTest(ExecMode execMode, boolean aligned) {
+       private void runTernaryTest(ExecMode execMode, boolean aligned, boolean 
singleMatrixInput) {
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
                ExecMode platformOld = rtplatform;
 
                if(rtplatform == ExecMode.SPARK)
                        DMLScript.USE_LOCAL_SPARK_CONFIG = true;
 
-               String TEST_NAME = aligned ? TEST_NAME2 : TEST_NAME1;
+               String TEST_NAME = aligned ? TEST_NAME2 : (!singleMatrixInput ? 
TEST_NAME1 : TEST_NAME3);
 
                getAndLoadTestConfiguration(TEST_NAME);
                String HOME = SCRIPT_DIR + TEST_DIR;
@@ -148,10 +155,51 @@ public class FederatedIfelseTest extends 
AutomatedTestBase {
                        TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
                        loadTestConfiguration(config);
 
-                       if(aligned)
-                               runAlignedTernary(HOME, TEST_NAME, r, c, port1, 
port2, port3, port4);
-                       else
-                               runTernary(HOME, TEST_NAME, port1, port2, 
port3, port4);
+                       if(aligned) {
+                               // Run reference dml script with normal matrix
+                               fullDMLScriptName = HOME + TEST_NAME + 
"Reference.dml";
+                               programArgs = new String[] {"-stats", "100", 
"-args", input("X1"), input("X2"), input("X3"), input("X4"),
+                                       input("Y1"), input("Y2"), input("Y3"), 
input("Y4"), expected("S"),
+                                       
Boolean.toString(rowPartitioned).toUpperCase()};
+                               runTest(true, false, null, -1);
+
+                               // Run actual dml script with federated matrix
+
+                               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                               programArgs = new String[] {"-stats", "100", 
"-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+                                       "in_X2=" + 
TestUtils.federatedAddress(port2, input("X2")),
+                                       "in_X3=" + 
TestUtils.federatedAddress(port3, input("X3")),
+                                       "in_X4=" + 
TestUtils.federatedAddress(port4, input("X4")),
+                                       "in_Y1=" + 
TestUtils.federatedAddress(port1, input("Y1")),
+                                       "in_Y2=" + 
TestUtils.federatedAddress(port2, input("Y2")),
+                                       "in_Y3=" + 
TestUtils.federatedAddress(port3, input("Y3")),
+                                       "in_Y4=" + 
TestUtils.federatedAddress(port4, input("Y4")), "rows=" + rows, "cols=" + cols,
+                                       "rP=" + 
Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")};
+                               runTest(true, false, null, -1);
+                       } else {
+                               // Run reference dml script with normal matrix
+                               fullDMLScriptName = HOME + TEST_NAME + 
"Reference.dml";
+                               programArgs = new String[] {"-stats", "100", 
"-args", input("X1"), input("X2"), input("X3"),
+                                       input("X4"), expected("S"), 
Boolean.toString(rowPartitioned).toUpperCase()};
+                               runTest(true, false, null, -1);
+
+                               // Run actual dml script with federated matrix
+                               double[][] x = getRandomMatrix(1, 1, 3.0, 3.0, 
1, 1);
+                               double[][] y = getRandomMatrix(1, 1, 4.0, 4.0, 
1, 1);
+                               MatrixCharacteristics mc1 = new 
MatrixCharacteristics(1, 1, blocksize, 1 * 1);
+                               writeInputMatrixWithMTD("x", x, false, mc1);
+                               writeInputMatrixWithMTD("y", y, false, mc1);
+
+                               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                               programArgs = new String[] {"-stats", "100", 
"-nvargs",
+                                       "in_X1=" + 
TestUtils.federatedAddress(port1, input("X1")),
+                                       "in_X2=" + 
TestUtils.federatedAddress(port2, input("X2")),
+                                       "in_X3=" + 
TestUtils.federatedAddress(port3, input("X3")),
+                                       "in_X4=" + 
TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + cols,
+                                       "x=" + input("x"), "y=" + input("y"), 
"rP=" + Boolean.toString(rowPartitioned).toUpperCase(),
+                                       "out_S=" + output("S")};
+                               runTest(true, false, null, -1);
+                       }
 
                        // compare via files
                        compareResults(1e-9, "DML1", "DML2");
@@ -178,56 +226,4 @@ public class FederatedIfelseTest extends AutomatedTestBase 
{
                        DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
                }
        }
-
-       private void runTernary(String HOME, String TEST_NAME, int port1, int 
port2, int port3, int port4) {
-               // Run reference dml script with normal matrix
-               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
-               programArgs = new String[] {"-stats", "100", "-args", 
input("X1"), input("X2"), input("X3"), input("X4"),
-                       expected("S"), 
Boolean.toString(rowPartitioned).toUpperCase()};
-               runTest(true, false, null, -1);
-
-               // Run actual dml script with federated matrix
-
-               fullDMLScriptName = HOME + TEST_NAME + ".dml";
-               programArgs = new String[] {"-stats", "100", "-nvargs", 
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
-                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
-                       "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
-                       "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")), "rows=" + rows, "cols=" + cols,
-                       "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), 
"out_S=" + output("S")};
-               runTest(true, false, null, -1);
-       }
-
-       private void runAlignedTernary(String HOME, String TEST_NAME, int r, 
int c, int port1, int port2, int port3,
-               int port4) {
-               double[][] Y1 = getRandomMatrix(r, c, 10, 15, 1, 3);
-               double[][] Y2 = getRandomMatrix(r, c, 10, 15, 1, 7);
-               double[][] Y3 = getRandomMatrix(r, c, 10, 15, 1, 8);
-               double[][] Y4 = getRandomMatrix(r, c, 10, 15, 1, 9);
-               MatrixCharacteristics mc2 = new MatrixCharacteristics(r, c, 
blocksize, r * c);
-               writeInputMatrixWithMTD("Y1", Y1, false, mc2);
-               writeInputMatrixWithMTD("Y2", Y2, false, mc2);
-               writeInputMatrixWithMTD("Y3", Y3, false, mc2);
-               writeInputMatrixWithMTD("Y4", Y4, false, mc2);
-
-               // Run reference dml script with normal matrix
-               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
-               programArgs = new String[] {"-stats", "100", "-args", 
input("X1"), input("X2"), input("X3"), input("X4"),
-                       input("Y1"), input("Y2"), input("Y3"), input("Y4"), 
expected("S"),
-                       Boolean.toString(rowPartitioned).toUpperCase()};
-               runTest(true, false, null, -1);
-
-               // Run actual dml script with federated matrix
-
-               fullDMLScriptName = HOME + TEST_NAME + ".dml";
-               programArgs = new String[] {"-stats", "100", "-nvargs", 
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
-                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
-                       "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
-                       "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")),
-                       "in_Y1=" + TestUtils.federatedAddress(port1, 
input("Y1")),
-                       "in_Y2=" + TestUtils.federatedAddress(port2, 
input("Y2")),
-                       "in_Y3=" + TestUtils.federatedAddress(port3, 
input("Y3")),
-                       "in_Y4=" + TestUtils.federatedAddress(port4, 
input("Y4")), "rows=" + rows, "cols=" + cols,
-                       "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), 
"out_S=" + output("S")};
-               runTest(true, false, null, -1);
-       }
 }
diff --git 
a/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml 
b/src/test/scripts/functions/federated/FederatedCentralMomentWeightedTest.dml
similarity index 73%
copy from src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
copy to 
src/test/scripts/functions/federated/FederatedCentralMomentWeightedTest.dml
index e232cdf095..64771c7cd0 100644
--- a/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
+++ 
b/src/test/scripts/functions/federated/FederatedCentralMomentWeightedTest.dml
@@ -19,13 +19,11 @@
 #
 #-------------------------------------------------------------
 
+X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+    ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+    list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
 
-if($6) { A = rbind(read($1), read($2), read($3), read($4)); }
-else { A = cbind(read($1), read($2), read($3), read($4)); }
+W = read($in_W1);
 
-c1 = ifelse(A>0, A + matrix(1, nrow(A), ncol(A)), A*2);
-c2 = ifelse(A-3>0, A + matrix(1, nrow(A), ncol(A)), 3);
-c3 = ifelse(1, matrix(1, nrow(A), ncol(A)), 3);
-s = c2 + c3;
-s = s + 10*c1;
-write(s, $5);
+s = moment(X, W, $k);
+write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml 
b/src/test/scripts/functions/federated/FederatedCentralMomentWeightedTestReference.dml
similarity index 73%
copy from src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
copy to 
src/test/scripts/functions/federated/FederatedCentralMomentWeightedTestReference.dml
index e232cdf095..83ee778694 100644
--- a/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
+++ 
b/src/test/scripts/functions/federated/FederatedCentralMomentWeightedTestReference.dml
@@ -19,13 +19,10 @@
 #
 #-------------------------------------------------------------
 
+X = rbind(read($1), read($2), read($3), read($4));
 
-if($6) { A = rbind(read($1), read($2), read($3), read($4)); }
-else { A = cbind(read($1), read($2), read($3), read($4)); }
+w = read($5);
+W = rbind(w, w, w, w)
 
-c1 = ifelse(A>0, A + matrix(1, nrow(A), ncol(A)), A*2);
-c2 = ifelse(A-3>0, A + matrix(1, nrow(A), ncol(A)), 3);
-c3 = ifelse(1, matrix(1, nrow(A), ncol(A)), 3);
-s = c2 + c3;
-s = s + 10*c1;
-write(s, $5);
+s = moment(X, W, $7);
+write(s, $6);
diff --git a/src/test/scripts/functions/federated/FederatedIfelseTest.dml 
b/src/test/scripts/functions/federated/FederatedIfelseSingleMatrixInputTest.dml
similarity index 88%
copy from src/test/scripts/functions/federated/FederatedIfelseTest.dml
copy to 
src/test/scripts/functions/federated/FederatedIfelseSingleMatrixInputTest.dml
index da650a3e65..579033134c 100644
--- a/src/test/scripts/functions/federated/FederatedIfelseTest.dml
+++ 
b/src/test/scripts/functions/federated/FederatedIfelseSingleMatrixInputTest.dml
@@ -28,9 +28,13 @@ if ($rP) {
             ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
                list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
 }
-c1 = ifelse(A>0, A + matrix(1, $rows, $cols), A*2);
-c2 = ifelse(A-3>0, A + matrix(1, $rows, $cols), 3);
-c3 = ifelse(1, matrix(1, $rows, $cols), 3);
+x = as.integer(as.scalar(read($x)));
+y = as.integer(as.scalar(read($y)));
+
+c1 = ifelse(x, A, y);
+c2 = ifelse(A-3>0, y, x);
+c3 = ifelse(x, y, A + matrix(1, $rows, $cols));
+
 s = c2 + c3;
-s = s + 10*c1;
+s = s + x*c1;
 write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml 
b/src/test/scripts/functions/federated/FederatedIfelseSingleMatrixInputTestReference.dml
similarity index 85%
copy from src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
copy to 
src/test/scripts/functions/federated/FederatedIfelseSingleMatrixInputTestReference.dml
index e232cdf095..644f2795a7 100644
--- a/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
+++ 
b/src/test/scripts/functions/federated/FederatedIfelseSingleMatrixInputTestReference.dml
@@ -23,9 +23,10 @@
 if($6) { A = rbind(read($1), read($2), read($3), read($4)); }
 else { A = cbind(read($1), read($2), read($3), read($4)); }
 
-c1 = ifelse(A>0, A + matrix(1, nrow(A), ncol(A)), A*2);
-c2 = ifelse(A-3>0, A + matrix(1, nrow(A), ncol(A)), 3);
-c3 = ifelse(1, matrix(1, nrow(A), ncol(A)), 3);
+c1 = ifelse(3, A, 4);
+c2 = ifelse(A-3>0, 4, 3);
+c3 = ifelse(3, 4, A + matrix(1, nrow(A), ncol(A)));
+
 s = c2 + c3;
-s = s + 10*c1;
+s = s + 3*c1;
 write(s, $5);
diff --git a/src/test/scripts/functions/federated/FederatedIfelseTest.dml 
b/src/test/scripts/functions/federated/FederatedIfelseTest.dml
index da650a3e65..e3b6ee9524 100644
--- a/src/test/scripts/functions/federated/FederatedIfelseTest.dml
+++ b/src/test/scripts/functions/federated/FederatedIfelseTest.dml
@@ -28,9 +28,12 @@ if ($rP) {
             ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
                list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
 }
-c1 = ifelse(A>0, A + matrix(1, $rows, $cols), A*2);
-c2 = ifelse(A-3>0, A + matrix(1, $rows, $cols), 3);
-c3 = ifelse(1, matrix(1, $rows, $cols), 3);
-s = c2 + c3;
-s = s + 10*c1;
+x = as.integer(as.scalar(read($x)));
+y = as.integer(as.scalar(read($y)));
+
+c1 = ifelse(x, A + matrix(1, $rows, $cols), A*2);
+c2 = ifelse(A-3>0, A + matrix(1, $rows, $cols), x);
+
+s = c2 + 1;
+s = s + x*c1;
 write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml 
b/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
index e232cdf095..416ff6756c 100644
--- a/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedIfelseTestReference.dml
@@ -23,9 +23,9 @@
 if($6) { A = rbind(read($1), read($2), read($3), read($4)); }
 else { A = cbind(read($1), read($2), read($3), read($4)); }
 
-c1 = ifelse(A>0, A + matrix(1, nrow(A), ncol(A)), A*2);
+c1 = ifelse(3, A + matrix(1, nrow(A), ncol(A)), A*2);
 c2 = ifelse(A-3>0, A + matrix(1, nrow(A), ncol(A)), 3);
-c3 = ifelse(1, matrix(1, nrow(A), ncol(A)), 3);
-s = c2 + c3;
-s = s + 10*c1;
+
+s = c2 + 1;
+s = s + 3*c1;
 write(s, $5);

Reply via email to