This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 1d2517e  [SYSTEMDS-2747] Federated weighted cross entropy operations 
(WCEMM)
1d2517e is described below

commit 1d2517e1eb75646363e767a04c5c1b37ad12124e
Author: ywcb00 <[email protected]>
AuthorDate: Mon Dec 28 22:08:54 2020 +0100

    [SYSTEMDS-2747] Federated weighted cross entropy operations (WCEMM)
    
    Quaternary operations, part 1
    Closes #1133.
---
 .../controlprogram/federated/FederationUtils.java  |   8 +-
 .../runtime/instructions/fed/FEDInstruction.java   |  19 +--
 .../instructions/fed/FEDInstructionUtils.java      |  32 ++--
 .../instructions/fed/QuaternaryFEDInstruction.java |  85 ++++++++++
 .../fed/QuaternaryWCeMMFEDInstruction.java         | 116 ++++++++++++++
 .../instructions/fed/ReorgFEDInstruction.java      |   5 -
 .../org/apache/sysds/test/AutomatedTestBase.java   |   4 +
 .../FederatedWeightedCrossEntropyTest.java         | 175 +++++++++++++++++++++
 .../federated/quaternary/FederatedWCeMMEpsTest.dml |  31 ++++
 .../quaternary/FederatedWCeMMEpsTestReference.dml  |  29 ++++
 .../federated/quaternary/FederatedWCeMMTest.dml    |  30 ++++
 .../quaternary/FederatedWCeMMTestReference.dml     |  28 ++++
 12 files changed, 537 insertions(+), 25 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index e9c41b2..881991a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -119,7 +119,7 @@ public class FederationUtils {
                        long size = 0;
                        for(int i=0; i<ffr.length; i++) {
                                Object input = ffr[i].get().getData()[0];
-                               MatrixBlock tmp = (input instanceof 
ScalarObject) ? 
+                               MatrixBlock tmp = (input instanceof 
ScalarObject) ?
                                        new 
MatrixBlock(((ScalarObject)input).getDoubleValue()) : (MatrixBlock) input;
                                size += ranges[i].getSize(0);
                                sop1 = sop1.setConstant(ranges[i].getSize(0));
@@ -317,6 +317,10 @@ public class FederationUtils {
                }
        }
 
+       public static ScalarObject aggScalar(AggregateUnaryOperator aop, 
Future<FederatedResponse>[] ffr) {
+               return aggScalar(aop, ffr, null);
+       }
+
        public static ScalarObject aggScalar(AggregateUnaryOperator aop, 
Future<FederatedResponse>[] ffr, FederationMap map) {
                if(!(aop.aggOp.increOp.fn instanceof KahanFunction || 
(aop.aggOp.increOp.fn instanceof Builtin &&
                        (((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == 
BuiltinCode.MIN
@@ -366,7 +370,7 @@ public class FederationUtils {
                        throw new DMLRuntimeException("Unsupported aggregation 
operator: "
                                + 
aop.aggOp.increOp.fn.getClass().getSimpleName());
        }
-       
+
        public static FederationMap federateLocalData(CacheableData<?> data) {
                long id = FederationUtils.getNextFedDataID();
                FederatedLocalData federatedLocalData = new 
FederatedLocalData(id, data);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index 00f6b72..0cf1fac 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -25,7 +25,7 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.privacy.propagation.PrivacyPropagator;
 
 public abstract class FEDInstruction extends Instruction {
-       
+
        public enum FEDType {
                AggregateBinary,
                AggregateUnary,
@@ -40,41 +40,42 @@ public abstract class FEDInstruction extends Instruction {
                Reorg,
                Reshape,
                MatrixIndexing,
+               Quaternary,
                QSort,
                QPick
        }
-       
+
        protected final FEDType _fedType;
        protected long _tid = -1; //main
-       
+
        protected FEDInstruction(FEDType type, String opcode, String istr) {
                this(type, null, opcode, istr);
        }
-       
+
        protected FEDInstruction(FEDType type, Operator op, String opcode, 
String istr) {
                super(op);
                _fedType = type;
                instString = istr;
                instOpcode = opcode;
        }
-       
+
        @Override
        public IType getType() {
                return IType.FEDERATED;
        }
-       
+
        public FEDType getFEDInstructionType() {
                return _fedType;
        }
-       
+
        public long getTID() {
                return _tid;
        }
-       
+
        public void setTID(long tid) {
                _tid = tid;
        }
-       
+
        @Override
        public Instruction preprocessInstruction(ExecutionContext ec) {
                Instruction tmp = super.preprocessInstruction(ec);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index e6a64cb..34f40bb 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -36,6 +36,7 @@ import 
org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.MatrixIndexingCPInstruction;
 import 
org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction;
 import 
org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.QuaternaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
@@ -48,6 +49,7 @@ import 
org.apache.sysds.runtime.instructions.spark.CentralMomentSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.MapmmSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.QuantilePickSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.QuantileSortSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.QuaternarySPInstruction;
 import org.apache.sysds.runtime.instructions.spark.UnarySPInstruction;
 import org.apache.sysds.runtime.instructions.spark.WriteSPInstruction;
 
@@ -101,7 +103,7 @@ public class FEDInstructionUtils {
                        }
                        else if(instruction.input1 != null && 
instruction.input1.isMatrix()
                                && ec.containsVariable(instruction.input1)) {
-                               
+
                                MatrixObject mo1 = 
ec.getMatrixObject(instruction.input1);
                                
if(instruction.getOpcode().equalsIgnoreCase("cm") && mo1.isFederated()) {
                                        fedinst = 
CentralMomentFEDInstruction.parseInstruction(inst.getInstructionString());
@@ -160,18 +162,18 @@ public class FEDInstructionUtils {
                }
                else if(inst instanceof VariableCPInstruction ){
                        VariableCPInstruction ins = (VariableCPInstruction) 
inst;
-                       if(ins.getVariableOpcode() == 
VariableOperationCode.Write 
+                       if(ins.getVariableOpcode() == 
VariableOperationCode.Write
                                && ins.getInput1().isMatrix()
                                && 
ins.getInput3().getName().contains("federated")){
                                fedinst = 
VariableFEDInstruction.parseInstruction(ins);
                        }
-                       else if(ins.getVariableOpcode() == 
VariableOperationCode.CastAsFrameVariable 
-                               && ins.getInput1().isMatrix() 
+                       else if(ins.getVariableOpcode() == 
VariableOperationCode.CastAsFrameVariable
+                               && ins.getInput1().isMatrix()
                                && 
ec.getCacheableData(ins.getInput1()).isFederated()){
                                fedinst = 
VariableFEDInstruction.parseInstruction(ins);
                        }
-                       else if(ins.getVariableOpcode() == 
VariableOperationCode.CastAsMatrixVariable 
-                               && ins.getInput1().isFrame() 
+                       else if(ins.getVariableOpcode() == 
VariableOperationCode.CastAsMatrixVariable
+                               && ins.getInput1().isFrame()
                                && 
ec.getCacheableData(ins.getInput1()).isFederated()){
                                fedinst = 
VariableFEDInstruction.parseInstruction(ins);
                        }
@@ -183,16 +185,22 @@ public class FEDInstructionUtils {
                                fedinst = 
AggregateTernaryFEDInstruction.parseInstruction(ins);
                        }
                }
+               else if(inst instanceof QuaternaryCPInstruction) {
+                       QuaternaryCPInstruction instruction = 
(QuaternaryCPInstruction) inst;
+                       Data data = ec.getVariable(instruction.input1);
+                       if(data instanceof MatrixObject && ((MatrixObject) 
data).isFederated())
+                               fedinst = 
QuaternaryFEDInstruction.parseInstruction(instruction.getInstructionString());
+               }
 
                //set thread id for federated context management
                if( fedinst != null ) {
                        fedinst.setTID(ec.getTID());
                        return fedinst;
                }
-               
+
                return inst;
        }
-       
+
        public static Instruction checkAndReplaceSP(Instruction inst, 
ExecutionContext ec) {
                FEDInstruction fedinst = null;
                if (inst instanceof MapmmSPInstruction) {
@@ -256,12 +264,18 @@ public class FEDInstructionUtils {
                                return 
VariableCPInstruction.parseInstruction(instruction.getInstructionString());
                        }
                }
+               else if(inst instanceof QuaternarySPInstruction) {
+                       QuaternarySPInstruction instruction = 
(QuaternarySPInstruction) inst;
+                       Data data = ec.getVariable(instruction.input1);
+                       if(data instanceof MatrixObject && ((MatrixObject) 
data).isFederated())
+                               fedinst = 
QuaternaryFEDInstruction.parseInstruction(instruction.getInstructionString());
+               }
                //set thread id for federated context management
                if( fedinst != null ) {
                        fedinst.setTID(ec.getTID());
                        return fedinst;
                }
-               
+
                return inst;
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java
new file mode 100644
index 0000000..2b62ec5
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.    See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.lops.WeightedCrossEntropy.WCeMMType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.fed.QuaternaryWCeMMFEDInstruction;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
+
+public abstract class QuaternaryFEDInstruction extends 
ComputationFEDInstruction
+{
+       protected CPOperand _input4 = null;
+
+       protected QuaternaryFEDInstruction(FEDInstruction.FEDType type, 
Operator operator,
+               CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, 
CPOperand out, String opcode, String instruction_str)
+       {
+               super(type, operator, in1, in2, in3, out, opcode, 
instruction_str);
+               _input4 = in4;
+       }
+
+       public static QuaternaryFEDInstruction parseInstruction(String str)
+       {
+               if(str.startsWith(ExecType.SPARK.name())) {
+                       // rewrite the spark instruction to a cp instruction
+                       str = str.replace(ExecType.SPARK.name(), 
ExecType.CP.name());
+                       str = str.replace("mapwcemm", "wcemm");
+                       str += Lop.OPERAND_DELIMITOR + "1"; //num threads
+               }
+
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+               String opcode = parts[0];
+
+               CPOperand in1 = new CPOperand(parts[1]);
+               CPOperand in2 = new CPOperand(parts[2]);
+               CPOperand in3 = new CPOperand(parts[3]);
+               CPOperand out = new CPOperand(parts[5]);
+
+               InstructionUtils.checkNumFields(parts, 7);
+
+               if(opcode.equals("wcemm")) {
+                       CPOperand in4 = new CPOperand(parts[4]);
+                       checkDataTypes(in1, in2, in3, in4);
+
+                       WCeMMType wcemm_type = WCeMMType.valueOf(parts[6]);
+                       QuaternaryOperator quaternary_operator = 
(wcemm_type.hasFourInputs() ?
+                               new QuaternaryOperator(wcemm_type, 
Double.parseDouble(in4.getName())) :
+                               new QuaternaryOperator(wcemm_type));
+                       return new 
QuaternaryWCeMMFEDInstruction(quaternary_operator, in1, in2, in3, in4, out, 
opcode, str);
+               }
+
+               throw new DMLRuntimeException("Unsupported opcode (" + opcode + 
") for QuaternaryFEDInstruction.");
+       }
+
+       protected static void checkDataTypes(CPOperand in1, CPOperand in2, 
CPOperand in3, CPOperand in4) {
+               if(in1.getDataType() != DataType.MATRIX || in2.getDataType() != 
DataType.MATRIX 
+                       || in3.getDataType() != DataType.MATRIX 
+                       || !(in4.getDataType() == DataType.SCALAR || 
in4.getDataType() == DataType.MATRIX)) {
+                       throw new DMLRuntimeException("Federated quaternary 
operations "
+                               + "only supported with matrix inputs and scalar 
epsilon.");
+               }
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java
new file mode 100644
index 0000000..8566b39
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.    See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.DoubleObject;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
+
+import java.util.concurrent.Future;
+
+public class QuaternaryWCeMMFEDInstruction extends QuaternaryFEDInstruction
+{
+       // input1 ... federated X
+       // input2 ... U
+       // input3 ... V
+       // _input4 ... W (=epsilon)
+       protected QuaternaryWCeMMFEDInstruction(Operator operator,
+               CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4,
+               CPOperand out, String opcode, String instruction_str)
+       {
+               super(FEDType.Quaternary, operator, in1, in2, in3, in4, out, 
opcode, instruction_str);
+       }
+
+       @Override
+       public void processInstruction(ExecutionContext ec)
+       {
+               QuaternaryOperator qop = (QuaternaryOperator) _optr;
+               MatrixObject X = ec.getMatrixObject(input1);
+               MatrixObject U = ec.getMatrixObject(input2);
+               MatrixObject V = ec.getMatrixObject(input3);
+               ScalarObject eps = null;
+               
+               if(qop.hasFourInputs()) {
+                       eps = (_input4.getDataType() == DataType.SCALAR) ?
+                               ec.getScalarInput(_input4) :
+                               new 
DoubleObject(ec.getMatrixInput(_input4.getName()).quickGetValue(0, 0));
+               }
+
+               if(!(X.isFederated() && !U.isFederated() && !V.isFederated()))
+                       throw new DMLRuntimeException("Unsupported federated 
inputs (X, U, V) = ("
+                               +X.isFederated()+", "+U.isFederated()+", 
"+V.isFederated()+")");
+               
+               FederationMap fedMap = X.getFedMapping();
+               FederatedRequest[] fr1 = fedMap.broadcastSliced(U, false);
+               FederatedRequest fr2 = fedMap.broadcast(V);
+               FederatedRequest fr3 = null;
+               FederatedRequest frComp = null;
+
+               // broadcast the scalar epsilon if there are four inputs
+               if(eps != null) {
+                       fr3 = fedMap.broadcast(eps);
+                       // change the is_literal flag from true to false 
because when broadcasted it is no literal anymore
+                       instString = instString.replace("true", "false");
+                       frComp = FederationUtils.callInstruction(instString, 
output,
+                               new CPOperand[]{input1, input2, input3, 
_input4},
+                               new long[]{fedMap.getID(), fr1[0].getID(), 
fr2.getID(), fr3.getID()});
+               }
+               else {
+                       frComp = FederationUtils.callInstruction(instString, 
output,
+                       new CPOperand[]{input1, input2, input3},
+                       new long[]{fedMap.getID(), fr1[0].getID(), 
fr2.getID()});
+               }
+               
+               FederatedRequest frGet = new 
FederatedRequest(RequestType.GET_VAR, frComp.getID());
+               FederatedRequest frClean1 = fedMap.cleanup(getTID(), 
frComp.getID());
+               FederatedRequest frClean2 = fedMap.cleanup(getTID(), 
fr1[0].getID());
+               FederatedRequest frClean3 = fedMap.cleanup(getTID(), 
fr2.getID());
+
+               Future<FederatedResponse>[] response;
+               if(fr3 != null) {
+                       FederatedRequest frClean4 = fedMap.cleanup(getTID(), 
fr3.getID());
+                       // execute federated instructions
+                       response = fedMap.execute(getTID(), true, fr1, fr2, fr3,
+                               frComp, frGet, frClean1, frClean2, frClean3, 
frClean4);
+               }
+               else {
+                       // execute federated instructions
+                       response = fedMap.execute(getTID(), true, fr1, fr2,
+                               frComp, frGet, frClean1, frClean2, frClean3);
+               }
+               
+               //aggregate partial results from federated responses
+               AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+               ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, 
response));
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index b2f3a53..a033769 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -19,9 +19,7 @@
 
 package org.apache.sysds.runtime.instructions.fed;
 
-import java.util.AbstractMap;
 import java.util.HashMap;
-import java.util.List;
 import java.util.Map;
 
 import org.apache.sysds.common.Types;
@@ -36,16 +34,13 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.functionobjects.DiagIndex;
 import org.apache.sysds.runtime.functionobjects.RevIndex;
-import org.apache.sysds.runtime.functionobjects.SortIndex;
 import org.apache.sysds.runtime.functionobjects.SwapIndex;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.Data;
-import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
-import org.apache.sysds.runtime.util.IndexRange;
 
 public class ReorgFEDInstruction extends UnaryFEDInstruction {
        
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java 
b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 4d3e9d9..98c6b79 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -837,6 +837,10 @@ public abstract class AutomatedTestBase {
                return TestUtils.readDMLMatrixFromHDFS(baseDirectory + 
OUTPUT_DIR + fileName);
        }
 
+       protected static HashMap<CellIndex, Double> 
readDMLMatrixFromExpectedDir(String fileName) {
+               return TestUtils.readDMLMatrixFromHDFS(baseDirectory + 
EXPECTED_DIR + fileName);
+       }
+       
        public HashMap<CellIndex, Double> readRMatrixFromExpectedDir(String 
fileName) {
                if(LOG.isInfoEnabled())
                        LOG.info("R script out: " + baseDirectory + 
EXPECTED_DIR + cacheDir + fileName);
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedCrossEntropyTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedCrossEntropyTest.java
new file mode 100644
index 0000000..bf676a3
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedCrossEntropyTest.java
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.     See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.      The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.   You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.    See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedWeightedCrossEntropyTest extends AutomatedTestBase
+{
+       private final static String STD_TEST_NAME = "FederatedWCeMMTest";
+       private final static String EPS_TEST_NAME = "FederatedWCeMMEpsTest";
+       private final static String TEST_DIR = 
"functions/federated/quaternary/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedWeightedCrossEntropyTest.class.getSimpleName() + "/";
+
+       private final static String OUTPUT_NAME = "Z";
+       private final static double TOLERANCE = 1e-9;
+       private final static int blocksize = 1024;
+
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+       @Parameterized.Parameter(2)
+       public int rank;
+       @Parameterized.Parameter(3)
+       public double epsilon;
+       @Parameterized.Parameter(4)
+       public double sparsity;
+
+       @Override
+       public void setUp() {
+               addTestConfiguration(STD_TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, STD_TEST_NAME, new String[]{OUTPUT_NAME}));
+               addTestConfiguration(EPS_TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, EPS_TEST_NAME, new String[]{OUTPUT_NAME}));
+       }
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               // rows must be even
+               return Arrays.asList(new Object[][] {
+                       // {rows, cols, rank, epsilon, sparsity}
+                       {2000, 50, 10, 0.01, 0.01},
+                       {2000, 50, 10, 0.01, 0.9},
+                       {2000, 50, 10, 6.45, 0.01},
+                       {2000, 50, 10, 6.45, 0.9}
+               });
+       }
+
+       @BeforeClass
+       public static void init() {
+               TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+       }
+
+       @Test
+       public void federatedWeightedCrossEntropySingleNode() {
+               federatedWeightedCrossEntropy(STD_TEST_NAME, 
ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void federatedWeightedCrossEntropySpark() {
+               federatedWeightedCrossEntropy(STD_TEST_NAME, ExecMode.SPARK);
+       }
+
+       @Test
+       public void federatedWeightedCrossEntropySingleNodeEpsilon() {
+               federatedWeightedCrossEntropy(EPS_TEST_NAME, 
ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       public void federatedWeightedCrossEntropySparkEpsilon() {
+               federatedWeightedCrossEntropy(EPS_TEST_NAME, ExecMode.SPARK);
+       }
+
+// 
-----------------------------------------------------------------------------
+
+       public void federatedWeightedCrossEntropy(String testname, ExecMode 
execMode)
+       {
+               // store the previous platform config to restore it after the 
test
+               ExecMode platform_old = setExecMode(execMode);
+
+               getAndLoadTestConfiguration(testname);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               int fed_rows = rows / 2;
+               int fed_cols = cols;
+
+               // generate dataset
+               // matrix handled by two federated workers
+               double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 0, 1, 
sparsity, 3);
+               double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 0, 1, 
sparsity, 7);
+
+               double[][] U = getRandomMatrix(rows, rank, 0, 1, 1, 512);
+               double[][] V = getRandomMatrix(cols, rank, 0, 1, 1, 5040);
+
+               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+
+               writeInputMatrixWithMTD("U", U, true);
+               writeInputMatrixWithMTD("V", V, true);
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               Thread thread1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
+               Thread thread2 = startLocalFedWorkerThread(port2);
+
+               getAndLoadTestConfiguration(testname);
+
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + testname + "Reference.dml";
+               programArgs = new String[] {"-nvargs", "in_X1=" + input("X1"), 
"in_X2=" + input("X2"),
+                       "in_U=" + input("U"), "in_V=" + input("V"), "in_W=" + 
Double.toString(epsilon),
+                       "out_Z=" + expected(OUTPUT_NAME)};
+               runTest(true, false, null, -1);
+
+               // Run actual dml script with federated matrix
+               fullDMLScriptName = HOME + testname + ".dml";
+               programArgs = new String[] {"-stats", "-nvargs",
+                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                       "in_U=" + input("U"),
+                       "in_V=" + input("V"),
+                       "in_W=" + Double.toString(epsilon),
+                       "rows=" + fed_rows, "cols=" + fed_cols, "out_Z=" + 
output(OUTPUT_NAME)};
+               runTest(true, false, null, -1);
+
+               // compare the results via files
+               HashMap<CellIndex, Double> refResults  = 
readDMLMatrixFromExpectedDir(OUTPUT_NAME);
+               HashMap<CellIndex, Double> fedResults = 
readDMLMatrixFromOutputDir(OUTPUT_NAME);
+               TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, 
"Fed", "Ref");
+
+               TestUtils.shutdownThreads(thread1, thread2);
+
+               // check for federated operations
+               Assert.assertTrue(heavyHittersContainsString("fed_wcemm"));
+
+               // check that federated input files are still existing
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+               resetExecMode(platform_old);
+       }
+}
diff --git 
a/src/test/scripts/functions/federated/quaternary/FederatedWCeMMEpsTest.dml 
b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMEpsTest.dml
new file mode 100644
index 0000000..84c0b92
--- /dev/null
+++ b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMEpsTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+  ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, 
$cols)))
+
+U = read($in_U)
+V = read($in_V)
+epsilon = $in_W
+
+Z = as.matrix(sum(X * log(U %*% t(V) + epsilon)))
+
+write(Z, $out_Z)
diff --git 
a/src/test/scripts/functions/federated/quaternary/FederatedWCeMMEpsTestReference.dml
 
b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMEpsTestReference.dml
new file mode 100644
index 0000000..c01f99a
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMEpsTestReference.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2))
+U = read($in_U)
+V = read($in_V)
+epsilon = $in_W
+
+Z = as.matrix(sum(X * log(U %*% t(V) + epsilon)))
+
+write(Z, $out_Z)
diff --git 
a/src/test/scripts/functions/federated/quaternary/FederatedWCeMMTest.dml 
b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMTest.dml
new file mode 100644
index 0000000..75ae2ef
--- /dev/null
+++ b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMTest.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+  ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, 
$cols)))
+
+U = read($in_U)
+V = read($in_V)
+
+Z = as.matrix(sum(X * log(U %*% t(V))))
+
+write(Z, $out_Z)
diff --git 
a/src/test/scripts/functions/federated/quaternary/FederatedWCeMMTestReference.dml
 
b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMTestReference.dml
new file mode 100644
index 0000000..499ed3d
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/quaternary/FederatedWCeMMTestReference.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2))
+U = read($in_U)
+V = read($in_V)
+
+Z = as.matrix(sum(X * log(U %*% t(V))))
+
+write(Z, $out_Z)

Reply via email to