This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 8565b1d  [SYSTEMDS-2766] Federated covariance
8565b1d is described below

commit 8565b1db948fc7222e3227a84622c6714fa6e425
Author: Olga <[email protected]>
AuthorDate: Mon Jan 11 01:11:47 2021 +0100

    [SYSTEMDS-2766] Federated covariance
    
    Closes #1150
---
 .../instructions/fed/BinaryFEDInstruction.java     |  28 ++
 .../instructions/fed/CovarianceFEDInstruction.java | 312 +++++++++++++++++++++
 .../instructions/fed/FEDInstructionUtils.java      |   3 +
 ...EmptyTest.java => FederatedCovarianceTest.java} | 125 +++++----
 .../primitives/FederatedRemoveEmptyTest.java       |   1 -
 .../federated/FederatedCovarianceAlignedTest.dml   |  31 ++
 .../FederatedCovarianceAlignedTestReference.dml    |  27 ++
 .../federated/FederatedCovarianceTest.dml          |  28 ++
 .../federated/FederatedCovarianceTestReference.dml |  27 ++
 9 files changed, 529 insertions(+), 53 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
index 9f0c91a..1adaf09 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
@@ -35,6 +35,11 @@ public abstract class BinaryFEDInstruction extends 
ComputationFEDInstruction {
                super(type, op, in1, in2, out, opcode, istr);
        }
 
+       public BinaryFEDInstruction(FEDInstruction.FEDType type, Operator op,
+               CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, 
String opcode, String istr) {
+               super(type, op, in1, in2, in3, out, opcode, istr);
+       }
+
        public static BinaryFEDInstruction parseInstruction(String str) {
                if(str.startsWith(ExecType.SPARK.name())) {
                        // rewrite the spark instruction to a cp instruction
@@ -67,6 +72,29 @@ public abstract class BinaryFEDInstruction extends 
ComputationFEDInstruction {
                        throw new DMLRuntimeException("Federated binary 
operations not yet supported:" + opcode);
        }
 
+       protected static String parseBinaryInstruction(String instr, CPOperand 
in1, CPOperand in2, CPOperand out) {
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(instr);
+               InstructionUtils.checkNumFields ( parts, 3, 4 );
+               String opcode = parts[0];
+               in1.split(parts[1]);
+               in2.split(parts[2]);
+               out.split(parts[3]);
+               return opcode;
+       }
+
+       protected static String parseBinaryInstruction(String instr, CPOperand 
in1, CPOperand in2, CPOperand in3, CPOperand out) {
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(instr);
+               InstructionUtils.checkNumFields ( parts, 4 );
+
+               String opcode = parts[0];
+               in1.split(parts[1]);
+               in2.split(parts[2]);
+               in3.split(parts[3]);
+               out.split(parts[4]);
+
+               return opcode;
+       }
+
        protected static void checkOutputDataType(CPOperand in1, CPOperand in2, 
CPOperand out) {
                // check for valid data type of output
                if( (in1.getDataType() == DataType.MATRIX || in2.getDataType() 
== DataType.MATRIX) && out.getDataType() != DataType.MATRIX )
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
new file mode 100644
index 0000000..dd38a2f
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
@@ -0,0 +1,312 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+import java.util.concurrent.Future;
+import java.util.stream.IntStream;
+
+import org.apache.commons.lang3.tuple.ImmutableTriple;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.functionobjects.COV;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.DoubleObject;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.COVOperator;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+public class CovarianceFEDInstruction extends BinaryFEDInstruction {
+       private CovarianceFEDInstruction(Operator op, CPOperand in1, CPOperand 
in2, CPOperand out, String opcode,
+               String istr) {
+               super(FEDInstruction.FEDType.AggregateBinary, op, in1, in2, 
out, opcode, istr);
+       }
+
+       private CovarianceFEDInstruction(Operator op, CPOperand in1, CPOperand 
in2, CPOperand in3, CPOperand out,
+               String opcode, String istr) {
+               super(FEDInstruction.FEDType.AggregateBinary, op, in1, in2, 
in3, out, opcode, istr);
+       }
+
+
+       public static CovarianceFEDInstruction parseInstruction(String str) {
+               CPOperand in1 = new CPOperand("", Types.ValueType.UNKNOWN, 
Types.DataType.UNKNOWN);
+               CPOperand in2 = new CPOperand("", Types.ValueType.UNKNOWN, 
Types.DataType.UNKNOWN);
+               CPOperand in3 = null;
+               CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN, 
Types.DataType.UNKNOWN);
+
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+               String opcode = parts[0];
+
+               if( !opcode.equalsIgnoreCase("cov") ) {
+                       throw new 
DMLRuntimeException("CovarianceCPInstruction.parseInstruction():: Unknown 
opcode " + opcode);
+               }
+
+               COVOperator cov = new COVOperator(COV.getCOMFnObject());
+               if ( parts.length == 4 ) {
+                       parseBinaryInstruction(str, in1, in2, out);
+                       return new CovarianceFEDInstruction(cov, in1, in2, out, 
opcode, str);
+               } else if ( parts.length == 5 ) {
+                       in3 = new CPOperand("", Types.ValueType.UNKNOWN, 
Types.DataType.UNKNOWN);
+                       parseBinaryInstruction(str, in1, in2, in3, out);
+                       return new CovarianceFEDInstruction(cov, in1, in2, in3, 
out, opcode, str);
+               }
+               else {
+                       throw new DMLRuntimeException("Invalid number of 
arguments in Instruction: " + str);
+               }
+       }
+
+       @Override
+       public void processInstruction(ExecutionContext ec) {
+               MatrixObject mo1 = ec.getMatrixObject(input1);
+               MatrixObject mo2 = ec.getMatrixObject(input2);
+               MatrixObject weights = input3 != null ? 
ec.getMatrixObject(input3) : null;
+
+               if(mo1.isFederated() && mo2.isFederated() && 
!mo1.getFedMapping().isAligned(mo2.getFedMapping(), false))
+                       throw new DMLRuntimeException("Not supported 
matrix-matrix binary operation: covariance.");
+
+               boolean moAligned = mo1.isFederated() && mo2.isFederated() && 
mo1.getFedMapping().isAligned(mo2.getFedMapping(), false);
+               boolean weightsAligned = weights == null || 
(weights.isFederated() && mo2.isFederated() && weights.getFedMapping()
+                       .isAligned(mo2.getFedMapping(), false));
+
+               // all aligned
+               if(moAligned && weightsAligned)
+                       processAlignedFedCov(ec, mo1, mo2, weights);
+               // weights are not aligned, broadcast
+               else if(moAligned)
+                       processFedCovWeights(ec, mo1, mo2, weights);
+               else
+                       processCov(ec, mo1, mo2);
+       }
+
+       private void processAlignedFedCov(ExecutionContext ec, MatrixObject 
mo1, MatrixObject mo2, MatrixObject mo3) {
+               FederatedRequest fr1;
+               if(mo3 == null)
+                       fr1 = FederationUtils.callInstruction(instString, 
output,
+                               new CPOperand[]{input1, input2}, new 
long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()});
+               else
+                       fr1 = FederationUtils.callInstruction(instString, 
output,
+                               new CPOperand[]{input1, input2, input3}, new 
long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID(), 
mo3.getFedMapping().getID()});
+
+               FederatedRequest fr2 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
+               FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), 
fr1.getID());
+               Future<FederatedResponse>[] covTmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
+
+               //means
+               Future<FederatedResponse>[] meanTmp1 = processMean(mo1, 0);
+               Future<FederatedResponse>[] meanTmp2 = processMean(mo2, 1);
+
+               ImmutableTriple<Double[], Double[], Double[]> res = 
getResponses(covTmp, meanTmp1, meanTmp2);
+
+               double result = aggCov(res.left, res.middle, res.right, 
mo1.getFedMapping().getFederatedRanges());
+               ec.setVariable(output.getName(), new DoubleObject(result));
+       }
+
+       private void processFedCovWeights(ExecutionContext ec, MatrixObject 
mo1, MatrixObject mo2, MatrixObject mo3) {
+
+               FederatedRequest[] fr2 = 
mo1.getFedMapping().broadcastSliced(mo3, false);
+               FederatedRequest fr1 = 
FederationUtils.callInstruction(instString, output,
+                       new CPOperand[]{input1, input2}, new 
long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()});
+               FederatedRequest fr3 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
+               FederatedRequest fr4 = mo1.getFedMapping().cleanup(getTID(), 
fr1.getID(), fr2[0].getID());
+               Future<FederatedResponse>[] covTmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2[0], fr3, fr4);
+
+               //means
+               Future<FederatedResponse>[] meanTmp1 = processMean(mo1, 0);
+               Future<FederatedResponse>[] meanTmp2 = processMean(mo2, 1);
+
+               ImmutableTriple<Double[], Double[], Double[]> res = 
getResponses(covTmp, meanTmp1, meanTmp2);
+
+               double result = aggCov(res.left, res.middle, res.right, 
mo1.getFedMapping().getFederatedRanges());
+               ec.setVariable(output.getName(), new DoubleObject(result));
+       }
+
+       private void processCov(ExecutionContext ec, MatrixObject mo1, 
MatrixObject mo2) {
+               MatrixBlock mb;
+               MatrixObject mo;
+               COVOperator cop = ((COVOperator)_optr);
+
+               if(!mo1.isFederated() && mo2.isFederated()) {
+                       mo = mo2;
+                       mb = ec.getMatrixInput(input1.getName());
+               }
+               else {
+                       mo = mo1;
+                       mb = ec.getMatrixInput(input2.getName());
+               }
+
+               FederationMap fedMapping = mo.getFedMapping();
+               List<CM_COV_Object> globalCmobj = new ArrayList<>();
+               long varID = FederationUtils.getNextFedDataID();
+               fedMapping.mapParallel(varID, (range, data) -> {
+
+                       FederatedResponse response;
+                       try {
+                               if(input3 == null) {
+                                       response = 
data.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+                                               new 
CovarianceFEDInstruction.COVFunction(data.getVarID(),
+                                                       
mb.slice(range.getBeginDimsInt()[0], range.getEndDimsInt()[0] - 1),
+                                                       cop))).get();
+                               }
+                               // with weights
+                               else {
+                                       MatrixBlock wtBlock = 
ec.getMatrixInput(input2.getName());
+                                       response = 
data.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+                                               new 
CovarianceFEDInstruction.COVWeightsFunction(data.getVarID(),
+                                                       
mb.slice(range.getBeginDimsInt()[0], range.getEndDimsInt()[0] - 1),
+                                                       cop, wtBlock))).get();
+                               }
+
+                               if(!response.isSuccessful())
+                                       response.throwExceptionFromResponse();
+                               synchronized(globalCmobj) {
+                                       globalCmobj.add((CM_COV_Object) 
response.getData()[0]);
+                               }
+                       }
+                       catch(Exception e) {
+                               throw new DMLRuntimeException(e);
+                       }
+                       return null;
+               });
+
+               Optional<CM_COV_Object> res = 
globalCmobj.stream().reduce((arg0, arg1) -> (CM_COV_Object) 
cop.fn.execute(arg0, arg1));
+               try {
+                       ec.setScalarOutput(output.getName(), new 
DoubleObject(res.get().getRequiredResult(cop)));
+               }
+               catch(Exception e) {
+                       throw new DMLRuntimeException(e);
+               }
+       }
+
+       private static ImmutableTriple<Double[], Double[], Double[]> 
getResponses(Future<FederatedResponse>[] covFfr, Future<FederatedResponse>[] 
mean1Ffr, Future<FederatedResponse>[] mean2Ffr) {
+               Double[] cov = new Double[covFfr.length];
+               Double[] mean1 = new Double[mean1Ffr.length];
+               Double[] mean2 = new Double[mean2Ffr.length];
+               IntStream.range(0, covFfr.length).forEach(i -> {
+                       try {
+                               cov[i] = ((ScalarObject) 
covFfr[i].get().getData()[0]).getDoubleValue();
+                               mean1[i] = ((ScalarObject) 
mean1Ffr[1].get().getData()[0]).getDoubleValue();
+                               mean2[i] = ((ScalarObject) 
mean2Ffr[2].get().getData()[0]).getDoubleValue();
+                       }
+                       catch(Exception e) {
+                               throw new 
DMLRuntimeException("CovarianceFEDInstruction: incorrect means or cov.");
+                       }
+               });
+
+               return new ImmutableTriple<>(cov, mean1, mean2);
+       }
+
+       private static double aggCov(Double[] covValues, Double[] mean1, 
Double[] mean2, FederatedRange[] ranges) {
+               double cov = covValues[0];
+               long size1 = ranges[0].getSize();
+               double mean = (mean1[0] + mean2[0]) / 2;
+
+               for(int i = 0; i < covValues.length - 1; i++) {
+                       long size2 = ranges[i+1].getSize();
+                       double nextMean = (mean1[i+1] + mean2[i+1]) / 2;
+                       double newMean = (size1 * mean + size2 * nextMean) / 
(size1 + size2);
+
+                       cov = (size1 * cov + size2 * covValues[i+1] + size1 * 
(mean - newMean) * (mean - newMean)
+                               + size2 * (nextMean - newMean) * (nextMean - 
newMean)) / (size1 + size2);
+
+                       mean = newMean;
+                       size1 = size1 + size2;
+               }
+               return cov;
+       }
+
+       private Future<FederatedResponse>[] processMean(MatrixObject mo1, int 
var){
+               String[] parts = instString.split("°");
+               String meanInstr = instString.replace(getOpcode(), 
getOpcode().replace("cov", "uamean"));
+               meanInstr = meanInstr.replace((var == 0 ? parts[2] : parts[3]) 
+ "°", "");
+               meanInstr = meanInstr.replace(parts[4], 
parts[4].replace("FP64", "STRING°16"));
+               Future<FederatedResponse>[] meanTmp = null;
+
+               //create federated commands for aggregation
+               FederatedRequest meanFr1 = 
FederationUtils.callInstruction(meanInstr, output,
+                       new CPOperand[]{var == 0 ? input2 : input1}, new 
long[]{mo1.getFedMapping().getID()});
+               FederatedRequest meanFr2 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, meanFr1.getID());
+               FederatedRequest meanFr3 = 
mo1.getFedMapping().cleanup(getTID(), meanFr1.getID());
+               meanTmp = mo1.getFedMapping().execute(getTID(), meanFr1, 
meanFr2, meanFr3);
+               return meanTmp;
+       }
+
+       private static class COVFunction extends FederatedUDF {
+
+               private static final long serialVersionUID = 
-501036588060113499L;
+               private final MatrixBlock _mo2;
+               private final COVOperator _op;
+
+               public COVFunction (long input, MatrixBlock mo2, COVOperator 
op) {
+                       super(new long[] {input});
+                       _op = op;
+                       _mo2 = mo2;
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       MatrixBlock mb = ((MatrixObject) 
data[0]).acquireReadAndRelease();
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, mb.covOperations(_op, 
_mo2));
+               }
+
+               @Override public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
+       }
+
+       private static class COVWeightsFunction extends FederatedUDF {
+               private static final long serialVersionUID = 
-1768739786192949573L;
+               private final COVOperator _op;
+               private final MatrixBlock _mo2;
+               private final MatrixBlock _weights;
+
+               protected COVWeightsFunction(long input, MatrixBlock mo2, 
COVOperator op, MatrixBlock weights) {
+                       super(new long[] {input});
+                       _mo2 = mo2;
+                       _op = op;
+                       _weights = weights;
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       MatrixBlock mb = ((MatrixObject) 
data[0]).acquireReadAndRelease();
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, mb.covOperations(_op, 
_mo2, _weights));
+               }
+
+               @Override public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 613ff31..1417c90 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -136,6 +136,9 @@ public class FEDInstructionUtils {
                                        fedinst = 
AppendFEDInstruction.parseInstruction(inst.getInstructionString());
                                else if(instruction.getOpcode().equals("qpick"))
                                        fedinst = 
QuantilePickFEDInstruction.parseInstruction(inst.getInstructionString());
+                               else if("cov".equals(instruction.getOpcode()) 
&& (ec.getMatrixObject(instruction.input1).isFederated(FType.ROW) ||
+                                       
ec.getMatrixObject(instruction.input2).isFederated(FType.ROW)))
+                                       fedinst = 
CovarianceFEDInstruction.parseInstruction(inst.getInstructionString());
                                else
                                        fedinst = 
BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
                        }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCovarianceTest.java
similarity index 51%
copy from 
src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
copy to 
src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCovarianceTest.java
index 10a6711..557341a 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCovarianceTest.java
@@ -25,7 +25,6 @@ import java.util.Collection;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
-import org.apache.sysds.runtime.util.HDFSTool;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
@@ -36,69 +35,64 @@ import org.junit.runners.Parameterized;
 
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
-public class FederatedRemoveEmptyTest extends AutomatedTestBase {
-       // private static final Log LOG = 
LogFactory.getLog(FederatedRightIndexTest.class.getName());
-
-       private final static String TEST_NAME = "FederatedRemoveEmptyTest";
+public class FederatedCovarianceTest extends AutomatedTestBase {
 
+       private final static String TEST_NAME1 = "FederatedCovarianceTest";
+       private final static String TEST_NAME2 = 
"FederatedCovarianceAlignedTest";
        private final static String TEST_DIR = "functions/federated/";
-       private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedRemoveEmptyTest.class.getSimpleName() + "/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedCovarianceTest.class.getSimpleName() + "/";
 
        private final static int blocksize = 1024;
-       @Parameterized.Parameter()
+       @Parameterized.Parameter
        public int rows;
        @Parameterized.Parameter(1)
        public int cols;
 
-       @Parameterized.Parameter(2)
-       public boolean rowPartitioned;
-
        @Parameterized.Parameters
        public static Collection<Object[]> data() {
                return Arrays.asList(new Object[][] {
-                       {20, 12, true},
-                       {20, 12, false}
+                       {20, 1},
+//                     {100, 1}, {1000, 1}
                });
        }
 
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"}));
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S.scalar"}));
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S.scalar"}));
        }
 
        @Test
-       public void testRemoveEmptyCP() {
-               runAggregateOperationTest(ExecMode.SINGLE_NODE);
-       }
+       public void testCovCP() { runCovTest(ExecMode.SINGLE_NODE, false); }
+
+       @Test
+       public void testAlignedCovCP() { runCovTest(ExecMode.SINGLE_NODE, 
true); }
 
-       private void runAggregateOperationTest(ExecMode execMode) {
+       private void runCovTest(ExecMode execMode, boolean alignedFedInput) {
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
                ExecMode platformOld = rtplatform;
 
                if(rtplatform == ExecMode.SPARK)
                        DMLScript.USE_LOCAL_SPARK_CONFIG = true;
 
+               String TEST_NAME = alignedFedInput ? TEST_NAME2 : TEST_NAME1;
+
                getAndLoadTestConfiguration(TEST_NAME);
                String HOME = SCRIPT_DIR + TEST_DIR;
 
                // write input matrices
-               int r = rows;
-               int c = cols / 4;
-               if(rowPartitioned) {
-                       r = rows / 4;
-                       c = cols;
-               }
+               int r = r = rows / 4;
+               int c = cols;
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
 
                double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
                double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
                double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
                double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
 
-               for(int k : new int[] {1, 2, 3}) {
-                       Arrays.fill(X3[k], 0);
-               }
-
                MatrixCharacteristics mc = new MatrixCharacteristics(r, c, 
blocksize, r * c);
                writeInputMatrixWithMTD("X1", X1, false, mc);
                writeInputMatrixWithMTD("X2", X2, false, mc);
@@ -124,36 +118,63 @@ public class FederatedRemoveEmptyTest extends 
AutomatedTestBase {
                TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
                loadTestConfiguration(config);
 
-               // Run reference dml script with normal matrix
-               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
-               programArgs = new String[] {"-stats", "100", "-args", 
input("X1"), input("X2"), input("X3"), input("X4"),
-                       Boolean.toString(rowPartitioned).toUpperCase(), 
expected("S")};
-
-               runTest(null);
-
-               fullDMLScriptName = HOME + TEST_NAME + ".dml";
-               programArgs = new String[] {"-stats", "100", "-nvargs",
-                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
-                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
-                       "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
-                       "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")), "rows=" + rows, "cols=" + cols,
-                       "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), 
"out_S=" + output("S")};
-
-               runTest(null);
+               if(alignedFedInput) {
+                       double[][] Y1 = getRandomMatrix(r, c, 1, 5, 1, 3);
+                       double[][] Y2 = getRandomMatrix(r, c, 1, 5, 1, 7);
+                       double[][] Y3 = getRandomMatrix(r, c, 1, 5, 1, 8);
+                       double[][] Y4 = getRandomMatrix(r, c, 1, 5, 1, 9);
+
+                       writeInputMatrixWithMTD("Y1", Y1, false, mc);
+                       writeInputMatrixWithMTD("Y2", Y2, false, mc);
+                       writeInputMatrixWithMTD("Y3", Y3, false, mc);
+                       writeInputMatrixWithMTD("Y4", Y4, false, mc);
+
+                       // Run reference dml script with normal matrix
+                       fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+                       programArgs = new String[] {"-stats", "100", "-args", 
input("X1"), input("X2"), input("X3"), input("X4"),
+                               input("Y1"), input("Y2"), input("Y3"), 
input("Y4"), expected("S")};
+                       runTest(null);
+
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[] {"-stats", "100", "-nvargs",
+                               "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                               "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                               "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
+                               "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")),
+                               "in_Y1=" + TestUtils.federatedAddress(port1, 
input("Y1")),
+                               "in_Y2=" + TestUtils.federatedAddress(port2, 
input("Y2")),
+                               "in_Y3=" + TestUtils.federatedAddress(port3, 
input("Y3")),
+                               "in_Y4=" + TestUtils.federatedAddress(port4, 
input("Y4")),
+                               "rows=" + rows, "cols=" + cols, "out_S=" + 
output("S")};
+                       runTest(null);
+
+               } else {
+                       double[][] Y = getRandomMatrix(rows, c, 1, 5, 1, 3);
+                       writeInputMatrixWithMTD("Y", Y, false, new 
MatrixCharacteristics(rows, cols, blocksize, r*c));
+
+                       // Run reference dml script with normal matrix
+                       fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+                                       programArgs = new String[] {"-stats", 
"100", "-args", input("X1"), input("X2"), input("X3"), input("X4"),
+                                               input("Y"), expected("S"),};
+                       runTest(null);
+
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[] {"-stats", "100", "-nvargs",
+                               "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                               "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                               "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
+                               "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")),
+                               "Y=" + input("Y"), "rows=" + rows, "cols=" + 
cols, "out_S=" + output("S")};
+                       runTest(null);
+               }
 
                // compare via files
-               compareResults(1e-9);
-
-               // check that federated input files are still existing
-               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
-               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
-               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
-               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+               compareResults(1e-2);
+               Assert.assertTrue(heavyHittersContainsString("fed_cov"));
 
                TestUtils.shutdownThreads(t1, t2, t3, t4);
-
                rtplatform = platformOld;
                DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
-
        }
 }
+
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
index 10a6711..89f67b2 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
@@ -49,7 +49,6 @@ public class FederatedRemoveEmptyTest extends 
AutomatedTestBase {
        public int rows;
        @Parameterized.Parameter(1)
        public int cols;
-
        @Parameterized.Parameter(2)
        public boolean rowPartitioned;
 
diff --git 
a/src/test/scripts/functions/federated/FederatedCovarianceAlignedTest.dml 
b/src/test/scripts/functions/federated/FederatedCovarianceAlignedTest.dml
new file mode 100644
index 0000000..9f64ad0
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedCovarianceAlignedTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+  ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+  list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, 
$cols)));
+
+B = federated(addresses=list($in_Y1, $in_Y2, $in_Y3, $in_Y4),
+  ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+  list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, 
$cols)));
+
+s = cov(A, B);
+write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/FederatedCovarianceAlignedTestReference.dml
 
b/src/test/scripts/functions/federated/FederatedCovarianceAlignedTestReference.dml
new file mode 100644
index 0000000..9039286
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/FederatedCovarianceAlignedTestReference.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rbind(read($1), read($2), read($3), read($4));
+B = rbind(read($5), read($6), read($7), read($8));
+
+s = cov(A, B);
+write(s, $9);
diff --git a/src/test/scripts/functions/federated/FederatedCovarianceTest.dml 
b/src/test/scripts/functions/federated/FederatedCovarianceTest.dml
new file mode 100644
index 0000000..ee1315a
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedCovarianceTest.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+  ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+  list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, 
$cols)));
+B = read($Y);
+
+s = cov(A, B);
+write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/FederatedCovarianceTestReference.dml 
b/src/test/scripts/functions/federated/FederatedCovarianceTestReference.dml
new file mode 100644
index 0000000..f3c3a3a
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedCovarianceTestReference.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rbind(read($1), read($2), read($3), read($4));
+B = read($5);
+
+s = cov(A, B);
+write(s, $6);

Reply via email to