This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 13611b0d92 [SYSTEMDS-3018] Support function calls in FedAll and 
Heuristic planners
13611b0d92 is described below

commit 13611b0d92b2423c18f3ac9725a93b39b2720aa0
Author: Kevin Innerebner <[email protected]>
AuthorDate: Thu May 8 16:40:40 2025 +0200

    [SYSTEMDS-3018] Support function calls in FedAll and Heuristic planners
    
    Closes #1666.
---
 .../hops/fedplanner/FederatedPlannerFedAll.java    | 53 ++++++++++++++++------
 .../hops/fedplanner/FederatedPlannerUtils.java     | 23 ----------
 .../fedplanning/FederatedL2SVMPlanningTest.java    |  3 --
 3 files changed, 40 insertions(+), 39 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java
index 4bf2e5606a..59967b7cf1 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java
@@ -19,13 +19,13 @@
 
 package org.apache.sysds.hops.fedplanner;
 
-import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 
 import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.common.Types.OpOpData;
 import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.FunctionOp;
 import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.fedplanner.FTypes.FType;
 import org.apache.sysds.hops.ipa.FunctionCallGraph;
@@ -89,14 +89,14 @@ public class FederatedPlannerFedAll extends 
AFederatedPlanner {
                else if (sb instanceof WhileStatementBlock) {
                        WhileStatementBlock wsb = (WhileStatementBlock) sb;
                        WhileStatement wstmt = 
(WhileStatement)wsb.getStatement(0);
-                       rRewriteHop(wsb.getPredicateHops(), new HashMap<>(), 
Collections.emptyMap());
+                       rRewriteHop(wsb.getPredicateHops(), new HashMap<>(), 
new HashMap<>(), sb.getDMLProg());
                        for (StatementBlock csb : wstmt.getBody())
                                rRewriteStatementBlock(csb, fedVars);
                }
                else if (sb instanceof IfStatementBlock) {
                        IfStatementBlock isb = (IfStatementBlock) sb;
                        IfStatement istmt = (IfStatement)isb.getStatement(0);
-                       rRewriteHop(isb.getPredicateHops(), new HashMap<>(), 
Collections.emptyMap());
+                       rRewriteHop(isb.getPredicateHops(), new HashMap<>(), 
new HashMap<>(), sb.getDMLProg());
                        for (StatementBlock csb : istmt.getIfBody())
                                rRewriteStatementBlock(csb, fedVars);
                        for (StatementBlock csb : istmt.getElseBody())
@@ -105,9 +105,9 @@ public class FederatedPlannerFedAll extends 
AFederatedPlanner {
                else if (sb instanceof ForStatementBlock) { //incl parfor
                        ForStatementBlock fsb = (ForStatementBlock) sb;
                        ForStatement fstmt = (ForStatement)fsb.getStatement(0);
-                       rRewriteHop(fsb.getFromHops(), new HashMap<>(), 
Collections.emptyMap());
-                       rRewriteHop(fsb.getToHops(), new HashMap<>(), 
Collections.emptyMap());
-                       rRewriteHop(fsb.getIncrementHops(), new HashMap<>(), 
Collections.emptyMap());
+                       rRewriteHop(fsb.getFromHops(), new HashMap<>(), new 
HashMap<>(), sb.getDMLProg());
+                       rRewriteHop(fsb.getToHops(), new HashMap<>(), new 
HashMap<>(), sb.getDMLProg());
+                       rRewriteHop(fsb.getIncrementHops(), new HashMap<>(), 
new HashMap<>(), sb.getDMLProg());
                        for (StatementBlock csb : fstmt.getBody())
                                rRewriteStatementBlock(csb, fedVars);
                }
@@ -117,9 +117,7 @@ public class FederatedPlannerFedAll extends 
AFederatedPlanner {
                        Map<Long, FType> fedHops = new HashMap<>();
                        if( sb.getHops() != null )
                                for( Hop c : sb.getHops() )
-                                       rRewriteHop(c, fedHops, fedVars);
-                       
-                       //TODO handle function calls
+                                       rRewriteHop(c, fedHops, fedVars, 
sb.getDMLProg());
                        
                        //propagate federated outputs across DAGs
                        if( sb.getHops() != null )
@@ -129,19 +127,31 @@ public class FederatedPlannerFedAll extends 
AFederatedPlanner {
                }
        }
        
-       private void rRewriteHop(Hop hop, Map<Long, FType> memo, Map<String, 
FType> fedVars) {
-               if( memo.containsKey(hop.getHopID()) )
+       private void rRewriteHop(Hop hop, Map<Long, FType> memo, Map<String, 
FType> fedVars, DMLProgram program) {
+               if( hop == null || memo.containsKey(hop.getHopID()) )
                        return; //already processed
                
                //process children first
                for( Hop c : hop.getInput() )
-                       rRewriteHop(c, memo, fedVars);
+                       rRewriteHop(c, memo, fedVars, program);
                
                //handle specific operators (except transient writes)
-               if( HopRewriteUtils.isData(hop, OpOpData.FEDERATED) )
+               if(hop instanceof FunctionOp) {
+                       String funcName = ((FunctionOp) hop).getFunctionName();
+                       String funcNamespace = ((FunctionOp) 
hop).getFunctionNamespace();
+                       FunctionStatementBlock sbFuncBlock = 
program.getFunctionDictionary(funcNamespace).getFunction(funcName);
+                       FunctionStatement funcStatement = (FunctionStatement) 
sbFuncBlock.getStatement(0);
+
+                       Map<String, FType> funcFedVars = 
createFunctionFedVarTable((FunctionOp) hop, memo);
+                       rRewriteStatementBlock(sbFuncBlock, funcFedVars);
+                       mapFunctionOutputs((FunctionOp) hop, funcStatement, 
funcFedVars, fedVars);
+               }
+               else if( HopRewriteUtils.isData(hop, OpOpData.FEDERATED) )
                        memo.put(hop.getHopID(), deriveFType((DataOp)hop));
                else if( HopRewriteUtils.isData(hop, OpOpData.TRANSIENTREAD) )
                        memo.put(hop.getHopID(), fedVars.get(hop.getName()));
+               else if( HopRewriteUtils.isData(hop, OpOpData.TRANSIENTWRITE) )
+                       fedVars.put(hop.getName(), memo.get(hop.getHopID()));
                else if( allowsFederated(hop, memo) ) {
                        hop.setForcedExecType(ExecType.FED);
                        memo.put(hop.getHopID(), getFederatedOut(hop, memo));
@@ -151,4 +161,21 @@ public class FederatedPlannerFedAll extends 
AFederatedPlanner {
                else // memoization as processed, but not federated
                        memo.put(hop.getHopID(), null);
        }
+       
+       static private Map<String, FType> createFunctionFedVarTable(FunctionOp 
hop, Map<Long, FType> memo) {
+               Map<String, Hop> funcParamMap = 
FederatedPlannerUtils.getParamMap(hop);
+               Map<String, FType> funcFedVars = new HashMap<>();
+               funcParamMap.forEach((key, value) -> {
+                       funcFedVars.put(key, memo.get(value.getHopID()));
+               });
+               return funcFedVars;
+       }
+
+       private void mapFunctionOutputs(FunctionOp sbHop, FunctionStatement 
funcStatement,
+               Map<String, FType> funcFedVars, Map<String, FType> callFedVars) 
{
+               for(int i = 0; i < sbHop.getOutputVariableNames().length; ++i) {
+                       FType outputFType = 
funcFedVars.get(funcStatement.getOutputParams().get(i).getName());
+                       callFedVars.put(sbHop.getOutputVariableNames()[i], 
outputFType);
+               }
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java
index 42c5f648f1..5951bd313a 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java
@@ -21,7 +21,6 @@ package org.apache.sysds.hops.fedplanner;
 
 import org.apache.sysds.hops.FunctionOp;
 import org.apache.sysds.hops.Hop;
-import org.apache.sysds.parser.FunctionStatement;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
 
@@ -76,26 +75,4 @@ public class FederatedPlannerUtils {
                }
                return paramMap;
        }
-
-       /**
-        * Saves the HOPs (TWrite) of the function return values for
-        * the variable name used when calling the function.
-        *
-        * Example:
-        * <code>
-        *     f = function() return (matrix[double] model) {a = rand(1, 1);}
-        *     b = f();
-        * </code>
-        * This function saves the HOP writing to <code>a</code> for identifier 
<code>b</code>.
-        *
-        * @param sbHop The <code>FunctionOp</code> for the call
-        * @param funcStatement The <code>FunctionStatement</code> of the 
called function
-        * @param transientWrites map of transient writes
-        */
-       public static void mapFunctionOutputs(FunctionOp sbHop, 
FunctionStatement funcStatement, Map<String,Hop> transientWrites) {
-               for (int i = 0; i < sbHop.getOutputVariableNames().length; ++i) 
{
-                       Hop outputWrite = 
transientWrites.get(funcStatement.getOutputParams().get(i).getName());
-                       transientWrites.put(sbHop.getOutputVariableNames()[i], 
outputWrite);
-               }
-       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java
index 3e8f8719a6..b37386c5d0 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java
@@ -27,7 +27,6 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
-import org.junit.Ignore;
 import org.junit.Test;
 
 import java.io.File;
@@ -72,7 +71,6 @@ public class FederatedL2SVMPlanningTest extends 
AutomatedTestBase {
        }
 
        @Test
-       @Ignore //TODO
        public void runL2SVMFunctionFOUTTest(){
                String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*", "fed_tak+*", "fed_+*",
                        "fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
@@ -81,7 +79,6 @@ public class FederatedL2SVMPlanningTest extends 
AutomatedTestBase {
        }
 
        @Test
-       @Ignore //TODO
        public void runL2SVMFunctionHeuristicTest(){
                String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*"};
                setTestConf("SystemDS-config-heuristic.xml");

Reply via email to