This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new a0f1e81 [SYSTEMDS-2988] Additional federated GLM algorithm tests (col
parts)
a0f1e81 is described below
commit a0f1e8192ba9671e9a5bb431b3d69d7755008588
Author: Olga <[email protected]>
AuthorDate: Sat Jun 5 16:08:20 2021 +0200
[SYSTEMDS-2988] Additional federated GLM algorithm tests (col parts)
Additional tests (GLM column-partitioned, and mmchain) as well as
related improvements of the federated backend.
Closes #1289.
---
.../runtime/instructions/InstructionUtils.java | 2 +-
.../instructions/fed/FEDInstructionUtils.java | 6 +-
.../instructions/fed/MMChainFEDInstruction.java | 4 +-
.../fed/ParameterizedBuiltinFEDInstruction.java | 2 +-
.../instructions/fed/TsmmFEDInstruction.java | 6 +-
.../federated/algorithms/FederatedGLMTest.java | 41 ++---
.../federated/primitives/FederatedMMChainTest.java | 165 +++++++++++++++++++++
.../functions/federated/FederatedGLMTest.dml | 10 +-
.../federated/FederatedGLMTestReference.dml | 8 +-
...deratedGLMTest.dml => FederatedMMChainTest.dml} | 20 ++-
...rence.dml => FederatedMMChainTestReference.dml} | 12 +-
...LMTest.dml => FederatedMMChainWeights2Test.dml} | 19 ++-
...l => FederatedMMChainWeights2TestReference.dml} | 12 +-
...GLMTest.dml => FederatedMMChainWeightsTest.dml} | 19 ++-
...ml => FederatedMMChainWeightsTestReference.dml} | 12 +-
15 files changed, 279 insertions(+), 59 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index 5269e79..2d8fdbd 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -1089,7 +1089,7 @@ public class InstructionUtils
return InstructionUtils.concatOperands(parts[0], parts[1],
createOperand(op1), createOperand(op2), createOperand(out));
}
- public static String constructUnaryInstString(String instString,
CPOperand op1, String opcode, CPOperand out) {
+ public static String constructUnaryInstString(String instString, String
opcode, CPOperand op1, CPOperand out) {
String[] parts = instString.split(Lop.OPERAND_DELIMITOR);
parts[1] = opcode;
return InstructionUtils.concatOperands(parts[0], parts[1],
createOperand(op1), createOperand(out));
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 4ab080d..df39028 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -20,6 +20,7 @@
package org.apache.sysds.runtime.instructions.fed;
import org.apache.commons.lang3.ArrayUtils;
+
import org.apache.sysds.runtime.codegen.SpoofCellwise;
import org.apache.sysds.runtime.codegen.SpoofMultiAggregate;
import org.apache.sysds.runtime.codegen.SpoofOuterProduct;
@@ -105,13 +106,14 @@ public class FEDInstructionUtils {
else if( inst instanceof MMChainCPInstruction) {
MMChainCPInstruction linst = (MMChainCPInstruction)
inst;
MatrixObject mo = ec.getMatrixObject(linst.input1);
- if( mo.isFederated() )
+ if( mo.isFederated(FType.ROW) )
fedinst =
MMChainFEDInstruction.parseInstruction(linst.getInstructionString());
}
else if( inst instanceof MMTSJCPInstruction ) {
MMTSJCPInstruction linst = (MMTSJCPInstruction) inst;
MatrixObject mo = ec.getMatrixObject(linst.input1);
- if( mo.isFederated() )
+ if( (mo.isFederated(FType.ROW) &&
linst.getMMTSJType().isLeft()) ||
+ (mo.isFederated(FType.COL) &&
linst.getMMTSJType().isRight()))
fedinst =
TsmmFEDInstruction.parseInstruction(linst.getInstructionString());
}
else if (inst instanceof UnaryCPInstruction && ! (inst
instanceof IndexingCPInstruction)) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
index 99a305b..25df0d3 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
@@ -69,7 +69,7 @@ public class MMChainFEDInstruction extends
UnaryFEDInstruction {
return new MMChainFEDInstruction(in1, in2, in3, out,
type, k, opcode, str);
}
}
-
+
@Override
public void processInstruction(ExecutionContext ec) {
MatrixObject mo1 = ec.getMatrixObject(input1);
@@ -104,7 +104,7 @@ public class MMChainFEDInstruction extends
UnaryFEDInstruction {
FederatedRequest fr3 = new
FederatedRequest(RequestType.GET_VAR, fr2.getID());
FederatedRequest fr4 = mo1.getFedMapping()
.cleanup(getTID(), fr0[0].getID(), fr1.getID(),
fr2.getID());
-
+
//execute federated operations and aggregate
Future<FederatedResponse>[] tmp =
mo1.getFedMapping().execute(getTID(), fr0, fr1, fr2, fr3, fr4);
MatrixBlock ret = FederationUtils.aggAdd(tmp);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index c64817f..08331a6 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -179,7 +179,7 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
long ncolId = FederationUtils.getNextFedDataID();
CPOperand ncolOp = new CPOperand(String.valueOf(ncolId),
ValueType.INT64, DataType.SCALAR);
- String unaryString =
InstructionUtils.constructUnaryInstString(instString, output, "ncol", ncolOp);
+ String unaryString =
InstructionUtils.constructUnaryInstString(instString, "ncol", ncolOp, output);
FederatedRequest fr2 =
FederationUtils.callInstruction(unaryString, ncolOp,
new CPOperand[] {output}, new long[]
{out.getFedMapping().getID()});
FederatedRequest fr3 = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
index c0fc942..aef46ce 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
@@ -19,21 +19,21 @@
package org.apache.sysds.runtime.instructions.fed;
+import java.util.concurrent.Future;
+
import org.apache.sysds.lops.MMTSJ.MMTSJType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
-import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import java.util.concurrent.Future;
-
public class TsmmFEDInstruction extends BinaryFEDInstruction {
private final MMTSJType _type;
@SuppressWarnings("unused")
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
index 37a7787..6d8e816 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
@@ -19,10 +19,9 @@
package org.apache.sysds.test.functions.federated.algorithms;
-import org.junit.Assert;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
+import java.util.Arrays;
+import java.util.Collection;
+
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
@@ -30,9 +29,10 @@ import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
-
-import java.util.Arrays;
-import java.util.Collection;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
@@ -47,6 +47,8 @@ public class FederatedGLMTest extends AutomatedTestBase {
public int rows;
@Parameterized.Parameter(1)
public int cols;
+ @Parameterized.Parameter(2)
+ public boolean rowPartitioned;
@Override
public void setUp() {
@@ -58,8 +60,9 @@ public class FederatedGLMTest extends AutomatedTestBase {
public static Collection<Object[]> data() {
// rows have to be even and > 1
return Arrays.asList(new Object[][] {
- // {10000, 10}, {1000, 100},
- {2000, 43}});
+ // {10000, 10, true}, {1000, 100, false},
+ {2000, 44, true},
+ {2000, 44, false}});
}
@Test
@@ -79,16 +82,18 @@ public class FederatedGLMTest extends AutomatedTestBase {
String HOME = SCRIPT_DIR + TEST_DIR;
// write input matrices
- int halfRows = rows / 2;
+ int r = rowPartitioned ? rows / 2 : rows;
+ int c = rowPartitioned ? cols : cols / 2;
+
// We have two matrices handled by a single federated worker
- double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
- double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
+ double[][] X1 = getRandomMatrix(r, c, 0, 1, 1, 42);
+ double[][] X2 = getRandomMatrix(r, c, 0, 1, 1, 1340);
double[][] Y = getRandomMatrix(rows, 1, -1, 1, 1, 1233);
for(int i = 0; i < rows; i++)
Y[i][0] = (Y[i][0] > 0) ? 1 : -1;
- writeInputMatrixWithMTD("X1", X1, false, new
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
- writeInputMatrixWithMTD("X2", X2, false, new
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("X1", X1, false, new
MatrixCharacteristics(r, c, blocksize, r * c));
+ writeInputMatrixWithMTD("X2", X2, false, new
MatrixCharacteristics(r, c, blocksize, r * c));
writeInputMatrixWithMTD("Y", Y, false, new
MatrixCharacteristics(rows, 1, blocksize, rows));
// empty script name because we don't execute any script, just
start the worker
@@ -104,18 +109,18 @@ public class FederatedGLMTest extends AutomatedTestBase {
// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
- programArgs = new String[] {"-args", input("X1"), input("X2"),
input("Y"), expected("Z")};
+ programArgs = new String[] {"-args", input("X1"), input("X2"),
input("Y"), Boolean.toString(rowPartitioned).toUpperCase(), expected("Z")};
runTest(true, false, null, -1);
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-stats", "-nvargs", "in_X1=" +
TestUtils.federatedAddress(port1, input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2,
input("X2")), "rows=" + rows, "cols=" + cols,
- "in_Y=" + input("Y"), "out=" + output("Z")};
+ "in_Y=" + input("Y"), "rP=" +
Boolean.toString(rowPartitioned).toUpperCase(), "out=" + output("Z")};
runTest(true, false, null, -1);
// compare via files
- compareResults(1e-9);
+ compareResults(1e-2);
TestUtils.shutdownThreads(t1, t2);
@@ -124,7 +129,7 @@ public class FederatedGLMTest extends AutomatedTestBase {
Assert.assertTrue(heavyHittersContainsString("fed_uark+",
"fed_uarsqk+"));
Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
// Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
- Assert.assertTrue(heavyHittersContainsString("fed_mmchain"));
+ Assert.assertTrue(!rowPartitioned ||
heavyHittersContainsString("fed_mmchain"));
// check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMMChainTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMMChainTest.java
new file mode 100644
index 0000000..f223cf2
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMMChainTest.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedMMChainTest extends AutomatedTestBase {
+
+ private final static String TEST_NAME1 = "FederatedMMChainTest";
+ private final static String TEST_NAME2 = "FederatedMMChainWeightsTest";
+ private final static String TEST_NAME3 = "FederatedMMChainWeights2Test";
+
+ private final static String TEST_DIR = "functions/federated/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
FederatedMMChainTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+ @Parameterized.Parameter(2)
+ public boolean rowPartitioned;
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ {1000, 100, true},
+ {100, 1000, false}
+ });
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S"}));
+ addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S"}));
+ addTestConfiguration(TEST_NAME3, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S"}));
+ }
+
+ @Test
+ public void testMMChainCP() { runMMChainTest(ExecMode.SINGLE_NODE,
TEST_NAME1); }
+ @Test
+ public void testMMChainWeightsCP() {
runMMChainTest(ExecMode.SINGLE_NODE, TEST_NAME2); }
+ @Test
+ public void testMMChainWeights2CP() {
runMMChainTest(ExecMode.SINGLE_NODE, TEST_NAME3); }
+
+ private void runMMChainTest(ExecMode execMode, String TEST_NAME) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ ExecMode platformOld = rtplatform;
+
+ if(rtplatform == ExecMode.SPARK)
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // write input matrices
+ int r = rows;
+ int c = cols / 4;
+ if(rowPartitioned) {
+ r = rows / 4;
+ c = cols;
+ }
+
+ double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
+ double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
+ double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
+ double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
+
+ MatrixCharacteristics mc = new MatrixCharacteristics(r, c,
blocksize, r * c);
+ writeInputMatrixWithMTD("X1", X1, false, mc);
+ writeInputMatrixWithMTD("X2", X2, false, mc);
+ writeInputMatrixWithMTD("X3", X3, false, mc);
+ writeInputMatrixWithMTD("X4", X4, false, mc);
+
+ double[][] v = getRandomMatrix(cols, 1, 0, 1, 0.7, 3);
+ writeInputMatrixWithMTD("v", v, true);
+ if(!TEST_NAME.equals(TEST_NAME1)){
+ double[][] w = getRandomMatrix(rows, 1, 0, 1, 0.7, 10);
+ writeInputMatrixWithMTD("w", w, true);
+ }
+
+ // empty script name because we don't execute any script, just
start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ int port3 = getRandomAvailablePort();
+ int port4 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+ Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
+ Thread t4 = startLocalFedWorkerThread(port4);
+
+ rtplatform = execMode;
+ if(rtplatform == ExecMode.SPARK) {
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+ TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-stats", "100", "-args",
input("X1"), input("X2"), input("X3"), input("X4"),
+ Boolean.toString(rowPartitioned).toUpperCase(),
input("v"), input("w"), expected("S")};
+ runTest(null);
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "100", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "in_X3=" + TestUtils.federatedAddress(port3,
input("X3")),
+ "in_X4=" + TestUtils.federatedAddress(port4,
input("X4")), "rows=" + rows, "cols=" + cols,
+ "rP=" + Boolean.toString(rowPartitioned).toUpperCase(),
+ "in_v=" + input("v"),
+ "in_w=" + input("w"),
+ "out_S=" + output("S")};
+ runTest(null);
+
+ // compare via files
+ compareResults(1e-9);
+
+ // check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+ TestUtils.shutdownThreads(t1, t2, t3, t4);
+
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+}
diff --git a/src/test/scripts/functions/federated/FederatedGLMTest.dml
b/src/test/scripts/functions/federated/FederatedGLMTest.dml
index aa23b5e..6c349bd 100644
--- a/src/test/scripts/functions/federated/FederatedGLMTest.dml
+++ b/src/test/scripts/functions/federated/FederatedGLMTest.dml
@@ -18,9 +18,13 @@
# under the License.
#
#-------------------------------------------------------------
-
-X = federated(addresses=list($in_X1, $in_X2),
- ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0),
list($rows, $cols)))
+if ($rP) {
+ X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0),
list($rows, $cols)))
+} else {
+ X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows, $cols / 2), list(0, $cols / 2),
list($rows, $cols)))
+}
Y = read($in_Y)
model = glm(X=X, Y=Y, icpt = FALSE, tol = 1e-6, reg = 0.01)
diff --git a/src/test/scripts/functions/federated/FederatedGLMTestReference.dml
b/src/test/scripts/functions/federated/FederatedGLMTestReference.dml
index a307c8c..fe815a4 100644
--- a/src/test/scripts/functions/federated/FederatedGLMTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedGLMTestReference.dml
@@ -19,7 +19,11 @@
#
#-------------------------------------------------------------
-X = rbind(read($1), read($2))
+if ($4) {
+ X = rbind(read($1), read($2))
+} else {
+ X = cbind(read($1), read($2))
+}
Y = read($3)
model = glm(X=X, Y=Y, icpt = FALSE, tol = 1e-6, reg = 0.01)
-write(model, $4)
+write(model, $5)
diff --git a/src/test/scripts/functions/federated/FederatedGLMTest.dml
b/src/test/scripts/functions/federated/FederatedMMChainTest.dml
similarity index 59%
copy from src/test/scripts/functions/federated/FederatedGLMTest.dml
copy to src/test/scripts/functions/federated/FederatedMMChainTest.dml
index aa23b5e..2f205b6 100644
--- a/src/test/scripts/functions/federated/FederatedGLMTest.dml
+++ b/src/test/scripts/functions/federated/FederatedMMChainTest.dml
@@ -19,9 +19,19 @@
#
#-------------------------------------------------------------
-X = federated(addresses=list($in_X1, $in_X2),
- ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0),
list($rows, $cols)))
-Y = read($in_Y)
+if ($rP) {
+ X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+ } else {
+ X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4),
list($rows, $cols/2),
+ list(0,$cols/2), list($rows, 3*($cols/4)), list(0,
3*($cols/4)), list($rows, $cols)));
+ }
-model = glm(X=X, Y=Y, icpt = FALSE, tol = 1e-6, reg = 0.01)
-write(model, $out)
+v = read($in_v);
+
+S = (t(X) %*% (X %*% v));
+print(nrow(S))
+print(ncol(S))
+write(S, $out_S);
diff --git a/src/test/scripts/functions/federated/FederatedGLMTestReference.dml
b/src/test/scripts/functions/federated/FederatedMMChainTestReference.dml
similarity index 83%
copy from src/test/scripts/functions/federated/FederatedGLMTestReference.dml
copy to src/test/scripts/functions/federated/FederatedMMChainTestReference.dml
index a307c8c..d423c7f 100644
--- a/src/test/scripts/functions/federated/FederatedGLMTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedMMChainTestReference.dml
@@ -19,7 +19,11 @@
#
#-------------------------------------------------------------
-X = rbind(read($1), read($2))
-Y = read($3)
-model = glm(X=X, Y=Y, icpt = FALSE, tol = 1e-6, reg = 0.01)
-write(model, $4)
+if($5) { X = rbind(read($1), read($2), read($3), read($4)); }
+else { X = cbind(read($1), read($2), read($3), read($4));}
+
+v = read($6);
+
+S = (t(X) %*% (X %*% v));
+
+write(S, $8);
diff --git a/src/test/scripts/functions/federated/FederatedGLMTest.dml
b/src/test/scripts/functions/federated/FederatedMMChainWeights2Test.dml
similarity index 59%
copy from src/test/scripts/functions/federated/FederatedGLMTest.dml
copy to src/test/scripts/functions/federated/FederatedMMChainWeights2Test.dml
index aa23b5e..dc4e3ab 100644
--- a/src/test/scripts/functions/federated/FederatedGLMTest.dml
+++ b/src/test/scripts/functions/federated/FederatedMMChainWeights2Test.dml
@@ -19,9 +19,18 @@
#
#-------------------------------------------------------------
-X = federated(addresses=list($in_X1, $in_X2),
- ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0),
list($rows, $cols)))
-Y = read($in_Y)
+if ($rP) {
+ X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+ } else {
+ X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4),
list($rows, $cols/2),
+ list(0,$cols/2), list($rows, 3*($cols/4)), list(0,
3*($cols/4)), list($rows, $cols)));
+ }
-model = glm(X=X, Y=Y, icpt = FALSE, tol = 1e-6, reg = 0.01)
-write(model, $out)
+v = read($in_v);
+w = read($in_w);
+S = t(X) %*% ((X %*% v)-w);
+
+write(S, $out_S);
diff --git a/src/test/scripts/functions/federated/FederatedGLMTestReference.dml
b/src/test/scripts/functions/federated/FederatedMMChainWeights2TestReference.dml
similarity index 82%
copy from src/test/scripts/functions/federated/FederatedGLMTestReference.dml
copy to
src/test/scripts/functions/federated/FederatedMMChainWeights2TestReference.dml
index a307c8c..cd9a939 100644
--- a/src/test/scripts/functions/federated/FederatedGLMTestReference.dml
+++
b/src/test/scripts/functions/federated/FederatedMMChainWeights2TestReference.dml
@@ -19,7 +19,11 @@
#
#-------------------------------------------------------------
-X = rbind(read($1), read($2))
-Y = read($3)
-model = glm(X=X, Y=Y, icpt = FALSE, tol = 1e-6, reg = 0.01)
-write(model, $4)
+if($5) { X = rbind(read($1), read($2), read($3), read($4)); }
+else { X = cbind(read($1), read($2), read($3), read($4));}
+
+v = read($6);
+w = read($7);
+S = t(X) %*% ((X %*% v)-w);
+
+write(S, $8);
diff --git a/src/test/scripts/functions/federated/FederatedGLMTest.dml
b/src/test/scripts/functions/federated/FederatedMMChainWeightsTest.dml
similarity index 59%
copy from src/test/scripts/functions/federated/FederatedGLMTest.dml
copy to src/test/scripts/functions/federated/FederatedMMChainWeightsTest.dml
index aa23b5e..d68d059 100644
--- a/src/test/scripts/functions/federated/FederatedGLMTest.dml
+++ b/src/test/scripts/functions/federated/FederatedMMChainWeightsTest.dml
@@ -19,9 +19,18 @@
#
#-------------------------------------------------------------
-X = federated(addresses=list($in_X1, $in_X2),
- ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0),
list($rows, $cols)))
-Y = read($in_Y)
+if ($rP) {
+ X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+ } else {
+ X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4),
list($rows, $cols/2),
+ list(0,$cols/2), list($rows, 3*($cols/4)), list(0,
3*($cols/4)), list($rows, $cols)));
+ }
-model = glm(X=X, Y=Y, icpt = FALSE, tol = 1e-6, reg = 0.01)
-write(model, $out)
+v = read($in_v);
+w = read($in_w);
+S = (t(X) %*% (w*(X %*% v)));
+
+write(S, $out_S);
diff --git a/src/test/scripts/functions/federated/FederatedGLMTestReference.dml
b/src/test/scripts/functions/federated/FederatedMMChainWeightsTestReference.dml
similarity index 82%
copy from src/test/scripts/functions/federated/FederatedGLMTestReference.dml
copy to
src/test/scripts/functions/federated/FederatedMMChainWeightsTestReference.dml
index a307c8c..61947b0 100644
--- a/src/test/scripts/functions/federated/FederatedGLMTestReference.dml
+++
b/src/test/scripts/functions/federated/FederatedMMChainWeightsTestReference.dml
@@ -19,7 +19,11 @@
#
#-------------------------------------------------------------
-X = rbind(read($1), read($2))
-Y = read($3)
-model = glm(X=X, Y=Y, icpt = FALSE, tol = 1e-6, reg = 0.01)
-write(model, $4)
+if($5) { X = rbind(read($1), read($2), read($3), read($4)); }
+else { X = cbind(read($1), read($2), read($3), read($4));}
+
+v = read($6);
+w = read($7);
+S = (t(X) %*% (w*(X %*% v)));
+
+write(S, $8);