This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new a65acda  [SYSTEMDS-2747] Federated Quarternary Operations WSLoss and 
WSigmoid
a65acda is described below

commit a65acda690bc6afc77a324b0dc323994f540cb7a
Author: ywcb00 <[email protected]>
AuthorDate: Mon Dec 28 13:50:07 2020 +0100

    [SYSTEMDS-2747] Federated Quarternary Operations WSLoss and WSigmoid
    
    Quaternary operations, part 2
    
    This commit adds support for two federated quarternary operations,
    along with federated junit tests.
    
    - Weighted Sigmoid function
    - Weighted Sigmoid loss function
    
    Closes #1143
---
 .../instructions/fed/QuaternaryFEDInstruction.java |  91 +++++++---
 .../fed/QuaternaryWSLossFEDInstruction.java        | 119 ++++++++++++
 .../fed/QuaternaryWSigmoidFEDInstruction.java      |  89 +++++++++
 .../primitives/FederatedWeightedSigmoidTest.java   | 202 +++++++++++++++++++++
 .../FederatedWeightedSquaredLossTest.java          | 194 ++++++++++++++++++++
 .../quaternary/FederatedWSLossPostTest.dml         |  31 ++++
 .../FederatedWSLossPostTestReference.dml           |  29 +++
 .../quaternary/FederatedWSLossPreTest.dml          |  31 ++++
 .../quaternary/FederatedWSLossPreTestReference.dml |  29 +++
 .../federated/quaternary/FederatedWSLossTest.dml   |  30 +++
 .../quaternary/FederatedWSLossTestReference.dml    |  28 +++
 .../quaternary/FederatedWSigmoidLogTest.dml        |  31 ++++
 .../FederatedWSigmoidLogTestReference.dml          |  29 +++
 .../quaternary/FederatedWSigmoidMinusLogTest.dml   |  31 ++++
 .../FederatedWSigmoidMinusLogTestReference.dml     |  29 +++
 .../quaternary/FederatedWSigmoidMinusTest.dml      |  31 ++++
 .../FederatedWSigmoidMinusTestReference.dml        |  29 +++
 .../federated/quaternary/FederatedWSigmoidTest.dml |  31 ++++
 .../quaternary/FederatedWSigmoidTestReference.dml  |  29 +++
 19 files changed, 1089 insertions(+), 24 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java
index 2b62ec5..ffb385d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java
@@ -23,63 +23,106 @@ import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.lops.Lop;
 import org.apache.sysds.lops.WeightedCrossEntropy.WCeMMType;
+import org.apache.sysds.lops.WeightedSigmoid.WSigmoidType;
+import org.apache.sysds.lops.WeightedSquaredLoss.WeightsType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
-import org.apache.sysds.runtime.instructions.fed.QuaternaryWCeMMFEDInstruction;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
 
-public abstract class QuaternaryFEDInstruction extends 
ComputationFEDInstruction
-{
+public abstract class QuaternaryFEDInstruction extends 
ComputationFEDInstruction {
        protected CPOperand _input4 = null;
 
-       protected QuaternaryFEDInstruction(FEDInstruction.FEDType type, 
Operator operator,
-               CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, 
CPOperand out, String opcode, String instruction_str)
-       {
+       protected QuaternaryFEDInstruction(FEDInstruction.FEDType type, 
Operator operator, CPOperand in1, CPOperand in2,
+               CPOperand in3, CPOperand out, String opcode, String 
instruction_str) {
+               super(type, operator, in1, in2, in3, out, opcode, 
instruction_str);
+       }
+
+       protected QuaternaryFEDInstruction(FEDInstruction.FEDType type, 
Operator operator, CPOperand in1, CPOperand in2,
+               CPOperand in3, CPOperand in4, CPOperand out, String opcode, 
String instruction_str) {
                super(type, operator, in1, in2, in3, out, opcode, 
instruction_str);
                _input4 = in4;
        }
 
-       public static QuaternaryFEDInstruction parseInstruction(String str)
-       {
+       public static QuaternaryFEDInstruction parseInstruction(String str) {
                if(str.startsWith(ExecType.SPARK.name())) {
                        // rewrite the spark instruction to a cp instruction
                        str = str.replace(ExecType.SPARK.name(), 
ExecType.CP.name());
                        str = str.replace("mapwcemm", "wcemm");
-                       str += Lop.OPERAND_DELIMITOR + "1"; //num threads
+                       str = str.replace("mapwsloss", "wsloss");
+                       if(str.contains("redwsloss")) {
+                               str = str.replace("redwsloss", "wsloss");
+                               // remove booleans which indicate cacheU and 
cacheV for redwsloss
+                               str = str.replace(Lop.OPERAND_DELIMITOR + 
"true", "");
+                               str = str.replace(Lop.OPERAND_DELIMITOR + 
"false", "");
+                       }
+                       str = str.replace("mapwsigmoid", "wsigmoid");
+                       str += Lop.OPERAND_DELIMITOR + "1"; // num threads
                }
 
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
                String opcode = parts[0];
 
+               int addInput4 = (opcode.equals("wcemm") || 
opcode.equals("wsloss")) ? 1 : 0;
+
+               InstructionUtils.checkNumFields(parts, 6 + addInput4);
+
                CPOperand in1 = new CPOperand(parts[1]);
                CPOperand in2 = new CPOperand(parts[2]);
                CPOperand in3 = new CPOperand(parts[3]);
-               CPOperand out = new CPOperand(parts[5]);
+               CPOperand out = new CPOperand(parts[4 + addInput4]);
 
-               InstructionUtils.checkNumFields(parts, 7);
+               checkDataTypes(DataType.MATRIX, in1, in2, in3);
 
-               if(opcode.equals("wcemm")) {
+               QuaternaryOperator qop = null;
+               if(addInput4 == 1) // wcemm, wsloss
+               {
                        CPOperand in4 = new CPOperand(parts[4]);
-                       checkDataTypes(in1, in2, in3, in4);
 
-                       WCeMMType wcemm_type = WCeMMType.valueOf(parts[6]);
-                       QuaternaryOperator quaternary_operator = 
(wcemm_type.hasFourInputs() ?
-                               new QuaternaryOperator(wcemm_type, 
Double.parseDouble(in4.getName())) :
-                               new QuaternaryOperator(wcemm_type));
-                       return new 
QuaternaryWCeMMFEDInstruction(quaternary_operator, in1, in2, in3, in4, out, 
opcode, str);
+                       if(opcode.equals("wcemm")) {
+                               final WCeMMType wcemm_type = 
WCeMMType.valueOf(parts[6]);
+                               if(wcemm_type.hasFourInputs())
+                                       checkDataTypes(new DataType[] 
{DataType.SCALAR, DataType.MATRIX}, in4);
+                               qop = (wcemm_type.hasFourInputs() ? new 
QuaternaryOperator(wcemm_type,
+                                       Double.parseDouble(in4.getName())) : 
new QuaternaryOperator(wcemm_type));
+                               return new QuaternaryWCeMMFEDInstruction(qop, 
in1, in2, in3, in4, out, opcode, str);
+                       }
+                       else if(opcode.equals("wsloss")) {
+                               final WeightsType weights_type = 
WeightsType.valueOf(parts[6]);
+                               if(weights_type.hasFourInputs())
+                                       checkDataTypes(DataType.MATRIX, in4);
+                               qop = new QuaternaryOperator(weights_type);
+                               return new QuaternaryWSLossFEDInstruction(qop, 
in1, in2, in3, in4, out, opcode, str);
+                       }
+               }
+               else if(opcode.equals("wsigmoid")) {
+                       final WSigmoidType wsigmoid_type = 
WSigmoidType.valueOf(parts[5]);
+                       qop = new QuaternaryOperator(wsigmoid_type);
+                       return new QuaternaryWSigmoidFEDInstruction(qop, in1, 
in2, in3, out, opcode, str);
                }
 
                throw new DMLRuntimeException("Unsupported opcode (" + opcode + 
") for QuaternaryFEDInstruction.");
        }
 
-       protected static void checkDataTypes(CPOperand in1, CPOperand in2, 
CPOperand in3, CPOperand in4) {
-               if(in1.getDataType() != DataType.MATRIX || in2.getDataType() != 
DataType.MATRIX 
-                       || in3.getDataType() != DataType.MATRIX 
-                       || !(in4.getDataType() == DataType.SCALAR || 
in4.getDataType() == DataType.MATRIX)) {
-                       throw new DMLRuntimeException("Federated quaternary 
operations "
-                               + "only supported with matrix inputs and scalar 
epsilon.");
+       protected static void checkDataTypes(DataType data_type, CPOperand... 
cp_operands) {
+               checkDataTypes(new DataType[] {data_type}, cp_operands);
+       }
+
+       protected static void checkDataTypes(DataType[] data_types, 
CPOperand... cp_operands) {
+               for(CPOperand cpo : cp_operands) {
+                       if(!checkDataType(data_types, cpo)) {
+                               throw new DMLRuntimeException(
+                                       "Federated quaternary operations " + 
"only supported with matrix inputs and scalar epsilon.");
+                       }
+               }
+       }
+
+       private static boolean checkDataType(DataType[] data_types, CPOperand 
cp_operand) {
+               for(DataType dt : data_types) {
+                       if(cp_operand.getDataType() == dt)
+                               return true;
                }
+               return false;
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java
new file mode 100644
index 0000000..2cd38a6
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
+
+import java.util.concurrent.Future;
+
+public class QuaternaryWSLossFEDInstruction extends QuaternaryFEDInstruction {
+
+       /**
+        * This Instruction performs a Weighted Sigmoid Loss function as 
follows:
+        * 
+        * Z = sum(W * (X - (U %*% t(V))) ^ 2)
+        * 
+        * @param operator Weighted Sigmoid Loss 
+        * @param in1 X
+        * @param in2 U
+        * @param in3 V
+        * @param in4 W
+        * @param out Z
+        * @param opcode
+        * @param instruction_str
+        */
+       protected QuaternaryWSLossFEDInstruction(Operator operator, CPOperand 
in1, CPOperand in2, CPOperand in3,
+               CPOperand in4, CPOperand out, String opcode, String 
instruction_str) {
+               super(FEDType.Quaternary, operator, in1, in2, in3, in4, out, 
opcode, instruction_str);
+       }
+
+       @Override
+       public void processInstruction(ExecutionContext ec) {
+               QuaternaryOperator qop = (QuaternaryOperator) _optr;
+
+               MatrixObject X = ec.getMatrixObject(input1);
+               MatrixObject U = ec.getMatrixObject(input2);
+               MatrixObject V = ec.getMatrixObject(input3);
+
+               MatrixObject W = null;
+               if(qop.hasFourInputs()) {
+                       W = ec.getMatrixObject(_input4);
+               }
+
+               if(!(X.isFederated() && !U.isFederated() && !V.isFederated() && 
(W == null || !W.isFederated())))
+                       throw new DMLRuntimeException("Unsupported federated 
inputs (X, U, V, W) = (" + X.isFederated() + ", "
+                               + U.isFederated() + ", " + V.isFederated() + (W 
!= null ? W.isFederated() : "none") + ")");
+
+               FederationMap fedMap = X.getFedMapping();
+               FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false);
+               FederatedRequest frInit2 = fedMap.broadcast(V);
+
+               FederatedRequest[] frInit3 = null;
+               FederatedRequest frCompute1 = null;
+               if(W != null) {
+                       frInit3 = fedMap.broadcastSliced(W, false);
+                       frCompute1 = FederationUtils.callInstruction(instString,
+                               output,
+                               new CPOperand[] {input1, input2, input3, 
_input4},
+                               new long[] {fedMap.getID(), frInit1[0].getID(), 
frInit2.getID(), frInit3[0].getID()});
+               }
+               else {
+                       frCompute1 = FederationUtils.callInstruction(instString,
+                               output,
+                               new CPOperand[] {input1, input2, input3},
+                               new long[] {fedMap.getID(), frInit1[0].getID(), 
frInit2.getID()});
+               }
+
+               FederatedRequest frGet1 = new 
FederatedRequest(RequestType.GET_VAR, frCompute1.getID());
+               FederatedRequest frCleanup1 = fedMap.cleanup(getTID(), 
frCompute1.getID());
+               FederatedRequest frCleanup2 = fedMap.cleanup(getTID(), 
frInit1[0].getID());
+               FederatedRequest frCleanup3 = fedMap.cleanup(getTID(), 
frInit2.getID());
+
+               Future<FederatedResponse>[] response;
+               if(frInit3 != null) {
+                       FederatedRequest frCleanup4 = fedMap.cleanup(getTID(), 
frInit3[0].getID());
+                       // execute federated instructions
+                       fedMap.execute(getTID(), true, frInit1, frInit2);
+                       response = fedMap
+                               .execute(getTID(), true, frInit3, frCompute1, 
frGet1, frCleanup1, frCleanup2, frCleanup3, frCleanup4);
+               }
+               else {
+                       // execute federated instructions
+                       response = fedMap
+                               .execute(getTID(), true, frInit1, frInit2, 
frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3);
+               }
+
+               // aggregate partial results from federated responses
+               AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+               ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, 
response));
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java
new file mode 100644
index 0000000..10456f8
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.concurrent.Future;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+public class QuaternaryWSigmoidFEDInstruction extends QuaternaryFEDInstruction 
{
+
+       /**
+        * This instruction performs:
+        * 
+        * UV = U %*% t(V); Z = X * log(1 / (1 + exp(-UV)));
+        * 
+        * @param operator        Weighted Sigmoid Federated Instruction.
+        * @param in1             X
+        * @param in2             U
+        * @param in3             V
+        * @param out             The Federated Result Z
+        * @param opcode          ...
+        * @param instruction_str ...
+        */
+       protected QuaternaryWSigmoidFEDInstruction(Operator operator, CPOperand 
in1, CPOperand in2, CPOperand in3,
+               CPOperand out, String opcode, String instruction_str) {
+               super(FEDType.Quaternary, operator, in1, in2, in3, out, opcode, 
instruction_str);
+       }
+
+       @Override
+       public void processInstruction(ExecutionContext ec) {
+               MatrixObject X = ec.getMatrixObject(input1);
+               MatrixObject U = ec.getMatrixObject(input2);
+               MatrixObject V = ec.getMatrixObject(input3);
+
+               if(!(X.isFederated() && !U.isFederated() && !V.isFederated()))
+                       throw new DMLRuntimeException("Unsupported federated 
inputs (X, U, V) = (" + X.isFederated() + ", "
+                               + U.isFederated() + ", " + V.isFederated() + 
")");
+
+               FederationMap fedMap = X.getFedMapping();
+               FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false);
+               FederatedRequest frInit2 = fedMap.broadcast(V);
+
+               FederatedRequest frCompute1 = 
FederationUtils.callInstruction(instString,
+                       output,
+                       new CPOperand[] {input1, input2, input3},
+                       new long[] {fedMap.getID(), frInit1[0].getID(), 
frInit2.getID()});
+
+               // get partial results from federated workers
+               FederatedRequest frGet1 = new 
FederatedRequest(RequestType.GET_VAR, frCompute1.getID());
+
+               FederatedRequest frCleanup1 = fedMap.cleanup(getTID(), 
frCompute1.getID());
+               FederatedRequest frCleanup2 = fedMap.cleanup(getTID(), 
frInit1[0].getID());
+               FederatedRequest frCleanup3 = fedMap.cleanup(getTID(), 
frInit2.getID());
+
+               // execute federated instructions
+               Future<FederatedResponse>[] response = fedMap
+                       .execute(getTID(), true, frInit1, frInit2, frCompute1, 
frGet1, frCleanup1, frCleanup2, frCleanup3);
+
+               // bind partial results from federated responses
+               ec.setMatrixOutput(output.getName(), 
FederationUtils.bind(response, false));
+
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSigmoidTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSigmoidTest.java
new file mode 100644
index 0000000..e73ce82
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSigmoidTest.java
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedWeightedSigmoidTest extends AutomatedTestBase {
+       private final static String STD_TEST_NAME = "FederatedWSigmoidTest";
+       private final static String LOG_TEST_NAME = "FederatedWSigmoidLogTest";
+       private final static String MINUS_TEST_NAME = 
"FederatedWSigmoidMinusTest";
+       private final static String MINUS_LOG_TEST_NAME = 
"FederatedWSigmoidMinusLogTest";
+       private final static String TEST_DIR = 
"functions/federated/quaternary/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedWeightedSigmoidTest.class.getSimpleName() + "/";
+
+       private final static String OUTPUT_NAME = "Z";
+
+       private final static double TOLERANCE = 0;
+
+       private final static int blocksize = 1024;
+
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+       @Parameterized.Parameter(2)
+       public int rank;
+       @Parameterized.Parameter(3)
+       public double sparsity;
+
+       @Override
+       public void setUp() {
+               addTestConfiguration(STD_TEST_NAME,
+                       new TestConfiguration(TEST_CLASS_DIR, STD_TEST_NAME, 
new String[] {OUTPUT_NAME}));
+               addTestConfiguration(LOG_TEST_NAME,
+                       new TestConfiguration(TEST_CLASS_DIR, LOG_TEST_NAME, 
new String[] {OUTPUT_NAME}));
+               addTestConfiguration(MINUS_TEST_NAME,
+                       new TestConfiguration(TEST_CLASS_DIR, MINUS_TEST_NAME, 
new String[] {OUTPUT_NAME}));
+               addTestConfiguration(MINUS_LOG_TEST_NAME,
+                       new TestConfiguration(TEST_CLASS_DIR, 
MINUS_LOG_TEST_NAME, new String[] {OUTPUT_NAME}));
+       }
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               // rows must be even
+               return Arrays.asList(new Object[][] {
+                       // {rows, cols, rank, sparsity}
+                       // {2000, 50, 10, 0.01},
+                       // {2000, 50, 10, 0.9},
+                       {150, 230, 75, 0.01}, {150, 230, 75, 0.9}});
+       }
+
+       @BeforeClass
+       public static void init() {
+               TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+       }
+
+       @Test
+       public void federatedWeightedSigmoidSingleNode() {
+               federatedWeightedSigmoid(STD_TEST_NAME, ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void federatedWeightedSigmoidSpark() {
+               federatedWeightedSigmoid(STD_TEST_NAME, ExecMode.SPARK);
+       }
+
+       @Test
+       public void federatedWeightedSigmoidLogSingleNode() {
+               federatedWeightedSigmoid(LOG_TEST_NAME, ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void federatedWeightedSigmoidLogSpark() {
+               federatedWeightedSigmoid(LOG_TEST_NAME, ExecMode.SPARK);
+       }
+
+       @Test
+       public void federatedWeightedSigmoidMinusSingleNode() {
+               federatedWeightedSigmoid(MINUS_TEST_NAME, ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void federatedWeightedSigmoidMinusSpark() {
+               federatedWeightedSigmoid(MINUS_TEST_NAME, ExecMode.SPARK);
+       }
+
+       @Test
+       public void federatedWeightedSigmoidMinusLogSingleNode() {
+               federatedWeightedSigmoid(MINUS_LOG_TEST_NAME, 
ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void federatedWeightedSigmoidMinusLogSpark() {
+               federatedWeightedSigmoid(MINUS_LOG_TEST_NAME, ExecMode.SPARK);
+       }
+
+       // 
-----------------------------------------------------------------------------
+
+       public void federatedWeightedSigmoid(String test_name, ExecMode 
exec_mode) {
+               // store the previous platform config to restore it after the 
test
+               ExecMode platform_old = setExecMode(exec_mode);
+
+               getAndLoadTestConfiguration(test_name);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               int fed_rows = rows / 2;
+               int fed_cols = cols;
+
+               // generate dataset
+               // matrix handled by two federated workers
+               double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 0, 1, 
sparsity, 3);
+               double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 0, 1, 
sparsity, 7);
+
+               double[][] U = getRandomMatrix(rows, rank, 0, 1, 1, 512);
+               double[][] V = getRandomMatrix(cols, rank, 0, 1, 1, 5040);
+
+               writeInputMatrixWithMTD("X1",
+                       X1,
+                       false,
+                       new MatrixCharacteristics(fed_rows, fed_cols, 
blocksize, fed_rows * fed_cols));
+               writeInputMatrixWithMTD("X2",
+                       X2,
+                       false,
+                       new MatrixCharacteristics(fed_rows, fed_cols, 
blocksize, fed_rows * fed_cols));
+
+               writeInputMatrixWithMTD("U", U, true);
+               writeInputMatrixWithMTD("V", V, true);
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               Thread thread1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
+               Thread thread2 = startLocalFedWorkerThread(port2);
+
+               getAndLoadTestConfiguration(test_name);
+
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + test_name + "Reference.dml";
+               programArgs = new String[] {"-nvargs", "in_X1=" + input("X1"), 
"in_X2=" + input("X2"), "in_U=" + input("U"),
+                       "in_V=" + input("V"), "out_Z=" + expected(OUTPUT_NAME)};
+               runTest(true, false, null, -1);
+
+               // Run actual dml script with federated matrix
+               fullDMLScriptName = HOME + test_name + ".dml";
+               programArgs = new String[] {"-stats", "-nvargs", "in_X1=" + 
TestUtils.federatedAddress(port1, input("X1")),
+                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")), "in_U=" + input("U"), "in_V=" + input("V"),
+                       "rows=" + fed_rows, "cols=" + fed_cols, "out_Z=" + 
output(OUTPUT_NAME)};
+               runTest(true, false, null, -1);
+
+               // compare the results via files
+               HashMap<CellIndex, Double> refResults = 
readDMLMatrixFromExpectedDir(OUTPUT_NAME);
+               HashMap<CellIndex, Double> fedResults = 
readDMLMatrixFromOutputDir(OUTPUT_NAME);
+               TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, 
"Fed", "Ref");
+
+               TestUtils.shutdownThreads(thread1, thread2);
+
+               // check for federated operations
+               Assert.assertTrue(heavyHittersContainsString("fed_wsigmoid"));
+
+               // check that federated input files are still existing
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+
+               resetExecMode(platform_old);
+       }
+
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSquaredLossTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSquaredLossTest.java
new file mode 100644
index 0000000..9b0f7a7
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSquaredLossTest.java
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedWeightedSquaredLossTest extends AutomatedTestBase {
+       private final static String STD_TEST_NAME = "FederatedWSLossTest";
+       private final static String PRE_TEST_NAME = "FederatedWSLossPreTest";
+       private final static String POST_TEST_NAME = "FederatedWSLossPostTest";
+       private final static String TEST_DIR = 
"functions/federated/quaternary/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedWeightedSquaredLossTest.class.getSimpleName()
+               + "/";
+
+       private final static String OUTPUT_NAME = "Z";
+
+       private final static double TOLERANCE = 1e-8;
+
+       private final static int blocksize = 1024;
+
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+       @Parameterized.Parameter(2)
+       public int rank;
+       @Parameterized.Parameter(3)
+       public double sparsity;
+
+       @Override
+       public void setUp() {
+               addTestConfiguration(STD_TEST_NAME,
+                       new TestConfiguration(TEST_CLASS_DIR, STD_TEST_NAME, 
new String[] {OUTPUT_NAME}));
+               addTestConfiguration(PRE_TEST_NAME,
+                       new TestConfiguration(TEST_CLASS_DIR, PRE_TEST_NAME, 
new String[] {OUTPUT_NAME}));
+               addTestConfiguration(POST_TEST_NAME,
+                       new TestConfiguration(TEST_CLASS_DIR, POST_TEST_NAME, 
new String[] {OUTPUT_NAME}));
+       }
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               // rows must be even
+               return Arrays.asList(new Object[][] {
+                       // {rows, cols, rank, sparsity}
+                       // {2000, 50, 10, 0.01}, {2000, 50, 10, 0.9},
+                       {100, 250, 25, 0.01}, {100, 250, 25, 0.9}});
+       }
+
+       @BeforeClass
+       public static void init() {
+               TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+       }
+
+       @Test
+       public void federatedWeightedSquaredLossSingleNode() {
+               federatedWeightedSquaredLoss(STD_TEST_NAME, 
ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void federatedWeightedSquaredLossSpark() {
+               federatedWeightedSquaredLoss(STD_TEST_NAME, ExecMode.SPARK);
+       }
+
+       @Test
+       public void federatedWeightedSquaredLossPreSingleNode() {
+               federatedWeightedSquaredLoss(PRE_TEST_NAME, 
ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void federatedWeightedSquaredLossPreSpark() {
+               federatedWeightedSquaredLoss(PRE_TEST_NAME, ExecMode.SPARK);
+       }
+
+       @Test
+       public void federatedWeightedSquaredLossPostSingleNode() {
+               federatedWeightedSquaredLoss(POST_TEST_NAME, 
ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void federatedWeightedSquaredLossPostSpark() {
+               federatedWeightedSquaredLoss(POST_TEST_NAME, ExecMode.SPARK);
+       }
+
+       // 
-----------------------------------------------------------------------------
+
+       public void federatedWeightedSquaredLoss(String test_name, ExecMode 
exec_mode) {
+               // store the previous platform config to restore it after the 
test
+               ExecMode platform_old = setExecMode(exec_mode);
+
+               getAndLoadTestConfiguration(test_name);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               int fed_rows = rows / 2;
+               int fed_cols = cols;
+
+               // generate dataset
+               // matrix handled by two federated workers
+               double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 0, 1, 
sparsity, 3);
+               double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 0, 1, 
sparsity, 7);
+
+               double[][] U = getRandomMatrix(rows, rank, 0, 1, 1, 512);
+               double[][] V = getRandomMatrix(cols, rank, 0, 1, 1, 5040);
+
+               writeInputMatrixWithMTD("X1",
+                       X1,
+                       false,
+                       new MatrixCharacteristics(fed_rows, fed_cols, 
blocksize, fed_rows * fed_cols));
+               writeInputMatrixWithMTD("X2",
+                       X2,
+                       false,
+                       new MatrixCharacteristics(fed_rows, fed_cols, 
blocksize, fed_rows * fed_cols));
+
+               writeInputMatrixWithMTD("U", U, true);
+               writeInputMatrixWithMTD("V", V, true);
+
+               if(!test_name.equals(STD_TEST_NAME)) {
+                       double[][] W = getRandomMatrix(rows, cols, 0, 1, 
sparsity, 54);
+                       writeInputMatrixWithMTD("W", W, true);
+               }
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               Thread thread1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
+               Thread thread2 = startLocalFedWorkerThread(port2);
+
+               getAndLoadTestConfiguration(test_name);
+
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + test_name + "Reference.dml";
+               programArgs = new String[] {"-nvargs", "in_X1=" + input("X1"), 
"in_X2=" + input("X2"), "in_U=" + input("U"),
+                       "in_V=" + input("V"), "in_W=" + input("W"), "out_Z=" + 
expected(OUTPUT_NAME)};
+               runTest(true, false, null, -1);
+
+               // Run actual dml script with federated matrix
+               fullDMLScriptName = HOME + test_name + ".dml";
+               programArgs = new String[] {"-stats", "-nvargs", "in_X1=" + 
TestUtils.federatedAddress(port1, input("X1")),
+                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")), "in_U=" + input("U"), "in_V=" + input("V"),
+                       "in_W=" + input("W"), "rows=" + fed_rows, "cols=" + 
fed_cols, "out_Z=" + output(OUTPUT_NAME)};
+               runTest(true, false, null, -1);
+
+               // compare the results via files
+               HashMap<CellIndex, Double> refResults = 
readDMLMatrixFromExpectedDir(OUTPUT_NAME);
+               HashMap<CellIndex, Double> fedResults = 
readDMLMatrixFromOutputDir(OUTPUT_NAME);
+               TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, 
"Fed", "Ref");
+
+               TestUtils.shutdownThreads(thread1, thread2);
+
+               // check for federated operations
+               Assert.assertTrue(heavyHittersContainsString("fed_wsloss"));
+
+               // check that federated input files are still existing
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+
+               resetExecMode(platform_old);
+       }
+
+}
diff --git 
a/src/test/scripts/functions/federated/quaternary/FederatedWSLossPostTest.dml 
b/src/test/scripts/functions/federated/quaternary/FederatedWSLossPostTest.dml
new file mode 100644
index 0000000..0f43b37
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/quaternary/FederatedWSLossPostTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+  ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, 
$cols)))
+
+U = read($in_U)
+V = read($in_V)
+W = read($in_W)
+
+Z = as.matrix(sum(W * (X - (U %*% t(V))) ^ 2))
+
+write(Z, $out_Z)
diff --git 
a/src/test/scripts/functions/federated/quaternary/FederatedWSLossPostTestReference.dml
 
b/src/test/scripts/functions/federated/quaternary/FederatedWSLossPostTestReference.dml
new file mode 100644
index 0000000..5bfc9cc
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/quaternary/FederatedWSLossPostTestReference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2))
+U = read($in_U)
+V = read($in_V)
+W = read($in_W)
+
+Z = as.matrix(sum(W * (X - (U %*% t(V))) ^ 2))
+
+write(Z, $out_Z)
diff --git 
a/src/test/scripts/functions/federated/quaternary/FederatedWSLossPreTest.dml 
b/src/test/scripts/functions/federated/quaternary/FederatedWSLossPreTest.dml
new file mode 100644
index 0000000..98cf21d
--- /dev/null
+++ b/src/test/scripts/functions/federated/quaternary/FederatedWSLossPreTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+  ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, 
$cols)))
+
+U = read($in_U)
+V = read($in_V)
+W = read($in_W)
+
+Z = as.matrix(sum((X - W * (U %*% t(V))) ^ 2))
+
+write(Z, $out_Z)
diff --git 
a/src/test/scripts/functions/federated/quaternary/FederatedWSLossPreTestReference.dml
 
b/src/test/scripts/functions/federated/quaternary/FederatedWSLossPreTestReference.dml
new file mode 100644
index 0000000..08b4d65
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/quaternary/FederatedWSLossPreTestReference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2))
+U = read($in_U)
+V = read($in_V)
+W = read($in_W)
+
+Z = as.matrix(sum((X - W * (U %*% t(V))) ^ 2))
+
+write(Z, $out_Z)
diff --git 
a/src/test/scripts/functions/federated/quaternary/FederatedWSLossTest.dml 
b/src/test/scripts/functions/federated/quaternary/FederatedWSLossTest.dml
new file mode 100644
index 0000000..9850a0f
--- /dev/null
+++ b/src/test/scripts/functions/federated/quaternary/FederatedWSLossTest.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+  ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, 
$cols)))
+
+U = read($in_U)
+V = read($in_V)
+
+Z = as.matrix(sum((X - (U %*% t(V))) ^ 2))
+
+write(Z, $out_Z)
diff --git 
a/src/test/scripts/functions/federated/quaternary/FederatedWSLossTestReference.dml
 
b/src/test/scripts/functions/federated/quaternary/FederatedWSLossTestReference.dml
new file mode 100644
index 0000000..2caaf15
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/quaternary/FederatedWSLossTestReference.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2))
+U = read($in_U)
+V = read($in_V)
+
+Z = as.matrix(sum((X - (U %*% t(V))) ^ 2))
+
+write(Z, $out_Z)
diff --git 
a/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidLogTest.dml 
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidLogTest.dml
new file mode 100644
index 0000000..a1369b8
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidLogTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+  ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, 
$cols)));
+
+U = read($in_U);
+V = read($in_V);
+
+UV = U %*% t(V);
+Z = X * log(1 / (1 + exp(-UV)));
+
+write(Z, $out_Z);
diff --git 
a/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidLogTestReference.dml
 
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidLogTestReference.dml
new file mode 100644
index 0000000..0477155
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidLogTestReference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2));
+U = read($in_U);
+V = read($in_V);
+
+UV = U %*% t(V);
+Z = X * log(1 / (1 + exp(-UV)));
+
+write(Z, $out_Z);
diff --git 
a/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusLogTest.dml
 
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusLogTest.dml
new file mode 100644
index 0000000..ec90e72
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusLogTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+  ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, 
$cols)));
+
+U = read($in_U);
+V = read($in_V);
+
+UV = -(U %*% t(V));
+Z = X * log(1 / (1 + exp(-UV)));
+
+write(Z, $out_Z);
diff --git 
a/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusLogTestReference.dml
 
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusLogTestReference.dml
new file mode 100644
index 0000000..5e279c8
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusLogTestReference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2));
+U = read($in_U);
+V = read($in_V);
+
+UV = -(U %*% t(V));
+Z = X * log(1 / (1 + exp(-UV)));
+
+write(Z, $out_Z);
diff --git 
a/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusTest.dml
 
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusTest.dml
new file mode 100644
index 0000000..8be3559
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+  ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, 
$cols)));
+
+U = read($in_U);
+V = read($in_V);
+
+UV = -(U %*% t(V));
+Z = X * (1 / (1 + exp(-UV)));
+
+write(Z, $out_Z);
diff --git 
a/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusTestReference.dml
 
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusTestReference.dml
new file mode 100644
index 0000000..455c135
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidMinusTestReference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2));
+U = read($in_U);
+V = read($in_V);
+
+UV = -(U %*% t(V));
+Z = X * (1 / (1 + exp(-UV)));
+
+write(Z, $out_Z);
diff --git 
a/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidTest.dml 
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidTest.dml
new file mode 100644
index 0000000..8fa43c0
--- /dev/null
+++ b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+  ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, 
$cols)));
+
+U = read($in_U);
+V = read($in_V);
+
+UV = U %*% t(V);
+Z = X * (1 / (1 + exp(-UV)));
+
+write(Z, $out_Z);
diff --git 
a/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidTestReference.dml
 
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidTestReference.dml
new file mode 100644
index 0000000..19ce7e6
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/quaternary/FederatedWSigmoidTestReference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2));
+U = read($in_U);
+V = read($in_V);
+
+UV = U %*% t(V);
+Z = X * (1 / (1 + exp(-UV)));
+
+write(Z, $out_Z);

Reply via email to