This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 1d1b00825c [SYSTEMDS-3789] Fix federated covariance (missing weighted
case)
1d1b00825c is described below
commit 1d1b00825c7160e2cbb36307c5fa8a96c9779c44
Author: Grigorii <[email protected]>
AuthorDate: Sun Nov 17 12:32:18 2024 +0100
[SYSTEMDS-3789] Fix federated covariance (missing weighted case)
Closes #2137.
---
.../instructions/fed/CovarianceFEDInstruction.java | 413 ++++++++++++++++++---
.../primitives/part5/FederatedCovarianceTest.java | 199 +++++++++-
.../FederatedCovarianceAlignedWeightedTest.dml | 35 ++
...ratedCovarianceAlignedWeightedTestReference.dml | 27 ++
.../FederatedCovarianceAllAlignedWeightedTest.dml | 38 ++
...edCovarianceAllAlignedWeightedTestReference.dml | 27 ++
.../federated/FederatedCovarianceWeightedTest.dml | 31 ++
.../FederatedCovarianceWeightedTestReference.dml | 27 ++
8 files changed, 739 insertions(+), 58 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
index 9e79b024a2..719fd91588 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
@@ -24,8 +24,9 @@ import java.util.List;
import java.util.Optional;
import java.util.concurrent.Future;
import java.util.stream.IntStream;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
-import org.apache.commons.lang3.tuple.ImmutableTriple;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -103,45 +104,71 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
private void processAlignedFedCov(ExecutionContext ec, MatrixObject
mo1, MatrixObject mo2,
MatrixLineagePair moLin3) {
FederatedRequest fr1;
- if(moLin3 == null)
+ if(moLin3 == null) {
fr1 = FederationUtils.callInstruction(instString,
output,
new CPOperand[]{input1, input2}, new
long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()});
- else
+ }
+ else {
fr1 = FederationUtils.callInstruction(instString,
output,
new CPOperand[]{input1, input2, input3}, new
long[]{mo1.getFedMapping().getID(),
mo2.getFedMapping().getID(),
moLin3.getFedMapping().getID()});
-
+ }
+
FederatedRequest fr2 = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(),
fr1.getID());
Future<FederatedResponse>[] covTmp =
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
//means
- Future<FederatedResponse>[] meanTmp1 = processMean(mo1, 0);
- Future<FederatedResponse>[] meanTmp2 = processMean(mo2, 1);
+ Future<FederatedResponse>[] meanTmp1 = processMean(mo1, moLin3,
0);
+ Future<FederatedResponse>[] meanTmp2 = processMean(mo2, moLin3,
1);
- ImmutableTriple<Double[], Double[], Double[]> res =
getResponses(covTmp, meanTmp1, meanTmp2);
+ Double[] cov = getResponses(covTmp);
+ Double[] mean1 = getResponses(meanTmp1);
+ Double[] mean2 = getResponses(meanTmp2);
- double result = aggCov(res.left, res.middle, res.right,
mo1.getFedMapping().getFederatedRanges());
- ec.setVariable(output.getName(), new DoubleObject(result));
+ if (moLin3 == null) {
+ double result = aggCov(cov, mean1, mean2,
mo1.getFedMapping().getFederatedRanges());
+ ec.setVariable(output.getName(), new
DoubleObject(result));
+ }
+ else {
+ Future<FederatedResponse>[] weightsSumTmp =
getWeightsSum(moLin3, moLin3.getFedMapping().getID(), instString,
moLin3.getFedMapping());
+ Double[] weights = getResponses(weightsSumTmp);
+
+ double result = aggWeightedCov(cov, mean1, mean2,
weights);
+ ec.setVariable(output.getName(), new
DoubleObject(result));
+ }
}
private void processFedCovWeights(ExecutionContext ec, MatrixObject
mo1, MatrixObject mo2,
MatrixLineagePair moLin3) {
+
+ FederatedRequest[] fr1 =
mo1.getFedMapping().broadcastSliced(moLin3, false);
- FederatedRequest[] fr2 =
mo1.getFedMapping().broadcastSliced(moLin3, false);
- FederatedRequest fr1 =
FederationUtils.callInstruction(instString, output,
- new CPOperand[]{input1, input2}, new
long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()});
- FederatedRequest fr3 = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
- FederatedRequest fr4 = mo1.getFedMapping().cleanup(getTID(),
fr1.getID());
- Future<FederatedResponse>[] covTmp =
mo1.getFedMapping().execute(getTID(), fr1, fr2[0], fr3, fr4);
+ // the original instruction encodes weights as "pREADW", change
to the new ID
+ String[] parts = instString.split("°");
+ String covInstr = instString.replace(parts[4],
String.valueOf(fr1[0].getID()) + "·MATRIX·FP64");
+
+ FederatedRequest fr2 = FederationUtils.callInstruction(
+ covInstr, output,
+ new CPOperand[]{input1, input2, input3},
+ new long[]{mo1.getFedMapping().getID(),
mo2.getFedMapping().getID(), fr1[0].getID()}
+ );
+ FederatedRequest fr3 = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr4 = mo1.getFedMapping().cleanup(getTID(),
fr2.getID());
+ Future<FederatedResponse>[] covTmp =
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
//means
- Future<FederatedResponse>[] meanTmp1 = processMean(mo1, 0);
- Future<FederatedResponse>[] meanTmp2 = processMean(mo2, 1);
+ Future<FederatedResponse>[] meanTmp1 = processMean(mo1, 0,
fr1[0].getID());
+ Future<FederatedResponse>[] meanTmp2 = processMean(mo2, 1,
fr1[0].getID());
- ImmutableTriple<Double[], Double[], Double[]> res =
getResponses(covTmp, meanTmp1, meanTmp2);
+ Double[] cov = getResponses(covTmp);
+ Double[] mean1 = getResponses(meanTmp1);
+ Double[] mean2 = getResponses(meanTmp2);
- double result = aggCov(res.left, res.middle, res.right,
mo1.getFedMapping().getFederatedRanges());
+ Future<FederatedResponse>[] weightsSumTmp =
getWeightsSum(moLin3, fr1[0].getID(), instString, mo1.getFedMapping());
+ Double[] weights = getResponses(weightsSumTmp);
+
+ double result = aggWeightedCov(cov, mean1, mean2, weights);
ec.setVariable(output.getName(), new DoubleObject(result));
}
@@ -174,11 +201,17 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
}
// with weights
else {
- MatrixBlock wtBlock =
ec.getMatrixInput(input2.getName());
- response =
data.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
- new
CovarianceFEDInstruction.COVWeightsFunction(data.getVarID(),
-
mb.slice(range.getBeginDimsInt()[0], range.getEndDimsInt()[0] - 1),
- cop, wtBlock))).get();
+ MatrixBlock wtBlock =
ec.getMatrixInput(input3.getName());
+ response =
data.executeFederatedOperation(
+ new FederatedRequest(
+
FederatedRequest.RequestType.EXEC_UDF, -1,
+ new
CovarianceFEDInstruction.COVWeightsFunction(
+ data.getVarID(),
+
mb.slice(range.getBeginDimsInt()[0], range.getEndDimsInt()[0] - 1),
+ cop,
wtBlock.slice(range.getBeginDimsInt()[0], range.getEndDimsInt()[0] - 1)
+ )
+ )
+ ).get();
}
if(!response.isSuccessful())
@@ -202,59 +235,329 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
}
}
- private static ImmutableTriple<Double[], Double[], Double[]>
getResponses(Future<FederatedResponse>[] covFfr, Future<FederatedResponse>[]
mean1Ffr, Future<FederatedResponse>[] mean2Ffr) {
- Double[] cov = new Double[covFfr.length];
- Double[] mean1 = new Double[mean1Ffr.length];
- Double[] mean2 = new Double[mean2Ffr.length];
- IntStream.range(0, covFfr.length).forEach(i -> {
+ private static Double[] getResponses(Future<FederatedResponse>[] ffr) {
+ Double[] fr = new Double[ffr.length];
+ IntStream.range(0, fr.length).forEach(i -> {
try {
- cov[i] = ((ScalarObject)
covFfr[i].get().getData()[0]).getDoubleValue();
- mean1[i] = ((ScalarObject)
mean1Ffr[1].get().getData()[0]).getDoubleValue();
- mean2[i] = ((ScalarObject)
mean2Ffr[2].get().getData()[0]).getDoubleValue();
+ fr[i] = ((ScalarObject)
ffr[i].get().getData()[0]).getDoubleValue();
}
catch(Exception e) {
throw new
DMLRuntimeException("CovarianceFEDInstruction: incorrect means or cov.");
}
});
- return new ImmutableTriple<>(cov, mean1, mean2);
+ return fr;
}
private static double aggCov(Double[] covValues, Double[] mean1,
Double[] mean2, FederatedRange[] ranges) {
- double cov = covValues[0];
- long size1 = ranges[0].getSize();
- double mean = (mean1[0] + mean2[0]) / 2;
+ long[] sizes = new long[ranges.length];
+ for (int i = 0; i < ranges.length; i++) {
+ sizes[i] = ranges[i].getSize();
+ }
+
+ // calculate global means
+ double totalMeanX = 0;
+ double totalMeanY = 0;
+ int totalCount = 0;
+ for (int i = 0; i < mean1.length; i++) {
+ totalMeanX += mean1[i] * sizes[i];
+ totalMeanY += mean2[i] * sizes[i];
+ totalCount += sizes[i];
+ }
+
+ totalMeanX /= totalCount;
+ totalMeanY /= totalCount;
+
+ // calculate global covariance
+ double cov = 0;
+ for (int i = 0; i < covValues.length; i++) {
+ cov += (sizes[i] - 1) * covValues[i];
+ cov += sizes[i] * (mean1[i] - totalMeanX) * (mean2[i] -
totalMeanY);
+ }
+ return cov / (totalCount - 1); // adjusting for degrees of
freedom
+ }
- for(int i = 0; i < covValues.length - 1; i++) {
- long size2 = ranges[i+1].getSize();
- double nextMean = (mean1[i+1] + mean2[i+1]) / 2;
- double newMean = (size1 * mean + size2 * nextMean) /
(size1 + size2);
+ private static double aggWeightedCov(Double[] covValues, Double[]
mean1, Double[] mean2, Double[] weights) {
+ // calculate global weighted means
+ double totalWeightedMeanX = 0;
+ double totalWeightedMeanY = 0;
+ double totalWeight = 0;
+ for (int i = 0; i < mean1.length; i++) {
+ totalWeight += weights[i];
+ totalWeightedMeanX += mean1[i] * weights[i];
+ totalWeightedMeanY += mean2[i] * weights[i];
+ }
- cov = (size1 * cov + size2 * covValues[i+1] + size1 *
(mean - newMean) * (mean - newMean)
- + size2 * (nextMean - newMean) * (nextMean -
newMean)) / (size1 + size2);
+ totalWeightedMeanX /= totalWeight;
+ totalWeightedMeanY /= totalWeight;
- mean = newMean;
- size1 = size1 + size2;
+ // calculate global weighted covariance
+ double cov = 0;
+ for (int i = 0; i < covValues.length; i++) {
+ cov += (weights[i] - 1) * covValues[i];
+ cov += weights[i] * (mean1[i] - totalWeightedMeanX) *
(mean2[i] - totalWeightedMeanY);
+ }
+ return cov / (totalWeight - 1); // adjusting for degrees of
freedom
+ }
+
+ private Future<FederatedResponse>[] processMean(MatrixObject mo1,
MatrixLineagePair moLin3, int var){
+ String[] parts = instString.split("°");
+ Future<FederatedResponse>[] meanTmp = null;
+ if (moLin3 == null) {
+ String meanInstr = instString.replace(getOpcode(),
getOpcode().replace("cov", "uamean"));
+ meanInstr = meanInstr.replace((var == 0 ? parts[2] :
parts[3]) + "°", "");
+ meanInstr = meanInstr.replace(parts[4],
parts[4].replace("FP64", "STRING°16"));
+
+ //create federated commands for aggregation
+ FederatedRequest meanFr1 =
FederationUtils.callInstruction(meanInstr, output,
+ new CPOperand[]{var == 0 ? input2 : input1},
new long[]{mo1.getFedMapping().getID()});
+ FederatedRequest meanFr2 = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, meanFr1.getID());
+ FederatedRequest meanFr3 =
mo1.getFedMapping().cleanup(getTID(), meanFr1.getID());
+
+ meanTmp = mo1.getFedMapping().execute(getTID(),
meanFr1, meanFr2, meanFr3);
+ }
+ else {
+ // multiply input X by weights W element-wise
+ String multOutput = incrementVar(parts[4], 1);
+ String multInstr = instString
+ .replace(getOpcode(),
getOpcode().replace("cov", "*"))
+ .replace((var == 0 ? parts[2] : parts[3]) +
"°", "")
+ .replace(parts[5], multOutput);
+
+ CPOperand multOutputCPOp = new CPOperand(
+ multOutput.substring(0,
multOutput.indexOf("·")),
+ mo1.getValueType(), mo1.getDataType()
+ );
+
+ FederatedRequest multFr =
FederationUtils.callInstruction(
+ multInstr,
+ multOutputCPOp,
+ new CPOperand[]{var == 0 ? input2 : input1,
input3},
+ new long[]{mo1.getFedMapping().getID(),
moLin3.getFedMapping().getID()}
+ );
+
+ // calculate the sum of the obtained vector
+ String[] partsMult = multInstr.split("°");
+ String sumInstr1Output = incrementVar(multOutput, 1)
+ .replace("m", "")
+ .replace("MATRIX", "SCALAR");
+ String sumInstr1 = multInstr
+ .replace(partsMult[1], "uak+")
+ .replace(partsMult[3] + "°", "")
+ .replace(partsMult[4], sumInstr1Output)
+ .replace(partsMult[2], multOutput);
+
+ FederatedRequest sumFr1 =
FederationUtils.callInstruction(
+ sumInstr1,
+ new CPOperand(
+ sumInstr1Output.substring(0,
sumInstr1Output.indexOf("·")),
+ output.getValueType(),
output.getDataType()
+ ),
+ new CPOperand[]{multOutputCPOp},
+ new long[]{multFr.getID()}
+ );
+
+ // calculate the sum of weights
+ String[] partsSum1 = sumInstr1.split("°");
+ String sumInstr2Output = incrementVar(sumInstr1Output,
1);
+ String sumInstr2 = sumInstr1
+ .replace(partsSum1[2], parts[4])
+ .replace(partsSum1[3], sumInstr2Output);
+
+ FederatedRequest sumFr2 =
FederationUtils.callInstruction(
+ sumInstr2,
+ new CPOperand(
+ sumInstr2Output.substring(0,
sumInstr2Output.indexOf("·")),
+ output.getValueType(),
output.getDataType()
+ ),
+ new CPOperand[]{input3},
+ new long[]{moLin3.getFedMapping().getID()}
+ );
+
+ // divide sum(X*W) by sum(W)
+ String[] partsSum2 = sumInstr2.split("°");
+ String divInstrOutput = incrementVar(sumInstr2Output,
1);
+ String divInstrInput1 =
partsSum2[2].replace(partsSum2[2], sumInstr1Output + "·false");
+ String divInstrInput2 =
partsSum2[3].replace(partsSum2[3], sumInstr2Output + "·false");
+ String divInstr = partsSum2[0] + "°" +
partsSum2[1].replace("uak+", "/") + "°" +
+ divInstrInput1 + "°" + divInstrInput2 +
"°" + divInstrOutput + "°" + partsSum2[4];
+
+ FederatedRequest divFr1 =
FederationUtils.callInstruction(
+ divInstr,
+ new CPOperand(
+ divInstrOutput.substring(0,
divInstrOutput.indexOf("·")),
+ output.getValueType(),
output.getDataType()
+ ),
+ new CPOperand[]{
+ new CPOperand(
+ sumInstr1Output.substring(0,
sumInstr1Output.indexOf("·")),
+ output.getValueType(),
output.getDataType(), output.isLiteral()
+ ),
+ new CPOperand(
+ sumInstr2Output.substring(0,
sumInstr2Output.indexOf("·")),
+ output.getValueType(),
output.getDataType(), output.isLiteral()
+ )
+ },
+ new long[]{sumFr1.getID(), sumFr2.getID()}
+ );
+ FederatedRequest divFr2 = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, divFr1.getID());
+ FederatedRequest divFr3 =
mo1.getFedMapping().cleanup(getTID(), multFr.getID(), sumFr1.getID(),
sumFr2.getID(), divFr1.getID(), divFr2.getID());
+
+ meanTmp = mo1.getFedMapping().execute(getTID(), multFr,
sumFr1, sumFr2, divFr1, divFr2, divFr3);
}
- return cov;
+ return meanTmp;
}
- private Future<FederatedResponse>[] processMean(MatrixObject mo1, int
var){
+ private Future<FederatedResponse>[] processMean(MatrixObject mo1, int
var, long weightsID){
String[] parts = instString.split("°");
- String meanInstr = instString.replace(getOpcode(),
getOpcode().replace("cov", "uamean"));
- meanInstr = meanInstr.replace((var == 0 ? parts[2] : parts[3])
+ "°", "");
- meanInstr = meanInstr.replace(parts[4],
parts[4].replace("FP64", "STRING°16"));
Future<FederatedResponse>[] meanTmp = null;
- //create federated commands for aggregation
- FederatedRequest meanFr1 =
FederationUtils.callInstruction(meanInstr, output,
- new CPOperand[]{var == 0 ? input2 : input1}, new
long[]{mo1.getFedMapping().getID()});
- FederatedRequest meanFr2 = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, meanFr1.getID());
- FederatedRequest meanFr3 =
mo1.getFedMapping().cleanup(getTID(), meanFr1.getID());
- meanTmp = mo1.getFedMapping().execute(getTID(), meanFr1,
meanFr2, meanFr3);
+ // multiply input X by weights W element-wise
+ String multOutput = (var == 0 ? incrementVar(parts[2], 5) :
incrementVar(parts[3], 3));
+ String multInstr = instString
+ .replace(getOpcode(), getOpcode().replace("cov", "*"))
+ .replace((var == 0 ? parts[2] : parts[3]) + "°", "")
+ .replace(parts[4], String.valueOf(weightsID) +
"·MATRIX·FP64")
+ .replace(parts[5], multOutput);
+
+ CPOperand multOutputCPOp = new CPOperand(
+ multOutput.substring(0, multOutput.indexOf("·")),
+ mo1.getValueType(), mo1.getDataType()
+ );
+
+ FederatedRequest multFr = FederationUtils.callInstruction(
+ multInstr,
+ multOutputCPOp,
+ new CPOperand[]{var == 0 ? input2 : input1, input3},
+ new long[]{mo1.getFedMapping().getID(), weightsID}
+ );
+
+ // calculate the sum of the obtained vector
+ String[] partsMult = multInstr.split("°");
+ String sumInstr1Output = incrementVar(multOutput, 1)
+ .replace("m", "")
+ .replace("MATRIX", "SCALAR");
+ String sumInstr1 = multInstr
+ .replace(partsMult[1], "uak+")
+ .replace(partsMult[3] + "°", "")
+ .replace(partsMult[4], sumInstr1Output)
+ .replace(partsMult[2], multOutput);
+
+ FederatedRequest sumFr1 = FederationUtils.callInstruction(
+ sumInstr1,
+ new CPOperand(
+ sumInstr1Output.substring(0,
sumInstr1Output.indexOf("·")),
+ output.getValueType(), output.getDataType()
+ ),
+ new CPOperand[]{multOutputCPOp},
+ new long[]{multFr.getID()}
+ );
+
+ // calculate the sum of weights
+ String[] partsSum1 = sumInstr1.split("°");
+ String sumInstr2Output = incrementVar(sumInstr1Output, 1);
+ String sumInstr2 = sumInstr1
+ .replace(partsSum1[2], String.valueOf(weightsID) +
"·MATRIX·FP64")
+ .replace(partsSum1[3], sumInstr2Output);
+
+ FederatedRequest sumFr2 = FederationUtils.callInstruction(
+ sumInstr2,
+ new CPOperand(
+ sumInstr2Output.substring(0,
sumInstr2Output.indexOf("·")),
+ output.getValueType(), output.getDataType()
+ ),
+ new CPOperand[]{input3},
+ new long[]{weightsID}
+ );
+
+ // divide sum(X*W) by sum(W)
+ String[] partsSum2 = sumInstr2.split("°");
+ String divInstrOutput = incrementVar(sumInstr2Output, 1);
+ String divInstrInput1 = partsSum2[2].replace(partsSum2[2],
sumInstr1Output + "·false");
+ String divInstrInput2 = partsSum2[3].replace(partsSum2[3],
sumInstr2Output + "·false");
+ String divInstr = partsSum2[0] + "°" +
partsSum2[1].replace("uak+", "/") + "°" +
+ divInstrInput1 + "°" + divInstrInput2 + "°" +
divInstrOutput + "°" + partsSum2[4];
+
+ FederatedRequest divFr1 = FederationUtils.callInstruction(
+ divInstr,
+ new CPOperand(
+ divInstrOutput.substring(0,
divInstrOutput.indexOf("·")),
+ output.getValueType(), output.getDataType()
+ ),
+ new CPOperand[]{
+ new CPOperand(
+ sumInstr1Output.substring(0,
sumInstr1Output.indexOf("·")),
+ output.getValueType(),
output.getDataType(), output.isLiteral()
+ ),
+ new CPOperand(
+ sumInstr2Output.substring(0,
sumInstr2Output.indexOf("·")),
+ output.getValueType(),
output.getDataType(), output.isLiteral()
+ )
+ },
+ new long[]{sumFr1.getID(), sumFr2.getID()}
+ );
+ FederatedRequest divFr2 = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, divFr1.getID());
+ FederatedRequest divFr3 = mo1.getFedMapping().cleanup(getTID(),
multFr.getID(), sumFr1.getID(), sumFr2.getID(), divFr1.getID(), divFr2.getID());
+
+ meanTmp = mo1.getFedMapping().execute(getTID(), multFr, sumFr1,
sumFr2, divFr1, divFr2, divFr3);
return meanTmp;
}
+ private Future<FederatedResponse>[] getWeightsSum(MatrixLineagePair
moLin3, long weightsID, String instString, FederationMap fedMap) {
+ Future<FederatedResponse>[] weightsSumTmp = null;
+
+ String[] parts = instString.split("°");
+ if (!instString.contains("pREADW")) {
+ String sumInstr = "CP°uak+°" + parts[4] + "°" +
parts[5] + "°" + parts[6];
+
+ FederatedRequest sumFr =
FederationUtils.callInstruction(
+ sumInstr,
+ new CPOperand(
+ parts[5].substring(0,
parts[5].indexOf("·")),
+ output.getValueType(),
+ output.getDataType()
+ ),
+ new CPOperand[]{input3},
+ new long[]{weightsID}
+ );
+ FederatedRequest sumFr2 = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, sumFr.getID());
+ FederatedRequest sumFr3 =
moLin3.getFedMapping().cleanup(getTID(), sumFr.getID());
+
+ weightsSumTmp = fedMap.execute(getTID(), sumFr, sumFr2,
sumFr3);
+ }
+ else {
+ String sumInstr = "CP°uak+°" +
String.valueOf(weightsID) + "·MATRIX·FP64" + "°" + parts[5] + "°" + parts[6];
+ FederatedRequest sumFr =
FederationUtils.callInstruction(
+ sumInstr,
+ new CPOperand(
+ parts[5].substring(0,
parts[5].indexOf("·")),
+ output.getValueType(),
+ output.getDataType()
+ ),
+ new CPOperand[]{input3},
+ new long[]{weightsID}
+ );
+ FederatedRequest sumFr2 = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, sumFr.getID());
+ FederatedRequest sumFr3 = fedMap.cleanup(getTID(),
sumFr.getID());
+
+ weightsSumTmp = fedMap.execute(getTID(), sumFr, sumFr2,
sumFr3);
+ }
+ return weightsSumTmp;
+ }
+
+ private static String incrementVar(String str, int inc) {
+ StringBuilder strOut = new StringBuilder(str);
+ Pattern pattern = Pattern.compile("\\d+");
+ Matcher matcher = pattern.matcher(strOut);
+ if (matcher.find()) {
+ int num = Integer.parseInt(matcher.group()) + inc;
+ int start = matcher.start();
+ int end = matcher.end();
+ strOut.replace(start, end, String.valueOf(num));
+ }
+ return strOut.toString();
+ }
+
private static class COVFunction extends FederatedUDF {
private static final long serialVersionUID =
-501036588060113499L;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
index 136cdde7f9..4f23c641fa 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
@@ -43,6 +43,9 @@ public class FederatedCovarianceTest extends
AutomatedTestBase {
private final static String TEST_NAME1 = "FederatedCovarianceTest";
private final static String TEST_NAME2 =
"FederatedCovarianceAlignedTest";
+ private final static String TEST_NAME3 =
"FederatedCovarianceWeightedTest";
+ private final static String TEST_NAME4 =
"FederatedCovarianceAlignedWeightedTest";
+ private final static String TEST_NAME5 =
"FederatedCovarianceAllAlignedWeightedTest";
private final static String TEST_DIR = "functions/federated/";
private static final String TEST_CLASS_DIR = TEST_DIR +
FederatedCovarianceTest.class.getSimpleName() + "/";
@@ -64,19 +67,37 @@ public class FederatedCovarianceTest extends
AutomatedTestBase {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S.scalar"}));
addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S.scalar"}));
+ addTestConfiguration(TEST_NAME3, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S.scalar"}));
+ addTestConfiguration(TEST_NAME4, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"S.scalar"}));
+ addTestConfiguration(TEST_NAME5, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {"S.scalar"}));
}
@Test
public void testCovCP() {
- runCovTest(ExecMode.SINGLE_NODE, false);
+ runCovarianceTest(ExecMode.SINGLE_NODE, false);
}
@Test
public void testAlignedCovCP() {
- runCovTest(ExecMode.SINGLE_NODE, true);
+ runCovarianceTest(ExecMode.SINGLE_NODE, true);
}
- private void runCovTest(ExecMode execMode, boolean alignedFedInput) {
+ @Test
+ public void testCovarianceWeightedCP() {
+ runWeightedCovarianceTest(ExecMode.SINGLE_NODE, false, false);
+ }
+
+ @Test
+ public void testAlignedCovarianceWeightedCP() {
+ runWeightedCovarianceTest(ExecMode.SINGLE_NODE, true, false);
+ }
+
+ @Test
+ public void testAllAlignedCovarianceWeightedCP() {
+ runWeightedCovarianceTest(ExecMode.SINGLE_NODE, true, true);
+ }
+
+ private void runCovarianceTest(ExecMode execMode, boolean
alignedFedInput) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;
@@ -190,4 +211,176 @@ public class FederatedCovarianceTest extends
AutomatedTestBase {
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
}
+
+ private void runWeightedCovarianceTest(ExecMode execMode, boolean
alignedInput, boolean alignedWeights) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ ExecMode platformOld = rtplatform;
+
+ if(rtplatform == ExecMode.SPARK)
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ String TEST_NAME = !alignedInput ? TEST_NAME3 :
(!alignedWeights ? TEST_NAME4 : TEST_NAME5);
+ getAndLoadTestConfiguration(TEST_NAME);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ int r = rows / 4;
+ int c = cols;
+
+ fullDMLScriptName = "";
+
+ // Create 4 random 5x1 matrices
+ double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
+ double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
+ double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
+ double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
+
+ // Create a 20x1 weights matrix
+ double[][] W = getRandomMatrix(rows, c, 0, 1, 1, 3);
+
+ MatrixCharacteristics mc = new MatrixCharacteristics(r, c,
blocksize, r * c);
+ writeInputMatrixWithMTD("X1", X1, false, mc);
+ writeInputMatrixWithMTD("X2", X2, false, mc);
+ writeInputMatrixWithMTD("X3", X3, false, mc);
+ writeInputMatrixWithMTD("X4", X4, false, mc);
+
+ writeInputMatrixWithMTD("W", W, false, new
MatrixCharacteristics(rows, cols, blocksize, r * c));
+
+ // empty script name because we don't execute any script, just
start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ int port3 = getRandomAvailablePort();
+ int port4 = getRandomAvailablePort();
+
+ Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
+ Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
+ Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
+ Process t4 = startLocalFedWorker(port4);
+
+ try {
+ if(!isAlive(t1, t2, t3, t4))
+ throw new RuntimeException("Failed starting
federated worker");
+
+ rtplatform = execMode;
+ if(rtplatform == ExecMode.SPARK) {
+ System.out.println(7);
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+
+ TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+
+ if (alignedInput) {
+ // Create 4 random 5x1 matrices
+ double[][] Y1 = getRandomMatrix(r, c, 1, 5, 1,
3);
+ double[][] Y2 = getRandomMatrix(r, c, 1, 5, 1,
7);
+ double[][] Y3 = getRandomMatrix(r, c, 1, 5, 1,
8);
+ double[][] Y4 = getRandomMatrix(r, c, 1, 5, 1,
9);
+
+ writeInputMatrixWithMTD("Y1", Y1, false, mc);
+ writeInputMatrixWithMTD("Y2", Y2, false, mc);
+ writeInputMatrixWithMTD("Y3", Y3, false, mc);
+ writeInputMatrixWithMTD("Y4", Y4, false, mc);
+
+ if (!alignedWeights) {
+ // Run reference dml script with a
normal matrix
+ fullDMLScriptName = HOME + TEST_NAME +
"Reference.dml";
+ programArgs = new String[] { "-stats",
"100", "-args",
+ input("X1"), input("X2"),
input("X3"), input("X4"),
+ input("Y1"), input("Y2"),
input("Y3"), input("Y4"),
+ input("W"), expected("S")
+ };
+ runTest(null);
+
+ // Run the dml script with federated
matrices
+ fullDMLScriptName = HOME + TEST_NAME +
".dml";
+ programArgs = new String[] {"-stats",
"100", "-nvargs",
+ "in_X1=" +
TestUtils.federatedAddress(port1, input("X1")),
+ "in_Y1=" +
TestUtils.federatedAddress(port1, input("Y1")),
+ "in_X2=" +
TestUtils.federatedAddress(port2, input("X2")),
+ "in_Y2=" +
TestUtils.federatedAddress(port2, input("Y2")),
+ "in_X3=" +
TestUtils.federatedAddress(port3, input("X3")),
+ "in_Y3=" +
TestUtils.federatedAddress(port3, input("Y3")),
+ "in_X4=" +
TestUtils.federatedAddress(port4, input("X4")),
+ "in_Y4=" +
TestUtils.federatedAddress(port4, input("Y4")),
+ "in_W1=" + input("W"), "rows="
+ rows, "cols=" + cols, "out_S=" + output("S")};
+ runTest(null);
+ }
+ else {
+ double[][] W1 = getRandomMatrix(r, c,
0, 1, 1, 3);
+ double[][] W2 = getRandomMatrix(r, c,
0, 1, 1, 7);
+ double[][] W3 = getRandomMatrix(r, c,
0, 1, 1, 8);
+ double[][] W4 = getRandomMatrix(r, c,
0, 1, 1, 9);
+
+ writeInputMatrixWithMTD("W1", W1,
false, mc);
+ writeInputMatrixWithMTD("W2", W2,
false, mc);
+ writeInputMatrixWithMTD("W3", W3,
false, mc);
+ writeInputMatrixWithMTD("W4", W4,
false, mc);
+
+ // Run reference dml script with a
normal matrix
+ fullDMLScriptName = HOME + TEST_NAME +
"Reference.dml";
+ programArgs = new String[] {"-stats",
"100", "-args",
+ input("X1"), input("X2"),
input("X3"), input("X4"),
+ input("Y1"), input("Y2"),
input("Y3"), input("Y4"),
+ input("W1"), input("W2"),
input("W3"), input("W4"), expected("S")
+ };
+ runTest(null);
+
+ // Run the dml script with federated
matrices and weights
+ fullDMLScriptName = HOME + TEST_NAME +
".dml";
+ programArgs = new String[] {"-stats",
"100", "-nvargs",
+ "in_X1=" +
TestUtils.federatedAddress(port1, input("X1")),
+ "in_Y1=" +
TestUtils.federatedAddress(port1, input("Y1")),
+ "in_W1=" +
TestUtils.federatedAddress(port1, input("W1")),
+ "in_X2=" +
TestUtils.federatedAddress(port2, input("X2")),
+ "in_Y2=" +
TestUtils.federatedAddress(port2, input("Y2")),
+ "in_W2=" +
TestUtils.federatedAddress(port2, input("W2")),
+ "in_X3=" +
TestUtils.federatedAddress(port3, input("X3")),
+ "in_Y3=" +
TestUtils.federatedAddress(port3, input("Y3")),
+ "in_W3=" +
TestUtils.federatedAddress(port3, input("W3")),
+ "in_X4=" +
TestUtils.federatedAddress(port4, input("X4")),
+ "in_Y4=" +
TestUtils.federatedAddress(port4, input("Y4")),
+ "in_W4=" +
TestUtils.federatedAddress(port4, input("W4")),
+ "rows=" + rows, "cols=" + cols,
"out_S=" + output("S")};
+ runTest(null);
+ }
+
+ }
+ else {
+ // Create a random 20x1 input matrix
+ double[][] Y = getRandomMatrix(rows, c, 1, 5,
1, 3);
+ writeInputMatrixWithMTD("Y", Y, false, new
MatrixCharacteristics(rows, cols, blocksize, r * c));
+
+ // Run reference dml script with a normal matrix
+ fullDMLScriptName = HOME + TEST_NAME +
"Reference.dml";
+ programArgs = new String[] {"-stats", "100",
"-args",
+ input("X1"), input("X2"), input("X3"),
input("X4"),
+ input("Y"), input("W"), expected("S")
+ };
+ runTest(null);
+
+ // Run the dml script with a federated matrix
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "100",
"-nvargs",
+ "in_X1=" +
TestUtils.federatedAddress(port1, input("X1")),
+ "in_X2=" +
TestUtils.federatedAddress(port2, input("X2")),
+ "in_X3=" +
TestUtils.federatedAddress(port3, input("X3")),
+ "in_X4=" +
TestUtils.federatedAddress(port4, input("X4")),
+ "in_W1=" + input("W"), "Y=" +
input("Y"),
+ "rows=" + rows, "cols=" + cols,
"out_S=" + output("S")};
+ runTest(null);
+ }
+
+ // compare via files
+ compareResults(1e-2);
+
Assert.assertTrue(heavyHittersContainsString("fed_cov"));
+
+ }
+ finally {
+ TestUtils.shutdownThreads(t1, t2, t3, t4);
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
}
diff --git
a/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTest.dml
b/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTest.dml
new file mode 100644
index 0000000000..da9db2f4de
--- /dev/null
+++
b/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTest.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# 5x1 on 4 workers -> 20x1
+X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows,
$cols)));
+
+# 5x1 on 4 workers -> 20x1
+Y = federated(addresses=list($in_Y1, $in_Y2, $in_Y3, $in_Y4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows,
$cols)));
+
+W = read($in_W1); # 20x1
+
+s = cov(X, Y, W);
+write(s, $out_S);
\ No newline at end of file
diff --git
a/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTestReference.dml
b/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTestReference.dml
new file mode 100644
index 0000000000..ee4062f7e6
--- /dev/null
+++
b/src/test/scripts/functions/federated/FederatedCovarianceAlignedWeightedTestReference.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($1), read($2), read($3), read($4)); # 20x1
+Y = rbind(read($5), read($6), read($7), read($8)); # 20x1
+W = read($9); # 20x1
+
+s = cov(X, Y, W);
+write(s, $10);
\ No newline at end of file
diff --git
a/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTest.dml
b/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTest.dml
new file mode 100644
index 0000000000..22029de451
--- /dev/null
+++
b/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTest.dml
@@ -0,0 +1,38 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# 5x1 on 4 workers -> 20x1
+X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows,
$cols)));
+
+# 5x1 on 4 workers -> 20x1
+Y = federated(addresses=list($in_Y1, $in_Y2, $in_Y3, $in_Y4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows,
$cols)));
+
+# 5x1 on 4 workers -> 20x1
+W = federated(addresses=list($in_W1, $in_W2, $in_W3, $in_W4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows,
$cols)));
+
+s = cov(X, Y, W);
+write(s, $out_S);
\ No newline at end of file
diff --git
a/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTestReference.dml
b/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTestReference.dml
new file mode 100644
index 0000000000..10c18f5a33
--- /dev/null
+++
b/src/test/scripts/functions/federated/FederatedCovarianceAllAlignedWeightedTestReference.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($1), read($2), read($3), read($4)); # 20x1
+Y = rbind(read($5), read($6), read($7), read($8)); # 20x1
+W = rbind(read($9), read($10), read($11), read($12)); # 20x1
+
+s = cov(X, Y, W);
+write(s, $13);
\ No newline at end of file
diff --git
a/src/test/scripts/functions/federated/FederatedCovarianceWeightedTest.dml
b/src/test/scripts/functions/federated/FederatedCovarianceWeightedTest.dml
new file mode 100644
index 0000000000..3ba2d5b15f
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedCovarianceWeightedTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# 5x1 on 4 workers -> 20x1
+X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows,
$cols)));
+
+Y = read($Y); # 20x1
+W = read($in_W1); # 20x1
+
+s = cov(X, Y, W);
+write(s, $out_S);
\ No newline at end of file
diff --git
a/src/test/scripts/functions/federated/FederatedCovarianceWeightedTestReference.dml
b/src/test/scripts/functions/federated/FederatedCovarianceWeightedTestReference.dml
new file mode 100644
index 0000000000..db1dc7c526
--- /dev/null
+++
b/src/test/scripts/functions/federated/FederatedCovarianceWeightedTestReference.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($1), read($2), read($3), read($4)); # 20x1
+Y = read($5); # 20x1
+W = read($6); # 20x1
+
+s = cov(X, Y, W);
+write(s, $7);
\ No newline at end of file