This is an automated email from the ASF dual-hosted git repository.
sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new bcea57e [MINOR] Federated Rewriter Function Fix
bcea57e is described below
commit bcea57e767f90bf77cee90edf767b9ad73b86b28
Author: sebwrede <[email protected]>
AuthorDate: Fri Oct 15 12:14:57 2021 +0200
[MINOR] Federated Rewriter Function Fix
This commit edits federated rewriter to rewrite FunctionStatementBlocks.
Before this commit, the FunctionStatementBlocks were not rewritten since
they are stored in StatementBlocks as FunctionOps.
---
.../hops/rewrite/IPAPassRewriteFederatedPlan.java | 54 ++--
.../hops/rewrite/RewriteFederatedExecution.java | 11 +-
.../privacy/algorithms/FederatedL2SVMTest.java | 299 +++++++++++----------
3 files changed, 195 insertions(+), 169 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
b/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
index 377ebb1..3d8fd2f 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
@@ -24,6 +24,7 @@ import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.ReorgOp;
@@ -84,7 +85,7 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
*/
@Override
public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph
fgraph, FunctionCallSizeInfo fcallSizes) {
- rewriteStatementBlocks(prog.getStatementBlocks());
+ rewriteStatementBlocks(prog, prog.getStatementBlocks());
return false;
}
@@ -93,13 +94,14 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
* by setting the federated output value of each hop in the statement
blocks.
* The method calls the contained statement blocks recursively.
*
+ * @param prog dml program
* @param sbs list of statement blocks
* @return list of statement blocks with the federated output value
updated for each hop
*/
- public ArrayList<StatementBlock>
rewriteStatementBlocks(List<StatementBlock> sbs) {
+ public ArrayList<StatementBlock> rewriteStatementBlocks(DMLProgram
prog, List<StatementBlock> sbs) {
ArrayList<StatementBlock> rewrittenStmBlocks = new
ArrayList<>();
for ( StatementBlock stmBlock : sbs )
-
rewrittenStmBlocks.addAll(rewriteStatementBlock(stmBlock));
+ rewrittenStmBlocks.addAll(rewriteStatementBlock(prog,
stmBlock));
return rewrittenStmBlocks;
}
@@ -108,66 +110,80 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
* by setting the federated output value of each hop in the statement
blocks.
* The method calls the contained statement blocks recursively.
*
+ * @param prog dml program
* @param sb statement block
* @return list of statement blocks with the federated output value
updated for each hop
*/
- public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock
sb) {
+ public ArrayList<StatementBlock> rewriteStatementBlock(DMLProgram prog,
StatementBlock sb) {
if ( sb instanceof WhileStatementBlock)
- return rewriteWhileStatementBlock((WhileStatementBlock)
sb);
+ return rewriteWhileStatementBlock(prog,
(WhileStatementBlock) sb);
else if ( sb instanceof IfStatementBlock)
- return rewriteIfStatementBlock((IfStatementBlock) sb);
+ return rewriteIfStatementBlock(prog, (IfStatementBlock)
sb);
else if ( sb instanceof ForStatementBlock){
// This also includes ParForStatementBlocks
- return rewriteForStatementBlock((ForStatementBlock) sb);
+ return rewriteForStatementBlock(prog,
(ForStatementBlock) sb);
}
else if ( sb instanceof FunctionStatementBlock)
- return
rewriteFunctionStatementBlock((FunctionStatementBlock) sb);
+ return rewriteFunctionStatementBlock(prog,
(FunctionStatementBlock) sb);
else {
// StatementBlock type (no subclass)
- selectFederatedExecutionPlan(sb.getHops());
+ return rewriteDefaultStatementBlock(prog, sb);
}
- return new ArrayList<>(Collections.singletonList(sb));
}
- private ArrayList<StatementBlock>
rewriteWhileStatementBlock(WhileStatementBlock whileSB){
+ private ArrayList<StatementBlock> rewriteWhileStatementBlock(DMLProgram
prog, WhileStatementBlock whileSB){
Hop whilePredicateHop = whileSB.getPredicateHops();
selectFederatedExecutionPlan(whilePredicateHop);
for ( Statement stm : whileSB.getStatements() ){
WhileStatement whileStm = (WhileStatement) stm;
-
whileStm.setBody(rewriteStatementBlocks(whileStm.getBody()));
+ whileStm.setBody(rewriteStatementBlocks(prog,
whileStm.getBody()));
}
return new ArrayList<>(Collections.singletonList(whileSB));
}
- private ArrayList<StatementBlock>
rewriteIfStatementBlock(IfStatementBlock ifSB){
+ private ArrayList<StatementBlock> rewriteIfStatementBlock(DMLProgram
prog, IfStatementBlock ifSB){
selectFederatedExecutionPlan(ifSB.getPredicateHops());
for ( Statement statement : ifSB.getStatements() ){
IfStatement ifStatement = (IfStatement) statement;
-
ifStatement.setIfBody(rewriteStatementBlocks(ifStatement.getIfBody()));
-
ifStatement.setElseBody(rewriteStatementBlocks(ifStatement.getElseBody()));
+ ifStatement.setIfBody(rewriteStatementBlocks(prog,
ifStatement.getIfBody()));
+ ifStatement.setElseBody(rewriteStatementBlocks(prog,
ifStatement.getElseBody()));
}
return new ArrayList<>(Collections.singletonList(ifSB));
}
- private ArrayList<StatementBlock>
rewriteForStatementBlock(ForStatementBlock forSB){
+ private ArrayList<StatementBlock> rewriteForStatementBlock(DMLProgram
prog, ForStatementBlock forSB){
selectFederatedExecutionPlan(forSB.getFromHops());
selectFederatedExecutionPlan(forSB.getToHops());
selectFederatedExecutionPlan(forSB.getIncrementHops());
for ( Statement statement : forSB.getStatements() ){
ForStatement forStatement = ((ForStatement)statement);
-
forStatement.setBody(rewriteStatementBlocks(forStatement.getBody()));
+ forStatement.setBody(rewriteStatementBlocks(prog,
forStatement.getBody()));
}
return new ArrayList<>(Collections.singletonList(forSB));
}
- private ArrayList<StatementBlock>
rewriteFunctionStatementBlock(FunctionStatementBlock funcSB){
+ private ArrayList<StatementBlock>
rewriteFunctionStatementBlock(DMLProgram prog, FunctionStatementBlock funcSB){
for ( Statement statement : funcSB.getStatements() ){
FunctionStatement funcStm = (FunctionStatement)
statement;
-
funcStm.setBody(rewriteStatementBlocks(funcStm.getBody()));
+ funcStm.setBody(rewriteStatementBlocks(prog,
funcStm.getBody()));
}
return new ArrayList<>(Collections.singletonList(funcSB));
}
+ private ArrayList<StatementBlock>
rewriteDefaultStatementBlock(DMLProgram prog, StatementBlock sb){
+ if ( sb.getHops() != null && !sb.getHops().isEmpty() ){
+ for ( Hop sbHop : sb.getHops() ){
+ if ( sbHop instanceof FunctionOp ){
+ String funcName = ((FunctionOp)
sbHop).getFunctionName();
+ FunctionStatementBlock sbFuncBlock =
prog.getBuiltinFunctionDictionary().getFunction(funcName);
+ rewriteStatementBlock(prog,
sbFuncBlock);
+ }
+ else selectFederatedExecutionPlan(sbHop);
+ }
+ }
+ return new ArrayList<>(Collections.singletonList(sb));
+ }
+
/**
* Sets FederatedOutput field of all hops in DAG starting from given
root.
* The FederatedOutput chosen for root is the minimum cost HopRel found
in memo table for the given root.
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
index 75fa735..4de2557 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
@@ -56,15 +56,16 @@ public class RewriteFederatedExecution extends
HopRewriteRule {
@Override
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots,
ProgramRewriteStatus state) {
- if ( roots == null )
- return null;
- for ( Hop root : roots )
- visitHop(root);
+ if ( roots != null )
+ for ( Hop root : roots )
+ rewriteHopDAG(root, state);
return roots;
}
@Override public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus
state) {
- return null;
+ if ( root != null )
+ visitHop(root);
+ return root;
}
private void visitHop(Hop hop){
diff --git
a/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java
b/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java
index e543d9e..cbadd98 100644
---
a/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java
@@ -19,9 +19,10 @@
package org.apache.sysds.test.functions.privacy.algorithms;
+import edu.emory.mathcs.backport.java.util.Arrays;
+import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.junit.Assert;
-import org.junit.Ignore;
import org.junit.Test;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
@@ -31,12 +32,15 @@ import
org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
-import org.apache.wink.json4j.JSONException;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
@net.jcip.annotations.NotThreadSafe
+@RunWith(value = Parameterized.class)
public class FederatedL2SVMTest extends AutomatedTestBase {
private final static String TEST_DIR = "functions/federated/";
@@ -47,57 +51,62 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
private int rows = 100;
private int cols = 10;
- @Override
- public void setUp() {
+ @Parameterized.Parameter()
+ public boolean fedOutCompilation;
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][]{
+ {false},
+ {true}
+ });
+ }
+
+ @Override public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
}
// PrivateAggregation Single Input
- @Test
- @Ignore
- public void federatedL2SVMCPPrivateAggregationX1() throws JSONException
{
+ @Test public void federatedL2SVMCPPrivateAggregationX1() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X1", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
- federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null,
+ PrivacyLevel.PrivateAggregation);
}
- @Test
- @Ignore
- public void federatedL2SVMCPPrivateAggregationX2() throws JSONException
{
+ @Test public void federatedL2SVMCPPrivateAggregationX2() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X2", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
- federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null,
+ PrivacyLevel.PrivateAggregation);
}
- @Test
- public void federatedL2SVMCPPrivateAggregationY() throws JSONException {
+ @Test public void federatedL2SVMCPPrivateAggregationY() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("Y", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
- federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null,
+ PrivacyLevel.PrivateAggregation);
}
// Private Single Input
- @Test
- public void federatedL2SVMCPPrivateFederatedX1() throws JSONException {
+ @Test public void federatedL2SVMCPPrivateFederatedX1() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X1", new
PrivacyConstraint(PrivacyLevel.Private));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private,
- false, null, true, DMLRuntimeException.class);
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private, false, null, true,
+ DMLRuntimeException.class);
}
- @Test
- public void federatedL2SVMCPPrivateFederatedX2() throws JSONException {
+ @Test public void federatedL2SVMCPPrivateFederatedX2() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X2", new
PrivacyConstraint(PrivacyLevel.Private));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private,
- false, null, true, DMLRuntimeException.class);
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private, false, null, true,
+ DMLRuntimeException.class);
}
- @Test
- public void federatedL2SVMCPPrivateFederatedY() throws JSONException {
+ @Test public void federatedL2SVMCPPrivateFederatedY() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("Y", new
PrivacyConstraint(PrivacyLevel.Private));
federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null, PrivacyLevel.Private);
@@ -105,230 +114,212 @@ public class FederatedL2SVMTest extends
AutomatedTestBase {
// Setting Privacy of Matrix (Throws Exception)
- @Test
- public void federatedL2SVMCPPrivateMatrixX1() throws JSONException {
+ @Test public void federatedL2SVMCPPrivateMatrixX1() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X1", new
PrivacyConstraint(PrivacyLevel.Private));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, null,
privacyConstraints, PrivacyLevel.Private,
- false, null, false, null);
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, null,
privacyConstraints, PrivacyLevel.Private, false, null, false,
+ null);
}
- @Test
- public void federatedL2SVMCPPrivateMatrixX2() throws JSONException {
+ @Test public void federatedL2SVMCPPrivateMatrixX2() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X2", new
PrivacyConstraint(PrivacyLevel.Private));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, null,
privacyConstraints, PrivacyLevel.Private,
- false, null, false, null);
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, null,
privacyConstraints, PrivacyLevel.Private, false, null, false,
+ null);
}
- @Test
- public void federatedL2SVMCPPrivateMatrixY() throws JSONException {
+ @Test public void federatedL2SVMCPPrivateMatrixY() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("Y", new
PrivacyConstraint(PrivacyLevel.Private));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, null,
privacyConstraints, PrivacyLevel.Private,
- false, null, false, null);
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, null,
privacyConstraints, PrivacyLevel.Private, false, null, false,
+ null);
}
- @Test
- public void federatedL2SVMCPPrivateFederatedAndMatrixX1() throws
JSONException {
+ @Test public void federatedL2SVMCPPrivateFederatedAndMatrixX1() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X1", new
PrivacyConstraint(PrivacyLevel.Private));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
privacyConstraints, PrivacyLevel.Private,
- false, null, true, DMLRuntimeException.class);
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
privacyConstraints, PrivacyLevel.Private, false,
+ null, true, DMLRuntimeException.class);
}
- @Test
- public void federatedL2SVMCPPrivateFederatedAndMatrixX2() throws
JSONException {
+ @Test public void federatedL2SVMCPPrivateFederatedAndMatrixX2() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X2", new
PrivacyConstraint(PrivacyLevel.Private));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
privacyConstraints, PrivacyLevel.Private,
- false, null, true, DMLRuntimeException.class);
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
privacyConstraints, PrivacyLevel.Private, false,
+ null, true, DMLRuntimeException.class);
}
- @Test
- public void federatedL2SVMCPPrivateFederatedAndMatrixY() throws
JSONException {
+ @Test public void federatedL2SVMCPPrivateFederatedAndMatrixY() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("Y", new
PrivacyConstraint(PrivacyLevel.Private));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
privacyConstraints, PrivacyLevel.Private,
- false, null, false, null);
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
privacyConstraints, PrivacyLevel.Private, false,
+ null, false, null);
}
// Privacy Level Private Combinations
- @Test
- public void federatedL2SVMCPPrivateFederatedX1X2() throws JSONException
{
+ @Test public void federatedL2SVMCPPrivateFederatedX1X2() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X1", new
PrivacyConstraint(PrivacyLevel.Private));
privacyConstraints.put("X2", new
PrivacyConstraint(PrivacyLevel.Private));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private,
- false, null, true, DMLRuntimeException.class);
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private, false, null, true,
+ DMLRuntimeException.class);
}
- @Test
- public void federatedL2SVMCPPrivateFederatedX1Y() throws JSONException {
+ @Test public void federatedL2SVMCPPrivateFederatedX1Y() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X1", new
PrivacyConstraint(PrivacyLevel.Private));
privacyConstraints.put("Y", new
PrivacyConstraint(PrivacyLevel.Private));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private,
- false, null, true, DMLRuntimeException.class);
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private, false, null, true,
+ DMLRuntimeException.class);
}
- @Test
- public void federatedL2SVMCPPrivateFederatedX2Y() throws JSONException {
+ @Test public void federatedL2SVMCPPrivateFederatedX2Y() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X2", new
PrivacyConstraint(PrivacyLevel.Private));
privacyConstraints.put("Y", new
PrivacyConstraint(PrivacyLevel.Private));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private,
- false, null, true, DMLRuntimeException.class);
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private, false, null, true,
+ DMLRuntimeException.class);
}
- @Test
- public void federatedL2SVMCPPrivateFederatedX1X2Y() throws
JSONException {
+ @Test public void federatedL2SVMCPPrivateFederatedX1X2Y() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X1", new
PrivacyConstraint(PrivacyLevel.Private));
privacyConstraints.put("X2", new
PrivacyConstraint(PrivacyLevel.Private));
privacyConstraints.put("Y", new
PrivacyConstraint(PrivacyLevel.Private));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private,
- false, null, true, DMLRuntimeException.class);
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private, false, null, true,
+ DMLRuntimeException.class);
}
// Privacy Level PrivateAggregation Combinations
- @Test
- public void federatedL2SVMCPPrivateAggregationFederatedX1X2() throws
JSONException {
+ @Test public void federatedL2SVMCPPrivateAggregationFederatedX1X2() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X1", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
privacyConstraints.put("X2", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
- federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null,
+ PrivacyLevel.PrivateAggregation);
}
- @Test
- public void federatedL2SVMCPPrivateAggregationFederatedX1Y() throws
JSONException {
+ @Test public void federatedL2SVMCPPrivateAggregationFederatedX1Y() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X1", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
privacyConstraints.put("Y", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
- federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null,
+ PrivacyLevel.PrivateAggregation);
}
- @Test
- public void federatedL2SVMCPPrivateAggregationFederatedX2Y() throws
JSONException {
+ @Test public void federatedL2SVMCPPrivateAggregationFederatedX2Y() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X2", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
privacyConstraints.put("Y", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
- federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null,
+ PrivacyLevel.PrivateAggregation);
}
- @Test
- public void federatedL2SVMCPPrivateAggregationFederatedX1X2Y() throws
JSONException {
+ @Test public void federatedL2SVMCPPrivateAggregationFederatedX1X2Y() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X1", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
privacyConstraints.put("X2", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
privacyConstraints.put("Y", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
- federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null,
+ PrivacyLevel.PrivateAggregation);
}
// Privacy Level Combinations
- @Test
- public void federatedL2SVMCPPrivatePrivateAggregationFederatedX1X2()
throws JSONException {
+ @Test public void
federatedL2SVMCPPrivatePrivateAggregationFederatedX1X2() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X1", new
PrivacyConstraint(PrivacyLevel.Private));
privacyConstraints.put("X2", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private,
- false, null, true, DMLRuntimeException.class);
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private, false, null, true,
+ DMLRuntimeException.class);
}
- @Test
- public void federatedL2SVMCPPrivatePrivateAggregationFederatedX1Y()
throws JSONException {
+ @Test public void
federatedL2SVMCPPrivatePrivateAggregationFederatedX1Y() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X1", new
PrivacyConstraint(PrivacyLevel.Private));
privacyConstraints.put("Y", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private,
- false, null, true, DMLRuntimeException.class);
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private, false, null, true,
+ DMLRuntimeException.class);
}
- @Test
- public void federatedL2SVMCPPrivatePrivateAggregationFederatedX2Y()
throws JSONException {
+ @Test public void
federatedL2SVMCPPrivatePrivateAggregationFederatedX2Y() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X2", new
PrivacyConstraint(PrivacyLevel.Private));
privacyConstraints.put("Y", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private,
- false, null, true, DMLRuntimeException.class);
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private, false, null, true,
+ DMLRuntimeException.class);
}
- @Test
- public void federatedL2SVMCPPrivatePrivateAggregationFederatedYX1()
throws JSONException {
+ @Test public void
federatedL2SVMCPPrivatePrivateAggregationFederatedYX1() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("Y", new
PrivacyConstraint(PrivacyLevel.Private));
privacyConstraints.put("X1", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null, PrivacyLevel.Private);
}
- @Test
- public void federatedL2SVMCPPrivatePrivateAggregationFederatedYX2()
throws JSONException {
+ @Test public void
federatedL2SVMCPPrivatePrivateAggregationFederatedYX2() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("Y", new
PrivacyConstraint(PrivacyLevel.Private));
privacyConstraints.put("X2", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null, PrivacyLevel.Private);
}
- @Test
- public void federatedL2SVMCPPrivatePrivateAggregationFederatedX2X1()
throws JSONException {
+ @Test public void
federatedL2SVMCPPrivatePrivateAggregationFederatedX2X1() {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X2", new
PrivacyConstraint(PrivacyLevel.Private));
privacyConstraints.put("X1", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private,
- false, null, true, DMLRuntimeException.class);
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private, false, null, true,
+ DMLRuntimeException.class);
}
// Require Federated Workers to return matrix
- @Test
- public void federatedL2SVMCPPrivateAggregationX1Exception() throws
JSONException {
- rows = 1000; cols = 1;
+ @Test public void federatedL2SVMCPPrivateAggregationX1Exception() {
+ rows = 1000;
+ cols = 1;
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X1", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
- federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null,
+ PrivacyLevel.PrivateAggregation);
}
- @Test
- public void federatedL2SVMCPPrivateAggregationX2Exception() throws
JSONException {
- rows = 1000; cols = 1;
+ @Test public void federatedL2SVMCPPrivateAggregationX2Exception() {
+ rows = 1000;
+ cols = 1;
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X2", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
- federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null,
+ PrivacyLevel.PrivateAggregation);
}
- @Test
- public void federatedL2SVMCPPrivateX1Exception() throws JSONException {
- rows = 1000; cols = 1;
+ @Test public void federatedL2SVMCPPrivateX1Exception() {
+ rows = 1000;
+ cols = 1;
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X1", new
PrivacyConstraint(PrivacyLevel.Private));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private,
- false, null, true, DMLRuntimeException.class);
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private, false, null, true,
+ DMLRuntimeException.class);
}
- @Test
- public void federatedL2SVMCPPrivateX2Exception() throws JSONException {
- rows = 1000; cols = 1;
+ @Test public void federatedL2SVMCPPrivateX2Exception() {
+ rows = 1000;
+ cols = 1;
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X2", new
PrivacyConstraint(PrivacyLevel.Private));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private,
- false, null, true, DMLRuntimeException.class);
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private, false, null, true,
+ DMLRuntimeException.class);
}
- private void federatedL2SVMNoException(Types.ExecMode execMode,
Map<String,
- PrivacyConstraint> privacyConstraintsFederated,
Map<String, PrivacyConstraint> privacyConstraintsMatrix,
- PrivacyLevel expectedPrivacyLevel)
- throws JSONException
- {
- federatedL2SVM(execMode, privacyConstraintsFederated,
privacyConstraintsMatrix, expectedPrivacyLevel,
- false, null, false, null);
+ private void federatedL2SVMNoException(Types.ExecMode execMode,
+ Map<String, PrivacyConstraint> privacyConstraintsFederated,
+ Map<String, PrivacyConstraint> privacyConstraintsMatrix,
PrivacyLevel expectedPrivacyLevel) {
+ federatedL2SVM(execMode, privacyConstraintsFederated,
privacyConstraintsMatrix, expectedPrivacyLevel, false,
+ null, false, null);
}
private void federatedL2SVM(Types.ExecMode execMode, Map<String,
PrivacyConstraint> privacyConstraintsFederated,
- Map<String, PrivacyConstraint>
privacyConstraintsMatrix, PrivacyLevel expectedPrivacyLevel,
- boolean exception1, Class<?> expectedException1,
boolean exception2, Class<?> expectedException2 )
- throws JSONException
- {
+ Map<String, PrivacyConstraint> privacyConstraintsMatrix,
PrivacyLevel expectedPrivacyLevel, boolean exception1,
+ Class<?> expectedException1, boolean exception2, Class<?>
expectedException2) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
Types.ExecMode platformOld = rtplatform;
rtplatform = execMode;
@@ -351,24 +342,40 @@ public class FederatedL2SVMTest extends AutomatedTestBase
{
Y[i][0] = (Y[i][0] > 0) ? 1 : -1;
// Write privacy constraints of normal matrix
- if ( privacyConstraintsMatrix != null ){
- writeInputMatrixWithMTD("MX1", X1, false, new
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols),
privacyConstraintsMatrix.get("X1"));
- writeInputMatrixWithMTD("MX2", X2, false, new
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols),
privacyConstraintsMatrix.get("X2"));
- writeInputMatrixWithMTD("MY", Y, false, new
MatrixCharacteristics(rows, 1, blocksize, rows),
privacyConstraintsMatrix.get("Y"));
- } else {
- writeInputMatrixWithMTD("MX1", X1, false, new
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
- writeInputMatrixWithMTD("MX2", X2, false, new
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ if(privacyConstraintsMatrix != null) {
+ writeInputMatrixWithMTD("MX1", X1, false,
+ new MatrixCharacteristics(halfRows,
cols, blocksize, halfRows * cols),
+ privacyConstraintsMatrix.get("X1"));
+ writeInputMatrixWithMTD("MX2", X2, false,
+ new MatrixCharacteristics(halfRows,
cols, blocksize, halfRows * cols),
+ privacyConstraintsMatrix.get("X2"));
+ writeInputMatrixWithMTD("MY", Y, false, new
MatrixCharacteristics(rows, 1, blocksize, rows),
+ privacyConstraintsMatrix.get("Y"));
+ }
+ else {
+ writeInputMatrixWithMTD("MX1", X1, false,
+ new MatrixCharacteristics(halfRows,
cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("MX2", X2, false,
+ new MatrixCharacteristics(halfRows,
cols, blocksize, halfRows * cols));
writeInputMatrixWithMTD("MY", Y, false, new
MatrixCharacteristics(rows, 1, blocksize, rows));
}
// Write privacy constraints of federated matrix
- if ( privacyConstraintsFederated != null ){
- writeInputMatrixWithMTD("X1", X1, false, new
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols),
privacyConstraintsFederated.get("X1"));
- writeInputMatrixWithMTD("X2", X2, false, new
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols),
privacyConstraintsFederated.get("X2"));
- writeInputMatrixWithMTD("Y", Y, false, new
MatrixCharacteristics(rows, 1, blocksize, rows),
privacyConstraintsFederated.get("Y"));
- } else {
- writeInputMatrixWithMTD("X1", X1, false, new
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
- writeInputMatrixWithMTD("X2", X2, false, new
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ if(privacyConstraintsFederated != null) {
+ writeInputMatrixWithMTD("X1", X1, false,
+ new MatrixCharacteristics(halfRows,
cols, blocksize, halfRows * cols),
+ privacyConstraintsFederated.get("X1"));
+ writeInputMatrixWithMTD("X2", X2, false,
+ new MatrixCharacteristics(halfRows,
cols, blocksize, halfRows * cols),
+ privacyConstraintsFederated.get("X2"));
+ writeInputMatrixWithMTD("Y", Y, false, new
MatrixCharacteristics(rows, 1, blocksize, rows),
+ privacyConstraintsFederated.get("Y"));
+ }
+ else {
+ writeInputMatrixWithMTD("X1", X1, false,
+ new MatrixCharacteristics(halfRows,
cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("X2", X2, false,
+ new MatrixCharacteristics(halfRows,
cols, blocksize, halfRows * cols));
writeInputMatrixWithMTD("Y", Y, false, new
MatrixCharacteristics(rows, 1, blocksize, rows));
}
@@ -388,25 +395,27 @@ public class FederatedL2SVMTest extends AutomatedTestBase
{
runTest(true, exception1, expectedException1, -1);
// Run actual dml script with federated matrix
+ OptimizerUtils.FEDERATED_COMPILATION =
fedOutCompilation;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-checkPrivacy",
- "-nvargs", "in_X1=" +
TestUtils.federatedAddress(port1, input("X1")),
+ programArgs = new String[] {"-stats", "-checkPrivacy",
"-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2,
input("X2")), "rows=" + rows, "cols=" + cols,
"in_Y=" + input("Y"), "single=FALSE", "out=" +
output("Z")};
-
+
runTest(true, exception2, expectedException2, -1);
- if ( !(exception1 || exception2) ) {
+ if(!(exception1 || exception2)) {
compareResults(1e-9);
}
- if ( expectedPrivacyLevel != null)
+ if(expectedPrivacyLevel != null)
Assert.assertTrue(checkedPrivacyConstraintsContains(expectedPrivacyLevel));
}
finally {
TestUtils.shutdownThreads(t1, t2);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ OptimizerUtils.FEDERATED_COMPILATION = false;
}
}
}