This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new a65acda [SYSTEMDS-2747] Federated Quarternary Operations WSLoss and
WSigmoid
a65acda is described below
commit a65acda690bc6afc77a324b0dc323994f540cb7a
Author: ywcb00 <[email protected]>
AuthorDate: Mon Dec 28 13:50:07 2020 +0100
[SYSTEMDS-2747] Federated Quarternary Operations WSLoss and WSigmoid
Quaternary operations, part 2
This commit adds support for two federated quarternary operations,
along with federated junit tests.
- Weighted Sigmoid function
- Weighted Sigmoid loss function
Closes #1143
---
.../instructions/fed/QuaternaryFEDInstruction.java | 91 +++++++---
.../fed/QuaternaryWSLossFEDInstruction.java | 119 ++++++++++++
.../fed/QuaternaryWSigmoidFEDInstruction.java | 89 +++++++++
.../primitives/FederatedWeightedSigmoidTest.java | 202 +++++++++++++++++++++
.../FederatedWeightedSquaredLossTest.java | 194 ++++++++++++++++++++
.../quaternary/FederatedWSLossPostTest.dml | 31 ++++
.../FederatedWSLossPostTestReference.dml | 29 +++
.../quaternary/FederatedWSLossPreTest.dml | 31 ++++
.../quaternary/FederatedWSLossPreTestReference.dml | 29 +++
.../federated/quaternary/FederatedWSLossTest.dml | 30 +++
.../quaternary/FederatedWSLossTestReference.dml | 28 +++
.../quaternary/FederatedWSigmoidLogTest.dml | 31 ++++
.../FederatedWSigmoidLogTestReference.dml | 29 +++
.../quaternary/FederatedWSigmoidMinusLogTest.dml | 31 ++++
.../FederatedWSigmoidMinusLogTestReference.dml | 29 +++
.../quaternary/FederatedWSigmoidMinusTest.dml | 31 ++++
.../FederatedWSigmoidMinusTestReference.dml | 29 +++
.../federated/quaternary/FederatedWSigmoidTest.dml | 31 ++++
.../quaternary/FederatedWSigmoidTestReference.dml | 29 +++
19 files changed, 1089 insertions(+), 24 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java
index 2b62ec5..ffb385d 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java
@@ -23,63 +23,106 @@ import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.WeightedCrossEntropy.WCeMMType;
+import org.apache.sysds.lops.WeightedSigmoid.WSigmoidType;
+import org.apache.sysds.lops.WeightedSquaredLoss.WeightsType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
-import org.apache.sysds.runtime.instructions.fed.QuaternaryWCeMMFEDInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
-public abstract class QuaternaryFEDInstruction extends
ComputationFEDInstruction
-{
+public abstract class QuaternaryFEDInstruction extends
ComputationFEDInstruction {
protected CPOperand _input4 = null;
- protected QuaternaryFEDInstruction(FEDInstruction.FEDType type,
Operator operator,
- CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4,
CPOperand out, String opcode, String instruction_str)
- {
+ protected QuaternaryFEDInstruction(FEDInstruction.FEDType type,
Operator operator, CPOperand in1, CPOperand in2,
+ CPOperand in3, CPOperand out, String opcode, String
instruction_str) {
+ super(type, operator, in1, in2, in3, out, opcode,
instruction_str);
+ }
+
+ protected QuaternaryFEDInstruction(FEDInstruction.FEDType type,
Operator operator, CPOperand in1, CPOperand in2,
+ CPOperand in3, CPOperand in4, CPOperand out, String opcode,
String instruction_str) {
super(type, operator, in1, in2, in3, out, opcode,
instruction_str);
_input4 = in4;
}
- public static QuaternaryFEDInstruction parseInstruction(String str)
- {
+ public static QuaternaryFEDInstruction parseInstruction(String str) {
if(str.startsWith(ExecType.SPARK.name())) {
// rewrite the spark instruction to a cp instruction
str = str.replace(ExecType.SPARK.name(),
ExecType.CP.name());
str = str.replace("mapwcemm", "wcemm");
- str += Lop.OPERAND_DELIMITOR + "1"; //num threads
+ str = str.replace("mapwsloss", "wsloss");
+ if(str.contains("redwsloss")) {
+ str = str.replace("redwsloss", "wsloss");
+ // remove booleans which indicate cacheU and
cacheV for redwsloss
+ str = str.replace(Lop.OPERAND_DELIMITOR +
"true", "");
+ str = str.replace(Lop.OPERAND_DELIMITOR +
"false", "");
+ }
+ str = str.replace("mapwsigmoid", "wsigmoid");
+ str += Lop.OPERAND_DELIMITOR + "1"; // num threads
}
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
+ int addInput4 = (opcode.equals("wcemm") ||
opcode.equals("wsloss")) ? 1 : 0;
+
+ InstructionUtils.checkNumFields(parts, 6 + addInput4);
+
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
- CPOperand out = new CPOperand(parts[5]);
+ CPOperand out = new CPOperand(parts[4 + addInput4]);
- InstructionUtils.checkNumFields(parts, 7);
+ checkDataTypes(DataType.MATRIX, in1, in2, in3);
- if(opcode.equals("wcemm")) {
+ QuaternaryOperator qop = null;
+ if(addInput4 == 1) // wcemm, wsloss
+ {
CPOperand in4 = new CPOperand(parts[4]);
- checkDataTypes(in1, in2, in3, in4);
- WCeMMType wcemm_type = WCeMMType.valueOf(parts[6]);
- QuaternaryOperator quaternary_operator =
(wcemm_type.hasFourInputs() ?
- new QuaternaryOperator(wcemm_type,
Double.parseDouble(in4.getName())) :
- new QuaternaryOperator(wcemm_type));
- return new
QuaternaryWCeMMFEDInstruction(quaternary_operator, in1, in2, in3, in4, out,
opcode, str);
+ if(opcode.equals("wcemm")) {
+ final WCeMMType wcemm_type =
WCeMMType.valueOf(parts[6]);
+ if(wcemm_type.hasFourInputs())
+ checkDataTypes(new DataType[]
{DataType.SCALAR, DataType.MATRIX}, in4);
+ qop = (wcemm_type.hasFourInputs() ? new
QuaternaryOperator(wcemm_type,
+ Double.parseDouble(in4.getName())) :
new QuaternaryOperator(wcemm_type));
+ return new QuaternaryWCeMMFEDInstruction(qop,
in1, in2, in3, in4, out, opcode, str);
+ }
+ else if(opcode.equals("wsloss")) {
+ final WeightsType weights_type =
WeightsType.valueOf(parts[6]);
+ if(weights_type.hasFourInputs())
+ checkDataTypes(DataType.MATRIX, in4);
+ qop = new QuaternaryOperator(weights_type);
+ return new QuaternaryWSLossFEDInstruction(qop,
in1, in2, in3, in4, out, opcode, str);
+ }
+ }
+ else if(opcode.equals("wsigmoid")) {
+ final WSigmoidType wsigmoid_type =
WSigmoidType.valueOf(parts[5]);
+ qop = new QuaternaryOperator(wsigmoid_type);
+ return new QuaternaryWSigmoidFEDInstruction(qop, in1,
in2, in3, out, opcode, str);
}
throw new DMLRuntimeException("Unsupported opcode (" + opcode +
") for QuaternaryFEDInstruction.");
}
- protected static void checkDataTypes(CPOperand in1, CPOperand in2,
CPOperand in3, CPOperand in4) {
- if(in1.getDataType() != DataType.MATRIX || in2.getDataType() !=
DataType.MATRIX
- || in3.getDataType() != DataType.MATRIX
- || !(in4.getDataType() == DataType.SCALAR ||
in4.getDataType() == DataType.MATRIX)) {
- throw new DMLRuntimeException("Federated quaternary
operations "
- + "only supported with matrix inputs and scalar
epsilon.");
+ protected static void checkDataTypes(DataType data_type, CPOperand...
cp_operands) {
+ checkDataTypes(new DataType[] {data_type}, cp_operands);
+ }
+
+ protected static void checkDataTypes(DataType[] data_types,
CPOperand... cp_operands) {
+ for(CPOperand cpo : cp_operands) {
+ if(!checkDataType(data_types, cpo)) {
+ throw new DMLRuntimeException(
+ "Federated quaternary operations " +
"only supported with matrix inputs and scalar epsilon.");
+ }
+ }
+ }
+
+ private static boolean checkDataType(DataType[] data_types, CPOperand
cp_operand) {
+ for(DataType dt : data_types) {
+ if(cp_operand.getDataType() == dt)
+ return true;
}
+ return false;
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java
new file mode 100644
index 0000000..2cd38a6
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
+
+import java.util.concurrent.Future;
+
+public class QuaternaryWSLossFEDInstruction extends QuaternaryFEDInstruction {
+
+ /**
+ * This Instruction performs a Weighted Sigmoid Loss function as
follows:
+ *
+ * Z = sum(W * (X - (U %*% t(V))) ^ 2)
+ *
+ * @param operator Weighted Sigmoid Loss
+ * @param in1 X
+ * @param in2 U
+ * @param in3 V
+ * @param in4 W
+ * @param out Z
+ * @param opcode
+ * @param instruction_str
+ */
+ protected QuaternaryWSLossFEDInstruction(Operator operator, CPOperand
in1, CPOperand in2, CPOperand in3,
+ CPOperand in4, CPOperand out, String opcode, String
instruction_str) {
+ super(FEDType.Quaternary, operator, in1, in2, in3, in4, out,
opcode, instruction_str);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ QuaternaryOperator qop = (QuaternaryOperator) _optr;
+
+ MatrixObject X = ec.getMatrixObject(input1);
+ MatrixObject U = ec.getMatrixObject(input2);
+ MatrixObject V = ec.getMatrixObject(input3);
+
+ MatrixObject W = null;
+ if(qop.hasFourInputs()) {
+ W = ec.getMatrixObject(_input4);
+ }
+
+ if(!(X.isFederated() && !U.isFederated() && !V.isFederated() &&
(W == null || !W.isFederated())))
+ throw new DMLRuntimeException("Unsupported federated
inputs (X, U, V, W) = (" + X.isFederated() + ", "
+ + U.isFederated() + ", " + V.isFederated() + (W
!= null ? W.isFederated() : "none") + ")");
+
+ FederationMap fedMap = X.getFedMapping();
+ FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false);
+ FederatedRequest frInit2 = fedMap.broadcast(V);
+
+ FederatedRequest[] frInit3 = null;
+ FederatedRequest frCompute1 = null;
+ if(W != null) {
+ frInit3 = fedMap.broadcastSliced(W, false);
+ frCompute1 = FederationUtils.callInstruction(instString,
+ output,
+ new CPOperand[] {input1, input2, input3,
_input4},
+ new long[] {fedMap.getID(), frInit1[0].getID(),
frInit2.getID(), frInit3[0].getID()});
+ }
+ else {
+ frCompute1 = FederationUtils.callInstruction(instString,
+ output,
+ new CPOperand[] {input1, input2, input3},
+ new long[] {fedMap.getID(), frInit1[0].getID(),
frInit2.getID()});
+ }
+
+ FederatedRequest frGet1 = new
FederatedRequest(RequestType.GET_VAR, frCompute1.getID());
+ FederatedRequest frCleanup1 = fedMap.cleanup(getTID(),
frCompute1.getID());
+ FederatedRequest frCleanup2 = fedMap.cleanup(getTID(),
frInit1[0].getID());
+ FederatedRequest frCleanup3 = fedMap.cleanup(getTID(),
frInit2.getID());
+
+ Future<FederatedResponse>[] response;
+ if(frInit3 != null) {
+ FederatedRequest frCleanup4 = fedMap.cleanup(getTID(),
frInit3[0].getID());
+ // execute federated instructions
+ fedMap.execute(getTID(), true, frInit1, frInit2);
+ response = fedMap
+ .execute(getTID(), true, frInit3, frCompute1,
frGet1, frCleanup1, frCleanup2, frCleanup3, frCleanup4);
+ }
+ else {
+ // execute federated instructions
+ response = fedMap
+ .execute(getTID(), true, frInit1, frInit2,
frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3);
+ }
+
+ // aggregate partial results from federated responses
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+ ec.setVariable(output.getName(), FederationUtils.aggScalar(aop,
response));
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java
new file mode 100644
index 0000000..10456f8
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.concurrent.Future;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+public class QuaternaryWSigmoidFEDInstruction extends QuaternaryFEDInstruction
{
+
+ /**
+ * This instruction performs:
+ *
+ * UV = U %*% t(V); Z = X * log(1 / (1 + exp(-UV)));
+ *
+ * @param operator Weighted Sigmoid Federated Instruction.
+ * @param in1 X
+ * @param in2 U
+ * @param in3 V
+ * @param out The Federated Result Z
+ * @param opcode ...
+ * @param instruction_str ...
+ */
+ protected QuaternaryWSigmoidFEDInstruction(Operator operator, CPOperand
in1, CPOperand in2, CPOperand in3,
+ CPOperand out, String opcode, String instruction_str) {
+ super(FEDType.Quaternary, operator, in1, in2, in3, out, opcode,
instruction_str);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ MatrixObject X = ec.getMatrixObject(input1);
+ MatrixObject U = ec.getMatrixObject(input2);
+ MatrixObject V = ec.getMatrixObject(input3);
+
+ if(!(X.isFederated() && !U.isFederated() && !V.isFederated()))
+ throw new DMLRuntimeException("Unsupported federated
inputs (X, U, V) = (" + X.isFederated() + ", "
+ + U.isFederated() + ", " + V.isFederated() +
")");
+
+ FederationMap fedMap = X.getFedMapping();
+ FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false);
+ FederatedRequest frInit2 = fedMap.broadcast(V);
+
+ FederatedRequest frCompute1 =
FederationUtils.callInstruction(instString,
+ output,
+ new CPOperand[] {input1, input2, input3},
+ new long[] {fedMap.getID(), frInit1[0].getID(),
frInit2.getID()});
+
+ // get partial results from federated workers
+ FederatedRequest frGet1 = new
FederatedRequest(RequestType.GET_VAR, frCompute1.getID());
+
+ FederatedRequest frCleanup1 = fedMap.cleanup(getTID(),
frCompute1.getID());
+ FederatedRequest frCleanup2 = fedMap.cleanup(getTID(),
frInit1[0].getID());
+ FederatedRequest frCleanup3 = fedMap.cleanup(getTID(),
frInit2.getID());
+
+ // execute federated instructions
+ Future<FederatedResponse>[] response = fedMap
+ .execute(getTID(), true, frInit1, frInit2, frCompute1,
frGet1, frCleanup1, frCleanup2, frCleanup3);
+
+ // bind partial results from federated responses
+ ec.setMatrixOutput(output.getName(),
FederationUtils.bind(response, false));
+
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSigmoidTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSigmoidTest.java
new file mode 100644
index 0000000..e73ce82
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSigmoidTest.java
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedWeightedSigmoidTest extends AutomatedTestBase {
+ private final static String STD_TEST_NAME = "FederatedWSigmoidTest";
+ private final static String LOG_TEST_NAME = "FederatedWSigmoidLogTest";
+ private final static String MINUS_TEST_NAME =
"FederatedWSigmoidMinusTest";
+ private final static String MINUS_LOG_TEST_NAME =
"FederatedWSigmoidMinusLogTest";
+ private final static String TEST_DIR =
"functions/federated/quaternary/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedWeightedSigmoidTest.class.getSimpleName() + "/";
+
+ private final static String OUTPUT_NAME = "Z";
+
+ private final static double TOLERANCE = 0;
+
+ private final static int blocksize = 1024;
+
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+ @Parameterized.Parameter(2)
+ public int rank;
+ @Parameterized.Parameter(3)
+ public double sparsity;
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(STD_TEST_NAME,
+ new TestConfiguration(TEST_CLASS_DIR, STD_TEST_NAME,
new String[] {OUTPUT_NAME}));
+ addTestConfiguration(LOG_TEST_NAME,
+ new TestConfiguration(TEST_CLASS_DIR, LOG_TEST_NAME,
new String[] {OUTPUT_NAME}));
+ addTestConfiguration(MINUS_TEST_NAME,
+ new TestConfiguration(TEST_CLASS_DIR, MINUS_TEST_NAME,
new String[] {OUTPUT_NAME}));
+ addTestConfiguration(MINUS_LOG_TEST_NAME,
+ new TestConfiguration(TEST_CLASS_DIR,
MINUS_LOG_TEST_NAME, new String[] {OUTPUT_NAME}));
+ }
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ // rows must be even
+ return Arrays.asList(new Object[][] {
+ // {rows, cols, rank, sparsity}
+ // {2000, 50, 10, 0.01},
+ // {2000, 50, 10, 0.9},
+ {150, 230, 75, 0.01}, {150, 230, 75, 0.9}});
+ }
+
+ @BeforeClass
+ public static void init() {
+ TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+ }
+
+ @Test
+ public void federatedWeightedSigmoidSingleNode() {
+ federatedWeightedSigmoid(STD_TEST_NAME, ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedWeightedSigmoidSpark() {
+ federatedWeightedSigmoid(STD_TEST_NAME, ExecMode.SPARK);
+ }
+
+ @Test
+ public void federatedWeightedSigmoidLogSingleNode() {
+ federatedWeightedSigmoid(LOG_TEST_NAME, ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedWeightedSigmoidLogSpark() {
+ federatedWeightedSigmoid(LOG_TEST_NAME, ExecMode.SPARK);
+ }
+
+ @Test
+ public void federatedWeightedSigmoidMinusSingleNode() {
+ federatedWeightedSigmoid(MINUS_TEST_NAME, ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedWeightedSigmoidMinusSpark() {
+ federatedWeightedSigmoid(MINUS_TEST_NAME, ExecMode.SPARK);
+ }
+
+ @Test
+ public void federatedWeightedSigmoidMinusLogSingleNode() {
+ federatedWeightedSigmoid(MINUS_LOG_TEST_NAME,
ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedWeightedSigmoidMinusLogSpark() {
+ federatedWeightedSigmoid(MINUS_LOG_TEST_NAME, ExecMode.SPARK);
+ }
+
+ //
-----------------------------------------------------------------------------
+
+ public void federatedWeightedSigmoid(String test_name, ExecMode
exec_mode) {
+ // store the previous platform config to restore it after the
test
+ ExecMode platform_old = setExecMode(exec_mode);
+
+ getAndLoadTestConfiguration(test_name);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ int fed_rows = rows / 2;
+ int fed_cols = cols;
+
+ // generate dataset
+ // matrix handled by two federated workers
+ double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 0, 1,
sparsity, 3);
+ double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 0, 1,
sparsity, 7);
+
+ double[][] U = getRandomMatrix(rows, rank, 0, 1, 1, 512);
+ double[][] V = getRandomMatrix(cols, rank, 0, 1, 1, 5040);
+
+ writeInputMatrixWithMTD("X1",
+ X1,
+ false,
+ new MatrixCharacteristics(fed_rows, fed_cols,
blocksize, fed_rows * fed_cols));
+ writeInputMatrixWithMTD("X2",
+ X2,
+ false,
+ new MatrixCharacteristics(fed_rows, fed_cols,
blocksize, fed_rows * fed_cols));
+
+ writeInputMatrixWithMTD("U", U, true);
+ writeInputMatrixWithMTD("V", V, true);
+
+ // empty script name because we don't execute any script, just
start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ Thread thread1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
+ Thread thread2 = startLocalFedWorkerThread(port2);
+
+ getAndLoadTestConfiguration(test_name);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + test_name + "Reference.dml";
+ programArgs = new String[] {"-nvargs", "in_X1=" + input("X1"),
"in_X2=" + input("X2"), "in_U=" + input("U"),
+ "in_V=" + input("V"), "out_Z=" + expected(OUTPUT_NAME)};
+ runTest(true, false, null, -1);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + test_name + ".dml";
+ programArgs = new String[] {"-stats", "-nvargs", "in_X1=" +
TestUtils.federatedAddress(port1, input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2,
input("X2")), "in_U=" + input("U"), "in_V=" + input("V"),
+ "rows=" + fed_rows, "cols=" + fed_cols, "out_Z=" +
output(OUTPUT_NAME)};
+ runTest(true, false, null, -1);
+
+ // compare the results via files
+ HashMap<CellIndex, Double> refResults =
readDMLMatrixFromExpectedDir(OUTPUT_NAME);
+ HashMap<CellIndex, Double> fedResults =
readDMLMatrixFromOutputDir(OUTPUT_NAME);
+ TestUtils.compareMatrices(fedResults, refResults, TOLERANCE,
"Fed", "Ref");
+
+ TestUtils.shutdownThreads(thread1, thread2);
+
+ // check for federated operations
+ Assert.assertTrue(heavyHittersContainsString("fed_wsigmoid"));
+
+ // check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+
+ resetExecMode(platform_old);
+ }
+
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSquaredLossTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSquaredLossTest.java
new file mode 100644
index 0000000..9b0f7a7
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSquaredLossTest.java
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedWeightedSquaredLossTest extends AutomatedTestBase {
+ private final static String STD_TEST_NAME = "FederatedWSLossTest";
+ private final static String PRE_TEST_NAME = "FederatedWSLossPreTest";
+ private final static String POST_TEST_NAME = "FederatedWSLossPostTest";
+ private final static String TEST_DIR =
"functions/federated/quaternary/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedWeightedSquaredLossTest.class.getSimpleName()
+ + "/";
+
+ private final static String OUTPUT_NAME = "Z";
+
+ private final static double TOLERANCE = 1e-8;
+
+ private final static int blocksize = 1024;
+
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+ @Parameterized.Parameter(2)
+ public int rank;
+ @Parameterized.Parameter(3)
+ public double sparsity;
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(STD_TEST_NAME,
+ new TestConfiguration(TEST_CLASS_DIR, STD_TEST_NAME,
new String[] {OUTPUT_NAME}));
+ addTestConfiguration(PRE_TEST_NAME,
+ new TestConfiguration(TEST_CLASS_DIR, PRE_TEST_NAME,
new String[] {OUTPUT_NAME}));
+ addTestConfiguration(POST_TEST_NAME,
+ new TestConfiguration(TEST_CLASS_DIR, POST_TEST_NAME,
new String[] {OUTPUT_NAME}));
+ }
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ // rows must be even
+ return Arrays.asList(new Object[][] {
+ // {rows, cols, rank, sparsity}
+ // {2000, 50, 10, 0.01}, {2000, 50, 10, 0.9},
+ {100, 250, 25, 0.01}, {100, 250, 25, 0.9}});
+ }
+
+ @BeforeClass
+ public static void init() {
+ TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+ }
+
+ @Test
+ public void federatedWeightedSquaredLossSingleNode() {
+ federatedWeightedSquaredLoss(STD_TEST_NAME,
ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedWeightedSquaredLossSpark() {
+ federatedWeightedSquaredLoss(STD_TEST_NAME, ExecMode.SPARK);
+ }
+
+ @Test
+ public void federatedWeightedSquaredLossPreSingleNode() {
+ federatedWeightedSquaredLoss(PRE_TEST_NAME,
ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedWeightedSquaredLossPreSpark() {
+ federatedWeightedSquaredLoss(PRE_TEST_NAME, ExecMode.SPARK);
+ }
+
+ @Test
+ public void federatedWeightedSquaredLossPostSingleNode() {
+ federatedWeightedSquaredLoss(POST_TEST_NAME,
ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedWeightedSquaredLossPostSpark() {
+ federatedWeightedSquaredLoss(POST_TEST_NAME, ExecMode.SPARK);
+ }
+
+ //
-----------------------------------------------------------------------------
+
+ public void federatedWeightedSquaredLoss(String test_name, ExecMode
exec_mode) {
+ // store the previous platform config to restore it after the
test
+ ExecMode platform_old = setExecMode(exec_mode);
+
+ getAndLoadTestConfiguration(test_name);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ int fed_rows = rows / 2;
+ int fed_cols = cols;
+
+ // generate dataset
+ // matrix handled by two federated workers
+ double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 0, 1,
sparsity, 3);
+ double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 0, 1,
sparsity, 7);
+
+ double[][] U = getRandomMatrix(rows, rank, 0, 1, 1, 512);
+ double[][] V = getRandomMatrix(cols, rank, 0, 1, 1, 5040);
+
+ writeInputMatrixWithMTD("X1",
+ X1,
+ false,
+ new MatrixCharacteristics(fed_rows, fed_cols,
blocksize, fed_rows * fed_cols));
+ writeInputMatrixWithMTD("X2",
+ X2,
+ false,
+ new MatrixCharacteristics(fed_rows, fed_cols,
blocksize, fed_rows * fed_cols));
+
+ writeInputMatrixWithMTD("U", U, true);
+ writeInputMatrixWithMTD("V", V, true);
+
+ if(!test_name.equals(STD_TEST_NAME)) {
+ double[][] W = getRandomMatrix(rows, cols, 0, 1,
sparsity, 54);
+ writeInputMatrixWithMTD("W", W, true);
+ }
+
+ // empty script name because we don't execute any script, just
start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ Thread thread1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
+ Thread thread2 = startLocalFedWorkerThread(port2);
+
+ getAndLoadTestConfiguration(test_name);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + test_name + "Reference.dml";
+ programArgs = new String[] {"-nvargs", "in_X1=" + input("X1"),
"in_X2=" + input("X2"), "in_U=" + input("U"),
+ "in_V=" + input("V"), "in_W=" + input("W"), "out_Z=" +
expected(OUTPUT_NAME)};
+ runTest(true, false, null, -1);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + test_name + ".dml";
+ programArgs = new String[] {"-stats", "-nvargs", "in_X1=" +
TestUtils.federatedAddress(port1, input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2,
input("X2")), "in_U=" + input("U"), "in_V=" + input("V"),
+ "in_W=" + input("W"), "rows=" + fed_rows, "cols=" +
fed_cols, "out_Z=" + output(OUTPUT_NAME)};
+ runTest(true, false, null, -1);
+
+ // compare the results via files
+ HashMap<CellIndex, Double> refResults =
readDMLMatrixFromExpectedDir(OUTPUT_NAME);
+ HashMap<CellIndex, Double> fedResults =
readDMLMatrixFromOutputDir(OUTPUT_NAME);
+ TestUtils.compareMatrices(fedResults, refResults, TOLERANCE,
"Fed", "Ref");
+
+ TestUtils.shutdownThreads(thread1, thread2);
+
+ // check for federated operations
+ Assert.assertTrue(heavyHittersContainsString("fed_wsloss"));
+
+ // check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+
+ resetExecMode(platform_old);
+ }
+
+}
diff --git
a/src/test/scripts/functions/federated/quaternary/FederatedWSLossPostTest.dml
b/src/test/scripts/functions/federated/quaternary/FederatedWSLossPostTest.dml
new file mode 100644
index 0000000..0f43b37
--- /dev/null
+++
b/src/test/scripts/functions/federated/quaternary/FederatedWSLossPostTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2,
$cols)))
+
+U = read($in_U)
+V = read($in_V)
+W = read($in_W)
+
+Z = as.matrix(sum(W * (X - (U %*% t(V))) ^ 2))
+
+write(Z, $out_Z)
diff --git
a/src/test/scripts/functions/federated/quaternary/FederatedWSLossPostTestReference.dml
b/src/test/scripts/functions/federated/quaternary/FederatedWSLossPostTestReference.dml
new file mode 100644
index 0000000..5bfc9cc
--- /dev/null
+++
b/src/test/scripts/functions/federated/quaternary/FederatedWSLossPostTestReference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2))
+U = read($in_U)
+V = read($in_V)
+W = read($in_W)
+
+Z = as.matrix(sum(W * (X - (U %*% t(V))) ^ 2))
+
+write(Z, $out_Z)
diff --git
a/src/test/scripts/functions/federated/quaternary/FederatedWSLossPreTest.dml
b/src/test/scripts/functions/federated/quaternary/FederatedWSLossPreTest.dml
new file mode 100644
index 0000000..98cf21d
--- /dev/null
+++ b/src/test/scripts/functions/federated/quaternary/FederatedWSLossPreTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2,
$cols)))
+
+U = read($in_U)
+V = read($in_V)
+W = read($in_W)
+
+Z = as.matrix(sum((X - W * (U %*% t(V))) ^ 2))
+
+write(Z, $out_Z)
diff --git
a/src/test/scripts/functions/federated/quaternary/FederatedWSLossPreTestReference.dml
b/src/test/scripts/functions/federated/quaternary/FederatedWSLossPreTestReference.dml
new file mode 100644
index 0000000..08b4d65
--- /dev/null
+++
b/src/test/scripts/functions/federated/quaternary/FederatedWSLossPreTestReference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2))
+U = read($in_U)
+V = read($in_V)
+W = read($in_W)
+
+Z = as.matrix(sum((X - W * (U %*% t(V))) ^ 2))
+
+write(Z, $out_Z)
diff --git
a/src/test/scripts/functions/federated/quaternary/FederatedWSLossTest.dml
b/src/test/scripts/functions/federated/quaternary/FederatedWSLossTest.dml
new file mode 100644
index 0000000..9850a0f
--- /dev/null
+++ b/src/test/scripts/functions/federated/quaternary/FederatedWSLossTest.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2,
$cols)))
+
+U = read($in_U)
+V = read($in_V)
+
+Z = as.matrix(sum((X - (U %*% t(V))) ^ 2))
+
+write(Z, $out_Z)
diff --git
a/src/test/scripts/functions/federated/quaternary/FederatedWSLossTestReference.dml
b/src/test/scripts/functions/federated/quaternary/FederatedWSLossTestReference.dml
new file mode 100644
index 0000000..2caaf15
--- /dev/null
+++
b/src/test/scripts/functions/federated/quaternary/FederatedWSLossTestReference.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2))
+U = read($in_U)
+V = read($in_V)
+
+Z = as.matrix(sum((X - (U %*% t(V))) ^ 2))
+
+write(Z, $out_Z)
diff --git
a/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidLogTest.dml
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidLogTest.dml
new file mode 100644
index 0000000..a1369b8
--- /dev/null
+++
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidLogTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2,
$cols)));
+
+U = read($in_U);
+V = read($in_V);
+
+UV = U %*% t(V);
+Z = X * log(1 / (1 + exp(-UV)));
+
+write(Z, $out_Z);
diff --git
a/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidLogTestReference.dml
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidLogTestReference.dml
new file mode 100644
index 0000000..0477155
--- /dev/null
+++
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidLogTestReference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2));
+U = read($in_U);
+V = read($in_V);
+
+UV = U %*% t(V);
+Z = X * log(1 / (1 + exp(-UV)));
+
+write(Z, $out_Z);
diff --git
a/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusLogTest.dml
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusLogTest.dml
new file mode 100644
index 0000000..ec90e72
--- /dev/null
+++
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusLogTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2,
$cols)));
+
+U = read($in_U);
+V = read($in_V);
+
+UV = -(U %*% t(V));
+Z = X * log(1 / (1 + exp(-UV)));
+
+write(Z, $out_Z);
diff --git
a/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusLogTestReference.dml
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusLogTestReference.dml
new file mode 100644
index 0000000..5e279c8
--- /dev/null
+++
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusLogTestReference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2));
+U = read($in_U);
+V = read($in_V);
+
+UV = -(U %*% t(V));
+Z = X * log(1 / (1 + exp(-UV)));
+
+write(Z, $out_Z);
diff --git
a/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusTest.dml
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusTest.dml
new file mode 100644
index 0000000..8be3559
--- /dev/null
+++
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2,
$cols)));
+
+U = read($in_U);
+V = read($in_V);
+
+UV = -(U %*% t(V));
+Z = X * (1 / (1 + exp(-UV)));
+
+write(Z, $out_Z);
diff --git
a/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusTestReference.dml
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusTestReference.dml
new file mode 100644
index 0000000..455c135
--- /dev/null
+++
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusTestReference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2));
+U = read($in_U);
+V = read($in_V);
+
+UV = -(U %*% t(V));
+Z = X * (1 / (1 + exp(-UV)));
+
+write(Z, $out_Z);
diff --git
a/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidTest.dml
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidTest.dml
new file mode 100644
index 0000000..8fa43c0
--- /dev/null
+++ b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2,
$cols)));
+
+U = read($in_U);
+V = read($in_V);
+
+UV = U %*% t(V);
+Z = X * (1 / (1 + exp(-UV)));
+
+write(Z, $out_Z);
diff --git
a/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidTestReference.dml
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidTestReference.dml
new file mode 100644
index 0000000..19ce7e6
--- /dev/null
+++
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidTestReference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2));
+U = read($in_U);
+V = read($in_V);
+
+UV = U %*% t(V);
+Z = X * (1 / (1 + exp(-UV)));
+
+write(Z, $out_Z);