This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 7a9ecff801 [SYSTEMDS-3790] Restore federated planner tests (all,
heuristic)
7a9ecff801 is described below
commit 7a9ecff8014c9ba490910f7058082d2faba42ea3
Author: min-guk <[email protected]>
AuthorDate: Sun Nov 17 14:04:21 2024 +0100
[SYSTEMDS-3790] Restore federated planner tests (all, heuristic)
Closes #2139.
---
.github/workflows/javaTests.yml | 2 +-
.../org/apache/sysds/hops/fedplanner/FTypes.java | 9 +-
.../fedplanning/FederatedDynamicPlanningTest.java | 176 ++++++++++
.../fedplanning/FederatedKMeansPlanningTest.java | 156 +++++++++
.../fedplanning/FederatedL2SVMPlanningTest.java | 185 ++++++++++
.../fedplanning/FederatedMultiplyPlanningTest.java | 318 ++++++++++++++++++
.../test/functions/fedplanning/FTypeCombTest.java | 71 ++++
.../fedplanning/FederatedCostEstimatorTest.java | 373 +++++++++++++++++++++
.../fedplanning/FederatedDynamicPlanningTest.java | 188 +++++++++++
.../fedplanning/FederatedKMeansPlanningTest.java | 168 ++++++++++
.../fedplanning/FederatedL2SVMPlanningTest.java | 202 +++++++++++
.../fedplanning/FederatedMultiplyPlanningTest.java | 334 ++++++++++++++++++
12 files changed, 2174 insertions(+), 8 deletions(-)
diff --git a/.github/workflows/javaTests.yml b/.github/workflows/javaTests.yml
index cd6e28670e..9f0258a0d1 100644
--- a/.github/workflows/javaTests.yml
+++ b/.github/workflows/javaTests.yml
@@ -62,7 +62,7 @@ jobs:
"**.functions.compress.**,**.functions.data.tensor.**,**.functions.codegenalg.parttwo.**,**.functions.codegen.**,**.functions.caching.**",
"**.functions.binary.matrix_full_cellwise.**,**.functions.binary.matrix_full_other.**",
"**.functions.federated.algorithms.**,**.functions.federated.io.**,**.functions.federated.paramserv.**",
- "**.functions.federated.transform.**",
+
"**.functions.federated.transform.**,**.functions.federated.fedplanner.**",
"**.functions.federated.primitives.part1.** -Dtest-threadCount=1
-Dtest-forkCount=1",
"**.functions.federated.primitives.part2.** -Dtest-threadCount=1
-Dtest-forkCount=1",
"**.functions.federated.primitives.part3.** -Dtest-threadCount=1
-Dtest-forkCount=1",
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java
index a82a56e88b..de9c9cb670 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java
@@ -44,13 +44,8 @@ public class FTypes
return this != NONE && this != RUNTIME;
}
public static boolean isCompiled(String planner) {
- try {
- return
FederatedPlanner.valueOf(planner.toUpperCase()).isCompiled();
- }
- catch(Exception ex) {
- ex.printStackTrace();
- return false;
- }
+ return planner != null
+ &&
FederatedPlanner.valueOf(planner.toUpperCase()).isCompiled();
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java
new file mode 100644
index 0000000000..bd098bf827
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.fedplanning;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import java.io.File;
+import java.util.Arrays;
+
+import static org.junit.Assert.fail;
+
[email protected]
+public class FederatedDynamicPlanningTest extends AutomatedTestBase {
+ private static final Log LOG =
LogFactory.getLog(FederatedDynamicPlanningTest.class.getName());
+
+ private final static String TEST_DIR = "functions/privacy/fedplanning/";
+ private final static String TEST_NAME =
"FederatedDynamicFunctionPlanningTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedDynamicPlanningTest.class.getSimpleName() + "/";
+ private static File TEST_CONF_FILE;
+
+ private final static int blocksize = 1024;
+ public final int rows = 1000;
+ public final int cols = 1000;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+ }
+
+ @Test
+ @Ignore
+ public void runDynamicFullFunctionTest() {
+ // compared to `FederatedL2SVMPlanningTest` this does not
create `fed_+*` or `fed_tsmm`, probably due to
+ // some rewrites not being applied. Might be a bug.
+ String[] expectedHeavyHitters = new String[] {"fed_fedinit",
"fed_ba+*", "fed_tak+*", "fed_max",
+ "fed_1-*", "fed_>"};
+ setTestConf("SystemDS-config-fout.xml");
+ loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+ }
+
+ @Test
+ @Ignore
+ public void runDynamicHeuristicFunctionTest() {
+ // compared to `FederatedL2SVMPlanningTest` this does not
create `fed_+*` or `fed_tsmm`, probably due to
+ // some rewrites not being applied. Might be a bug.
+ String[] expectedHeavyHitters = new String[] {"fed_fedinit",
"fed_ba+*"};
+ setTestConf("SystemDS-config-heuristic.xml");
+ loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+ }
+
+ private void setTestConf(String test_conf) {
+ TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, test_conf);
+ }
+
+ private void writeInputMatrices() {
+ writeBinaryVector("A", 42, rows);
+ writeStandardMatrix("B1", 65, rows / 2, cols);
+ writeStandardMatrix("B2", 75, rows / 2, cols);
+ writeStandardMatrix("C1", 13, rows, cols / 2);
+ writeStandardMatrix("C2", 17, rows, cols / 2);
+ }
+
+ private void writeBinaryVector(String matrixName, long seed, int
numRows){
+ double[][] matrix = getRandomMatrix(numRows, 1, -1, 1, 1, seed);
+ for(int i = 0; i < numRows; i++)
+ matrix[i][0] = (matrix[i][0] > 0) ? 1 : -1;
+ MatrixCharacteristics mc = new MatrixCharacteristics(numRows,
1, blocksize, numRows);
+ writeInputMatrixWithMTD(matrixName, matrix, false, mc);
+ }
+
+ private void writeStandardMatrix(String matrixName, long seed, int
numRows, int numCols) {
+ double[][] matrix = getRandomMatrix(numRows, numCols, 0, 1, 1,
seed);
+ writeStandardMatrix(matrixName, numRows, numCols, matrix);
+ }
+
+ private void writeStandardMatrix(String matrixName, int numRows, int
numCols, double[][] matrix) {
+ MatrixCharacteristics mc = new MatrixCharacteristics(numRows,
numCols, blocksize, (long) numRows * numCols);
+ writeInputMatrixWithMTD(matrixName, matrix, false, mc);
+ }
+
+ private void loadAndRunTest(String[] expectedHeavyHitters, String
testName) {
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ Types.ExecMode platformOld = rtplatform;
+ rtplatform = Types.ExecMode.SINGLE_NODE;
+
+ Thread t1 = null, t2 = null;
+
+ try {
+ getAndLoadTestConfiguration(testName);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ writeInputMatrices();
+
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ t1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
+ t2 = startLocalFedWorkerThread(port2);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + testName + ".dml";
+ programArgs = new String[] {"-stats", "-nvargs",
+ "r=" + rows, "c=" + cols,
+ "A=" + input("A"),
+ "B1=" + TestUtils.federatedAddress(port1,
input("B1")),
+ "B2=" + TestUtils.federatedAddress(port2,
input("B2")),
+ "C1=" + TestUtils.federatedAddress(port1,
input("C1")),
+ "C2=" + TestUtils.federatedAddress(port2,
input("C2")),
+ "lB1=" + input("B1"),
+ "lB2=" + input("B2"),
+ "Z=" + output("Z")};
+ runTest(true, false, null, -1);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + testName + "Reference.dml";
+ programArgs = new String[] {"-nvargs",
+ "r=" + rows, "c=" + cols,
+ "A=" + input("A"),
+ "B1=" + input("B1"),
+ "B2=" + input("B2"),
+ "C1=" + input("C1"),
+ "C2=" + input("C2"),
+ "Z=" + expected("Z")};
+ runTest(true, false, null, -1);
+
+ // compare via files
+ compareResults(1e-9);
+ if(!heavyHittersContainsAllString(expectedHeavyHitters))
+ fail("The following expected heavy hitters are
missing: "
+ +
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+ }
+ finally {
+ TestUtils.shutdownThreads(t1, t2);
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+
+ /**
+ * Override default configuration with custom test configuration to
ensure scratch space and local temporary
+ * directory locations are also updated.
+ */
+ @Override
+ protected File getConfigTemplateFile() {
+ // Instrumentation in this test's output log to show custom
configuration file used for template.
+ LOG.info("This test case overrides default configuration with "
+ TEST_CONF_FILE.getPath());
+ return TEST_CONF_FILE;
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java
new file mode 100644
index 0000000000..326516d423
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.fedplanning;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.io.File;
+import java.util.Arrays;
+
+import static org.junit.Assert.fail;
+
+public class FederatedKMeansPlanningTest extends AutomatedTestBase {
+ private static final Log LOG =
LogFactory.getLog(FederatedKMeansPlanningTest.class.getName());
+
+ private final static String TEST_DIR = "functions/privacy/fedplanning/";
+ private final static String TEST_NAME = "FederatedKMeansPlanningTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedKMeansPlanningTest.class.getSimpleName() + "/";
+ private static File TEST_CONF_FILE;
+
+ private final static int blocksize = 1024;
+ public final int rows = 1000;
+ public final int cols = 100;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+ }
+
+ @Test
+ public void runKMeansFOUTTest(){
+ String[] expectedHeavyHitters = new String[]{};
+ setTestConf("SystemDS-config-fout.xml");
+ loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+ }
+
+ @Test
+ public void runKMeansHeuristicTest(){
+ String[] expectedHeavyHitters = new String[]{};
+ setTestConf("SystemDS-config-heuristic.xml");
+ loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+ }
+
+ @Test
+ public void runRuntimeTest(){
+ String[] expectedHeavyHitters = new String[]{};
+ TEST_CONF_FILE = new
File("src/test/config/SystemDS-config.xml");
+ loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+ }
+
+ private void setTestConf(String test_conf){
+ TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, test_conf);
+ }
+
+ /**
+ * Override default configuration with custom test configuration to
ensure
+ * scratch space and local temporary directory locations are also
updated.
+ */
+ @Override
+ protected File getConfigTemplateFile() {
+ // Instrumentation in this test's output log to show custom
configuration file used for template.
+ LOG.info("This test case overrides default configuration with "
+ TEST_CONF_FILE.getPath());
+ return TEST_CONF_FILE;
+ }
+
+ private void writeInputMatrices(){
+ writeStandardRowFedMatrix("X1", 65);
+ writeStandardRowFedMatrix("X2", 75);
+ }
+
+ private void writeStandardMatrix(String matrixName, long seed, int
numRows){
+ double[][] matrix = getRandomMatrix(numRows, cols, 0, 1, 1,
seed);
+ writeStandardMatrix(matrixName, numRows, matrix);
+ }
+
+ private void writeStandardMatrix(String matrixName, int numRows,
double[][] matrix){
+ MatrixCharacteristics mc = new MatrixCharacteristics(numRows,
cols, blocksize, (long) numRows * cols);
+ writeInputMatrixWithMTD(matrixName, matrix, false, mc);
+ }
+
+ private void writeStandardRowFedMatrix(String matrixName, long seed){
+ int halfRows = rows/2;
+ writeStandardMatrix(matrixName, seed, halfRows);
+ }
+
+ private void loadAndRunTest(String[] expectedHeavyHitters, String
testName){
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ Types.ExecMode platformOld = rtplatform;
+ rtplatform = Types.ExecMode.SINGLE_NODE;
+
+ Thread t1 = null, t2 = null;
+
+ try {
+ getAndLoadTestConfiguration(testName);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ writeInputMatrices();
+
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ t1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
+ t2 = startLocalFedWorkerThread(port2);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + testName + ".dml";
+ programArgs = new String[] { "-stats", "-nvargs",
+ "X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "Y=" + input("Y"), "r=" + rows, "c=" + cols,
"Z=" + output("Z")};
+ runTest(true, false, null, -1);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + testName + "Reference.dml";
+ programArgs = new String[] {"-nvargs", "X1=" +
input("X1"), "X2=" + input("X2"),
+ "Y=" + input("Y"), "Z=" + expected("Z")};
+ runTest(true, false, null, -1);
+
+ // compare via files
+ compareResults(1e-9);
+ if
(!heavyHittersContainsAllString(expectedHeavyHitters))
+ fail("The following expected heavy hitters are
missing: "
+ +
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+ }
+ finally {
+ TestUtils.shutdownThreads(t1, t2);
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java
new file mode 100644
index 0000000000..3e8f8719a6
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.fedplanning;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import java.io.File;
+import java.util.Arrays;
+
+import static org.junit.Assert.fail;
+
[email protected]
+public class FederatedL2SVMPlanningTest extends AutomatedTestBase {
+ private static final Log LOG =
LogFactory.getLog(FederatedL2SVMPlanningTest.class.getName());
+
+ private final static String TEST_DIR = "functions/privacy/fedplanning/";
+ private final static String TEST_NAME = "FederatedL2SVMPlanningTest";
+ private final static String TEST_NAME_2 =
"FederatedL2SVMFunctionPlanningTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedL2SVMPlanningTest.class.getSimpleName() + "/";
+ private static File TEST_CONF_FILE;
+
+ private final static int blocksize = 1024;
+ public final int rows = 1000;
+ public final int cols = 100;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+ addTestConfiguration(TEST_NAME_2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_2, new String[] {"Z"}));
+ }
+
+ @Test
+ public void runL2SVMFOUTTest(){
+ String[] expectedHeavyHitters = new String[]{ "fed_fedinit",
"fed_ba+*", "fed_tak+*", "fed_+*",
+ "fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
+ setTestConf("SystemDS-config-fout.xml");
+ loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+ }
+
+ @Test
+ public void runL2SVMHeuristicTest(){
+ String[] expectedHeavyHitters = new String[]{ "fed_fedinit",
"fed_ba+*"};
+ setTestConf("SystemDS-config-heuristic.xml");
+ loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+ }
+
+ @Test
+ @Ignore //TODO
+ public void runL2SVMFunctionFOUTTest(){
+ String[] expectedHeavyHitters = new String[]{ "fed_fedinit",
"fed_ba+*", "fed_tak+*", "fed_+*",
+ "fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
+ setTestConf("SystemDS-config-fout.xml");
+ loadAndRunTest(expectedHeavyHitters, TEST_NAME_2);
+ }
+
+ @Test
+ @Ignore //TODO
+ public void runL2SVMFunctionHeuristicTest(){
+ String[] expectedHeavyHitters = new String[]{ "fed_fedinit",
"fed_ba+*"};
+ setTestConf("SystemDS-config-heuristic.xml");
+ loadAndRunTest(expectedHeavyHitters, TEST_NAME_2);
+ }
+
+ private void setTestConf(String test_conf){
+ TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, test_conf);
+ }
+
+ private void writeInputMatrices(){
+ writeStandardRowFedMatrix("X1", 65);
+ writeStandardRowFedMatrix("X2", 75);
+ writeBinaryVector("Y", 44);
+ }
+
+ private void writeBinaryVector(String matrixName, long seed){
+ double[][] matrix = getRandomMatrix(rows, 1, -1, 1, 1, seed);
+ for(int i = 0; i < rows; i++)
+ matrix[i][0] = (matrix[i][0] > 0) ? 1 : -1;
+ MatrixCharacteristics mc = new MatrixCharacteristics(rows, 1,
blocksize, rows);
+ writeInputMatrixWithMTD(matrixName, matrix, false, mc);
+ }
+
+ @SuppressWarnings("unused")
+ private void writeStandardMatrix(String matrixName, long seed){
+ writeStandardMatrix(matrixName, seed, rows);
+ }
+
+ private void writeStandardMatrix(String matrixName, long seed, int
numRows){
+ double[][] matrix = getRandomMatrix(numRows, cols, 0, 1, 1,
seed);
+ writeStandardMatrix(matrixName, numRows, matrix);
+ }
+
+ private void writeStandardMatrix(String matrixName, int numRows,
double[][] matrix){
+ MatrixCharacteristics mc = new MatrixCharacteristics(numRows,
cols, blocksize, (long) numRows * cols);
+ writeInputMatrixWithMTD(matrixName, matrix, false, mc);
+ }
+
+ private void writeStandardRowFedMatrix(String matrixName, long seed){
+ int halfRows = rows/2;
+ writeStandardMatrix(matrixName, seed, halfRows);
+ }
+
+ private void loadAndRunTest(String[] expectedHeavyHitters, String
testName){
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ Types.ExecMode platformOld = rtplatform;
+ rtplatform = Types.ExecMode.SINGLE_NODE;
+
+ Thread t1 = null, t2 = null;
+
+ try {
+ getAndLoadTestConfiguration(testName);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ writeInputMatrices();
+
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ t1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
+ t2 = startLocalFedWorkerThread(port2);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + testName + ".dml";
+ programArgs = new String[] { "-stats", "-nvargs",
+ "X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "Y=" + input("Y"), "r=" + rows, "c=" + cols,
"Z=" + output("Z")};
+ runTest(true, false, null, -1);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + testName + "Reference.dml";
+ programArgs = new String[] {"-nvargs", "X1=" +
input("X1"), "X2=" + input("X2"),
+ "Y=" + input("Y"), "Z=" + expected("Z")};
+ runTest(true, false, null, -1);
+
+ // compare via files
+ compareResults(1e-9);
+ if
(!heavyHittersContainsAllString(expectedHeavyHitters))
+ fail("The following expected heavy hitters are
missing: "
+ +
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+ }
+ finally {
+ TestUtils.shutdownThreads(t1, t2);
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+
+ /**
+ * Override default configuration with custom test configuration to
ensure
+ * scratch space and local temporary directory locations are also
updated.
+ */
+ @Override
+ protected File getConfigTemplateFile() {
+ // Instrumentation in this test's output log to show custom
configuration file used for template.
+ LOG.info("This test case overrides default configuration with "
+ TEST_CONF_FILE.getPath());
+ return TEST_CONF_FILE;
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java
new file mode 100644
index 0000000000..5b54f14d05
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.fedplanning;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+import java.io.File;
+import java.util.Arrays;
+import java.util.Collection;
+
+import static org.junit.Assert.fail;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
+ private static final Log LOG =
LogFactory.getLog(FederatedMultiplyPlanningTest.class.getName());
+
+ private final static String TEST_DIR = "functions/privacy/fedplanning/";
+ private final static String TEST_NAME = "FederatedMultiplyPlanningTest";
+ private final static String TEST_NAME_2 =
"FederatedMultiplyPlanningTest2";
+ private final static String TEST_NAME_3 =
"FederatedMultiplyPlanningTest3";
+ private final static String TEST_NAME_4 =
"FederatedMultiplyPlanningTest4";
+ private final static String TEST_NAME_5 =
"FederatedMultiplyPlanningTest5";
+ private final static String TEST_NAME_6 =
"FederatedMultiplyPlanningTest6";
+ private final static String TEST_NAME_7 =
"FederatedMultiplyPlanningTest7";
+ private final static String TEST_NAME_8 =
"FederatedMultiplyPlanningTest8";
+ private final static String TEST_NAME_9 =
"FederatedMultiplyPlanningTest9";
+ private final static String TEST_NAME_10 =
"FederatedMultiplyPlanningTest10";
+ private final static String TEST_NAME_11 =
"FederatedMultiplyPlanningTest11";
+ private final static String TEST_NAME_12 =
"FederatedMultiplyPlanningTest12";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedMultiplyPlanningTest.class.getSimpleName() + "/";
+ private static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR,
"SystemDS-config-heuristic.xml");
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+ addTestConfiguration(TEST_NAME_2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_2, new String[] {"Z"}));
+ addTestConfiguration(TEST_NAME_3, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_3, new String[] {"Z.scalar"}));
+ addTestConfiguration(TEST_NAME_4, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_4, new String[] {"Z"}));
+ addTestConfiguration(TEST_NAME_5, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_5, new String[] {"Z"}));
+ addTestConfiguration(TEST_NAME_6, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_6, new String[] {"Z"}));
+ addTestConfiguration(TEST_NAME_7, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_7, new String[] {"Z"}));
+ addTestConfiguration(TEST_NAME_8, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_8, new String[] {"Z.scalar"}));
+ addTestConfiguration(TEST_NAME_9, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_9, new String[] {"Z.scalar"}));
+ addTestConfiguration(TEST_NAME_10, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_10, new String[] {"Z"}));
+ addTestConfiguration(TEST_NAME_11, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_11, new String[] {"Z"}));
+ addTestConfiguration(TEST_NAME_12, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_12, new String[] {"Z"}));
+ }
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ // rows have to be even and > 1
+ return Arrays.asList(new Object[][] {
+ {100, 10}
+ });
+ }
+
+ @Test
+ public void federatedMultiplyCP() {
+ String[] expectedHeavyHitters = new String[]{"fed_*",
"fed_fedinit", "fed_r'", "fed_ba+*"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME,
expectedHeavyHitters);
+ }
+
+ @Test
+ public void federatedRowSum(){
+ String[] expectedHeavyHitters = new String[]{"fed_*", "fed_r'",
"fed_fedinit", "fed_ba+*", "fed_uark+"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_2,
expectedHeavyHitters);
+ }
+
+ @Test
+ public void federatedTernarySequence(){
+ String[] expectedHeavyHitters = new String[]{"fed_+*",
"fed_1-*", "fed_fedinit", "fed_uak+"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_3,
expectedHeavyHitters);
+ }
+
+ @Test
+ public void federatedAggregateBinarySequence(){
+ cols = rows;
+ String[] expectedHeavyHitters = new String[]{"fed_ba+*",
"fed_*", "fed_fedinit"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_4,
expectedHeavyHitters);
+ }
+
+ @Test
+ public void federatedAggregateBinaryColFedSequence(){
+ cols = rows;
+ //TODO: When alignment checks have been added to
getFederatedOut in AFederatedPlanner,
+ // the following expectedHeavyHitters can be added. Until then,
fed_* will not be generated.
+ //String[] expectedHeavyHitters = new
String[]{"fed_ba+*","fed_*","fed_fedinit"};
+ String[] expectedHeavyHitters = new
String[]{"fed_ba+*","fed_fedinit"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_5,
expectedHeavyHitters);
+ }
+
+ @Test
+ public void federatedAggregateBinarySequence2(){
+ String[] expectedHeavyHitters = new
String[]{"fed_ba+*","fed_fedinit"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_6,
expectedHeavyHitters);
+ }
+
+ @Test
+ public void federatedMultiplyDoubleHop() {
+ String[] expectedHeavyHitters = new String[]{"fed_*",
"fed_fedinit", "fed_ba+*"}; //TODO "fed_r' " ?
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_7,
expectedHeavyHitters);
+ }
+
+ @Test
+ public void federatedMultiplyDoubleHop2() {
+ String[] expectedHeavyHitters = new String[]{"fed_fedinit",
"fed_ba+*"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_8,
expectedHeavyHitters);
+ }
+
+ @Test
+ public void federatedMultiplyPlanningTest9(){
+ String[] expectedHeavyHitters = new String[]{"fed_+*",
"fed_1-*", "fed_fedinit", "fed_max"}; //TODO "fed_tak+*"
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_9,
expectedHeavyHitters);
+ }
+
+ @Test
+ public void federatedMultiplyPlanningTest10(){
+ String[] expectedHeavyHitters = new String[]{"fed_fedinit",
"fed_^2"};
+ TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR,
"SystemDS-config-fout.xml");
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_10,
expectedHeavyHitters);
+ }
+
+ @Test
+ public void federatedMultiplyPlanningTest11(){
+ String[] expectedHeavyHitters = new String[]{"fed_fedinit"};
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_11,
expectedHeavyHitters);
+ }
+
+ @Test
+ public void federatedMultiplyPlanningTest12(){
+ String[] expectedHeavyHitters = new String[]{"fed_fedinit"};
+ rows = 30;
+ cols = 30;
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_12,
expectedHeavyHitters);
+ }
+
+ private void writeStandardMatrix(String matrixName, long seed){
+ int halfRows = rows/2;
+ double[][] matrix = getRandomMatrix(halfRows, cols, 0, 1, 1,
seed);
+ MatrixCharacteristics mc = new MatrixCharacteristics(halfRows,
cols, blocksize, (long) halfRows * cols);
+ writeInputMatrixWithMTD(matrixName, matrix, false, mc);
+ }
+
+ private void writeColStandardMatrix(String matrixName, long seed){
+ int halfCols = cols/2;
+ double[][] matrix = getRandomMatrix(rows, halfCols, 0, 1, 1,
seed);
+ MatrixCharacteristics mc = new MatrixCharacteristics(rows,
halfCols, blocksize, (long) halfCols *rows);
+ writeInputMatrixWithMTD(matrixName, matrix, false, mc);
+ }
+
+ private void writeRowFederatedVector(String matrixName, long seed){
+ int halfCols = cols / 2;
+ double[][] matrix = getRandomMatrix(halfCols, 1, 0, 1, 1, seed);
+ MatrixCharacteristics mc = new MatrixCharacteristics(halfCols,
1, blocksize, (long) halfCols *rows);
+ writeInputMatrixWithMTD(matrixName, matrix, false, mc);
+ }
+
+ private void writeInputMatrices(String testName){
+ if ( testName.equals(TEST_NAME_5) ){
+ writeColStandardMatrix("X1", 42);
+ writeColStandardMatrix("X2", 1340);
+ writeColStandardMatrix("Y1", 44);
+ writeColStandardMatrix("Y2", 21);
+ }
+ else if ( testName.equals(TEST_NAME_6) ){
+ writeColStandardMatrix("X1", 42);
+ writeColStandardMatrix("X2", 1340);
+ writeRowFederatedVector("Y1", 44);
+ writeRowFederatedVector("Y2", 21);
+ }
+ else if ( testName.equals(TEST_NAME_8) ){
+ writeColStandardMatrix("X1", 42);
+ writeColStandardMatrix("X2", 1340);
+ writeColStandardMatrix("Y1", 44);
+ writeColStandardMatrix("Y2", 21);
+ writeColStandardMatrix("W1", 76);
+ writeColStandardMatrix("W2", 11);
+ }
+ else if ( testName.equals(TEST_NAME_10) ||
testName.equals(TEST_NAME_12) ){
+ writeStandardMatrix("X1", 42);
+ writeStandardMatrix("X2", 1340);
+ }
+ else {
+ writeStandardMatrix("X1", 42);
+ writeStandardMatrix("X2", 1340);
+ if ( testName.equals(TEST_NAME_4) ){
+ writeStandardMatrix("Y1", 44);
+ writeStandardMatrix("Y2", 21);
+ }
+ else {
+ writeStandardMatrix("Y1", 44);
+ writeStandardMatrix("Y2", 21);
+ }
+ }
+ }
+
+ private void federatedTwoMatricesSingleNodeTest(String testName,
String[] expectedHeavyHitters){
+ federatedTwoMatricesTest(Types.ExecMode.SINGLE_NODE, testName,
expectedHeavyHitters);
+ }
+
+ private void federatedTwoMatricesTest(Types.ExecMode execMode, String
testName, String[] expectedHeavyHitters) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ Types.ExecMode platformOld = rtplatform;
+ rtplatform = execMode;
+ if(rtplatform == Types.ExecMode.SPARK) {
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+ Thread t1 = null, t2 = null;
+
+ try{
+ getAndLoadTestConfiguration(testName);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ writeInputMatrices(testName);
+
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ t1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
+ t2 = startLocalFedWorkerThread(port2);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + testName + ".dml";
+ programArgs = new String[] {"-stats", "-nvargs", "X1="
+ TestUtils.federatedAddress(port1, input("X1")),
+ "X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "Y1=" + TestUtils.federatedAddress(port1,
input("Y1")),
+ "Y2=" + TestUtils.federatedAddress(port2,
input("Y2")), "r=" + rows, "c=" + cols, "Z=" + output("Z")};
+ rewriteRealProgramArgs(testName, port1, port2);
+ runTest(true, false, null, -1);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + testName + "Reference.dml";
+ programArgs = new String[] {"-nvargs", "X1=" +
input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
+ "Y2=" + input("Y2"), "Z=" + expected("Z")};
+ rewriteReferenceProgramArgs(testName);
+ runTest(true, false, null, -1);
+
+ // compare via files
+ compareResults(1e-9);
+ if
(!heavyHittersContainsAllString(expectedHeavyHitters))
+ fail("The following expected heavy hitters are
missing: "
+ +
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+ } finally {
+ TestUtils.shutdownThreads(t1, t2);
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+
+ private void rewriteRealProgramArgs(String testName, int port1, int
port2){
+ if ( testName.equals(TEST_NAME_4) ||
testName.equals(TEST_NAME_5) ){
+ programArgs = new String[] {"-stats","-nvargs", "X1=" +
TestUtils.federatedAddress(port1, input("X1")),
+ "X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "Y1=" + input("Y1"),
+ "Y2=" + input("Y2"), "r=" + rows, "c=" + cols,
"Z=" + output("Z")};
+ } else if ( testName.equals(TEST_NAME_8) ){
+ programArgs = new String[] {"-stats","-nvargs", "X1=" +
TestUtils.federatedAddress(port1, input("X1")),
+ "X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "Y1=" + TestUtils.federatedAddress(port1,
input("Y1")),
+ "Y2=" + TestUtils.federatedAddress(port2,
input("Y2")),
+ "W1=" + input("W1"),
+ "W2=" + input("W2"),
+ "r=" + rows, "c=" + cols, "Z=" + output("Z")};
+ }
+ }
+
+ private void rewriteReferenceProgramArgs(String testName){
+ if ( testName.equals(TEST_NAME_8) ){
+ programArgs = new String[] {"-nvargs", "X1=" +
input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
+ "Y2=" + input("Y2"), "W1=" + input("W1"), "W2="
+ input("W2"), "Z=" + expected("Z")};
+ }
+ }
+
+ /**
+ * Override default configuration with custom test configuration to
ensure
+ * scratch space and local temporary directory locations are also
updated.
+ */
+ @Override
+ protected File getConfigTemplateFile() {
+ // Instrumentation in this test's output log to show custom
configuration file used for template.
+ LOG.info("This test case overrides default configuration with "
+ TEST_CONF_FILE.getPath());
+ return TEST_CONF_FILE;
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/fedplanning/FTypeCombTest.java
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FTypeCombTest.java
new file mode 100644
index 0000000000..e36d517d98
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FTypeCombTest.java
@@ -0,0 +1,71 @@
+package org.apache.sysds.test.functions.fedplanning;
+///*
+// * Licensed to the Apache Software Foundation (ASF) under one
+// * or more contributor license agreements. See the NOTICE file
+// * distributed with this work for additional information
+// * regarding copyright ownership. The ASF licenses this file
+// * to you under the Apache License, Version 2.0 (the
+// * "License"); you may not use this file except in compliance
+// * with the License. You may obtain a copy of the License at
+// *
+// * http://www.apache.org/licenses/LICENSE-2.0
+// *
+// * Unless required by applicable law or agreed to in writing,
+// * software distributed under the License is distributed on an
+// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// * KIND, either express or implied. See the License for the
+// * specific language governing permissions and limitations
+// * under the License.
+// */
+//
+//package org.apache.sysds.test.functions.privacy.fedplanning;
+//
+//import org.apache.sysds.hops.fedplanner.FTypes.FType;
+//import org.apache.sysds.hops.fedplanner.FederatedPlannerCostbased;
+//import org.apache.sysds.test.AutomatedTestBase;
+//import org.junit.Assert;
+//import org.junit.Test;
+//
+//import java.util.ArrayList;
+//import java.util.List;
+//
+//public class FTypeCombTest extends AutomatedTestBase {
+//
+// @Override public void setUp() {}
+//
+// @Test
+// public void ftypeCombTest(){
+// List<FType> secondInput = new ArrayList<>();
+// secondInput.add(null);
+// List<List<FType>> inputFTypes = List.of(
+// List.of(FType.ROW,FType.COL),
+// secondInput,
+// List.of(FType.BROADCAST,FType.FULL)
+// );
+//
+// FederatedPlannerCostbased planner = new
FederatedPlannerCostbased();
+// List<List<FType>> actualCombinations =
planner.getAllCombinations(inputFTypes);
+//
+// List<FType> expected1 = new ArrayList<>();
+// expected1.add(FType.ROW);
+// expected1.add(null);
+// expected1.add(FType.BROADCAST);
+// List<FType> expected2 = new ArrayList<>();
+// expected2.add(FType.ROW);
+// expected2.add(null);
+// expected2.add(FType.FULL);
+// List<FType> expected3 = new ArrayList<>();
+// expected3.add(FType.COL);
+// expected3.add(null);
+// expected3.add(FType.BROADCAST);
+// List<FType> expected4 = new ArrayList<>();
+// expected4.add(FType.COL);
+// expected4.add(null);
+// expected4.add(FType.FULL);
+// List<List<FType>> expectedCombinations =
List.of(expected1,expected2, expected3, expected4);
+//
+// Assert.assertEquals(expectedCombinations.size(),
actualCombinations.size());
+// for (List<FType> expectedComb : expectedCombinations)
+//
Assert.assertTrue(actualCombinations.contains(expectedComb));
+// }
+//}
diff --git
a/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedCostEstimatorTest.java
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedCostEstimatorTest.java
new file mode 100644
index 0000000000..073c8f1d9d
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedCostEstimatorTest.java
@@ -0,0 +1,373 @@
+package org.apache.sysds.test.functions.fedplanning;
+///*
+// * Licensed to the Apache Software Foundation (ASF) under one
+// * or more contributor license agreements. See the NOTICE file
+// * distributed with this work for additional information
+// * regarding copyright ownership. The ASF licenses this file
+// * to you under the Apache License, Version 2.0 (the
+// * "License"); you may not use this file except in compliance
+// * with the License. You may obtain a copy of the License at
+// *
+// * http://www.apache.org/licenses/LICENSE-2.0
+// *
+// * Unless required by applicable law or agreed to in writing,
+// * software distributed under the License is distributed on an
+// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// * KIND, either express or implied. See the License for the
+// * specific language governing permissions and limitations
+// * under the License.
+// */
+//
+//package org.apache.sysds.test.functions.privacy.fedplanning;
+//
+//import net.jcip.annotations.NotThreadSafe;
+//import org.apache.sysds.api.DMLScript;
+//import org.apache.sysds.common.Types;
+//import org.apache.sysds.conf.ConfigurationManager;
+//import org.apache.sysds.conf.DMLConfig;
+//import org.apache.sysds.hops.AggBinaryOp;
+//import org.apache.sysds.hops.BinaryOp;
+//import org.apache.sysds.hops.DataOp;
+//import org.apache.sysds.hops.Hop;
+//import org.apache.sysds.hops.LiteralOp;
+//import org.apache.sysds.hops.NaryOp;
+//import org.apache.sysds.hops.ReorgOp;
+//import org.apache.sysds.hops.cost.FederatedCost;
+//import org.apache.sysds.hops.cost.FederatedCostEstimator;
+//import org.apache.sysds.hops.fedplanner.FederatedPlannerCostbased;
+//import org.apache.sysds.hops.ipa.FunctionCallGraph;
+//import org.apache.sysds.parser.DMLProgram;
+//import org.apache.sysds.parser.DMLTranslator;
+//import org.apache.sysds.parser.LanguageException;
+//import org.apache.sysds.parser.ParserFactory;
+//import org.apache.sysds.parser.ParserWrapper;
+//import org.apache.sysds.parser.StatementBlock;
+//import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+//import org.apache.sysds.test.AutomatedTestBase;
+//import org.apache.sysds.test.TestConfiguration;
+//import org.junit.After;
+//import org.junit.Assert;
+//import org.junit.Before;
+//import org.junit.BeforeClass;
+//import org.junit.Test;
+//
+//import java.io.FileNotFoundException;
+//import java.io.IOException;
+//import java.util.HashMap;
+//import java.util.HashSet;
+//import java.util.Set;
+//
+//import static org.apache.sysds.common.Types.OpOp2.MULT;
+//
+//@NotThreadSafe
+//public class FederatedCostEstimatorTest extends AutomatedTestBase {
+//
+// private static final String TEST_DIR = "functions/privacy/fedplanning/";
+// private static final String HOME = SCRIPT_DIR + TEST_DIR;
+// private static final String TEST_CLASS_DIR = TEST_DIR +
FederatedCostEstimatorTest.class.getSimpleName() + "/";
+// FederatedCostEstimator fedCostEstimator = new FederatedCostEstimator();
+//
+// private static double COMPUTE_FLOPS;
+// private static double READ_PS;
+// private static double NETWORK_PS;
+//
+// @Override
+// public void setUp() {}
+//
+// @BeforeClass
+// public static void storeConstants(){
+// COMPUTE_FLOPS =
FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS;
+// READ_PS = FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS;
+// NETWORK_PS =
FederatedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS;
+// }
+//
+// @Before
+// public void setConstants(){
+// FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
+// FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+// FederatedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS = 5;
+// }
+//
+// @After
+// public void resetConstants(){
+// FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS =
COMPUTE_FLOPS;
+// FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = READ_PS;
+// FederatedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS =
NETWORK_PS;
+// }
+//
+// @Test
+// public void simpleBinary() {
+//
+// /*
+// * HOP Occurences ComputeCost
ReadCost ComputeCostFinal ReadCostFinal
+// *
------------------------------------------------------------------------------------------
+// * LiteralOp 16 1
0 0.0625 0
+// * DataGenOp 2 100
64 6.25 6.4
+// * BinaryOp 1 100
1600 6.25 160
+// * TOSTRING 1 1
800 0.0625
80
+// * UnaryOp 1 1
8 0.0625
0.8
+// */
+// double computeCost = (16+2*100+100+1+1) /
(FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *
FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+// double readCost = (2*64+1600+800+8) /
(FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+//
+// double expectedCost = computeCost + readCost;
+// runTest("BinaryCostEstimatorTest.dml", false, expectedCost);
+// }
+//
+// @Test
+// public void simpleBinaryHopRelTest() {
+// runHopRelTest("BinaryCostEstimatorTest.dml", false);
+// }
+//
+// @Test
+// public void ifElseTest(){
+// double computeCost = (16+2*100+100+1+1) /
(FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *
FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+// double readCost = (2*64+1600+800+8) /
(FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+// double expectedCost = ((computeCost + readCost + 0.8 + 0.0625 +
0.0625) / 2) + 0.0625 + 0.8 + 0.0625;
+// runTest("IfElseCostEstimatorTest.dml", false, expectedCost);
+// }
+//
+// @Test
+// public void ifElseHopRelTest(){
+// runHopRelTest("IfElseCostEstimatorTest.dml", false);
+// }
+//
+// @Test
+// public void whileTest(){
+// double computeCost = (16+2*100+100+1+1) /
(FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *
FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+// double readCost = (2*64+1600+800+8) /
(FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+// double expectedCost = (computeCost + readCost + 0.0625 + 0.0625
+ 0.8) * StatementBlock.DEFAULT_LOOP_REPETITIONS;
+// runTest("WhileCostEstimatorTest.dml", false, expectedCost);
+// }
+//
+// @Test
+// public void whileHopRelTest(){
+// runHopRelTest("WhileCostEstimatorTest.dml", false);
+// }
+//
+// @Test
+// public void forLoopTest(){
+// double computeCost = (16+2*100+100+1+1) /
(FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *
FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+// double readCost = (2*64+1600+800+8) /
(FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+// double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 +
0.0625 + 0.0625 + 0.8 + 0.0625;
+// double expectedCost = (computeCost + readCost + predicateCost)
* 5;
+// runTest("ForLoopCostEstimatorTest.dml", false, expectedCost);
+// }
+//
+// @Test
+// public void forLoopHopRelTest(){
+// runHopRelTest("ForLoopCostEstimatorTest.dml", false);
+// }
+//
+// @Test
+// public void parForLoopTest(){
+// double computeCost = (16+2*100+100+1+1) /
(FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *
FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+// double readCost = (2*64+1600+800+8) /
(FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+// double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 +
0.0625 + 0.0625 + 0.8 + 0.0625;
+// double expectedCost = (computeCost + readCost + predicateCost)
* 5;
+// runTest("ParForLoopCostEstimatorTest.dml", false, expectedCost);
+// }
+//
+// @Test
+// public void parForLoopHopRelTest(){
+// runHopRelTest("ParForLoopCostEstimatorTest.dml", false);
+// }
+//
+// @Test
+// public void functionTest(){
+// double computeCost = (16+2*100+100+1+1) /
(FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS *
FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+// double readCost = (2*64+1600+800+8) /
(FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+// double expectedCost = (computeCost + readCost);
+// runTest("FunctionCostEstimatorTest.dml", false, expectedCost);
+// }
+//
+// @Test
+// public void functionHopRelTest(){
+// runHopRelTest("FunctionCostEstimatorTest.dml", false);
+// }
+//
+// @Test
+// public void federatedMultiply() {
+//
+// double literalOpCost = 10*0.0625;
+// double naryOpCostSpecial = (0.125+2.2);
+// double naryOpCostSpecial2 = (0.25+6.4);
+// double naryOpCost = 4*(0.125+1.6);
+// double reorgOpCost = 6250+80015.2+160030.4;
+// double binaryOpMultCost = 3125+160000;
+// double aggBinaryOpCost = 125000+160015.2+160030.4+190.4;
+// double dataOpCost = 2*(6250+5.6);
+// double dataOpWriteCost = 6.25+100.3;
+//
+// double expectedCost = literalOpCost + naryOpCost +
naryOpCostSpecial + naryOpCostSpecial2 + reorgOpCost
+// + binaryOpMultCost + aggBinaryOpCost + dataOpCost +
dataOpWriteCost;
+// runTest("FederatedMultiplyCostEstimatorTest.dml", false,
expectedCost);
+//
+// double aggBinaryActualCost = hops.stream()
+// .filter(hop -> hop instanceof AggBinaryOp)
+// .mapToDouble(aggHop ->
aggHop.getFederatedCost().getTotal()-aggHop.getFederatedCost().getInputTotalCost())
+// .sum();
+// Assert.assertEquals(aggBinaryOpCost, aggBinaryActualCost,
0.0001);
+//
+// double writeActualCost = hops.stream()
+// .filter(hop -> hop instanceof DataOp)
+// .mapToDouble(writeHop ->
writeHop.getFederatedCost().getTotal()-writeHop.getFederatedCost().getInputTotalCost())
+// .sum();
+// Assert.assertEquals(dataOpWriteCost+dataOpCost,
writeActualCost, 0.0001);
+// }
+//
+// Set<Hop> hops = new HashSet<>();
+//
+// /**
+// * Recursively adds the hop and its inputs to the set of hops.
+// * @param hop root to be added to set of hops
+// */
+// private void addHop(Hop hop){
+// hops.add(hop);
+// for(Hop inHop : hop.getInput()){
+// addHop(inHop);
+// }
+// }
+//
+// /**
+// * Sets dimensions of federated X and Y and sets binary multiplication
to FOUT.
+// * @param prog dml program where the HOPS are modified
+// */
+// private void modifyFedouts(DMLProgram prog){
+// prog.getStatementBlocks().forEach(stmBlock ->
stmBlock.getHops().forEach(this::addHop));
+// hops.forEach(hop -> {
+// if ( hop instanceof DataOp || (hop instanceof BinaryOp
&& ((BinaryOp) hop).getOp() == MULT ) ){
+//
hop.setFederatedOutput(FEDInstruction.FederatedOutput.FOUT);
+// hop.setExecType(Types.ExecType.FED);
+// } else {
+//
hop.setFederatedOutput(FEDInstruction.FederatedOutput.LOUT);
+// }
+// if ( hop.getOpString().equals("Fed Y") ||
hop.getOpString().equals("Fed X") ){
+// hop.setDim1(10000);
+// hop.setDim2(10);
+// }
+// });
+// }
+//
+// @SuppressWarnings("unused")
+// private void printHopsInfo(){
+// //LiteralOp
+// long literalCount = hops.stream().filter(hop -> hop instanceof
LiteralOp).count();
+// System.out.println("LiteralOp Count: " + literalCount);
+// //NaryOp
+// long naryCount = hops.stream().filter(hop -> hop instanceof
NaryOp).count();
+// System.out.println("NaryOp Count " + naryCount);
+// //ReorgOp
+// long reorgCount = hops.stream().filter(hop -> hop instanceof
ReorgOp).count();
+// System.out.println("ReorgOp Count: " + reorgCount);
+// //BinaryOp
+// long binaryCount = hops.stream().filter(hop -> hop instanceof
BinaryOp).count();
+// System.out.println("Binary count: " + binaryCount);
+// //AggBinaryOp
+// long aggBinaryCount = hops.stream().filter(hop -> hop
instanceof AggBinaryOp).count();
+// System.out.println("AggBinaryOp Count: " + aggBinaryCount);
+// //DataOp
+// long dataOpCount = hops.stream().filter(hop -> hop instanceof
DataOp).count();
+// System.out.println("DataOp Count: " + dataOpCount);
+//
+//
hops.stream().map(Hop::getClass).distinct().forEach(System.out::println);
+// }
+//
+// private DMLProgram testSetup(String scriptFilename) throws IOException{
+// setTestConfig(scriptFilename);
+// String dmlScriptString = readScript(scriptFilename);
+//
+// //parsing, dependency analysis and constructing hops (step 3
and 4 in DMLScript.java)
+// ParserWrapper parser = ParserFactory.createParser();
+// DMLProgram prog =
parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new
HashMap<>());
+// DMLTranslator dmlt = new DMLTranslator(prog);
+// dmlt.liveVariableAnalysis(prog);
+// dmlt.validateParseTree(prog);
+// dmlt.constructHops(prog);
+// if (
scriptFilename.equals("FederatedMultiplyCostEstimatorTest.dml")){
+// modifyFedouts(prog);
+// dmlt.rewriteHopsDAG(prog);
+// hops = new HashSet<>();
+// prog.getStatementBlocks().forEach(stmBlock ->
stmBlock.getHops().forEach(this::addHop));
+// }
+// return prog;
+// }
+//
+// private void compareResults(DMLProgram prog) {
+// FederatedPlannerCostbased rewriter = new
FederatedPlannerCostbased();
+// rewriter.rewriteProgram(prog, new FunctionCallGraph(prog),
null);
+//
+// double actualCost = 0;
+// for ( Hop root : rewriter.getTerminalHops() ){
+// actualCost += root.getFederatedCost().getTotal();
+// }
+//
+//
+// rewriter.getTerminalHops().forEach(Hop::resetFederatedCost);
+// fedCostEstimator = new FederatedCostEstimator();
+// double expectedCost = 0;
+// for ( Hop root : rewriter.getTerminalHops() )
+// expectedCost +=
fedCostEstimator.costEstimate(root).getTotal();
+// Assert.assertEquals(expectedCost, actualCost, 0.0001);
+// }
+//
+// private void runHopRelTest( String scriptFilename, boolean
expectedException ) {
+// boolean raisedException = false;
+// try
+// {
+// DMLProgram prog = testSetup(scriptFilename);
+// compareResults(prog);
+// }
+// catch(LanguageException ex) {
+// raisedException = true;
+// if(raisedException!=expectedException)
+// ex.printStackTrace();
+// }
+// catch(Exception ex2) {
+// ex2.printStackTrace();
+// throw new RuntimeException(ex2);
+// }
+//
+// Assert.assertEquals("Expected exception does not match raised
exception",
+// expectedException, raisedException);
+// }
+//
+// private void runTest( String scriptFilename, boolean expectedException,
double expectedCost ) {
+// boolean raisedException = false;
+// try
+// {
+// DMLProgram prog = testSetup(scriptFilename);
+//
+// fedCostEstimator = new FederatedCostEstimator();
+// FederatedCost actualCost =
fedCostEstimator.costEstimate(prog);
+// Assert.assertEquals(expectedCost,
actualCost.getTotal(), 0.0001);
+// }
+// catch(LanguageException ex) {
+// raisedException = true;
+// if(raisedException!=expectedException)
+// ex.printStackTrace();
+// }
+// catch(Exception ex2) {
+// ex2.printStackTrace();
+// throw new RuntimeException(ex2);
+// }
+//
+// Assert.assertEquals("Expected exception does not match raised
exception",
+// expectedException, raisedException);
+// }
+//
+// private void setTestConfig(String scriptFilename) throws
FileNotFoundException {
+// int index = scriptFilename.lastIndexOf(".dml");
+// String testName = scriptFilename.substring(0, index > 0 ? index
: scriptFilename.length());
+// TestConfiguration testConfig = new
TestConfiguration(TEST_CLASS_DIR, testName, new String[] {});
+// addTestConfiguration(testName, testConfig);
+// loadTestConfiguration(testConfig);
+//
+// DMLConfig conf = new DMLConfig(getCurConfigFile().getPath());
+// ConfigurationManager.setLocalConfig(conf);
+// }
+//
+// private static String readScript(String scriptFilename) throws
IOException {
+// return DMLScript.readDMLScript(true, HOME + scriptFilename);
+// }
+//}
diff --git
a/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedDynamicPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedDynamicPlanningTest.java
new file mode 100644
index 0000000000..23da01d438
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedDynamicPlanningTest.java
@@ -0,0 +1,188 @@
+package org.apache.sysds.test.functions.fedplanning;
+///*
+// * Licensed to the Apache Software Foundation (ASF) under one
+// * or more contributor license agreements. See the NOTICE file
+// * distributed with this work for additional information
+// * regarding copyright ownership. The ASF licenses this file
+// * to you under the Apache License, Version 2.0 (the
+// * "License"); you may not use this file except in compliance
+// * with the License. You may obtain a copy of the License at
+// *
+// * http://www.apache.org/licenses/LICENSE-2.0
+// *
+// * Unless required by applicable law or agreed to in writing,
+// * software distributed under the License is distributed on an
+// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// * KIND, either express or implied. See the License for the
+// * specific language governing permissions and limitations
+// * under the License.
+// */
+//
+//package org.apache.sysds.test.functions.privacy.fedplanning;
+//
+//import org.apache.commons.logging.Log;
+//import org.apache.commons.logging.LogFactory;
+//import org.apache.sysds.api.DMLScript;
+//import org.apache.sysds.common.Types;
+//import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+//import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+//import org.apache.sysds.test.AutomatedTestBase;
+//import org.apache.sysds.test.TestConfiguration;
+//import org.apache.sysds.test.TestUtils;
+//import org.junit.Test;
+//
+//import java.io.File;
+//import java.util.Arrays;
+//
+//import static org.junit.Assert.fail;
+//
+//@net.jcip.annotations.NotThreadSafe
+//public class FederatedDynamicPlanningTest extends AutomatedTestBase {
+// private static final Log LOG =
LogFactory.getLog(FederatedDynamicPlanningTest.class.getName());
+//
+// private final static String TEST_DIR = "functions/privacy/fedplanning/";
+// private final static String TEST_NAME =
"FederatedDynamicFunctionPlanningTest";
+// private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedDynamicPlanningTest.class.getSimpleName() + "/";
+// private static File TEST_CONF_FILE;
+//
+// private final static int blocksize = 1024;
+// public final int rows = 1000;
+// public final int cols = 1000;
+//
+// @Override
+// public void setUp() {
+// TestUtils.clearAssertionInformation();
+// addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+// }
+//
+// @Test
+// public void runDynamicFullFunctionTest() {
+// // compared to `FederatedL2SVMPlanningTest` this does not
create `fed_+*` or `fed_tsmm`, probably due to
+// // some rewrites not being applied. Might be a bug.
+// String[] expectedHeavyHitters = new String[] {"fed_fedinit",
"fed_ba+*", "fed_tak+*", "fed_max",
+// "fed_1-*", "fed_>"};
+// setTestConf("SystemDS-config-fout.xml");
+// loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+// }
+//
+// @Test
+// public void runDynamicHeuristicFunctionTest() {
+// // compared to `FederatedL2SVMPlanningTest` this does not
create `fed_+*` or `fed_tsmm`, probably due to
+// // some rewrites not being applied. Might be a bug.
+// String[] expectedHeavyHitters = new String[] {"fed_fedinit",
"fed_ba+*"};
+// setTestConf("SystemDS-config-heuristic.xml");
+// loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+// }
+//
+// @Test
+// public void runDynamicCostBasedFunctionTest() {
+// // compared to `FederatedL2SVMPlanningTest` this does not
create `fed_+*` or `fed_tsmm`, probably due to
+// // some rewrites not being applied. Might be a bug.
+// String[] expectedHeavyHitters = new String[] {"fed_fedinit",
"fed_ba+*", "fed_tak+*", "fed_max",
+// "fed_1-*", "fed_>"};
+// setTestConf("SystemDS-config-cost-based.xml");
+// loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+// }
+//
+// private void setTestConf(String test_conf) {
+// TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, test_conf);
+// }
+//
+// private void writeInputMatrices() {
+// writeBinaryVector("A", 42, rows, null);
+// writeStandardMatrix("B1", 65, rows / 2, cols, null);
+// writeStandardMatrix("B2", 75, rows / 2, cols, null);
+// writeStandardMatrix("C1", 13, rows, cols / 2, null);
+// writeStandardMatrix("C2", 17, rows, cols / 2, null);
+// }
+//
+// private void writeBinaryVector(String matrixName, long seed, int
numRows, PrivacyConstraint privacyConstraint){
+// double[][] matrix = getRandomMatrix(numRows, 1, -1, 1, 1, seed);
+// for(int i = 0; i < numRows; i++)
+// matrix[i][0] = (matrix[i][0] > 0) ? 1 : -1;
+// MatrixCharacteristics mc = new MatrixCharacteristics(numRows,
1, blocksize, numRows);
+// writeInputMatrixWithMTD(matrixName, matrix, false, mc,
privacyConstraint);
+// }
+//
+// private void writeStandardMatrix(String matrixName, long seed, int
numRows, int numCols,
+// PrivacyConstraint privacyConstraint) {
+// double[][] matrix = getRandomMatrix(numRows, numCols, 0, 1, 1,
seed);
+// writeStandardMatrix(matrixName, numRows, numCols,
privacyConstraint, matrix);
+// }
+//
+// private void writeStandardMatrix(String matrixName, int numRows, int
numCols, PrivacyConstraint privacyConstraint,
+// double[][] matrix) {
+// MatrixCharacteristics mc = new MatrixCharacteristics(numRows,
numCols, blocksize, (long) numRows * numCols);
+// writeInputMatrixWithMTD(matrixName, matrix, false, mc,
privacyConstraint);
+// }
+//
+// private void loadAndRunTest(String[] expectedHeavyHitters, String
testName) {
+//
+// boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+// Types.ExecMode platformOld = rtplatform;
+// rtplatform = Types.ExecMode.SINGLE_NODE;
+//
+// Thread t1 = null, t2 = null;
+//
+// try {
+// getAndLoadTestConfiguration(testName);
+// String HOME = SCRIPT_DIR + TEST_DIR;
+//
+// writeInputMatrices();
+//
+// int port1 = getRandomAvailablePort();
+// int port2 = getRandomAvailablePort();
+// t1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
+// t2 = startLocalFedWorkerThread(port2);
+//
+// // Run actual dml script with federated matrix
+// fullDMLScriptName = HOME + testName + ".dml";
+// programArgs = new String[] {"-stats", "-nvargs",
+// "r=" + rows, "c=" + cols,
+// "A=" + input("A"),
+// "B1=" + TestUtils.federatedAddress(port1,
input("B1")),
+// "B2=" + TestUtils.federatedAddress(port2,
input("B2")),
+// "C1=" + TestUtils.federatedAddress(port1,
input("C1")),
+// "C2=" + TestUtils.federatedAddress(port2,
input("C2")),
+// "lB1=" + input("B1"),
+// "lB2=" + input("B2"),
+// "Z=" + output("Z")};
+// runTest(true, false, null, -1);
+//
+// // Run reference dml script with normal matrix
+// fullDMLScriptName = HOME + testName + "Reference.dml";
+// programArgs = new String[] {"-nvargs",
+// "r=" + rows, "c=" + cols,
+// "A=" + input("A"),
+// "B1=" + input("B1"),
+// "B2=" + input("B2"),
+// "C1=" + input("C1"),
+// "C2=" + input("C2"),
+// "Z=" + expected("Z")};
+// runTest(true, false, null, -1);
+//
+// // compare via files
+// compareResults(1e-9);
+// if(!heavyHittersContainsAllString(expectedHeavyHitters))
+// fail("The following expected heavy hitters are
missing: "
+// +
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+// }
+// finally {
+// TestUtils.shutdownThreads(t1, t2);
+// rtplatform = platformOld;
+// DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+// }
+// }
+//
+// /**
+// * Override default configuration with custom test configuration to
ensure scratch space and local temporary
+// * directory locations are also updated.
+// */
+// @Override
+// protected File getConfigTemplateFile() {
+// // Instrumentation in this test's output log to show custom
configuration file used for template.
+// LOG.info("This test case overrides default configuration with "
+ TEST_CONF_FILE.getPath());
+// return TEST_CONF_FILE;
+// }
+//
+//}
diff --git
a/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedKMeansPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedKMeansPlanningTest.java
new file mode 100644
index 0000000000..48d9a06b8c
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedKMeansPlanningTest.java
@@ -0,0 +1,168 @@
+package org.apache.sysds.test.functions.fedplanning;
+///*
+// * Licensed to the Apache Software Foundation (ASF) under one
+// * or more contributor license agreements. See the NOTICE file
+// * distributed with this work for additional information
+// * regarding copyright ownership. The ASF licenses this file
+// * to you under the Apache License, Version 2.0 (the
+// * "License"); you may not use this file except in compliance
+// * with the License. You may obtain a copy of the License at
+// *
+// * http://www.apache.org/licenses/LICENSE-2.0
+// *
+// * Unless required by applicable law or agreed to in writing,
+// * software distributed under the License is distributed on an
+// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// * KIND, either express or implied. See the License for the
+// * specific language governing permissions and limitations
+// * under the License.
+// */
+//
+//package org.apache.sysds.test.functions.privacy.fedplanning;
+//
+//import org.apache.commons.logging.Log;
+//import org.apache.commons.logging.LogFactory;
+//import org.apache.sysds.api.DMLScript;
+//import org.apache.sysds.common.Types;
+//import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+//import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+//import org.apache.sysds.test.AutomatedTestBase;
+//import org.apache.sysds.test.TestConfiguration;
+//import org.apache.sysds.test.TestUtils;
+//import org.junit.Ignore;
+//import org.junit.Test;
+//
+//import java.io.File;
+//import java.util.Arrays;
+//
+//import static org.junit.Assert.fail;
+//
+//public class FederatedKMeansPlanningTest extends AutomatedTestBase {
+// private static final Log LOG =
LogFactory.getLog(FederatedKMeansPlanningTest.class.getName());
+//
+// private final static String TEST_DIR = "functions/privacy/fedplanning/";
+// private final static String TEST_NAME = "FederatedKMeansPlanningTest";
+// private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedKMeansPlanningTest.class.getSimpleName() + "/";
+// private static File TEST_CONF_FILE;
+//
+// private final static int blocksize = 1024;
+// public final int rows = 1000;
+// public final int cols = 100;
+//
+// @Override
+// public void setUp() {
+// TestUtils.clearAssertionInformation();
+// addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+// }
+//
+// @Test
+// public void runKMeansFOUTTest(){
+// String[] expectedHeavyHitters = new String[]{};
+// setTestConf("SystemDS-config-fout.xml");
+// loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+// }
+//
+// @Test
+// public void runKMeansHeuristicTest(){
+// String[] expectedHeavyHitters = new String[]{};
+// setTestConf("SystemDS-config-heuristic.xml");
+// loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+// }
+//
+// @Test
+// public void runKMeansCostBasedTest(){
+// String[] expectedHeavyHitters = new String[]{ "fed_fedinit",
"fed_ba+*", "fed_*", "fed_uack+", "fed_bcumoffk+"};
+// setTestConf("SystemDS-config-cost-based.xml");
+// loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+// }
+//
+// @Test
+// public void runRuntimeTest(){
+// String[] expectedHeavyHitters = new String[]{};
+// TEST_CONF_FILE = new
File("src/test/config/SystemDS-config.xml");
+// loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+// }
+//
+// private void setTestConf(String test_conf){
+// TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, test_conf);
+// }
+//
+// /**
+// * Override default configuration with custom test configuration to
ensure
+// * scratch space and local temporary directory locations are also
updated.
+// */
+// @Override
+// protected File getConfigTemplateFile() {
+// // Instrumentation in this test's output log to show custom
configuration file used for template.
+// LOG.info("This test case overrides default configuration with "
+ TEST_CONF_FILE.getPath());
+// return TEST_CONF_FILE;
+// }
+//
+// private void writeInputMatrices(){
+// writeStandardRowFedMatrix("X1", 65, null);
+// writeStandardRowFedMatrix("X2", 75, null);
+// }
+//
+// private void writeStandardMatrix(String matrixName, long seed, int
numRows, PrivacyConstraint privacyConstraint){
+// double[][] matrix = getRandomMatrix(numRows, cols, 0, 1, 1,
seed);
+// writeStandardMatrix(matrixName, numRows, privacyConstraint,
matrix);
+// }
+//
+// private void writeStandardMatrix(String matrixName, int numRows,
PrivacyConstraint privacyConstraint, double[][] matrix){
+// MatrixCharacteristics mc = new MatrixCharacteristics(numRows,
cols, blocksize, (long) numRows * cols);
+// writeInputMatrixWithMTD(matrixName, matrix, false, mc,
privacyConstraint);
+// }
+//
+// private void writeStandardRowFedMatrix(String matrixName, long seed,
PrivacyConstraint privacyConstraint){
+// int halfRows = rows/2;
+// writeStandardMatrix(matrixName, seed, halfRows,
privacyConstraint);
+// }
+//
+// private void loadAndRunTest(String[] expectedHeavyHitters, String
testName){
+//
+// boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+// Types.ExecMode platformOld = rtplatform;
+// rtplatform = Types.ExecMode.SINGLE_NODE;
+//
+// Thread t1 = null, t2 = null;
+//
+// try {
+// getAndLoadTestConfiguration(testName);
+// String HOME = SCRIPT_DIR + TEST_DIR;
+//
+// writeInputMatrices();
+//
+// int port1 = getRandomAvailablePort();
+// int port2 = getRandomAvailablePort();
+// t1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
+// t2 = startLocalFedWorkerThread(port2);
+//
+// // Run actual dml script with federated matrix
+// fullDMLScriptName = HOME + testName + ".dml";
+// programArgs = new String[] { "-stats", "-nvargs",
+// "X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+// "X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+// "Y=" + input("Y"), "r=" + rows, "c=" + cols,
"Z=" + output("Z")};
+// runTest(true, false, null, -1);
+//
+// // Run reference dml script with normal matrix
+// fullDMLScriptName = HOME + testName + "Reference.dml";
+// programArgs = new String[] {"-nvargs", "X1=" +
input("X1"), "X2=" + input("X2"),
+// "Y=" + input("Y"), "Z=" + expected("Z")};
+// runTest(true, false, null, -1);
+//
+// // compare via files
+// compareResults(1e-9);
+// if
(!heavyHittersContainsAllString(expectedHeavyHitters))
+// fail("The following expected heavy hitters are
missing: "
+// +
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+// }
+// finally {
+// TestUtils.shutdownThreads(t1, t2);
+// rtplatform = platformOld;
+// DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+// }
+// }
+//
+//
+//}
diff --git
a/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedL2SVMPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedL2SVMPlanningTest.java
new file mode 100644
index 0000000000..0ef4fde6a4
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedL2SVMPlanningTest.java
@@ -0,0 +1,202 @@
+package org.apache.sysds.test.functions.fedplanning;
+///*
+// * Licensed to the Apache Software Foundation (ASF) under one
+// * or more contributor license agreements. See the NOTICE file
+// * distributed with this work for additional information
+// * regarding copyright ownership. The ASF licenses this file
+// * to you under the Apache License, Version 2.0 (the
+// * "License"); you may not use this file except in compliance
+// * with the License. You may obtain a copy of the License at
+// *
+// * http://www.apache.org/licenses/LICENSE-2.0
+// *
+// * Unless required by applicable law or agreed to in writing,
+// * software distributed under the License is distributed on an
+// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// * KIND, either express or implied. See the License for the
+// * specific language governing permissions and limitations
+// * under the License.
+// */
+//
+//package org.apache.sysds.test.functions.privacy.fedplanning;
+//
+//import org.apache.commons.logging.Log;
+//import org.apache.commons.logging.LogFactory;
+//import org.apache.sysds.api.DMLScript;
+//import org.apache.sysds.common.Types;
+//import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+//import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+//import org.apache.sysds.test.AutomatedTestBase;
+//import org.apache.sysds.test.TestConfiguration;
+//import org.apache.sysds.test.TestUtils;
+//import org.junit.Ignore;
+//import org.junit.Test;
+//
+//import java.io.File;
+//import java.util.Arrays;
+//
+//import static org.junit.Assert.fail;
+//
+//@net.jcip.annotations.NotThreadSafe
+//public class FederatedL2SVMPlanningTest extends AutomatedTestBase {
+// private static final Log LOG =
LogFactory.getLog(FederatedL2SVMPlanningTest.class.getName());
+//
+// private final static String TEST_DIR = "functions/privacy/fedplanning/";
+// private final static String TEST_NAME = "FederatedL2SVMPlanningTest";
+// private final static String TEST_NAME_2 =
"FederatedL2SVMFunctionPlanningTest";
+// private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedL2SVMPlanningTest.class.getSimpleName() + "/";
+// private static File TEST_CONF_FILE;
+//
+// private final static int blocksize = 1024;
+// public final int rows = 1000;
+// public final int cols = 100;
+//
+// @Override
+// public void setUp() {
+// TestUtils.clearAssertionInformation();
+// addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+// addTestConfiguration(TEST_NAME_2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_2, new String[] {"Z"}));
+// }
+//
+// @Test
+// public void runL2SVMFOUTTest(){
+// String[] expectedHeavyHitters = new String[]{ "fed_fedinit",
"fed_ba+*", "fed_tak+*", "fed_+*",
+// "fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
+// setTestConf("SystemDS-config-fout.xml");
+// loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+// }
+//
+// @Test
+// public void runL2SVMHeuristicTest(){
+// String[] expectedHeavyHitters = new String[]{ "fed_fedinit",
"fed_ba+*"};
+// setTestConf("SystemDS-config-heuristic.xml");
+// loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+// }
+//
+// @Test
+// public void runL2SVMCostBasedTest(){
+// String[] expectedHeavyHitters = new String[]{ "fed_fedinit",
"fed_ba+*", "fed_tak+*", "fed_+*",
+// "fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
+// setTestConf("SystemDS-config-cost-based.xml");
+// loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+// }
+//
+// @Test
+// public void runL2SVMFunctionFOUTTest(){
+// String[] expectedHeavyHitters = new String[]{ "fed_fedinit",
"fed_ba+*", "fed_tak+*", "fed_+*",
+// "fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
+// setTestConf("SystemDS-config-fout.xml");
+// loadAndRunTest(expectedHeavyHitters, TEST_NAME_2);
+// }
+//
+// @Test
+// public void runL2SVMFunctionHeuristicTest(){
+// String[] expectedHeavyHitters = new String[]{ "fed_fedinit",
"fed_ba+*"};
+// setTestConf("SystemDS-config-heuristic.xml");
+// loadAndRunTest(expectedHeavyHitters, TEST_NAME_2);
+// }
+//
+// @Test
+// public void runL2SVMFunctionCostBasedTest(){
+// String[] expectedHeavyHitters = new String[]{ "fed_fedinit",
"fed_ba+*", "fed_tak+*", "fed_+*",
+// "fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
+// setTestConf("SystemDS-config-cost-based.xml");
+// loadAndRunTest(expectedHeavyHitters, TEST_NAME_2);
+// }
+//
+// private void setTestConf(String test_conf){
+// TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, test_conf);
+// }
+//
+// private void writeInputMatrices(){
+// writeStandardRowFedMatrix("X1", 65, null);
+// writeStandardRowFedMatrix("X2", 75, null);
+// writeBinaryVector("Y", 44, null);
+// }
+//
+// private void writeBinaryVector(String matrixName, long seed,
PrivacyConstraint privacyConstraint){
+// double[][] matrix = getRandomMatrix(rows, 1, -1, 1, 1, seed);
+// for(int i = 0; i < rows; i++)
+// matrix[i][0] = (matrix[i][0] > 0) ? 1 : -1;
+// MatrixCharacteristics mc = new MatrixCharacteristics(rows, 1,
blocksize, rows);
+// writeInputMatrixWithMTD(matrixName, matrix, false, mc,
privacyConstraint);
+// }
+//
+// @SuppressWarnings("unused")
+// private void writeStandardMatrix(String matrixName, long seed,
PrivacyConstraint privacyConstraint){
+// writeStandardMatrix(matrixName, seed, rows, privacyConstraint);
+// }
+//
+// private void writeStandardMatrix(String matrixName, long seed, int
numRows, PrivacyConstraint privacyConstraint){
+// double[][] matrix = getRandomMatrix(numRows, cols, 0, 1, 1,
seed);
+// writeStandardMatrix(matrixName, numRows, privacyConstraint,
matrix);
+// }
+//
+// private void writeStandardMatrix(String matrixName, int numRows,
PrivacyConstraint privacyConstraint, double[][] matrix){
+// MatrixCharacteristics mc = new MatrixCharacteristics(numRows,
cols, blocksize, (long) numRows * cols);
+// writeInputMatrixWithMTD(matrixName, matrix, false, mc,
privacyConstraint);
+// }
+//
+// private void writeStandardRowFedMatrix(String matrixName, long seed,
PrivacyConstraint privacyConstraint){
+// int halfRows = rows/2;
+// writeStandardMatrix(matrixName, seed, halfRows,
privacyConstraint);
+// }
+//
+// private void loadAndRunTest(String[] expectedHeavyHitters, String
testName){
+//
+// boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+// Types.ExecMode platformOld = rtplatform;
+// rtplatform = Types.ExecMode.SINGLE_NODE;
+//
+// Thread t1 = null, t2 = null;
+//
+// try {
+// getAndLoadTestConfiguration(testName);
+// String HOME = SCRIPT_DIR + TEST_DIR;
+//
+// writeInputMatrices();
+//
+// int port1 = getRandomAvailablePort();
+// int port2 = getRandomAvailablePort();
+// t1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
+// t2 = startLocalFedWorkerThread(port2);
+//
+// // Run actual dml script with federated matrix
+// fullDMLScriptName = HOME + testName + ".dml";
+// programArgs = new String[] { "-stats", "-nvargs",
+// "X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+// "X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+// "Y=" + input("Y"), "r=" + rows, "c=" + cols,
"Z=" + output("Z")};
+// runTest(true, false, null, -1);
+//
+// // Run reference dml script with normal matrix
+// fullDMLScriptName = HOME + testName + "Reference.dml";
+// programArgs = new String[] {"-nvargs", "X1=" +
input("X1"), "X2=" + input("X2"),
+// "Y=" + input("Y"), "Z=" + expected("Z")};
+// runTest(true, false, null, -1);
+//
+// // compare via files
+// compareResults(1e-9);
+// if
(!heavyHittersContainsAllString(expectedHeavyHitters))
+// fail("The following expected heavy hitters are
missing: "
+// +
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+// }
+// finally {
+// TestUtils.shutdownThreads(t1, t2);
+// rtplatform = platformOld;
+// DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+// }
+// }
+//
+// /**
+// * Override default configuration with custom test configuration to
ensure
+// * scratch space and local temporary directory locations are also
updated.
+// */
+// @Override
+// protected File getConfigTemplateFile() {
+// // Instrumentation in this test's output log to show custom
configuration file used for template.
+// LOG.info("This test case overrides default configuration with "
+ TEST_CONF_FILE.getPath());
+// return TEST_CONF_FILE;
+// }
+//
+//}
diff --git
a/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedMultiplyPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedMultiplyPlanningTest.java
new file mode 100644
index 0000000000..f3eee4ee41
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/fedplanning/FederatedMultiplyPlanningTest.java
@@ -0,0 +1,334 @@
+package org.apache.sysds.test.functions.fedplanning;
+///*
+// * Licensed to the Apache Software Foundation (ASF) under one
+// * or more contributor license agreements. See the NOTICE file
+// * distributed with this work for additional information
+// * regarding copyright ownership. The ASF licenses this file
+// * to you under the Apache License, Version 2.0 (the
+// * "License"); you may not use this file except in compliance
+// * with the License. You may obtain a copy of the License at
+// *
+// * http://www.apache.org/licenses/LICENSE-2.0
+// *
+// * Unless required by applicable law or agreed to in writing,
+// * software distributed under the License is distributed on an
+// * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// * KIND, either express or implied. See the License for the
+// * specific language governing permissions and limitations
+// * under the License.
+// */
+//
+//package org.apache.sysds.test.functions.privacy.fedplanning;
+//
+//import org.apache.commons.logging.Log;
+//import org.apache.commons.logging.LogFactory;
+//import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+//import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
+//import org.junit.Test;
+//import org.junit.runner.RunWith;
+//import org.junit.runners.Parameterized;
+//import org.apache.sysds.api.DMLScript;
+//import org.apache.sysds.common.Types;
+//import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+//import org.apache.sysds.test.AutomatedTestBase;
+//import org.apache.sysds.test.TestConfiguration;
+//import org.apache.sysds.test.TestUtils;
+//
+//import java.io.File;
+//import java.util.Arrays;
+//import java.util.Collection;
+//
+//import static org.junit.Assert.fail;
+//
+//@RunWith(value = Parameterized.class)
+//@net.jcip.annotations.NotThreadSafe
+//public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
+// private static final Log LOG =
LogFactory.getLog(FederatedMultiplyPlanningTest.class.getName());
+//
+// private final static String TEST_DIR = "functions/privacy/fedplanning/";
+// private final static String TEST_NAME = "FederatedMultiplyPlanningTest";
+// private final static String TEST_NAME_2 =
"FederatedMultiplyPlanningTest2";
+// private final static String TEST_NAME_3 =
"FederatedMultiplyPlanningTest3";
+// private final static String TEST_NAME_4 =
"FederatedMultiplyPlanningTest4";
+// private final static String TEST_NAME_5 =
"FederatedMultiplyPlanningTest5";
+// private final static String TEST_NAME_6 =
"FederatedMultiplyPlanningTest6";
+// private final static String TEST_NAME_7 =
"FederatedMultiplyPlanningTest7";
+// private final static String TEST_NAME_8 =
"FederatedMultiplyPlanningTest8";
+// private final static String TEST_NAME_9 =
"FederatedMultiplyPlanningTest9";
+// private final static String TEST_NAME_10 =
"FederatedMultiplyPlanningTest10";
+// private final static String TEST_NAME_11 =
"FederatedMultiplyPlanningTest11";
+// private final static String TEST_NAME_12 =
"FederatedMultiplyPlanningTest12";
+// private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedMultiplyPlanningTest.class.getSimpleName() + "/";
+// private static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR,
"SystemDS-config-cost-based.xml");
+//
+// private final static int blocksize = 1024;
+// @Parameterized.Parameter()
+// public int rows;
+// @Parameterized.Parameter(1)
+// public int cols;
+//
+// @Override
+// public void setUp() {
+// TestUtils.clearAssertionInformation();
+// addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+// addTestConfiguration(TEST_NAME_2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_2, new String[] {"Z"}));
+// addTestConfiguration(TEST_NAME_3, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_3, new String[] {"Z.scalar"}));
+// addTestConfiguration(TEST_NAME_4, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_4, new String[] {"Z"}));
+// addTestConfiguration(TEST_NAME_5, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_5, new String[] {"Z"}));
+// addTestConfiguration(TEST_NAME_6, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_6, new String[] {"Z"}));
+// addTestConfiguration(TEST_NAME_7, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_7, new String[] {"Z"}));
+// addTestConfiguration(TEST_NAME_8, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_8, new String[] {"Z.scalar"}));
+// addTestConfiguration(TEST_NAME_9, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_9, new String[] {"Z.scalar"}));
+// addTestConfiguration(TEST_NAME_10, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_10, new String[] {"Z"}));
+// addTestConfiguration(TEST_NAME_11, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_11, new String[] {"Z"}));
+// addTestConfiguration(TEST_NAME_12, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_12, new String[] {"Z"}));
+// }
+//
+// @Parameterized.Parameters
+// public static Collection<Object[]> data() {
+// // rows have to be even and > 1
+// return Arrays.asList(new Object[][] {
+// {100, 10}
+// });
+// }
+//
+// @Test
+// public void federatedMultiplyCP() {
+// String[] expectedHeavyHitters = new String[]{"fed_*",
"fed_fedinit", "fed_r'", "fed_ba+*"};
+// federatedTwoMatricesSingleNodeTest(TEST_NAME,
expectedHeavyHitters);
+// }
+//
+// @Test
+// public void federatedRowSum(){
+// String[] expectedHeavyHitters = new String[]{"fed_*", "fed_r'",
"fed_fedinit", "fed_ba+*", "fed_uark+"};
+// federatedTwoMatricesSingleNodeTest(TEST_NAME_2,
expectedHeavyHitters);
+// }
+//
+// @Test
+// public void federatedTernarySequence(){
+// String[] expectedHeavyHitters = new String[]{"fed_+*",
"fed_1-*", "fed_fedinit", "fed_uak+"};
+// federatedTwoMatricesSingleNodeTest(TEST_NAME_3,
expectedHeavyHitters);
+// }
+//
+// @Test
+// public void federatedAggregateBinarySequence(){
+// cols = rows;
+// String[] expectedHeavyHitters = new String[]{"fed_ba+*",
"fed_*", "fed_fedinit"};
+// federatedTwoMatricesSingleNodeTest(TEST_NAME_4,
expectedHeavyHitters);
+// }
+//
+// @Test
+// public void federatedAggregateBinaryColFedSequence(){
+// cols = rows;
+// //TODO: When alignment checks have been added to
getFederatedOut in AFederatedPlanner,
+// // the following expectedHeavyHitters can be added. Until then,
fed_* will not be generated.
+// //String[] expectedHeavyHitters = new
String[]{"fed_ba+*","fed_*","fed_fedinit"};
+// String[] expectedHeavyHitters = new
String[]{"fed_ba+*","fed_fedinit"};
+// federatedTwoMatricesSingleNodeTest(TEST_NAME_5,
expectedHeavyHitters);
+// }
+//
+// @Test
+// public void federatedAggregateBinarySequence2(){
+// String[] expectedHeavyHitters = new
String[]{"fed_ba+*","fed_fedinit"};
+// federatedTwoMatricesSingleNodeTest(TEST_NAME_6,
expectedHeavyHitters);
+// }
+//
+// @Test
+// public void federatedMultiplyDoubleHop() {
+// String[] expectedHeavyHitters = new String[]{"fed_*",
"fed_fedinit", "fed_r'", "fed_ba+*"};
+// federatedTwoMatricesSingleNodeTest(TEST_NAME_7,
expectedHeavyHitters);
+// }
+//
+// @Test
+// public void federatedMultiplyDoubleHop2() {
+// String[] expectedHeavyHitters = new String[]{"fed_fedinit",
"fed_ba+*"};
+// federatedTwoMatricesSingleNodeTest(TEST_NAME_8,
expectedHeavyHitters);
+// }
+//
+// @Test
+// public void federatedMultiplyPlanningTest9(){
+// String[] expectedHeavyHitters = new String[]{"fed_+*",
"fed_1-*", "fed_fedinit", "fed_tak+*", "fed_max"};
+// federatedTwoMatricesSingleNodeTest(TEST_NAME_9,
expectedHeavyHitters);
+// }
+//
+// @Test
+// public void federatedMultiplyPlanningTest10(){
+// String[] expectedHeavyHitters = new String[]{"fed_fedinit",
"fed_^2"};
+// TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR,
"SystemDS-config-fout.xml");
+// federatedTwoMatricesSingleNodeTest(TEST_NAME_10,
expectedHeavyHitters);
+// }
+//
+// @Test
+// public void federatedMultiplyPlanningTest11(){
+// String[] expectedHeavyHitters = new String[]{"fed_fedinit"};
+// federatedTwoMatricesSingleNodeTest(TEST_NAME_11,
expectedHeavyHitters);
+// }
+//
+// @Test
+// public void federatedMultiplyPlanningTest12(){
+// String[] expectedHeavyHitters = new String[]{"fed_fedinit"};
+// rows = 30;
+// cols = 30;
+// federatedTwoMatricesSingleNodeTest(TEST_NAME_12,
expectedHeavyHitters);
+// }
+//
+// private void writeStandardMatrix(String matrixName, long seed){
+// writeStandardMatrix(matrixName, seed, new
PrivacyConstraint(PrivacyConstraint.PrivacyLevel.PrivateAggregation));
+// }
+//
+// private void writeStandardMatrix(String matrixName, long seed,
PrivacyConstraint privacyConstraint){
+// int halfRows = rows/2;
+// double[][] matrix = getRandomMatrix(halfRows, cols, 0, 1, 1,
seed);
+// MatrixCharacteristics mc = new MatrixCharacteristics(halfRows,
cols, blocksize, (long) halfRows * cols);
+// writeInputMatrixWithMTD(matrixName, matrix, false, mc,
privacyConstraint);
+// }
+//
+// private void writeColStandardMatrix(String matrixName, long seed){
+// writeColStandardMatrix(matrixName, seed, new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+// }
+//
+// private void writeColStandardMatrix(String matrixName, long seed,
PrivacyConstraint privacyConstraint){
+// int halfCols = cols/2;
+// double[][] matrix = getRandomMatrix(rows, halfCols, 0, 1, 1,
seed);
+// MatrixCharacteristics mc = new MatrixCharacteristics(rows,
halfCols, blocksize, (long) halfCols *rows);
+// writeInputMatrixWithMTD(matrixName, matrix, false, mc,
privacyConstraint);
+// }
+//
+// private void writeRowFederatedVector(String matrixName, long seed){
+// writeRowFederatedVector(matrixName, seed, new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
+// }
+//
+// private void writeRowFederatedVector(String matrixName, long seed,
PrivacyConstraint privacyConstraint){
+// int halfCols = cols / 2;
+// double[][] matrix = getRandomMatrix(halfCols, 1, 0, 1, 1, seed);
+// MatrixCharacteristics mc = new MatrixCharacteristics(halfCols,
1, blocksize, (long) halfCols *rows);
+// writeInputMatrixWithMTD(matrixName, matrix, false, mc,
privacyConstraint);
+// }
+//
+// private void writeInputMatrices(String testName){
+// if ( testName.equals(TEST_NAME_5) ){
+// writeColStandardMatrix("X1", 42);
+// writeColStandardMatrix("X2", 1340);
+// writeColStandardMatrix("Y1", 44, null);
+// writeColStandardMatrix("Y2", 21, null);
+// }
+// else if ( testName.equals(TEST_NAME_6) ){
+// writeColStandardMatrix("X1", 42);
+// writeColStandardMatrix("X2", 1340);
+// writeRowFederatedVector("Y1", 44);
+// writeRowFederatedVector("Y2", 21);
+// }
+// else if ( testName.equals(TEST_NAME_8) ){
+// writeColStandardMatrix("X1", 42, null);
+// writeColStandardMatrix("X2", 1340, null);
+// writeColStandardMatrix("Y1", 44, null);
+// writeColStandardMatrix("Y2", 21, null);
+// writeColStandardMatrix("W1", 76, null);
+// writeColStandardMatrix("W2", 11, null);
+// }
+// else if ( testName.equals(TEST_NAME_10) ||
testName.equals(TEST_NAME_12) ){
+// writeStandardMatrix("X1", 42, null);
+// writeStandardMatrix("X2", 1340, null);
+// }
+// else {
+// writeStandardMatrix("X1", 42);
+// writeStandardMatrix("X2", 1340);
+// if ( testName.equals(TEST_NAME_4) ){
+// writeStandardMatrix("Y1", 44, null);
+// writeStandardMatrix("Y2", 21, null);
+// }
+// else {
+// writeStandardMatrix("Y1", 44);
+// writeStandardMatrix("Y2", 21);
+// }
+// }
+// }
+//
+// private void federatedTwoMatricesSingleNodeTest(String testName,
String[] expectedHeavyHitters){
+// federatedTwoMatricesTest(Types.ExecMode.SINGLE_NODE, testName,
expectedHeavyHitters);
+// }
+//
+// private void federatedTwoMatricesTest(Types.ExecMode execMode, String
testName, String[] expectedHeavyHitters) {
+// boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+// Types.ExecMode platformOld = rtplatform;
+// rtplatform = execMode;
+// if(rtplatform == Types.ExecMode.SPARK) {
+// DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+// }
+// Thread t1 = null, t2 = null;
+//
+// try{
+// getAndLoadTestConfiguration(testName);
+// String HOME = SCRIPT_DIR + TEST_DIR;
+//
+// writeInputMatrices(testName);
+//
+// int port1 = getRandomAvailablePort();
+// int port2 = getRandomAvailablePort();
+// t1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
+// t2 = startLocalFedWorkerThread(port2);
+//
+// // Run actual dml script with federated matrix
+// fullDMLScriptName = HOME + testName + ".dml";
+// programArgs = new String[] {"-stats", "-nvargs", "X1="
+ TestUtils.federatedAddress(port1, input("X1")),
+// "X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+// "Y1=" + TestUtils.federatedAddress(port1,
input("Y1")),
+// "Y2=" + TestUtils.federatedAddress(port2,
input("Y2")), "r=" + rows, "c=" + cols, "Z=" + output("Z")};
+// rewriteRealProgramArgs(testName, port1, port2);
+// runTest(true, false, null, -1);
+//
+// // Run reference dml script with normal matrix
+// fullDMLScriptName = HOME + testName + "Reference.dml";
+// programArgs = new String[] {"-nvargs", "X1=" +
input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
+// "Y2=" + input("Y2"), "Z=" + expected("Z")};
+// rewriteReferenceProgramArgs(testName);
+// runTest(true, false, null, -1);
+//
+// // compare via files
+// compareResults(1e-9);
+// if
(!heavyHittersContainsAllString(expectedHeavyHitters))
+// fail("The following expected heavy hitters are
missing: "
+// +
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+// } finally {
+// TestUtils.shutdownThreads(t1, t2);
+// rtplatform = platformOld;
+// DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+// }
+// }
+//
+// private void rewriteRealProgramArgs(String testName, int port1, int
port2){
+// if ( testName.equals(TEST_NAME_4) ||
testName.equals(TEST_NAME_5) ){
+// programArgs = new String[] {"-stats","-nvargs", "X1=" +
TestUtils.federatedAddress(port1, input("X1")),
+// "X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+// "Y1=" + input("Y1"),
+// "Y2=" + input("Y2"), "r=" + rows, "c=" + cols,
"Z=" + output("Z")};
+// } else if ( testName.equals(TEST_NAME_8) ){
+// programArgs = new String[] {"-stats","-nvargs", "X1=" +
TestUtils.federatedAddress(port1, input("X1")),
+// "X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+// "Y1=" + TestUtils.federatedAddress(port1,
input("Y1")),
+// "Y2=" + TestUtils.federatedAddress(port2,
input("Y2")),
+// "W1=" + input("W1"),
+// "W2=" + input("W2"),
+// "r=" + rows, "c=" + cols, "Z=" + output("Z")};
+// }
+// }
+//
+// private void rewriteReferenceProgramArgs(String testName){
+// if ( testName.equals(TEST_NAME_8) ){
+// programArgs = new String[] {"-nvargs", "X1=" +
input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
+// "Y2=" + input("Y2"), "W1=" + input("W1"), "W2="
+ input("W2"), "Z=" + expected("Z")};
+// }
+// }
+//
+// /**
+// * Override default configuration with custom test configuration to
ensure
+// * scratch space and local temporary directory locations are also
updated.
+// */
+// @Override
+// protected File getConfigTemplateFile() {
+// // Instrumentation in this test's output log to show custom
configuration file used for template.
+// LOG.info("This test case overrides default configuration with "
+ TEST_CONF_FILE.getPath());
+// return TEST_CONF_FILE;
+// }
+//}
+//