This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 65d6ad3 [SYSTEMDS-2747] Federated ALS-CG test, extended federated
binary ops
65d6ad3 is described below
commit 65d6ad393acea7f309c1cff4e768cec0b5998d15
Author: ywcb00 <[email protected]>
AuthorDate: Sat Jan 30 23:23:19 2021 +0100
[SYSTEMDS-2747] Federated ALS-CG test, extended federated binary ops
Closes #1170.
---
.../fed/BinaryMatrixMatrixFEDInstruction.java | 33 ++-
.../instructions/fed/FEDInstructionUtils.java | 16 ++
.../federated/algorithms/FederatedAlsCGTest.java | 170 ++++++++++++
.../federated/primitives/FederatedLogicalTest.java | 305 +++++++++++++++++++++
.../functions/federated/FederatedAlsCGTest.dml | 35 +++
.../federated/FederatedAlsCGTestReference.dml | 34 +++
.../binary/FederatedLogicalMatrixMatrixTest.dml | 41 +++
.../FederatedLogicalMatrixMatrixTestReference.dml | 40 +++
.../binary/FederatedLogicalMatrixScalarTest.dml | 41 +++
.../FederatedLogicalMatrixScalarTestReference.dml | 40 +++
10 files changed, 748 insertions(+), 7 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index ea34df1..0ba1935 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -23,6 +23,7 @@ import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
@@ -39,15 +40,15 @@ public class BinaryMatrixMatrixFEDInstruction extends
BinaryFEDInstruction
public void processInstruction(ExecutionContext ec) {
MatrixObject mo1 = ec.getMatrixObject(input1);
MatrixObject mo2 = ec.getMatrixObject(input2);
-
+
//canonicalization for federated lhs
- if( !mo1.isFederated() && mo2.isFederated()
- &&
mo1.getDataCharacteristics().equalDims(mo2.getDataCharacteristics())
+ if( !mo1.isFederated() && mo2.isFederated()
+ &&
mo1.getDataCharacteristics().equalDims(mo2.getDataCharacteristics())
&& ((BinaryOperator)_optr).isCommutative() ) {
mo1 = ec.getMatrixObject(input2);
mo2 = ec.getMatrixObject(input1);
}
-
+
//execute federated operation on mo1 or mo2
FederatedRequest fr2 = null;
if( mo2.isFederated() ) {
@@ -63,7 +64,7 @@ public class BinaryMatrixMatrixFEDInstruction extends
BinaryFEDInstruction
}
else {
//matrix-matrix binary operations -> lhs fed input ->
fed output
- if(mo2.getNumRows() > 1 && mo2.getNumColumns() == 1 ) {
//MV row vector
+ if(mo2.getNumRows() > 1 && mo2.getNumColumns() == 1 ) {
//MV col vector
FederatedRequest[] fr1 =
mo1.getFedMapping().broadcastSliced(mo2, false);
fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1,
input2},
new long[]{mo1.getFedMapping().getID(),
fr1[0].getID()});
@@ -71,7 +72,7 @@ public class BinaryMatrixMatrixFEDInstruction extends
BinaryFEDInstruction
//execute federated instruction and cleanup
intermediates
mo1.getFedMapping().execute(getTID(), true,
fr1, fr2, fr3);
}
- else { //MM or MV col vector
+ else if(mo2.getNumRows() == 1 && mo2.getNumColumns() >
1) { //MV row vector
FederatedRequest fr1 =
mo1.getFedMapping().broadcast(mo2);
fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1,
input2},
new long[]{mo1.getFedMapping().getID(),
fr1.getID()});
@@ -79,8 +80,26 @@ public class BinaryMatrixMatrixFEDInstruction extends
BinaryFEDInstruction
//execute federated instruction and cleanup
intermediates
mo1.getFedMapping().execute(getTID(), true,
fr1, fr2, fr3);
}
+ else { //MM
+ if(mo1.isFederated(FType.ROW)) {
+ FederatedRequest[] fr1 =
mo1.getFedMapping().broadcastSliced(mo2, false);
+ fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1,
input2},
+ new
long[]{mo1.getFedMapping().getID(), fr1[0].getID()});
+ FederatedRequest fr3 =
mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+ //execute federated instruction and
cleanup intermediates
+ mo1.getFedMapping().execute(getTID(),
true, fr1, fr2, fr3);
+ }
+ else {
+ FederatedRequest fr1 =
mo1.getFedMapping().broadcast(mo2);
+ fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1,
input2},
+ new
long[]{mo1.getFedMapping().getID(), fr1.getID()});
+ FederatedRequest fr3 =
mo1.getFedMapping().cleanup(getTID(), fr1.getID());
+ //execute federated instruction and
cleanup intermediates
+ mo1.getFedMapping().execute(getTID(),
true, fr1, fr2, fr3);
+ }
+ }
}
-
+
//derive new fed mapping for output
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(mo1.getDataCharacteristics());
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index fbdb3a2..845f8a4 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.fed;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.caching.TensorObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.instructions.Instruction;
@@ -43,7 +44,11 @@ import
org.apache.sysds.runtime.instructions.cp.VariableCPInstruction.VariableOp
import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendGAlignedSPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendGSPInstruction;
+import
org.apache.sysds.runtime.instructions.spark.BinaryMatrixMatrixSPInstruction;
+import
org.apache.sysds.runtime.instructions.spark.BinaryMatrixScalarSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinarySPInstruction;
+import
org.apache.sysds.runtime.instructions.spark.BinaryTensorTensorBroadcastSPInstruction;
+import
org.apache.sysds.runtime.instructions.spark.BinaryTensorTensorSPInstruction;
import org.apache.sysds.runtime.instructions.spark.CentralMomentSPInstruction;
import org.apache.sysds.runtime.instructions.spark.MapmmSPInstruction;
import org.apache.sysds.runtime.instructions.spark.QuantilePickSPInstruction;
@@ -254,6 +259,17 @@ public class FEDInstructionUtils {
fedinst =
AppendFEDInstruction.parseInstruction(instruction.getInstructionString());
}
}
+ else if (inst instanceof BinaryMatrixScalarSPInstruction
+ || inst instanceof
BinaryMatrixMatrixSPInstruction
+ || inst instanceof
BinaryTensorTensorSPInstruction
+ || inst instanceof
BinaryTensorTensorBroadcastSPInstruction) {
+ BinarySPInstruction instruction =
(BinarySPInstruction) inst;
+ Data data = ec.getVariable(instruction.input1);
+ if((data instanceof MatrixObject &&
((MatrixObject)data).isFederated())
+ || (data instanceof TensorObject &&
((TensorObject)data).isFederated())) {
+ fedinst =
BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
+ }
+ }
}
else if (inst instanceof WriteSPInstruction) {
WriteSPInstruction instruction = (WriteSPInstruction)
inst;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
new file mode 100644
index 0000000..4909f7c
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.algorithms;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedAlsCGTest extends AutomatedTestBase
+{
+ private final static String TEST_NAME = "FederatedAlsCGTest";
+ private final static String TEST_DIR = "functions/federated/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedAlsCGTest.class.getSimpleName() + "/";
+
+ private final static String OUTPUT_NAME = "Z";
+ private final static double TOLERANCE = 0.01;
+ private final static int blocksize = 1024;
+
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+ @Parameterized.Parameter(2)
+ public int rank;
+ @Parameterized.Parameter(3)
+ public String regression;
+ @Parameterized.Parameter(4)
+ public double lambda;
+ @Parameterized.Parameter(5)
+ public int max_iter;
+ @Parameterized.Parameter(6)
+ public double threshold;
+ @Parameterized.Parameter(7)
+ public double sparsity;
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{OUTPUT_NAME}));
+ }
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ // rows must be even
+ return Arrays.asList(new Object[][] {
+ // {rows, cols, rank, regression, lambda, max_iter,
threshold, sparsity}
+ {30, 15, 10, "L2", 0.0000001, 50, 0.000001, 1},
+ {30, 15, 10, "wL2", 0.0000001, 50, 0.000001, 1}
+ });
+ }
+
+ @BeforeClass
+ public static void init() {
+ TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+ }
+
+ @Test
+ public void federatedAlsCGSingleNode() {
+ federatedAlsCG(TEST_NAME, ExecMode.SINGLE_NODE);
+ }
+
+// @Test
+// public void federatedAlsCGSpark() {
+// federatedAlsCG(TEST_NAME, ExecMode.SPARK);
+// }
+
+//
-----------------------------------------------------------------------------
+
+ public void federatedAlsCG(String testname, ExecMode execMode)
+ {
+ // store the previous platform config to restore it after the
test
+ ExecMode platform_old = setExecMode(execMode);
+
+ getAndLoadTestConfiguration(testname);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ int fed_rows = rows / 2;
+ int fed_cols = cols;
+
+ double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 1, 2,
sparsity, 13);
+ double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 1, 2,
sparsity, 2);
+
+ writeInputMatrixWithMTD("X1", X1, false, new
MatrixCharacteristics(
+ fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+ writeInputMatrixWithMTD("X2", X2, false, new
MatrixCharacteristics(
+ fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+
+ // empty script name because we don't execute any script, just
start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ Thread thread1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
+ Thread thread2 = startLocalFedWorkerThread(port2);
+
+ getAndLoadTestConfiguration(testname);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + testname + "Reference.dml";
+ programArgs = new String[] {"-stats", "-nvargs",
+ "in_X1=" + input("X1"), "in_X2=" + input("X2"),
"in_rank=" + Integer.toString(rank),
+ "in_reg=" + regression, "in_lambda=" +
Double.toString(lambda),
+ "in_maxi=" + Integer.toString(max_iter), "in_thr=" +
Double.toString(threshold),
+ "out_Z=" + expected(OUTPUT_NAME)};
+ runTest(true, false, null, -1);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + testname + ".dml";
+ programArgs = new String[] {"-stats", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "in_rank=" + Integer.toString(rank),
+ "in_reg=" + regression,
+ "in_lambda=" + Double.toString(lambda),
+ "in_maxi=" + Integer.toString(max_iter),
+ "in_thr=" + Double.toString(threshold),
+ "rows=" + fed_rows, "cols=" + fed_cols,
+ "out_Z=" + output(OUTPUT_NAME)};
+ runTest(true, false, null, -1);
+
+ // compare the results via files
+ HashMap<CellIndex, Double> refResults =
readDMLMatrixFromExpectedDir(OUTPUT_NAME);
+ HashMap<CellIndex, Double> fedResults =
readDMLMatrixFromOutputDir(OUTPUT_NAME);
+ TestUtils.compareMatrices(fedResults, refResults, TOLERANCE,
"Fed", "Ref");
+
+ TestUtils.shutdownThreads(thread1, thread2);
+
+ // check for federated operations
+ Assert.assertTrue(heavyHittersContainsString("fed_!="));
+ Assert.assertTrue(heavyHittersContainsString("fed_fedinit"));
+ Assert.assertTrue(heavyHittersContainsString("fed_wdivmm"));
+ Assert.assertTrue(heavyHittersContainsString("fed_wsloss"));
+
+ // check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+
+ resetExecMode(platform_old);
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java
new file mode 100644
index 0000000..53dfb2e
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java
@@ -0,0 +1,305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedLogicalTest extends AutomatedTestBase
+{
+ private final static String SCALAR_TEST_NAME =
"FederatedLogicalMatrixScalarTest";
+ private final static String MATRIX_TEST_NAME =
"FederatedLogicalMatrixMatrixTest";
+ private final static String TEST_DIR = "functions/federated/binary/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedLogicalTest.class.getSimpleName() + "/";
+
+ private final static String OUTPUT_NAME = "Z";
+ private final static double TOLERANCE = 0;
+ private final static int blocksize = 1024;
+
+ public enum Type{
+ GREATER,
+ LESS,
+ EQUALS,
+ NOT_EQUALS,
+ GREATER_EQUALS,
+ LESS_EQUALS
+ }
+
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+ @Parameterized.Parameter(2)
+ public double sparsity;
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(SCALAR_TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, SCALAR_TEST_NAME, new String[]{OUTPUT_NAME}));
+ addTestConfiguration(MATRIX_TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, MATRIX_TEST_NAME, new String[]{OUTPUT_NAME}));
+ }
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ // rows must be even
+ return Arrays.asList(new Object[][] {
+ // {rows, cols, sparsity}
+ {100, 75, 0.01},
+ {100, 75, 0.9},
+ {2, 75, 0.01},
+ {2, 75, 0.9}
+ });
+ }
+
+ @BeforeClass
+ public static void init() {
+ TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+ }
+
+ //---------------------------MATRIX SCALAR--------------------------
+ @Test
+ public void federatedLogicalScalarGreaterSingleNode() {
+ federatedLogicalTest(SCALAR_TEST_NAME, Type.GREATER,
ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedLogicalScalarGreaterSpark() {
+ federatedLogicalTest(SCALAR_TEST_NAME, Type.GREATER,
ExecMode.SPARK);
+ }
+
+ @Test
+ public void federatedLogicalScalarLessSingleNode() {
+ federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS,
ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedLogicalScalarLessSpark() {
+ federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS,
ExecMode.SPARK);
+ }
+
+ @Test
+ public void federatedLogicalScalarEqualsSingleNode() {
+ federatedLogicalTest(SCALAR_TEST_NAME, Type.EQUALS,
ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedLogicalScalarEqualsSpark() {
+ federatedLogicalTest(SCALAR_TEST_NAME, Type.EQUALS,
ExecMode.SPARK);
+ }
+
+ @Test
+ public void federatedLogicalScalarNotEqualsSingleNode() {
+ federatedLogicalTest(SCALAR_TEST_NAME, Type.NOT_EQUALS,
ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedLogicalScalarNotEqualsSpark() {
+ federatedLogicalTest(SCALAR_TEST_NAME, Type.NOT_EQUALS,
ExecMode.SPARK);
+ }
+
+ @Test
+ public void federatedLogicalScalarGreaterEqualsSingleNode() {
+ federatedLogicalTest(SCALAR_TEST_NAME, Type.GREATER_EQUALS,
ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedLogicalScalarGreaterEqualsSpark() {
+ federatedLogicalTest(SCALAR_TEST_NAME, Type.GREATER_EQUALS,
ExecMode.SPARK);
+ }
+
+ @Test
+ public void federatedLogicalScalarLessEqualsSingleNode() {
+ federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS_EQUALS,
ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedLogicalScalarLessEqualsSpark() {
+ federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS_EQUALS,
ExecMode.SPARK);
+ }
+
+ //---------------------------MATRIX MATRIX--------------------------
+ @Test
+ public void federatedLogicalMatrixGreaterSingleNode() {
+ federatedLogicalTest(MATRIX_TEST_NAME, Type.GREATER,
ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedLogicalMatrixGreaterSpark() {
+ federatedLogicalTest(MATRIX_TEST_NAME, Type.GREATER,
ExecMode.SPARK);
+ }
+
+ @Test
+ public void federatedLogicalMatrixLessSingleNode() {
+ federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS,
ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedLogicalMatrixLessSpark() {
+ federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS,
ExecMode.SPARK);
+ }
+
+ @Test
+ public void federatedLogicalMatrixEqualsSingleNode() {
+ federatedLogicalTest(MATRIX_TEST_NAME, Type.EQUALS,
ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedLogicalMatrixEqualsSpark() {
+ federatedLogicalTest(MATRIX_TEST_NAME, Type.EQUALS,
ExecMode.SPARK);
+ }
+
+ @Test
+ public void federatedLogicalMatrixNotEqualsSingleNode() {
+ federatedLogicalTest(MATRIX_TEST_NAME, Type.NOT_EQUALS,
ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedLogicalMatrixNotEqualsSpark() {
+ federatedLogicalTest(MATRIX_TEST_NAME, Type.NOT_EQUALS,
ExecMode.SPARK);
+ }
+
+ @Test
+ public void federatedLogicalMatrixGreaterEqualsSingleNode() {
+ federatedLogicalTest(MATRIX_TEST_NAME, Type.GREATER_EQUALS,
ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedLogicalMatrixGreaterEqualsSpark() {
+ federatedLogicalTest(MATRIX_TEST_NAME, Type.GREATER_EQUALS,
ExecMode.SPARK);
+ }
+
+ @Test
+ public void federatedLogicalMatrixLessEqualsSingleNode() {
+ federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS_EQUALS,
ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedLogicalMatrixLessEqualsSpark() {
+ federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS_EQUALS,
ExecMode.SPARK);
+ }
+
+//
-----------------------------------------------------------------------------
+
+ public void federatedLogicalTest(String testname, Type op_type,
ExecMode execMode)
+ {
+ // store the previous platform config to restore it after the
test
+ ExecMode platform_old = setExecMode(execMode);
+
+ getAndLoadTestConfiguration(testname);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ int fed_rows = rows / 2;
+ int fed_cols = cols;
+
+ // generate dataset
+ // matrix handled by two federated workers
+ double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 1, 2, 1,
13);
+ double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 1, 2, 1, 2);
+
+ writeInputMatrixWithMTD("X1", X1, false, new
MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+ writeInputMatrixWithMTD("X2", X2, false, new
MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+
+ boolean is_matrix_test = testname.equals(MATRIX_TEST_NAME);
+
+ double[][] Y_mat = null;
+ double Y_scal = 0;
+ if(is_matrix_test) {
+ Y_mat = getRandomMatrix(rows, cols, 0, 1, sparsity,
5040);
+ writeInputMatrixWithMTD("Y", Y_mat, true);
+ }
+
+ // empty script name because we don't execute any script, just
start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ Thread thread1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
+ Thread thread2 = startLocalFedWorkerThread(port2);
+
+ getAndLoadTestConfiguration(testname);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + testname + "Reference.dml";
+ programArgs = new String[] {"-nvargs", "in_X1=" + input("X1"),
"in_X2=" + input("X2"),
+ "in_Y=" + (is_matrix_test ? input("Y") :
Double.toString(Y_scal)),
+ "in_op_type=" + Integer.toString(op_type.ordinal()),
+ "out_Z=" + expected(OUTPUT_NAME)};
+ runTest(true, false, null, -1);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + testname + ".dml";
+ programArgs = new String[] {"-stats", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")), "in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
+ "in_Y=" + (is_matrix_test ? input("Y") :
Double.toString(Y_scal)),
+ "in_op_type=" + Integer.toString(op_type.ordinal()),
+ "rows=" + fed_rows, "cols=" + fed_cols, "out_Z=" +
output(OUTPUT_NAME)};
+ runTest(true, false, null, -1);
+
+ // compare the results via files
+ HashMap<CellIndex, Double> refResults =
readDMLMatrixFromExpectedDir(OUTPUT_NAME);
+ HashMap<CellIndex, Double> fedResults =
readDMLMatrixFromOutputDir(OUTPUT_NAME);
+ TestUtils.compareMatrices(fedResults, refResults, TOLERANCE,
"Fed", "Ref");
+
+ TestUtils.shutdownThreads(thread1, thread2);
+
+ // check for federated operations
+ switch(op_type)
+ {
+ case GREATER:
+
Assert.assertTrue(heavyHittersContainsString("fed_>"));
+ break;
+ case LESS:
+
Assert.assertTrue(heavyHittersContainsString("fed_<"));
+ break;
+ case EQUALS:
+
Assert.assertTrue(heavyHittersContainsString("fed_=="));
+ break;
+ case NOT_EQUALS:
+
Assert.assertTrue(heavyHittersContainsString("fed_!="));
+ break;
+ case GREATER_EQUALS:
+
Assert.assertTrue(heavyHittersContainsString("fed_>="));
+ break;
+ case LESS_EQUALS:
+
Assert.assertTrue(heavyHittersContainsString("fed_<="));
+ break;
+ }
+
+ // check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+
+ resetExecMode(platform_old);
+ }
+}
diff --git a/src/test/scripts/functions/federated/FederatedAlsCGTest.dml
b/src/test/scripts/functions/federated/FederatedAlsCGTest.dml
new file mode 100644
index 0000000..05258f4
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedAlsCGTest.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2,
$cols)));
+
+rank = $in_rank;
+reg = $in_reg;
+lambda = $in_lambda;
+maxi = $in_maxi;
+thr = $in_thr;
+
+[U, V] = alsCG(X = X, rank = rank, reg = reg, lambda = lambda, maxi = maxi,
check = TRUE, thr = thr);
+
+Z = U %*% V;
+
+write(Z, $out_Z);
diff --git
a/src/test/scripts/functions/federated/FederatedAlsCGTestReference.dml
b/src/test/scripts/functions/federated/FederatedAlsCGTestReference.dml
new file mode 100644
index 0000000..a73efba
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedAlsCGTestReference.dml
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2));
+
+rank = $in_rank;
+reg = $in_reg;
+lambda = $in_lambda;
+maxi = $in_maxi;
+thr = $in_thr;
+
+[U, V] = alsCG(X = X, rank = rank, reg = reg, lambda = lambda, maxi = maxi,
check = TRUE, thr = thr);
+
+Z = U %*% V;
+
+write(Z, $out_Z);
diff --git
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTest.dml
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTest.dml
new file mode 100644
index 0000000..7fe350c
--- /dev/null
+++
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTest.dml
@@ -0,0 +1,41 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2,
$cols)));
+
+Y = read($in_Y);
+op_type = $in_op_type;
+
+if(op_type == 0)
+ Z = (X > Y)
+else if(op_type == 1)
+ Z = (X < Y)
+else if(op_type == 2)
+ Z = (X == Y)
+else if(op_type == 3)
+ Z = (X != Y)
+else if(op_type == 4)
+ Z = (X >= Y)
+else if(op_type == 5)
+ Z = (X <= Y)
+
+write(Z, $out_Z);
diff --git
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
new file mode 100644
index 0000000..1285eb0
--- /dev/null
+++
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
@@ -0,0 +1,40 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2));
+
+Y = read($in_Y);
+op_type = $in_op_type;
+
+if(op_type == 0)
+ Z = (X > Y)
+else if(op_type == 1)
+ Z = (X < Y)
+else if(op_type == 2)
+ Z = (X == Y)
+else if(op_type == 3)
+ Z = (X != Y)
+else if(op_type == 4)
+ Z = (X >= Y)
+else if(op_type == 5)
+ Z = (X <= Y)
+
+write(Z, $out_Z);
diff --git
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTest.dml
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTest.dml
new file mode 100644
index 0000000..1dfb762
--- /dev/null
+++
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTest.dml
@@ -0,0 +1,41 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2,
$cols)));
+
+y = $in_Y;
+op_type = $in_op_type;
+
+if(op_type == 0)
+ Z = (X > y)
+else if(op_type == 1)
+ Z = (X < y)
+else if(op_type == 2)
+ Z = (X == y)
+else if(op_type == 3)
+ Z = (X != y)
+else if(op_type == 4)
+ Z = (X >= y)
+else if(op_type == 5)
+ Z = (X <= y)
+
+write(Z, $out_Z);
diff --git
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTestReference.dml
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTestReference.dml
new file mode 100644
index 0000000..4682aea
--- /dev/null
+++
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTestReference.dml
@@ -0,0 +1,40 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2));
+
+y = $in_Y;
+op_type = $in_op_type;
+
+if(op_type == 0)
+ Z = (X > y)
+else if(op_type == 1)
+ Z = (X < y)
+else if(op_type == 2)
+ Z = (X == y)
+else if(op_type == 3)
+ Z = (X != y)
+else if(op_type == 4)
+ Z = (X >= y)
+else if(op_type == 5)
+ Z = (X <= y)
+
+write(Z, $out_Z);