This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 13611b0d92 [SYSTEMDS-3018] Support function calls in FedAll and
Heuristic planners
13611b0d92 is described below
commit 13611b0d92b2423c18f3ac9725a93b39b2720aa0
Author: Kevin Innerebner <[email protected]>
AuthorDate: Thu May 8 16:40:40 2025 +0200
[SYSTEMDS-3018] Support function calls in FedAll and Heuristic planners
Closes #1666.
---
.../hops/fedplanner/FederatedPlannerFedAll.java | 53 ++++++++++++++++------
.../hops/fedplanner/FederatedPlannerUtils.java | 23 ----------
.../fedplanning/FederatedL2SVMPlanningTest.java | 3 --
3 files changed, 40 insertions(+), 39 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java
index 4bf2e5606a..59967b7cf1 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java
@@ -19,13 +19,13 @@
package org.apache.sysds.hops.fedplanner;
-import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.hops.ipa.FunctionCallGraph;
@@ -89,14 +89,14 @@ public class FederatedPlannerFedAll extends
AFederatedPlanner {
else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt =
(WhileStatement)wsb.getStatement(0);
- rRewriteHop(wsb.getPredicateHops(), new HashMap<>(),
Collections.emptyMap());
+ rRewriteHop(wsb.getPredicateHops(), new HashMap<>(),
new HashMap<>(), sb.getDMLProg());
for (StatementBlock csb : wstmt.getBody())
rRewriteStatementBlock(csb, fedVars);
}
else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement)isb.getStatement(0);
- rRewriteHop(isb.getPredicateHops(), new HashMap<>(),
Collections.emptyMap());
+ rRewriteHop(isb.getPredicateHops(), new HashMap<>(),
new HashMap<>(), sb.getDMLProg());
for (StatementBlock csb : istmt.getIfBody())
rRewriteStatementBlock(csb, fedVars);
for (StatementBlock csb : istmt.getElseBody())
@@ -105,9 +105,9 @@ public class FederatedPlannerFedAll extends
AFederatedPlanner {
else if (sb instanceof ForStatementBlock) { //incl parfor
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement)fsb.getStatement(0);
- rRewriteHop(fsb.getFromHops(), new HashMap<>(),
Collections.emptyMap());
- rRewriteHop(fsb.getToHops(), new HashMap<>(),
Collections.emptyMap());
- rRewriteHop(fsb.getIncrementHops(), new HashMap<>(),
Collections.emptyMap());
+ rRewriteHop(fsb.getFromHops(), new HashMap<>(), new
HashMap<>(), sb.getDMLProg());
+ rRewriteHop(fsb.getToHops(), new HashMap<>(), new
HashMap<>(), sb.getDMLProg());
+ rRewriteHop(fsb.getIncrementHops(), new HashMap<>(),
new HashMap<>(), sb.getDMLProg());
for (StatementBlock csb : fstmt.getBody())
rRewriteStatementBlock(csb, fedVars);
}
@@ -117,9 +117,7 @@ public class FederatedPlannerFedAll extends
AFederatedPlanner {
Map<Long, FType> fedHops = new HashMap<>();
if( sb.getHops() != null )
for( Hop c : sb.getHops() )
- rRewriteHop(c, fedHops, fedVars);
-
- //TODO handle function calls
+ rRewriteHop(c, fedHops, fedVars,
sb.getDMLProg());
//propagate federated outputs across DAGs
if( sb.getHops() != null )
@@ -129,19 +127,31 @@ public class FederatedPlannerFedAll extends
AFederatedPlanner {
}
}
- private void rRewriteHop(Hop hop, Map<Long, FType> memo, Map<String,
FType> fedVars) {
- if( memo.containsKey(hop.getHopID()) )
+ private void rRewriteHop(Hop hop, Map<Long, FType> memo, Map<String,
FType> fedVars, DMLProgram program) {
+ if( hop == null || memo.containsKey(hop.getHopID()) )
return; //already processed
//process children first
for( Hop c : hop.getInput() )
- rRewriteHop(c, memo, fedVars);
+ rRewriteHop(c, memo, fedVars, program);
//handle specific operators (except transient writes)
- if( HopRewriteUtils.isData(hop, OpOpData.FEDERATED) )
+ if(hop instanceof FunctionOp) {
+ String funcName = ((FunctionOp) hop).getFunctionName();
+ String funcNamespace = ((FunctionOp)
hop).getFunctionNamespace();
+ FunctionStatementBlock sbFuncBlock =
program.getFunctionDictionary(funcNamespace).getFunction(funcName);
+ FunctionStatement funcStatement = (FunctionStatement)
sbFuncBlock.getStatement(0);
+
+ Map<String, FType> funcFedVars =
createFunctionFedVarTable((FunctionOp) hop, memo);
+ rRewriteStatementBlock(sbFuncBlock, funcFedVars);
+ mapFunctionOutputs((FunctionOp) hop, funcStatement,
funcFedVars, fedVars);
+ }
+ else if( HopRewriteUtils.isData(hop, OpOpData.FEDERATED) )
memo.put(hop.getHopID(), deriveFType((DataOp)hop));
else if( HopRewriteUtils.isData(hop, OpOpData.TRANSIENTREAD) )
memo.put(hop.getHopID(), fedVars.get(hop.getName()));
+ else if( HopRewriteUtils.isData(hop, OpOpData.TRANSIENTWRITE) )
+ fedVars.put(hop.getName(), memo.get(hop.getHopID()));
else if( allowsFederated(hop, memo) ) {
hop.setForcedExecType(ExecType.FED);
memo.put(hop.getHopID(), getFederatedOut(hop, memo));
@@ -151,4 +161,21 @@ public class FederatedPlannerFedAll extends
AFederatedPlanner {
else // memoization as processed, but not federated
memo.put(hop.getHopID(), null);
}
+
+ static private Map<String, FType> createFunctionFedVarTable(FunctionOp
hop, Map<Long, FType> memo) {
+ Map<String, Hop> funcParamMap =
FederatedPlannerUtils.getParamMap(hop);
+ Map<String, FType> funcFedVars = new HashMap<>();
+ funcParamMap.forEach((key, value) -> {
+ funcFedVars.put(key, memo.get(value.getHopID()));
+ });
+ return funcFedVars;
+ }
+
+ private void mapFunctionOutputs(FunctionOp sbHop, FunctionStatement
funcStatement,
+ Map<String, FType> funcFedVars, Map<String, FType> callFedVars)
{
+ for(int i = 0; i < sbHop.getOutputVariableNames().length; ++i) {
+ FType outputFType =
funcFedVars.get(funcStatement.getOutputParams().get(i).getName());
+ callFedVars.put(sbHop.getOutputVariableNames()[i],
outputFType);
+ }
+ }
}
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java
index 42c5f648f1..5951bd313a 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java
@@ -21,7 +21,6 @@ package org.apache.sysds.hops.fedplanner;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
-import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
@@ -76,26 +75,4 @@ public class FederatedPlannerUtils {
}
return paramMap;
}
-
- /**
- * Saves the HOPs (TWrite) of the function return values for
- * the variable name used when calling the function.
- *
- * Example:
- * <code>
- * f = function() return (matrix[double] model) {a = rand(1, 1);}
- * b = f();
- * </code>
- * This function saves the HOP writing to <code>a</code> for identifier
<code>b</code>.
- *
- * @param sbHop The <code>FunctionOp</code> for the call
- * @param funcStatement The <code>FunctionStatement</code> of the
called function
- * @param transientWrites map of transient writes
- */
- public static void mapFunctionOutputs(FunctionOp sbHop,
FunctionStatement funcStatement, Map<String,Hop> transientWrites) {
- for (int i = 0; i < sbHop.getOutputVariableNames().length; ++i)
{
- Hop outputWrite =
transientWrites.get(funcStatement.getOutputParams().get(i).getName());
- transientWrites.put(sbHop.getOutputVariableNames()[i],
outputWrite);
- }
- }
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java
index 3e8f8719a6..b37386c5d0 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java
@@ -27,7 +27,6 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
-import org.junit.Ignore;
import org.junit.Test;
import java.io.File;
@@ -72,7 +71,6 @@ public class FederatedL2SVMPlanningTest extends
AutomatedTestBase {
}
@Test
- @Ignore //TODO
public void runL2SVMFunctionFOUTTest(){
String[] expectedHeavyHitters = new String[]{ "fed_fedinit",
"fed_ba+*", "fed_tak+*", "fed_+*",
"fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
@@ -81,7 +79,6 @@ public class FederatedL2SVMPlanningTest extends
AutomatedTestBase {
}
@Test
- @Ignore //TODO
public void runL2SVMFunctionHeuristicTest(){
String[] expectedHeavyHitters = new String[]{ "fed_fedinit",
"fed_ba+*"};
setTestConf("SystemDS-config-heuristic.xml");