This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 2c2fedc0b26f53597679e9be2b77e523167444ce
Author: baunsgaard <[email protected]>
AuthorDate: Tue Sep 14 13:15:41 2021 +0200

    [SYSTEMDS-3018] Federated parameterserver print only if failing
---
 .../federated/paramserv/AvgModelFederatedParamservTest.java    | 10 ++++------
 .../functions/federated/paramserv/FederatedParamservTest.java  |  8 +++-----
 .../federated/paramserv/NbatchesFederatedParamservTest.java    | 10 ++++------
 .../functions/federated/primitives/FederatedRCBindTest.java    |  1 -
 .../primitives/FederatedWeightedUnaryMatrixMultTest.java       |  2 +-
 5 files changed, 12 insertions(+), 19 deletions(-)

diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
index 66482f3..702f632 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
@@ -24,8 +24,6 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.List;
 
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
@@ -39,7 +37,7 @@ import org.junit.runners.Parameterized;
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
 public class AvgModelFederatedParamservTest extends AutomatedTestBase {
-       private static final Log LOG = 
LogFactory.getLog(AvgModelFederatedParamservTest.class.getName());
+       // private static final Log LOG = 
LogFactory.getLog(AvgModelFederatedParamservTest.class.getName());
        private final static String TEST_DIR = "functions/federated/paramserv/";
        private final static String TEST_NAME = 
"AvgModelFederatedParamservTest";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
AvgModelFederatedParamservTest.class.getSimpleName() + "/";
@@ -131,7 +129,7 @@ public class AvgModelFederatedParamservTest extends 
AutomatedTestBase {
                // config
                getAndLoadTestConfiguration(TEST_NAME);
                String HOME = SCRIPT_DIR + TEST_DIR;
-               setOutputBuffering(false);
+               setOutputBuffering(true);
 
                int C = 1, Hin = 28, Win = 28;
                int numLabels = 10;
@@ -201,8 +199,8 @@ public class AvgModelFederatedParamservTest extends 
AutomatedTestBase {
                                "modelAvg=" +  
Boolean.toString(modelAvg).toUpperCase()));
 
                        programArgs = programArgsList.toArray(new String[0]);
-                       LOG.debug(runTest(null));
-                       Assert.assertEquals(0, 
Statistics.getNoOfExecutedSPInst());
+                       String log = runTest(null).toString();
+                       Assert.assertEquals("Test Failed \n" + log, 0, 
Statistics.getNoOfExecutedSPInst());
 
                        // shut down threads
                        for(int i = 0; i < _numFederatedWorkers; i++) {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
index c316214..fd40275 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -24,8 +24,6 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.List;
 
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
@@ -39,7 +37,7 @@ import org.junit.runners.Parameterized;
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
 public class FederatedParamservTest extends AutomatedTestBase {
-       private static final Log LOG = 
LogFactory.getLog(FederatedParamservTest.class.getName());
+       // private static final Log LOG = 
LogFactory.getLog(FederatedParamservTest.class.getName());
        private final static String TEST_DIR = "functions/federated/paramserv/";
        private final static String TEST_NAME = "FederatedParamservTest";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedParamservTest.class.getSimpleName() + "/";
@@ -199,8 +197,8 @@ public class FederatedParamservTest extends 
AutomatedTestBase {
                                        "seed=" + _seed));
 
                        programArgs = programArgsList.toArray(new String[0]);
-                       LOG.debug(runTest(null));
-                       Assert.assertEquals(0, 
Statistics.getNoOfExecutedSPInst());
+                       String log = runTest(null).toString();
+                       Assert.assertEquals("Test Failed \n" + log, 0, 
Statistics.getNoOfExecutedSPInst());
                        
                        // shut down threads
                        for(int i = 0; i < _numFederatedWorkers; i++) {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/NbatchesFederatedParamservTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/NbatchesFederatedParamservTest.java
index e2e4f20..9b9f9bb 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/NbatchesFederatedParamservTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/NbatchesFederatedParamservTest.java
@@ -24,8 +24,6 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.List;
 
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
@@ -39,7 +37,7 @@ import org.junit.runners.Parameterized;
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
 public class NbatchesFederatedParamservTest extends AutomatedTestBase {
-       private static final Log LOG = 
LogFactory.getLog(NbatchesFederatedParamservTest.class.getName());
+       // private static final Log LOG = 
LogFactory.getLog(NbatchesFederatedParamservTest.class.getName());
        private final static String TEST_DIR = "functions/federated/paramserv/";
        private final static String TEST_NAME = 
"NbatchesFederatedParamservTest";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
NbatchesFederatedParamservTest.class.getSimpleName() + "/";
@@ -111,7 +109,7 @@ public class NbatchesFederatedParamservTest extends 
AutomatedTestBase {
                // config
                getAndLoadTestConfiguration(TEST_NAME);
                String HOME = SCRIPT_DIR + TEST_DIR;
-               setOutputBuffering(false);
+               setOutputBuffering(true);
 
                int C = 1, Hin = 28, Win = 28;
                int numLabels = 10;
@@ -181,8 +179,8 @@ public class NbatchesFederatedParamservTest extends 
AutomatedTestBase {
                                "nbatches=" + _nbatches));
 
                        programArgs = programArgsList.toArray(new String[0]);
-                       LOG.debug(runTest(null));
-                       Assert.assertEquals(0, 
Statistics.getNoOfExecutedSPInst());
+                       String log = runTest(null).toString();
+                       Assert.assertEquals("Test Failed \n" + log,0, 
Statistics.getNoOfExecutedSPInst());
 
                        // shut down threads
                        for(int i = 0; i < _numFederatedWorkers; i++) {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
index 1470274..04b668d 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
@@ -28,7 +28,6 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
-import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java
index 8bc9fee..9086edf 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java
@@ -182,7 +182,7 @@ public class FederatedWeightedUnaryMatrixMultTest extends 
AutomatedTestBase
                        runTest(true, false, null, -1);
 
                        // compare the results via files
-                       HashMap<CellIndex, Double> refResults   = 
readDMLMatrixFromExpectedDir(OUTPUT_NAME);
+                       HashMap<CellIndex, Double> refResults = 
readDMLMatrixFromExpectedDir(OUTPUT_NAME);
                        HashMap<CellIndex, Double> fedResults = 
readDMLMatrixFromOutputDir(OUTPUT_NAME);
                        TestUtils.compareMatrices(fedResults, refResults, 
TOLERANCE, "Fed", "Ref");
 

Reply via email to