This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 9a318eeccc [SYSTEMDS-3796] Fix robustness federated weighted 
covariance and tests
9a318eeccc is described below

commit 9a318eeccc3ae1da999f47a3b8f1c4d003ea32bc
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Nov 28 09:38:31 2024 +0100

    [SYSTEMDS-3796] Fix robustness federated weighted covariance and tests
---
 .../controlprogram/federated/FederationMap.java    |  6 ++--
 .../instructions/fed/CovarianceFEDInstruction.java | 41 +++++++---------------
 .../primitives/part5/FederatedCovarianceTest.java  |  7 ++--
 3 files changed, 19 insertions(+), 35 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 91e6c156c4..2574c4f175 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -147,10 +147,8 @@ public class FederationMap {
                return broadcastSliced(data, null, transposed);
        }
 
-       public FederatedRequest[] broadcastSliced(MatrixLineagePair moLin,
-               boolean transposed) {
-               return broadcastSliced(moLin.getMO(), moLin.getLI(),
-                       transposed);
+       public FederatedRequest[] broadcastSliced(MatrixLineagePair moLin, 
boolean transposed) {
+               return broadcastSliced(moLin.getMO(), moLin.getLI(), 
transposed);
        }
 
        /**
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
index d7f28293ce..4d22fd753e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
@@ -114,27 +114,20 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
                                new CPOperand[]{input1, input2, input3}, new 
long[]{mo1.getFedMapping().getID(),
                                        mo2.getFedMapping().getID(), 
moLin3.getFedMapping().getID()});
                }
-                       
+
                FederatedRequest fr2 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
                FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), 
fr1.getID());
-               Future<FederatedResponse>[] covTmp = 
mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
-
-               //means
-               Future<FederatedResponse>[] meanTmp1 = processMean(mo1, moLin3, 
0);
-               Future<FederatedResponse>[] meanTmp2 = processMean(mo2, moLin3, 
1);
-
-               Double[] cov = getResponses(covTmp);
-               Double[] mean1 = getResponses(meanTmp1);
-               Double[] mean2 = getResponses(meanTmp2);
+               Double[] cov = 
getResponses(mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3));
+               Double[] mean1 = getResponses(processMean(mo1, moLin3, 0));
+               Double[] mean2 = getResponses(processMean(mo2, moLin3, 1));
 
                if (moLin3 == null) {
                        double result = aggCov(cov, mean1, mean2, 
mo1.getFedMapping().getFederatedRanges());
                        ec.setVariable(output.getName(), new 
DoubleObject(result));
                }
                else {
-                       Future<FederatedResponse>[] weightsSumTmp = 
getWeightsSum(moLin3, moLin3.getFedMapping().getID(), instString, 
moLin3.getFedMapping());
-                       Double[] weights = getResponses(weightsSumTmp);
-                       
+                       Double[] weights = getResponses(
+                               getWeightsSum(moLin3, 
moLin3.getFedMapping().getID(), instString, moLin3.getFedMapping()));
                        double result = aggWeightedCov(cov, mean1, mean2, 
weights);
                        ec.setVariable(output.getName(), new 
DoubleObject(result));
                }
@@ -154,21 +147,13 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
                        new CPOperand[]{input1, input2, input3},
                        new long[]{mo1.getFedMapping().getID(), 
mo2.getFedMapping().getID(), fr1[0].getID()}
                );
+               //sequential execution of cov and means for robustness
                FederatedRequest fr3 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
                FederatedRequest fr4 = mo1.getFedMapping().cleanup(getTID(), 
fr2.getID());
-               Future<FederatedResponse>[] covTmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
-
-               //means
-               Future<FederatedResponse>[] meanTmp1 = processMean(mo1, 0, 
fr1[0].getID());
-               Future<FederatedResponse>[] meanTmp2 = processMean(mo2, 1, 
fr1[0].getID());
-
-               Double[] cov = getResponses(covTmp);
-               Double[] mean1 = getResponses(meanTmp1);
-               Double[] mean2 = getResponses(meanTmp2);
-
-               Future<FederatedResponse>[] weightsSumTmp = 
getWeightsSum(moLin3, fr1[0].getID(), instString, mo1.getFedMapping());
-               Double[] weights = getResponses(weightsSumTmp);
-               
+               Double[] cov = 
getResponses(mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3, fr4));
+               Double[] mean1 = getResponses(processMean(mo1, 0, 
fr1[0].getID()));
+               Double[] mean2 = getResponses(processMean(mo2, 1, 
fr1[0].getID()));
+               Double[] weights = getResponses(getWeightsSum(moLin3, 
fr1[0].getID(), instString, mo1.getFedMapping()));
                double result = aggWeightedCov(cov, mean1, mean2, weights);
                ec.setVariable(output.getName(), new DoubleObject(result));
        }
@@ -243,7 +228,7 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
                                fr[i] = ((ScalarObject) 
ffr[i].get().getData()[0]).getDoubleValue();
                        }
                        catch(Exception e) {
-                               throw new 
DMLRuntimeException("CovarianceFEDInstruction: incorrect means or cov.");
+                               throw new 
DMLRuntimeException("CovarianceFEDInstruction: incorrect means or cov.", e);
                        }
                });
 
@@ -302,7 +287,7 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
        }
 
        private Future<FederatedResponse>[] processMean(MatrixObject mo1, 
MatrixLineagePair moLin3, int var){
-               String[] parts = instString.split("°");
+               String[] parts = instString.split(Lop.OPERAND_DELIMITOR);
                Future<FederatedResponse>[] meanTmp = null;
                if (moLin3 == null) {
                        String meanInstr = instString.replace(getOpcode(), 
getOpcode().replace("cov", "uamean"));
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
index 9fb42f23fd..48c9cab632 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
@@ -49,7 +49,7 @@ public class FederatedCovarianceTest extends 
AutomatedTestBase {
        private final static String TEST_DIR = "functions/federated/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedCovarianceTest.class.getSimpleName() + "/";
 
-       private final static int blocksize = 1024;
+       private final static int blocksize = 1000;
        @Parameterized.Parameter
        public int rows;
        @Parameterized.Parameter(1)
@@ -57,8 +57,9 @@ public class FederatedCovarianceTest extends 
AutomatedTestBase {
 
        @Parameterized.Parameters
        public static Collection<Object[]> data() {
-               return Arrays.asList(new Object[][] {{20, 1},
-                       // {100, 1}, {1000, 1}
+               return Arrays.asList(new Object[][] {
+                       {120, 1},
+                       {1100, 1},
                });
        }
 

Reply via email to