This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 3d97048e457a00af9288a9940f962392df3abcbc
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Oct 31 18:05:54 2020 +0100

    [SYSTEMDS-2679] Fix dim size propagation for federated data ops
    
    This patch fixes two issues that where encountered in federated
    parameter servers. First, we fixed null pointer executions in the
    cleanup of parallelized RDDs for the case of redundant cleanups. Second,
    we fixed the size propagation for federated DataOps which so far was
    never refreshed during recompilation after initial parsing. Each of
    these fixes alone would already solve the reported bug, but the fixed
    size propagation is important for all federated use cases in default
    hybrid execution mode.
---
 src/main/java/org/apache/sysds/hops/DataOp.java    |  29 +-
 src/main/java/org/apache/sysds/hops/Hop.java       |   4 +
 .../apache/sysds/hops/rewrite/HopRewriteUtils.java |   8 +-
 src/main/java/org/apache/sysds/lops/Federated.java |   3 +-
 .../context/SparkExecutionContext.java             |   4 +-
 .../federated/io/FederatedReaderTest.java          | 172 +++++------
 .../federated/io/FederatedWriterTest.java          | 192 ++++++------
 .../paramserv/FederatedParamservTest.java          | 325 +++++++++++----------
 8 files changed, 382 insertions(+), 355 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/DataOp.java 
b/src/main/java/org/apache/sysds/hops/DataOp.java
index 0046078..114006f 100644
--- a/src/main/java/org/apache/sysds/hops/DataOp.java
+++ b/src/main/java/org/apache/sysds/hops/DataOp.java
@@ -30,6 +30,7 @@ import org.apache.sysds.common.Types.OpOpData;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.conf.CompilerConfig.ConfigType;
 import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.hops.rewrite.HopRewriteUtils;
 import org.apache.sysds.lops.Data;
 import org.apache.sysds.lops.Federated;
 import org.apache.sysds.lops.Lop;
@@ -37,6 +38,7 @@ import org.apache.sysds.lops.LopProperties.ExecType;
 import org.apache.sysds.lops.LopsException;
 import org.apache.sysds.lops.Sql;
 import org.apache.sysds.parser.DataExpression;
+import static org.apache.sysds.parser.DataExpression.FED_RANGES;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.util.LocalFileUtils;
@@ -270,8 +272,7 @@ public class DataOp extends Hop {
                // construct lops for all input parameters
                HashMap<String, Lop> inputLops = new HashMap<>();
                for (Entry<String, Integer> cur : _paramIndexMap.entrySet()) {
-                       inputLops.put(cur.getKey(), 
getInput().get(cur.getValue())
-                                       .constructLops());
+                       inputLops.put(cur.getKey(), 
getInput().get(cur.getValue()).constructLops());
                }
 
                // Create the lop
@@ -488,21 +489,30 @@ public class DataOp extends Hop {
        }
        
        @Override
-       public void refreshSizeInformation()
-       {
-               if( _op == OpOpData.PERSISTENTWRITE || _op == 
OpOpData.TRANSIENTWRITE )
-               {
+       public void refreshSizeInformation() {
+               if( _op == OpOpData.PERSISTENTWRITE || _op == 
OpOpData.TRANSIENTWRITE ) {
                        Hop input1 = getInput().get(0);
                        setDim1(input1.getDim1());
                        setDim2(input1.getDim2());
                        setNnz(input1.getNnz());
                }
-               else //READ
-               {
+               else if( _op == OpOpData.FEDERATED ) {
+                       Hop ranges = 
getInput().get(getParameterIndex(FED_RANGES));
+                       long nrow = -1, ncol = -1;
+                       for( Hop c : ranges.getInput() ) {
+                               if( !(c.getInput(0) instanceof LiteralOp && 
c.getInput(1) instanceof LiteralOp))
+                                       return; // invalid size inference if 
not all know.
+                               nrow = Math.max(nrow, 
HopRewriteUtils.getIntValueSafe(c.getInput(0)));
+                               ncol = Math.max(ncol, 
HopRewriteUtils.getIntValueSafe(c.getInput(1)));
+                       }
+                       setDim1(nrow);
+                       setDim2(ncol);
+               }
+               else { //READ
                        //do nothing; dimensions updated via set output params
                }
        }
-               
+
        
        /**
         * Explicitly disables recompilation of transient reads, this 
additional information 
@@ -590,5 +600,4 @@ public class DataOp extends Hop {
                        }
                }
        }
-
 }
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java 
b/src/main/java/org/apache/sysds/hops/Hop.java
index 60a4dc5..bb0960d 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -758,6 +758,10 @@ public abstract class Hop implements ParseInfo {
                return _input;
        }
        
+       public Hop getInput(int ix) {
+               return _input.get(ix);
+       }
+       
        public void addInput( Hop h ) {
                _input.add(h);
                h._parent.add(this);
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index 1815d69..b1a8799 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -162,10 +162,14 @@ public class HopRewriteUtils
                }
        }
        
+       public static long getIntValueSafe( Hop op ) {
+               return getIntValueSafe((LiteralOp) op);
+       }
+       
        public static long getIntValueSafe( LiteralOp op ) {
                switch( op.getValueType() ) {
-                       case FP64:  return 
UtilFunctions.toLong(op.getDoubleValue());
-                       case INT64:     return op.getLongValue();
+                       case FP64:    return 
UtilFunctions.toLong(op.getDoubleValue());
+                       case INT64:   return op.getLongValue();
                        case BOOLEAN: return op.getBooleanValue() ? 1 : 0;
                        default: return Long.MAX_VALUE;
                }
diff --git a/src/main/java/org/apache/sysds/lops/Federated.java 
b/src/main/java/org/apache/sysds/lops/Federated.java
index 8aacbd7..52b52be 100644
--- a/src/main/java/org/apache/sysds/lops/Federated.java
+++ b/src/main/java/org/apache/sysds/lops/Federated.java
@@ -63,7 +63,6 @@ public class Federated extends Lop {
        
        @Override
        public String toString() {
-               // TODO Federated.toString() lop
-               return null;
+               return "FedInit";
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index 2be647d..41ac510 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -1847,8 +1847,8 @@ public class SparkExecutionContext extends 
ExecutionContext
                }
 
                public synchronized void deregisterRDD(int rddID) {
-                       long rddSize = _rdds.remove(rddID);
-                       _size -= rddSize;
+                       Long rddSize = _rdds.remove(rddID);
+                       _size -= (rddSize!=null) ? rddSize : 0;
                }
 
                public synchronized void clear() {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
index a8e4407..c14ac1d 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
@@ -38,98 +38,98 @@ import org.junit.runners.Parameterized;
 @net.jcip.annotations.NotThreadSafe
 public class FederatedReaderTest extends AutomatedTestBase {
 
-    // private static final Log LOG = 
LogFactory.getLog(FederatedReaderTest.class.getName());
-    private final static String TEST_DIR = "functions/federated/ioR/";
-    private final static String TEST_NAME = "FederatedReaderTest";
-    private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedReaderTest.class.getSimpleName() + "/";
-    private final static int blocksize = 1024;
-    @Parameterized.Parameter()
-    public int rows;
-    @Parameterized.Parameter(1)
-    public int cols;
-    @Parameterized.Parameter(2)
-    public boolean rowPartitioned;
-    @Parameterized.Parameter(3)
-    public int fedCount;
+       // private static final Log LOG = 
LogFactory.getLog(FederatedReaderTest.class.getName());
+       private final static String TEST_DIR = "functions/federated/ioR/";
+       private final static String TEST_NAME = "FederatedReaderTest";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedReaderTest.class.getSimpleName() + "/";
+       private final static int blocksize = 1024;
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+       @Parameterized.Parameter(2)
+       public boolean rowPartitioned;
+       @Parameterized.Parameter(3)
+       public int fedCount;
 
-    @Override
-    public void setUp() {
-        TestUtils.clearAssertionInformation();
-        addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, 
TEST_NAME));
-    }
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
+       }
 
-    @Parameterized.Parameters
-    public static Collection<Object[]> data() {
-        // number of rows or cols has to be >= number of federated locations.
-        return Arrays.asList(new Object[][] {{10, 13, true, 2},});
-    }
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               // number of rows or cols has to be >= number of federated 
locations.
+               return Arrays.asList(new Object[][] {{10, 13, true, 2},});
+       }
 
-    @Test
-    public void federatedSinglenodeRead() {
-        federatedRead(Types.ExecMode.SINGLE_NODE);
-    }
+       @Test
+       public void federatedSinglenodeRead() {
+               federatedRead(Types.ExecMode.SINGLE_NODE);
+       }
 
-    public void federatedRead(Types.ExecMode execMode) {
-        boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
-        Types.ExecMode platformOld = rtplatform;
-        rtplatform = execMode;
-        if(rtplatform == Types.ExecMode.SPARK) {
-            DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-        }
-        getAndLoadTestConfiguration(TEST_NAME);
+       public void federatedRead(Types.ExecMode execMode) {
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               Types.ExecMode platformOld = rtplatform;
+               rtplatform = execMode;
+               if(rtplatform == Types.ExecMode.SPARK) {
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               }
+               getAndLoadTestConfiguration(TEST_NAME);
 
-        // write input matrices
-        int halfRows = rows / 2;
-        long[][] begins = new long[][] {new long[] {0, 0}, new long[] 
{halfRows, 0}};
-        long[][] ends = new long[][] {new long[] {halfRows, cols}, new long[] 
{rows, cols}};
-        // We have two matrices handled by a single federated worker
-        double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
-        double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
-        writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
-        writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
-        // empty script name because we don't execute any script, just start 
the worker
-        fullDMLScriptName = "";
-        int port1 = getRandomAvailablePort();
-        int port2 = getRandomAvailablePort();
-        Thread t1 = startLocalFedWorkerThread(port1);
-        Thread t2 = startLocalFedWorkerThread(port2);
-        String host = "localhost";
+               // write input matrices
+               int halfRows = rows / 2;
+               long[][] begins = new long[][] {new long[] {0, 0}, new long[] 
{halfRows, 0}};
+               long[][] ends = new long[][] {new long[] {halfRows, cols}, new 
long[] {rows, cols}};
+               // We have two matrices handled by a single federated worker
+               double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
+               double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
+               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               Thread t1 = startLocalFedWorkerThread(port1);
+               Thread t2 = startLocalFedWorkerThread(port2);
+               String host = "localhost";
 
-        MatrixObject fed = 
FederatedTestObjectConstructor.constructFederatedInput(rows,
-            cols,
-            blocksize,
-            host,
-            begins,
-            ends,
-            new int[] {port1, port2},
-            new String[] {input("X1"), input("X2")},
-            input("X.json"));
-        writeInputFederatedWithMTD("X.json", fed, null);
+               MatrixObject fed = 
FederatedTestObjectConstructor.constructFederatedInput(rows,
+                       cols,
+                       blocksize,
+                       host,
+                       begins,
+                       ends,
+                       new int[] {port1, port2},
+                       new String[] {input("X1"), input("X2")},
+                       input("X.json"));
+               writeInputFederatedWithMTD("X.json", fed, null);
 
-        try {
-            // Run reference dml script with normal matrix
-            fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/" + 
TEST_NAME + (rowPartitioned ? "Row" : "Col")
-                + "Reference.dml";
-            programArgs = new String[] {"-args", input("X1"), input("X2")};
-            String refOut = runTest(null).toString();
-            // Run federated
-            fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/" + 
TEST_NAME + ".dml";
-            programArgs = new String[] {"-stats", "-args", input("X.json")};
-            String out = runTest(null).toString();
-            // LOG.error(out);
-            Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
-            // Verify output
-            Assert.assertEquals(Double.parseDouble(refOut.split("\n")[0]),
-                Double.parseDouble(out.split("\n")[0]),
-                0.00001);
-        }
-        catch(Exception e) {
-            e.printStackTrace();
-            Assert.assertTrue(false);
-        }
+               try {
+                       // Run reference dml script with normal matrix
+                       fullDMLScriptName = SCRIPT_DIR + 
"functions/federated/io/" + TEST_NAME + (rowPartitioned ? "Row" : "Col")
+                               + "Reference.dml";
+                       programArgs = new String[] {"-args", input("X1"), 
input("X2")};
+                       String refOut = runTest(null).toString();
+                       // Run federated
+                       fullDMLScriptName = SCRIPT_DIR + 
"functions/federated/io/" + TEST_NAME + ".dml";
+                       programArgs = new String[] {"-stats", "-args", 
input("X.json")};
+                       String out = runTest(null).toString();
+                       // LOG.error(out);
+                       
Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
+                       // Verify output
+                       
Assert.assertEquals(Double.parseDouble(refOut.split("\n")[0]),
+                               Double.parseDouble(out.split("\n")[0]),
+                               0.00001);
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       Assert.assertTrue(false);
+               }
 
-        TestUtils.shutdownThreads(t1, t2);
-        rtplatform = platformOld;
-        DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
-    }
+               TestUtils.shutdownThreads(t1, t2);
+               rtplatform = platformOld;
+               DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java
index ef92a67..e03474d 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java
@@ -36,100 +36,100 @@ import org.junit.runners.Parameterized;
 @net.jcip.annotations.NotThreadSafe
 public class FederatedWriterTest extends AutomatedTestBase {
 
-    // private static final Log LOG = 
LogFactory.getLog(FederatedWriterTest.class.getName());
-    private final static String TEST_DIR = "functions/federated/";
-    private final static String TEST_NAME = "FederatedWriterTest";
-    private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedWriterTest.class.getSimpleName() + "/";
-    private final static int blocksize = 1024;
-
-    @Parameterized.Parameter()
-    public int rows;
-    @Parameterized.Parameter(1)
-    public int cols;
-    @Parameterized.Parameter(2)
-    public boolean rowPartitioned;
-    @Parameterized.Parameter(3)
-    public int fedCount;
-
-    @Override
-    public void setUp() {
-        TestUtils.clearAssertionInformation();
-        addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, 
TEST_NAME));
-    }
-
-    @Parameterized.Parameters
-    public static Collection<Object[]> data() {
-        // number of rows or cols has to be >= number of federated locations.
-        return Arrays.asList(new Object[][] {{10, 13, true, 2},});
-    }
-
-    @Test
-    public void federatedSinglenodeWrite() {
-        federatedWrite(Types.ExecMode.SINGLE_NODE);
-    }
-
-    public void federatedWrite(Types.ExecMode execMode) {
-        boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
-        Types.ExecMode platformOld = rtplatform;
-        rtplatform = execMode;
-        if(rtplatform == Types.ExecMode.SPARK) {
-            DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-        }
-        getAndLoadTestConfiguration(TEST_NAME);
-
-        // write input matrices
-        int halfRows = rows / 2;
-        // We have two matrices handled by a single federated worker
-        double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
-        double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
-        writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
-        writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
-        // empty script name because we don't execute any script, just start 
the worker
-        fullDMLScriptName = "";
-        int port1 = getRandomAvailablePort();
-        int port2 = getRandomAvailablePort();
-        Thread t1 = startLocalFedWorkerThread(port1);
-        Thread t2 = startLocalFedWorkerThread(port2);
-
-        try {
-
-            // Run reader and write a federated json to enable the rest of the 
test
-            fullDMLScriptName = SCRIPT_DIR + 
"functions/federated/io/FederatedReaderTestCreate.dml";
-            programArgs = new String[] {"-stats", "-explain", "-args", 
input("X1"), input("X2"), port1 + "", port2 + "",
-                input("X.json")};
-            // String writer = runTest(null).toString();
-            runTest(null);
-            // LOG.error(writer);
-            // LOG.error("Writing Done");
-
-            // Run reference dml script with normal matrix
-            fullDMLScriptName = SCRIPT_DIR + 
"functions/federated/io/FederatedReaderTest.dml";
-            programArgs = new String[] {"-stats", "-args", input("X.json")};
-            String out = runTest(null).toString();
-
-            Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
-
-            fullDMLScriptName = SCRIPT_DIR + 
"functions/federated/io/FederatedReference.dml";
-            // programArgs = new String[] {"-args", input("X1"), input("X2")};
-            programArgs = new String[] {"-stats", "100", "-nvargs",
-                "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
-                "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), 
"rows=" + rows, "cols=" + cols};
-            String refOut = runTest(null).toString();
-
-            // Run federated
-
-            // Verify output
-            Assert.assertEquals(Double.parseDouble(refOut.split("\n")[0]),
-                Double.parseDouble(out.split("\n")[0]),
-                0.00001);
-        }
-        catch(Exception e) {
-            e.printStackTrace();
-            Assert.assertTrue(false);
-        }
-
-        TestUtils.shutdownThreads(t1, t2);
-        rtplatform = platformOld;
-        DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
-    }
+       // private static final Log LOG = 
LogFactory.getLog(FederatedWriterTest.class.getName());
+       private final static String TEST_DIR = "functions/federated/";
+       private final static String TEST_NAME = "FederatedWriterTest";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedWriterTest.class.getSimpleName() + "/";
+       private final static int blocksize = 1024;
+
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+       @Parameterized.Parameter(2)
+       public boolean rowPartitioned;
+       @Parameterized.Parameter(3)
+       public int fedCount;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
+       }
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               // number of rows or cols has to be >= number of federated 
locations.
+               return Arrays.asList(new Object[][] {{10, 13, true, 2},});
+       }
+
+       @Test
+       public void federatedSinglenodeWrite() {
+               federatedWrite(Types.ExecMode.SINGLE_NODE);
+       }
+
+       public void federatedWrite(Types.ExecMode execMode) {
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               Types.ExecMode platformOld = rtplatform;
+               rtplatform = execMode;
+               if(rtplatform == Types.ExecMode.SPARK) {
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               }
+               getAndLoadTestConfiguration(TEST_NAME);
+
+               // write input matrices
+               int halfRows = rows / 2;
+               // We have two matrices handled by a single federated worker
+               double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
+               double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
+               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               Thread t1 = startLocalFedWorkerThread(port1);
+               Thread t2 = startLocalFedWorkerThread(port2);
+
+               try {
+
+                       // Run reader and write a federated json to enable the 
rest of the test
+                       fullDMLScriptName = SCRIPT_DIR + 
"functions/federated/io/FederatedReaderTestCreate.dml";
+                       programArgs = new String[] {"-stats", "-explain", 
"-args", input("X1"), input("X2"), port1 + "", port2 + "",
+                               input("X.json")};
+                       // String writer = runTest(null).toString();
+                       runTest(null);
+                       // LOG.error(writer);
+                       // LOG.error("Writing Done");
+
+                       // Run reference dml script with normal matrix
+                       fullDMLScriptName = SCRIPT_DIR + 
"functions/federated/io/FederatedReaderTest.dml";
+                       programArgs = new String[] {"-stats", "-args", 
input("X.json")};
+                       String out = runTest(null).toString();
+
+                       
Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
+
+                       fullDMLScriptName = SCRIPT_DIR + 
"functions/federated/io/FederatedReference.dml";
+                       // programArgs = new String[] {"-args", input("X1"), 
input("X2")};
+                       programArgs = new String[] {"-stats", "100", "-nvargs",
+                               "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                               "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")), "rows=" + rows, "cols=" + cols};
+                       String refOut = runTest(null).toString();
+
+                       // Run federated
+
+                       // Verify output
+                       
Assert.assertEquals(Double.parseDouble(refOut.split("\n")[0]),
+                               Double.parseDouble(out.split("\n")[0]),
+                               0.00001);
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       Assert.assertTrue(false);
+               }
+
+               TestUtils.shutdownThreads(t1, t2);
+               rtplatform = platformOld;
+               DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
index 194df09..9b321e4 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -24,172 +24,183 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.List;
 
-import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
+
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
 public class FederatedParamservTest extends AutomatedTestBase {
-    private final static String TEST_DIR = "functions/federated/paramserv/";
-    private final static String TEST_NAME = "FederatedParamservTest";
-    private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedParamservTest.class.getSimpleName() + "/";
-    private final static int _blocksize = 1024;
-
-    private final String _networkType;
-    private final int _numFederatedWorkers;
-    private final int _examplesPerWorker;
-    private final int _epochs;
-    private final int _batch_size;
-    private final double _eta;
-    private final String _utype;
-    private final String _freq;
-
-    private Types.ExecMode _platformOld;
-
-    // parameters
-    @Parameterized.Parameters
-    public static Collection<Object[]> parameters() {
-        return Arrays.asList(new Object[][] {
-                //Network type, number of federated workers, examples per 
worker, batch size, epochs, learning rate, update type, update frequency
-                {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"},
-                {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"},
-                {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"},
-                {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"},
-                {"CNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"},
-                {"CNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"},
-                {"CNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"},
-                {"CNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"},
-                {"TwoNN", 5, 1000, 32, 2, 0.01, "BSP", "BATCH"},
-                {"TwoNN", 5, 1000, 32, 2, 0.01, "ASP", "BATCH"},
-                {"TwoNN", 5, 1000, 32, 2, 0.01, "BSP", "EPOCH"},
-                {"TwoNN", 5, 1000, 32, 2, 0.01, "ASP", "EPOCH"},
-                {"CNN", 5, 1000, 32, 2, 0.01, "BSP", "BATCH"},
-                {"CNN", 5, 1000, 32, 2, 0.01, "ASP", "BATCH"},
-                {"CNN", 5, 1000, 32, 2, 0.01, "BSP", "EPOCH"},
-                {"CNN", 5, 1000, 32, 2, 0.01, "ASP", "EPOCH"}
-        });
-    }
-
-    public FederatedParamservTest(String networkType, int numFederatedWorkers, 
int examplesPerWorker, int batch_size, int epochs, double eta, String utype, 
String freq) {
-        _networkType = networkType;
-        _numFederatedWorkers = numFederatedWorkers;
-        _examplesPerWorker = examplesPerWorker;
-        _batch_size = batch_size;
-        _epochs = epochs;
-        _eta = eta;
-        _utype = utype;
-        _freq = freq;
-    }
-
-    @Override
-    public void setUp() {
-        TestUtils.clearAssertionInformation();
-        addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, 
TEST_NAME));
-
-        _platformOld = setExecMode(Types.ExecMode.SINGLE_NODE);
-    }
-
-    @Override
-    public void tearDown() {
-
-        rtplatform = _platformOld;
-    }
-
-    @Test
-    public void federatedParamserv() {
-        // config
-        getAndLoadTestConfiguration(TEST_NAME);
-        String HOME = SCRIPT_DIR + TEST_DIR;
-        setOutputBuffering(true);
-
-        int C = 1, Hin = 28, Win = 28;
-        int numFeatures = C*Hin*Win;
-        int numLabels = 10;
-
-        // dml name
-        fullDMLScriptName = HOME + TEST_NAME + ".dml";
-        // generate program args
-        List<String> programArgsList = new ArrayList<>(Arrays.asList(
-                "-stats",
-                "-nvargs",
-                "examples_per_worker=" + _examplesPerWorker,
-                "num_features=" + numFeatures,
-                "num_labels=" + numLabels,
-                "epochs=" + _epochs,
-                "batch_size=" + _batch_size,
-                "eta=" + _eta,
-                "utype=" + _utype,
-                "freq=" + _freq,
-                "network_type=" + _networkType,
-                "channels=" + C,
-                "hin=" + Hin,
-                "win=" + Win
-        ));
-
-        // for each worker
-        List<Integer> ports = new ArrayList<>();
-        List<Thread> threads = new ArrayList<>();
-        for(int i = 0; i < _numFederatedWorkers; i++) {
-            // write row partitioned features to disk
-            writeInputMatrixWithMTD("X" + i, 
generateDummyMNISTFeatures(_examplesPerWorker, C, Hin, Win), false,
-                    new MatrixCharacteristics(_examplesPerWorker, numFeatures, 
_blocksize, _examplesPerWorker * numFeatures));
-            // write row partitioned labels to disk
-            writeInputMatrixWithMTD("y" + i, 
generateDummyMNISTLabels(_examplesPerWorker, numLabels), false,
-                    new MatrixCharacteristics(_examplesPerWorker, numLabels, 
_blocksize, _examplesPerWorker * numLabels));
-
-            // start worker
-            ports.add(getRandomAvailablePort());
-            threads.add(startLocalFedWorkerThread(ports.get(i)));
-
-            // add worker to program args
-            programArgsList.add("X" + i + "=" + 
TestUtils.federatedAddress(ports.get(i), input("X" + i)));
-            programArgsList.add("y" + i + "=" + 
TestUtils.federatedAddress(ports.get(i), input("y" + i)));
-        }
-
-        programArgs = programArgsList.toArray(new String[0]);
-        // ByteArrayOutputStream stdout =
-        runTest(null);
-        // System.out.print(stdout.toString());
-
-        // cleanup
-        for(int i = 0; i < _numFederatedWorkers; i++) {
-            TestUtils.shutdownThreads(threads.get(i));
-        }
-    }
-
-    /**
-     * Generates an feature matrix that has the same format as the MNIST 
dataset,
-     * but is completely random and normalized
-     *
-     *  @param numExamples Number of examples to generate
-     *  @param C Channels in the input data
-     *  @param Hin Height in Pixels of the input data
-     *  @param Win Width in Pixels of the input data
-     *  @return a dummy MNIST feature matrix
-     */
-    private double[][] generateDummyMNISTFeatures(int numExamples, int C, int 
Hin, int Win) {
-        // Seed -1 takes the time in milliseconds as a seed
-        // Sparsity 1 means no sparsity
-        return getRandomMatrix(numExamples, C*Hin*Win, 0, 1, 1, -1);
-    }
-
-    /**
-     * Generates an label matrix that has the same format as the MNIST 
dataset, but is completely random and consists
-     * of one hot encoded vectors as rows
-     *
-     *  @param numExamples Number of examples to generate
-     *  @param numLabels Number of labels to generate
-     *  @return a dummy MNIST lable matrix
-     */
-    private double[][] generateDummyMNISTLabels(int numExamples, int 
numLabels) {
-        // Seed -1 takes the time in milliseconds as a seed
-        // Sparsity 1 means no sparsity
-        return getRandomMatrix(numExamples, numLabels, 0, 1, 1, -1);
-    }
+       private final static String TEST_DIR = "functions/federated/paramserv/";
+       private final static String TEST_NAME = "FederatedParamservTest";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedParamservTest.class.getSimpleName() + "/";
+       private final static int _blocksize = 1024;
+
+       private final String _networkType;
+       private final int _numFederatedWorkers;
+       private final int _examplesPerWorker;
+       private final int _epochs;
+       private final int _batch_size;
+       private final double _eta;
+       private final String _utype;
+       private final String _freq;
+
+       // parameters
+       @Parameterized.Parameters
+       public static Collection<Object[]> parameters() {
+               return Arrays.asList(new Object[][] {
+                       //Network type, number of federated workers, examples 
per worker, batch size, epochs, learning rate, update type, update frequency
+                       {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"},
+                       {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"},
+                       {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"},
+                       {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"},
+                       {"CNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"},
+                       {"CNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"},
+                       {"CNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"},
+                       {"CNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"},
+                       {"TwoNN", 5, 1000, 32, 2, 0.01, "BSP", "BATCH"},
+                       {"TwoNN", 5, 1000, 32, 2, 0.01, "ASP", "BATCH"},
+                       {"TwoNN", 5, 1000, 32, 2, 0.01, "BSP", "EPOCH"},
+                       {"TwoNN", 5, 1000, 32, 2, 0.01, "ASP", "EPOCH"},
+                       {"CNN", 5, 1000, 32, 2, 0.01, "BSP", "BATCH"},
+                       {"CNN", 5, 1000, 32, 2, 0.01, "ASP", "BATCH"},
+                       {"CNN", 5, 1000, 32, 2, 0.01, "BSP", "EPOCH"},
+                       {"CNN", 5, 1000, 32, 2, 0.01, "ASP", "EPOCH"}
+               });
+       }
+
+       public FederatedParamservTest(String networkType, int 
numFederatedWorkers, int examplesPerWorker, int batch_size, int epochs, double 
eta, String utype, String freq) {
+               _networkType = networkType;
+               _numFederatedWorkers = numFederatedWorkers;
+               _examplesPerWorker = examplesPerWorker;
+               _batch_size = batch_size;
+               _epochs = epochs;
+               _eta = eta;
+               _utype = utype;
+               _freq = freq;
+       }
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
+       }
+
+       @Test
+       public void federatedParamservSingleNode() {
+               federatedParamserv(ExecMode.SINGLE_NODE);
+       }
+       
+       @Test
+       public void federatedParamservHybrid() {
+               federatedParamserv(ExecMode.HYBRID);
+       }
+       
+       private void federatedParamserv(ExecMode mode) {
+               // config
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+               setOutputBuffering(true);
+
+               int C = 1, Hin = 28, Win = 28;
+               int numFeatures = C*Hin*Win;
+               int numLabels = 10;
+
+               ExecMode platformOld = setExecMode(mode);
+               
+               try {
+               
+                       // dml name
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       // generate program args
+                       List<String> programArgsList = new 
ArrayList<>(Arrays.asList(
+                               "-stats",
+                               "-nvargs",
+                               "examples_per_worker=" + _examplesPerWorker,
+                               "num_features=" + numFeatures,
+                               "num_labels=" + numLabels,
+                               "epochs=" + _epochs,
+                               "batch_size=" + _batch_size,
+                               "eta=" + _eta,
+                               "utype=" + _utype,
+                               "freq=" + _freq,
+                               "network_type=" + _networkType,
+                               "channels=" + C,
+                               "hin=" + Hin,
+                               "win=" + Win
+                       ));
+       
+                       // for each worker
+                       List<Integer> ports = new ArrayList<>();
+                       List<Thread> threads = new ArrayList<>();
+                       for(int i = 0; i < _numFederatedWorkers; i++) {
+                               // write row partitioned features to disk
+                               writeInputMatrixWithMTD("X" + i, 
generateDummyMNISTFeatures(_examplesPerWorker, C, Hin, Win), false,
+                                               new 
MatrixCharacteristics(_examplesPerWorker, numFeatures, _blocksize, 
_examplesPerWorker * numFeatures));
+                               // write row partitioned labels to disk
+                               writeInputMatrixWithMTD("y" + i, 
generateDummyMNISTLabels(_examplesPerWorker, numLabels), false,
+                                               new 
MatrixCharacteristics(_examplesPerWorker, numLabels, _blocksize, 
_examplesPerWorker * numLabels));
+       
+                               // start worker
+                               ports.add(getRandomAvailablePort());
+                               
threads.add(startLocalFedWorkerThread(ports.get(i)));
+       
+                               // add worker to program args
+                               programArgsList.add("X" + i + "=" + 
TestUtils.federatedAddress(ports.get(i), input("X" + i)));
+                               programArgsList.add("y" + i + "=" + 
TestUtils.federatedAddress(ports.get(i), input("y" + i)));
+                       }
+       
+                       programArgs = programArgsList.toArray(new String[0]);
+                       // ByteArrayOutputStream stdout =
+                       runTest(null);
+                       // System.out.print(stdout.toString());
+                       Assert.assertEquals(0, 
Statistics.getNoOfExecutedSPInst());
+                       
+                       // cleanup
+                       for(int i = 0; i < _numFederatedWorkers; i++) {
+                               TestUtils.shutdownThreads(threads.get(i));
+                       }
+               }
+               finally {
+                       resetExecMode(platformOld);
+               }
+       }
+
+       /**
+        * Generates an feature matrix that has the same format as the MNIST 
dataset,
+        * but is completely random and normalized
+        *
+        *  @param numExamples Number of examples to generate
+        *  @param C Channels in the input data
+        *  @param Hin Height in Pixels of the input data
+        *  @param Win Width in Pixels of the input data
+        *  @return a dummy MNIST feature matrix
+        */
+       private double[][] generateDummyMNISTFeatures(int numExamples, int C, 
int Hin, int Win) {
+               // Seed -1 takes the time in milliseconds as a seed
+               // Sparsity 1 means no sparsity
+               return getRandomMatrix(numExamples, C*Hin*Win, 0, 1, 1, -1);
+       }
+
+       /**
+        * Generates an label matrix that has the same format as the MNIST 
dataset, but is completely random and consists
+        * of one hot encoded vectors as rows
+        *
+        *  @param numExamples Number of examples to generate
+        *  @param numLabels Number of labels to generate
+        *  @return a dummy MNIST lable matrix
+        */
+       private double[][] generateDummyMNISTLabels(int numExamples, int 
numLabels) {
+               // Seed -1 takes the time in milliseconds as a seed
+               // Sparsity 1 means no sparsity
+               return getRandomMatrix(numExamples, numLabels, 0, 1, 1, -1);
+       }
 }

Reply via email to