This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 0f698fe  [SYSTEMDS-2627] Federated mmchain instruction for lmCG, 
MLogreg, GLM
0f698fe is described below

commit 0f698fee39191de324b72cc3e22c609d340b124e
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Aug 18 22:55:29 2020 +0200

    [SYSTEMDS-2627] Federated mmchain instruction for lmCG, MLogreg, GLM
    
    This patch adds a federated mmchain instruction for a common
    matrix-vector multiplication chain as it appears in the inner loop of
    lmCG, Mlogreg, and GLM. It also includes a fix for more robust
    instruction manipulation, and a GLM federated test.
    
    Furthermore, we now use a slightly better approach for deciding between
    conf-only and context spark cluster analysis to avoid unnecessary spark
    context creation in local mode (which sometimes interferes with netty
    port allocation in federated tests).
---
 .../java/org/apache/sysds/lops/MapMultChain.java   |   5 +-
 .../context/SparkExecutionContext.java             |   6 +
 .../controlprogram/federated/FederationUtils.java  |   8 +-
 .../instructions/cp/MMChainCPInstruction.java      |  25 ++--
 .../runtime/instructions/fed/FEDInstruction.java   |   1 +
 .../instructions/fed/FEDInstructionUtils.java      |  31 +++--
 .../instructions/fed/MMChainFEDInstruction.java    | 112 +++++++++++++++++
 .../test/functions/federated/FederatedGLMTest.java | 135 +++++++++++++++++++++
 .../functions/federated/FederatedGLMTest.dml       |  27 +++++
 .../federated/FederatedGLMTestReference.dml        |  25 ++++
 10 files changed, 348 insertions(+), 27 deletions(-)

diff --git a/src/main/java/org/apache/sysds/lops/MapMultChain.java 
b/src/main/java/org/apache/sysds/lops/MapMultChain.java
index 79d57f7..b45d813 100644
--- a/src/main/java/org/apache/sysds/lops/MapMultChain.java
+++ b/src/main/java/org/apache/sysds/lops/MapMultChain.java
@@ -35,7 +35,10 @@ public class MapMultChain extends Lop
                XtXv,  //(t(X) %*% (X %*% v))
                XtwXv, //(t(X) %*% (w * (X %*% v)))
                XtXvy, //(t(X) %*% ((X %*% v) - y))
-               NONE,
+               NONE;
+               public boolean isWeighted() {
+                       return this == XtwXv || this == ChainType.XtXvy;
+               }
        }
        
        private ChainType _chainType = null;
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index 65348f1..2be647d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -1771,6 +1771,12 @@ public class SparkExecutionContext extends 
ExecutionContext
                                _defaultPar = (defaultPar>1) ? defaultPar : 
numExecutors * numCoresPerExec;
                                _confOnly &= true;
                        }
+                       else if( DMLScript.USE_LOCAL_SPARK_CONFIG ) {
+                               //avoid unnecessary spark context creation in 
local mode (e.g., tests)
+                               _numExecutors = 1;
+                               _defaultPar = 2;
+                               _confOnly &= true;
+                       }
                        else {
                                //get default parallelism (total number of 
executors and cores)
                                //note: spark context provides this information 
while conf does not
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index c34fa62..429834b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -52,10 +52,14 @@ public class FederationUtils {
                //TODO better and safe replacement of operand names --> 
instruction utils
                long id = getNextFedDataID();
                String linst = inst.replace(ExecType.SPARK.name(), 
ExecType.CP.name());
-               linst = 
linst.replace(Lop.OPERAND_DELIMITOR+varOldOut.getName(), 
Lop.OPERAND_DELIMITOR+String.valueOf(id));
+               linst = linst.replace(
+                       
Lop.OPERAND_DELIMITOR+varOldOut.getName()+Lop.DATATYPE_PREFIX,
+                       
Lop.OPERAND_DELIMITOR+String.valueOf(id)+Lop.DATATYPE_PREFIX);
                for(int i=0; i<varOldIn.length; i++)
                        if( varOldIn[i] != null ) {
-                               linst = 
linst.replace(Lop.OPERAND_DELIMITOR+varOldIn[i].getName(), 
Lop.OPERAND_DELIMITOR+String.valueOf(varNewIn[i]));
+                               linst = linst.replace(
+                                       
Lop.OPERAND_DELIMITOR+varOldIn[i].getName()+Lop.DATATYPE_PREFIX,
+                                       
Lop.OPERAND_DELIMITOR+String.valueOf(varNewIn[i])+Lop.DATATYPE_PREFIX);
                                linst = 
linst.replace("="+varOldIn[i].getName(), "="+String.valueOf(varNewIn[i])); 
//parameterized
                        }
                return new FederatedRequest(RequestType.EXEC_INST, id, linst);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/MMChainCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/MMChainCPInstruction.java
index f540343..dcff65b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/MMChainCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/MMChainCPInstruction.java
@@ -36,31 +36,30 @@ public class MMChainCPInstruction extends 
UnaryCPInstruction {
                _type = type;
                _numThreads = k;
        }
+       
+       public ChainType getMMChainType() {
+               return _type;
+       }
 
        public static MMChainCPInstruction parseInstruction ( String str ) {
                //parse instruction parts (without exec type)
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType( str );
                InstructionUtils.checkNumFields( parts, 5, 6 );
-       
                String opcode = parts[0];
                CPOperand in1 = new CPOperand(parts[1]);
                CPOperand in2 = new CPOperand(parts[2]);
                
-               if( parts.length==6 )
-               {
+               if( parts.length==6 ) {
                        CPOperand out= new CPOperand(parts[3]);
                        ChainType type = ChainType.valueOf(parts[4]);
                        int k = Integer.parseInt(parts[5]);
-                       
                        return new MMChainCPInstruction(null, in1, in2, null, 
out, type, k, opcode, str);
                }
-               else //parts.length==7
-               {
+               else { //parts.length==7
                        CPOperand in3 = new CPOperand(parts[3]);
                        CPOperand out = new CPOperand(parts[4]);
                        ChainType type = ChainType.valueOf(parts[5]);
                        int k = Integer.parseInt(parts[6]);
-                       
                        return new MMChainCPInstruction(null, in1, in2, in3, 
out, type, k, opcode, str);
                }
        }
@@ -70,19 +69,15 @@ public class MMChainCPInstruction extends 
UnaryCPInstruction {
                //get inputs
                MatrixBlock X = ec.getMatrixInput(input1.getName());
                MatrixBlock v = ec.getMatrixInput(input2.getName());
-               MatrixBlock w = (_type==ChainType.XtwXv || 
_type==ChainType.XtXvy) ? 
-                       ec.getMatrixInput(input3.getName()) : null;
+               MatrixBlock w = _type.isWeighted() ? 
ec.getMatrixInput(input3.getName()) : null;
+               
                //execute mmchain operation 
-                MatrixBlock out = X.chainMatrixMultOperations(v, w, new 
MatrixBlock(), _type, _numThreads);
+               MatrixBlock out = X.chainMatrixMultOperations(v, w, new 
MatrixBlock(), _type, _numThreads);
+               
                //set output and release inputs
                ec.setMatrixOutput(output.getName(), out);
                ec.releaseMatrixInput(input1.getName(), input2.getName());
                if( w !=null )
                        ec.releaseMatrixInput(input3.getName());
        }
-       
-       public ChainType getMMChainType()
-       {
-               return _type;
-       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index 6df1b1e..77dedfd 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -35,6 +35,7 @@ public abstract class FEDInstruction extends Instruction {
                MultiReturnParameterizedBuiltin,
                ParameterizedBuiltin,
                Tsmm,
+               MMChain,
        }
        
        protected final FEDType _fedType;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 4325456..bbdaa8e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -45,6 +45,18 @@ public class FEDInstructionUtils {
                                }
                        }
                }
+               else if( inst instanceof MMChainCPInstruction) {
+                       MMChainCPInstruction linst = (MMChainCPInstruction) 
inst;
+                       MatrixObject mo = ec.getMatrixObject(linst.input1);
+                       if( mo.isFederated() )
+                               fedinst = 
MMChainFEDInstruction.parseInstruction(linst.getInstructionString());
+               }
+               else if( inst instanceof MMTSJCPInstruction ) {
+                       MMTSJCPInstruction linst = (MMTSJCPInstruction) inst;
+                       MatrixObject mo = ec.getMatrixObject(linst.input1);
+                       if( mo.isFederated() )
+                               fedinst = 
TsmmFEDInstruction.parseInstruction(linst.getInstructionString());
+               }
                else if (inst instanceof AggregateUnaryCPInstruction) {
                        AggregateUnaryCPInstruction instruction = 
(AggregateUnaryCPInstruction) inst;
                        if( instruction.input1.isMatrix() && 
ec.containsVariable(instruction.input1) ) {
@@ -77,12 +89,6 @@ public class FEDInstructionUtils {
                                }
                        }
                }
-               else if( inst instanceof MMTSJCPInstruction ) {
-                       MMTSJCPInstruction linst = (MMTSJCPInstruction) inst;
-                       MatrixObject mo = ec.getMatrixObject(linst.input1);
-                       if( mo.isFederated() )
-                               fedinst = 
TsmmFEDInstruction.parseInstruction(linst.getInstructionString());
-               }
                
                //set thread id for federated context management
                if( fedinst != null ) {
@@ -94,13 +100,14 @@ public class FEDInstructionUtils {
        }
        
        public static Instruction checkAndReplaceSP(Instruction inst, 
ExecutionContext ec) {
+               FEDInstruction fedinst = null;
                if (inst instanceof MapmmSPInstruction) {
                        // FIXME does not yet work for MV multiplication. SPARK 
execution mode not supported for federated l2svm
                        MapmmSPInstruction instruction = (MapmmSPInstruction) 
inst;
                        Data data = ec.getVariable(instruction.input1);
                        if (data instanceof MatrixObject && ((MatrixObject) 
data).isFederated()) {
                                // TODO correct FED instruction string
-                               return new 
AggregateBinaryFEDInstruction(instruction.getOperator(),
+                               fedinst = new 
AggregateBinaryFEDInstruction(instruction.getOperator(),
                                        instruction.input1, instruction.input2, 
instruction.output, "ba+*", "FED...");
                        }
                }
@@ -108,7 +115,7 @@ public class FEDInstructionUtils {
                        AggregateUnarySPInstruction instruction = 
(AggregateUnarySPInstruction) inst;
                        Data data = ec.getVariable(instruction.input1);
                        if (data instanceof MatrixObject && ((MatrixObject) 
data).isFederated())
-                               return 
AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString());
+                               fedinst = 
AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString());
                }
                else if (inst instanceof WriteSPInstruction) {
                        WriteSPInstruction instruction = (WriteSPInstruction) 
inst;
@@ -124,9 +131,15 @@ public class FEDInstructionUtils {
                        AppendGAlignedSPInstruction instruction = 
(AppendGAlignedSPInstruction) inst;
                        Data data = ec.getVariable(instruction.input1);
                        if (data instanceof MatrixObject && ((MatrixObject) 
data).isFederated()) {
-                               return 
AppendFEDInstruction.parseInstruction(instruction.getInstructionString());
+                               fedinst = 
AppendFEDInstruction.parseInstruction(instruction.getInstructionString());
                        }
                }
+               //set thread id for federated context management
+               if( fedinst != null ) {
+                       fedinst.setTID(ec.getTID());
+                       return fedinst;
+               }
+               
                return inst;
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
new file mode 100644
index 0000000..2dee64b
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.lops.MapMultChain.ChainType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+import java.util.concurrent.Future;
+
+public class MMChainFEDInstruction extends UnaryFEDInstruction {
+       
+       public MMChainFEDInstruction(CPOperand in1, CPOperand in2, CPOperand 
in3, 
+               CPOperand out, ChainType type, int k, String opcode, String 
istr) {
+               super(FEDType.MMChain, null, in1, in2, in3, out, opcode, istr);
+               _type = type;
+       }
+       
+       private final ChainType _type;
+
+       public ChainType getMMChainType() {
+               return _type;
+       }
+
+       public static MMChainFEDInstruction parseInstruction ( String str ) {
+               //parse instruction parts (without exec type)
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType( str );
+               InstructionUtils.checkNumFields( parts, 5, 6 );
+               String opcode = parts[0];
+               CPOperand in1 = new CPOperand(parts[1]);
+               CPOperand in2 = new CPOperand(parts[2]);
+               
+               if( parts.length==6 ) {
+                       CPOperand out= new CPOperand(parts[3]);
+                       ChainType type = ChainType.valueOf(parts[4]);
+                       int k = Integer.parseInt(parts[5]);
+                       return new MMChainFEDInstruction(in1, in2, null, out, 
type, k, opcode, str);
+               }
+               else { //parts.length==7
+                       CPOperand in3 = new CPOperand(parts[3]);
+                       CPOperand out = new CPOperand(parts[4]);
+                       ChainType type = ChainType.valueOf(parts[5]);
+                       int k = Integer.parseInt(parts[6]);
+                       return new MMChainFEDInstruction(in1, in2, in3, out, 
type, k, opcode, str);
+               }
+       }
+       
+       @Override
+       public void processInstruction(ExecutionContext ec) {
+               MatrixObject mo1 = ec.getMatrixObject(input1);
+               MatrixObject mo2 = ec.getMatrixObject(input2);
+               MatrixObject mo3 = _type.isWeighted() ? 
ec.getMatrixObject(input3) : null;
+               
+               if( !mo1.isFederated() )
+                       throw new DMLRuntimeException("Federated MMChain: 
Federated main input expected, "
+                               + "but invoked w/ "+mo1.isFederated()+" 
"+mo2.isFederated());
+       
+               if( !_type.isWeighted() ) { //XtXv
+                       //construct commands: broadcast vector, execute, get 
and aggregate, cleanup
+                       FederatedRequest fr1 = 
mo1.getFedMapping().broadcast(mo2);
+                       FederatedRequest fr2 = 
FederationUtils.callInstruction(instString, output,
+                               new CPOperand[]{input1, input2}, new 
long[]{mo1.getFedMapping().getID(), fr1.getID()});
+                       FederatedRequest fr3 = new 
FederatedRequest(RequestType.GET_VAR, fr2.getID());
+                       
+                       //execute federated operations and aggregate
+                       Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
+                       MatrixBlock ret = FederationUtils.aggAdd(tmp);
+                       mo1.getFedMapping().cleanup(getTID(), fr1.getID(), 
fr2.getID());
+                       ec.setMatrixOutput(output.getName(), ret);
+               }
+               else { //XtwXv | XtXvy
+                       //construct commands: broadcast 2 vectors, execute, get 
and aggregate, cleanup
+                       FederatedRequest[] fr0 = 
mo1.getFedMapping().broadcastSliced(mo3, false);
+                       FederatedRequest fr1 = 
mo1.getFedMapping().broadcast(mo2);
+                       FederatedRequest fr2 = 
FederationUtils.callInstruction(instString, output,
+                               new CPOperand[]{input1, input2, input3},
+                               new long[]{mo1.getFedMapping().getID(), 
fr1.getID(), fr0[0].getID()});
+                       FederatedRequest fr3 = new 
FederatedRequest(RequestType.GET_VAR, fr2.getID());
+                       
+                       //execute federated operations and aggregate
+                       Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(getTID(), fr0, fr1, fr2, fr3);
+                       MatrixBlock ret = FederationUtils.aggAdd(tmp);
+                       mo1.getFedMapping().cleanup(getTID(), fr0[0].getID(), 
fr1.getID(), fr2.getID());
+                       ec.setMatrixOutput(output.getName(), ret);
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedGLMTest.java 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedGLMTest.java
new file mode 100644
index 0000000..fe24bc8
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedGLMTest.java
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedGLMTest extends AutomatedTestBase {
+
+       private final static String TEST_DIR = "functions/federated/";
+       private final static String TEST_NAME = "FederatedGLMTest";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedGLMTest.class.getSimpleName() + "/";
+
+       private final static int blocksize = 1024;
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+       }
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               // rows have to be even and > 1
+               return Arrays.asList(new Object[][] {{10000, 10}, {1000, 100}, 
{2000, 43}});
+       }
+
+       @Test
+       public void federatedSinglenodeGLM() {
+               federatedGLM(Types.ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void federatedHybridGLM() {
+               federatedGLM(Types.ExecMode.HYBRID);
+       }
+
+       
+       public void federatedGLM(Types.ExecMode execMode) {
+               ExecMode platformOld = setExecMode(execMode);
+
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               // write input matrices
+               int halfRows = rows / 2;
+               // We have two matrices handled by a single federated worker
+               double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
+               double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
+               double[][] Y = getRandomMatrix(rows, 1, -1, 1, 1, 1233);
+               for(int i = 0; i < rows; i++)
+                       Y[i][0] = (Y[i][0] > 0) ? 1 : -1;
+
+               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+               writeInputMatrixWithMTD("Y", Y, false, new 
MatrixCharacteristics(rows, 1, blocksize, rows));
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               Thread t1 = startLocalFedWorker(port1);
+               Thread t2 = startLocalFedWorker(port2);
+
+               TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+               loadTestConfiguration(config);
+               setOutputBuffering(false);
+               
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+               programArgs = new String[] {"-args", input("X1"), input("X2"), 
input("Y"), expected("Z")};
+               runTest(true, false, null, -1);
+
+               // Run actual dml script with federated matrix
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-stats",
+                       "-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")), "rows=" + rows, "cols=" + cols,
+                       "in_Y=" + input("Y"), "out=" + output("Z")};
+               runTest(true, false, null, -1);
+
+               // compare via files
+               compareResults(1e-9);
+
+               TestUtils.shutdownThreads(t1, t2);
+
+               // check for federated operations
+               Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
+               
Assert.assertTrue(heavyHittersContainsString("fed_uark+","fed_uarsqk+"));
+               Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
+               Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
+               Assert.assertTrue(heavyHittersContainsString("fed_mmchain"));
+               
+               //check that federated input files are still existing
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+               
+               resetExecMode(platformOld);
+       }
+}
diff --git a/src/test/scripts/functions/federated/FederatedGLMTest.dml 
b/src/test/scripts/functions/federated/FederatedGLMTest.dml
new file mode 100644
index 0000000..aa23b5e
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedGLMTest.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+    ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), 
list($rows, $cols)))
+Y = read($in_Y)
+
+model = glm(X=X, Y=Y, icpt = FALSE, tol = 1e-6, reg = 0.01)
+write(model, $out)
diff --git a/src/test/scripts/functions/federated/FederatedGLMTestReference.dml 
b/src/test/scripts/functions/federated/FederatedGLMTestReference.dml
new file mode 100644
index 0000000..a307c8c
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedGLMTestReference.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($1), read($2))
+Y = read($3)
+model = glm(X=X, Y=Y, icpt = FALSE, tol = 1e-6, reg = 0.01)
+write(model, $4)

Reply via email to