This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 65d6ad3  [SYSTEMDS-2747] Federated ALS-CG test, extended federated 
binary ops
65d6ad3 is described below

commit 65d6ad393acea7f309c1cff4e768cec0b5998d15
Author: ywcb00 <[email protected]>
AuthorDate: Sat Jan 30 23:23:19 2021 +0100

    [SYSTEMDS-2747] Federated ALS-CG test, extended federated binary ops
    
    Closes #1170.
---
 .../fed/BinaryMatrixMatrixFEDInstruction.java      |  33 ++-
 .../instructions/fed/FEDInstructionUtils.java      |  16 ++
 .../federated/algorithms/FederatedAlsCGTest.java   | 170 ++++++++++++
 .../federated/primitives/FederatedLogicalTest.java | 305 +++++++++++++++++++++
 .../functions/federated/FederatedAlsCGTest.dml     |  35 +++
 .../federated/FederatedAlsCGTestReference.dml      |  34 +++
 .../binary/FederatedLogicalMatrixMatrixTest.dml    |  41 +++
 .../FederatedLogicalMatrixMatrixTestReference.dml  |  40 +++
 .../binary/FederatedLogicalMatrixScalarTest.dml    |  41 +++
 .../FederatedLogicalMatrixScalarTestReference.dml  |  40 +++
 10 files changed, 748 insertions(+), 7 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index ea34df1..0ba1935 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -23,6 +23,7 @@ import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
@@ -39,15 +40,15 @@ public class BinaryMatrixMatrixFEDInstruction extends 
BinaryFEDInstruction
        public void processInstruction(ExecutionContext ec) {
                MatrixObject mo1 = ec.getMatrixObject(input1);
                MatrixObject mo2 = ec.getMatrixObject(input2);
-               
+
                //canonicalization for federated lhs
-               if( !mo1.isFederated() && mo2.isFederated() 
-                       && 
mo1.getDataCharacteristics().equalDims(mo2.getDataCharacteristics()) 
+               if( !mo1.isFederated() && mo2.isFederated()
+                       && 
mo1.getDataCharacteristics().equalDims(mo2.getDataCharacteristics())
                        && ((BinaryOperator)_optr).isCommutative() ) {
                        mo1 = ec.getMatrixObject(input2);
                        mo2 = ec.getMatrixObject(input1);
                }
-               
+
                //execute federated operation on mo1 or mo2
                FederatedRequest fr2 = null;
                if( mo2.isFederated() ) {
@@ -63,7 +64,7 @@ public class BinaryMatrixMatrixFEDInstruction extends 
BinaryFEDInstruction
                }
                else {
                        //matrix-matrix binary operations -> lhs fed input -> 
fed output
-                       if(mo2.getNumRows() > 1 && mo2.getNumColumns() == 1 ) { 
//MV row vector
+                       if(mo2.getNumRows() > 1 && mo2.getNumColumns() == 1 ) { 
//MV col vector
                                FederatedRequest[] fr1 = 
mo1.getFedMapping().broadcastSliced(mo2, false);
                                fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, 
input2},
                                        new long[]{mo1.getFedMapping().getID(), 
fr1[0].getID()});
@@ -71,7 +72,7 @@ public class BinaryMatrixMatrixFEDInstruction extends 
BinaryFEDInstruction
                                //execute federated instruction and cleanup 
intermediates
                                mo1.getFedMapping().execute(getTID(), true, 
fr1, fr2, fr3);
                        }
-                       else { //MM or MV col vector
+                       else if(mo2.getNumRows() == 1 && mo2.getNumColumns() > 
1) { //MV row vector
                                FederatedRequest fr1 = 
mo1.getFedMapping().broadcast(mo2);
                                fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, 
input2},
                                        new long[]{mo1.getFedMapping().getID(), 
fr1.getID()});
@@ -79,8 +80,26 @@ public class BinaryMatrixMatrixFEDInstruction extends 
BinaryFEDInstruction
                                //execute federated instruction and cleanup 
intermediates
                                mo1.getFedMapping().execute(getTID(), true, 
fr1, fr2, fr3);
                        }
+                       else { //MM
+                               if(mo1.isFederated(FType.ROW)) {
+                                       FederatedRequest[] fr1 = 
mo1.getFedMapping().broadcastSliced(mo2, false);
+                                       fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, 
input2},
+                                               new 
long[]{mo1.getFedMapping().getID(), fr1[0].getID()});
+                                       FederatedRequest fr3 = 
mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+                                       //execute federated instruction and 
cleanup intermediates
+                                       mo1.getFedMapping().execute(getTID(), 
true, fr1, fr2, fr3);
+                               }
+                               else {
+                                       FederatedRequest fr1 = 
mo1.getFedMapping().broadcast(mo2);
+                                       fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, 
input2},
+                                               new 
long[]{mo1.getFedMapping().getID(), fr1.getID()});
+                                       FederatedRequest fr3 = 
mo1.getFedMapping().cleanup(getTID(), fr1.getID());
+                                       //execute federated instruction and 
cleanup intermediates
+                                       mo1.getFedMapping().execute(getTID(), 
true, fr1, fr2, fr3);
+                               }
+                       }
                }
-               
+
                //derive new fed mapping for output
                MatrixObject out = ec.getMatrixObject(output);
                out.getDataCharacteristics().set(mo1.getDataCharacteristics());
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index fbdb3a2..845f8a4 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.fed;
 import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
 import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.caching.TensorObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
 import org.apache.sysds.runtime.instructions.Instruction;
@@ -43,7 +44,11 @@ import 
org.apache.sysds.runtime.instructions.cp.VariableCPInstruction.VariableOp
 import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
 import org.apache.sysds.runtime.instructions.spark.AppendGAlignedSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.AppendGSPInstruction;
+import 
org.apache.sysds.runtime.instructions.spark.BinaryMatrixMatrixSPInstruction;
+import 
org.apache.sysds.runtime.instructions.spark.BinaryMatrixScalarSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.BinarySPInstruction;
+import 
org.apache.sysds.runtime.instructions.spark.BinaryTensorTensorBroadcastSPInstruction;
+import 
org.apache.sysds.runtime.instructions.spark.BinaryTensorTensorSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.CentralMomentSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.MapmmSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.QuantilePickSPInstruction;
@@ -254,6 +259,17 @@ public class FEDInstructionUtils {
                                        fedinst = 
AppendFEDInstruction.parseInstruction(instruction.getInstructionString());
                                }
                        }
+                       else if (inst instanceof BinaryMatrixScalarSPInstruction
+                               || inst instanceof 
BinaryMatrixMatrixSPInstruction
+                               || inst instanceof 
BinaryTensorTensorSPInstruction
+                               || inst instanceof 
BinaryTensorTensorBroadcastSPInstruction) {
+                               BinarySPInstruction instruction = 
(BinarySPInstruction) inst;
+                               Data data = ec.getVariable(instruction.input1);
+                               if((data instanceof MatrixObject && 
((MatrixObject)data).isFederated())
+                                       || (data instanceof TensorObject && 
((TensorObject)data).isFederated())) {
+                                       fedinst = 
BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
+                               }
+                       }
                }
                else if (inst instanceof WriteSPInstruction) {
                        WriteSPInstruction instruction = (WriteSPInstruction) 
inst;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
new file mode 100644
index 0000000..4909f7c
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.     See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.      The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.   You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.    See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.algorithms;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedAlsCGTest extends AutomatedTestBase
+{
+       private final static String TEST_NAME = "FederatedAlsCGTest";
+       private final static String TEST_DIR = "functions/federated/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedAlsCGTest.class.getSimpleName() + "/";
+
+       private final static String OUTPUT_NAME = "Z";
+       private final static double TOLERANCE = 0.01;
+       private final static int blocksize = 1024;
+
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+       @Parameterized.Parameter(2)
+       public int rank;
+       @Parameterized.Parameter(3)
+       public String regression;
+       @Parameterized.Parameter(4)
+       public double lambda;
+       @Parameterized.Parameter(5)
+       public int max_iter;
+       @Parameterized.Parameter(6)
+       public double threshold;
+       @Parameterized.Parameter(7)
+       public double sparsity;
+
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{OUTPUT_NAME}));
+       }
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               // rows must be even
+               return Arrays.asList(new Object[][] {
+                       // {rows, cols, rank, regression, lambda, max_iter, 
threshold, sparsity}
+                       {30, 15, 10, "L2", 0.0000001, 50, 0.000001, 1},
+                       {30, 15, 10, "wL2", 0.0000001, 50, 0.000001, 1}
+               });
+       }
+
+       @BeforeClass
+       public static void init() {
+               TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+       }
+
+       @Test
+       public void federatedAlsCGSingleNode() {
+               federatedAlsCG(TEST_NAME, ExecMode.SINGLE_NODE);
+       }
+
+//     @Test
+//     public void federatedAlsCGSpark() {
+//             federatedAlsCG(TEST_NAME, ExecMode.SPARK);
+//     }
+
+// 
-----------------------------------------------------------------------------
+
+       public void federatedAlsCG(String testname, ExecMode execMode)
+       {
+               // store the previous platform config to restore it after the 
test
+               ExecMode platform_old = setExecMode(execMode);
+
+               getAndLoadTestConfiguration(testname);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               int fed_rows = rows / 2;
+               int fed_cols = cols;
+
+               double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 1, 2, 
sparsity, 13);
+               double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 1, 2, 
sparsity, 2);
+
+               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(
+                       fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(
+                       fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               Thread thread1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
+               Thread thread2 = startLocalFedWorkerThread(port2);
+
+               getAndLoadTestConfiguration(testname);
+
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + testname + "Reference.dml";
+               programArgs = new String[] {"-stats", "-nvargs",
+                       "in_X1=" + input("X1"), "in_X2=" + input("X2"), 
"in_rank=" + Integer.toString(rank),
+                       "in_reg=" + regression, "in_lambda=" + 
Double.toString(lambda),
+                       "in_maxi=" + Integer.toString(max_iter), "in_thr=" + 
Double.toString(threshold),
+                       "out_Z=" + expected(OUTPUT_NAME)};
+               runTest(true, false, null, -1);
+
+               // Run actual dml script with federated matrix
+               fullDMLScriptName = HOME + testname + ".dml";
+               programArgs = new String[] {"-stats", "-nvargs",
+                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                       "in_rank=" + Integer.toString(rank),
+                       "in_reg=" + regression,
+                       "in_lambda=" + Double.toString(lambda),
+                       "in_maxi=" + Integer.toString(max_iter),
+                       "in_thr=" + Double.toString(threshold),
+                       "rows=" + fed_rows, "cols=" + fed_cols,
+                       "out_Z=" + output(OUTPUT_NAME)};
+               runTest(true, false, null, -1);
+
+               // compare the results via files
+               HashMap<CellIndex, Double> refResults  = 
readDMLMatrixFromExpectedDir(OUTPUT_NAME);
+               HashMap<CellIndex, Double> fedResults = 
readDMLMatrixFromOutputDir(OUTPUT_NAME);
+               TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, 
"Fed", "Ref");
+
+               TestUtils.shutdownThreads(thread1, thread2);
+
+               // check for federated operations
+               Assert.assertTrue(heavyHittersContainsString("fed_!="));
+               Assert.assertTrue(heavyHittersContainsString("fed_fedinit"));
+               Assert.assertTrue(heavyHittersContainsString("fed_wdivmm"));
+               Assert.assertTrue(heavyHittersContainsString("fed_wsloss"));
+
+               // check that federated input files are still existing
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+
+               resetExecMode(platform_old);
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java
new file mode 100644
index 0000000..53dfb2e
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLogicalTest.java
@@ -0,0 +1,305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.     See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.      The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.   You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.    See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedLogicalTest extends AutomatedTestBase
+{
+       private final static String SCALAR_TEST_NAME = 
"FederatedLogicalMatrixScalarTest";
+       private final static String MATRIX_TEST_NAME = 
"FederatedLogicalMatrixMatrixTest";
+       private final static String TEST_DIR = "functions/federated/binary/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedLogicalTest.class.getSimpleName() + "/";
+
+       private final static String OUTPUT_NAME = "Z";
+       private final static double TOLERANCE = 0;
+       private final static int blocksize = 1024;
+
+       public enum Type{
+               GREATER,
+               LESS,
+               EQUALS,
+               NOT_EQUALS,
+               GREATER_EQUALS,
+               LESS_EQUALS
+       }
+
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+       @Parameterized.Parameter(2)
+       public double sparsity;
+
+       @Override
+       public void setUp() {
+               addTestConfiguration(SCALAR_TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, SCALAR_TEST_NAME, new String[]{OUTPUT_NAME}));
+               addTestConfiguration(MATRIX_TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, MATRIX_TEST_NAME, new String[]{OUTPUT_NAME}));
+       }
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               // rows must be even
+               return Arrays.asList(new Object[][] {
+                       // {rows, cols, sparsity}
+                       {100, 75, 0.01},
+                       {100, 75, 0.9},
+                       {2, 75, 0.01},
+                       {2, 75, 0.9}
+               });
+       }
+
+       @BeforeClass
+       public static void init() {
+               TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+       }
+
+       //---------------------------MATRIX SCALAR--------------------------
+       @Test
+       public void federatedLogicalScalarGreaterSingleNode() {
+               federatedLogicalTest(SCALAR_TEST_NAME, Type.GREATER, 
ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void federatedLogicalScalarGreaterSpark() {
+               federatedLogicalTest(SCALAR_TEST_NAME, Type.GREATER, 
ExecMode.SPARK);
+       }
+
+       @Test
+       public void federatedLogicalScalarLessSingleNode() {
+               federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS, 
ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void federatedLogicalScalarLessSpark() {
+               federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS, 
ExecMode.SPARK);
+       }
+
+       @Test
+       public void federatedLogicalScalarEqualsSingleNode() {
+               federatedLogicalTest(SCALAR_TEST_NAME, Type.EQUALS, 
ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void federatedLogicalScalarEqualsSpark() {
+               federatedLogicalTest(SCALAR_TEST_NAME, Type.EQUALS, 
ExecMode.SPARK);
+       }
+
+       @Test
+       public void federatedLogicalScalarNotEqualsSingleNode() {
+               federatedLogicalTest(SCALAR_TEST_NAME, Type.NOT_EQUALS, 
ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void federatedLogicalScalarNotEqualsSpark() {
+               federatedLogicalTest(SCALAR_TEST_NAME, Type.NOT_EQUALS, 
ExecMode.SPARK);
+       }
+
+       @Test
+       public void federatedLogicalScalarGreaterEqualsSingleNode() {
+               federatedLogicalTest(SCALAR_TEST_NAME, Type.GREATER_EQUALS, 
ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void federatedLogicalScalarGreaterEqualsSpark() {
+               federatedLogicalTest(SCALAR_TEST_NAME, Type.GREATER_EQUALS, 
ExecMode.SPARK);
+       }
+
+       @Test
+       public void federatedLogicalScalarLessEqualsSingleNode() {
+               federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS_EQUALS, 
ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void federatedLogicalScalarLessEqualsSpark() {
+               federatedLogicalTest(SCALAR_TEST_NAME, Type.LESS_EQUALS, 
ExecMode.SPARK);
+       }
+
+       //---------------------------MATRIX MATRIX--------------------------
+       @Test
+       public void federatedLogicalMatrixGreaterSingleNode() {
+               federatedLogicalTest(MATRIX_TEST_NAME, Type.GREATER, 
ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void federatedLogicalMatrixGreaterSpark() {
+               federatedLogicalTest(MATRIX_TEST_NAME, Type.GREATER, 
ExecMode.SPARK);
+       }
+
+       @Test
+       public void federatedLogicalMatrixLessSingleNode() {
+               federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS, 
ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void federatedLogicalMatrixLessSpark() {
+               federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS, 
ExecMode.SPARK);
+       }
+
+       @Test
+       public void federatedLogicalMatrixEqualsSingleNode() {
+               federatedLogicalTest(MATRIX_TEST_NAME, Type.EQUALS, 
ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void federatedLogicalMatrixEqualsSpark() {
+               federatedLogicalTest(MATRIX_TEST_NAME, Type.EQUALS, 
ExecMode.SPARK);
+       }
+
+       @Test
+       public void federatedLogicalMatrixNotEqualsSingleNode() {
+               federatedLogicalTest(MATRIX_TEST_NAME, Type.NOT_EQUALS, 
ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void federatedLogicalMatrixNotEqualsSpark() {
+               federatedLogicalTest(MATRIX_TEST_NAME, Type.NOT_EQUALS, 
ExecMode.SPARK);
+       }
+
+       @Test
+       public void federatedLogicalMatrixGreaterEqualsSingleNode() {
+               federatedLogicalTest(MATRIX_TEST_NAME, Type.GREATER_EQUALS, 
ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void federatedLogicalMatrixGreaterEqualsSpark() {
+               federatedLogicalTest(MATRIX_TEST_NAME, Type.GREATER_EQUALS, 
ExecMode.SPARK);
+       }
+
+       @Test
+       public void federatedLogicalMatrixLessEqualsSingleNode() {
+               federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS_EQUALS, 
ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void federatedLogicalMatrixLessEqualsSpark() {
+               federatedLogicalTest(MATRIX_TEST_NAME, Type.LESS_EQUALS, 
ExecMode.SPARK);
+       }
+
+// 
-----------------------------------------------------------------------------
+
+       public void federatedLogicalTest(String testname, Type op_type, 
ExecMode execMode)
+       {
+               // store the previous platform config to restore it after the 
test
+               ExecMode platform_old = setExecMode(execMode);
+
+               getAndLoadTestConfiguration(testname);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               int fed_rows = rows / 2;
+               int fed_cols = cols;
+
+               // generate dataset
+               // matrix handled by two federated workers
+               double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 1, 2, 1, 
13);
+               double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 1, 2, 1, 2);
+
+               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+
+               boolean is_matrix_test = testname.equals(MATRIX_TEST_NAME);
+
+               double[][] Y_mat = null;
+               double Y_scal = 0;
+               if(is_matrix_test) {
+                       Y_mat = getRandomMatrix(rows, cols, 0, 1, sparsity, 
5040);
+                       writeInputMatrixWithMTD("Y", Y_mat, true);
+               }
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               Thread thread1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
+               Thread thread2 = startLocalFedWorkerThread(port2);
+
+               getAndLoadTestConfiguration(testname);
+
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + testname + "Reference.dml";
+               programArgs = new String[] {"-nvargs", "in_X1=" + input("X1"), 
"in_X2=" + input("X2"),
+                       "in_Y=" + (is_matrix_test ? input("Y") : 
Double.toString(Y_scal)),
+                       "in_op_type=" + Integer.toString(op_type.ordinal()),
+                       "out_Z=" + expected(OUTPUT_NAME)};
+               runTest(true, false, null, -1);
+
+               // Run actual dml script with federated matrix
+               fullDMLScriptName = HOME + testname + ".dml";
+               programArgs = new String[] {"-stats", "-nvargs",
+                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")), "in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
+                       "in_Y=" + (is_matrix_test ? input("Y") : 
Double.toString(Y_scal)),
+                       "in_op_type=" + Integer.toString(op_type.ordinal()),
+                       "rows=" + fed_rows, "cols=" + fed_cols, "out_Z=" + 
output(OUTPUT_NAME)};
+               runTest(true, false, null, -1);
+
+               // compare the results via files
+               HashMap<CellIndex, Double> refResults  = 
readDMLMatrixFromExpectedDir(OUTPUT_NAME);
+               HashMap<CellIndex, Double> fedResults = 
readDMLMatrixFromOutputDir(OUTPUT_NAME);
+               TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, 
"Fed", "Ref");
+
+               TestUtils.shutdownThreads(thread1, thread2);
+
+               // check for federated operations
+               switch(op_type)
+               {
+                       case GREATER:
+                               
Assert.assertTrue(heavyHittersContainsString("fed_>"));
+                               break;
+                       case LESS:
+                               
Assert.assertTrue(heavyHittersContainsString("fed_<"));
+                               break;
+                       case EQUALS:
+                               
Assert.assertTrue(heavyHittersContainsString("fed_=="));
+                               break;
+                       case NOT_EQUALS:
+                               
Assert.assertTrue(heavyHittersContainsString("fed_!="));
+                               break;
+                       case GREATER_EQUALS:
+                               
Assert.assertTrue(heavyHittersContainsString("fed_>="));
+                               break;
+                       case LESS_EQUALS:
+                               
Assert.assertTrue(heavyHittersContainsString("fed_<="));
+                               break;
+               }
+
+               // check that federated input files are still existing
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+
+               resetExecMode(platform_old);
+       }
+}
diff --git a/src/test/scripts/functions/federated/FederatedAlsCGTest.dml 
b/src/test/scripts/functions/federated/FederatedAlsCGTest.dml
new file mode 100644
index 0000000..05258f4
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedAlsCGTest.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+  ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, 
$cols)));
+
+rank = $in_rank;
+reg = $in_reg;
+lambda = $in_lambda;
+maxi = $in_maxi;
+thr = $in_thr;
+
+[U, V] = alsCG(X = X, rank = rank, reg = reg, lambda = lambda, maxi = maxi, 
check = TRUE, thr = thr);
+
+Z = U %*% V;
+
+write(Z, $out_Z);
diff --git 
a/src/test/scripts/functions/federated/FederatedAlsCGTestReference.dml 
b/src/test/scripts/functions/federated/FederatedAlsCGTestReference.dml
new file mode 100644
index 0000000..a73efba
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedAlsCGTestReference.dml
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2));
+
+rank = $in_rank;
+reg = $in_reg;
+lambda = $in_lambda;
+maxi = $in_maxi;
+thr = $in_thr;
+
+[U, V] = alsCG(X = X, rank = rank, reg = reg, lambda = lambda, maxi = maxi, 
check = TRUE, thr = thr);
+
+Z = U %*% V;
+
+write(Z, $out_Z);
diff --git 
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTest.dml
 
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTest.dml
new file mode 100644
index 0000000..7fe350c
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTest.dml
@@ -0,0 +1,41 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+  ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, 
$cols)));
+
+Y = read($in_Y);
+op_type = $in_op_type;
+
+if(op_type == 0)
+   Z = (X > Y)
+else if(op_type == 1)
+   Z = (X < Y)
+else if(op_type == 2)
+   Z = (X == Y)
+else if(op_type == 3)
+   Z = (X != Y)
+else if(op_type == 4)
+   Z = (X >= Y)
+else if(op_type == 5)
+   Z = (X <= Y)
+
+write(Z, $out_Z);
diff --git 
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
 
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
new file mode 100644
index 0000000..1285eb0
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixMatrixTestReference.dml
@@ -0,0 +1,40 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2));
+
+Y = read($in_Y);
+op_type = $in_op_type;
+
+if(op_type == 0)
+   Z = (X > Y)
+else if(op_type == 1)
+   Z = (X < Y)
+else if(op_type == 2)
+   Z = (X == Y)
+else if(op_type == 3)
+   Z = (X != Y)
+else if(op_type == 4)
+   Z = (X >= Y)
+else if(op_type == 5)
+   Z = (X <= Y)
+
+write(Z, $out_Z);
diff --git 
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTest.dml
 
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTest.dml
new file mode 100644
index 0000000..1dfb762
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTest.dml
@@ -0,0 +1,41 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+  ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, 
$cols)));
+
+y = $in_Y;
+op_type = $in_op_type;
+
+if(op_type == 0)
+   Z = (X > y)
+else if(op_type == 1)
+   Z = (X < y)
+else if(op_type == 2)
+   Z = (X == y)
+else if(op_type == 3)
+   Z = (X != y)
+else if(op_type == 4)
+   Z = (X >= y)
+else if(op_type == 5)
+   Z = (X <= y)
+
+write(Z, $out_Z);
diff --git 
a/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTestReference.dml
 
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTestReference.dml
new file mode 100644
index 0000000..4682aea
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/binary/FederatedLogicalMatrixScalarTestReference.dml
@@ -0,0 +1,40 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2));
+
+y = $in_Y;
+op_type = $in_op_type;
+
+if(op_type == 0)
+   Z = (X > y)
+else if(op_type == 1)
+   Z = (X < y)
+else if(op_type == 2)
+   Z = (X == y)
+else if(op_type == 3)
+   Z = (X != y)
+else if(op_type == 4)
+   Z = (X >= y)
+else if(op_type == 5)
+   Z = (X <= y)
+
+write(Z, $out_Z);

Reply via email to