This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 7a9ecff801 [SYSTEMDS-3790] Restore federated planner tests (all, 
heuristic)
7a9ecff801 is described below

commit 7a9ecff8014c9ba490910f7058082d2faba42ea3
Author: min-guk <[email protected]>
AuthorDate: Sun Nov 17 14:04:21 2024 +0100

    [SYSTEMDS-3790] Restore federated planner tests (all, heuristic)
    
    Closes #2139.
---
 .github/workflows/javaTests.yml                    |   2 +-
 .../org/apache/sysds/hops/fedplanner/FTypes.java   |   9 +-
 .../fedplanning/FederatedDynamicPlanningTest.java  | 176 ++++++++++
 .../fedplanning/FederatedKMeansPlanningTest.java   | 156 +++++++++
 .../fedplanning/FederatedL2SVMPlanningTest.java    | 185 ++++++++++
 .../fedplanning/FederatedMultiplyPlanningTest.java | 318 ++++++++++++++++++
 .../test/functions/fedplanning/FTypeCombTest.java  |  71 ++++
 .../fedplanning/FederatedCostEstimatorTest.java    | 373 +++++++++++++++++++++
 .../fedplanning/FederatedDynamicPlanningTest.java  | 188 +++++++++++
 .../fedplanning/FederatedKMeansPlanningTest.java   | 168 ++++++++++
 .../fedplanning/FederatedL2SVMPlanningTest.java    | 202 +++++++++++
 .../fedplanning/FederatedMultiplyPlanningTest.java | 334 ++++++++++++++++++
 12 files changed, 2174 insertions(+), 8 deletions(-)

diff --git a/.github/workflows/javaTests.yml b/.github/workflows/javaTests.yml
index cd6e28670e..9f0258a0d1 100644
--- a/.github/workflows/javaTests.yml
+++ b/.github/workflows/javaTests.yml
@@ -62,7 +62,7 @@ jobs:
           
"**.functions.compress.**,**.functions.data.tensor.**,**.functions.codegenalg.parttwo.**,**.functions.codegen.**,**.functions.caching.**",
           
"**.functions.binary.matrix_full_cellwise.**,**.functions.binary.matrix_full_other.**",
           
"**.functions.federated.algorithms.**,**.functions.federated.io.**,**.functions.federated.paramserv.**",
-          "**.functions.federated.transform.**",
+          
"**.functions.federated.transform.**,**.functions.federated.fedplanner.**",
           "**.functions.federated.primitives.part1.** -Dtest-threadCount=1 
-Dtest-forkCount=1",
           "**.functions.federated.primitives.part2.** -Dtest-threadCount=1 
-Dtest-forkCount=1",
           "**.functions.federated.primitives.part3.** -Dtest-threadCount=1 
-Dtest-forkCount=1",
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java
index a82a56e88b..de9c9cb670 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java
@@ -44,13 +44,8 @@ public class FTypes
                        return this != NONE && this != RUNTIME;
                }
                public static boolean isCompiled(String planner) {
-                       try {
-                               return 
FederatedPlanner.valueOf(planner.toUpperCase()).isCompiled();
-                       }
-                       catch(Exception ex) {
-                               ex.printStackTrace();
-                               return false;
-                       }
+                       return planner != null 
+                               && 
FederatedPlanner.valueOf(planner.toUpperCase()).isCompiled();
                }
        }
        
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java
new file mode 100644
index 0000000000..bd098bf827
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java
@@ -0,0 +1,176 @@
+/*
+ *  Licensed to the Apache Software Foundation (ASF) under one
+ *  or more contributor license agreements.  See the NOTICE file
+ *  distributed with this work for additional information
+ *  regarding copyright ownership.  The ASF licenses this file
+ *  to you under the Apache License, Version 2.0 (the
+ *  "License"); you may not use this file except in compliance
+ *  with the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ *  Unless required by applicable law or agreed to in writing,
+ *  software distributed under the License is distributed on an
+ *  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ *  KIND, either express or implied.  See the License for the
+ *  specific language governing permissions and limitations
+ *  under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.fedplanning;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import java.io.File;
+import java.util.Arrays;
+
+import static org.junit.Assert.fail;
+
[email protected]
+public class FederatedDynamicPlanningTest extends AutomatedTestBase {
+       private static final Log LOG = 
LogFactory.getLog(FederatedDynamicPlanningTest.class.getName());
+
+       private final static String TEST_DIR = "functions/privacy/fedplanning/";
+       private final static String TEST_NAME = 
"FederatedDynamicFunctionPlanningTest";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedDynamicPlanningTest.class.getSimpleName() + "/";
+       private static File TEST_CONF_FILE;
+
+       private final static int blocksize = 1024;
+       public final int rows = 1000;
+       public final int cols = 1000;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+       }
+
+       @Test
+       @Ignore
+       public void runDynamicFullFunctionTest() {
+               // compared to `FederatedL2SVMPlanningTest` this does not 
create `fed_+*` or `fed_tsmm`, probably due to
+               // some rewrites not being applied. Might be a bug.
+               String[] expectedHeavyHitters = new String[] {"fed_fedinit", 
"fed_ba+*", "fed_tak+*", "fed_max",
+                               "fed_1-*", "fed_>"};
+               setTestConf("SystemDS-config-fout.xml");
+               loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+       }
+
+       @Test
+       @Ignore
+       public void runDynamicHeuristicFunctionTest() {
+               // compared to `FederatedL2SVMPlanningTest` this does not 
create `fed_+*` or `fed_tsmm`, probably due to
+               // some rewrites not being applied. Might be a bug.
+               String[] expectedHeavyHitters = new String[] {"fed_fedinit", 
"fed_ba+*"};
+               setTestConf("SystemDS-config-heuristic.xml");
+               loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+       }
+
+       private void setTestConf(String test_conf) {
+               TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, test_conf);
+       }
+
+       private void writeInputMatrices() {
+               writeBinaryVector("A", 42, rows);
+               writeStandardMatrix("B1", 65, rows / 2, cols);
+               writeStandardMatrix("B2", 75, rows / 2, cols);
+               writeStandardMatrix("C1", 13, rows, cols / 2);
+               writeStandardMatrix("C2", 17, rows, cols / 2);
+       }
+
+       private void writeBinaryVector(String matrixName, long seed, int 
numRows){
+               double[][] matrix = getRandomMatrix(numRows, 1, -1, 1, 1, seed);
+               for(int i = 0; i < numRows; i++)
+                       matrix[i][0] = (matrix[i][0] > 0) ? 1 : -1;
+               MatrixCharacteristics mc = new MatrixCharacteristics(numRows, 
1, blocksize, numRows);
+               writeInputMatrixWithMTD(matrixName, matrix, false, mc);
+       }
+
+       private void writeStandardMatrix(String matrixName, long seed, int 
numRows, int numCols) {
+               double[][] matrix = getRandomMatrix(numRows, numCols, 0, 1, 1, 
seed);
+               writeStandardMatrix(matrixName, numRows, numCols, matrix);
+       }
+
+       private void writeStandardMatrix(String matrixName, int numRows, int 
numCols, double[][] matrix) {
+               MatrixCharacteristics mc = new MatrixCharacteristics(numRows, 
numCols, blocksize, (long) numRows * numCols);
+               writeInputMatrixWithMTD(matrixName, matrix, false, mc);
+       }
+
+       private void loadAndRunTest(String[] expectedHeavyHitters, String 
testName) {
+
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               Types.ExecMode platformOld = rtplatform;
+               rtplatform = Types.ExecMode.SINGLE_NODE;
+
+               Thread t1 = null, t2 = null;
+
+               try {
+                       getAndLoadTestConfiguration(testName);
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+
+                       writeInputMatrices();
+
+                       int port1 = getRandomAvailablePort();
+                       int port2 = getRandomAvailablePort();
+                       t1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
+                       t2 = startLocalFedWorkerThread(port2);
+
+                       // Run actual dml script with federated matrix
+                       fullDMLScriptName = HOME + testName + ".dml";
+                       programArgs = new String[] {"-stats", "-nvargs",
+                                "r=" + rows, "c=" + cols,
+                               "A=" + input("A"),
+                               "B1=" + TestUtils.federatedAddress(port1, 
input("B1")),
+                               "B2=" + TestUtils.federatedAddress(port2, 
input("B2")),
+                               "C1=" + TestUtils.federatedAddress(port1, 
input("C1")),
+                               "C2=" + TestUtils.federatedAddress(port2, 
input("C2")),
+                               "lB1=" + input("B1"),
+                               "lB2=" + input("B2"),
+                               "Z=" + output("Z")};
+                       runTest(true, false, null, -1);
+
+                       // Run reference dml script with normal matrix
+                       fullDMLScriptName = HOME + testName + "Reference.dml";
+                       programArgs = new String[] {"-nvargs",
+                               "r=" + rows, "c=" + cols,
+                               "A=" + input("A"),
+                               "B1=" + input("B1"),
+                               "B2=" + input("B2"),
+                               "C1=" + input("C1"),
+                               "C2=" + input("C2"),
+                               "Z=" + expected("Z")};
+                       runTest(true, false, null, -1);
+
+                       // compare via files
+                       compareResults(1e-9);
+                       if(!heavyHittersContainsAllString(expectedHeavyHitters))
+                               fail("The following expected heavy hitters are 
missing: "
+                                       + 
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+               }
+               finally {
+                       TestUtils.shutdownThreads(t1, t2);
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               }
+       }
+
+       /**
+        * Override default configuration with custom test configuration to 
ensure scratch space and local temporary
+        * directory locations are also updated.
+        */
+       @Override
+       protected File getConfigTemplateFile() {
+               // Instrumentation in this test's output log to show custom 
configuration file used for template.
+               LOG.info("This test case overrides default configuration with " 
+ TEST_CONF_FILE.getPath());
+               return TEST_CONF_FILE;
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java
new file mode 100644
index 0000000000..326516d423
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.fedplanning;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.io.File;
+import java.util.Arrays;
+
+import static org.junit.Assert.fail;
+
+public class FederatedKMeansPlanningTest extends AutomatedTestBase {
+       private static final Log LOG = 
LogFactory.getLog(FederatedKMeansPlanningTest.class.getName());
+
+       private final static String TEST_DIR = "functions/privacy/fedplanning/";
+       private final static String TEST_NAME = "FederatedKMeansPlanningTest";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedKMeansPlanningTest.class.getSimpleName() + "/";
+       private static File TEST_CONF_FILE;
+
+       private final static int blocksize = 1024;
+       public final int rows = 1000;
+       public final int cols = 100;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+       }
+
+       @Test
+       public void runKMeansFOUTTest(){
+               String[] expectedHeavyHitters = new String[]{};
+               setTestConf("SystemDS-config-fout.xml");
+               loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+       }
+
+       @Test
+       public void runKMeansHeuristicTest(){
+               String[] expectedHeavyHitters = new String[]{};
+               setTestConf("SystemDS-config-heuristic.xml");
+               loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+       }
+
+       @Test
+       public void runRuntimeTest(){
+               String[] expectedHeavyHitters = new String[]{};
+               TEST_CONF_FILE = new 
File("src/test/config/SystemDS-config.xml");
+               loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+       }
+
+       private void setTestConf(String test_conf){
+               TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, test_conf);
+       }
+
+       /**
+        * Override default configuration with custom test configuration to 
ensure
+        * scratch space and local temporary directory locations are also 
updated.
+        */
+       @Override
+       protected File getConfigTemplateFile() {
+               // Instrumentation in this test's output log to show custom 
configuration file used for template.
+               LOG.info("This test case overrides default configuration with " 
+ TEST_CONF_FILE.getPath());
+               return TEST_CONF_FILE;
+       }
+
+       private void writeInputMatrices(){
+               writeStandardRowFedMatrix("X1", 65);
+               writeStandardRowFedMatrix("X2", 75);
+       }
+
+       private void writeStandardMatrix(String matrixName, long seed, int 
numRows){
+               double[][] matrix = getRandomMatrix(numRows, cols, 0, 1, 1, 
seed);
+               writeStandardMatrix(matrixName, numRows, matrix);
+       }
+
+       private void writeStandardMatrix(String matrixName, int numRows, 
double[][] matrix){
+               MatrixCharacteristics mc = new MatrixCharacteristics(numRows, 
cols, blocksize, (long) numRows * cols);
+               writeInputMatrixWithMTD(matrixName, matrix, false, mc);
+       }
+
+       private void writeStandardRowFedMatrix(String matrixName, long seed){
+               int halfRows = rows/2;
+               writeStandardMatrix(matrixName, seed, halfRows);
+       }
+
+       private void loadAndRunTest(String[] expectedHeavyHitters, String 
testName){
+
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               Types.ExecMode platformOld = rtplatform;
+               rtplatform = Types.ExecMode.SINGLE_NODE;
+
+               Thread t1 = null, t2 = null;
+
+               try {
+                       getAndLoadTestConfiguration(testName);
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+
+                       writeInputMatrices();
+
+                       int port1 = getRandomAvailablePort();
+                       int port2 = getRandomAvailablePort();
+                       t1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
+                       t2 = startLocalFedWorkerThread(port2);
+
+                       // Run actual dml script with federated matrix
+                       fullDMLScriptName = HOME + testName + ".dml";
+                       programArgs = new String[] { "-stats", "-nvargs",
+                               "X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                               "X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                               "Y=" + input("Y"), "r=" + rows, "c=" + cols, 
"Z=" + output("Z")};
+                       runTest(true, false, null, -1);
+
+                       // Run reference dml script with normal matrix
+                       fullDMLScriptName = HOME + testName + "Reference.dml";
+                       programArgs = new String[] {"-nvargs", "X1=" + 
input("X1"), "X2=" + input("X2"),
+                               "Y=" + input("Y"), "Z=" + expected("Z")};
+                       runTest(true, false, null, -1);
+
+                       // compare via files
+                       compareResults(1e-9);
+                       if 
(!heavyHittersContainsAllString(expectedHeavyHitters))
+                               fail("The following expected heavy hitters are 
missing: "
+                                       + 
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+               }
+               finally {
+                       TestUtils.shutdownThreads(t1, t2);
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java
new file mode 100644
index 0000000000..3e8f8719a6
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.fedplanning;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import java.io.File;
+import java.util.Arrays;
+
+import static org.junit.Assert.fail;
+
[email protected]
+public class FederatedL2SVMPlanningTest extends AutomatedTestBase {
+       private static final Log LOG = 
LogFactory.getLog(FederatedL2SVMPlanningTest.class.getName());
+
+       private final static String TEST_DIR = "functions/privacy/fedplanning/";
+       private final static String TEST_NAME = "FederatedL2SVMPlanningTest";
+       private final static String TEST_NAME_2 = 
"FederatedL2SVMFunctionPlanningTest";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedL2SVMPlanningTest.class.getSimpleName() + "/";
+       private static File TEST_CONF_FILE;
+
+       private final static int blocksize = 1024;
+       public final int rows = 1000;
+       public final int cols = 100;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+               addTestConfiguration(TEST_NAME_2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_2, new String[] {"Z"}));
+       }
+
+       @Test
+       public void runL2SVMFOUTTest(){
+               String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*", "fed_tak+*", "fed_+*",
+                       "fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
+               setTestConf("SystemDS-config-fout.xml");
+               loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+       }
+
+       @Test
+       public void runL2SVMHeuristicTest(){
+               String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*"};
+               setTestConf("SystemDS-config-heuristic.xml");
+               loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+       }
+
+       @Test
+       @Ignore //TODO
+       public void runL2SVMFunctionFOUTTest(){
+               String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*", "fed_tak+*", "fed_+*",
+                       "fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
+               setTestConf("SystemDS-config-fout.xml");
+               loadAndRunTest(expectedHeavyHitters, TEST_NAME_2);
+       }
+
+       @Test
+       @Ignore //TODO
+       public void runL2SVMFunctionHeuristicTest(){
+               String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*"};
+               setTestConf("SystemDS-config-heuristic.xml");
+               loadAndRunTest(expectedHeavyHitters, TEST_NAME_2);
+       }
+
+       private void setTestConf(String test_conf){
+               TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, test_conf);
+       }
+
+       private void writeInputMatrices(){
+               writeStandardRowFedMatrix("X1", 65);
+               writeStandardRowFedMatrix("X2", 75);
+               writeBinaryVector("Y", 44);
+       }
+
+       private void writeBinaryVector(String matrixName, long seed){
+               double[][] matrix = getRandomMatrix(rows, 1, -1, 1, 1, seed);
+               for(int i = 0; i < rows; i++)
+                       matrix[i][0] = (matrix[i][0] > 0) ? 1 : -1;
+               MatrixCharacteristics mc = new MatrixCharacteristics(rows, 1, 
blocksize, rows);
+               writeInputMatrixWithMTD(matrixName, matrix, false, mc);
+       }
+
+       @SuppressWarnings("unused")
+       private void writeStandardMatrix(String matrixName, long seed){
+               writeStandardMatrix(matrixName, seed, rows);
+       }
+
+       private void writeStandardMatrix(String matrixName, long seed, int 
numRows){
+               double[][] matrix = getRandomMatrix(numRows, cols, 0, 1, 1, 
seed);
+               writeStandardMatrix(matrixName, numRows, matrix);
+       }
+
+       private void writeStandardMatrix(String matrixName, int numRows, 
double[][] matrix){
+               MatrixCharacteristics mc = new MatrixCharacteristics(numRows, 
cols, blocksize, (long) numRows * cols);
+               writeInputMatrixWithMTD(matrixName, matrix, false, mc);
+       }
+
+       private void writeStandardRowFedMatrix(String matrixName, long seed){
+               int halfRows = rows/2;
+               writeStandardMatrix(matrixName, seed, halfRows);
+       }
+
+       private void loadAndRunTest(String[] expectedHeavyHitters, String 
testName){
+
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               Types.ExecMode platformOld = rtplatform;
+               rtplatform = Types.ExecMode.SINGLE_NODE;
+
+               Thread t1 = null, t2 = null;
+
+               try {
+                       getAndLoadTestConfiguration(testName);
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+
+                       writeInputMatrices();
+
+                       int port1 = getRandomAvailablePort();
+                       int port2 = getRandomAvailablePort();
+                       t1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
+                       t2 = startLocalFedWorkerThread(port2);
+
+                       // Run actual dml script with federated matrix
+                       fullDMLScriptName = HOME + testName + ".dml";
+                       programArgs = new String[] { "-stats", "-nvargs",
+                               "X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                               "X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                               "Y=" + input("Y"), "r=" + rows, "c=" + cols, 
"Z=" + output("Z")};
+                       runTest(true, false, null, -1);
+
+                       // Run reference dml script with normal matrix
+                       fullDMLScriptName = HOME + testName + "Reference.dml";
+                       programArgs = new String[] {"-nvargs", "X1=" + 
input("X1"), "X2=" + input("X2"),
+                               "Y=" + input("Y"), "Z=" + expected("Z")};
+                       runTest(true, false, null, -1);
+
+                       // compare via files
+                       compareResults(1e-9);
+                       if 
(!heavyHittersContainsAllString(expectedHeavyHitters))
+                               fail("The following expected heavy hitters are 
missing: "
+                                       + 
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+               }
+               finally {
+                       TestUtils.shutdownThreads(t1, t2);
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               }
+       }
+
+       /**
+        * Override default configuration with custom test configuration to 
ensure
+        * scratch space and local temporary directory locations are also 
updated.
+        */
+       @Override
+       protected File getConfigTemplateFile() {
+               // Instrumentation in this test's output log to show custom 
configuration file used for template.
+               LOG.info("This test case overrides default configuration with " 
+ TEST_CONF_FILE.getPath());
+               return TEST_CONF_FILE;
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java
new file mode 100644
index 0000000000..5b54f14d05
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.fedplanning;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+import java.io.File;
+import java.util.Arrays;
+import java.util.Collection;
+
+import static org.junit.Assert.fail;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
+       private static final Log LOG = 
LogFactory.getLog(FederatedMultiplyPlanningTest.class.getName());
+
+       private final static String TEST_DIR = "functions/privacy/fedplanning/";
+       private final static String TEST_NAME = "FederatedMultiplyPlanningTest";
+       private final static String TEST_NAME_2 = 
"FederatedMultiplyPlanningTest2";
+       private final static String TEST_NAME_3 = 
"FederatedMultiplyPlanningTest3";
+       private final static String TEST_NAME_4 = 
"FederatedMultiplyPlanningTest4";
+       private final static String TEST_NAME_5 = 
"FederatedMultiplyPlanningTest5";
+       private final static String TEST_NAME_6 = 
"FederatedMultiplyPlanningTest6";
+       private final static String TEST_NAME_7 = 
"FederatedMultiplyPlanningTest7";
+       private final static String TEST_NAME_8 = 
"FederatedMultiplyPlanningTest8";
+       private final static String TEST_NAME_9 = 
"FederatedMultiplyPlanningTest9";
+       private final static String TEST_NAME_10 = 
"FederatedMultiplyPlanningTest10";
+       private final static String TEST_NAME_11 = 
"FederatedMultiplyPlanningTest11";
+       private final static String TEST_NAME_12 = 
"FederatedMultiplyPlanningTest12";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedMultiplyPlanningTest.class.getSimpleName() + "/";
+       private static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, 
"SystemDS-config-heuristic.xml");
+
+       private final static int blocksize = 1024;
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+               addTestConfiguration(TEST_NAME_2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_2, new String[] {"Z"}));
+               addTestConfiguration(TEST_NAME_3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_3, new String[] {"Z.scalar"}));
+               addTestConfiguration(TEST_NAME_4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_4, new String[] {"Z"}));
+               addTestConfiguration(TEST_NAME_5, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_5, new String[] {"Z"}));
+               addTestConfiguration(TEST_NAME_6, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_6, new String[] {"Z"}));
+               addTestConfiguration(TEST_NAME_7, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_7, new String[] {"Z"}));
+               addTestConfiguration(TEST_NAME_8, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_8, new String[] {"Z.scalar"}));
+               addTestConfiguration(TEST_NAME_9, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_9, new String[] {"Z.scalar"}));
+               addTestConfiguration(TEST_NAME_10, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_10, new String[] {"Z"}));
+               addTestConfiguration(TEST_NAME_11, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_11, new String[] {"Z"}));
+               addTestConfiguration(TEST_NAME_12, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_12, new String[] {"Z"}));
+       }
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               // rows have to be even and > 1
+               return Arrays.asList(new Object[][] {
+                       {100, 10}
+               });
+       }
+
+       @Test
+       public void federatedMultiplyCP() {
+               String[] expectedHeavyHitters = new String[]{"fed_*", 
"fed_fedinit", "fed_r'", "fed_ba+*"};
+               federatedTwoMatricesSingleNodeTest(TEST_NAME, 
expectedHeavyHitters);
+       }
+
+       @Test
+       public void federatedRowSum(){
+               String[] expectedHeavyHitters = new String[]{"fed_*", "fed_r'", 
"fed_fedinit", "fed_ba+*", "fed_uark+"};
+               federatedTwoMatricesSingleNodeTest(TEST_NAME_2, 
expectedHeavyHitters);
+       }
+
+       @Test
+       public void federatedTernarySequence(){
+               String[] expectedHeavyHitters = new String[]{"fed_+*", 
"fed_1-*", "fed_fedinit", "fed_uak+"};
+               federatedTwoMatricesSingleNodeTest(TEST_NAME_3, 
expectedHeavyHitters);
+       }
+
+       @Test
+       public void federatedAggregateBinarySequence(){
+               cols = rows;
+               String[] expectedHeavyHitters = new String[]{"fed_ba+*", 
"fed_*", "fed_fedinit"};
+               federatedTwoMatricesSingleNodeTest(TEST_NAME_4, 
expectedHeavyHitters);
+       }
+
+       @Test
+       public void federatedAggregateBinaryColFedSequence(){
+               cols = rows;
+               //TODO: When alignment checks have been added to 
getFederatedOut in AFederatedPlanner,
+               // the following expectedHeavyHitters can be added. Until then, 
fed_* will not be generated.
+               //String[] expectedHeavyHitters = new 
String[]{"fed_ba+*","fed_*","fed_fedinit"};
+               String[] expectedHeavyHitters = new 
String[]{"fed_ba+*","fed_fedinit"};
+               federatedTwoMatricesSingleNodeTest(TEST_NAME_5, 
expectedHeavyHitters);
+       }
+
+       @Test
+       public void federatedAggregateBinarySequence2(){
+               String[] expectedHeavyHitters = new 
String[]{"fed_ba+*","fed_fedinit"};
+               federatedTwoMatricesSingleNodeTest(TEST_NAME_6, 
expectedHeavyHitters);
+       }
+
+       @Test
+       public void federatedMultiplyDoubleHop() {
+               String[] expectedHeavyHitters = new String[]{"fed_*", 
"fed_fedinit", "fed_ba+*"}; //TODO "fed_r' " ?
+               federatedTwoMatricesSingleNodeTest(TEST_NAME_7, 
expectedHeavyHitters);
+       }
+
+       @Test
+       public void federatedMultiplyDoubleHop2() {
+               String[] expectedHeavyHitters = new String[]{"fed_fedinit", 
"fed_ba+*"};
+               federatedTwoMatricesSingleNodeTest(TEST_NAME_8, 
expectedHeavyHitters);
+       }
+
+       @Test
+       public void federatedMultiplyPlanningTest9(){
+               String[] expectedHeavyHitters = new String[]{"fed_+*", 
"fed_1-*", "fed_fedinit", "fed_max"}; //TODO "fed_tak+*"
+               federatedTwoMatricesSingleNodeTest(TEST_NAME_9, 
expectedHeavyHitters);
+       }
+
+       @Test
+       public void federatedMultiplyPlanningTest10(){
+               String[] expectedHeavyHitters = new String[]{"fed_fedinit", 
"fed_^2"};
+               TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, 
"SystemDS-config-fout.xml");
+               federatedTwoMatricesSingleNodeTest(TEST_NAME_10, 
expectedHeavyHitters);
+       }
+
+       @Test
+       public void federatedMultiplyPlanningTest11(){
+               String[] expectedHeavyHitters = new String[]{"fed_fedinit"};
+               federatedTwoMatricesSingleNodeTest(TEST_NAME_11, 
expectedHeavyHitters);
+       }
+
+       @Test
+       public void federatedMultiplyPlanningTest12(){
+               String[] expectedHeavyHitters = new String[]{"fed_fedinit"};
+               rows = 30;
+               cols = 30;
+               federatedTwoMatricesSingleNodeTest(TEST_NAME_12, 
expectedHeavyHitters);
+       }
+
+       private void writeStandardMatrix(String matrixName, long seed){
+               int halfRows = rows/2;
+               double[][] matrix = getRandomMatrix(halfRows, cols, 0, 1, 1, 
seed);
+               MatrixCharacteristics mc = new MatrixCharacteristics(halfRows, 
cols, blocksize, (long) halfRows * cols);
+               writeInputMatrixWithMTD(matrixName, matrix, false, mc);
+       }
+
+       private void writeColStandardMatrix(String matrixName, long seed){
+               int halfCols = cols/2;
+               double[][] matrix = getRandomMatrix(rows, halfCols, 0, 1, 1, 
seed);
+               MatrixCharacteristics mc = new MatrixCharacteristics(rows, 
halfCols, blocksize, (long) halfCols *rows);
+               writeInputMatrixWithMTD(matrixName, matrix, false, mc);
+       }
+
+       private void writeRowFederatedVector(String matrixName, long seed){
+               int halfCols = cols / 2;
+               double[][] matrix = getRandomMatrix(halfCols, 1, 0, 1, 1, seed);
+               MatrixCharacteristics mc = new MatrixCharacteristics(halfCols, 
1, blocksize, (long) halfCols *rows);
+               writeInputMatrixWithMTD(matrixName, matrix, false, mc);
+       }
+
+       private void writeInputMatrices(String testName){
+               if ( testName.equals(TEST_NAME_5) ){
+                       writeColStandardMatrix("X1", 42);
+                       writeColStandardMatrix("X2", 1340);
+                       writeColStandardMatrix("Y1", 44);
+                       writeColStandardMatrix("Y2", 21);
+               }
+               else if ( testName.equals(TEST_NAME_6) ){
+                       writeColStandardMatrix("X1", 42);
+                       writeColStandardMatrix("X2", 1340);
+                       writeRowFederatedVector("Y1", 44);
+                       writeRowFederatedVector("Y2", 21);
+               }
+               else if ( testName.equals(TEST_NAME_8) ){
+                       writeColStandardMatrix("X1", 42);
+                       writeColStandardMatrix("X2", 1340);
+                       writeColStandardMatrix("Y1", 44);
+                       writeColStandardMatrix("Y2", 21);
+                       writeColStandardMatrix("W1", 76);
+                       writeColStandardMatrix("W2", 11);
+               }
+               else if ( testName.equals(TEST_NAME_10) || 
testName.equals(TEST_NAME_12) ){
+                       writeStandardMatrix("X1", 42);
+                       writeStandardMatrix("X2", 1340);
+               }
+               else {
+                       writeStandardMatrix("X1", 42);
+                       writeStandardMatrix("X2", 1340);
+                       if ( testName.equals(TEST_NAME_4) ){
+                               writeStandardMatrix("Y1", 44);
+                               writeStandardMatrix("Y2", 21);
+                       }
+                       else {
+                               writeStandardMatrix("Y1", 44);
+                               writeStandardMatrix("Y2", 21);
+                       }
+               }
+       }
+
+       private void federatedTwoMatricesSingleNodeTest(String testName, 
String[] expectedHeavyHitters){
+               federatedTwoMatricesTest(Types.ExecMode.SINGLE_NODE, testName, 
expectedHeavyHitters);
+       }
+
+       private void federatedTwoMatricesTest(Types.ExecMode execMode, String 
testName, String[] expectedHeavyHitters) {
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               Types.ExecMode platformOld = rtplatform;
+               rtplatform = execMode;
+               if(rtplatform == Types.ExecMode.SPARK) {
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               }
+               Thread t1 = null, t2 = null;
+
+               try{
+                       getAndLoadTestConfiguration(testName);
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+
+                       writeInputMatrices(testName);
+
+                       int port1 = getRandomAvailablePort();
+                       int port2 = getRandomAvailablePort();
+                       t1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
+                       t2 = startLocalFedWorkerThread(port2);
+
+                       // Run actual dml script with federated matrix
+                       fullDMLScriptName = HOME + testName + ".dml";
+                       programArgs = new String[] {"-stats", "-nvargs", "X1=" 
+ TestUtils.federatedAddress(port1, input("X1")),
+                               "X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                               "Y1=" + TestUtils.federatedAddress(port1, 
input("Y1")),
+                               "Y2=" + TestUtils.federatedAddress(port2, 
input("Y2")), "r=" + rows, "c=" + cols, "Z=" + output("Z")};
+                       rewriteRealProgramArgs(testName, port1, port2);
+                       runTest(true, false, null, -1);
+
+                       // Run reference dml script with normal matrix
+                       fullDMLScriptName = HOME + testName + "Reference.dml";
+                       programArgs = new String[] {"-nvargs", "X1=" + 
input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
+                               "Y2=" + input("Y2"), "Z=" + expected("Z")};
+                       rewriteReferenceProgramArgs(testName);
+                       runTest(true, false, null, -1);
+
+                       // compare via files
+                       compareResults(1e-9);
+                       if 
(!heavyHittersContainsAllString(expectedHeavyHitters))
+                               fail("The following expected heavy hitters are 
missing: "
+                                       + 
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+               } finally {
+                       TestUtils.shutdownThreads(t1, t2);
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               }
+       }
+
+       private void rewriteRealProgramArgs(String testName, int port1, int 
port2){
+               if ( testName.equals(TEST_NAME_4) || 
testName.equals(TEST_NAME_5) ){
+                       programArgs = new String[] {"-stats","-nvargs", "X1=" + 
TestUtils.federatedAddress(port1, input("X1")),
+                               "X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                               "Y1=" + input("Y1"),
+                               "Y2=" + input("Y2"), "r=" + rows, "c=" + cols, 
"Z=" + output("Z")};
+               } else if ( testName.equals(TEST_NAME_8) ){
+                       programArgs = new String[] {"-stats","-nvargs", "X1=" + 
TestUtils.federatedAddress(port1, input("X1")),
+                               "X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                               "Y1=" + TestUtils.federatedAddress(port1, 
input("Y1")),
+                               "Y2=" + TestUtils.federatedAddress(port2, 
input("Y2")),
+                               "W1=" + input("W1"),
+                               "W2=" + input("W2"),
+                               "r=" + rows, "c=" + cols, "Z=" + output("Z")};
+               }
+       }
+
+       private void rewriteReferenceProgramArgs(String testName){
+               if ( testName.equals(TEST_NAME_8) ){
+                       programArgs = new String[] {"-nvargs", "X1=" + 
input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
+                               "Y2=" + input("Y2"), "W1=" + input("W1"), "W2=" 
+ input("W2"), "Z=" + expected("Z")};
+               }
+       }
+
+       /**
+        * Override default configuration with custom test configuration to 
ensure
+        * scratch space and local temporary directory locations are also 
updated.
+        */
+       @Override
+       protected File getConfigTemplateFile() {
+               // Instrumentation in this test's output log to show custom 
configuration file used for template.
+               LOG.info("This test case overrides default configuration with " 
+ TEST_CONF_FILE.getPath());
+               return TEST_CONF_FILE;
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/fedplanning/FTypeCombTest.java 
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FTypeCombTest.java
new file mode 100644
index 0000000000..e36d517d98
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FTypeCombTest.java
@@ -0,0 +1,71 @@
+package org.apache.sysds.test.functions.fedplanning;
+///*
+// * Licensed to the Apache Software Foundation (ASF) under one
+// * or more contributor license agreements.  See the NOTICE file
+// * distributed with this work for additional information
+// * regarding copyright ownership.  The ASF licenses this file
+// * to you under the Apache License, Version 2.0 (the
+// * "License"); you may not use this file except in compliance
+// * with the License.  You may obtain a copy of the License at
+// *
+// *   http://www.apache.org/licenses/LICENSE-2.0
+// *
+// * Unless required by applicable law or agreed to in writing,
+// * software distributed under the License is distributed on an
+// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// * KIND, either express or implied.  See the License for the
+// * specific language governing permissions and limitations
+// * under the License.
+// */
+//
+//package org.apache.sysds.test.functions.privacy.fedplanning;
+//
+//import org.apache.sysds.hops.fedplanner.FTypes.FType;
+//import org.apache.sysds.hops.fedplanner.FederatedPlannerCostbased;
+//import org.apache.sysds.test.AutomatedTestBase;
+//import org.junit.Assert;
+//import org.junit.Test;
+//
+//import java.util.ArrayList;
+//import java.util.List;
+//
+//public class FTypeCombTest extends AutomatedTestBase {
+//
+//     @Override public void setUp() {}
+//
+//     @Test
+//     public void ftypeCombTest(){
+//             List<FType> secondInput = new ArrayList<>();
+//             secondInput.add(null);
+//             List<List<FType>> inputFTypes = List.of(
+//                     List.of(FType.ROW,FType.COL),
+//                     secondInput,
+//                     List.of(FType.BROADCAST,FType.FULL)
+//             );
+//
+//             FederatedPlannerCostbased planner = new 
FederatedPlannerCostbased();
+//             List<List<FType>> actualCombinations = 
planner.getAllCombinations(inputFTypes);
+//
+//             List<FType> expected1 = new ArrayList<>();
+//             expected1.add(FType.ROW);
+//             expected1.add(null);
+//             expected1.add(FType.BROADCAST);
+//             List<FType> expected2 = new ArrayList<>();
+//             expected2.add(FType.ROW);
+//             expected2.add(null);
+//             expected2.add(FType.FULL);
+//             List<FType> expected3 = new ArrayList<>();
+//             expected3.add(FType.COL);
+//             expected3.add(null);
+//             expected3.add(FType.BROADCAST);
+//             List<FType> expected4 = new ArrayList<>();
+//             expected4.add(FType.COL);
+//             expected4.add(null);
+//             expected4.add(FType.FULL);
+//             List<List<FType>> expectedCombinations = 
List.of(expected1,expected2, expected3, expected4);
+//
+//             Assert.assertEquals(expectedCombinations.size(), 
actualCombinations.size());
+//             for (List<FType> expectedComb : expectedCombinations)
+//                     
Assert.assertTrue(actualCombinations.contains(expectedComb));
+//     }
+//}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedCostEstimatorTest.java
 
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedCostEstimatorTest.java
new file mode 100644
index 0000000000..073c8f1d9d
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedCostEstimatorTest.java
@@ -0,0 +1,373 @@
+package org.apache.sysds.test.functions.fedplanning;
+///*
+// * Licensed to the Apache Software Foundation (ASF) under one
+// * or more contributor license agreements.  See the NOTICE file
+// * distributed with this work for additional information
+// * regarding copyright ownership.  The ASF licenses this file
+// * to you under the Apache License, Version 2.0 (the
+// * "License"); you may not use this file except in compliance
+// * with the License.  You may obtain a copy of the License at
+// *
+// *   http://www.apache.org/licenses/LICENSE-2.0
+// *
+// * Unless required by applicable law or agreed to in writing,
+// * software distributed under the License is distributed on an
+// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// * KIND, either express or implied.  See the License for the
+// * specific language governing permissions and limitations
+// * under the License.
+// */
+//
+//package org.apache.sysds.test.functions.privacy.fedplanning;
+//
+//import net.jcip.annotations.NotThreadSafe;
+//import org.apache.sysds.api.DMLScript;
+//import org.apache.sysds.common.Types;
+//import org.apache.sysds.conf.ConfigurationManager;
+//import org.apache.sysds.conf.DMLConfig;
+//import org.apache.sysds.hops.AggBinaryOp;
+//import org.apache.sysds.hops.BinaryOp;
+//import org.apache.sysds.hops.DataOp;
+//import org.apache.sysds.hops.Hop;
+//import org.apache.sysds.hops.LiteralOp;
+//import org.apache.sysds.hops.NaryOp;
+//import org.apache.sysds.hops.ReorgOp;
+//import org.apache.sysds.hops.cost.FederatedCost;
+//import org.apache.sysds.hops.cost.FederatedCostEstimator;
+//import org.apache.sysds.hops.fedplanner.FederatedPlannerCostbased;
+//import org.apache.sysds.hops.ipa.FunctionCallGraph;
+//import org.apache.sysds.parser.DMLProgram;
+//import org.apache.sysds.parser.DMLTranslator;
+//import org.apache.sysds.parser.LanguageException;
+//import org.apache.sysds.parser.ParserFactory;
+//import org.apache.sysds.parser.ParserWrapper;
+//import org.apache.sysds.parser.StatementBlock;
+//import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+//import org.apache.sysds.test.AutomatedTestBase;
+//import org.apache.sysds.test.TestConfiguration;
+//import org.junit.After;
+//import org.junit.Assert;
+//import org.junit.Before;
+//import org.junit.BeforeClass;
+//import org.junit.Test;
+//
+//import java.io.FileNotFoundException;
+//import java.io.IOException;
+//import java.util.HashMap;
+//import java.util.HashSet;
+//import java.util.Set;
+//
+//import static org.apache.sysds.common.Types.OpOp2.MULT;
+//
+//@NotThreadSafe
+//public class FederatedCostEstimatorTest extends AutomatedTestBase {
+//
+//     private static final String TEST_DIR = "functions/privacy/fedplanning/";
+//     private static final String HOME = SCRIPT_DIR + TEST_DIR;
+//     private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedCostEstimatorTest.class.getSimpleName() + "/";
+//     FederatedCostEstimator fedCostEstimator = new FederatedCostEstimator();
+//
+//     private static double COMPUTE_FLOPS;
+//     private static double READ_PS;
+//     private static double NETWORK_PS;
+//
+//     @Override
+//     public void setUp() {}
+//
+//     @BeforeClass
+//     public static void storeConstants(){
+//             COMPUTE_FLOPS = 
FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS;
+//             READ_PS = FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS;
+//             NETWORK_PS = 
FederatedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS;
+//     }
+//
+//     @Before
+//     public void setConstants(){
+//             FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
+//             FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+//             FederatedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS = 5;
+//     }
+//
+//     @After
+//     public void resetConstants(){
+//             FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 
COMPUTE_FLOPS;
+//             FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = READ_PS;
+//             FederatedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS = 
NETWORK_PS;
+//     }
+//
+//     @Test
+//     public void simpleBinary() {
+//
+//             /*
+//              * HOP                  Occurences              ComputeCost     
        ReadCost        ComputeCostFinal        ReadCostFinal
+//              * 
------------------------------------------------------------------------------------------
+//              * LiteralOp    16                              1               
                0                       0.0625                          0
+//              * DataGenOp    2                               100             
                64                      6.25                            6.4
+//              * BinaryOp             1                               100     
                        1600            6.25                            160
+//              * TOSTRING             1                               1       
                        800                     0.0625                          
80
+//              * UnaryOp              1                               1       
                        8                       0.0625                          
0.8
+//              */
+//             double computeCost = (16+2*100+100+1+1) / 
(FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * 
FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+//             double readCost = (2*64+1600+800+8) / 
(FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+//
+//             double expectedCost = computeCost + readCost;
+//             runTest("BinaryCostEstimatorTest.dml", false, expectedCost);
+//     }
+//
+//     @Test
+//     public void simpleBinaryHopRelTest() {
+//             runHopRelTest("BinaryCostEstimatorTest.dml", false);
+//     }
+//
+//     @Test
+//     public void ifElseTest(){
+//             double computeCost = (16+2*100+100+1+1) / 
(FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * 
FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+//             double readCost = (2*64+1600+800+8) / 
(FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+//             double expectedCost = ((computeCost + readCost + 0.8 + 0.0625 + 
0.0625) / 2) + 0.0625 + 0.8 + 0.0625;
+//             runTest("IfElseCostEstimatorTest.dml", false, expectedCost);
+//     }
+//
+//     @Test
+//     public void ifElseHopRelTest(){
+//             runHopRelTest("IfElseCostEstimatorTest.dml", false);
+//     }
+//
+//     @Test
+//     public void whileTest(){
+//             double computeCost = (16+2*100+100+1+1) / 
(FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * 
FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+//             double readCost = (2*64+1600+800+8) / 
(FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+//             double expectedCost = (computeCost + readCost + 0.0625 + 0.0625 
+ 0.8) * StatementBlock.DEFAULT_LOOP_REPETITIONS;
+//             runTest("WhileCostEstimatorTest.dml", false, expectedCost);
+//     }
+//
+//     @Test
+//     public void whileHopRelTest(){
+//             runHopRelTest("WhileCostEstimatorTest.dml", false);
+//     }
+//
+//     @Test
+//     public void forLoopTest(){
+//             double computeCost = (16+2*100+100+1+1) / 
(FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * 
FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+//             double readCost = (2*64+1600+800+8) / 
(FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+//             double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 
0.0625 + 0.0625 + 0.8 + 0.0625;
+//             double expectedCost = (computeCost + readCost + predicateCost) 
* 5;
+//             runTest("ForLoopCostEstimatorTest.dml", false, expectedCost);
+//     }
+//
+//     @Test
+//     public void forLoopHopRelTest(){
+//             runHopRelTest("ForLoopCostEstimatorTest.dml", false);
+//     }
+//
+//     @Test
+//     public void parForLoopTest(){
+//             double computeCost = (16+2*100+100+1+1) / 
(FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * 
FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+//             double readCost = (2*64+1600+800+8) / 
(FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+//             double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 
0.0625 + 0.0625 + 0.8 + 0.0625;
+//             double expectedCost = (computeCost + readCost + predicateCost) 
* 5;
+//             runTest("ParForLoopCostEstimatorTest.dml", false, expectedCost);
+//     }
+//
+//     @Test
+//     public void parForLoopHopRelTest(){
+//             runHopRelTest("ParForLoopCostEstimatorTest.dml", false);
+//     }
+//
+//     @Test
+//     public void functionTest(){
+//             double computeCost = (16+2*100+100+1+1) / 
(FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * 
FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+//             double readCost = (2*64+1600+800+8) / 
(FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+//             double expectedCost = (computeCost + readCost);
+//             runTest("FunctionCostEstimatorTest.dml", false, expectedCost);
+//     }
+//
+//     @Test
+//     public void functionHopRelTest(){
+//             runHopRelTest("FunctionCostEstimatorTest.dml", false);
+//     }
+//
+//     @Test
+//     public void federatedMultiply() {
+//
+//             double literalOpCost = 10*0.0625;
+//             double naryOpCostSpecial = (0.125+2.2);
+//             double naryOpCostSpecial2 = (0.25+6.4);
+//             double naryOpCost = 4*(0.125+1.6);
+//             double reorgOpCost = 6250+80015.2+160030.4;
+//             double binaryOpMultCost = 3125+160000;
+//             double aggBinaryOpCost = 125000+160015.2+160030.4+190.4;
+//             double dataOpCost = 2*(6250+5.6);
+//             double dataOpWriteCost = 6.25+100.3;
+//
+//             double expectedCost = literalOpCost + naryOpCost + 
naryOpCostSpecial + naryOpCostSpecial2 + reorgOpCost
+//                     + binaryOpMultCost + aggBinaryOpCost + dataOpCost + 
dataOpWriteCost;
+//             runTest("FederatedMultiplyCostEstimatorTest.dml", false, 
expectedCost);
+//
+//             double aggBinaryActualCost = hops.stream()
+//                     .filter(hop -> hop instanceof AggBinaryOp)
+//                     .mapToDouble(aggHop -> 
aggHop.getFederatedCost().getTotal()-aggHop.getFederatedCost().getInputTotalCost())
+//                     .sum();
+//             Assert.assertEquals(aggBinaryOpCost, aggBinaryActualCost, 
0.0001);
+//
+//             double writeActualCost = hops.stream()
+//                     .filter(hop -> hop instanceof DataOp)
+//                     .mapToDouble(writeHop -> 
writeHop.getFederatedCost().getTotal()-writeHop.getFederatedCost().getInputTotalCost())
+//                     .sum();
+//             Assert.assertEquals(dataOpWriteCost+dataOpCost, 
writeActualCost, 0.0001);
+//     }
+//
+//     Set<Hop> hops = new HashSet<>();
+//
+//     /**
+//      * Recursively adds the hop and its inputs to the set of hops.
+//      * @param hop root to be added to set of hops
+//      */
+//     private void addHop(Hop hop){
+//             hops.add(hop);
+//             for(Hop inHop : hop.getInput()){
+//                     addHop(inHop);
+//             }
+//     }
+//
+//     /**
+//      * Sets dimensions of federated X and Y and sets binary multiplication 
to FOUT.
+//      * @param prog dml program where the HOPS are modified
+//      */
+//     private void modifyFedouts(DMLProgram prog){
+//             prog.getStatementBlocks().forEach(stmBlock -> 
stmBlock.getHops().forEach(this::addHop));
+//             hops.forEach(hop -> {
+//                     if ( hop instanceof DataOp || (hop instanceof BinaryOp 
&& ((BinaryOp) hop).getOp() == MULT ) ){
+//                             
hop.setFederatedOutput(FEDInstruction.FederatedOutput.FOUT);
+//                             hop.setExecType(Types.ExecType.FED);
+//                     } else {
+//                             
hop.setFederatedOutput(FEDInstruction.FederatedOutput.LOUT);
+//                     }
+//                     if ( hop.getOpString().equals("Fed Y") || 
hop.getOpString().equals("Fed X") ){
+//                             hop.setDim1(10000);
+//                             hop.setDim2(10);
+//                     }
+//             });
+//     }
+//
+//     @SuppressWarnings("unused")
+//     private void printHopsInfo(){
+//             //LiteralOp
+//             long literalCount = hops.stream().filter(hop -> hop instanceof 
LiteralOp).count();
+//             System.out.println("LiteralOp Count: " + literalCount);
+//             //NaryOp
+//             long naryCount = hops.stream().filter(hop -> hop instanceof 
NaryOp).count();
+//             System.out.println("NaryOp Count " + naryCount);
+//             //ReorgOp
+//             long reorgCount = hops.stream().filter(hop -> hop instanceof 
ReorgOp).count();
+//             System.out.println("ReorgOp Count: " + reorgCount);
+//             //BinaryOp
+//             long binaryCount = hops.stream().filter(hop -> hop instanceof 
BinaryOp).count();
+//             System.out.println("Binary count: " + binaryCount);
+//             //AggBinaryOp
+//             long aggBinaryCount = hops.stream().filter(hop -> hop 
instanceof AggBinaryOp).count();
+//             System.out.println("AggBinaryOp Count: " + aggBinaryCount);
+//             //DataOp
+//             long dataOpCount = hops.stream().filter(hop -> hop instanceof 
DataOp).count();
+//             System.out.println("DataOp Count: " + dataOpCount);
+//
+//             
hops.stream().map(Hop::getClass).distinct().forEach(System.out::println);
+//     }
+//
+//     private DMLProgram testSetup(String scriptFilename) throws IOException{
+//             setTestConfig(scriptFilename);
+//             String dmlScriptString = readScript(scriptFilename);
+//
+//             //parsing, dependency analysis and constructing hops (step 3 
and 4 in DMLScript.java)
+//             ParserWrapper parser = ParserFactory.createParser();
+//             DMLProgram prog = 
parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new 
HashMap<>());
+//             DMLTranslator dmlt = new DMLTranslator(prog);
+//             dmlt.liveVariableAnalysis(prog);
+//             dmlt.validateParseTree(prog);
+//             dmlt.constructHops(prog);
+//             if ( 
scriptFilename.equals("FederatedMultiplyCostEstimatorTest.dml")){
+//                     modifyFedouts(prog);
+//                     dmlt.rewriteHopsDAG(prog);
+//                     hops = new HashSet<>();
+//                     prog.getStatementBlocks().forEach(stmBlock -> 
stmBlock.getHops().forEach(this::addHop));
+//             }
+//             return prog;
+//     }
+//
+//     private void compareResults(DMLProgram prog) {
+//             FederatedPlannerCostbased rewriter = new 
FederatedPlannerCostbased();
+//             rewriter.rewriteProgram(prog, new FunctionCallGraph(prog), 
null);
+//
+//             double actualCost = 0;
+//             for ( Hop root : rewriter.getTerminalHops() ){
+//                     actualCost += root.getFederatedCost().getTotal();
+//             }
+//
+//
+//             rewriter.getTerminalHops().forEach(Hop::resetFederatedCost);
+//             fedCostEstimator = new FederatedCostEstimator();
+//             double expectedCost = 0;
+//             for ( Hop root : rewriter.getTerminalHops() )
+//                     expectedCost += 
fedCostEstimator.costEstimate(root).getTotal();
+//             Assert.assertEquals(expectedCost, actualCost, 0.0001);
+//     }
+//
+//     private void runHopRelTest( String scriptFilename, boolean 
expectedException ) {
+//             boolean raisedException = false;
+//             try
+//             {
+//                     DMLProgram prog = testSetup(scriptFilename);
+//                     compareResults(prog);
+//             }
+//             catch(LanguageException ex) {
+//                     raisedException = true;
+//                     if(raisedException!=expectedException)
+//                             ex.printStackTrace();
+//             }
+//             catch(Exception ex2) {
+//                     ex2.printStackTrace();
+//                     throw new RuntimeException(ex2);
+//             }
+//
+//             Assert.assertEquals("Expected exception does not match raised 
exception",
+//                     expectedException, raisedException);
+//     }
+//
+//     private void runTest( String scriptFilename, boolean expectedException, 
double expectedCost ) {
+//             boolean raisedException = false;
+//             try
+//             {
+//                     DMLProgram prog = testSetup(scriptFilename);
+//
+//                     fedCostEstimator = new FederatedCostEstimator();
+//                     FederatedCost actualCost = 
fedCostEstimator.costEstimate(prog);
+//                     Assert.assertEquals(expectedCost, 
actualCost.getTotal(), 0.0001);
+//             }
+//             catch(LanguageException ex) {
+//                     raisedException = true;
+//                     if(raisedException!=expectedException)
+//                             ex.printStackTrace();
+//             }
+//             catch(Exception ex2) {
+//                     ex2.printStackTrace();
+//                     throw new RuntimeException(ex2);
+//             }
+//
+//             Assert.assertEquals("Expected exception does not match raised 
exception",
+//                     expectedException, raisedException);
+//     }
+//
+//     private void setTestConfig(String scriptFilename) throws 
FileNotFoundException {
+//             int index = scriptFilename.lastIndexOf(".dml");
+//             String testName = scriptFilename.substring(0, index > 0 ? index 
: scriptFilename.length());
+//             TestConfiguration testConfig = new 
TestConfiguration(TEST_CLASS_DIR, testName, new String[] {});
+//             addTestConfiguration(testName, testConfig);
+//             loadTestConfiguration(testConfig);
+//
+//             DMLConfig conf = new DMLConfig(getCurConfigFile().getPath());
+//             ConfigurationManager.setLocalConfig(conf);
+//     }
+//
+//     private static String readScript(String scriptFilename) throws 
IOException {
+//             return DMLScript.readDMLScript(true, HOME + scriptFilename);
+//     }
+//}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedDynamicPlanningTest.java
 
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedDynamicPlanningTest.java
new file mode 100644
index 0000000000..23da01d438
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedDynamicPlanningTest.java
@@ -0,0 +1,188 @@
+package org.apache.sysds.test.functions.fedplanning;
+///*
+// *  Licensed to the Apache Software Foundation (ASF) under one
+// *  or more contributor license agreements.  See the NOTICE file
+// *  distributed with this work for additional information
+// *  regarding copyright ownership.  The ASF licenses this file
+// *  to you under the Apache License, Version 2.0 (the
+// *  "License"); you may not use this file except in compliance
+// *  with the License.  You may obtain a copy of the License at
+// *
+// *    http://www.apache.org/licenses/LICENSE-2.0
+// *
+// *  Unless required by applicable law or agreed to in writing,
+// *  software distributed under the License is distributed on an
+// *  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// *  KIND, either express or implied.  See the License for the
+// *  specific language governing permissions and limitations
+// *  under the License.
+// */
+//
+//package org.apache.sysds.test.functions.privacy.fedplanning;
+//
+//import org.apache.commons.logging.Log;
+//import org.apache.commons.logging.LogFactory;
+//import org.apache.sysds.api.DMLScript;
+//import org.apache.sysds.common.Types;
+//import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+//import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+//import org.apache.sysds.test.AutomatedTestBase;
+//import org.apache.sysds.test.TestConfiguration;
+//import org.apache.sysds.test.TestUtils;
+//import org.junit.Test;
+//
+//import java.io.File;
+//import java.util.Arrays;
+//
+//import static org.junit.Assert.fail;
+//
+//@net.jcip.annotations.NotThreadSafe
+//public class FederatedDynamicPlanningTest extends AutomatedTestBase {
+//     private static final Log LOG = 
LogFactory.getLog(FederatedDynamicPlanningTest.class.getName());
+//
+//     private final static String TEST_DIR = "functions/privacy/fedplanning/";
+//     private final static String TEST_NAME = 
"FederatedDynamicFunctionPlanningTest";
+//     private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedDynamicPlanningTest.class.getSimpleName() + "/";
+//     private static File TEST_CONF_FILE;
+//
+//     private final static int blocksize = 1024;
+//     public final int rows = 1000;
+//     public final int cols = 1000;
+//
+//     @Override
+//     public void setUp() {
+//             TestUtils.clearAssertionInformation();
+//             addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+//     }
+//
+//     @Test
+//     public void runDynamicFullFunctionTest() {
+//             // compared to `FederatedL2SVMPlanningTest` this does not 
create `fed_+*` or `fed_tsmm`, probably due to
+//             // some rewrites not being applied. Might be a bug.
+//             String[] expectedHeavyHitters = new String[] {"fed_fedinit", 
"fed_ba+*", "fed_tak+*", "fed_max",
+//                             "fed_1-*", "fed_>"};
+//             setTestConf("SystemDS-config-fout.xml");
+//             loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+//     }
+//
+//     @Test
+//     public void runDynamicHeuristicFunctionTest() {
+//             // compared to `FederatedL2SVMPlanningTest` this does not 
create `fed_+*` or `fed_tsmm`, probably due to
+//             // some rewrites not being applied. Might be a bug.
+//             String[] expectedHeavyHitters = new String[] {"fed_fedinit", 
"fed_ba+*"};
+//             setTestConf("SystemDS-config-heuristic.xml");
+//             loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+//     }
+//
+//     @Test
+//     public void runDynamicCostBasedFunctionTest() {
+//             // compared to `FederatedL2SVMPlanningTest` this does not 
create `fed_+*` or `fed_tsmm`, probably due to
+//             // some rewrites not being applied. Might be a bug.
+//             String[] expectedHeavyHitters = new String[] {"fed_fedinit", 
"fed_ba+*", "fed_tak+*", "fed_max",
+//                     "fed_1-*", "fed_>"};
+//             setTestConf("SystemDS-config-cost-based.xml");
+//             loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+//     }
+//
+//     private void setTestConf(String test_conf) {
+//             TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, test_conf);
+//     }
+//
+//     private void writeInputMatrices() {
+//             writeBinaryVector("A", 42, rows, null);
+//             writeStandardMatrix("B1", 65, rows / 2, cols, null);
+//             writeStandardMatrix("B2", 75, rows / 2, cols, null);
+//             writeStandardMatrix("C1", 13, rows, cols / 2, null);
+//             writeStandardMatrix("C2", 17, rows, cols / 2, null);
+//     }
+//
+//     private void writeBinaryVector(String matrixName, long seed, int 
numRows, PrivacyConstraint privacyConstraint){
+//             double[][] matrix = getRandomMatrix(numRows, 1, -1, 1, 1, seed);
+//             for(int i = 0; i < numRows; i++)
+//                     matrix[i][0] = (matrix[i][0] > 0) ? 1 : -1;
+//             MatrixCharacteristics mc = new MatrixCharacteristics(numRows, 
1, blocksize, numRows);
+//             writeInputMatrixWithMTD(matrixName, matrix, false, mc, 
privacyConstraint);
+//     }
+//
+//     private void writeStandardMatrix(String matrixName, long seed, int 
numRows, int numCols,
+//             PrivacyConstraint privacyConstraint) {
+//             double[][] matrix = getRandomMatrix(numRows, numCols, 0, 1, 1, 
seed);
+//             writeStandardMatrix(matrixName, numRows, numCols, 
privacyConstraint, matrix);
+//     }
+//
+//     private void writeStandardMatrix(String matrixName, int numRows, int 
numCols, PrivacyConstraint privacyConstraint,
+//             double[][] matrix) {
+//             MatrixCharacteristics mc = new MatrixCharacteristics(numRows, 
numCols, blocksize, (long) numRows * numCols);
+//             writeInputMatrixWithMTD(matrixName, matrix, false, mc, 
privacyConstraint);
+//     }
+//
+//     private void loadAndRunTest(String[] expectedHeavyHitters, String 
testName) {
+//
+//             boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+//             Types.ExecMode platformOld = rtplatform;
+//             rtplatform = Types.ExecMode.SINGLE_NODE;
+//
+//             Thread t1 = null, t2 = null;
+//
+//             try {
+//                     getAndLoadTestConfiguration(testName);
+//                     String HOME = SCRIPT_DIR + TEST_DIR;
+//
+//                     writeInputMatrices();
+//
+//                     int port1 = getRandomAvailablePort();
+//                     int port2 = getRandomAvailablePort();
+//                     t1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
+//                     t2 = startLocalFedWorkerThread(port2);
+//
+//                     // Run actual dml script with federated matrix
+//                     fullDMLScriptName = HOME + testName + ".dml";
+//                     programArgs = new String[] {"-stats", "-nvargs",
+//                              "r=" + rows, "c=" + cols,
+//                             "A=" + input("A"),
+//                             "B1=" + TestUtils.federatedAddress(port1, 
input("B1")),
+//                             "B2=" + TestUtils.federatedAddress(port2, 
input("B2")),
+//                             "C1=" + TestUtils.federatedAddress(port1, 
input("C1")),
+//                             "C2=" + TestUtils.federatedAddress(port2, 
input("C2")),
+//                             "lB1=" + input("B1"),
+//                             "lB2=" + input("B2"),
+//                             "Z=" + output("Z")};
+//                     runTest(true, false, null, -1);
+//
+//                     // Run reference dml script with normal matrix
+//                     fullDMLScriptName = HOME + testName + "Reference.dml";
+//                     programArgs = new String[] {"-nvargs",
+//                             "r=" + rows, "c=" + cols,
+//                             "A=" + input("A"),
+//                             "B1=" + input("B1"),
+//                             "B2=" + input("B2"),
+//                             "C1=" + input("C1"),
+//                             "C2=" + input("C2"),
+//                             "Z=" + expected("Z")};
+//                     runTest(true, false, null, -1);
+//
+//                     // compare via files
+//                     compareResults(1e-9);
+//                     if(!heavyHittersContainsAllString(expectedHeavyHitters))
+//                             fail("The following expected heavy hitters are 
missing: "
+//                                     + 
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+//             }
+//             finally {
+//                     TestUtils.shutdownThreads(t1, t2);
+//                     rtplatform = platformOld;
+//                     DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+//             }
+//     }
+//
+//     /**
+//      * Override default configuration with custom test configuration to 
ensure scratch space and local temporary
+//      * directory locations are also updated.
+//      */
+//     @Override
+//     protected File getConfigTemplateFile() {
+//             // Instrumentation in this test's output log to show custom 
configuration file used for template.
+//             LOG.info("This test case overrides default configuration with " 
+ TEST_CONF_FILE.getPath());
+//             return TEST_CONF_FILE;
+//     }
+//
+//}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedKMeansPlanningTest.java
 
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedKMeansPlanningTest.java
new file mode 100644
index 0000000000..48d9a06b8c
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedKMeansPlanningTest.java
@@ -0,0 +1,168 @@
+package org.apache.sysds.test.functions.fedplanning;
+///*
+// * Licensed to the Apache Software Foundation (ASF) under one
+// * or more contributor license agreements.  See the NOTICE file
+// * distributed with this work for additional information
+// * regarding copyright ownership.  The ASF licenses this file
+// * to you under the Apache License, Version 2.0 (the
+// * "License"); you may not use this file except in compliance
+// * with the License.  You may obtain a copy of the License at
+// *
+// *   http://www.apache.org/licenses/LICENSE-2.0
+// *
+// * Unless required by applicable law or agreed to in writing,
+// * software distributed under the License is distributed on an
+// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// * KIND, either express or implied.  See the License for the
+// * specific language governing permissions and limitations
+// * under the License.
+// */
+//
+//package org.apache.sysds.test.functions.privacy.fedplanning;
+//
+//import org.apache.commons.logging.Log;
+//import org.apache.commons.logging.LogFactory;
+//import org.apache.sysds.api.DMLScript;
+//import org.apache.sysds.common.Types;
+//import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+//import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+//import org.apache.sysds.test.AutomatedTestBase;
+//import org.apache.sysds.test.TestConfiguration;
+//import org.apache.sysds.test.TestUtils;
+//import org.junit.Ignore;
+//import org.junit.Test;
+//
+//import java.io.File;
+//import java.util.Arrays;
+//
+//import static org.junit.Assert.fail;
+//
+//public class FederatedKMeansPlanningTest extends AutomatedTestBase {
+//     private static final Log LOG = 
LogFactory.getLog(FederatedKMeansPlanningTest.class.getName());
+//
+//     private final static String TEST_DIR = "functions/privacy/fedplanning/";
+//     private final static String TEST_NAME = "FederatedKMeansPlanningTest";
+//     private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedKMeansPlanningTest.class.getSimpleName() + "/";
+//     private static File TEST_CONF_FILE;
+//
+//     private final static int blocksize = 1024;
+//     public final int rows = 1000;
+//     public final int cols = 100;
+//
+//     @Override
+//     public void setUp() {
+//             TestUtils.clearAssertionInformation();
+//             addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+//     }
+//
+//     @Test
+//     public void runKMeansFOUTTest(){
+//             String[] expectedHeavyHitters = new String[]{};
+//             setTestConf("SystemDS-config-fout.xml");
+//             loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+//     }
+//
+//     @Test
+//     public void runKMeansHeuristicTest(){
+//             String[] expectedHeavyHitters = new String[]{};
+//             setTestConf("SystemDS-config-heuristic.xml");
+//             loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+//     }
+//
+//     @Test
+//     public void runKMeansCostBasedTest(){
+//             String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*", "fed_*", "fed_uack+", "fed_bcumoffk+"};
+//             setTestConf("SystemDS-config-cost-based.xml");
+//             loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+//     }
+//
+//     @Test
+//     public void runRuntimeTest(){
+//             String[] expectedHeavyHitters = new String[]{};
+//             TEST_CONF_FILE = new 
File("src/test/config/SystemDS-config.xml");
+//             loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+//     }
+//
+//     private void setTestConf(String test_conf){
+//             TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, test_conf);
+//     }
+//
+//     /**
+//      * Override default configuration with custom test configuration to 
ensure
+//      * scratch space and local temporary directory locations are also 
updated.
+//      */
+//     @Override
+//     protected File getConfigTemplateFile() {
+//             // Instrumentation in this test's output log to show custom 
configuration file used for template.
+//             LOG.info("This test case overrides default configuration with " 
+ TEST_CONF_FILE.getPath());
+//             return TEST_CONF_FILE;
+//     }
+//
+//     private void writeInputMatrices(){
+//             writeStandardRowFedMatrix("X1", 65, null);
+//             writeStandardRowFedMatrix("X2", 75, null);
+//     }
+//
+//     private void writeStandardMatrix(String matrixName, long seed, int 
numRows, PrivacyConstraint privacyConstraint){
+//             double[][] matrix = getRandomMatrix(numRows, cols, 0, 1, 1, 
seed);
+//             writeStandardMatrix(matrixName, numRows, privacyConstraint, 
matrix);
+//     }
+//
+//     private void writeStandardMatrix(String matrixName, int numRows, 
PrivacyConstraint privacyConstraint, double[][] matrix){
+//             MatrixCharacteristics mc = new MatrixCharacteristics(numRows, 
cols, blocksize, (long) numRows * cols);
+//             writeInputMatrixWithMTD(matrixName, matrix, false, mc, 
privacyConstraint);
+//     }
+//
+//     private void writeStandardRowFedMatrix(String matrixName, long seed, 
PrivacyConstraint privacyConstraint){
+//             int halfRows = rows/2;
+//             writeStandardMatrix(matrixName, seed, halfRows, 
privacyConstraint);
+//     }
+//
+//     private void loadAndRunTest(String[] expectedHeavyHitters, String 
testName){
+//
+//             boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+//             Types.ExecMode platformOld = rtplatform;
+//             rtplatform = Types.ExecMode.SINGLE_NODE;
+//
+//             Thread t1 = null, t2 = null;
+//
+//             try {
+//                     getAndLoadTestConfiguration(testName);
+//                     String HOME = SCRIPT_DIR + TEST_DIR;
+//
+//                     writeInputMatrices();
+//
+//                     int port1 = getRandomAvailablePort();
+//                     int port2 = getRandomAvailablePort();
+//                     t1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
+//                     t2 = startLocalFedWorkerThread(port2);
+//
+//                     // Run actual dml script with federated matrix
+//                     fullDMLScriptName = HOME + testName + ".dml";
+//                     programArgs = new String[] { "-stats", "-nvargs",
+//                             "X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+//                             "X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+//                             "Y=" + input("Y"), "r=" + rows, "c=" + cols, 
"Z=" + output("Z")};
+//                     runTest(true, false, null, -1);
+//
+//                     // Run reference dml script with normal matrix
+//                     fullDMLScriptName = HOME + testName + "Reference.dml";
+//                     programArgs = new String[] {"-nvargs", "X1=" + 
input("X1"), "X2=" + input("X2"),
+//                             "Y=" + input("Y"), "Z=" + expected("Z")};
+//                     runTest(true, false, null, -1);
+//
+//                     // compare via files
+//                     compareResults(1e-9);
+//                     if 
(!heavyHittersContainsAllString(expectedHeavyHitters))
+//                             fail("The following expected heavy hitters are 
missing: "
+//                                     + 
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+//             }
+//             finally {
+//                     TestUtils.shutdownThreads(t1, t2);
+//                     rtplatform = platformOld;
+//                     DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+//             }
+//     }
+//
+//
+//}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedL2SVMPlanningTest.java
 
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedL2SVMPlanningTest.java
new file mode 100644
index 0000000000..0ef4fde6a4
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedL2SVMPlanningTest.java
@@ -0,0 +1,202 @@
+package org.apache.sysds.test.functions.fedplanning;
+///*
+// * Licensed to the Apache Software Foundation (ASF) under one
+// * or more contributor license agreements.  See the NOTICE file
+// * distributed with this work for additional information
+// * regarding copyright ownership.  The ASF licenses this file
+// * to you under the Apache License, Version 2.0 (the
+// * "License"); you may not use this file except in compliance
+// * with the License.  You may obtain a copy of the License at
+// *
+// *   http://www.apache.org/licenses/LICENSE-2.0
+// *
+// * Unless required by applicable law or agreed to in writing,
+// * software distributed under the License is distributed on an
+// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// * KIND, either express or implied.  See the License for the
+// * specific language governing permissions and limitations
+// * under the License.
+// */
+//
+//package org.apache.sysds.test.functions.privacy.fedplanning;
+//
+//import org.apache.commons.logging.Log;
+//import org.apache.commons.logging.LogFactory;
+//import org.apache.sysds.api.DMLScript;
+//import org.apache.sysds.common.Types;
+//import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+//import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+//import org.apache.sysds.test.AutomatedTestBase;
+//import org.apache.sysds.test.TestConfiguration;
+//import org.apache.sysds.test.TestUtils;
+//import org.junit.Ignore;
+//import org.junit.Test;
+//
+//import java.io.File;
+//import java.util.Arrays;
+//
+//import static org.junit.Assert.fail;
+//
+//@net.jcip.annotations.NotThreadSafe
+//public class FederatedL2SVMPlanningTest extends AutomatedTestBase {
+//     private static final Log LOG = 
LogFactory.getLog(FederatedL2SVMPlanningTest.class.getName());
+//
+//     private final static String TEST_DIR = "functions/privacy/fedplanning/";
+//     private final static String TEST_NAME = "FederatedL2SVMPlanningTest";
+//     private final static String TEST_NAME_2 = 
"FederatedL2SVMFunctionPlanningTest";
+//     private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedL2SVMPlanningTest.class.getSimpleName() + "/";
+//     private static File TEST_CONF_FILE;
+//
+//     private final static int blocksize = 1024;
+//     public final int rows = 1000;
+//     public final int cols = 100;
+//
+//     @Override
+//     public void setUp() {
+//             TestUtils.clearAssertionInformation();
+//             addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+//             addTestConfiguration(TEST_NAME_2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_2, new String[] {"Z"}));
+//     }
+//
+//     @Test
+//     public void runL2SVMFOUTTest(){
+//             String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*", "fed_tak+*", "fed_+*",
+//                     "fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
+//             setTestConf("SystemDS-config-fout.xml");
+//             loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+//     }
+//
+//     @Test
+//     public void runL2SVMHeuristicTest(){
+//             String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*"};
+//             setTestConf("SystemDS-config-heuristic.xml");
+//             loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+//     }
+//
+//     @Test
+//     public void runL2SVMCostBasedTest(){
+//             String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*", "fed_tak+*", "fed_+*",
+//                     "fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
+//             setTestConf("SystemDS-config-cost-based.xml");
+//             loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+//     }
+//
+//     @Test
+//     public void runL2SVMFunctionFOUTTest(){
+//             String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*", "fed_tak+*", "fed_+*",
+//                     "fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
+//             setTestConf("SystemDS-config-fout.xml");
+//             loadAndRunTest(expectedHeavyHitters, TEST_NAME_2);
+//     }
+//
+//     @Test
+//     public void runL2SVMFunctionHeuristicTest(){
+//             String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*"};
+//             setTestConf("SystemDS-config-heuristic.xml");
+//             loadAndRunTest(expectedHeavyHitters, TEST_NAME_2);
+//     }
+//
+//     @Test
+//     public void runL2SVMFunctionCostBasedTest(){
+//             String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*", "fed_tak+*", "fed_+*",
+//                     "fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
+//             setTestConf("SystemDS-config-cost-based.xml");
+//             loadAndRunTest(expectedHeavyHitters, TEST_NAME_2);
+//     }
+//
+//     private void setTestConf(String test_conf){
+//             TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, test_conf);
+//     }
+//
+//     private void writeInputMatrices(){
+//             writeStandardRowFedMatrix("X1", 65, null);
+//             writeStandardRowFedMatrix("X2", 75, null);
+//             writeBinaryVector("Y", 44, null);
+//     }
+//
+//     private void writeBinaryVector(String matrixName, long seed, 
PrivacyConstraint privacyConstraint){
+//             double[][] matrix = getRandomMatrix(rows, 1, -1, 1, 1, seed);
+//             for(int i = 0; i < rows; i++)
+//                     matrix[i][0] = (matrix[i][0] > 0) ? 1 : -1;
+//             MatrixCharacteristics mc = new MatrixCharacteristics(rows, 1, 
blocksize, rows);
+//             writeInputMatrixWithMTD(matrixName, matrix, false, mc, 
privacyConstraint);
+//     }
+//
+//     @SuppressWarnings("unused")
+//     private void writeStandardMatrix(String matrixName, long seed, 
PrivacyConstraint privacyConstraint){
+//             writeStandardMatrix(matrixName, seed, rows, privacyConstraint);
+//     }
+//
+//     private void writeStandardMatrix(String matrixName, long seed, int 
numRows, PrivacyConstraint privacyConstraint){
+//             double[][] matrix = getRandomMatrix(numRows, cols, 0, 1, 1, 
seed);
+//             writeStandardMatrix(matrixName, numRows, privacyConstraint, 
matrix);
+//     }
+//
+//     private void writeStandardMatrix(String matrixName, int numRows, 
PrivacyConstraint privacyConstraint, double[][] matrix){
+//             MatrixCharacteristics mc = new MatrixCharacteristics(numRows, 
cols, blocksize, (long) numRows * cols);
+//             writeInputMatrixWithMTD(matrixName, matrix, false, mc, 
privacyConstraint);
+//     }
+//
+//     private void writeStandardRowFedMatrix(String matrixName, long seed, 
PrivacyConstraint privacyConstraint){
+//             int halfRows = rows/2;
+//             writeStandardMatrix(matrixName, seed, halfRows, 
privacyConstraint);
+//     }
+//
+//     private void loadAndRunTest(String[] expectedHeavyHitters, String 
testName){
+//
+//             boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+//             Types.ExecMode platformOld = rtplatform;
+//             rtplatform = Types.ExecMode.SINGLE_NODE;
+//
+//             Thread t1 = null, t2 = null;
+//
+//             try {
+//                     getAndLoadTestConfiguration(testName);
+//                     String HOME = SCRIPT_DIR + TEST_DIR;
+//
+//                     writeInputMatrices();
+//
+//                     int port1 = getRandomAvailablePort();
+//                     int port2 = getRandomAvailablePort();
+//                     t1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
+//                     t2 = startLocalFedWorkerThread(port2);
+//
+//                     // Run actual dml script with federated matrix
+//                     fullDMLScriptName = HOME + testName + ".dml";
+//                     programArgs = new String[] { "-stats", "-nvargs",
+//                             "X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+//                             "X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+//                             "Y=" + input("Y"), "r=" + rows, "c=" + cols, 
"Z=" + output("Z")};
+//                     runTest(true, false, null, -1);
+//
+//                     // Run reference dml script with normal matrix
+//                     fullDMLScriptName = HOME + testName + "Reference.dml";
+//                     programArgs = new String[] {"-nvargs", "X1=" + 
input("X1"), "X2=" + input("X2"),
+//                             "Y=" + input("Y"), "Z=" + expected("Z")};
+//                     runTest(true, false, null, -1);
+//
+//                     // compare via files
+//                     compareResults(1e-9);
+//                     if 
(!heavyHittersContainsAllString(expectedHeavyHitters))
+//                             fail("The following expected heavy hitters are 
missing: "
+//                                     + 
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+//             }
+//             finally {
+//                     TestUtils.shutdownThreads(t1, t2);
+//                     rtplatform = platformOld;
+//                     DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+//             }
+//     }
+//
+//     /**
+//      * Override default configuration with custom test configuration to 
ensure
+//      * scratch space and local temporary directory locations are also 
updated.
+//      */
+//     @Override
+//     protected File getConfigTemplateFile() {
+//             // Instrumentation in this test's output log to show custom 
configuration file used for template.
+//             LOG.info("This test case overrides default configuration with " 
+ TEST_CONF_FILE.getPath());
+//             return TEST_CONF_FILE;
+//     }
+//
+//}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedMultiplyPlanningTest.java
 
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedMultiplyPlanningTest.java
new file mode 100644
index 0000000000..f3eee4ee41
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedMultiplyPlanningTest.java
@@ -0,0 +1,334 @@
+package org.apache.sysds.test.functions.fedplanning;
+///*
+// * Licensed to the Apache Software Foundation (ASF) under one
+// * or more contributor license agreements.  See the NOTICE file
+// * distributed with this work for additional information
+// * regarding copyright ownership.  The ASF licenses this file
+// * to you under the Apache License, Version 2.0 (the
+// * "License"); you may not use this file except in compliance
+// * with the License.  You may obtain a copy of the License at
+// *
+// *   http://www.apache.org/licenses/LICENSE-2.0
+// *
+// * Unless required by applicable law or agreed to in writing,
+// * software distributed under the License is distributed on an
+// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// * KIND, either express or implied.  See the License for the
+// * specific language governing permissions and limitations
+// * under the License.
+// */
+//
+//package org.apache.sysds.test.functions.privacy.fedplanning;
+//
+//import org.apache.commons.logging.Log;
+//import org.apache.commons.logging.LogFactory;
+//import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+//import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
+//import org.junit.Test;
+//import org.junit.runner.RunWith;
+//import org.junit.runners.Parameterized;
+//import org.apache.sysds.api.DMLScript;
+//import org.apache.sysds.common.Types;
+//import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+//import org.apache.sysds.test.AutomatedTestBase;
+//import org.apache.sysds.test.TestConfiguration;
+//import org.apache.sysds.test.TestUtils;
+//
+//import java.io.File;
+//import java.util.Arrays;
+//import java.util.Collection;
+//
+//import static org.junit.Assert.fail;
+//
+//@RunWith(value = Parameterized.class)
+//@net.jcip.annotations.NotThreadSafe
+//public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
+//     private static final Log LOG = 
LogFactory.getLog(FederatedMultiplyPlanningTest.class.getName());
+//
+//     private final static String TEST_DIR = "functions/privacy/fedplanning/";
+//     private final static String TEST_NAME = "FederatedMultiplyPlanningTest";
+//     private final static String TEST_NAME_2 = 
"FederatedMultiplyPlanningTest2";
+//     private final static String TEST_NAME_3 = 
"FederatedMultiplyPlanningTest3";
+//     private final static String TEST_NAME_4 = 
"FederatedMultiplyPlanningTest4";
+//     private final static String TEST_NAME_5 = 
"FederatedMultiplyPlanningTest5";
+//     private final static String TEST_NAME_6 = 
"FederatedMultiplyPlanningTest6";
+//     private final static String TEST_NAME_7 = 
"FederatedMultiplyPlanningTest7";
+//     private final static String TEST_NAME_8 = 
"FederatedMultiplyPlanningTest8";
+//     private final static String TEST_NAME_9 = 
"FederatedMultiplyPlanningTest9";
+//     private final static String TEST_NAME_10 = 
"FederatedMultiplyPlanningTest10";
+//     private final static String TEST_NAME_11 = 
"FederatedMultiplyPlanningTest11";
+//     private final static String TEST_NAME_12 = 
"FederatedMultiplyPlanningTest12";
+//     private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedMultiplyPlanningTest.class.getSimpleName() + "/";
+//     private static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, 
"SystemDS-config-cost-based.xml");
+//
+//     private final static int blocksize = 1024;
+//     @Parameterized.Parameter()
+//     public int rows;
+//     @Parameterized.Parameter(1)
+//     public int cols;
+//
+//     @Override
+//     public void setUp() {
+//             TestUtils.clearAssertionInformation();
+//             addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+//             addTestConfiguration(TEST_NAME_2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_2, new String[] {"Z"}));
+//             addTestConfiguration(TEST_NAME_3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_3, new String[] {"Z.scalar"}));
+//             addTestConfiguration(TEST_NAME_4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_4, new String[] {"Z"}));
+//             addTestConfiguration(TEST_NAME_5, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_5, new String[] {"Z"}));
+//             addTestConfiguration(TEST_NAME_6, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_6, new String[] {"Z"}));
+//             addTestConfiguration(TEST_NAME_7, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_7, new String[] {"Z"}));
+//             addTestConfiguration(TEST_NAME_8, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_8, new String[] {"Z.scalar"}));
+//             addTestConfiguration(TEST_NAME_9, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_9, new String[] {"Z.scalar"}));
+//             addTestConfiguration(TEST_NAME_10, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_10, new String[] {"Z"}));
+//             addTestConfiguration(TEST_NAME_11, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_11, new String[] {"Z"}));
+//             addTestConfiguration(TEST_NAME_12, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_12, new String[] {"Z"}));
+//     }
+//
+//     @Parameterized.Parameters
+//     public static Collection<Object[]> data() {
+//             // rows have to be even and > 1
+//             return Arrays.asList(new Object[][] {
+//                     {100, 10}
+//             });
+//     }
+//
+//     @Test
+//     public void federatedMultiplyCP() {
+//             String[] expectedHeavyHitters = new String[]{"fed_*", 
"fed_fedinit", "fed_r'", "fed_ba+*"};
+//             federatedTwoMatricesSingleNodeTest(TEST_NAME, 
expectedHeavyHitters);
+//     }
+//
+//     @Test
+//     public void federatedRowSum(){
+//             String[] expectedHeavyHitters = new String[]{"fed_*", "fed_r'", 
"fed_fedinit", "fed_ba+*", "fed_uark+"};
+//             federatedTwoMatricesSingleNodeTest(TEST_NAME_2, 
expectedHeavyHitters);
+//     }
+//
+//     @Test
+//     public void federatedTernarySequence(){
+//             String[] expectedHeavyHitters = new String[]{"fed_+*", 
"fed_1-*", "fed_fedinit", "fed_uak+"};
+//             federatedTwoMatricesSingleNodeTest(TEST_NAME_3, 
expectedHeavyHitters);
+//     }
+//
+//     @Test
+//     public void federatedAggregateBinarySequence(){
+//             cols = rows;
+//             String[] expectedHeavyHitters = new String[]{"fed_ba+*", 
"fed_*", "fed_fedinit"};
+//             federatedTwoMatricesSingleNodeTest(TEST_NAME_4, 
expectedHeavyHitters);
+//     }
+//
+//     @Test
+//     public void federatedAggregateBinaryColFedSequence(){
+//             cols = rows;
+//             //TODO: When alignment checks have been added to 
getFederatedOut in AFederatedPlanner,
+//             // the following expectedHeavyHitters can be added. Until then, 
fed_* will not be generated.
+//             //String[] expectedHeavyHitters = new 
String[]{"fed_ba+*","fed_*","fed_fedinit"};
+//             String[] expectedHeavyHitters = new 
String[]{"fed_ba+*","fed_fedinit"};
+//             federatedTwoMatricesSingleNodeTest(TEST_NAME_5, 
expectedHeavyHitters);
+//     }
+//
+//     @Test
+//     public void federatedAggregateBinarySequence2(){
+//             String[] expectedHeavyHitters = new 
String[]{"fed_ba+*","fed_fedinit"};
+//             federatedTwoMatricesSingleNodeTest(TEST_NAME_6, 
expectedHeavyHitters);
+//     }
+//
+//     @Test
+//     public void federatedMultiplyDoubleHop() {
+//             String[] expectedHeavyHitters = new String[]{"fed_*", 
"fed_fedinit", "fed_r'", "fed_ba+*"};
+//             federatedTwoMatricesSingleNodeTest(TEST_NAME_7, 
expectedHeavyHitters);
+//     }
+//
+//     @Test
+//     public void federatedMultiplyDoubleHop2() {
+//             String[] expectedHeavyHitters = new String[]{"fed_fedinit", 
"fed_ba+*"};
+//             federatedTwoMatricesSingleNodeTest(TEST_NAME_8, 
expectedHeavyHitters);
+//     }
+//
+//     @Test
+//     public void federatedMultiplyPlanningTest9(){
+//             String[] expectedHeavyHitters = new String[]{"fed_+*", 
"fed_1-*", "fed_fedinit", "fed_tak+*", "fed_max"};
+//             federatedTwoMatricesSingleNodeTest(TEST_NAME_9, 
expectedHeavyHitters);
+//     }
+//
+//     @Test
+//     public void federatedMultiplyPlanningTest10(){
+//             String[] expectedHeavyHitters = new String[]{"fed_fedinit", 
"fed_^2"};
+//             TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, 
"SystemDS-config-fout.xml");
+//             federatedTwoMatricesSingleNodeTest(TEST_NAME_10, 
expectedHeavyHitters);
+//     }
+//
+//     @Test
+//     public void federatedMultiplyPlanningTest11(){
+//             String[] expectedHeavyHitters = new String[]{"fed_fedinit"};
+//             federatedTwoMatricesSingleNodeTest(TEST_NAME_11, 
expectedHeavyHitters);
+//     }
+//
+//     @Test
+//     public void federatedMultiplyPlanningTest12(){
+//             String[] expectedHeavyHitters = new String[]{"fed_fedinit"};
+//             rows = 30;
+//             cols = 30;
+//             federatedTwoMatricesSingleNodeTest(TEST_NAME_12, 
expectedHeavyHitters);
+//     }
+//
+//     private void writeStandardMatrix(String matrixName, long seed){
+//             writeStandardMatrix(matrixName, seed, new 
PrivacyConstraint(PrivacyConstraint.PrivacyLevel.PrivateAggregation));
+//     }
+//
+//     private void writeStandardMatrix(String matrixName, long seed, 
PrivacyConstraint privacyConstraint){
+//             int halfRows = rows/2;
+//             double[][] matrix = getRandomMatrix(halfRows, cols, 0, 1, 1, 
seed);
+//             MatrixCharacteristics mc = new MatrixCharacteristics(halfRows, 
cols, blocksize, (long) halfRows * cols);
+//             writeInputMatrixWithMTD(matrixName, matrix, false, mc, 
privacyConstraint);
+//     }
+//
+//     private void writeColStandardMatrix(String matrixName, long seed){
+//             writeColStandardMatrix(matrixName, seed, new 
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+//     }
+//
+//     private void writeColStandardMatrix(String matrixName, long seed, 
PrivacyConstraint privacyConstraint){
+//             int halfCols = cols/2;
+//             double[][] matrix = getRandomMatrix(rows, halfCols, 0, 1, 1, 
seed);
+//             MatrixCharacteristics mc = new MatrixCharacteristics(rows, 
halfCols, blocksize, (long) halfCols *rows);
+//             writeInputMatrixWithMTD(matrixName, matrix, false, mc, 
privacyConstraint);
+//     }
+//
+//     private void writeRowFederatedVector(String matrixName, long seed){
+//             writeRowFederatedVector(matrixName, seed, new 
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+//     }
+//
+//     private void writeRowFederatedVector(String matrixName, long seed, 
PrivacyConstraint privacyConstraint){
+//             int halfCols = cols / 2;
+//             double[][] matrix = getRandomMatrix(halfCols, 1, 0, 1, 1, seed);
+//             MatrixCharacteristics mc = new MatrixCharacteristics(halfCols, 
1, blocksize, (long) halfCols *rows);
+//             writeInputMatrixWithMTD(matrixName, matrix, false, mc, 
privacyConstraint);
+//     }
+//
+//     private void writeInputMatrices(String testName){
+//             if ( testName.equals(TEST_NAME_5) ){
+//                     writeColStandardMatrix("X1", 42);
+//                     writeColStandardMatrix("X2", 1340);
+//                     writeColStandardMatrix("Y1", 44, null);
+//                     writeColStandardMatrix("Y2", 21, null);
+//             }
+//             else if ( testName.equals(TEST_NAME_6) ){
+//                     writeColStandardMatrix("X1", 42);
+//                     writeColStandardMatrix("X2", 1340);
+//                     writeRowFederatedVector("Y1", 44);
+//                     writeRowFederatedVector("Y2", 21);
+//             }
+//             else if ( testName.equals(TEST_NAME_8) ){
+//                     writeColStandardMatrix("X1", 42, null);
+//                     writeColStandardMatrix("X2", 1340, null);
+//                     writeColStandardMatrix("Y1", 44, null);
+//                     writeColStandardMatrix("Y2", 21, null);
+//                     writeColStandardMatrix("W1", 76, null);
+//                     writeColStandardMatrix("W2", 11, null);
+//             }
+//             else if ( testName.equals(TEST_NAME_10) || 
testName.equals(TEST_NAME_12) ){
+//                     writeStandardMatrix("X1", 42, null);
+//                     writeStandardMatrix("X2", 1340, null);
+//             }
+//             else {
+//                     writeStandardMatrix("X1", 42);
+//                     writeStandardMatrix("X2", 1340);
+//                     if ( testName.equals(TEST_NAME_4) ){
+//                             writeStandardMatrix("Y1", 44, null);
+//                             writeStandardMatrix("Y2", 21, null);
+//                     }
+//                     else {
+//                             writeStandardMatrix("Y1", 44);
+//                             writeStandardMatrix("Y2", 21);
+//                     }
+//             }
+//     }
+//
+//     private void federatedTwoMatricesSingleNodeTest(String testName, 
String[] expectedHeavyHitters){
+//             federatedTwoMatricesTest(Types.ExecMode.SINGLE_NODE, testName, 
expectedHeavyHitters);
+//     }
+//
+//     private void federatedTwoMatricesTest(Types.ExecMode execMode, String 
testName, String[] expectedHeavyHitters) {
+//             boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+//             Types.ExecMode platformOld = rtplatform;
+//             rtplatform = execMode;
+//             if(rtplatform == Types.ExecMode.SPARK) {
+//                     DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+//             }
+//             Thread t1 = null, t2 = null;
+//
+//             try{
+//                     getAndLoadTestConfiguration(testName);
+//                     String HOME = SCRIPT_DIR + TEST_DIR;
+//
+//                     writeInputMatrices(testName);
+//
+//                     int port1 = getRandomAvailablePort();
+//                     int port2 = getRandomAvailablePort();
+//                     t1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
+//                     t2 = startLocalFedWorkerThread(port2);
+//
+//                     // Run actual dml script with federated matrix
+//                     fullDMLScriptName = HOME + testName + ".dml";
+//                     programArgs = new String[] {"-stats", "-nvargs", "X1=" 
+ TestUtils.federatedAddress(port1, input("X1")),
+//                             "X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+//                             "Y1=" + TestUtils.federatedAddress(port1, 
input("Y1")),
+//                             "Y2=" + TestUtils.federatedAddress(port2, 
input("Y2")), "r=" + rows, "c=" + cols, "Z=" + output("Z")};
+//                     rewriteRealProgramArgs(testName, port1, port2);
+//                     runTest(true, false, null, -1);
+//
+//                     // Run reference dml script with normal matrix
+//                     fullDMLScriptName = HOME + testName + "Reference.dml";
+//                     programArgs = new String[] {"-nvargs", "X1=" + 
input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
+//                             "Y2=" + input("Y2"), "Z=" + expected("Z")};
+//                     rewriteReferenceProgramArgs(testName);
+//                     runTest(true, false, null, -1);
+//
+//                     // compare via files
+//                     compareResults(1e-9);
+//                     if 
(!heavyHittersContainsAllString(expectedHeavyHitters))
+//                             fail("The following expected heavy hitters are 
missing: "
+//                                     + 
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+//             } finally {
+//                     TestUtils.shutdownThreads(t1, t2);
+//                     rtplatform = platformOld;
+//                     DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+//             }
+//     }
+//
+//     private void rewriteRealProgramArgs(String testName, int port1, int 
port2){
+//             if ( testName.equals(TEST_NAME_4) || 
testName.equals(TEST_NAME_5) ){
+//                     programArgs = new String[] {"-stats","-nvargs", "X1=" + 
TestUtils.federatedAddress(port1, input("X1")),
+//                             "X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+//                             "Y1=" + input("Y1"),
+//                             "Y2=" + input("Y2"), "r=" + rows, "c=" + cols, 
"Z=" + output("Z")};
+//             } else if ( testName.equals(TEST_NAME_8) ){
+//                     programArgs = new String[] {"-stats","-nvargs", "X1=" + 
TestUtils.federatedAddress(port1, input("X1")),
+//                             "X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+//                             "Y1=" + TestUtils.federatedAddress(port1, 
input("Y1")),
+//                             "Y2=" + TestUtils.federatedAddress(port2, 
input("Y2")),
+//                             "W1=" + input("W1"),
+//                             "W2=" + input("W2"),
+//                             "r=" + rows, "c=" + cols, "Z=" + output("Z")};
+//             }
+//     }
+//
+//     private void rewriteReferenceProgramArgs(String testName){
+//             if ( testName.equals(TEST_NAME_8) ){
+//                     programArgs = new String[] {"-nvargs", "X1=" + 
input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
+//                             "Y2=" + input("Y2"), "W1=" + input("W1"), "W2=" 
+ input("W2"), "Z=" + expected("Z")};
+//             }
+//     }
+//
+//     /**
+//      * Override default configuration with custom test configuration to 
ensure
+//      * scratch space and local temporary directory locations are also 
updated.
+//      */
+//     @Override
+//     protected File getConfigTemplateFile() {
+//             // Instrumentation in this test's output log to show custom 
configuration file used for template.
+//             LOG.info("This test case overrides default configuration with " 
+ TEST_CONF_FILE.getPath());
+//             return TEST_CONF_FILE;
+//     }
+//}
+//


Reply via email to