min-guk opened a new pull request, #2147:
URL: https://github.com/apache/systemds/pull/2147
This implementation is based on the newly implemented
FederatedPlanCostEstimator and FederatedMemoTable, following the direction we
previously discussed.
## 1. FederatedMemoTable (MemoTable)
```java
public class MemoTable {
private final Map<Pair<Long, FTypes.FType>, List<FedPlan>>
hopMemoTable = new HashMap<>();
public static class FedPlan {
@SuppressWarnings("unused")
private final Hop hopRef; // The
associated Hop object
private final double cost; // Cost of this
federated plan
@SuppressWarnings("unused")
private final List<Pair<Long, FType>> planRefs; // References
to dependent plans
}
}
```
The previous FedPlan class structure had several issues:
- A single <HopID, FederatedOutput> pair stored multiple FedPlans as a list
in the MemoTable, redundantly storing the hopRef.
- A single <HopID, FederatedOutput> pair had to calculate its computeCost
and accessCost 2^(planRefs+1) times redundantly.
- FedPlan did not store its own FederatedOutput
```java
public class FederatedMemoTable {
private final Map<Pair<Long, FederatedOutput>, FedPlanVariants>
hopMemoTable = new HashMap<>();
public static class FedPlanVariants {
protected final Hop hopRef; // Reference to the associated
Hop
protected double currentCost; // Current execution cost
(compute + memory access)
protected double netTransferCost; // Network transfer cost
protected List<FedPlan> _fedPlanVariants;
}
public static class FedPlan {
private double cumulativeCost; // Total cost
including child plans
private final FederatedOutput fedOutType; // Output type
(FOUT/LOUT)
private final FedPlanVariants fedPlanVariantList; // Reference to
variant list
private List<Pair<Long, FederatedOutput>> metaChildFedPlans; //
Child plan references
private List<FedPlan> selectedFedPlans; // Selected child
plans
}
```
The key points of the redesigned FederatedMemoTable are as follows:
- A single <HopID, FederatedOutput> pair has one FedPlanVariants, which
stores and shares the redundant hopRef, currentCost, and netTransferCost with
FedPlans stored in fedPlanVariants.
- A single <HopID, FederatedOutput> pair calculates its computeCost and
accessCost only once.
- FedPlan stores its own FederatedOutput.
## 2. CostEstimator
```java
// Do not create and allocate any new FedPlan.
// just calculate the cost for given fed plans.
// cost of dependent fedplans in planRefs is already calculated.
public static void computeFederatedPlanCost(FedPlan currentPlan,
FederatedMemoTable memoTable){
double cost = computeFederatedPlanCost(currentPlan.getHopRef());
for (Pair<Long, FederatedOutput> planRefMeta:
currentPlan.getPlanRefs()){
FedPlan planRef = memoTable.getFedPlan(planRefMeta.getLeft(),
planRefMeta.getRight());
cost += planRef.getCost();
if (currentPlan.getFedOutType() != planRef.getFedOutType()){
cost +=
computeHopNetworkAccessCost(planRef.getHopRef().getOutputMemEstimate());
}
}
currentPlan.setCost(cost);
}
```
The previous CostEstimator also had several issues:
- It calculates the currentHop's cost every time.
- The Optimal FedPlan should minimize the total cost of compute, memory
access, and network access.
- However, the previous CostEstimator selects the ref plan with minimum cost
excluding network cost, and then adds network cost afterward, so it cannot
guarantee the minimum cost FedPlan.
```java
public static void computeFederatedPlanCost(FedPlan currentPlan,
FederatedMemoTable memoTable) {
double cumulativeCost = 0;
Hop currentHop = currentPlan.getHopRef();
// Step 1: Calculate current node costs if not already computed
if (currentPlan.getCurrentCost() == 0) {
// Compute cost for current node (computation + memory access)
cumulativeCost = computeCurrentCost(currentHop);
currentPlan.setCurrentCost(cumulativeCost);
// Calculate potential network transfer cost if federation type
changes
currentPlan.setNetTransferCost(computeHopNetworkAccessCost(currentHop.getOutputMemEstimate()));
} else {
cumulativeCost = currentPlan.getCurrentCost();
}
// Step 2: Process each child plan and add their costs
for (Pair<Long, FederatedOutput> planRefMeta :
currentPlan.getMetaChildFedPlans()) {
// Find minimum cost child plan considering federation type
compatibility
// Note: This approach might lead to suboptimal or wrong
solutions when a child has multiple parents
// because we're selecting child plans independently for each
parent
FedPlan planRef = memoTable.getMinCostChildFedPlan(
planRefMeta.getLeft(), planRefMeta.getRight(),
currentPlan.getFedOutType());
// Add child plan cost (includes network transfer cost if
federation types differ)
cumulativeCost +=
planRef.getParentViewCost(currentPlan.getFedOutType());
// Store selected child plan
// Note: Selected plan has minimum parent view cost, not minimum
cumulative cost,
// which means it highly unlikely to be found through simple
pruning after enumeration
currentPlan.putChildFedPlan(planRef);
}
// Step 3: Set final cumulative cost including current node
currentPlan.setCumulativeCost(cumulativeCost);
}
```
The key points of the redesigned CostEstimator are as follows:
- It calculates the compute cost and access cost of currentHop only once per
HopID.
- When selecting the minimum cost ref plan, it selects the ref plan
including network cost, ensuring minimum total cost.
- It stores selected child plans in a list as pointers.
- This is because when pruning all at once in the memotable later, we
cannot calculate network cost without knowing the fOutType of each fedplan's
parent fedplan, so we cannot identify the optimal cost plan. Therefore, pruning
in the current MemoTable has been removed.
However, the current CostEstimator may cause two problems because it selects
child plans based only on the cost of a single current plan and child plan:
1. A child plan can have multiple parent plans, and different parent plans
can select different child plans. Therefore, a child plan could form a
non-existent fed plan with different fOutTypes.
2. Since a child plan can have multiple parent plans, it should select the
fOutType that minimizes the sum of costs of all parent plans referencing it.
Otherwise, it may select a suboptimal plan.
- We need to devise a new algorithm to solve these two problems.
## 3. FederatedPlanCostEnumerator
```java
public class FederatedPlanCostEnumerator {
public static FedPlan enumerateFederatedPlanCost(Hop rootHop) {
FederatedMemoTable memoTable = new FederatedMemoTable();
enumerateFederatedPlanCost(rootHop, memoTable);
return getMinCostRootFedPlan(rootHop.getHopID(), memoTable);
}
/**
* Recursively enumerates all possible federated execution plans for a
Hop DAG.
* For each node:
* 1. First processes all input nodes recursively if not already
processed
* 2. Generates all possible combinations of federation types
(FOUT/LOUT) for inputs
* 3. Creates and evaluates both FOUT and LOUT variants for current node
with each input combination
*
* The enumeration uses a bottom-up approach where:
* - Each input combination is represented by a binary number (i)
* - Bit j in i determines whether input j is FOUT (1) or LOUT (0)
* - Total number of combinations is 2^numInputs
*/
private static void enumerateFederatedPlanCost(Hop hop,
FederatedMemoTable memoTable) {
int numInputs = hop.getInput().size();
// Process all input nodes first if not already in memo table
for (Hop inputHop : hop.getInput()) {
if (!memoTable.contains(inputHop.getHopID(),
FederatedOutput.FOUT)
&& !memoTable.contains(inputHop.getHopID(),
FederatedOutput.LOUT)) {
enumerateFederatedPlanCost(inputHop, memoTable);
}
}
// Generate all possible input combinations using binary
representation
// i represents a specific combination of FOUT/LOUT for inputs
for (int i = 0; i < (1 << numInputs); i++) {
List<Pair<Long, FederatedOutput>> planChilds = new
ArrayList<>();
// For each input, determine if it should be FOUT or LOUT based
on bit j in i
for (int j = 0; j < numInputs; j++) {
Hop inputHop = hop.getInput().get(j);
// If bit j is set (1), use FOUT; otherwise use LOUT
FederatedOutput childType = ((i & (1 << j)) != 0) ?
FederatedOutput.FOUT : FederatedOutput.LOUT;
planChilds.add(Pair.of(inputHop.getHopID(), childType));
}
// Create and evaluate FOUT variant for current input combination
FedPlan fOutPlan = memoTable.addFedPlan(hop,
FederatedOutput.FOUT, planChilds);
FederatedPlanCostEstimator.computeFederatedPlanCost(fOutPlan,
memoTable);
// Create and evaluate LOUT variant for current input combination
FedPlan lOutPlan = memoTable.addFedPlan(hop,
FederatedOutput.LOUT, planChilds);
FederatedPlanCostEstimator.computeFederatedPlanCost(lOutPlan,
memoTable);
}
}
}
```
- This implementation is based on the newly implemented
FederatedPlanCostEstimator and FederatedMemoTable, following the direction we
previously discussed.
- I'm not sure how to create complex Hop DAGs similar to real scenarios in
the test code. Could you please provide some reference test code that I can
refer to?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]