This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new bcea57e  [MINOR] Federated Rewriter Function Fix
bcea57e is described below

commit bcea57e767f90bf77cee90edf767b9ad73b86b28
Author: sebwrede <[email protected]>
AuthorDate: Fri Oct 15 12:14:57 2021 +0200

    [MINOR] Federated Rewriter Function Fix
    
    This commit edits federated rewriter to rewrite FunctionStatementBlocks.
    Before this commit, the FunctionStatementBlocks were not rewritten since 
they are stored in StatementBlocks as FunctionOps.
---
 .../hops/rewrite/IPAPassRewriteFederatedPlan.java  |  54 ++--
 .../hops/rewrite/RewriteFederatedExecution.java    |  11 +-
 .../privacy/algorithms/FederatedL2SVMTest.java     | 299 +++++++++++----------
 3 files changed, 195 insertions(+), 169 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java 
b/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
index 377ebb1..3d8fd2f 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
@@ -24,6 +24,7 @@ import org.apache.sysds.hops.AggBinaryOp;
 import org.apache.sysds.hops.AggUnaryOp;
 import org.apache.sysds.hops.BinaryOp;
 import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.FunctionOp;
 import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.hops.ReorgOp;
@@ -84,7 +85,7 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
         */
        @Override
        public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph 
fgraph, FunctionCallSizeInfo fcallSizes) {
-               rewriteStatementBlocks(prog.getStatementBlocks());
+               rewriteStatementBlocks(prog, prog.getStatementBlocks());
                return false;
        }
 
@@ -93,13 +94,14 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
         * by setting the federated output value of each hop in the statement 
blocks.
         * The method calls the contained statement blocks recursively.
         *
+        * @param prog dml program
         * @param sbs   list of statement blocks
         * @return list of statement blocks with the federated output value 
updated for each hop
         */
-       public ArrayList<StatementBlock> 
rewriteStatementBlocks(List<StatementBlock> sbs) {
+       public ArrayList<StatementBlock> rewriteStatementBlocks(DMLProgram 
prog, List<StatementBlock> sbs) {
                ArrayList<StatementBlock> rewrittenStmBlocks = new 
ArrayList<>();
                for ( StatementBlock stmBlock : sbs )
-                       
rewrittenStmBlocks.addAll(rewriteStatementBlock(stmBlock));
+                       rewrittenStmBlocks.addAll(rewriteStatementBlock(prog, 
stmBlock));
                return rewrittenStmBlocks;
        }
 
@@ -108,66 +110,80 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
         * by setting the federated output value of each hop in the statement 
blocks.
         * The method calls the contained statement blocks recursively.
         *
+        * @param prog dml program
         * @param sb    statement block
         * @return list of statement blocks with the federated output value 
updated for each hop
         */
-       public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock 
sb) {
+       public ArrayList<StatementBlock> rewriteStatementBlock(DMLProgram prog, 
StatementBlock sb) {
                if ( sb instanceof WhileStatementBlock)
-                       return rewriteWhileStatementBlock((WhileStatementBlock) 
sb);
+                       return rewriteWhileStatementBlock(prog, 
(WhileStatementBlock) sb);
                else if ( sb instanceof IfStatementBlock)
-                       return rewriteIfStatementBlock((IfStatementBlock) sb);
+                       return rewriteIfStatementBlock(prog, (IfStatementBlock) 
sb);
                else if ( sb instanceof ForStatementBlock){
                        // This also includes ParForStatementBlocks
-                       return rewriteForStatementBlock((ForStatementBlock) sb);
+                       return rewriteForStatementBlock(prog, 
(ForStatementBlock) sb);
                }
                else if ( sb instanceof FunctionStatementBlock)
-                       return 
rewriteFunctionStatementBlock((FunctionStatementBlock) sb);
+                       return rewriteFunctionStatementBlock(prog, 
(FunctionStatementBlock) sb);
                else {
                        // StatementBlock type (no subclass)
-                       selectFederatedExecutionPlan(sb.getHops());
+                       return rewriteDefaultStatementBlock(prog, sb);
                }
-               return new ArrayList<>(Collections.singletonList(sb));
        }
 
-       private ArrayList<StatementBlock> 
rewriteWhileStatementBlock(WhileStatementBlock whileSB){
+       private ArrayList<StatementBlock> rewriteWhileStatementBlock(DMLProgram 
prog, WhileStatementBlock whileSB){
                Hop whilePredicateHop = whileSB.getPredicateHops();
                selectFederatedExecutionPlan(whilePredicateHop);
                for ( Statement stm : whileSB.getStatements() ){
                        WhileStatement whileStm = (WhileStatement) stm;
-                       
whileStm.setBody(rewriteStatementBlocks(whileStm.getBody()));
+                       whileStm.setBody(rewriteStatementBlocks(prog, 
whileStm.getBody()));
                }
                return new ArrayList<>(Collections.singletonList(whileSB));
        }
 
-       private ArrayList<StatementBlock> 
rewriteIfStatementBlock(IfStatementBlock ifSB){
+       private ArrayList<StatementBlock> rewriteIfStatementBlock(DMLProgram 
prog, IfStatementBlock ifSB){
                selectFederatedExecutionPlan(ifSB.getPredicateHops());
                for ( Statement statement : ifSB.getStatements() ){
                        IfStatement ifStatement = (IfStatement) statement;
-                       
ifStatement.setIfBody(rewriteStatementBlocks(ifStatement.getIfBody()));
-                       
ifStatement.setElseBody(rewriteStatementBlocks(ifStatement.getElseBody()));
+                       ifStatement.setIfBody(rewriteStatementBlocks(prog, 
ifStatement.getIfBody()));
+                       ifStatement.setElseBody(rewriteStatementBlocks(prog, 
ifStatement.getElseBody()));
                }
                return new ArrayList<>(Collections.singletonList(ifSB));
        }
 
-       private ArrayList<StatementBlock> 
rewriteForStatementBlock(ForStatementBlock forSB){
+       private ArrayList<StatementBlock> rewriteForStatementBlock(DMLProgram 
prog, ForStatementBlock forSB){
                selectFederatedExecutionPlan(forSB.getFromHops());
                selectFederatedExecutionPlan(forSB.getToHops());
                selectFederatedExecutionPlan(forSB.getIncrementHops());
                for ( Statement statement : forSB.getStatements() ){
                        ForStatement forStatement = ((ForStatement)statement);
-                       
forStatement.setBody(rewriteStatementBlocks(forStatement.getBody()));
+                       forStatement.setBody(rewriteStatementBlocks(prog, 
forStatement.getBody()));
                }
                return new ArrayList<>(Collections.singletonList(forSB));
        }
 
-       private ArrayList<StatementBlock> 
rewriteFunctionStatementBlock(FunctionStatementBlock funcSB){
+       private ArrayList<StatementBlock> 
rewriteFunctionStatementBlock(DMLProgram prog, FunctionStatementBlock funcSB){
                for ( Statement statement : funcSB.getStatements() ){
                        FunctionStatement funcStm = (FunctionStatement) 
statement;
-                       
funcStm.setBody(rewriteStatementBlocks(funcStm.getBody()));
+                       funcStm.setBody(rewriteStatementBlocks(prog, 
funcStm.getBody()));
                }
                return new ArrayList<>(Collections.singletonList(funcSB));
        }
 
+       private ArrayList<StatementBlock> 
rewriteDefaultStatementBlock(DMLProgram prog, StatementBlock sb){
+               if ( sb.getHops() != null && !sb.getHops().isEmpty() ){
+                       for ( Hop sbHop : sb.getHops() ){
+                               if ( sbHop instanceof FunctionOp ){
+                                       String funcName = ((FunctionOp) 
sbHop).getFunctionName();
+                                       FunctionStatementBlock sbFuncBlock = 
prog.getBuiltinFunctionDictionary().getFunction(funcName);
+                                       rewriteStatementBlock(prog, 
sbFuncBlock);
+                               }
+                               else selectFederatedExecutionPlan(sbHop);
+                       }
+               }
+               return new ArrayList<>(Collections.singletonList(sb));
+       }
+
        /**
         * Sets FederatedOutput field of all hops in DAG starting from given 
root.
         * The FederatedOutput chosen for root is the minimum cost HopRel found 
in memo table for the given root.
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
index 75fa735..4de2557 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
@@ -56,15 +56,16 @@ public class RewriteFederatedExecution extends 
HopRewriteRule {
 
        @Override
        public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, 
ProgramRewriteStatus state) {
-               if ( roots == null )
-                       return null;
-               for ( Hop root : roots )
-                       visitHop(root);
+               if ( roots != null )
+                       for ( Hop root : roots )
+                               rewriteHopDAG(root, state);
                return roots;
        }
 
        @Override public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus 
state) {
-               return null;
+               if ( root != null )
+                       visitHop(root);
+               return root;
        }
 
        private void visitHop(Hop hop){
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java
 
b/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java
index e543d9e..cbadd98 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java
@@ -19,9 +19,10 @@
 
 package org.apache.sysds.test.functions.privacy.algorithms;
 
+import edu.emory.mathcs.backport.java.util.Arrays;
+import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.junit.Assert;
-import org.junit.Ignore;
 import org.junit.Test;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types;
@@ -31,12 +32,15 @@ import 
org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
-import org.apache.wink.json4j.JSONException;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
+import java.util.Collection;
 import java.util.HashMap;
 import java.util.Map;
 
 @net.jcip.annotations.NotThreadSafe
+@RunWith(value = Parameterized.class)
 public class FederatedL2SVMTest extends AutomatedTestBase {
 
        private final static String TEST_DIR = "functions/federated/";
@@ -47,57 +51,62 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
        private int rows = 100;
        private int cols = 10;
 
-       @Override
-       public void setUp() {
+       @Parameterized.Parameter()
+       public boolean fedOutCompilation;
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               return Arrays.asList(new Object[][]{
+                       {false},
+                       {true}
+               });
+       }
+
+       @Override public void setUp() {
                TestUtils.clearAssertionInformation();
                addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
        }
 
        // PrivateAggregation Single Input
 
-       @Test
-       @Ignore
-       public void federatedL2SVMCPPrivateAggregationX1() throws JSONException 
{
+       @Test public void federatedL2SVMCPPrivateAggregationX1()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X1", new 
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
-               federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, 
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+               federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, 
privacyConstraints, null,
+                       PrivacyLevel.PrivateAggregation);
        }
 
-       @Test
-       @Ignore
-       public void federatedL2SVMCPPrivateAggregationX2() throws JSONException 
{
+       @Test public void federatedL2SVMCPPrivateAggregationX2()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X2", new 
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
-               federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, 
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+               federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, 
privacyConstraints, null,
+                       PrivacyLevel.PrivateAggregation);
        }
 
-       @Test
-       public void federatedL2SVMCPPrivateAggregationY() throws JSONException {
+       @Test public void federatedL2SVMCPPrivateAggregationY()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("Y", new 
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
-               federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, 
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+               federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, 
privacyConstraints, null,
+                       PrivacyLevel.PrivateAggregation);
        }
 
        // Private Single Input
 
-       @Test
-       public void federatedL2SVMCPPrivateFederatedX1() throws JSONException {
+       @Test public void federatedL2SVMCPPrivateFederatedX1()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X1", new 
PrivacyConstraint(PrivacyLevel.Private));
-               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private,
-                       false, null, true, DMLRuntimeException.class);
+               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private, false, null, true,
+                       DMLRuntimeException.class);
        }
 
-       @Test
-       public void federatedL2SVMCPPrivateFederatedX2() throws JSONException {
+       @Test public void federatedL2SVMCPPrivateFederatedX2()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X2", new 
PrivacyConstraint(PrivacyLevel.Private));
-               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private,
-                       false, null, true, DMLRuntimeException.class);
+               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private, false, null, true,
+                       DMLRuntimeException.class);
        }
 
-       @Test
-       public void federatedL2SVMCPPrivateFederatedY() throws JSONException {
+       @Test public void federatedL2SVMCPPrivateFederatedY()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("Y", new 
PrivacyConstraint(PrivacyLevel.Private));
                federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, 
privacyConstraints, null, PrivacyLevel.Private);
@@ -105,230 +114,212 @@ public class FederatedL2SVMTest extends 
AutomatedTestBase {
 
        // Setting Privacy of Matrix (Throws Exception)
 
-       @Test
-       public void federatedL2SVMCPPrivateMatrixX1() throws JSONException {
+       @Test public void federatedL2SVMCPPrivateMatrixX1()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X1", new 
PrivacyConstraint(PrivacyLevel.Private));
-               federatedL2SVM(Types.ExecMode.SINGLE_NODE, null, 
privacyConstraints, PrivacyLevel.Private,
-                       false, null, false, null);
+               federatedL2SVM(Types.ExecMode.SINGLE_NODE, null, 
privacyConstraints, PrivacyLevel.Private, false, null, false,
+                       null);
        }
 
-       @Test
-       public void federatedL2SVMCPPrivateMatrixX2() throws JSONException {
+       @Test public void federatedL2SVMCPPrivateMatrixX2()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X2", new 
PrivacyConstraint(PrivacyLevel.Private));
-               federatedL2SVM(Types.ExecMode.SINGLE_NODE, null, 
privacyConstraints, PrivacyLevel.Private,
-                       false, null, false, null);
+               federatedL2SVM(Types.ExecMode.SINGLE_NODE, null, 
privacyConstraints, PrivacyLevel.Private, false, null, false,
+                       null);
        }
 
-       @Test
-       public void federatedL2SVMCPPrivateMatrixY() throws JSONException {
+       @Test public void federatedL2SVMCPPrivateMatrixY()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("Y", new 
PrivacyConstraint(PrivacyLevel.Private));
-               federatedL2SVM(Types.ExecMode.SINGLE_NODE, null, 
privacyConstraints, PrivacyLevel.Private,
-                       false, null, false, null);
+               federatedL2SVM(Types.ExecMode.SINGLE_NODE, null, 
privacyConstraints, PrivacyLevel.Private, false, null, false,
+                       null);
        }
 
-       @Test
-       public void federatedL2SVMCPPrivateFederatedAndMatrixX1() throws 
JSONException {
+       @Test public void federatedL2SVMCPPrivateFederatedAndMatrixX1()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X1", new 
PrivacyConstraint(PrivacyLevel.Private));
-               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
privacyConstraints, PrivacyLevel.Private,
-                       false, null, true, DMLRuntimeException.class);
+               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
privacyConstraints, PrivacyLevel.Private, false,
+                       null, true, DMLRuntimeException.class);
        }
 
-       @Test
-       public void federatedL2SVMCPPrivateFederatedAndMatrixX2() throws 
JSONException {
+       @Test public void federatedL2SVMCPPrivateFederatedAndMatrixX2()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X2", new 
PrivacyConstraint(PrivacyLevel.Private));
-               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
privacyConstraints, PrivacyLevel.Private,
-                       false, null, true, DMLRuntimeException.class);
+               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
privacyConstraints, PrivacyLevel.Private, false,
+                       null, true, DMLRuntimeException.class);
        }
 
-       @Test
-       public void federatedL2SVMCPPrivateFederatedAndMatrixY() throws 
JSONException {
+       @Test public void federatedL2SVMCPPrivateFederatedAndMatrixY()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("Y", new 
PrivacyConstraint(PrivacyLevel.Private));
-               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
privacyConstraints, PrivacyLevel.Private,
-                       false, null, false, null);
+               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
privacyConstraints, PrivacyLevel.Private, false,
+                       null, false, null);
        }
 
        // Privacy Level Private Combinations
 
-       @Test
-       public void federatedL2SVMCPPrivateFederatedX1X2() throws JSONException 
{
+       @Test public void federatedL2SVMCPPrivateFederatedX1X2()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X1", new 
PrivacyConstraint(PrivacyLevel.Private));
                privacyConstraints.put("X2", new 
PrivacyConstraint(PrivacyLevel.Private));
-               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private,
-                       false, null, true, DMLRuntimeException.class);
+               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private, false, null, true,
+                       DMLRuntimeException.class);
        }
 
-       @Test
-       public void federatedL2SVMCPPrivateFederatedX1Y() throws JSONException {
+       @Test public void federatedL2SVMCPPrivateFederatedX1Y()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X1", new 
PrivacyConstraint(PrivacyLevel.Private));
                privacyConstraints.put("Y", new 
PrivacyConstraint(PrivacyLevel.Private));
-               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private,
-                       false, null, true, DMLRuntimeException.class);
+               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private, false, null, true,
+                       DMLRuntimeException.class);
        }
 
-       @Test
-       public void federatedL2SVMCPPrivateFederatedX2Y() throws JSONException {
+       @Test public void federatedL2SVMCPPrivateFederatedX2Y()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X2", new 
PrivacyConstraint(PrivacyLevel.Private));
                privacyConstraints.put("Y", new 
PrivacyConstraint(PrivacyLevel.Private));
-               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private,
-                       false, null, true, DMLRuntimeException.class);
+               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private, false, null, true,
+                       DMLRuntimeException.class);
        }
 
-       @Test
-       public void federatedL2SVMCPPrivateFederatedX1X2Y() throws 
JSONException {
+       @Test public void federatedL2SVMCPPrivateFederatedX1X2Y()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X1", new 
PrivacyConstraint(PrivacyLevel.Private));
                privacyConstraints.put("X2", new 
PrivacyConstraint(PrivacyLevel.Private));
                privacyConstraints.put("Y", new 
PrivacyConstraint(PrivacyLevel.Private));
-               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private,
-                       false, null, true, DMLRuntimeException.class);
+               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private, false, null, true,
+                       DMLRuntimeException.class);
        }
 
        // Privacy Level PrivateAggregation Combinations
-       @Test
-       public void federatedL2SVMCPPrivateAggregationFederatedX1X2() throws 
JSONException {
+       @Test public void federatedL2SVMCPPrivateAggregationFederatedX1X2()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X1", new 
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
                privacyConstraints.put("X2", new 
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
-               federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, 
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+               federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, 
privacyConstraints, null,
+                       PrivacyLevel.PrivateAggregation);
        }
 
-       @Test
-       public void federatedL2SVMCPPrivateAggregationFederatedX1Y() throws 
JSONException {
+       @Test public void federatedL2SVMCPPrivateAggregationFederatedX1Y()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X1", new 
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
                privacyConstraints.put("Y", new 
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
-               federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, 
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+               federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, 
privacyConstraints, null,
+                       PrivacyLevel.PrivateAggregation);
        }
 
-       @Test
-       public void federatedL2SVMCPPrivateAggregationFederatedX2Y() throws 
JSONException {
+       @Test public void federatedL2SVMCPPrivateAggregationFederatedX2Y()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X2", new 
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
                privacyConstraints.put("Y", new 
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
-               federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, 
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+               federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, 
privacyConstraints, null,
+                       PrivacyLevel.PrivateAggregation);
        }
 
-       @Test
-       public void federatedL2SVMCPPrivateAggregationFederatedX1X2Y() throws 
JSONException {
+       @Test public void federatedL2SVMCPPrivateAggregationFederatedX1X2Y()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X1", new 
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
                privacyConstraints.put("X2", new 
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
                privacyConstraints.put("Y", new 
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
-               federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, 
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+               federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, 
privacyConstraints, null,
+                       PrivacyLevel.PrivateAggregation);
        }
 
        // Privacy Level Combinations
-       @Test
-       public void federatedL2SVMCPPrivatePrivateAggregationFederatedX1X2() 
throws JSONException {
+       @Test public void 
federatedL2SVMCPPrivatePrivateAggregationFederatedX1X2()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X1", new 
PrivacyConstraint(PrivacyLevel.Private));
                privacyConstraints.put("X2", new 
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
-               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private,
-                       false, null, true, DMLRuntimeException.class);
+               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private, false, null, true,
+                       DMLRuntimeException.class);
        }
 
-       @Test
-       public void federatedL2SVMCPPrivatePrivateAggregationFederatedX1Y() 
throws JSONException {
+       @Test public void 
federatedL2SVMCPPrivatePrivateAggregationFederatedX1Y()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X1", new 
PrivacyConstraint(PrivacyLevel.Private));
                privacyConstraints.put("Y", new 
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
-               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private,
-                       false, null, true, DMLRuntimeException.class);
+               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private, false, null, true,
+                       DMLRuntimeException.class);
        }
 
-       @Test
-       public void federatedL2SVMCPPrivatePrivateAggregationFederatedX2Y() 
throws JSONException {
+       @Test public void 
federatedL2SVMCPPrivatePrivateAggregationFederatedX2Y()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X2", new 
PrivacyConstraint(PrivacyLevel.Private));
                privacyConstraints.put("Y", new 
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
-               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private,
-                       false, null, true, DMLRuntimeException.class);
+               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private, false, null, true,
+                       DMLRuntimeException.class);
        }
 
-       @Test
-       public void federatedL2SVMCPPrivatePrivateAggregationFederatedYX1() 
throws JSONException {
+       @Test public void 
federatedL2SVMCPPrivatePrivateAggregationFederatedYX1()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("Y", new 
PrivacyConstraint(PrivacyLevel.Private));
                privacyConstraints.put("X1", new 
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
                federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, 
privacyConstraints, null, PrivacyLevel.Private);
        }
 
-       @Test
-       public void federatedL2SVMCPPrivatePrivateAggregationFederatedYX2() 
throws JSONException {
+       @Test public void 
federatedL2SVMCPPrivatePrivateAggregationFederatedYX2()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("Y", new 
PrivacyConstraint(PrivacyLevel.Private));
                privacyConstraints.put("X2", new 
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
                federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, 
privacyConstraints, null, PrivacyLevel.Private);
        }
 
-       @Test
-       public void federatedL2SVMCPPrivatePrivateAggregationFederatedX2X1() 
throws JSONException {
+       @Test public void 
federatedL2SVMCPPrivatePrivateAggregationFederatedX2X1()  {
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X2", new 
PrivacyConstraint(PrivacyLevel.Private));
                privacyConstraints.put("X1", new 
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
-               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private,
-                       false, null, true, DMLRuntimeException.class);
+               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private, false, null, true,
+                       DMLRuntimeException.class);
        }
 
        // Require Federated Workers to return matrix
 
-       @Test
-       public void federatedL2SVMCPPrivateAggregationX1Exception() throws 
JSONException {
-               rows = 1000; cols = 1;
+       @Test public void federatedL2SVMCPPrivateAggregationX1Exception()  {
+               rows = 1000;
+               cols = 1;
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X1", new 
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
-               federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, 
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+               federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, 
privacyConstraints, null,
+                       PrivacyLevel.PrivateAggregation);
        }
 
-       @Test
-       public void federatedL2SVMCPPrivateAggregationX2Exception() throws 
JSONException {
-               rows = 1000; cols = 1;
+       @Test public void federatedL2SVMCPPrivateAggregationX2Exception()  {
+               rows = 1000;
+               cols = 1;
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X2", new 
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
-               federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, 
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+               federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE, 
privacyConstraints, null,
+                       PrivacyLevel.PrivateAggregation);
        }
 
-       @Test
-       public void federatedL2SVMCPPrivateX1Exception() throws JSONException {
-               rows = 1000; cols = 1;
+       @Test public void federatedL2SVMCPPrivateX1Exception()  {
+               rows = 1000;
+               cols = 1;
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X1", new 
PrivacyConstraint(PrivacyLevel.Private));
-               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private,
-                       false, null, true, DMLRuntimeException.class);
+               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private, false, null, true,
+                       DMLRuntimeException.class);
        }
 
-       @Test
-       public void federatedL2SVMCPPrivateX2Exception() throws JSONException {
-               rows = 1000; cols = 1;
+       @Test public void federatedL2SVMCPPrivateX2Exception()  {
+               rows = 1000;
+               cols = 1;
                Map<String, PrivacyConstraint> privacyConstraints = new 
HashMap<>();
                privacyConstraints.put("X2", new 
PrivacyConstraint(PrivacyLevel.Private));
-               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private,
-                       false, null, true, DMLRuntimeException.class);
+               federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints, 
null, PrivacyLevel.Private, false, null, true,
+                       DMLRuntimeException.class);
        }
 
-       private void federatedL2SVMNoException(Types.ExecMode execMode, 
Map<String,
-                       PrivacyConstraint> privacyConstraintsFederated, 
Map<String, PrivacyConstraint> privacyConstraintsMatrix,
-                       PrivacyLevel expectedPrivacyLevel)
-               throws JSONException
-       {
-               federatedL2SVM(execMode, privacyConstraintsFederated, 
privacyConstraintsMatrix, expectedPrivacyLevel,
-                       false, null, false, null);
+       private void federatedL2SVMNoException(Types.ExecMode execMode,
+               Map<String, PrivacyConstraint> privacyConstraintsFederated,
+               Map<String, PrivacyConstraint> privacyConstraintsMatrix, 
PrivacyLevel expectedPrivacyLevel) {
+               federatedL2SVM(execMode, privacyConstraintsFederated, 
privacyConstraintsMatrix, expectedPrivacyLevel, false,
+                       null, false, null);
        }
 
        private void federatedL2SVM(Types.ExecMode execMode, Map<String, 
PrivacyConstraint> privacyConstraintsFederated,
-                       Map<String, PrivacyConstraint> 
privacyConstraintsMatrix, PrivacyLevel expectedPrivacyLevel, 
-                       boolean exception1, Class<?> expectedException1, 
boolean exception2, Class<?> expectedException2 ) 
-               throws JSONException
-       {
+               Map<String, PrivacyConstraint> privacyConstraintsMatrix, 
PrivacyLevel expectedPrivacyLevel, boolean exception1,
+               Class<?> expectedException1, boolean exception2, Class<?> 
expectedException2) {
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
                Types.ExecMode platformOld = rtplatform;
                rtplatform = execMode;
@@ -351,24 +342,40 @@ public class FederatedL2SVMTest extends AutomatedTestBase 
{
                                Y[i][0] = (Y[i][0] > 0) ? 1 : -1;
 
                        // Write privacy constraints of normal matrix
-                       if ( privacyConstraintsMatrix != null ){
-                               writeInputMatrixWithMTD("MX1", X1, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols), 
privacyConstraintsMatrix.get("X1"));
-                               writeInputMatrixWithMTD("MX2", X2, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols), 
privacyConstraintsMatrix.get("X2"));
-                               writeInputMatrixWithMTD("MY", Y, false, new 
MatrixCharacteristics(rows, 1, blocksize, rows), 
privacyConstraintsMatrix.get("Y"));
-                       } else {
-                               writeInputMatrixWithMTD("MX1", X1, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
-                               writeInputMatrixWithMTD("MX2", X2, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+                       if(privacyConstraintsMatrix != null) {
+                               writeInputMatrixWithMTD("MX1", X1, false,
+                                       new MatrixCharacteristics(halfRows, 
cols, blocksize, halfRows * cols),
+                                       privacyConstraintsMatrix.get("X1"));
+                               writeInputMatrixWithMTD("MX2", X2, false,
+                                       new MatrixCharacteristics(halfRows, 
cols, blocksize, halfRows * cols),
+                                       privacyConstraintsMatrix.get("X2"));
+                               writeInputMatrixWithMTD("MY", Y, false, new 
MatrixCharacteristics(rows, 1, blocksize, rows),
+                                       privacyConstraintsMatrix.get("Y"));
+                       }
+                       else {
+                               writeInputMatrixWithMTD("MX1", X1, false,
+                                       new MatrixCharacteristics(halfRows, 
cols, blocksize, halfRows * cols));
+                               writeInputMatrixWithMTD("MX2", X2, false,
+                                       new MatrixCharacteristics(halfRows, 
cols, blocksize, halfRows * cols));
                                writeInputMatrixWithMTD("MY", Y, false, new 
MatrixCharacteristics(rows, 1, blocksize, rows));
                        }
 
                        // Write privacy constraints of federated matrix
-                       if ( privacyConstraintsFederated != null ){
-                               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols), 
privacyConstraintsFederated.get("X1"));
-                               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols), 
privacyConstraintsFederated.get("X2"));
-                               writeInputMatrixWithMTD("Y", Y, false, new 
MatrixCharacteristics(rows, 1, blocksize, rows), 
privacyConstraintsFederated.get("Y"));
-                       } else {
-                               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
-                               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+                       if(privacyConstraintsFederated != null) {
+                               writeInputMatrixWithMTD("X1", X1, false,
+                                       new MatrixCharacteristics(halfRows, 
cols, blocksize, halfRows * cols),
+                                       privacyConstraintsFederated.get("X1"));
+                               writeInputMatrixWithMTD("X2", X2, false,
+                                       new MatrixCharacteristics(halfRows, 
cols, blocksize, halfRows * cols),
+                                       privacyConstraintsFederated.get("X2"));
+                               writeInputMatrixWithMTD("Y", Y, false, new 
MatrixCharacteristics(rows, 1, blocksize, rows),
+                                       privacyConstraintsFederated.get("Y"));
+                       }
+                       else {
+                               writeInputMatrixWithMTD("X1", X1, false,
+                                       new MatrixCharacteristics(halfRows, 
cols, blocksize, halfRows * cols));
+                               writeInputMatrixWithMTD("X2", X2, false,
+                                       new MatrixCharacteristics(halfRows, 
cols, blocksize, halfRows * cols));
                                writeInputMatrixWithMTD("Y", Y, false, new 
MatrixCharacteristics(rows, 1, blocksize, rows));
                        }
 
@@ -388,25 +395,27 @@ public class FederatedL2SVMTest extends AutomatedTestBase 
{
                        runTest(true, exception1, expectedException1, -1);
 
                        // Run actual dml script with federated matrix
+                       OptimizerUtils.FEDERATED_COMPILATION = 
fedOutCompilation;
                        fullDMLScriptName = HOME + TEST_NAME + ".dml";
-                       programArgs = new String[] {"-checkPrivacy", 
-                               "-nvargs", "in_X1=" + 
TestUtils.federatedAddress(port1, input("X1")),
+                       programArgs = new String[] {"-stats", "-checkPrivacy", 
"-nvargs",
+                               "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
                                "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")), "rows=" + rows, "cols=" + cols,
                                "in_Y=" + input("Y"), "single=FALSE", "out=" + 
output("Z")};
-                       
+
                        runTest(true, exception2, expectedException2, -1);
 
-                       if ( !(exception1 || exception2) ) {
+                       if(!(exception1 || exception2)) {
                                compareResults(1e-9);
                        }
 
-                       if ( expectedPrivacyLevel != null)
+                       if(expectedPrivacyLevel != null)
                                
Assert.assertTrue(checkedPrivacyConstraintsContains(expectedPrivacyLevel));
                }
                finally {
                        TestUtils.shutdownThreads(t1, t2);
                        rtplatform = platformOld;
                        DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+                       OptimizerUtils.FEDERATED_COMPILATION = false;
                }
        }
 }

Reply via email to