This is an automated email from the ASF dual-hosted git repository.
sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 65ea7f3189 [SYSTEMDS-3018] Add Function Parameters to Cost-Based
Federated Planner
65ea7f3189 is described below
commit 65ea7f318957127e3e75f5bc8cc7d1b5a356c885
Author: sebwrede <[email protected]>
AuthorDate: Tue May 17 10:32:12 2022 +0200
[SYSTEMDS-3018] Add Function Parameters to Cost-Based Federated Planner
This commit will also:
- Add Null Check to Repetition Estimate Update
- Add Transient Writes to Terminal Hops
- Edit Transpose FEDInstruction So That LOUT Binds Output Fedmapping
Correctly
- Edit L2SVM Fed Planning Test To Prepare for L2SVM Function Call Tests
Closes #1618.
---
.../java/org/apache/sysds/hops/AggBinaryOp.java | 4 +-
.../hops/fedplanner/FederatedPlannerCostbased.java | 122 ++++++++++++++-------
.../apache/sysds/hops/rewrite/HopRewriteUtils.java | 4 +-
.../org/apache/sysds/parser/ForStatementBlock.java | 9 +-
.../apache/sysds/parser/WhileStatementBlock.java | 3 +-
.../instructions/fed/ReorgFEDInstruction.java | 4 +-
.../fedplanning/FederatedL2SVMPlanningTest.java | 46 ++++++--
.../fedplanning/FederatedMultiplyPlanningTest.java | 1 -
.../FederatedL2SVMFunctionPlanningTest.dml | 36 ++++++
...FederatedL2SVMFunctionPlanningTestReference.dml | 35 ++++++
10 files changed, 203 insertions(+), 61 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index 403a0466f0..16b42c4840 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -44,7 +44,7 @@ import org.apache.sysds.lops.PMMJ;
import org.apache.sysds.lops.PMapMult;
import org.apache.sysds.lops.Transform;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
@@ -677,7 +677,7 @@ public class AggBinaryOp extends MultiThreadedHop {
setLineNumbers(mult);
//result transpose (dimensions set outside)
- ExecType outTransposeExecType = ( _federatedOutput ==
FEDInstruction.FederatedOutput.FOUT ) ?
+ ExecType outTransposeExecType = ( _federatedOutput ==
FederatedOutput.FOUT ) ?
ExecType.FED : ExecType.CP;
Lop out = new Transform(mult, ReOrgOp.TRANS, getDataType(),
getValueType(), outTransposeExecType, k);
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
index e9a25206f8..1f9abb4c18 100644
---
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
+++
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
@@ -78,7 +78,7 @@ public class FederatedPlannerCostbased extends
AFederatedPlanner {
@Override
public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph,
FunctionCallSizeInfo fcallSizes ) {
prog.updateRepetitionEstimates();
- rewriteStatementBlocks(prog, prog.getStatementBlocks());
+ rewriteStatementBlocks(prog, prog.getStatementBlocks(), null);
setFinalFedouts();
updateExplain();
}
@@ -89,12 +89,13 @@ public class FederatedPlannerCostbased extends
AFederatedPlanner {
*
* @param prog dml program
* @param sbs list of statement blocks
+ * @param paramMap map of parameters in function call
* @return list of statement blocks with the federated output value
updated for each hop
*/
- private ArrayList<StatementBlock> rewriteStatementBlocks(DMLProgram
prog, List<StatementBlock> sbs) {
+ private ArrayList<StatementBlock> rewriteStatementBlocks(DMLProgram
prog, List<StatementBlock> sbs, Map<String, Hop> paramMap) {
ArrayList<StatementBlock> rewrittenStmBlocks = new
ArrayList<>();
for(StatementBlock stmBlock : sbs)
- rewrittenStmBlocks.addAll(rewriteStatementBlock(prog,
stmBlock));
+ rewrittenStmBlocks.addAll(rewriteStatementBlock(prog,
stmBlock, paramMap));
return rewrittenStmBlocks;
}
@@ -104,79 +105,99 @@ public class FederatedPlannerCostbased extends
AFederatedPlanner {
*
* @param prog dml program
* @param sb statement block
+ * @param paramMap map of parameters in function call
* @return list of statement blocks with the federated output value
updated for each hop
*/
- public ArrayList<StatementBlock> rewriteStatementBlock(DMLProgram prog,
StatementBlock sb) {
+ public ArrayList<StatementBlock> rewriteStatementBlock(DMLProgram prog,
StatementBlock sb, Map<String, Hop> paramMap) {
if(sb instanceof WhileStatementBlock)
- return rewriteWhileStatementBlock(prog,
(WhileStatementBlock) sb);
+ return rewriteWhileStatementBlock(prog,
(WhileStatementBlock) sb, paramMap);
else if(sb instanceof IfStatementBlock)
- return rewriteIfStatementBlock(prog, (IfStatementBlock)
sb);
+ return rewriteIfStatementBlock(prog, (IfStatementBlock)
sb, paramMap);
else if(sb instanceof ForStatementBlock) {
// This also includes ParForStatementBlocks
- return rewriteForStatementBlock(prog,
(ForStatementBlock) sb);
+ return rewriteForStatementBlock(prog,
(ForStatementBlock) sb, paramMap);
}
else if(sb instanceof FunctionStatementBlock)
- return rewriteFunctionStatementBlock(prog,
(FunctionStatementBlock) sb);
+ return rewriteFunctionStatementBlock(prog,
(FunctionStatementBlock) sb, paramMap);
else {
// StatementBlock type (no subclass)
- return rewriteDefaultStatementBlock(prog, sb);
+ return rewriteDefaultStatementBlock(prog, sb, paramMap);
}
}
- private ArrayList<StatementBlock> rewriteWhileStatementBlock(DMLProgram
prog, WhileStatementBlock whileSB) {
+ private ArrayList<StatementBlock> rewriteWhileStatementBlock(DMLProgram
prog, WhileStatementBlock whileSB, Map<String, Hop> paramMap) {
Hop whilePredicateHop = whileSB.getPredicateHops();
- selectFederatedExecutionPlan(whilePredicateHop);
+ selectFederatedExecutionPlan(whilePredicateHop, paramMap);
for(Statement stm : whileSB.getStatements()) {
WhileStatement whileStm = (WhileStatement) stm;
- whileStm.setBody(rewriteStatementBlocks(prog,
whileStm.getBody()));
+ whileStm.setBody(rewriteStatementBlocks(prog,
whileStm.getBody(), paramMap));
}
return new ArrayList<>(Collections.singletonList(whileSB));
}
- private ArrayList<StatementBlock> rewriteIfStatementBlock(DMLProgram
prog, IfStatementBlock ifSB) {
- selectFederatedExecutionPlan(ifSB.getPredicateHops());
+ private ArrayList<StatementBlock> rewriteIfStatementBlock(DMLProgram
prog, IfStatementBlock ifSB, Map<String, Hop> paramMap) {
+ selectFederatedExecutionPlan(ifSB.getPredicateHops(), paramMap);
for(Statement statement : ifSB.getStatements()) {
IfStatement ifStatement = (IfStatement) statement;
- ifStatement.setIfBody(rewriteStatementBlocks(prog,
ifStatement.getIfBody()));
- ifStatement.setElseBody(rewriteStatementBlocks(prog,
ifStatement.getElseBody()));
+ ifStatement.setIfBody(rewriteStatementBlocks(prog,
ifStatement.getIfBody(), paramMap));
+ ifStatement.setElseBody(rewriteStatementBlocks(prog,
ifStatement.getElseBody(), paramMap));
}
return new ArrayList<>(Collections.singletonList(ifSB));
}
- private ArrayList<StatementBlock> rewriteForStatementBlock(DMLProgram
prog, ForStatementBlock forSB) {
- selectFederatedExecutionPlan(forSB.getFromHops());
- selectFederatedExecutionPlan(forSB.getToHops());
- selectFederatedExecutionPlan(forSB.getIncrementHops());
+ private ArrayList<StatementBlock> rewriteForStatementBlock(DMLProgram
prog, ForStatementBlock forSB, Map<String, Hop> paramMap) {
+ selectFederatedExecutionPlan(forSB.getFromHops(), paramMap);
+ selectFederatedExecutionPlan(forSB.getToHops(), paramMap);
+ selectFederatedExecutionPlan(forSB.getIncrementHops(),
paramMap);
for(Statement statement : forSB.getStatements()) {
ForStatement forStatement = ((ForStatement) statement);
- forStatement.setBody(rewriteStatementBlocks(prog,
forStatement.getBody()));
+ forStatement.setBody(rewriteStatementBlocks(prog,
forStatement.getBody(), paramMap));
}
return new ArrayList<>(Collections.singletonList(forSB));
}
- private ArrayList<StatementBlock>
rewriteFunctionStatementBlock(DMLProgram prog, FunctionStatementBlock funcSB) {
+ private ArrayList<StatementBlock>
rewriteFunctionStatementBlock(DMLProgram prog, FunctionStatementBlock funcSB,
Map<String, Hop> paramMap) {
for(Statement statement : funcSB.getStatements()) {
FunctionStatement funcStm = (FunctionStatement)
statement;
- funcStm.setBody(rewriteStatementBlocks(prog,
funcStm.getBody()));
+ funcStm.setBody(rewriteStatementBlocks(prog,
funcStm.getBody(), paramMap));
}
return new ArrayList<>(Collections.singletonList(funcSB));
}
- private ArrayList<StatementBlock>
rewriteDefaultStatementBlock(DMLProgram prog, StatementBlock sb) {
+ private ArrayList<StatementBlock>
rewriteDefaultStatementBlock(DMLProgram prog, StatementBlock sb, Map<String,
Hop> paramMap) {
if(sb.hasHops()) {
for(Hop sbHop : sb.getHops()) {
+ selectFederatedExecutionPlan(sbHop, paramMap);
if(sbHop instanceof FunctionOp) {
String funcName = ((FunctionOp)
sbHop).getFunctionName();
+ Map<String, Hop> funcParamMap =
getParamMap((FunctionOp) sbHop);
+ if ( paramMap != null && funcParamMap
!= null)
+ funcParamMap.putAll(paramMap);
+ paramMap = funcParamMap;
FunctionStatementBlock sbFuncBlock =
prog.getBuiltinFunctionDictionary().getFunction(funcName);
- rewriteStatementBlock(prog,
sbFuncBlock);
+ rewriteStatementBlock(prog,
sbFuncBlock, paramMap);
}
- else
- selectFederatedExecutionPlan(sbHop);
}
}
return new ArrayList<>(Collections.singletonList(sb));
}
+ /**
+ * Return parameter map containing the mapping from parameter name to
input hop
+ * for all parameters of the function hop.
+ * @param funcOp hop for which the mapping of parameter names to input
hops are made
+ * @return parameter map or empty map if function has no parameters
+ */
+ private Map<String,Hop> getParamMap(FunctionOp funcOp){
+ String[] inputNames = funcOp.getInputVariableNames();
+ Map<String,Hop> paramMap = new HashMap<>();
+ if ( inputNames != null ){
+ for ( int i = 0; i < funcOp.getInput().size(); i++ )
+ paramMap.put(inputNames[i],funcOp.getInput(i));
+ }
+ return paramMap;
+ }
+
/**
* Set final fedouts of all hops starting from terminal hops.
*/
@@ -266,21 +287,23 @@ public class FederatedPlannerCostbased extends
AFederatedPlanner {
* The cost estimates of the hops are also updated when FederatedOutput
is updated in the hops.
*
* @param roots starting point for going through the Hop DAG to update
the FederatedOutput fields.
+ * @param paramMap map of parameters in function call
*/
@SuppressWarnings("unused")
- private void selectFederatedExecutionPlan(ArrayList<Hop> roots){
+ private void selectFederatedExecutionPlan(ArrayList<Hop> roots,
Map<String, Hop> paramMap){
for ( Hop root : roots )
- selectFederatedExecutionPlan(root);
+ selectFederatedExecutionPlan(root, paramMap);
}
/**
* Select federated execution plan for every Hop in the DAG starting
from given root.
*
* @param root starting point for going through the Hop DAG to update
the federatedOutput fields
+ * @param paramMap map of parameters in function call
*/
- private void selectFederatedExecutionPlan(Hop root) {
+ private void selectFederatedExecutionPlan(Hop root, Map<String, Hop>
paramMap) {
if ( root != null ){
- visitFedPlanHop(root);
+ visitFedPlanHop(root, paramMap);
if ( HopRewriteUtils.isTerminalHop(root) )
terminalHops.add(root);
}
@@ -290,17 +313,18 @@ public class FederatedPlannerCostbased extends
AFederatedPlanner {
* Go through the Hop DAG and set the FederatedOutput field and cost
estimate for each Hop from leaf to given currentHop.
*
* @param currentHop the Hop from which the DAG is visited
+ * @param paramMap map of parameters in function call
*/
- private void visitFedPlanHop(Hop currentHop) {
+ private void visitFedPlanHop(Hop currentHop, Map<String, Hop> paramMap)
{
// If the currentHop is in the hopRelMemo table, it means that
it has been visited
if(hopRelMemo.containsHop(currentHop))
return;
debugLog(currentHop);
// If the currentHop has input, then the input should be
visited depth-first
for(Hop input : currentHop.getInput())
- visitFedPlanHop(input);
+ visitFedPlanHop(input, paramMap);
// Put FOUT and LOUT HopRels into the memo table
- ArrayList<HopRel> hopRels = getFedPlans(currentHop);
+ ArrayList<HopRel> hopRels = getFedPlans(currentHop, paramMap);
// Put NONE HopRel into memo table if no FOUT or LOUT HopRels
were added
if(hopRels.isEmpty())
hopRels.add(getNONEHopRel(currentHop));
@@ -319,17 +343,14 @@ public class FederatedPlannerCostbased extends
AFederatedPlanner {
/**
* Get the alternative plans regarding the federated output for given
currentHop.
* @param currentHop for which alternative federated plans are generated
+ * @param paramMap map of parameters in function call
* @return list of alternative plans
*/
- private ArrayList<HopRel> getFedPlans(Hop currentHop){
+ private ArrayList<HopRel> getFedPlans(Hop currentHop, Map<String, Hop>
paramMap){
ArrayList<HopRel> hopRels = new ArrayList<>();
ArrayList<Hop> inputHops = currentHop.getInput();
- if ( HopRewriteUtils.isData(currentHop,
Types.OpOpData.TRANSIENTREAD) ){
- Hop tWriteHop =
transientWrites.get(currentHop.getName());
- if ( tWriteHop == null )
- throw new DMLRuntimeException("Transient write
not found for " + currentHop);
- inputHops = new
ArrayList<>(Collections.singletonList(tWriteHop));
- }
+ if ( HopRewriteUtils.isData(currentHop,
Types.OpOpData.TRANSIENTREAD) )
+ inputHops = getTransientInputs(currentHop, paramMap);
if ( HopRewriteUtils.isData(currentHop,
Types.OpOpData.TRANSIENTWRITE) )
transientWrites.put(currentHop.getName(), currentHop);
if ( HopRewriteUtils.isData(currentHop,
Types.OpOpData.FEDERATED) )
@@ -341,6 +362,25 @@ public class FederatedPlannerCostbased extends
AFederatedPlanner {
return hopRels;
}
+ /**
+ * Get transient inputs from either paramMap or transientWrites.
+ * Inputs from paramMap has higher priority than inputs from
transientWrites.
+ * @param currentHop hop for which inputs are read from maps
+ * @param paramMap of local parameters
+ * @return inputs of currentHop
+ */
+ private ArrayList<Hop> getTransientInputs(Hop currentHop, Map<String,
Hop> paramMap){
+ Hop tWriteHop = null;
+ if ( paramMap != null)
+ tWriteHop = paramMap.get(currentHop.getName());
+ if ( tWriteHop == null )
+ tWriteHop = transientWrites.get(currentHop.getName());
+ if ( tWriteHop == null )
+ throw new DMLRuntimeException("Transient write not
found for " + currentHop);
+ else
+ return new
ArrayList<>(Collections.singletonList(tWriteHop));
+ }
+
/**
* Generate a collection of FOUT HopRels representing the different
possible FType outputs.
* For each FType output, only the minimum cost input combination is
chosen.
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index 6dbc5e35b6..d10a43e810 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -1167,7 +1167,9 @@ public class HopRewriteUtils {
public static boolean isTerminalHop(Hop hop){
return isUnary(hop, OpOp1.PRINT)
|| isNary(hop, OpOpN.PRINTF)
- || isData(hop, OpOpData.PERSISTENTWRITE);
+ || isData(hop, OpOpData.PERSISTENTWRITE)
+ || isData(hop, OpOpData.TRANSIENTWRITE)
+ || hop instanceof FunctionOp;
}
public static boolean isMatrixMultiply(Hop hop) {
diff --git a/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
b/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
index b21b9b58a6..ce31ae9bcf 100644
--- a/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
@@ -453,9 +453,12 @@ public class ForStatementBlock extends StatementBlock
@Override
public void updateRepetitionEstimates(double repetitions){
this.repetitions = repetitions * getEstimateReps();
- _fromHops.updateRepetitionEstimates(this.repetitions);
- _toHops.updateRepetitionEstimates(this.repetitions);
- _incrementHops.updateRepetitionEstimates(this.repetitions);
+ if ( _fromHops != null )
+ _fromHops.updateRepetitionEstimates(this.repetitions);
+ if ( _toHops != null )
+ _toHops.updateRepetitionEstimates(this.repetitions);
+ if ( _incrementHops != null )
+
_incrementHops.updateRepetitionEstimates(this.repetitions);
for(Statement statement : getStatements()) {
List<StatementBlock> children = ((ForStatement)
statement).getBody();
for ( StatementBlock stmBlock : children ){
diff --git a/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java
b/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java
index b28e6825bb..8a92f3bc23 100644
--- a/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java
@@ -322,7 +322,8 @@ public class WhileStatementBlock extends StatementBlock
@Override
public void updateRepetitionEstimates(double repetitions){
this.repetitions = repetitions * DEFAULT_LOOP_REPETITIONS;
- getPredicateHops().updateRepetitionEstimates(this.repetitions);
+ if ( getPredicateHops() != null )
+
getPredicateHops().updateRepetitionEstimates(this.repetitions);
for(Statement statement : getStatements()) {
List<StatementBlock> children =
((WhileStatement)statement).getBody();
for ( StatementBlock stmBlock : children ){
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index aff69a24a6..2f9e26a2a1 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -104,7 +104,7 @@ public class ReorgFEDInstruction extends
UnaryFEDInstruction {
if( !mo1.isFederated() )
throw new DMLRuntimeException("Federated Reorg: "
+ "Federated input expected, but invoked w/
"+mo1.isFederated());
- if ( !( mo1.isFederated(FType.COL) ||
mo1.isFederated(FType.ROW) || mo1.isFederated(FType.PART) ) )
+ if ( !( mo1.isFederated(FType.COL) ||
mo1.isFederated(FType.ROW) ) )
throw new DMLRuntimeException("Federation type " +
mo1.getFedMapping().getType()
+ " is not supported for Reorg processing");
@@ -126,7 +126,7 @@ public class ReorgFEDInstruction extends
UnaryFEDInstruction {
FederatedRequest getRequest = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
Future<FederatedResponse>[] execResponse =
mo1.getFedMapping().execute(getTID(), true, fr1, getRequest);
ec.setMatrixOutput(output.getName(),
- FederationUtils.bind(execResponse,
mo1.isFederated(FType.COL)));
+ FederationUtils.bind(execResponse,
mo1.isFederated(FType.ROW)));
}
} else if ( mo1.isFederated(FType.PART) ){
throw new DMLRuntimeException("Operation with opcode "
+ instOpcode + " is not supported with PART input");
diff --git
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
index 1ba9966773..60ab0d93ce 100644
---
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
@@ -28,6 +28,7 @@ import org.apache.sysds.runtime.privacy.PrivacyConstraint;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
+import org.junit.Ignore;
import org.junit.Test;
import java.io.File;
@@ -41,6 +42,7 @@ public class FederatedL2SVMPlanningTest extends
AutomatedTestBase {
private final static String TEST_DIR = "functions/privacy/fedplanning/";
private final static String TEST_NAME = "FederatedL2SVMPlanningTest";
+ private final static String TEST_NAME_2 =
"FederatedL2SVMFunctionPlanningTest";
private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedL2SVMPlanningTest.class.getSimpleName() + "/";
private static File TEST_CONF_FILE;
@@ -52,6 +54,7 @@ public class FederatedL2SVMPlanningTest extends
AutomatedTestBase {
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+ addTestConfiguration(TEST_NAME_2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_2, new String[] {"Z"}));
}
@Test
@@ -59,24 +62,47 @@ public class FederatedL2SVMPlanningTest extends
AutomatedTestBase {
String[] expectedHeavyHitters = new String[]{ "fed_fedinit",
"fed_ba+*", "fed_tak+*", "fed_+*",
"fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
setTestConf("SystemDS-config-fout.xml");
- loadAndRunTest(expectedHeavyHitters);
+ loadAndRunTest(expectedHeavyHitters, TEST_NAME);
}
@Test
public void runL2SVMHeuristicTest(){
String[] expectedHeavyHitters = new String[]{ "fed_fedinit",
"fed_ba+*"};
setTestConf("SystemDS-config-heuristic.xml");
- loadAndRunTest(expectedHeavyHitters);
+ loadAndRunTest(expectedHeavyHitters, TEST_NAME);
}
@Test
public void runL2SVMCostBasedTest(){
- //String[] expectedHeavyHitters = new String[]{ "fed_fedinit",
"fed_ba+*", "fed_tak+*", "fed_+*",
- // "fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
String[] expectedHeavyHitters = new String[]{ "fed_fedinit",
"fed_ba+*", "fed_tak+*", "fed_+*",
- "fed_max", "fed_1-*", "fed_>"};
+ "fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
+ setTestConf("SystemDS-config-cost-based.xml");
+ loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+ }
+
+ @Test
+ @Ignore
+ public void runL2SVMFunctionFOUTTest(){
+ String[] expectedHeavyHitters = new String[]{ "fed_fedinit",
"fed_ba+*", "fed_tak+*", "fed_+*",
+ "fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
+ setTestConf("SystemDS-config-fout.xml");
+ loadAndRunTest(expectedHeavyHitters, TEST_NAME_2);
+ }
+
+ @Test
+ @Ignore
+ public void runL2SVMFunctionHeuristicTest(){
+ String[] expectedHeavyHitters = new String[]{ "fed_fedinit",
"fed_ba+*"};
+ setTestConf("SystemDS-config-heuristic.xml");
+ loadAndRunTest(expectedHeavyHitters, TEST_NAME_2);
+ }
+
+ @Test
+ public void runL2SVMFunctionCostBasedTest(){
+ String[] expectedHeavyHitters = new String[]{ "fed_fedinit",
"fed_ba+*", "fed_tak+*", "fed_+*",
+ "fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
setTestConf("SystemDS-config-cost-based.xml");
- loadAndRunTest(expectedHeavyHitters);
+ loadAndRunTest(expectedHeavyHitters, TEST_NAME_2);
}
private void setTestConf(String test_conf){
@@ -117,7 +143,7 @@ public class FederatedL2SVMPlanningTest extends
AutomatedTestBase {
writeStandardMatrix(matrixName, seed, halfRows,
privacyConstraint);
}
- private void loadAndRunTest(String[] expectedHeavyHitters){
+ private void loadAndRunTest(String[] expectedHeavyHitters, String
testName){
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
Types.ExecMode platformOld = rtplatform;
@@ -126,7 +152,7 @@ public class FederatedL2SVMPlanningTest extends
AutomatedTestBase {
Thread t1 = null, t2 = null;
try {
- getAndLoadTestConfiguration(TEST_NAME);
+ getAndLoadTestConfiguration(testName);
String HOME = SCRIPT_DIR + TEST_DIR;
writeInputMatrices();
@@ -137,7 +163,7 @@ public class FederatedL2SVMPlanningTest extends
AutomatedTestBase {
t2 = startLocalFedWorkerThread(port2);
// Run actual dml script with federated matrix
- fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ fullDMLScriptName = HOME + testName + ".dml";
programArgs = new String[] { "-stats", "-explain",
"hops", "-nvargs",
"X1=" + TestUtils.federatedAddress(port1,
input("X1")),
"X2=" + TestUtils.federatedAddress(port2,
input("X2")),
@@ -145,7 +171,7 @@ public class FederatedL2SVMPlanningTest extends
AutomatedTestBase {
runTest(true, false, null, -1);
// Run reference dml script with normal matrix
- fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ fullDMLScriptName = HOME + testName + "Reference.dml";
programArgs = new String[] {"-nvargs", "X1=" +
input("X1"), "X2=" + input("X2"),
"Y=" + input("Y"), "Z=" + expected("Z")};
runTest(true, false, null, -1);
diff --git
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index e8d16f6bcb..b9a3a14fd5 100644
---
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -130,7 +130,6 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
}
@Test
- @Ignore
public void federatedMultiplyDoubleHop() {
String[] expectedHeavyHitters = new String[]{"fed_*",
"fed_fedinit", "fed_r'", "fed_ba+*"};
federatedTwoMatricesSingleNodeTest(TEST_NAME_7,
expectedHeavyHitters);
diff --git
a/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMFunctionPlanningTest.dml
b/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMFunctionPlanningTest.dml
new file mode 100644
index 0000000000..134d1b35c2
--- /dev/null
+++
b/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMFunctionPlanningTest.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+ maxii = 20
+ verbose = FALSE
+ columnId = -1
+ Y = read($Y)
+ X = federated(addresses=list($X1, $X2),
+ ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
+ intercept = FALSE
+ epsilon = 1e-12
+ reg = 1
+ maxIterations = 100
+
+ model = l2svm(X=X, Y=Y, intercept = FALSE, epsilon = epsilon, reg = reg,
maxIterations = maxIterations)
+
+ write(model, $Z)
diff --git
a/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMFunctionPlanningTestReference.dml
b/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMFunctionPlanningTestReference.dml
new file mode 100644
index 0000000000..7fec5d2a20
--- /dev/null
+++
b/src/test/scripts/functions/privacy/fedplanning/FederatedL2SVMFunctionPlanningTestReference.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+ maxii = 20
+ verbose = FALSE
+ columnId = -1
+ Y = read($Y)
+ X = rbind(read($X1), read($X2))
+ intercept = FALSE
+ epsilon = 1e-12
+ reg = 1
+ maxIterations = 100
+
+ model = l2svm(X=X, Y=Y, intercept = FALSE, epsilon = epsilon, reg = reg,
maxIterations = maxIterations)
+
+ write(model, $Z)