This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 29b4d92858 [SYSTEMDS-3790] Rework FedPlanner memo table, cost 
estimator, enumerator
29b4d92858 is described below

commit 29b4d9285867bfd7e5b92276ca25137e38d9fb2a
Author: min-guk <[email protected]>
AuthorDate: Sat Dec 21 16:04:56 2024 +0100

    [SYSTEMDS-3790] Rework FedPlanner memo table, cost estimator, enumerator
    
    Closes #2147.
---
 .../sysds/hops/fedplanner/FederatedMemoTable.java  | 288 +++++++++++++++++++++
 .../fedplanner/FederatedPlanCostEnumerator.java    | 136 ++++++++++
 .../fedplanner/FederatedPlanCostEstimator.java     | 116 +++++++++
 .../apache/sysds/hops/fedplanner/MemoTable.java    | 160 ------------
 .../federated/FederatedPlanCostEnumeratorTest.java |  87 +++++++
 .../test/component/federated/MemoTableTest.java    | 186 -------------
 src/test/scripts/functions/federated/cost.dml      |  25 ++
 7 files changed, 652 insertions(+), 346 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java
new file mode 100644
index 0000000000..16240f0281
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java
@@ -0,0 +1,288 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.fedplanner;
+
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.ArrayList;
+import java.util.Map;
+import java.util.HashSet;
+import java.util.Set;
+
+/**
+ * A Memoization Table for managing federated plans (FedPlan) based on 
combinations of Hops and fedOutTypes.
+ * This table stores and manages different execution plan variants for each 
Hop and fedOutType combination,
+ * facilitating the optimization of federated execution plans.
+ */
+public class FederatedMemoTable {
+       // Maps Hop ID and fedOutType pairs to their plan variants
+       private final Map<Pair<Long, FederatedOutput>, FedPlanVariants> 
hopMemoTable = new HashMap<>();
+
+       /**
+        * Adds a new federated plan to the memo table.
+        * Creates a new variant list if none exists for the given Hop and 
fedOutType.
+        *
+        * @param hop            The Hop node
+        * @param fedOutType  The federated output type
+        * @param planChilds  List of child plan references
+        * @return                 The newly created FedPlan
+        */
+       public FedPlan addFedPlan(Hop hop, FederatedOutput fedOutType, 
List<Pair<Long, FederatedOutput>> planChilds) {
+               long hopID = hop.getHopID();
+               FedPlanVariants fedPlanVariantList;
+
+               if (contains(hopID, fedOutType)) {
+                       fedPlanVariantList = hopMemoTable.get(new 
ImmutablePair<>(hopID, fedOutType));
+               } else {
+                       fedPlanVariantList = new FedPlanVariants(hop, 
fedOutType);
+                       hopMemoTable.put(new ImmutablePair<>(hopID, 
fedOutType), fedPlanVariantList);
+               }
+
+               FedPlan newPlan = new FedPlan(planChilds, fedPlanVariantList);
+               fedPlanVariantList.addFedPlan(newPlan);
+
+               return newPlan;
+       }
+
+       /**
+        * Retrieves the minimum cost child plan considering the parent's 
output type.
+        * The cost is calculated using getParentViewCost to account for 
potential type mismatches.
+        * 
+        * @param childHopID ?
+        * @param childFedOutType ?
+        * @return ?
+        */
+       public FedPlan getMinCostChildFedPlan(long childHopID, FederatedOutput 
childFedOutType) {
+               FedPlanVariants fedPlanVariantList = hopMemoTable.get(new 
ImmutablePair<>(childHopID, childFedOutType));
+               return fedPlanVariantList._fedPlanVariants.stream()
+                               
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
+                               .orElse(null);
+       }
+
+       public FedPlanVariants getFedPlanVariants(long hopID, FederatedOutput 
fedOutType) {
+               return hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType));
+       }
+
+       /**
+        * Checks if the memo table contains an entry for a given Hop and 
fedOutType.
+        *
+        * @param hopID   The Hop ID.
+        * @param fedOutType The associated fedOutType.
+        * @return True if the entry exists, false otherwise.
+        */
+       public boolean contains(long hopID, FederatedOutput fedOutType) {
+               return hopMemoTable.containsKey(new ImmutablePair<>(hopID, 
fedOutType));
+       }
+
+       /**
+        * Prunes all entries in the memo table, retaining only the minimum-cost
+        * FedPlan for each entry.
+        */
+       public void pruneMemoTable() {
+               for (Map.Entry<Pair<Long, FederatedOutput>, FedPlanVariants> 
entry : hopMemoTable.entrySet()) {
+                       List<FedPlan> fedPlanList = 
entry.getValue().getFedPlanVariants();
+                       if (fedPlanList.size() > 1) {
+                               // Find the FedPlan with the minimum cost
+                               FedPlan minCostPlan = fedPlanList.stream()
+                                                       
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
+                                               .orElse(null);
+
+                               // Retain only the minimum cost plan
+                               fedPlanList.clear();
+                               fedPlanList.add(minCostPlan);
+                       }
+               }
+       }
+
+       /**
+        * Recursively prints a tree representation of the DAG starting from 
the given root FedPlan.
+        * Includes information about hopID, fedOutType, TotalCost, SelfCost, 
and NetCost for each node.
+        *
+        * @param rootFedPlan The starting point FedPlan to print
+        */
+       public void printFedPlanTree(FedPlan rootFedPlan) {
+               Set<FedPlan> visited = new HashSet<>();
+               printFedPlanTreeRecursive(rootFedPlan, visited, 0, true);
+       }
+
+       /**
+        * Helper method to recursively print the FedPlan tree.
+        *
+        * @param plan  The current FedPlan to print
+        * @param visited Set to keep track of visited FedPlans (prevents 
cycles)
+        * @param depth   The current depth level for indentation
+        * @param isLast  Whether this node is the last child of its parent
+        */
+       private void printFedPlanTreeRecursive(FedPlan plan, Set<FedPlan> 
visited, int depth, boolean isLast) {
+               if (plan == null || visited.contains(plan)) {
+                       return;
+               }
+
+               visited.add(plan);
+
+               Hop hop = plan.getHopRef();
+               StringBuilder sb = new StringBuilder();
+
+               // Add FedPlan information
+               sb.append(String.format("(%d) ", plan.getHopRef().getHopID()))
+                               .append(plan.getHopRef().getOpString())
+                               .append(" [")
+                               .append(plan.getFedOutType())
+                               .append("]");
+
+               StringBuilder childs = new StringBuilder();
+               childs.append(" (");
+               boolean childAdded = false;
+               for( Hop input : hop.getInput()){
+                       childs.append(childAdded?",":"");
+                       childs.append(input.getHopID());
+                       childAdded = true;
+               }
+               childs.append(")");
+               if( childAdded )
+                       sb.append(childs.toString());
+                
+                
+               sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}",
+                               plan.getTotalCost(),
+                               plan.getSelfCost(),
+                               plan.getNetTransferCost()));
+
+               // Add matrix characteristics
+               sb.append(" [")
+                       .append(hop.getDim1()).append(", ")
+                       .append(hop.getDim2()).append(", ")
+                       .append(hop.getBlocksize()).append(", ")
+                       .append(hop.getNnz());
+
+               if (hop.getUpdateType().isInPlace()) {
+                       sb.append(", 
").append(hop.getUpdateType().toString().toLowerCase());
+               }
+               sb.append("]");
+
+               // Add memory estimates
+               sb.append(" [")
+                       
.append(OptimizerUtils.toMB(hop.getInputMemEstimate())).append(", ")
+                       
.append(OptimizerUtils.toMB(hop.getIntermediateMemEstimate())).append(", ")
+                       
.append(OptimizerUtils.toMB(hop.getOutputMemEstimate())).append(" -> ")
+                       
.append(OptimizerUtils.toMB(hop.getMemEstimate())).append("MB]");
+
+               // Add reblock and checkpoint requirements
+               if (hop.requiresReblock() && hop.requiresCheckpoint()) {
+                       sb.append(" [rblk, chkpt]");
+               } else if (hop.requiresReblock()) {
+                       sb.append(" [rblk]");
+               } else if (hop.requiresCheckpoint()) {
+                       sb.append(" [chkpt]");
+               }
+
+               // Add execution type
+               if (hop.getExecType() != null) {
+                       sb.append(", ").append(hop.getExecType());
+               }
+
+               System.out.println(sb);
+
+               // Process child nodes
+               List<Pair<Long, FederatedOutput>> childRefs = 
plan.getChildFedPlans();
+               for (int i = 0; i < childRefs.size(); i++) {
+                       Pair<Long, FederatedOutput> childRef = childRefs.get(i);
+                       FedPlanVariants childVariants = 
getFedPlanVariants(childRef.getLeft(), childRef.getRight());
+                       if (childVariants == null || 
childVariants.getFedPlanVariants().isEmpty())
+                               continue;
+
+                       boolean isLastChild = (i == childRefs.size() - 1);
+                       for (FedPlan childPlan : 
childVariants.getFedPlanVariants()) {
+                               printFedPlanTreeRecursive(childPlan, visited, 
depth + 1, isLastChild);
+                       }
+               }
+       }
+
+       /**
+        * Represents a collection of federated execution plan variants for a 
specific Hop.
+        * Contains cost information and references to the associated plans.
+        */
+       public static class FedPlanVariants {
+               protected final Hop hopRef;              // Reference to the 
associated Hop
+               protected double selfCost;         // Current execution cost 
(compute + memory access)
+               protected double netTransferCost;   // Network transfer cost
+               private final FederatedOutput fedOutType;          // Output 
type (FOUT/LOUT)
+               protected List<FedPlan> _fedPlanVariants;       // List of plan 
variants
+
+               public FedPlanVariants(Hop hopRef, FederatedOutput fedOutType) {
+                       this.hopRef = hopRef;
+                       this.fedOutType = fedOutType;
+                       this.selfCost = 0;
+                       this.netTransferCost = 0;
+                       this._fedPlanVariants = new ArrayList<>();
+               }
+
+               public int size() {return _fedPlanVariants.size();}
+               public void addFedPlan(FedPlan fedPlan) 
{_fedPlanVariants.add(fedPlan);}
+               public List<FedPlan> getFedPlanVariants() {return 
_fedPlanVariants;}
+       }
+
+       /**
+        * Represents a single federated execution plan with its associated 
costs and dependencies.
+        * Contains:
+        * 1. selfCost: Cost of current hop (compute + input/output memory 
access)
+        * 2. totalCost: Cumulative cost including this plan and all child plans
+        * 3. netTransferCost: Network transfer cost for this plan to parent 
plan.
+        */
+       public static class FedPlan {
+               private double totalCost;                                 // 
Total cost including child plans
+               private final FedPlanVariants fedPlanVariants;  // Reference to 
variant list
+               private final List<Pair<Long, FederatedOutput>> childFedPlans;  
// Child plan references
+
+               public FedPlan(List<Pair<Long, FederatedOutput>> childFedPlans, 
FedPlanVariants fedPlanVariants) {
+                       this.totalCost = 0;
+                       this.childFedPlans = childFedPlans;
+                       this.fedPlanVariants = fedPlanVariants;
+               }
+
+               public void setTotalCost(double totalCost) {this.totalCost = 
totalCost;}
+               public void setSelfCost(double selfCost) 
{fedPlanVariants.selfCost = selfCost;}
+               public void setNetTransferCost(double netTransferCost) 
{fedPlanVariants.netTransferCost = netTransferCost;}
+
+               public Hop getHopRef() {return fedPlanVariants.hopRef;}
+               public FederatedOutput getFedOutType() {return 
fedPlanVariants.fedOutType;}
+               public double getTotalCost() {return totalCost;}
+               public double getSelfCost() {return fedPlanVariants.selfCost;}
+               private double getNetTransferCost() {return 
fedPlanVariants.netTransferCost;}
+               public List<Pair<Long, FederatedOutput>> getChildFedPlans() 
{return childFedPlans;}
+
+               /**
+                * Calculates the conditional network transfer cost based on 
output type compatibility.
+                * Returns 0 if output types match, otherwise returns the 
network transfer cost.
+                * @param parentFedOutType ?
+                * @return ?
+                */
+               public double getCondNetTransferCost(FederatedOutput 
parentFedOutType) {
+                       if (parentFedOutType == getFedOutType()) return 0;
+                       return fedPlanVariants.netTransferCost;
+               }
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java
 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java
new file mode 100644
index 0000000000..73e8d5d693
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.fedplanner;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Comparator;
+import java.util.Objects;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan;
+import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants;
+import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+
+/**
+ * Enumerates and evaluates all possible federated execution plans for a given 
Hop DAG.
+ * Works with FederatedMemoTable to store plan variants and 
FederatedPlanCostEstimator
+ * to compute their costs.
+ */
+public class FederatedPlanCostEnumerator {
+       /**
+        * Entry point for federated plan enumeration. Creates a memo table and 
returns
+        * the minimum cost plan for the entire DAG.
+        * 
+        * @param rootHop ?
+        * @param printTree ?
+        * @return ?
+        */
+       public static FedPlan enumerateFederatedPlanCost(Hop rootHop, boolean 
printTree) {
+               // Create new memo table to store all plan variants
+               FederatedMemoTable memoTable = new FederatedMemoTable();
+
+               // Recursively enumerate all possible plans
+               enumerateFederatedPlanCost(rootHop, memoTable);
+
+               // Return the minimum cost plan for the root node
+               FedPlan optimalPlan = getMinCostRootFedPlan(rootHop.getHopID(), 
memoTable);
+               memoTable.pruneMemoTable();
+               if (printTree) memoTable.printFedPlanTree(optimalPlan);
+
+               return optimalPlan;
+       }
+
+       /**
+        * Recursively enumerates all possible federated execution plans for a 
Hop DAG.
+        * For each node:
+        * 1. First processes all input nodes recursively if not already 
processed
+        * 2. Generates all possible combinations of federation types 
(FOUT/LOUT) for inputs
+        * 3. Creates and evaluates both FOUT and LOUT variants for current 
node with each input combination
+        * 
+        * The enumeration uses a bottom-up approach where:
+        * - Each input combination is represented by a binary number (i)
+        * - Bit j in i determines whether input j is FOUT (1) or LOUT (0)
+        * - Total number of combinations is 2^numInputs
+        * 
+        * @param hop ?
+        * @param memoTable ?
+        */
+       private static void enumerateFederatedPlanCost(Hop hop, 
FederatedMemoTable memoTable) {
+               int numInputs = hop.getInput().size();
+
+               // Process all input nodes first if not already in memo table
+               for (Hop inputHop : hop.getInput()) {
+                       if (!memoTable.contains(inputHop.getHopID(), 
FederatedOutput.FOUT) 
+                               && !memoTable.contains(inputHop.getHopID(), 
FederatedOutput.LOUT)) {
+                                       enumerateFederatedPlanCost(inputHop, 
memoTable);
+                       }
+               }
+
+               // Generate all possible input combinations using binary 
representation
+               // i represents a specific combination of FOUT/LOUT for inputs
+               for (int i = 0; i < (1 << numInputs); i++) {
+                       List<Pair<Long, FederatedOutput>> planChilds = new 
ArrayList<>(); 
+
+                       // For each input, determine if it should be FOUT or 
LOUT based on bit j in i
+                       for (int j = 0; j < numInputs; j++) {
+                               Hop inputHop = hop.getInput().get(j);
+                               // If bit j is set (1), use FOUT; otherwise use 
LOUT
+                               FederatedOutput childType = ((i & (1 << j)) != 
0) ?
+                                       FederatedOutput.FOUT : 
FederatedOutput.LOUT;
+                               planChilds.add(Pair.of(inputHop.getHopID(), 
childType));
+                       }
+                       
+                       // Create and evaluate FOUT variant for current input 
combination
+                       FedPlan fOutPlan = memoTable.addFedPlan(hop, 
FederatedOutput.FOUT, planChilds);
+                       
FederatedPlanCostEstimator.computeFederatedPlanCost(fOutPlan, memoTable);
+
+                       // Create and evaluate LOUT variant for current input 
combination
+                       FedPlan lOutPlan = memoTable.addFedPlan(hop, 
FederatedOutput.LOUT, planChilds);
+                       
FederatedPlanCostEstimator.computeFederatedPlanCost(lOutPlan, memoTable);
+               }
+       }
+
+       /**
+        * Returns the minimum cost plan for the root Hop, comparing both FOUT 
and LOUT variants.
+        * Used to select the final execution plan after enumeration.
+        * 
+        * @param HopID ?
+        * @param memoTable ?
+        * @return ?
+        */
+       private static FedPlan getMinCostRootFedPlan(long HopID, 
FederatedMemoTable memoTable) {
+               FedPlanVariants fOutFedPlanVariants = 
memoTable.getFedPlanVariants(HopID, FederatedOutput.FOUT);
+               FedPlanVariants lOutFedPlanVariants = 
memoTable.getFedPlanVariants(HopID, FederatedOutput.LOUT);
+
+               FedPlan minFOutFedPlan = 
fOutFedPlanVariants._fedPlanVariants.stream()
+                                                                       
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
+                                                                       
.orElse(null);
+               FedPlan minlOutFedPlan = 
lOutFedPlanVariants._fedPlanVariants.stream()
+                                                                       
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
+                                                                       
.orElse(null);
+
+               if (Objects.requireNonNull(minFOutFedPlan).getTotalCost()
+                               < 
Objects.requireNonNull(minlOutFedPlan).getTotalCost()) {
+                       return minFOutFedPlan;
+               }
+               return minlOutFedPlan;
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java
 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java
new file mode 100644
index 0000000000..a716c3321d
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.fedplanner;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.cost.ComputeCost;
+import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan;
+import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+
+/**
+ * Cost estimator for federated execution plans.
+ * Calculates computation, memory access, and network transfer costs for 
federated operations.
+ * Works in conjunction with FederatedMemoTable to evaluate different 
execution plan variants.
+ */
+public class FederatedPlanCostEstimator {
+       // Default value is used as a reasonable estimate since we only need
+       // to compare relative costs between different federated plans
+       // Memory bandwidth for local computations (25 GB/s)
+       private static final double DEFAULT_MBS_MEMORY_BANDWIDTH = 25000.0;
+       // Network bandwidth for data transfers between federated sites (1 Gbps)
+       private static final double DEFAULT_MBS_NETWORK_BANDWIDTH = 125.0;
+
+       /**
+        * Computes total cost of federated plan by:
+        * 1. Computing current node cost (if not cached)
+        * 2. Adding minimum-cost child plans
+        * 3. Including network transfer costs when needed
+        *
+        * @param currentPlan Plan to compute cost for
+        * @param memoTable Table containing all plan variants
+        */
+       public static void computeFederatedPlanCost(FedPlan currentPlan, 
FederatedMemoTable memoTable) {
+               double totalCost = 0;
+               Hop currentHop = currentPlan.getHopRef();
+
+               // Step 1: Calculate current node costs if not already computed
+               if (currentPlan.getSelfCost() == 0) {
+                       // Compute cost for current node (computation + memory 
access)
+                       totalCost = computeCurrentCost(currentHop);
+                       currentPlan.setSelfCost(totalCost);
+                       // Calculate potential network transfer cost if 
federation type changes
+                       
currentPlan.setNetTransferCost(computeHopNetworkAccessCost(currentHop.getOutputMemEstimate()));
+               } else {
+                       totalCost = currentPlan.getSelfCost();
+               }
+               
+               // Step 2: Process each child plan and add their costs
+               for (Pair<Long, FederatedOutput> planRefMeta : 
currentPlan.getChildFedPlans()) {
+                       // Find minimum cost child plan considering federation 
type compatibility
+                       // Note: This approach might lead to suboptimal or 
wrong solutions when a child has multiple parents
+                       // because we're selecting child plans independently 
for each parent
+                       FedPlan planRef = 
memoTable.getMinCostChildFedPlan(planRefMeta.getLeft(), planRefMeta.getRight());
+
+                       // Add child plan cost (includes network transfer cost 
if federation types differ)
+                       totalCost += planRef.getTotalCost() + 
planRef.getCondNetTransferCost(currentPlan.getFedOutType());
+               }
+               
+               // Step 3: Set final cumulative cost including current node
+               currentPlan.setTotalCost(totalCost);
+       }
+
+       /**
+        * Computes the cost for the current Hop node.
+        * 
+        * @param currentHop The Hop node whose cost needs to be computed
+        * @return The total cost for the current node's operation
+        */
+       private static double computeCurrentCost(Hop currentHop){
+               double computeCost = ComputeCost.getHOPComputeCost(currentHop);
+               double inputAccessCost = 
computeHopMemoryAccessCost(currentHop.getInputMemEstimate());
+               double ouputAccessCost = 
computeHopMemoryAccessCost(currentHop.getOutputMemEstimate());
+               
+               // Compute total cost assuming:
+               // 1. Computation and input access can be overlapped (hence 
taking max)
+               // 2. Output access must wait for both to complete (hence 
adding)
+               return Math.max(computeCost, inputAccessCost) + ouputAccessCost;
+       }
+
+       /**
+        * Calculates the memory access cost based on data size and memory 
bandwidth.
+        * 
+        * @param memSize Size of data to be accessed (in bytes)
+        * @return Time cost for memory access (in seconds)
+        */
+       private static double computeHopMemoryAccessCost(double memSize) {
+               return memSize / (1024*1024) / DEFAULT_MBS_MEMORY_BANDWIDTH;
+       }
+
+       /**
+        * Calculates the network transfer cost based on data size and network 
bandwidth.
+        * Used when federation status changes between parent and child plans.
+        * 
+        * @param memSize Size of data to be transferred (in bytes)
+        * @return Time cost for network transfer (in seconds)
+        */
+       private static double computeHopNetworkAccessCost(double memSize) {
+               return memSize / (1024*1024) / DEFAULT_MBS_NETWORK_BANDWIDTH;
+       }
+}
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
deleted file mode 100644
index f11b17b984..0000000000
--- a/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
+++ /dev/null
@@ -1,160 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.hops.fedplanner;
-
-import org.apache.sysds.hops.Hop;
-import org.apache.sysds.hops.fedplanner.FTypes.FType;
-import org.apache.commons.lang3.tuple.Pair;
-import org.apache.commons.lang3.tuple.ImmutablePair;
-
-import java.util.Comparator;
-import java.util.HashMap;
-import java.util.List;
-import java.util.ArrayList;
-import java.util.Map;
-
-/**
- * A Memoization Table for managing federated plans (`FedPlan`) based on
- * combinations of Hops and FTypes. Each combination is mapped to a list
- * of possible execution plans, allowing for pruning and optimization.
- */
-public class MemoTable {
-
-       // Maps combinations of Hop ID and FType to lists of FedPlans
-       private final Map<Pair<Long, FTypes.FType>, List<FedPlan>> hopMemoTable 
= new HashMap<>();
-       
-       /**
-        * Represents a federated execution plan with its cost and associated 
references.
-        */
-       public static class FedPlan {
-               @SuppressWarnings("unused")
-               private final Hop hopRef;                       // The 
associated Hop object
-               private final double cost;                      // Cost of this 
federated plan
-               @SuppressWarnings("unused")
-               private final List<Pair<Long, FType>> planRefs; // References 
to dependent plans
-
-               public FedPlan(Hop hopRef, double cost, List<Pair<Long, FType>> 
planRefs) {
-                       this.hopRef = hopRef;
-                       this.cost = cost;
-                       this.planRefs = planRefs;
-               }
-
-               public double getCost() {
-                       return cost;
-               }
-       }
-
-       /**
-        * Adds a single FedPlan to the memo table for a given Hop and FType.
-        * If the entry already exists, the new FedPlan is appended to the list.
-        *
-        * @param hop     The Hop object.
-        * @param fType   The associated FType.
-        * @param fedPlan The FedPlan to add.
-        */
-       public void addFedPlan(Hop hop, FType fType, FedPlan fedPlan) {
-               if (contains(hop, fType)) {
-                       List<FedPlan> fedPlanList = get(hop, fType);
-                       fedPlanList.add(fedPlan);
-               } else {
-                       List<FedPlan> fedPlanList = new ArrayList<>();
-                       fedPlanList.add(fedPlan);
-                       hopMemoTable.put(new ImmutablePair<>(hop.getHopID(), 
fType), fedPlanList);
-               }
-       }
-
-       /**
-        * Adds multiple FedPlans to the memo table for a given Hop and FType.
-        * If the entry already exists, the new FedPlans are appended to the 
list.
-        *
-        * @param hop    The Hop object.
-        * @param fType  The associated FType.
-        * @param fedPlanList  The list of FedPlans to add.
-        */
-       public void addFedPlanList(Hop hop, FType fType, List<FedPlan> 
fedPlanList) {
-               if (contains(hop, fType)) {
-                       List<FedPlan> prevFedPlanList = get(hop, fType);
-                       prevFedPlanList.addAll(fedPlanList);
-               } else {
-                       hopMemoTable.put(new ImmutablePair<>(hop.getHopID(), 
fType), fedPlanList);
-               }
-       }
-
-       /**
-        * Retrieves the list of FedPlans associated with a given Hop and FType.
-        *
-        * @param hop   The Hop object.
-        * @param fType The associated FType.
-        * @return The list of FedPlans, or null if no entry exists.
-        */
-       public List<FedPlan> get(Hop hop, FType fType) {
-               return hopMemoTable.get(new ImmutablePair<>(hop.getHopID(), 
fType));
-       }
-
-       /**
-        * Checks if the memo table contains an entry for a given Hop and FType.
-        *
-        * @param hop   The Hop object.
-        * @param fType The associated FType.
-        * @return True if the entry exists, false otherwise.
-        */
-       public boolean contains(Hop hop, FType fType) {
-               return hopMemoTable.containsKey(new 
ImmutablePair<>(hop.getHopID(), fType));
-       }
-
-       /**
-        * Prunes the FedPlans associated with a specific Hop and FType,
-        * keeping only the plan with the minimum cost.
-        *
-        * @param hop   The Hop object.
-        * @param fType The associated FType.
-        */
-       public void prunePlan(Hop hop, FType fType) {
-               prunePlan(hopMemoTable.get(new ImmutablePair<>(hop.getHopID(), 
fType)));
-       }
-
-       /**
-        * Prunes all entries in the memo table, retaining only the minimum-cost
-        * FedPlan for each entry.
-        */
-       public void pruneAll() {
-               for (Map.Entry<Pair<Long, FType>, List<FedPlan>> entry : 
hopMemoTable.entrySet()) {
-                       prunePlan(entry.getValue());
-               }
-       }
-
-       /**
-        * Prunes the given list of FedPlans to retain only the plan with the 
minimum cost.
-        *
-        * @param fedPlanList The list of FedPlans to prune.
-        */
-       private void prunePlan(List<FedPlan> fedPlanList) {
-               if (fedPlanList.size() > 1) {
-                       // Find the FedPlan with the minimum cost
-                       FedPlan minCostPlan = fedPlanList.stream()
-                                       .min(Comparator.comparingDouble(plan -> 
plan.cost))
-                                       .orElse(null);
-
-                       // Retain only the minimum cost plan
-                       fedPlanList.clear();
-                       fedPlanList.add(minCostPlan);
-               }
-       }
-}
diff --git 
a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java
 
b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java
new file mode 100644
index 0000000000..56de8cf3c4
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.federated;
+
+import java.io.IOException;
+import java.util.HashMap;
+
+import org.apache.sysds.hops.Hop;
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.DMLTranslator;
+import org.apache.sysds.parser.ParserFactory;
+import org.apache.sysds.parser.ParserWrapper;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.hops.fedplanner.FederatedPlanCostEnumerator;
+
+
+public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase
+{
+       private static final String TEST_DIR = "functions/federated/";
+       private static final String HOME = SCRIPT_DIR + TEST_DIR;
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedPlanCostEnumeratorTest.class.getSimpleName() + "/";
+       
+       @Override
+       public void setUp() {}
+       
+       @Test
+       public void testDependencyAnalysis1() { runTest("cost.dml"); }
+       
+       private void runTest( String scriptFilename ) {
+               int index = scriptFilename.lastIndexOf(".dml");
+               String testName = scriptFilename.substring(0, index > 0 ? index 
: scriptFilename.length());
+               TestConfiguration testConfig = new 
TestConfiguration(TEST_CLASS_DIR, testName, new String[] {});
+               addTestConfiguration(testName, testConfig);
+               loadTestConfiguration(testConfig);
+               
+               try {
+                       DMLConfig conf = new 
DMLConfig(getCurConfigFile().getPath());
+                       ConfigurationManager.setLocalConfig(conf);
+                       
+                       //read script
+                       String dmlScriptString = DMLScript.readDMLScript(true, 
HOME + scriptFilename);
+               
+                       //parsing and dependency analysis
+                       ParserWrapper parser = ParserFactory.createParser();
+                       DMLProgram prog = 
parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new 
HashMap<>());
+                       DMLTranslator dmlt = new DMLTranslator(prog);
+                       dmlt.liveVariableAnalysis(prog);
+                       dmlt.validateParseTree(prog);
+                       dmlt.constructHops(prog);
+                       dmlt.rewriteHopsDAG(prog);
+                       dmlt.constructLops(prog);
+                       /* TODO) In the current DAG, Hop's _outputMemEstimate 
is not initialized
+                       // This leads to incorrect fedplan generation, so test 
code needs to be modified
+                       // If needed, modify costEstimator to handle cases 
where _outputMemEstimate is not initialized
+                       */
+                       Hop hops = 
prog.getStatementBlocks().get(0).getHops().get(0);
+                       
FederatedPlanCostEnumerator.enumerateFederatedPlanCost(hops, true);
+               }
+               catch (IOException e) {
+                       e.printStackTrace();
+                       Assert.fail();
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/component/federated/MemoTableTest.java 
b/src/test/java/org/apache/sysds/test/component/federated/MemoTableTest.java
deleted file mode 100644
index e3928c1263..0000000000
--- a/src/test/java/org/apache/sysds/test/component/federated/MemoTableTest.java
+++ /dev/null
@@ -1,186 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.test.component.federated;
-
-import org.apache.sysds.hops.Hop;
-import org.apache.sysds.hops.fedplanner.FTypes;
-import org.apache.sysds.hops.fedplanner.MemoTable;
-import org.apache.sysds.hops.fedplanner.MemoTable.FedPlan;
-import org.apache.commons.lang3.tuple.Pair;
-import org.junit.Before;
-import org.junit.Test;
-import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
-
-import java.util.ArrayList;
-import java.util.List;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertNull;
-import static org.junit.Assert.assertTrue;
-import static org.mockito.Mockito.when;
-
-public class MemoTableTest {
-       
-       private MemoTable memoTable;
-       
-       @Mock
-       private Hop mockHop1;
-       
-       @Mock
-       private Hop mockHop2;
-       
-       private java.util.Random rand;
-
-       @Before
-       public void setUp() {
-               MockitoAnnotations.openMocks(this);
-               memoTable = new MemoTable();
-               
-               // Set up unique IDs for mock Hops
-               when(mockHop1.getHopID()).thenReturn(1L);
-               when(mockHop2.getHopID()).thenReturn(2L);
-               
-               // Initialize random generator with fixed seed for reproducible 
tests
-               rand = new java.util.Random(42); 
-       }
-       
-       @Test
-       public void testAddAndGetSingleFedPlan() {
-               // Initialize test data
-               List<Pair<Long, FTypes.FType>> planRefs = new ArrayList<>();
-               FedPlan fedPlan = new FedPlan(mockHop1, 10.0, planRefs);
-               
-               // Verify initial state
-               List<FedPlan> result = memoTable.get(mockHop1, 
FTypes.FType.FULL);
-               assertNull("Initial FedPlan list should be null before adding 
any plans", result);
-
-               // Add single FedPlan
-               memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, fedPlan);
-               
-               // Verify after addition
-               result = memoTable.get(mockHop1, FTypes.FType.FULL);
-               assertNotNull("FedPlan list should exist after adding a plan", 
result);
-               assertEquals("FedPlan list should contain exactly one plan", 1, 
result.size());
-               assertEquals("FedPlan cost should be exactly 10.0", 10.0, 
result.get(0).getCost(), 0.001);
-       }
-       
-       @Test
-       public void testAddMultipleDuplicatedFedPlans() {
-               // Initialize test data with duplicate costs
-               List<Pair<Long, FTypes.FType>> planRefs = new ArrayList<>();
-               List<FedPlan> fedPlans = new ArrayList<>();
-               fedPlans.add(new FedPlan(mockHop1, 10.0, planRefs));  // Unique 
cost
-               fedPlans.add(new FedPlan(mockHop1, 20.0, planRefs));  // First 
duplicate
-               fedPlans.add(new FedPlan(mockHop1, 20.0, planRefs));  // Second 
duplicate
-               
-               // Add multiple plans including duplicates
-               memoTable.addFedPlanList(mockHop1, FTypes.FType.FULL, fedPlans);
-               
-               // Verify handling of duplicate plans
-               List<FedPlan> result = memoTable.get(mockHop1, 
FTypes.FType.FULL);
-               assertNotNull("FedPlan list should exist after adding multiple 
plans", result);
-               assertEquals("FedPlan list should maintain all plans including 
duplicates", 3, result.size());
-       }
-       
-       @Test
-       public void testContains() {
-               // Initialize test data
-               List<Pair<Long, FTypes.FType>> planRefs = new ArrayList<>();
-               FedPlan fedPlan = new FedPlan(mockHop1, 10.0, planRefs);
-               
-               // Verify initial state
-               assertFalse("MemoTable should not contain any entries 
initially", 
-                       memoTable.contains(mockHop1, FTypes.FType.FULL));
-               
-               // Add plan and verify presence
-               memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, fedPlan);
-               
-               assertTrue("MemoTable should contain entry after adding 
FedPlan", 
-                       memoTable.contains(mockHop1, FTypes.FType.FULL));
-               assertFalse("MemoTable should not contain entries for different 
Hop", 
-                       memoTable.contains(mockHop2, FTypes.FType.FULL));
-       }
-       
-       @Test
-       public void testPrunePlanPruneAll() {
-               // Initialize base test data
-               List<Pair<Long, FTypes.FType>> planRefs = new ArrayList<>();
-               // Create separate FedPlan lists for independent testing of 
each Hop
-               List<FedPlan> fedPlans1 = new ArrayList<>();  // Plans for 
mockHop1
-               List<FedPlan> fedPlans2 = new ArrayList<>();  // Plans for 
mockHop2
-               
-               // Generate random cost FedPlans for both Hops
-               double minCost = Double.MAX_VALUE;
-               int size = 100;
-               for(int i = 0; i < size; i++) {
-                       double cost = rand.nextDouble() * 1000;  // Random cost 
between 0 and 1000
-                       fedPlans1.add(new FedPlan(mockHop1, cost, planRefs));
-                       fedPlans2.add(new FedPlan(mockHop2, cost, planRefs));
-                       minCost = Math.min(minCost, cost);
-               }
-               
-               // Add FedPlan lists to MemoTable
-               memoTable.addFedPlanList(mockHop1, FTypes.FType.FULL, 
fedPlans1);
-               memoTable.addFedPlanList(mockHop2, FTypes.FType.FULL, 
fedPlans2);
-               
-               // Test selective pruning on mockHop1
-               memoTable.prunePlan(mockHop1, FTypes.FType.FULL);
-               
-               // Get results for verification
-               List<FedPlan> result1 = memoTable.get(mockHop1, 
FTypes.FType.FULL);
-               List<FedPlan> result2 = memoTable.get(mockHop2, 
FTypes.FType.FULL);
-
-               // Verify selective pruning results
-               assertNotNull("Pruned mockHop1 should maintain a FedPlan list", 
result1);
-               assertEquals("Pruned mockHop1 should contain exactly one 
minimum cost plan", 1, result1.size());
-               assertEquals("Pruned mockHop1's plan should have the minimum 
cost", minCost, result1.get(0).getCost(), 0.001);
-               
-               // Verify unpruned Hop state
-               assertNotNull("Unpruned mockHop2 should maintain a FedPlan 
list", result2);
-               assertEquals("Unpruned mockHop2 should maintain all original 
plans", size, result2.size());
-
-               // Add additional plans to both Hops
-               for(int i = 0; i < size; i++) {
-                       double cost = rand.nextDouble() * 1000;
-                       memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, new 
FedPlan(mockHop1, cost, planRefs));
-                       memoTable.addFedPlan(mockHop2, FTypes.FType.FULL, new 
FedPlan(mockHop2, cost, planRefs));
-                       minCost = Math.min(minCost, cost);
-               }
-
-               // Test global pruning
-               memoTable.pruneAll();
-               
-               // Verify global pruning results
-               assertNotNull("mockHop1 should maintain a FedPlan list after 
global pruning", result1);
-               assertEquals("mockHop1 should contain exactly one minimum cost 
plan after global pruning", 
-                       1, result1.size());
-               assertEquals("mockHop1's plan should have the global minimum 
cost", 
-                       minCost, result1.get(0).getCost(), 0.001);
-
-               assertNotNull("mockHop2 should maintain a FedPlan list after 
global pruning", result2);
-               assertEquals("mockHop2 should contain exactly one minimum cost 
plan after global pruning", 
-                       1, result2.size());
-               assertEquals("mockHop2's plan should have the global minimum 
cost", 
-                       minCost, result2.get(0).getCost(), 0.001);
-       }
-}
diff --git a/src/test/scripts/functions/federated/cost.dml 
b/src/test/scripts/functions/federated/cost.dml
new file mode 100644
index 0000000000..ec34d45bb6
--- /dev/null
+++ b/src/test/scripts/functions/federated/cost.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+a = matrix(7,10,10);
+b = a + a^2;
+c = sqrt(b);
+print(sum(c));
\ No newline at end of file


Reply via email to