This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 29b4d92858 [SYSTEMDS-3790] Rework FedPlanner memo table, cost
estimator, enumerator
29b4d92858 is described below
commit 29b4d9285867bfd7e5b92276ca25137e38d9fb2a
Author: min-guk <[email protected]>
AuthorDate: Sat Dec 21 16:04:56 2024 +0100
[SYSTEMDS-3790] Rework FedPlanner memo table, cost estimator, enumerator
Closes #2147.
---
.../sysds/hops/fedplanner/FederatedMemoTable.java | 288 +++++++++++++++++++++
.../fedplanner/FederatedPlanCostEnumerator.java | 136 ++++++++++
.../fedplanner/FederatedPlanCostEstimator.java | 116 +++++++++
.../apache/sysds/hops/fedplanner/MemoTable.java | 160 ------------
.../federated/FederatedPlanCostEnumeratorTest.java | 87 +++++++
.../test/component/federated/MemoTableTest.java | 186 -------------
src/test/scripts/functions/federated/cost.dml | 25 ++
7 files changed, 652 insertions(+), 346 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java
new file mode 100644
index 0000000000..16240f0281
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java
@@ -0,0 +1,288 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.fedplanner;
+
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.ArrayList;
+import java.util.Map;
+import java.util.HashSet;
+import java.util.Set;
+
+/**
+ * A Memoization Table for managing federated plans (FedPlan) based on
combinations of Hops and fedOutTypes.
+ * This table stores and manages different execution plan variants for each
Hop and fedOutType combination,
+ * facilitating the optimization of federated execution plans.
+ */
+public class FederatedMemoTable {
+ // Maps Hop ID and fedOutType pairs to their plan variants
+ private final Map<Pair<Long, FederatedOutput>, FedPlanVariants>
hopMemoTable = new HashMap<>();
+
+ /**
+ * Adds a new federated plan to the memo table.
+ * Creates a new variant list if none exists for the given Hop and
fedOutType.
+ *
+ * @param hop The Hop node
+ * @param fedOutType The federated output type
+ * @param planChilds List of child plan references
+ * @return The newly created FedPlan
+ */
+ public FedPlan addFedPlan(Hop hop, FederatedOutput fedOutType,
List<Pair<Long, FederatedOutput>> planChilds) {
+ long hopID = hop.getHopID();
+ FedPlanVariants fedPlanVariantList;
+
+ if (contains(hopID, fedOutType)) {
+ fedPlanVariantList = hopMemoTable.get(new
ImmutablePair<>(hopID, fedOutType));
+ } else {
+ fedPlanVariantList = new FedPlanVariants(hop,
fedOutType);
+ hopMemoTable.put(new ImmutablePair<>(hopID,
fedOutType), fedPlanVariantList);
+ }
+
+ FedPlan newPlan = new FedPlan(planChilds, fedPlanVariantList);
+ fedPlanVariantList.addFedPlan(newPlan);
+
+ return newPlan;
+ }
+
+ /**
+ * Retrieves the minimum cost child plan considering the parent's
output type.
+ * The cost is calculated using getParentViewCost to account for
potential type mismatches.
+ *
+ * @param childHopID ?
+ * @param childFedOutType ?
+ * @return ?
+ */
+ public FedPlan getMinCostChildFedPlan(long childHopID, FederatedOutput
childFedOutType) {
+ FedPlanVariants fedPlanVariantList = hopMemoTable.get(new
ImmutablePair<>(childHopID, childFedOutType));
+ return fedPlanVariantList._fedPlanVariants.stream()
+
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
+ .orElse(null);
+ }
+
+ public FedPlanVariants getFedPlanVariants(long hopID, FederatedOutput
fedOutType) {
+ return hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType));
+ }
+
+ /**
+ * Checks if the memo table contains an entry for a given Hop and
fedOutType.
+ *
+ * @param hopID The Hop ID.
+ * @param fedOutType The associated fedOutType.
+ * @return True if the entry exists, false otherwise.
+ */
+ public boolean contains(long hopID, FederatedOutput fedOutType) {
+ return hopMemoTable.containsKey(new ImmutablePair<>(hopID,
fedOutType));
+ }
+
+ /**
+ * Prunes all entries in the memo table, retaining only the minimum-cost
+ * FedPlan for each entry.
+ */
+ public void pruneMemoTable() {
+ for (Map.Entry<Pair<Long, FederatedOutput>, FedPlanVariants>
entry : hopMemoTable.entrySet()) {
+ List<FedPlan> fedPlanList =
entry.getValue().getFedPlanVariants();
+ if (fedPlanList.size() > 1) {
+ // Find the FedPlan with the minimum cost
+ FedPlan minCostPlan = fedPlanList.stream()
+
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
+ .orElse(null);
+
+ // Retain only the minimum cost plan
+ fedPlanList.clear();
+ fedPlanList.add(minCostPlan);
+ }
+ }
+ }
+
+ /**
+ * Recursively prints a tree representation of the DAG starting from
the given root FedPlan.
+ * Includes information about hopID, fedOutType, TotalCost, SelfCost,
and NetCost for each node.
+ *
+ * @param rootFedPlan The starting point FedPlan to print
+ */
+ public void printFedPlanTree(FedPlan rootFedPlan) {
+ Set<FedPlan> visited = new HashSet<>();
+ printFedPlanTreeRecursive(rootFedPlan, visited, 0, true);
+ }
+
+ /**
+ * Helper method to recursively print the FedPlan tree.
+ *
+ * @param plan The current FedPlan to print
+ * @param visited Set to keep track of visited FedPlans (prevents
cycles)
+ * @param depth The current depth level for indentation
+ * @param isLast Whether this node is the last child of its parent
+ */
+ private void printFedPlanTreeRecursive(FedPlan plan, Set<FedPlan>
visited, int depth, boolean isLast) {
+ if (plan == null || visited.contains(plan)) {
+ return;
+ }
+
+ visited.add(plan);
+
+ Hop hop = plan.getHopRef();
+ StringBuilder sb = new StringBuilder();
+
+ // Add FedPlan information
+ sb.append(String.format("(%d) ", plan.getHopRef().getHopID()))
+ .append(plan.getHopRef().getOpString())
+ .append(" [")
+ .append(plan.getFedOutType())
+ .append("]");
+
+ StringBuilder childs = new StringBuilder();
+ childs.append(" (");
+ boolean childAdded = false;
+ for( Hop input : hop.getInput()){
+ childs.append(childAdded?",":"");
+ childs.append(input.getHopID());
+ childAdded = true;
+ }
+ childs.append(")");
+ if( childAdded )
+ sb.append(childs.toString());
+
+
+ sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}",
+ plan.getTotalCost(),
+ plan.getSelfCost(),
+ plan.getNetTransferCost()));
+
+ // Add matrix characteristics
+ sb.append(" [")
+ .append(hop.getDim1()).append(", ")
+ .append(hop.getDim2()).append(", ")
+ .append(hop.getBlocksize()).append(", ")
+ .append(hop.getNnz());
+
+ if (hop.getUpdateType().isInPlace()) {
+ sb.append(",
").append(hop.getUpdateType().toString().toLowerCase());
+ }
+ sb.append("]");
+
+ // Add memory estimates
+ sb.append(" [")
+
.append(OptimizerUtils.toMB(hop.getInputMemEstimate())).append(", ")
+
.append(OptimizerUtils.toMB(hop.getIntermediateMemEstimate())).append(", ")
+
.append(OptimizerUtils.toMB(hop.getOutputMemEstimate())).append(" -> ")
+
.append(OptimizerUtils.toMB(hop.getMemEstimate())).append("MB]");
+
+ // Add reblock and checkpoint requirements
+ if (hop.requiresReblock() && hop.requiresCheckpoint()) {
+ sb.append(" [rblk, chkpt]");
+ } else if (hop.requiresReblock()) {
+ sb.append(" [rblk]");
+ } else if (hop.requiresCheckpoint()) {
+ sb.append(" [chkpt]");
+ }
+
+ // Add execution type
+ if (hop.getExecType() != null) {
+ sb.append(", ").append(hop.getExecType());
+ }
+
+ System.out.println(sb);
+
+ // Process child nodes
+ List<Pair<Long, FederatedOutput>> childRefs =
plan.getChildFedPlans();
+ for (int i = 0; i < childRefs.size(); i++) {
+ Pair<Long, FederatedOutput> childRef = childRefs.get(i);
+ FedPlanVariants childVariants =
getFedPlanVariants(childRef.getLeft(), childRef.getRight());
+ if (childVariants == null ||
childVariants.getFedPlanVariants().isEmpty())
+ continue;
+
+ boolean isLastChild = (i == childRefs.size() - 1);
+ for (FedPlan childPlan :
childVariants.getFedPlanVariants()) {
+ printFedPlanTreeRecursive(childPlan, visited,
depth + 1, isLastChild);
+ }
+ }
+ }
+
+ /**
+ * Represents a collection of federated execution plan variants for a
specific Hop.
+ * Contains cost information and references to the associated plans.
+ */
+ public static class FedPlanVariants {
+ protected final Hop hopRef; // Reference to the
associated Hop
+ protected double selfCost; // Current execution cost
(compute + memory access)
+ protected double netTransferCost; // Network transfer cost
+ private final FederatedOutput fedOutType; // Output
type (FOUT/LOUT)
+ protected List<FedPlan> _fedPlanVariants; // List of plan
variants
+
+ public FedPlanVariants(Hop hopRef, FederatedOutput fedOutType) {
+ this.hopRef = hopRef;
+ this.fedOutType = fedOutType;
+ this.selfCost = 0;
+ this.netTransferCost = 0;
+ this._fedPlanVariants = new ArrayList<>();
+ }
+
+ public int size() {return _fedPlanVariants.size();}
+ public void addFedPlan(FedPlan fedPlan)
{_fedPlanVariants.add(fedPlan);}
+ public List<FedPlan> getFedPlanVariants() {return
_fedPlanVariants;}
+ }
+
+ /**
+ * Represents a single federated execution plan with its associated
costs and dependencies.
+ * Contains:
+ * 1. selfCost: Cost of current hop (compute + input/output memory
access)
+ * 2. totalCost: Cumulative cost including this plan and all child plans
+ * 3. netTransferCost: Network transfer cost for this plan to parent
plan.
+ */
+ public static class FedPlan {
+ private double totalCost; //
Total cost including child plans
+ private final FedPlanVariants fedPlanVariants; // Reference to
variant list
+ private final List<Pair<Long, FederatedOutput>> childFedPlans;
// Child plan references
+
+ public FedPlan(List<Pair<Long, FederatedOutput>> childFedPlans,
FedPlanVariants fedPlanVariants) {
+ this.totalCost = 0;
+ this.childFedPlans = childFedPlans;
+ this.fedPlanVariants = fedPlanVariants;
+ }
+
+ public void setTotalCost(double totalCost) {this.totalCost =
totalCost;}
+ public void setSelfCost(double selfCost)
{fedPlanVariants.selfCost = selfCost;}
+ public void setNetTransferCost(double netTransferCost)
{fedPlanVariants.netTransferCost = netTransferCost;}
+
+ public Hop getHopRef() {return fedPlanVariants.hopRef;}
+ public FederatedOutput getFedOutType() {return
fedPlanVariants.fedOutType;}
+ public double getTotalCost() {return totalCost;}
+ public double getSelfCost() {return fedPlanVariants.selfCost;}
+ private double getNetTransferCost() {return
fedPlanVariants.netTransferCost;}
+ public List<Pair<Long, FederatedOutput>> getChildFedPlans()
{return childFedPlans;}
+
+ /**
+ * Calculates the conditional network transfer cost based on
output type compatibility.
+ * Returns 0 if output types match, otherwise returns the
network transfer cost.
+ * @param parentFedOutType ?
+ * @return ?
+ */
+ public double getCondNetTransferCost(FederatedOutput
parentFedOutType) {
+ if (parentFedOutType == getFedOutType()) return 0;
+ return fedPlanVariants.netTransferCost;
+ }
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java
new file mode 100644
index 0000000000..73e8d5d693
--- /dev/null
+++
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.fedplanner;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Comparator;
+import java.util.Objects;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan;
+import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants;
+import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+
+/**
+ * Enumerates and evaluates all possible federated execution plans for a given
Hop DAG.
+ * Works with FederatedMemoTable to store plan variants and
FederatedPlanCostEstimator
+ * to compute their costs.
+ */
+public class FederatedPlanCostEnumerator {
+ /**
+ * Entry point for federated plan enumeration. Creates a memo table and
returns
+ * the minimum cost plan for the entire DAG.
+ *
+ * @param rootHop ?
+ * @param printTree ?
+ * @return ?
+ */
+ public static FedPlan enumerateFederatedPlanCost(Hop rootHop, boolean
printTree) {
+ // Create new memo table to store all plan variants
+ FederatedMemoTable memoTable = new FederatedMemoTable();
+
+ // Recursively enumerate all possible plans
+ enumerateFederatedPlanCost(rootHop, memoTable);
+
+ // Return the minimum cost plan for the root node
+ FedPlan optimalPlan = getMinCostRootFedPlan(rootHop.getHopID(),
memoTable);
+ memoTable.pruneMemoTable();
+ if (printTree) memoTable.printFedPlanTree(optimalPlan);
+
+ return optimalPlan;
+ }
+
+ /**
+ * Recursively enumerates all possible federated execution plans for a
Hop DAG.
+ * For each node:
+ * 1. First processes all input nodes recursively if not already
processed
+ * 2. Generates all possible combinations of federation types
(FOUT/LOUT) for inputs
+ * 3. Creates and evaluates both FOUT and LOUT variants for current
node with each input combination
+ *
+ * The enumeration uses a bottom-up approach where:
+ * - Each input combination is represented by a binary number (i)
+ * - Bit j in i determines whether input j is FOUT (1) or LOUT (0)
+ * - Total number of combinations is 2^numInputs
+ *
+ * @param hop ?
+ * @param memoTable ?
+ */
+ private static void enumerateFederatedPlanCost(Hop hop,
FederatedMemoTable memoTable) {
+ int numInputs = hop.getInput().size();
+
+ // Process all input nodes first if not already in memo table
+ for (Hop inputHop : hop.getInput()) {
+ if (!memoTable.contains(inputHop.getHopID(),
FederatedOutput.FOUT)
+ && !memoTable.contains(inputHop.getHopID(),
FederatedOutput.LOUT)) {
+ enumerateFederatedPlanCost(inputHop,
memoTable);
+ }
+ }
+
+ // Generate all possible input combinations using binary
representation
+ // i represents a specific combination of FOUT/LOUT for inputs
+ for (int i = 0; i < (1 << numInputs); i++) {
+ List<Pair<Long, FederatedOutput>> planChilds = new
ArrayList<>();
+
+ // For each input, determine if it should be FOUT or
LOUT based on bit j in i
+ for (int j = 0; j < numInputs; j++) {
+ Hop inputHop = hop.getInput().get(j);
+ // If bit j is set (1), use FOUT; otherwise use
LOUT
+ FederatedOutput childType = ((i & (1 << j)) !=
0) ?
+ FederatedOutput.FOUT :
FederatedOutput.LOUT;
+ planChilds.add(Pair.of(inputHop.getHopID(),
childType));
+ }
+
+ // Create and evaluate FOUT variant for current input
combination
+ FedPlan fOutPlan = memoTable.addFedPlan(hop,
FederatedOutput.FOUT, planChilds);
+
FederatedPlanCostEstimator.computeFederatedPlanCost(fOutPlan, memoTable);
+
+ // Create and evaluate LOUT variant for current input
combination
+ FedPlan lOutPlan = memoTable.addFedPlan(hop,
FederatedOutput.LOUT, planChilds);
+
FederatedPlanCostEstimator.computeFederatedPlanCost(lOutPlan, memoTable);
+ }
+ }
+
+ /**
+ * Returns the minimum cost plan for the root Hop, comparing both FOUT
and LOUT variants.
+ * Used to select the final execution plan after enumeration.
+ *
+ * @param HopID ?
+ * @param memoTable ?
+ * @return ?
+ */
+ private static FedPlan getMinCostRootFedPlan(long HopID,
FederatedMemoTable memoTable) {
+ FedPlanVariants fOutFedPlanVariants =
memoTable.getFedPlanVariants(HopID, FederatedOutput.FOUT);
+ FedPlanVariants lOutFedPlanVariants =
memoTable.getFedPlanVariants(HopID, FederatedOutput.LOUT);
+
+ FedPlan minFOutFedPlan =
fOutFedPlanVariants._fedPlanVariants.stream()
+
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
+
.orElse(null);
+ FedPlan minlOutFedPlan =
lOutFedPlanVariants._fedPlanVariants.stream()
+
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
+
.orElse(null);
+
+ if (Objects.requireNonNull(minFOutFedPlan).getTotalCost()
+ <
Objects.requireNonNull(minlOutFedPlan).getTotalCost()) {
+ return minFOutFedPlan;
+ }
+ return minlOutFedPlan;
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java
new file mode 100644
index 0000000000..a716c3321d
--- /dev/null
+++
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.fedplanner;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.cost.ComputeCost;
+import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan;
+import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+
+/**
+ * Cost estimator for federated execution plans.
+ * Calculates computation, memory access, and network transfer costs for
federated operations.
+ * Works in conjunction with FederatedMemoTable to evaluate different
execution plan variants.
+ */
+public class FederatedPlanCostEstimator {
+ // Default value is used as a reasonable estimate since we only need
+ // to compare relative costs between different federated plans
+ // Memory bandwidth for local computations (25 GB/s)
+ private static final double DEFAULT_MBS_MEMORY_BANDWIDTH = 25000.0;
+ // Network bandwidth for data transfers between federated sites (1 Gbps)
+ private static final double DEFAULT_MBS_NETWORK_BANDWIDTH = 125.0;
+
+ /**
+ * Computes total cost of federated plan by:
+ * 1. Computing current node cost (if not cached)
+ * 2. Adding minimum-cost child plans
+ * 3. Including network transfer costs when needed
+ *
+ * @param currentPlan Plan to compute cost for
+ * @param memoTable Table containing all plan variants
+ */
+ public static void computeFederatedPlanCost(FedPlan currentPlan,
FederatedMemoTable memoTable) {
+ double totalCost = 0;
+ Hop currentHop = currentPlan.getHopRef();
+
+ // Step 1: Calculate current node costs if not already computed
+ if (currentPlan.getSelfCost() == 0) {
+ // Compute cost for current node (computation + memory
access)
+ totalCost = computeCurrentCost(currentHop);
+ currentPlan.setSelfCost(totalCost);
+ // Calculate potential network transfer cost if
federation type changes
+
currentPlan.setNetTransferCost(computeHopNetworkAccessCost(currentHop.getOutputMemEstimate()));
+ } else {
+ totalCost = currentPlan.getSelfCost();
+ }
+
+ // Step 2: Process each child plan and add their costs
+ for (Pair<Long, FederatedOutput> planRefMeta :
currentPlan.getChildFedPlans()) {
+ // Find minimum cost child plan considering federation
type compatibility
+ // Note: This approach might lead to suboptimal or
wrong solutions when a child has multiple parents
+ // because we're selecting child plans independently
for each parent
+ FedPlan planRef =
memoTable.getMinCostChildFedPlan(planRefMeta.getLeft(), planRefMeta.getRight());
+
+ // Add child plan cost (includes network transfer cost
if federation types differ)
+ totalCost += planRef.getTotalCost() +
planRef.getCondNetTransferCost(currentPlan.getFedOutType());
+ }
+
+ // Step 3: Set final cumulative cost including current node
+ currentPlan.setTotalCost(totalCost);
+ }
+
+ /**
+ * Computes the cost for the current Hop node.
+ *
+ * @param currentHop The Hop node whose cost needs to be computed
+ * @return The total cost for the current node's operation
+ */
+ private static double computeCurrentCost(Hop currentHop){
+ double computeCost = ComputeCost.getHOPComputeCost(currentHop);
+ double inputAccessCost =
computeHopMemoryAccessCost(currentHop.getInputMemEstimate());
+ double ouputAccessCost =
computeHopMemoryAccessCost(currentHop.getOutputMemEstimate());
+
+ // Compute total cost assuming:
+ // 1. Computation and input access can be overlapped (hence
taking max)
+ // 2. Output access must wait for both to complete (hence
adding)
+ return Math.max(computeCost, inputAccessCost) + ouputAccessCost;
+ }
+
+ /**
+ * Calculates the memory access cost based on data size and memory
bandwidth.
+ *
+ * @param memSize Size of data to be accessed (in bytes)
+ * @return Time cost for memory access (in seconds)
+ */
+ private static double computeHopMemoryAccessCost(double memSize) {
+ return memSize / (1024*1024) / DEFAULT_MBS_MEMORY_BANDWIDTH;
+ }
+
+ /**
+ * Calculates the network transfer cost based on data size and network
bandwidth.
+ * Used when federation status changes between parent and child plans.
+ *
+ * @param memSize Size of data to be transferred (in bytes)
+ * @return Time cost for network transfer (in seconds)
+ */
+ private static double computeHopNetworkAccessCost(double memSize) {
+ return memSize / (1024*1024) / DEFAULT_MBS_NETWORK_BANDWIDTH;
+ }
+}
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
deleted file mode 100644
index f11b17b984..0000000000
--- a/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java
+++ /dev/null
@@ -1,160 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.hops.fedplanner;
-
-import org.apache.sysds.hops.Hop;
-import org.apache.sysds.hops.fedplanner.FTypes.FType;
-import org.apache.commons.lang3.tuple.Pair;
-import org.apache.commons.lang3.tuple.ImmutablePair;
-
-import java.util.Comparator;
-import java.util.HashMap;
-import java.util.List;
-import java.util.ArrayList;
-import java.util.Map;
-
-/**
- * A Memoization Table for managing federated plans (`FedPlan`) based on
- * combinations of Hops and FTypes. Each combination is mapped to a list
- * of possible execution plans, allowing for pruning and optimization.
- */
-public class MemoTable {
-
- // Maps combinations of Hop ID and FType to lists of FedPlans
- private final Map<Pair<Long, FTypes.FType>, List<FedPlan>> hopMemoTable
= new HashMap<>();
-
- /**
- * Represents a federated execution plan with its cost and associated
references.
- */
- public static class FedPlan {
- @SuppressWarnings("unused")
- private final Hop hopRef; // The
associated Hop object
- private final double cost; // Cost of this
federated plan
- @SuppressWarnings("unused")
- private final List<Pair<Long, FType>> planRefs; // References
to dependent plans
-
- public FedPlan(Hop hopRef, double cost, List<Pair<Long, FType>>
planRefs) {
- this.hopRef = hopRef;
- this.cost = cost;
- this.planRefs = planRefs;
- }
-
- public double getCost() {
- return cost;
- }
- }
-
- /**
- * Adds a single FedPlan to the memo table for a given Hop and FType.
- * If the entry already exists, the new FedPlan is appended to the list.
- *
- * @param hop The Hop object.
- * @param fType The associated FType.
- * @param fedPlan The FedPlan to add.
- */
- public void addFedPlan(Hop hop, FType fType, FedPlan fedPlan) {
- if (contains(hop, fType)) {
- List<FedPlan> fedPlanList = get(hop, fType);
- fedPlanList.add(fedPlan);
- } else {
- List<FedPlan> fedPlanList = new ArrayList<>();
- fedPlanList.add(fedPlan);
- hopMemoTable.put(new ImmutablePair<>(hop.getHopID(),
fType), fedPlanList);
- }
- }
-
- /**
- * Adds multiple FedPlans to the memo table for a given Hop and FType.
- * If the entry already exists, the new FedPlans are appended to the
list.
- *
- * @param hop The Hop object.
- * @param fType The associated FType.
- * @param fedPlanList The list of FedPlans to add.
- */
- public void addFedPlanList(Hop hop, FType fType, List<FedPlan>
fedPlanList) {
- if (contains(hop, fType)) {
- List<FedPlan> prevFedPlanList = get(hop, fType);
- prevFedPlanList.addAll(fedPlanList);
- } else {
- hopMemoTable.put(new ImmutablePair<>(hop.getHopID(),
fType), fedPlanList);
- }
- }
-
- /**
- * Retrieves the list of FedPlans associated with a given Hop and FType.
- *
- * @param hop The Hop object.
- * @param fType The associated FType.
- * @return The list of FedPlans, or null if no entry exists.
- */
- public List<FedPlan> get(Hop hop, FType fType) {
- return hopMemoTable.get(new ImmutablePair<>(hop.getHopID(),
fType));
- }
-
- /**
- * Checks if the memo table contains an entry for a given Hop and FType.
- *
- * @param hop The Hop object.
- * @param fType The associated FType.
- * @return True if the entry exists, false otherwise.
- */
- public boolean contains(Hop hop, FType fType) {
- return hopMemoTable.containsKey(new
ImmutablePair<>(hop.getHopID(), fType));
- }
-
- /**
- * Prunes the FedPlans associated with a specific Hop and FType,
- * keeping only the plan with the minimum cost.
- *
- * @param hop The Hop object.
- * @param fType The associated FType.
- */
- public void prunePlan(Hop hop, FType fType) {
- prunePlan(hopMemoTable.get(new ImmutablePair<>(hop.getHopID(),
fType)));
- }
-
- /**
- * Prunes all entries in the memo table, retaining only the minimum-cost
- * FedPlan for each entry.
- */
- public void pruneAll() {
- for (Map.Entry<Pair<Long, FType>, List<FedPlan>> entry :
hopMemoTable.entrySet()) {
- prunePlan(entry.getValue());
- }
- }
-
- /**
- * Prunes the given list of FedPlans to retain only the plan with the
minimum cost.
- *
- * @param fedPlanList The list of FedPlans to prune.
- */
- private void prunePlan(List<FedPlan> fedPlanList) {
- if (fedPlanList.size() > 1) {
- // Find the FedPlan with the minimum cost
- FedPlan minCostPlan = fedPlanList.stream()
- .min(Comparator.comparingDouble(plan ->
plan.cost))
- .orElse(null);
-
- // Retain only the minimum cost plan
- fedPlanList.clear();
- fedPlanList.add(minCostPlan);
- }
- }
-}
diff --git
a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java
b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java
new file mode 100644
index 0000000000..56de8cf3c4
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.federated;
+
+import java.io.IOException;
+import java.util.HashMap;
+
+import org.apache.sysds.hops.Hop;
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.DMLTranslator;
+import org.apache.sysds.parser.ParserFactory;
+import org.apache.sysds.parser.ParserWrapper;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.hops.fedplanner.FederatedPlanCostEnumerator;
+
+
+public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase
+{
+ private static final String TEST_DIR = "functions/federated/";
+ private static final String HOME = SCRIPT_DIR + TEST_DIR;
+ private static final String TEST_CLASS_DIR = TEST_DIR +
FederatedPlanCostEnumeratorTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {}
+
+ @Test
+ public void testDependencyAnalysis1() { runTest("cost.dml"); }
+
+ private void runTest( String scriptFilename ) {
+ int index = scriptFilename.lastIndexOf(".dml");
+ String testName = scriptFilename.substring(0, index > 0 ? index
: scriptFilename.length());
+ TestConfiguration testConfig = new
TestConfiguration(TEST_CLASS_DIR, testName, new String[] {});
+ addTestConfiguration(testName, testConfig);
+ loadTestConfiguration(testConfig);
+
+ try {
+ DMLConfig conf = new
DMLConfig(getCurConfigFile().getPath());
+ ConfigurationManager.setLocalConfig(conf);
+
+ //read script
+ String dmlScriptString = DMLScript.readDMLScript(true,
HOME + scriptFilename);
+
+ //parsing and dependency analysis
+ ParserWrapper parser = ParserFactory.createParser();
+ DMLProgram prog =
parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new
HashMap<>());
+ DMLTranslator dmlt = new DMLTranslator(prog);
+ dmlt.liveVariableAnalysis(prog);
+ dmlt.validateParseTree(prog);
+ dmlt.constructHops(prog);
+ dmlt.rewriteHopsDAG(prog);
+ dmlt.constructLops(prog);
+ /* TODO) In the current DAG, Hop's _outputMemEstimate
is not initialized
+ // This leads to incorrect fedplan generation, so test
code needs to be modified
+ // If needed, modify costEstimator to handle cases
where _outputMemEstimate is not initialized
+ */
+ Hop hops =
prog.getStatementBlocks().get(0).getHops().get(0);
+
FederatedPlanCostEnumerator.enumerateFederatedPlanCost(hops, true);
+ }
+ catch (IOException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/component/federated/MemoTableTest.java
b/src/test/java/org/apache/sysds/test/component/federated/MemoTableTest.java
deleted file mode 100644
index e3928c1263..0000000000
--- a/src/test/java/org/apache/sysds/test/component/federated/MemoTableTest.java
+++ /dev/null
@@ -1,186 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.test.component.federated;
-
-import org.apache.sysds.hops.Hop;
-import org.apache.sysds.hops.fedplanner.FTypes;
-import org.apache.sysds.hops.fedplanner.MemoTable;
-import org.apache.sysds.hops.fedplanner.MemoTable.FedPlan;
-import org.apache.commons.lang3.tuple.Pair;
-import org.junit.Before;
-import org.junit.Test;
-import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
-
-import java.util.ArrayList;
-import java.util.List;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertNull;
-import static org.junit.Assert.assertTrue;
-import static org.mockito.Mockito.when;
-
-public class MemoTableTest {
-
- private MemoTable memoTable;
-
- @Mock
- private Hop mockHop1;
-
- @Mock
- private Hop mockHop2;
-
- private java.util.Random rand;
-
- @Before
- public void setUp() {
- MockitoAnnotations.openMocks(this);
- memoTable = new MemoTable();
-
- // Set up unique IDs for mock Hops
- when(mockHop1.getHopID()).thenReturn(1L);
- when(mockHop2.getHopID()).thenReturn(2L);
-
- // Initialize random generator with fixed seed for reproducible
tests
- rand = new java.util.Random(42);
- }
-
- @Test
- public void testAddAndGetSingleFedPlan() {
- // Initialize test data
- List<Pair<Long, FTypes.FType>> planRefs = new ArrayList<>();
- FedPlan fedPlan = new FedPlan(mockHop1, 10.0, planRefs);
-
- // Verify initial state
- List<FedPlan> result = memoTable.get(mockHop1,
FTypes.FType.FULL);
- assertNull("Initial FedPlan list should be null before adding
any plans", result);
-
- // Add single FedPlan
- memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, fedPlan);
-
- // Verify after addition
- result = memoTable.get(mockHop1, FTypes.FType.FULL);
- assertNotNull("FedPlan list should exist after adding a plan",
result);
- assertEquals("FedPlan list should contain exactly one plan", 1,
result.size());
- assertEquals("FedPlan cost should be exactly 10.0", 10.0,
result.get(0).getCost(), 0.001);
- }
-
- @Test
- public void testAddMultipleDuplicatedFedPlans() {
- // Initialize test data with duplicate costs
- List<Pair<Long, FTypes.FType>> planRefs = new ArrayList<>();
- List<FedPlan> fedPlans = new ArrayList<>();
- fedPlans.add(new FedPlan(mockHop1, 10.0, planRefs)); // Unique
cost
- fedPlans.add(new FedPlan(mockHop1, 20.0, planRefs)); // First
duplicate
- fedPlans.add(new FedPlan(mockHop1, 20.0, planRefs)); // Second
duplicate
-
- // Add multiple plans including duplicates
- memoTable.addFedPlanList(mockHop1, FTypes.FType.FULL, fedPlans);
-
- // Verify handling of duplicate plans
- List<FedPlan> result = memoTable.get(mockHop1,
FTypes.FType.FULL);
- assertNotNull("FedPlan list should exist after adding multiple
plans", result);
- assertEquals("FedPlan list should maintain all plans including
duplicates", 3, result.size());
- }
-
- @Test
- public void testContains() {
- // Initialize test data
- List<Pair<Long, FTypes.FType>> planRefs = new ArrayList<>();
- FedPlan fedPlan = new FedPlan(mockHop1, 10.0, planRefs);
-
- // Verify initial state
- assertFalse("MemoTable should not contain any entries
initially",
- memoTable.contains(mockHop1, FTypes.FType.FULL));
-
- // Add plan and verify presence
- memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, fedPlan);
-
- assertTrue("MemoTable should contain entry after adding
FedPlan",
- memoTable.contains(mockHop1, FTypes.FType.FULL));
- assertFalse("MemoTable should not contain entries for different
Hop",
- memoTable.contains(mockHop2, FTypes.FType.FULL));
- }
-
- @Test
- public void testPrunePlanPruneAll() {
- // Initialize base test data
- List<Pair<Long, FTypes.FType>> planRefs = new ArrayList<>();
- // Create separate FedPlan lists for independent testing of
each Hop
- List<FedPlan> fedPlans1 = new ArrayList<>(); // Plans for
mockHop1
- List<FedPlan> fedPlans2 = new ArrayList<>(); // Plans for
mockHop2
-
- // Generate random cost FedPlans for both Hops
- double minCost = Double.MAX_VALUE;
- int size = 100;
- for(int i = 0; i < size; i++) {
- double cost = rand.nextDouble() * 1000; // Random cost
between 0 and 1000
- fedPlans1.add(new FedPlan(mockHop1, cost, planRefs));
- fedPlans2.add(new FedPlan(mockHop2, cost, planRefs));
- minCost = Math.min(minCost, cost);
- }
-
- // Add FedPlan lists to MemoTable
- memoTable.addFedPlanList(mockHop1, FTypes.FType.FULL,
fedPlans1);
- memoTable.addFedPlanList(mockHop2, FTypes.FType.FULL,
fedPlans2);
-
- // Test selective pruning on mockHop1
- memoTable.prunePlan(mockHop1, FTypes.FType.FULL);
-
- // Get results for verification
- List<FedPlan> result1 = memoTable.get(mockHop1,
FTypes.FType.FULL);
- List<FedPlan> result2 = memoTable.get(mockHop2,
FTypes.FType.FULL);
-
- // Verify selective pruning results
- assertNotNull("Pruned mockHop1 should maintain a FedPlan list",
result1);
- assertEquals("Pruned mockHop1 should contain exactly one
minimum cost plan", 1, result1.size());
- assertEquals("Pruned mockHop1's plan should have the minimum
cost", minCost, result1.get(0).getCost(), 0.001);
-
- // Verify unpruned Hop state
- assertNotNull("Unpruned mockHop2 should maintain a FedPlan
list", result2);
- assertEquals("Unpruned mockHop2 should maintain all original
plans", size, result2.size());
-
- // Add additional plans to both Hops
- for(int i = 0; i < size; i++) {
- double cost = rand.nextDouble() * 1000;
- memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, new
FedPlan(mockHop1, cost, planRefs));
- memoTable.addFedPlan(mockHop2, FTypes.FType.FULL, new
FedPlan(mockHop2, cost, planRefs));
- minCost = Math.min(minCost, cost);
- }
-
- // Test global pruning
- memoTable.pruneAll();
-
- // Verify global pruning results
- assertNotNull("mockHop1 should maintain a FedPlan list after
global pruning", result1);
- assertEquals("mockHop1 should contain exactly one minimum cost
plan after global pruning",
- 1, result1.size());
- assertEquals("mockHop1's plan should have the global minimum
cost",
- minCost, result1.get(0).getCost(), 0.001);
-
- assertNotNull("mockHop2 should maintain a FedPlan list after
global pruning", result2);
- assertEquals("mockHop2 should contain exactly one minimum cost
plan after global pruning",
- 1, result2.size());
- assertEquals("mockHop2's plan should have the global minimum
cost",
- minCost, result2.get(0).getCost(), 0.001);
- }
-}
diff --git a/src/test/scripts/functions/federated/cost.dml
b/src/test/scripts/functions/federated/cost.dml
new file mode 100644
index 0000000000..ec34d45bb6
--- /dev/null
+++ b/src/test/scripts/functions/federated/cost.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+a = matrix(7,10,10);
+b = a + a^2;
+c = sqrt(b);
+print(sum(c));
\ No newline at end of file