This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new a0f1e81  [SYSTEMDS-2988] Additional federated GLM algorithm tests (col 
parts)
a0f1e81 is described below

commit a0f1e8192ba9671e9a5bb431b3d69d7755008588
Author: Olga <[email protected]>
AuthorDate: Sat Jun 5 16:08:20 2021 +0200

    [SYSTEMDS-2988] Additional federated GLM algorithm tests (col parts)
    
    Additional tests (GLM column-partitioned, and mmchain) as well as
    related improvements of the federated backend.
    
    Closes #1289.
---
 .../runtime/instructions/InstructionUtils.java     |   2 +-
 .../instructions/fed/FEDInstructionUtils.java      |   6 +-
 .../instructions/fed/MMChainFEDInstruction.java    |   4 +-
 .../fed/ParameterizedBuiltinFEDInstruction.java    |   2 +-
 .../instructions/fed/TsmmFEDInstruction.java       |   6 +-
 .../federated/algorithms/FederatedGLMTest.java     |  41 ++---
 .../federated/primitives/FederatedMMChainTest.java | 165 +++++++++++++++++++++
 .../functions/federated/FederatedGLMTest.dml       |  10 +-
 .../federated/FederatedGLMTestReference.dml        |   8 +-
 ...deratedGLMTest.dml => FederatedMMChainTest.dml} |  20 ++-
 ...rence.dml => FederatedMMChainTestReference.dml} |  12 +-
 ...LMTest.dml => FederatedMMChainWeights2Test.dml} |  19 ++-
 ...l => FederatedMMChainWeights2TestReference.dml} |  12 +-
 ...GLMTest.dml => FederatedMMChainWeightsTest.dml} |  19 ++-
 ...ml => FederatedMMChainWeightsTestReference.dml} |  12 +-
 15 files changed, 279 insertions(+), 59 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java 
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index 5269e79..2d8fdbd 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -1089,7 +1089,7 @@ public class InstructionUtils
                return InstructionUtils.concatOperands(parts[0], parts[1], 
createOperand(op1), createOperand(op2), createOperand(out));
        }
 
-       public static String constructUnaryInstString(String instString, 
CPOperand op1, String opcode, CPOperand out) {
+       public static String constructUnaryInstString(String instString, String 
opcode, CPOperand op1, CPOperand out) {
                String[] parts = instString.split(Lop.OPERAND_DELIMITOR);
                parts[1] = opcode;
                return InstructionUtils.concatOperands(parts[0], parts[1], 
createOperand(op1), createOperand(out));
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 4ab080d..df39028 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -20,6 +20,7 @@
 package org.apache.sysds.runtime.instructions.fed;
 
 import org.apache.commons.lang3.ArrayUtils;
+
 import org.apache.sysds.runtime.codegen.SpoofCellwise;
 import org.apache.sysds.runtime.codegen.SpoofMultiAggregate;
 import org.apache.sysds.runtime.codegen.SpoofOuterProduct;
@@ -105,13 +106,14 @@ public class FEDInstructionUtils {
                else if( inst instanceof MMChainCPInstruction) {
                        MMChainCPInstruction linst = (MMChainCPInstruction) 
inst;
                        MatrixObject mo = ec.getMatrixObject(linst.input1);
-                       if( mo.isFederated() )
+                       if( mo.isFederated(FType.ROW) )
                                fedinst = 
MMChainFEDInstruction.parseInstruction(linst.getInstructionString());
                }
                else if( inst instanceof MMTSJCPInstruction ) {
                        MMTSJCPInstruction linst = (MMTSJCPInstruction) inst;
                        MatrixObject mo = ec.getMatrixObject(linst.input1);
-                       if( mo.isFederated() )
+                       if( (mo.isFederated(FType.ROW) && 
linst.getMMTSJType().isLeft()) ||
+                               (mo.isFederated(FType.COL) && 
linst.getMMTSJType().isRight()))
                                fedinst = 
TsmmFEDInstruction.parseInstruction(linst.getInstructionString());
                }
                else if (inst instanceof UnaryCPInstruction && ! (inst 
instanceof IndexingCPInstruction)) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
index 99a305b..25df0d3 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
@@ -69,7 +69,7 @@ public class MMChainFEDInstruction extends 
UnaryFEDInstruction {
                        return new MMChainFEDInstruction(in1, in2, in3, out, 
type, k, opcode, str);
                }
        }
-       
+
        @Override
        public void processInstruction(ExecutionContext ec) {
                MatrixObject mo1 = ec.getMatrixObject(input1);
@@ -104,7 +104,7 @@ public class MMChainFEDInstruction extends 
UnaryFEDInstruction {
                        FederatedRequest fr3 = new 
FederatedRequest(RequestType.GET_VAR, fr2.getID());
                        FederatedRequest fr4 = mo1.getFedMapping()
                                .cleanup(getTID(), fr0[0].getID(), fr1.getID(), 
fr2.getID());
-                       
+
                        //execute federated operations and aggregate
                        Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(getTID(), fr0, fr1, fr2, fr3, fr4);
                        MatrixBlock ret = FederationUtils.aggAdd(tmp);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index c64817f..08331a6 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -179,7 +179,7 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                long ncolId = FederationUtils.getNextFedDataID();
                CPOperand ncolOp = new CPOperand(String.valueOf(ncolId), 
ValueType.INT64, DataType.SCALAR);
 
-               String unaryString = 
InstructionUtils.constructUnaryInstString(instString, output, "ncol", ncolOp);
+               String unaryString = 
InstructionUtils.constructUnaryInstString(instString, "ncol", ncolOp, output);
                FederatedRequest fr2 = 
FederationUtils.callInstruction(unaryString, ncolOp,
                        new CPOperand[] {output}, new long[] 
{out.getFedMapping().getID()});
                FederatedRequest fr3 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
index c0fc942..aef46ce 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
@@ -19,21 +19,21 @@
 
 package org.apache.sysds.runtime.instructions.fed;
 
+import java.util.concurrent.Future;
+
 import org.apache.sysds.lops.MMTSJ.MMTSJType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
 import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
-import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 
-import java.util.concurrent.Future;
-
 public class TsmmFEDInstruction extends BinaryFEDInstruction {
        private final MMTSJType _type;
        @SuppressWarnings("unused")
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
index 37a7787..6d8e816 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
@@ -19,10 +19,9 @@
 
 package org.apache.sysds.test.functions.federated.algorithms;
 
-import org.junit.Assert;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
+import java.util.Arrays;
+import java.util.Collection;
+
 import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
@@ -30,9 +29,10 @@ import org.apache.sysds.runtime.util.HDFSTool;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
-
-import java.util.Arrays;
-import java.util.Collection;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
@@ -47,6 +47,8 @@ public class FederatedGLMTest extends AutomatedTestBase {
        public int rows;
        @Parameterized.Parameter(1)
        public int cols;
+       @Parameterized.Parameter(2)
+       public boolean rowPartitioned;
 
        @Override
        public void setUp() {
@@ -58,8 +60,9 @@ public class FederatedGLMTest extends AutomatedTestBase {
        public static Collection<Object[]> data() {
                // rows have to be even and > 1
                return Arrays.asList(new Object[][] {
-                       // {10000, 10}, {1000, 100},
-                       {2000, 43}});
+                       // {10000, 10, true}, {1000, 100, false},
+                       {2000, 44, true},
+                       {2000, 44, false}});
        }
 
        @Test
@@ -79,16 +82,18 @@ public class FederatedGLMTest extends AutomatedTestBase {
                String HOME = SCRIPT_DIR + TEST_DIR;
 
                // write input matrices
-               int halfRows = rows / 2;
+               int r = rowPartitioned ? rows / 2 : rows;
+               int c = rowPartitioned ? cols : cols / 2;
+
                // We have two matrices handled by a single federated worker
-               double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
-               double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
+               double[][] X1 = getRandomMatrix(r, c, 0, 1, 1, 42);
+               double[][] X2 = getRandomMatrix(r, c, 0, 1, 1, 1340);
                double[][] Y = getRandomMatrix(rows, 1, -1, 1, 1, 1233);
                for(int i = 0; i < rows; i++)
                        Y[i][0] = (Y[i][0] > 0) ? 1 : -1;
 
-               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
-               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(r, c, blocksize, r * c));
+               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(r, c, blocksize, r * c));
                writeInputMatrixWithMTD("Y", Y, false, new 
MatrixCharacteristics(rows, 1, blocksize, rows));
 
                // empty script name because we don't execute any script, just 
start the worker
@@ -104,18 +109,18 @@ public class FederatedGLMTest extends AutomatedTestBase {
 
                // Run reference dml script with normal matrix
                fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
-               programArgs = new String[] {"-args", input("X1"), input("X2"), 
input("Y"), expected("Z")};
+               programArgs = new String[] {"-args", input("X1"), input("X2"), 
input("Y"), Boolean.toString(rowPartitioned).toUpperCase(), expected("Z")};
                runTest(true, false, null, -1);
 
                // Run actual dml script with federated matrix
                fullDMLScriptName = HOME + TEST_NAME + ".dml";
                programArgs = new String[] {"-stats", "-nvargs", "in_X1=" + 
TestUtils.federatedAddress(port1, input("X1")),
                        "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")), "rows=" + rows, "cols=" + cols,
-                       "in_Y=" + input("Y"), "out=" + output("Z")};
+                       "in_Y=" + input("Y"), "rP=" + 
Boolean.toString(rowPartitioned).toUpperCase(), "out=" + output("Z")};
                runTest(true, false, null, -1);
 
                // compare via files
-               compareResults(1e-9);
+               compareResults(1e-2);
 
                TestUtils.shutdownThreads(t1, t2);
 
@@ -124,7 +129,7 @@ public class FederatedGLMTest extends AutomatedTestBase {
                Assert.assertTrue(heavyHittersContainsString("fed_uark+", 
"fed_uarsqk+"));
                Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
                // Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
-               Assert.assertTrue(heavyHittersContainsString("fed_mmchain"));
+               Assert.assertTrue(!rowPartitioned || 
heavyHittersContainsString("fed_mmchain"));
 
                // check that federated input files are still existing
                Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMMChainTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMMChainTest.java
new file mode 100644
index 0000000..f223cf2
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMMChainTest.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedMMChainTest extends AutomatedTestBase {
+
+       private final static String TEST_NAME1 = "FederatedMMChainTest";
+       private final static String TEST_NAME2 = "FederatedMMChainWeightsTest";
+       private final static String TEST_NAME3 = "FederatedMMChainWeights2Test";
+
+       private final static String TEST_DIR = "functions/federated/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedMMChainTest.class.getSimpleName() + "/";
+
+       private final static int blocksize = 1024;
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+       @Parameterized.Parameter(2)
+       public boolean rowPartitioned;
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               return Arrays.asList(new Object[][] {
+                       {1000, 100, true},
+                       {100, 1000, false}
+               });
+       }
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S"}));
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S"}));
+               addTestConfiguration(TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S"}));
+       }
+
+       @Test
+       public void testMMChainCP() { runMMChainTest(ExecMode.SINGLE_NODE, 
TEST_NAME1); }
+       @Test
+       public void testMMChainWeightsCP() { 
runMMChainTest(ExecMode.SINGLE_NODE, TEST_NAME2); }
+       @Test
+       public void testMMChainWeights2CP() { 
runMMChainTest(ExecMode.SINGLE_NODE, TEST_NAME3); }
+
+       private void runMMChainTest(ExecMode execMode, String TEST_NAME) {
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               ExecMode platformOld = rtplatform;
+
+               if(rtplatform == ExecMode.SPARK)
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               // write input matrices
+               int r = rows;
+               int c = cols / 4;
+               if(rowPartitioned) {
+                       r = rows / 4;
+                       c = cols;
+               }
+
+               double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
+               double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
+               double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
+               double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
+
+               MatrixCharacteristics mc = new MatrixCharacteristics(r, c, 
blocksize, r * c);
+               writeInputMatrixWithMTD("X1", X1, false, mc);
+               writeInputMatrixWithMTD("X2", X2, false, mc);
+               writeInputMatrixWithMTD("X3", X3, false, mc);
+               writeInputMatrixWithMTD("X4", X4, false, mc);
+
+               double[][] v = getRandomMatrix(cols, 1, 0, 1, 0.7, 3);
+               writeInputMatrixWithMTD("v", v, true);
+               if(!TEST_NAME.equals(TEST_NAME1)){
+                       double[][] w = getRandomMatrix(rows, 1, 0, 1, 0.7, 10);
+                       writeInputMatrixWithMTD("w", w, true);
+               }
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               int port3 = getRandomAvailablePort();
+               int port4 = getRandomAvailablePort();
+               Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+               Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+               Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
+               Thread t4 = startLocalFedWorkerThread(port4);
+
+               rtplatform = execMode;
+               if(rtplatform == ExecMode.SPARK) {
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               }
+               TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+               loadTestConfiguration(config);
+
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+               programArgs = new String[] {"-stats", "100", "-args", 
input("X1"), input("X2"), input("X3"), input("X4"),
+                       Boolean.toString(rowPartitioned).toUpperCase(), 
input("v"), input("w"), expected("S")};
+               runTest(null);
+
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-stats", "100", "-nvargs",
+                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                       "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
+                       "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")), "rows=" + rows, "cols=" + cols,
+                       "rP=" + Boolean.toString(rowPartitioned).toUpperCase(),
+                       "in_v=" + input("v"),
+                       "in_w=" + input("w"),
+                       "out_S=" + output("S")};
+               runTest(null);
+
+               // compare via files
+               compareResults(1e-9);
+
+               // check that federated input files are still existing
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+               TestUtils.shutdownThreads(t1, t2, t3, t4);
+
+               rtplatform = platformOld;
+               DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+       }
+}
diff --git a/src/test/scripts/functions/federated/FederatedGLMTest.dml 
b/src/test/scripts/functions/federated/FederatedGLMTest.dml
index aa23b5e..6c349bd 100644
--- a/src/test/scripts/functions/federated/FederatedGLMTest.dml
+++ b/src/test/scripts/functions/federated/FederatedGLMTest.dml
@@ -18,9 +18,13 @@
 # under the License.
 #
 #-------------------------------------------------------------
-
-X = federated(addresses=list($in_X1, $in_X2),
-    ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), 
list($rows, $cols)))
+if ($rP) {
+  X = federated(addresses=list($in_X1, $in_X2),
+      ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), 
list($rows, $cols)))
+} else {
+  X = federated(addresses=list($in_X1, $in_X2),
+        ranges=list(list(0, 0), list($rows, $cols / 2), list(0, $cols / 2), 
list($rows, $cols)))
+}
 Y = read($in_Y)
 
 model = glm(X=X, Y=Y, icpt = FALSE, tol = 1e-6, reg = 0.01)
diff --git a/src/test/scripts/functions/federated/FederatedGLMTestReference.dml 
b/src/test/scripts/functions/federated/FederatedGLMTestReference.dml
index a307c8c..fe815a4 100644
--- a/src/test/scripts/functions/federated/FederatedGLMTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedGLMTestReference.dml
@@ -19,7 +19,11 @@
 #
 #-------------------------------------------------------------
 
-X = rbind(read($1), read($2))
+if ($4) {
+  X = rbind(read($1), read($2))
+} else {
+  X = cbind(read($1), read($2))
+}
 Y = read($3)
 model = glm(X=X, Y=Y, icpt = FALSE, tol = 1e-6, reg = 0.01)
-write(model, $4)
+write(model, $5)
diff --git a/src/test/scripts/functions/federated/FederatedGLMTest.dml 
b/src/test/scripts/functions/federated/FederatedMMChainTest.dml
similarity index 59%
copy from src/test/scripts/functions/federated/FederatedGLMTest.dml
copy to src/test/scripts/functions/federated/FederatedMMChainTest.dml
index aa23b5e..2f205b6 100644
--- a/src/test/scripts/functions/federated/FederatedGLMTest.dml
+++ b/src/test/scripts/functions/federated/FederatedMMChainTest.dml
@@ -19,9 +19,19 @@
 #
 #-------------------------------------------------------------
 
-X = federated(addresses=list($in_X1, $in_X2),
-    ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), 
list($rows, $cols)))
-Y = read($in_Y)
+if ($rP) {
+    X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+  } else {
+    X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+            ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
+               list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
+    }
 
-model = glm(X=X, Y=Y, icpt = FALSE, tol = 1e-6, reg = 0.01)
-write(model, $out)
+v = read($in_v);
+
+S = (t(X) %*% (X %*% v));
+print(nrow(S))
+print(ncol(S))
+write(S, $out_S);
diff --git a/src/test/scripts/functions/federated/FederatedGLMTestReference.dml 
b/src/test/scripts/functions/federated/FederatedMMChainTestReference.dml
similarity index 83%
copy from src/test/scripts/functions/federated/FederatedGLMTestReference.dml
copy to src/test/scripts/functions/federated/FederatedMMChainTestReference.dml
index a307c8c..d423c7f 100644
--- a/src/test/scripts/functions/federated/FederatedGLMTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedMMChainTestReference.dml
@@ -19,7 +19,11 @@
 #
 #-------------------------------------------------------------
 
-X = rbind(read($1), read($2))
-Y = read($3)
-model = glm(X=X, Y=Y, icpt = FALSE, tol = 1e-6, reg = 0.01)
-write(model, $4)
+if($5) { X = rbind(read($1), read($2), read($3), read($4)); }
+else { X = cbind(read($1), read($2), read($3), read($4));}
+
+v = read($6);
+
+S = (t(X) %*% (X %*% v));
+
+write(S, $8);
diff --git a/src/test/scripts/functions/federated/FederatedGLMTest.dml 
b/src/test/scripts/functions/federated/FederatedMMChainWeights2Test.dml
similarity index 59%
copy from src/test/scripts/functions/federated/FederatedGLMTest.dml
copy to src/test/scripts/functions/federated/FederatedMMChainWeights2Test.dml
index aa23b5e..dc4e3ab 100644
--- a/src/test/scripts/functions/federated/FederatedGLMTest.dml
+++ b/src/test/scripts/functions/federated/FederatedMMChainWeights2Test.dml
@@ -19,9 +19,18 @@
 #
 #-------------------------------------------------------------
 
-X = federated(addresses=list($in_X1, $in_X2),
-    ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), 
list($rows, $cols)))
-Y = read($in_Y)
+if ($rP) {
+    X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+  } else {
+    X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+            ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
+               list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
+    }
 
-model = glm(X=X, Y=Y, icpt = FALSE, tol = 1e-6, reg = 0.01)
-write(model, $out)
+v = read($in_v);
+w = read($in_w);
+S = t(X) %*% ((X %*% v)-w);
+
+write(S, $out_S);
diff --git a/src/test/scripts/functions/federated/FederatedGLMTestReference.dml 
b/src/test/scripts/functions/federated/FederatedMMChainWeights2TestReference.dml
similarity index 82%
copy from src/test/scripts/functions/federated/FederatedGLMTestReference.dml
copy to 
src/test/scripts/functions/federated/FederatedMMChainWeights2TestReference.dml
index a307c8c..cd9a939 100644
--- a/src/test/scripts/functions/federated/FederatedGLMTestReference.dml
+++ 
b/src/test/scripts/functions/federated/FederatedMMChainWeights2TestReference.dml
@@ -19,7 +19,11 @@
 #
 #-------------------------------------------------------------
 
-X = rbind(read($1), read($2))
-Y = read($3)
-model = glm(X=X, Y=Y, icpt = FALSE, tol = 1e-6, reg = 0.01)
-write(model, $4)
+if($5) { X = rbind(read($1), read($2), read($3), read($4)); }
+else { X = cbind(read($1), read($2), read($3), read($4));}
+
+v = read($6);
+w = read($7);
+S = t(X) %*% ((X %*% v)-w);
+
+write(S, $8);
diff --git a/src/test/scripts/functions/federated/FederatedGLMTest.dml 
b/src/test/scripts/functions/federated/FederatedMMChainWeightsTest.dml
similarity index 59%
copy from src/test/scripts/functions/federated/FederatedGLMTest.dml
copy to src/test/scripts/functions/federated/FederatedMMChainWeightsTest.dml
index aa23b5e..d68d059 100644
--- a/src/test/scripts/functions/federated/FederatedGLMTest.dml
+++ b/src/test/scripts/functions/federated/FederatedMMChainWeightsTest.dml
@@ -19,9 +19,18 @@
 #
 #-------------------------------------------------------------
 
-X = federated(addresses=list($in_X1, $in_X2),
-    ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), 
list($rows, $cols)))
-Y = read($in_Y)
+if ($rP) {
+    X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+  } else {
+    X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+            ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
+               list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
+    }
 
-model = glm(X=X, Y=Y, icpt = FALSE, tol = 1e-6, reg = 0.01)
-write(model, $out)
+v = read($in_v);
+w = read($in_w);
+S = (t(X) %*% (w*(X %*% v)));
+
+write(S, $out_S);
diff --git a/src/test/scripts/functions/federated/FederatedGLMTestReference.dml 
b/src/test/scripts/functions/federated/FederatedMMChainWeightsTestReference.dml
similarity index 82%
copy from src/test/scripts/functions/federated/FederatedGLMTestReference.dml
copy to 
src/test/scripts/functions/federated/FederatedMMChainWeightsTestReference.dml
index a307c8c..61947b0 100644
--- a/src/test/scripts/functions/federated/FederatedGLMTestReference.dml
+++ 
b/src/test/scripts/functions/federated/FederatedMMChainWeightsTestReference.dml
@@ -19,7 +19,11 @@
 #
 #-------------------------------------------------------------
 
-X = rbind(read($1), read($2))
-Y = read($3)
-model = glm(X=X, Y=Y, icpt = FALSE, tol = 1e-6, reg = 0.01)
-write(model, $4)
+if($5) { X = rbind(read($1), read($2), read($3), read($4)); }
+else { X = cbind(read($1), read($2), read($3), read($4));}
+
+v = read($6);
+w = read($7);
+S = (t(X) %*% (w*(X %*% v)));
+
+write(S, $8);

Reply via email to