This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 1d1b00825c [SYSTEMDS-3789] Fix federated covariance (missing weighted 
case)
1d1b00825c is described below

commit 1d1b00825c7160e2cbb36307c5fa8a96c9779c44
Author: Grigorii <[email protected]>
AuthorDate: Sun Nov 17 12:32:18 2024 +0100

    [SYSTEMDS-3789] Fix federated covariance (missing weighted case)
    
    Closes #2137.
---
 .../instructions/fed/CovarianceFEDInstruction.java | 413 ++++++++++++++++++---
 .../primitives/part5/FederatedCovarianceTest.java  | 199 +++++++++-
 .../FederatedCovarianceAlignedWeightedTest.dml     |  35 ++
 ...ratedCovarianceAlignedWeightedTestReference.dml |  27 ++
 .../FederatedCovarianceAllAlignedWeightedTest.dml  |  38 ++
 ...edCovarianceAllAlignedWeightedTestReference.dml |  27 ++
 .../federated/FederatedCovarianceWeightedTest.dml  |  31 ++
 .../FederatedCovarianceWeightedTestReference.dml   |  27 ++
 8 files changed, 739 insertions(+), 58 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
index 9e79b024a2..719fd91588 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
@@ -24,8 +24,9 @@ import java.util.List;
 import java.util.Optional;
 import java.util.concurrent.Future;
 import java.util.stream.IntStream;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
 
-import org.apache.commons.lang3.tuple.ImmutableTriple;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -103,45 +104,71 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
        private void processAlignedFedCov(ExecutionContext ec, MatrixObject 
mo1, MatrixObject mo2,
                MatrixLineagePair moLin3) {
                FederatedRequest fr1;
-               if(moLin3 == null)
+               if(moLin3 == null) {
                        fr1 = FederationUtils.callInstruction(instString, 
output,
                                new CPOperand[]{input1, input2}, new 
long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()});
-               else
+               }
+               else {
                        fr1 = FederationUtils.callInstruction(instString, 
output,
                                new CPOperand[]{input1, input2, input3}, new 
long[]{mo1.getFedMapping().getID(),
                                        mo2.getFedMapping().getID(), 
moLin3.getFedMapping().getID()});
-
+               }
+                       
                FederatedRequest fr2 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
                FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), 
fr1.getID());
                Future<FederatedResponse>[] covTmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
 
                //means
-               Future<FederatedResponse>[] meanTmp1 = processMean(mo1, 0);
-               Future<FederatedResponse>[] meanTmp2 = processMean(mo2, 1);
+               Future<FederatedResponse>[] meanTmp1 = processMean(mo1, moLin3, 
0);
+               Future<FederatedResponse>[] meanTmp2 = processMean(mo2, moLin3, 
1);
 
-               ImmutableTriple<Double[], Double[], Double[]> res = 
getResponses(covTmp, meanTmp1, meanTmp2);
+               Double[] cov = getResponses(covTmp);
+               Double[] mean1 = getResponses(meanTmp1);
+               Double[] mean2 = getResponses(meanTmp2);
 
-               double result = aggCov(res.left, res.middle, res.right, 
mo1.getFedMapping().getFederatedRanges());
-               ec.setVariable(output.getName(), new DoubleObject(result));
+               if (moLin3 == null) {
+                       double result = aggCov(cov, mean1, mean2, 
mo1.getFedMapping().getFederatedRanges());
+                       ec.setVariable(output.getName(), new 
DoubleObject(result));
+               }
+               else {
+                       Future<FederatedResponse>[] weightsSumTmp = 
getWeightsSum(moLin3, moLin3.getFedMapping().getID(), instString, 
moLin3.getFedMapping());
+                       Double[] weights = getResponses(weightsSumTmp);
+                       
+                       double result = aggWeightedCov(cov, mean1, mean2, 
weights);
+                       ec.setVariable(output.getName(), new 
DoubleObject(result));
+               }
        }
 
        private void processFedCovWeights(ExecutionContext ec, MatrixObject 
mo1, MatrixObject mo2,
                MatrixLineagePair moLin3) {
+               
+               FederatedRequest[] fr1 = 
mo1.getFedMapping().broadcastSliced(moLin3, false);
 
-               FederatedRequest[] fr2 = 
mo1.getFedMapping().broadcastSliced(moLin3, false);
-               FederatedRequest fr1 = 
FederationUtils.callInstruction(instString, output,
-                       new CPOperand[]{input1, input2}, new 
long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()});
-               FederatedRequest fr3 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
-               FederatedRequest fr4 = mo1.getFedMapping().cleanup(getTID(), 
fr1.getID());
-               Future<FederatedResponse>[] covTmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2[0], fr3, fr4);
+               // the original instruction encodes weights as "pREADW", change 
to the new ID
+               String[] parts = instString.split("°");
+               String covInstr = instString.replace(parts[4], 
String.valueOf(fr1[0].getID()) + "·MATRIX·FP64");
+
+               FederatedRequest fr2 = FederationUtils.callInstruction(
+                       covInstr, output,
+                       new CPOperand[]{input1, input2, input3},
+                       new long[]{mo1.getFedMapping().getID(), 
mo2.getFedMapping().getID(), fr1[0].getID()}
+               );
+               FederatedRequest fr3 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
+               FederatedRequest fr4 = mo1.getFedMapping().cleanup(getTID(), 
fr2.getID());
+               Future<FederatedResponse>[] covTmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
 
                //means
-               Future<FederatedResponse>[] meanTmp1 = processMean(mo1, 0);
-               Future<FederatedResponse>[] meanTmp2 = processMean(mo2, 1);
+               Future<FederatedResponse>[] meanTmp1 = processMean(mo1, 0, 
fr1[0].getID());
+               Future<FederatedResponse>[] meanTmp2 = processMean(mo2, 1, 
fr1[0].getID());
 
-               ImmutableTriple<Double[], Double[], Double[]> res = 
getResponses(covTmp, meanTmp1, meanTmp2);
+               Double[] cov = getResponses(covTmp);
+               Double[] mean1 = getResponses(meanTmp1);
+               Double[] mean2 = getResponses(meanTmp2);
 
-               double result = aggCov(res.left, res.middle, res.right, 
mo1.getFedMapping().getFederatedRanges());
+               Future<FederatedResponse>[] weightsSumTmp = 
getWeightsSum(moLin3, fr1[0].getID(), instString, mo1.getFedMapping());
+               Double[] weights = getResponses(weightsSumTmp);
+               
+               double result = aggWeightedCov(cov, mean1, mean2, weights);
                ec.setVariable(output.getName(), new DoubleObject(result));
        }
 
@@ -174,11 +201,17 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
                                }
                                // with weights
                                else {
-                                       MatrixBlock wtBlock = 
ec.getMatrixInput(input2.getName());
-                                       response = 
data.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
-                                               new 
CovarianceFEDInstruction.COVWeightsFunction(data.getVarID(),
-                                                       
mb.slice(range.getBeginDimsInt()[0], range.getEndDimsInt()[0] - 1),
-                                                       cop, wtBlock))).get();
+                                       MatrixBlock wtBlock = 
ec.getMatrixInput(input3.getName());
+                                       response = 
data.executeFederatedOperation(
+                                               new FederatedRequest(
+                                                       
FederatedRequest.RequestType.EXEC_UDF, -1,
+                                                       new 
CovarianceFEDInstruction.COVWeightsFunction(
+                                                               data.getVarID(),
+                                                               
mb.slice(range.getBeginDimsInt()[0], range.getEndDimsInt()[0] - 1),
+                                                               cop, 
wtBlock.slice(range.getBeginDimsInt()[0], range.getEndDimsInt()[0] - 1)
+                                                       )
+                                               )
+                                       ).get();
                                }
 
                                if(!response.isSuccessful())
@@ -202,59 +235,329 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
                }
        }
 
-       private static ImmutableTriple<Double[], Double[], Double[]> 
getResponses(Future<FederatedResponse>[] covFfr, Future<FederatedResponse>[] 
mean1Ffr, Future<FederatedResponse>[] mean2Ffr) {
-               Double[] cov = new Double[covFfr.length];
-               Double[] mean1 = new Double[mean1Ffr.length];
-               Double[] mean2 = new Double[mean2Ffr.length];
-               IntStream.range(0, covFfr.length).forEach(i -> {
+       private static Double[] getResponses(Future<FederatedResponse>[] ffr) {
+               Double[] fr = new Double[ffr.length];
+               IntStream.range(0, fr.length).forEach(i -> {
                        try {
-                               cov[i] = ((ScalarObject) 
covFfr[i].get().getData()[0]).getDoubleValue();
-                               mean1[i] = ((ScalarObject) 
mean1Ffr[1].get().getData()[0]).getDoubleValue();
-                               mean2[i] = ((ScalarObject) 
mean2Ffr[2].get().getData()[0]).getDoubleValue();
+                               fr[i] = ((ScalarObject) 
ffr[i].get().getData()[0]).getDoubleValue();
                        }
                        catch(Exception e) {
                                throw new 
DMLRuntimeException("CovarianceFEDInstruction: incorrect means or cov.");
                        }
                });
 
-               return new ImmutableTriple<>(cov, mean1, mean2);
+               return fr;
        }
 
        private static double aggCov(Double[] covValues, Double[] mean1, 
Double[] mean2, FederatedRange[] ranges) {
-               double cov = covValues[0];
-               long size1 = ranges[0].getSize();
-               double mean = (mean1[0] + mean2[0]) / 2;
+               long[] sizes = new long[ranges.length];
+               for (int i = 0; i < ranges.length; i++) {
+                       sizes[i] = ranges[i].getSize();
+               }
+               
+               // calculate global means
+               double totalMeanX = 0;
+               double totalMeanY = 0;
+               int totalCount = 0;
+               for (int i = 0; i < mean1.length; i++) {
+                       totalMeanX += mean1[i] * sizes[i];
+                       totalMeanY += mean2[i] * sizes[i];
+                       totalCount += sizes[i];
+               }
+
+               totalMeanX /= totalCount;
+               totalMeanY /= totalCount;
+
+               // calculate global covariance
+               double cov = 0;
+               for (int i = 0; i < covValues.length; i++) {
+                       cov += (sizes[i] - 1) * covValues[i];
+                       cov += sizes[i] * (mean1[i] - totalMeanX) * (mean2[i] - 
totalMeanY);
+               }
+               return cov / (totalCount - 1); // adjusting for degrees of 
freedom
+       }
 
-               for(int i = 0; i < covValues.length - 1; i++) {
-                       long size2 = ranges[i+1].getSize();
-                       double nextMean = (mean1[i+1] + mean2[i+1]) / 2;
-                       double newMean = (size1 * mean + size2 * nextMean) / 
(size1 + size2);
+       private static double aggWeightedCov(Double[] covValues, Double[] 
mean1, Double[] mean2, Double[] weights) {
+               // calculate global weighted means
+               double totalWeightedMeanX = 0;
+               double totalWeightedMeanY = 0;
+               double totalWeight = 0;
+               for (int i = 0; i < mean1.length; i++) {
+                       totalWeight += weights[i];
+                       totalWeightedMeanX += mean1[i] * weights[i];
+                       totalWeightedMeanY += mean2[i] * weights[i];
+               }
 
-                       cov = (size1 * cov + size2 * covValues[i+1] + size1 * 
(mean - newMean) * (mean - newMean)
-                               + size2 * (nextMean - newMean) * (nextMean - 
newMean)) / (size1 + size2);
+               totalWeightedMeanX /= totalWeight;
+               totalWeightedMeanY /= totalWeight;
 
-                       mean = newMean;
-                       size1 = size1 + size2;
+               // calculate global weighted covariance
+               double cov = 0;
+               for (int i = 0; i < covValues.length; i++) {
+                       cov += (weights[i] - 1) * covValues[i];
+                       cov += weights[i] * (mean1[i] - totalWeightedMeanX) * 
(mean2[i] - totalWeightedMeanY);
+               }
+               return cov / (totalWeight - 1); // adjusting for degrees of 
freedom
+       }
+
+       private Future<FederatedResponse>[] processMean(MatrixObject mo1, 
MatrixLineagePair moLin3, int var){
+               String[] parts = instString.split("°");
+               Future<FederatedResponse>[] meanTmp = null;
+               if (moLin3 == null) {
+                       String meanInstr = instString.replace(getOpcode(), 
getOpcode().replace("cov", "uamean"));
+                       meanInstr = meanInstr.replace((var == 0 ? parts[2] : 
parts[3]) + "°", "");
+                       meanInstr = meanInstr.replace(parts[4], 
parts[4].replace("FP64", "STRING°16"));
+
+                       //create federated commands for aggregation
+                       FederatedRequest meanFr1 = 
FederationUtils.callInstruction(meanInstr, output,
+                               new CPOperand[]{var == 0 ? input2 : input1}, 
new long[]{mo1.getFedMapping().getID()});
+                       FederatedRequest meanFr2 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, meanFr1.getID());
+                       FederatedRequest meanFr3 = 
mo1.getFedMapping().cleanup(getTID(), meanFr1.getID());
+
+                       meanTmp = mo1.getFedMapping().execute(getTID(), 
meanFr1, meanFr2, meanFr3);
+               }
+               else {
+                       // multiply input X by weights W element-wise
+                       String multOutput = incrementVar(parts[4], 1);
+                       String multInstr = instString
+                               .replace(getOpcode(), 
getOpcode().replace("cov", "*"))
+                               .replace((var == 0 ? parts[2] : parts[3]) + 
"°", "")
+                               .replace(parts[5], multOutput);
+
+                       CPOperand multOutputCPOp = new CPOperand(
+                               multOutput.substring(0, 
multOutput.indexOf("·")),
+                               mo1.getValueType(), mo1.getDataType()
+                       );
+
+                       FederatedRequest multFr = 
FederationUtils.callInstruction(
+                               multInstr,
+                               multOutputCPOp,
+                               new CPOperand[]{var == 0 ? input2 : input1, 
input3},
+                               new long[]{mo1.getFedMapping().getID(), 
moLin3.getFedMapping().getID()}
+                       );
+
+                       // calculate the sum of the obtained vector
+                       String[] partsMult = multInstr.split("°");
+                       String sumInstr1Output = incrementVar(multOutput, 1)
+                               .replace("m", "")
+                               .replace("MATRIX", "SCALAR");
+                       String sumInstr1 = multInstr
+                               .replace(partsMult[1], "uak+")
+                               .replace(partsMult[3] + "°", "")
+                               .replace(partsMult[4], sumInstr1Output)
+                               .replace(partsMult[2], multOutput);
+
+                       FederatedRequest sumFr1 = 
FederationUtils.callInstruction(
+                               sumInstr1,
+                               new CPOperand(
+                                       sumInstr1Output.substring(0, 
sumInstr1Output.indexOf("·")),
+                                       output.getValueType(), 
output.getDataType()
+                               ),
+                               new CPOperand[]{multOutputCPOp},
+                               new long[]{multFr.getID()}
+                       );
+
+                       // calculate the sum of weights
+                       String[] partsSum1 = sumInstr1.split("°");
+                       String sumInstr2Output = incrementVar(sumInstr1Output, 
1);
+                       String sumInstr2 = sumInstr1
+                               .replace(partsSum1[2], parts[4])
+                               .replace(partsSum1[3], sumInstr2Output);
+
+                       FederatedRequest sumFr2 = 
FederationUtils.callInstruction(
+                               sumInstr2,
+                               new CPOperand(
+                                       sumInstr2Output.substring(0, 
sumInstr2Output.indexOf("·")),
+                                       output.getValueType(), 
output.getDataType()
+                               ),
+                               new CPOperand[]{input3},
+                               new long[]{moLin3.getFedMapping().getID()}
+                       );
+
+                       // divide sum(X*W) by sum(W)
+                       String[] partsSum2 = sumInstr2.split("°");
+                       String divInstrOutput = incrementVar(sumInstr2Output, 
1);
+                       String divInstrInput1 = 
partsSum2[2].replace(partsSum2[2], sumInstr1Output + "·false");
+                       String divInstrInput2 = 
partsSum2[3].replace(partsSum2[3], sumInstr2Output + "·false");
+                       String divInstr = partsSum2[0] + "°" + 
partsSum2[1].replace("uak+", "/") + "°" +
+                                       divInstrInput1 + "°" + divInstrInput2 + 
"°" + divInstrOutput + "°" + partsSum2[4];
+
+                       FederatedRequest divFr1 = 
FederationUtils.callInstruction(
+                               divInstr,
+                               new CPOperand(
+                                       divInstrOutput.substring(0, 
divInstrOutput.indexOf("·")),
+                                       output.getValueType(), 
output.getDataType()
+                               ),
+                               new CPOperand[]{
+                                       new CPOperand(
+                                               sumInstr1Output.substring(0, 
sumInstr1Output.indexOf("·")),
+                                               output.getValueType(), 
output.getDataType(), output.isLiteral()
+                                       ),
+                                       new CPOperand(
+                                               sumInstr2Output.substring(0, 
sumInstr2Output.indexOf("·")),
+                                               output.getValueType(), 
output.getDataType(), output.isLiteral()
+                                       )
+                               },
+                               new long[]{sumFr1.getID(), sumFr2.getID()}
+                       );
+                       FederatedRequest divFr2 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, divFr1.getID());
+                       FederatedRequest divFr3 = 
mo1.getFedMapping().cleanup(getTID(), multFr.getID(), sumFr1.getID(), 
sumFr2.getID(), divFr1.getID(), divFr2.getID());
+
+                       meanTmp = mo1.getFedMapping().execute(getTID(), multFr, 
sumFr1, sumFr2, divFr1, divFr2, divFr3);
                }
-               return cov;
+               return meanTmp;
        }
 
-       private Future<FederatedResponse>[] processMean(MatrixObject mo1, int 
var){
+       private Future<FederatedResponse>[] processMean(MatrixObject mo1, int 
var, long weightsID){
                String[] parts = instString.split("°");
-               String meanInstr = instString.replace(getOpcode(), 
getOpcode().replace("cov", "uamean"));
-               meanInstr = meanInstr.replace((var == 0 ? parts[2] : parts[3]) 
+ "°", "");
-               meanInstr = meanInstr.replace(parts[4], 
parts[4].replace("FP64", "STRING°16"));
                Future<FederatedResponse>[] meanTmp = null;
 
-               //create federated commands for aggregation
-               FederatedRequest meanFr1 = 
FederationUtils.callInstruction(meanInstr, output,
-                       new CPOperand[]{var == 0 ? input2 : input1}, new 
long[]{mo1.getFedMapping().getID()});
-               FederatedRequest meanFr2 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, meanFr1.getID());
-               FederatedRequest meanFr3 = 
mo1.getFedMapping().cleanup(getTID(), meanFr1.getID());
-               meanTmp = mo1.getFedMapping().execute(getTID(), meanFr1, 
meanFr2, meanFr3);
+               // multiply input X by weights W element-wise
+               String multOutput = (var == 0 ? incrementVar(parts[2], 5) : 
incrementVar(parts[3], 3));
+               String multInstr = instString
+                       .replace(getOpcode(), getOpcode().replace("cov", "*"))
+                       .replace((var == 0 ? parts[2] : parts[3]) + "°", "")
+                       .replace(parts[4], String.valueOf(weightsID) + 
"·MATRIX·FP64")
+                       .replace(parts[5], multOutput);
+
+               CPOperand multOutputCPOp = new CPOperand(
+                       multOutput.substring(0, multOutput.indexOf("·")),
+                       mo1.getValueType(), mo1.getDataType()
+               );
+
+               FederatedRequest multFr = FederationUtils.callInstruction(
+                       multInstr,
+                       multOutputCPOp,
+                       new CPOperand[]{var == 0 ? input2 : input1, input3},
+                       new long[]{mo1.getFedMapping().getID(), weightsID}
+               );
+
+               // calculate the sum of the obtained vector
+               String[] partsMult = multInstr.split("°");
+               String sumInstr1Output = incrementVar(multOutput, 1)
+                       .replace("m", "")
+                       .replace("MATRIX", "SCALAR");
+               String sumInstr1 = multInstr
+                       .replace(partsMult[1], "uak+")
+                       .replace(partsMult[3] + "°", "")
+                       .replace(partsMult[4], sumInstr1Output)
+                       .replace(partsMult[2], multOutput);
+
+               FederatedRequest sumFr1 = FederationUtils.callInstruction(
+                       sumInstr1,
+                       new CPOperand(
+                               sumInstr1Output.substring(0, 
sumInstr1Output.indexOf("·")),
+                               output.getValueType(), output.getDataType()
+                       ),
+                       new CPOperand[]{multOutputCPOp},
+                       new long[]{multFr.getID()}
+               );
+
+               // calculate the sum of weights
+               String[] partsSum1 = sumInstr1.split("°");
+               String sumInstr2Output = incrementVar(sumInstr1Output, 1);
+               String sumInstr2 = sumInstr1
+                       .replace(partsSum1[2], String.valueOf(weightsID) + 
"·MATRIX·FP64")
+                       .replace(partsSum1[3], sumInstr2Output);
+
+               FederatedRequest sumFr2 = FederationUtils.callInstruction(
+                       sumInstr2,
+                       new CPOperand(
+                               sumInstr2Output.substring(0, 
sumInstr2Output.indexOf("·")),
+                               output.getValueType(), output.getDataType()
+                       ),
+                       new CPOperand[]{input3},
+                       new long[]{weightsID}
+               );
+
+               // divide sum(X*W) by sum(W)
+               String[] partsSum2 = sumInstr2.split("°");
+               String divInstrOutput = incrementVar(sumInstr2Output, 1);
+               String divInstrInput1 = partsSum2[2].replace(partsSum2[2], 
sumInstr1Output + "·false");
+               String divInstrInput2 = partsSum2[3].replace(partsSum2[3], 
sumInstr2Output + "·false");
+               String divInstr = partsSum2[0] + "°" + 
partsSum2[1].replace("uak+", "/") + "°" +
+                               divInstrInput1 + "°" + divInstrInput2 + "°" + 
divInstrOutput + "°" + partsSum2[4];
+
+               FederatedRequest divFr1 = FederationUtils.callInstruction(
+                       divInstr,
+                       new CPOperand(
+                               divInstrOutput.substring(0, 
divInstrOutput.indexOf("·")),
+                               output.getValueType(), output.getDataType()
+                       ),
+                       new CPOperand[]{
+                               new CPOperand(
+                                       sumInstr1Output.substring(0, 
sumInstr1Output.indexOf("·")),
+                                       output.getValueType(), 
output.getDataType(), output.isLiteral()
+                               ),
+                               new CPOperand(
+                                       sumInstr2Output.substring(0, 
sumInstr2Output.indexOf("·")),
+                                       output.getValueType(), 
output.getDataType(), output.isLiteral()
+                               )
+                       },
+                       new long[]{sumFr1.getID(), sumFr2.getID()}
+               );
+               FederatedRequest divFr2 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, divFr1.getID());
+               FederatedRequest divFr3 = mo1.getFedMapping().cleanup(getTID(), 
multFr.getID(), sumFr1.getID(), sumFr2.getID(), divFr1.getID(), divFr2.getID());
+
+               meanTmp = mo1.getFedMapping().execute(getTID(), multFr, sumFr1, 
sumFr2, divFr1, divFr2, divFr3);
                return meanTmp;
        }
 
+       private Future<FederatedResponse>[] getWeightsSum(MatrixLineagePair 
moLin3, long weightsID, String instString, FederationMap fedMap) {
+               Future<FederatedResponse>[] weightsSumTmp = null;
+
+               String[] parts = instString.split("°");
+               if (!instString.contains("pREADW")) {
+                       String sumInstr = "CP°uak+°" + parts[4] + "°" + 
parts[5] + "°" + parts[6];
+
+                       FederatedRequest sumFr = 
FederationUtils.callInstruction(
+                               sumInstr,
+                               new CPOperand(
+                                       parts[5].substring(0, 
parts[5].indexOf("·")),
+                                       output.getValueType(),
+                                       output.getDataType()
+                               ),
+                               new CPOperand[]{input3},
+                               new long[]{weightsID}
+                       );
+                       FederatedRequest sumFr2 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, sumFr.getID());
+                       FederatedRequest sumFr3 = 
moLin3.getFedMapping().cleanup(getTID(), sumFr.getID());
+
+                       weightsSumTmp = fedMap.execute(getTID(), sumFr, sumFr2, 
sumFr3);
+               }
+               else {
+                       String sumInstr = "CP°uak+°" + 
String.valueOf(weightsID) + "·MATRIX·FP64" + "°" + parts[5] + "°" + parts[6];
+                       FederatedRequest sumFr = 
FederationUtils.callInstruction(
+                               sumInstr,
+                               new CPOperand(
+                                       parts[5].substring(0, 
parts[5].indexOf("·")),
+                                       output.getValueType(),
+                                       output.getDataType()
+                               ),
+                               new CPOperand[]{input3},
+                               new long[]{weightsID}
+                       );
+                       FederatedRequest sumFr2 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, sumFr.getID());
+                       FederatedRequest sumFr3 = fedMap.cleanup(getTID(), 
sumFr.getID());
+
+                       weightsSumTmp = fedMap.execute(getTID(), sumFr, sumFr2, 
sumFr3);
+               }
+               return weightsSumTmp;
+       }
+
+       private static String incrementVar(String str, int inc) {
+               StringBuilder strOut = new StringBuilder(str);
+               Pattern pattern = Pattern.compile("\\d+");
+               Matcher matcher = pattern.matcher(strOut);
+               if (matcher.find()) {
+                       int num = Integer.parseInt(matcher.group()) + inc;
+                       int start = matcher.start();
+                       int end = matcher.end();
+                       strOut.replace(start, end, String.valueOf(num));
+               }
+               return strOut.toString();
+       }
+
        private static class COVFunction extends FederatedUDF {
 
                private static final long serialVersionUID = 
-501036588060113499L;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
index 136cdde7f9..4f23c641fa 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
@@ -43,6 +43,9 @@ public class FederatedCovarianceTest extends 
AutomatedTestBase {
 
        private final static String TEST_NAME1 = "FederatedCovarianceTest";
        private final static String TEST_NAME2 = 
"FederatedCovarianceAlignedTest";
+       private final static String TEST_NAME3 = 
"FederatedCovarianceWeightedTest";
+       private final static String TEST_NAME4 = 
"FederatedCovarianceAlignedWeightedTest";
+       private final static String TEST_NAME5 = 
"FederatedCovarianceAllAlignedWeightedTest";
        private final static String TEST_DIR = "functions/federated/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedCovarianceTest.class.getSimpleName() + "/";
 
@@ -64,19 +67,37 @@ public class FederatedCovarianceTest extends 
AutomatedTestBase {
                TestUtils.clearAssertionInformation();
                addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S.scalar"}));
                addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S.scalar"}));
+               addTestConfiguration(TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S.scalar"}));
+               addTestConfiguration(TEST_NAME4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"S.scalar"}));
+               addTestConfiguration(TEST_NAME5, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {"S.scalar"}));
        }
 
        @Test
        public void testCovCP() {
-               runCovTest(ExecMode.SINGLE_NODE, false);
+               runCovarianceTest(ExecMode.SINGLE_NODE, false);
        }
 
        @Test
        public void testAlignedCovCP() {
-               runCovTest(ExecMode.SINGLE_NODE, true);
+               runCovarianceTest(ExecMode.SINGLE_NODE, true);
        }
 
-       private void runCovTest(ExecMode execMode, boolean alignedFedInput) {
+       @Test
+       public void testCovarianceWeightedCP() {
+               runWeightedCovarianceTest(ExecMode.SINGLE_NODE, false, false);
+       }
+
+       @Test
+       public void testAlignedCovarianceWeightedCP() {
+               runWeightedCovarianceTest(ExecMode.SINGLE_NODE, true, false);
+       }
+
+       @Test
+       public void testAllAlignedCovarianceWeightedCP() {
+               runWeightedCovarianceTest(ExecMode.SINGLE_NODE, true, true);
+       }
+
+       private void runCovarianceTest(ExecMode execMode, boolean 
alignedFedInput) {
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
                ExecMode platformOld = rtplatform;
 
@@ -190,4 +211,176 @@ public class FederatedCovarianceTest extends 
AutomatedTestBase {
                        DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
                }
        }
+
+       private void runWeightedCovarianceTest(ExecMode execMode, boolean 
alignedInput, boolean alignedWeights) {
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               ExecMode platformOld = rtplatform;
+
+               if(rtplatform == ExecMode.SPARK)
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+               String TEST_NAME = !alignedInput ? TEST_NAME3 : 
(!alignedWeights ? TEST_NAME4 : TEST_NAME5);
+               getAndLoadTestConfiguration(TEST_NAME);
+
+               String HOME = SCRIPT_DIR + TEST_DIR;
+               
+               int r = rows / 4;
+               int c = cols;
+
+               fullDMLScriptName = "";
+
+               // Create 4 random 5x1 matrices
+               double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
+               double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
+               double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
+               double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
+
+               // Create a 20x1 weights matrix 
+               double[][] W = getRandomMatrix(rows, c, 0, 1, 1, 3);
+
+               MatrixCharacteristics mc = new MatrixCharacteristics(r, c, 
blocksize, r * c);
+               writeInputMatrixWithMTD("X1", X1, false, mc);
+               writeInputMatrixWithMTD("X2", X2, false, mc);
+               writeInputMatrixWithMTD("X3", X3, false, mc);
+               writeInputMatrixWithMTD("X4", X4, false, mc);
+
+               writeInputMatrixWithMTD("W", W, false, new 
MatrixCharacteristics(rows, cols, blocksize, r * c));
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               int port3 = getRandomAvailablePort();
+               int port4 = getRandomAvailablePort();
+
+               Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
+               Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
+               Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
+               Process t4 = startLocalFedWorker(port4);
+
+               try {
+                       if(!isAlive(t1, t2, t3, t4))
+                               throw new RuntimeException("Failed starting 
federated worker");
+
+                       rtplatform = execMode;
+                       if(rtplatform == ExecMode.SPARK) {
+                               System.out.println(7);
+                               DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+                       }
+
+                       TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+                       loadTestConfiguration(config);
+
+                       if (alignedInput) {
+                               // Create 4 random 5x1 matrices
+                               double[][] Y1 = getRandomMatrix(r, c, 1, 5, 1, 
3);
+                               double[][] Y2 = getRandomMatrix(r, c, 1, 5, 1, 
7);
+                               double[][] Y3 = getRandomMatrix(r, c, 1, 5, 1, 
8);
+                               double[][] Y4 = getRandomMatrix(r, c, 1, 5, 1, 
9);
+
+                               writeInputMatrixWithMTD("Y1", Y1, false, mc);
+                               writeInputMatrixWithMTD("Y2", Y2, false, mc);
+                               writeInputMatrixWithMTD("Y3", Y3, false, mc);
+                               writeInputMatrixWithMTD("Y4", Y4, false, mc);
+
+                               if (!alignedWeights) {
+                                       // Run reference dml script with a 
normal matrix
+                                       fullDMLScriptName = HOME + TEST_NAME + 
"Reference.dml";
+                                       programArgs = new String[] { "-stats", 
"100", "-args",
+                                               input("X1"), input("X2"), 
input("X3"), input("X4"),
+                                               input("Y1"), input("Y2"), 
input("Y3"), input("Y4"),
+                                               input("W"), expected("S")
+                                       };
+                                       runTest(null);
+                                       
+                                       // Run the dml script with federated 
matrices
+                                       fullDMLScriptName = HOME + TEST_NAME + 
".dml";
+                                       programArgs = new String[] {"-stats", 
"100", "-nvargs",
+                                               "in_X1=" + 
TestUtils.federatedAddress(port1, input("X1")),
+                                               "in_Y1=" + 
TestUtils.federatedAddress(port1, input("Y1")),
+                                               "in_X2=" + 
TestUtils.federatedAddress(port2, input("X2")),
+                                               "in_Y2=" + 
TestUtils.federatedAddress(port2, input("Y2")),
+                                               "in_X3=" + 
TestUtils.federatedAddress(port3, input("X3")),
+                                               "in_Y3=" + 
TestUtils.federatedAddress(port3, input("Y3")),
+                                               "in_X4=" + 
TestUtils.federatedAddress(port4, input("X4")),
+                                               "in_Y4=" + 
TestUtils.federatedAddress(port4, input("Y4")),
+                                               "in_W1=" + input("W"), "rows=" 
+ rows, "cols=" + cols, "out_S=" + output("S")};
+                                       runTest(null);
+                               }
+                               else {
+                                       double[][] W1 = getRandomMatrix(r, c, 
0, 1, 1, 3);
+                                       double[][] W2 = getRandomMatrix(r, c, 
0, 1, 1, 7);
+                                       double[][] W3 = getRandomMatrix(r, c, 
0, 1, 1, 8);
+                                       double[][] W4 = getRandomMatrix(r, c, 
0, 1, 1, 9);
+
+                                       writeInputMatrixWithMTD("W1", W1, 
false, mc);
+                                       writeInputMatrixWithMTD("W2", W2, 
false, mc);
+                                       writeInputMatrixWithMTD("W3", W3, 
false, mc);
+                                       writeInputMatrixWithMTD("W4", W4, 
false, mc);
+
+                                       // Run reference dml script with a 
normal matrix
+                                       fullDMLScriptName = HOME + TEST_NAME + 
"Reference.dml";
+                                       programArgs = new String[] {"-stats", 
"100", "-args",
+                                               input("X1"), input("X2"), 
input("X3"), input("X4"),
+                                               input("Y1"), input("Y2"), 
input("Y3"), input("Y4"),
+                                               input("W1"), input("W2"), 
input("W3"), input("W4"), expected("S")
+                                       };
+                                       runTest(null);
+
+                                       // Run the dml script with federated 
matrices and weights
+                                       fullDMLScriptName = HOME + TEST_NAME + 
".dml";
+                                       programArgs = new String[] {"-stats", 
"100", "-nvargs",
+                                               "in_X1=" + 
TestUtils.federatedAddress(port1, input("X1")),
+                                               "in_Y1=" + 
TestUtils.federatedAddress(port1, input("Y1")),
+                                               "in_W1=" + 
TestUtils.federatedAddress(port1, input("W1")),
+                                               "in_X2=" + 
TestUtils.federatedAddress(port2, input("X2")),
+                                               "in_Y2=" + 
TestUtils.federatedAddress(port2, input("Y2")),
+                                               "in_W2=" + 
TestUtils.federatedAddress(port2, input("W2")),
+                                               "in_X3=" + 
TestUtils.federatedAddress(port3, input("X3")),
+                                               "in_Y3=" + 
TestUtils.federatedAddress(port3, input("Y3")),
+                                               "in_W3=" + 
TestUtils.federatedAddress(port3, input("W3")),
+                                               "in_X4=" + 
TestUtils.federatedAddress(port4, input("X4")),
+                                               "in_Y4=" + 
TestUtils.federatedAddress(port4, input("Y4")),
+                                               "in_W4=" + 
TestUtils.federatedAddress(port4, input("W4")),
+                                               "rows=" + rows, "cols=" + cols, 
"out_S=" + output("S")};
+                                       runTest(null);
+                               }
+                               
+                       }
+                       else {
+                               // Create a random 20x1 input matrix
+                               double[][] Y = getRandomMatrix(rows, c, 1, 5, 
1, 3);
+                               writeInputMatrixWithMTD("Y", Y, false, new 
MatrixCharacteristics(rows, cols, blocksize, r * c));
+
+                               // Run reference dml script with a normal matrix
+                               fullDMLScriptName = HOME + TEST_NAME + 
"Reference.dml";
+                               programArgs = new String[] {"-stats", "100", 
"-args",
+                                       input("X1"), input("X2"), input("X3"), 
input("X4"),
+                                       input("Y"), input("W"), expected("S")
+                               };
+                               runTest(null);
+
+                               // Run the dml script with a federated matrix
+                               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                               programArgs = new String[] {"-stats", "100", 
"-nvargs",
+                                       "in_X1=" + 
TestUtils.federatedAddress(port1, input("X1")),
+                                       "in_X2=" + 
TestUtils.federatedAddress(port2, input("X2")),
+                                       "in_X3=" + 
TestUtils.federatedAddress(port3, input("X3")),
+                                       "in_X4=" + 
TestUtils.federatedAddress(port4, input("X4")),
+                                       "in_W1=" + input("W"), "Y=" + 
input("Y"),
+                                       "rows=" + rows, "cols=" + cols, 
"out_S=" + output("S")};
+                               runTest(null);
+                       }
+
+                       // compare via files
+                       compareResults(1e-2);
+                       
Assert.assertTrue(heavyHittersContainsString("fed_cov"));
+
+               }
+               finally {
+                       TestUtils.shutdownThreads(t1, t2, t3, t4);
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               }
+       }
 }
diff --git 
a/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTest.dml
 
b/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTest.dml
new file mode 100644
index 0000000000..da9db2f4de
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTest.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# 5x1 on 4 workers -> 20x1
+X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+  ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+  list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, 
$cols)));
+
+# 5x1 on 4 workers -> 20x1
+Y = federated(addresses=list($in_Y1, $in_Y2, $in_Y3, $in_Y4),
+  ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+  list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, 
$cols)));
+
+W = read($in_W1); # 20x1
+
+s = cov(X, Y, W);
+write(s, $out_S);
\ No newline at end of file
diff --git 
a/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTestReference.dml
 
b/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTestReference.dml
new file mode 100644
index 0000000000..ee4062f7e6
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTestReference.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($1), read($2), read($3), read($4)); # 20x1
+Y = rbind(read($5), read($6), read($7), read($8)); # 20x1
+W = read($9); # 20x1
+
+s = cov(X, Y, W);
+write(s, $10);
\ No newline at end of file
diff --git 
a/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTest.dml
 
b/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTest.dml
new file mode 100644
index 0000000000..22029de451
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTest.dml
@@ -0,0 +1,38 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# 5x1 on 4 workers -> 20x1
+X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+  ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+  list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, 
$cols)));
+
+# 5x1 on 4 workers -> 20x1
+Y = federated(addresses=list($in_Y1, $in_Y2, $in_Y3, $in_Y4),
+  ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+  list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, 
$cols)));
+
+# 5x1 on 4 workers -> 20x1
+W = federated(addresses=list($in_W1, $in_W2, $in_W3, $in_W4),
+  ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+  list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, 
$cols)));
+
+s = cov(X, Y, W);
+write(s, $out_S);
\ No newline at end of file
diff --git 
a/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTestReference.dml
 
b/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTestReference.dml
new file mode 100644
index 0000000000..10c18f5a33
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTestReference.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($1), read($2), read($3), read($4)); # 20x1
+Y = rbind(read($5), read($6), read($7), read($8)); # 20x1
+W = rbind(read($9), read($10), read($11), read($12)); # 20x1
+
+s = cov(X, Y, W);
+write(s, $13);
\ No newline at end of file
diff --git 
a/src/test/scripts/functions/federated/FederatedCovarianceWeightedTest.dml 
b/src/test/scripts/functions/federated/FederatedCovarianceWeightedTest.dml
new file mode 100644
index 0000000000..3ba2d5b15f
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedCovarianceWeightedTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# 5x1 on 4 workers -> 20x1
+X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+  ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+  list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, 
$cols)));
+
+Y = read($Y); # 20x1
+W = read($in_W1); # 20x1
+
+s = cov(X, Y, W);
+write(s, $out_S);
\ No newline at end of file
diff --git 
a/src/test/scripts/functions/federated/FederatedCovarianceWeightedTestReference.dml
 
b/src/test/scripts/functions/federated/FederatedCovarianceWeightedTestReference.dml
new file mode 100644
index 0000000000..db1dc7c526
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/FederatedCovarianceWeightedTestReference.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($1), read($2), read($3), read($4)); # 20x1
+Y = read($5); # 20x1
+W = read($6); # 20x1
+
+s = cov(X, Y, W);
+write(s, $7);
\ No newline at end of file


Reply via email to