This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 37b3e93  [SYSTEMDS-3018] Federated Cost Estimation for Repetitions
37b3e93 is described below

commit 37b3e934ddc9d686d8f6ede9f689038a998ff87a
Author: sebwrede <[email protected]>
AuthorDate: Tue Feb 15 11:48:33 2022 +0100

    [SYSTEMDS-3018] Federated Cost Estimation for Repetitions
    
    This commit changes the federated plan cost estimation when while/for/if 
statement blocks are used.
    Closes #1547.
---
 src/main/java/org/apache/sysds/hops/Hop.java       |  22 +++
 .../java/org/apache/sysds/hops/OptimizerUtils.java |  23 +--
 .../org/apache/sysds/hops/cost/CostEstimator.java  |   4 +-
 .../org/apache/sysds/hops/cost/FederatedCost.java  |  38 ++---
 .../sysds/hops/cost/FederatedCostEstimator.java    |  62 +++----
 .../java/org/apache/sysds/hops/cost/HopRel.java    |   2 +-
 .../hops/ipa/IPAPassRewriteFederatedPlan.java      |   7 +-
 .../java/org/apache/sysds/parser/DMLProgram.java   |   6 +
 .../org/apache/sysds/parser/ForStatementBlock.java |  18 +-
 .../sysds/parser/FunctionStatementBlock.java       |   9 +
 .../org/apache/sysds/parser/IfStatementBlock.java  |  15 +-
 .../org/apache/sysds/parser/StatementBlock.java    |  33 ++++
 .../apache/sysds/parser/WhileStatementBlock.java   |  15 +-
 .../runtime/controlprogram/WhileProgramBlock.java  |   6 +-
 .../fedplanning/FederatedCostEstimatorTest.java    | 181 ++++++++++++++++-----
 15 files changed, 326 insertions(+), 115 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/Hop.java 
b/src/main/java/org/apache/sysds/hops/Hop.java
index 037bfa5..003492f 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -93,6 +93,7 @@ public abstract class Hop implements ParseInfo {
         */
        protected FederatedOutput _federatedOutput = FederatedOutput.NONE;
        protected FederatedCost _federatedCost = new FederatedCost();
+       protected double repetitions = 1;
 
        /**
         * Field defining if prefetch should be activated for operation.
@@ -996,6 +997,15 @@ public abstract class Hop implements ParseInfo {
                _federatedCost = cost;
        }
 
+       /**
+        * Reset federated cost of this hop and all children of this hop.
+        */
+       public void resetFederatedCost(){
+               _federatedCost = new FederatedCost();
+               for ( Hop input : getInput() )
+                       input.resetFederatedCost();
+       }
+
        public void setUpdateType(UpdateType update){
                _updateType = update;
        }
@@ -1539,6 +1549,18 @@ public abstract class Hop implements ParseInfo {
                return ret;
        }
 
+       public void updateRepetitionEstimates(double repetitions){
+               if ( !federatedCostInitialized() ){
+                       this.repetitions = repetitions;
+                       for ( Hop input : getInput() )
+                               input.updateRepetitionEstimates(repetitions);
+               }
+       }
+
+       public double getRepetitions(){
+               return repetitions;
+       }
+
        /**
         * Clones the attributes of that and copies it over to this.
         * 
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java 
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index 47b5822..4d48df6 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -1289,17 +1289,18 @@ public class OptimizerUtils
                if( fpb.getStatementBlock()==null )
                        return defaultValue;
                ForStatementBlock fsb = (ForStatementBlock) 
fpb.getStatementBlock();
-               try {
-                       HashMap<Long,Long> memo = new HashMap<>();
-                       long from = 
rEvalSimpleLongExpression(fsb.getFromHops().getInput().get(0), memo);
-                       long to = 
rEvalSimpleLongExpression(fsb.getToHops().getInput().get(0), memo);
-                       long increment = (fsb.getIncrementHops()==null) ? (from 
< to) ? 1 : -1 : 
-                               
rEvalSimpleLongExpression(fsb.getIncrementHops().getInput().get(0), memo);
-                       if( from != Long.MAX_VALUE && to != Long.MAX_VALUE && 
increment != Long.MAX_VALUE )
-                               return 
(int)Math.ceil(((double)(to-from+1))/increment);
-               }
-               catch(Exception ex){}
-               return defaultValue;
+               return getNumIterations(fsb, defaultValue);
+       }
+
+       public static long getNumIterations(ForStatementBlock fsb, long 
defaultValue){
+               HashMap<Long,Long> memo = new HashMap<>();
+               long from = 
rEvalSimpleLongExpression(fsb.getFromHops().getInput().get(0), memo);
+               long to = 
rEvalSimpleLongExpression(fsb.getToHops().getInput().get(0), memo);
+               long increment = (fsb.getIncrementHops()==null) ? (from < to) ? 
1 : -1 :
+                       
rEvalSimpleLongExpression(fsb.getIncrementHops().getInput().get(0), memo);
+               if( from != Long.MAX_VALUE && to != Long.MAX_VALUE && increment 
!= Long.MAX_VALUE )
+                       return (int)Math.ceil(((double)(to-from+1))/increment);
+               else return defaultValue;
        }
        
        public static long getNumIterations(ForProgramBlock fpb, 
LocalVariableMap vars, long defaultValue) {
diff --git a/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java 
b/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java
index 03948d4..497b807 100644
--- a/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java
+++ b/src/main/java/org/apache/sysds/hops/cost/CostEstimator.java
@@ -116,7 +116,7 @@ public abstract class CostEstimator
                                for( ProgramBlock pb2 : tmp.getChildBlocks() )
                                        ret += rGetTimeEstimate(pb2, stats, 
memoFunc, recursive);
                        
-                       ret *= getNumIterations(stats, tmp);
+                       ret *= getNumIterations(tmp);
                }
                else if ( pb instanceof FunctionProgramBlock ) {
                        FunctionProgramBlock tmp = (FunctionProgramBlock) pb;
@@ -413,7 +413,7 @@ public abstract class CostEstimator
                vs[2] = _unknownStats;
        }
                
-       private static long getNumIterations(HashMap<String,VarStats> stats, 
ForProgramBlock pb) {
+       private static long getNumIterations(ForProgramBlock pb) {
                return OptimizerUtils.getNumIterations(pb, DEFAULT_NUMITER);
        }
 
diff --git a/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java 
b/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java
index f4f8db4..8831fdc 100644
--- a/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java
+++ b/src/main/java/org/apache/sysds/hops/cost/FederatedCost.java
@@ -30,15 +30,20 @@ public class FederatedCost {
        protected double _outputTransferCost = 0;
        protected double _inputTotalCost = 0;
 
+       protected double _repetitions = 1;
+       protected double _totalCost;
+
        public FederatedCost(){}
 
        public FederatedCost(double readCost, double inputTransferCost, double 
outputTransferCost,
-               double computeCost, double inputTotalCost){
+               double computeCost, double inputTotalCost, double repetitions){
                _readCost = readCost;
                _inputTransferCost = inputTransferCost;
                _outputTransferCost = outputTransferCost;
                _computeCost = computeCost;
                _inputTotalCost = inputTotalCost;
+               _repetitions = repetitions;
+               _totalCost = calcTotal();
        }
 
        /**
@@ -46,15 +51,15 @@ public class FederatedCost {
         * @return total cost
         */
        public double getTotal(){
-               return _computeCost + _readCost + _inputTransferCost + 
_outputTransferCost + _inputTotalCost;
+               return _totalCost;
        }
 
-       /**
-        * Multiply the input costs by the number of times the costs are 
repeated.
-        * @param repetitionNumber number of repetitions of the costs
-        */
-       public void addRepetitionCost(int repetitionNumber){
-               _inputTotalCost *= repetitionNumber;
+       private double calcTotal(){
+               return (_computeCost + _readCost + _inputTransferCost + 
_outputTransferCost) * _repetitions + _inputTotalCost;
+       }
+
+       private void updateTotal(){
+               this._totalCost = calcTotal();
        }
 
        /**
@@ -75,6 +80,7 @@ public class FederatedCost {
         */
        public void addInputTotalCost(double additionalCost){
                _inputTotalCost += additionalCost;
+               updateTotal();
        }
 
        /**
@@ -82,19 +88,7 @@ public class FederatedCost {
         * @param federatedCost input cost from which the total is retrieved
         */
        public void addInputTotalCost(FederatedCost federatedCost){
-               _inputTotalCost += federatedCost.getTotal();
-       }
-
-       /**
-        * Add costs of FederatedCost object to this object's current costs.
-        * @param additionalCost object to add to this object
-        */
-       public void addFederatedCost(FederatedCost additionalCost){
-               _readCost += additionalCost._readCost;
-               _inputTransferCost += additionalCost._inputTransferCost;
-               _outputTransferCost += additionalCost._outputTransferCost;
-               _computeCost += additionalCost._computeCost;
-               _inputTotalCost += additionalCost._inputTotalCost;
+               addInputTotalCost(federatedCost.getTotal());
        }
 
        @Override
@@ -110,6 +104,8 @@ public class FederatedCost {
                builder.append(_outputTransferCost);
                builder.append("\n inputTotalCost: ");
                builder.append(_inputTotalCost);
+               builder.append("\n repetitions: ");
+               builder.append(_repetitions);
                builder.append("\n total cost: ");
                builder.append(getTotal());
                return builder.toString();
diff --git 
a/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java 
b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
index 96a33d4..400caa9 100644
--- a/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
+++ b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
@@ -19,6 +19,8 @@
 
 package org.apache.sysds.hops.cost;
 
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.ipa.MemoTable;
 import org.apache.sysds.parser.DMLProgram;
@@ -39,14 +41,13 @@ import java.util.ArrayList;
  * Cost estimator for federated executions with methods and constants for 
going through DML programs to estimate costs.
  */
 public class FederatedCostEstimator {
-       public int DEFAULT_MEMORY_ESTIMATE = 8;
-       public int DEFAULT_ITERATION_NUMBER = 15;
-       public double WORKER_NETWORK_BANDWIDTH_BYTES_PS = 1024*1024*1024; 
//Default network bandwidth in bytes per second
-       public double WORKER_COMPUTE_BANDWIDTH_FLOPS = 2.5*1024*1024*1024; 
//Default compute bandwidth in FLOPS
-       public double WORKER_DEGREE_OF_PARALLELISM = 8; //Default number of 
parallel processes for workers
-       public double WORKER_READ_BANDWIDTH_BYTES_PS = 3.5*1024*1024*1024; 
//Default read bandwidth in bytes per second
+       private static final Log LOG = 
LogFactory.getLog(FederatedCostEstimator.class.getName());
 
-       public boolean printCosts = false; //Temporary for debugging purposes
+       public static int DEFAULT_MEMORY_ESTIMATE = 8;
+       public static double WORKER_NETWORK_BANDWIDTH_BYTES_PS = 
1024*1024*1024; //Default network bandwidth in bytes per second
+       public static double WORKER_COMPUTE_BANDWIDTH_FLOPS = 
2.5*1024*1024*1024; //Default compute bandwidth in FLOPS
+       public static double WORKER_DEGREE_OF_PARALLELISM = 8; //Default number 
of parallel processes for workers
+       public static double WORKER_READ_BANDWIDTH_BYTES_PS = 
3.5*1024*1024*1024; //Default read bandwidth in bytes per second
 
        /**
         * Estimate cost of given DML program in bytes.
@@ -54,6 +55,7 @@ public class FederatedCostEstimator {
         * @return federated cost object with cost estimate in bytes
         */
        public FederatedCost costEstimate(DMLProgram dmlProgram){
+               dmlProgram.updateRepetitionEstimates();
                FederatedCost programTotalCost = new FederatedCost();
                for ( StatementBlock stmBlock : dmlProgram.getStatementBlocks() 
)
                        
programTotalCost.addInputTotalCost(costEstimate(stmBlock).getTotal());
@@ -74,12 +76,9 @@ public class FederatedCostEstimator {
                                for ( StatementBlock bodyBlock : 
whileStatement.getBody() )
                                        
whileSBCost.addInputTotalCost(costEstimate(bodyBlock));
                        }
-                       whileSBCost.addRepetitionCost(DEFAULT_ITERATION_NUMBER);
                        return whileSBCost;
                }
                else if ( sb instanceof IfStatementBlock){
-                       //Get cost of if-block + else-block and divide by two
-                       // since only one of the code blocks will be executed 
in the end
                        IfStatementBlock ifSB = (IfStatementBlock) sb;
                        FederatedCost ifSBCost = new FederatedCost();
                        for ( Statement statement : ifSB.getStatements() ){
@@ -89,7 +88,6 @@ public class FederatedCostEstimator {
                                for ( StatementBlock elseBodySB : 
ifStatement.getElseBody() )
                                        
ifSBCost.addInputTotalCost(costEstimate(elseBodySB));
                        }
-                       
ifSBCost.setInputTotalCost(ifSBCost.getInputTotalCost()/2);
                        
ifSBCost.addInputTotalCost(costEstimate(ifSB.getPredicateHops()));
                        return ifSBCost;
                }
@@ -106,7 +104,6 @@ public class FederatedCostEstimator {
                                for ( StatementBlock forStatementBlockBody : 
forStatement.getBody() )
                                        
forSBCost.addInputTotalCost(costEstimate(forStatementBlockBody));
                        }
-                       forSBCost.addRepetitionCost(forSB.getEstimateReps());
                        return forSBCost;
                }
                else if ( sb instanceof FunctionStatementBlock){
@@ -182,12 +179,13 @@ public class FederatedCostEstimator {
                                
root.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / 
WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0;
                        double readCost = 
root.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / 
WORKER_READ_BANDWIDTH_BYTES_PS;
 
+                       double rootRepetitions = root.getRepetitions();
                        FederatedCost rootFedCost =
-                               new FederatedCost(readCost, inputTransferCost, 
outputTransferCost, computingCost, inputCosts);
+                               new FederatedCost(readCost, inputTransferCost, 
outputTransferCost, computingCost, inputCosts, rootRepetitions);
                        root.setFederatedCost(rootFedCost);
 
-                       if ( printCosts )
-                               printCosts(root);
+                       if ( LOG.isDebugEnabled() )
+                               LOG.debug(getCostInfo(root));
 
                        return rootFedCost;
                }
@@ -199,7 +197,7 @@ public class FederatedCostEstimator {
         * @param hopRelMemo memo table of HopRels for calculating input costs
         * @return cost estimation of Hop DAG starting from given root HopRel
         */
-       public FederatedCost costEstimate(HopRel root, MemoTable hopRelMemo){
+       public static FederatedCost costEstimate(HopRel root, MemoTable 
hopRelMemo){
                // Check if root is in memo table.
                if ( hopRelMemo.containsHopRel(root) ){
                        return root.getCostObject();
@@ -234,7 +232,8 @@ public class FederatedCostEstimator {
                                
root.hopRef.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / 
WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0;
                        double readCost = 
root.hopRef.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / 
WORKER_READ_BANDWIDTH_BYTES_PS;
 
-                       return new FederatedCost(readCost, inputTransferCost, 
outputTransferCost, computingCost, inputCosts);
+                       double rootRepetitions = root.hopRef.getRepetitions();
+                       return new FederatedCost(readCost, inputTransferCost, 
outputTransferCost, computingCost, inputCosts, rootRepetitions);
                }
        }
 
@@ -247,7 +246,7 @@ public class FederatedCostEstimator {
         * @param root hopRel for which cost is estimated
         * @return input transfer cost estimate
         */
-       private double inputTransferCostEstimate(boolean hasFederatedInput, 
HopRel root){
+       private static double inputTransferCostEstimate(boolean 
hasFederatedInput, HopRel root){
                if ( hasFederatedInput )
                        return root.inputDependency.stream()
                                .filter(input -> 
(root.hopRef.isFederatedDataOp()) ? input.hasFederatedOutput() : 
input.hasLocalOutput() )
@@ -275,18 +274,21 @@ public class FederatedCostEstimator {
        }
 
        /**
-        * Prints costs and information about root for debugging purposes
-        * @param root hop for which information is printed
+        * Return costs and information about root for debugging purposes.
+        * @param root hop for which information is returned
+        * @return information about root cost
         */
-       private static void printCosts(Hop root){
-               System.out.println("===============================");
-               System.out.println(root);
-               System.out.println("Is federated: " + root.isFederated());
-               System.out.println("Has federated output: " + 
root.hasFederatedOutput());
-               System.out.println(root.getText());
-               System.out.println("Pure computeCost: " + 
ComputeCost.getHOPComputeCost(root));
-               System.out.println("Dim1: " + root.getDim1() + " Dim2: " + 
root.getDim2());
-               System.out.println(root.getFederatedCost().toString());
-               System.out.println("===============================");
+       private static String getCostInfo(Hop root){
+               String sep = System.getProperty("line.separator");
+               StringBuilder costInfo = new StringBuilder();
+               costInfo
+                       .append(root).append(sep)
+                       .append("Is federated: ").append(root.isFederated())
+                       .append(" Has federated output: 
").append(root.hasFederatedOutput())
+                       .append(root.getText()).append(sep)
+                       .append("Pure computeCost: " + 
ComputeCost.getHOPComputeCost(root))
+                       .append(" Dim1: " + root.getDim1() + " Dim2: " + 
root.getDim2()).append(sep)
+                       .append(root.getFederatedCost().toString()).append(sep);
+               return costInfo.toString();
        }
 }
diff --git a/src/main/java/org/apache/sysds/hops/cost/HopRel.java 
b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
index b1cc6dd..bd5ee85 100644
--- a/src/main/java/org/apache/sysds/hops/cost/HopRel.java
+++ b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
@@ -55,7 +55,7 @@ public class HopRel {
                hopRef = associatedHop;
                this.fedOut = fedOut;
                setInputDependency(hopRelMemo);
-               cost = new FederatedCostEstimator().costEstimate(this, 
hopRelMemo);
+               cost = FederatedCostEstimator.costEstimate(this, hopRelMemo);
        }
 
        /**
diff --git 
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
index db313af..383be42 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
@@ -69,6 +69,10 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
         */
        private final static List<Hop> terminalHops = new ArrayList<>();
 
+       public List<Hop> getTerminalHops(){
+               return terminalHops;
+       }
+
        /**
         * Indicates if an IPA pass is applicable for the current configuration.
         * The configuration depends on OptimizerUtils.FEDERATED_COMPILATION.
@@ -93,6 +97,7 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
        @Override
        public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph,
                FunctionCallSizeInfo fcallSizes) {
+               prog.updateRepetitionEstimates();
                rewriteStatementBlocks(prog, prog.getStatementBlocks());
                setFinalFedouts();
                return false;
@@ -178,7 +183,7 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
        }
 
        private ArrayList<StatementBlock> 
rewriteDefaultStatementBlock(DMLProgram prog, StatementBlock sb) {
-               if(sb.getHops() != null && !sb.getHops().isEmpty()) {
+               if(sb.hasHops()) {
                        for(Hop sbHop : sb.getHops()) {
                                if(sbHop instanceof FunctionOp) {
                                        String funcName = ((FunctionOp) 
sbHop).getFunctionName();
diff --git a/src/main/java/org/apache/sysds/parser/DMLProgram.java 
b/src/main/java/org/apache/sysds/parser/DMLProgram.java
index 498a59d..2edffa7 100644
--- a/src/main/java/org/apache/sysds/parser/DMLProgram.java
+++ b/src/main/java/org/apache/sysds/parser/DMLProgram.java
@@ -201,6 +201,12 @@ public class DMLProgram
                        throw new RuntimeException(ex);
                }
        }
+
+       public void updateRepetitionEstimates(){
+               for ( StatementBlock stmBlock : getStatementBlocks() ){
+                       stmBlock.updateRepetitionEstimates(1);
+               }
+       }
        
        @Override
        public String toString(){
diff --git a/src/main/java/org/apache/sysds/parser/ForStatementBlock.java 
b/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
index 1acd1ac..b21b9b5 100644
--- a/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
@@ -447,6 +447,20 @@ public class ForStatementBlock extends StatementBlock
                        }
                }
                
-               return 10;
+               return (int) DEFAULT_LOOP_REPETITIONS;
        }
-}
\ No newline at end of file
+
+       @Override
+       public void updateRepetitionEstimates(double repetitions){
+               this.repetitions = repetitions * getEstimateReps();
+               _fromHops.updateRepetitionEstimates(this.repetitions);
+               _toHops.updateRepetitionEstimates(this.repetitions);
+               _incrementHops.updateRepetitionEstimates(this.repetitions);
+               for(Statement statement : getStatements()) {
+                       List<StatementBlock> children = ((ForStatement) 
statement).getBody();
+                       for ( StatementBlock stmBlock : children ){
+                               
stmBlock.updateRepetitionEstimates(this.repetitions);
+                       }
+               }
+       }
+}
diff --git a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java 
b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
index cc7ab64..ed70c69 100644
--- a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
@@ -258,4 +258,13 @@ public class FunctionStatementBlock extends StatementBlock 
implements FunctionBl
                return ProgramConverter
                        .createDeepCopyFunctionStatementBlock(this, new 
HashSet<>(), new HashSet<>());
        }
+
+       @Override
+       public void updateRepetitionEstimates(double repetitions){
+               for (Statement stm : getStatements()){
+                       for (StatementBlock block : ((FunctionStatement) 
stm).getBody()){
+                               block.updateRepetitionEstimates(repetitions);
+                       }
+               }
+       }
 }
diff --git a/src/main/java/org/apache/sysds/parser/IfStatementBlock.java 
b/src/main/java/org/apache/sysds/parser/IfStatementBlock.java
index 4762a14..bae78ca 100644
--- a/src/main/java/org/apache/sysds/parser/IfStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/IfStatementBlock.java
@@ -502,7 +502,20 @@ public class IfStatementBlock extends StatementBlock
                liveInReturn.addVariables(_liveIn);
                return liveInReturn;
        }
-       
+
+       @Override
+       public void updateRepetitionEstimates(double repetitions){
+               this.repetitions = repetitions;
+               getPredicateHops().updateRepetitionEstimates(this.repetitions);
+               for ( Statement statement : getStatements() ){
+                       IfStatement ifStatement = (IfStatement) statement;
+                       double blockLevelReps = repetitions / 2;
+                       for ( StatementBlock ifBodySB : ifStatement.getIfBody() 
)
+                               
ifBodySB.updateRepetitionEstimates(blockLevelReps);
+                       for ( StatementBlock elseBodySB : 
ifStatement.getElseBody() )
+                               
elseBodySB.updateRepetitionEstimates(blockLevelReps);
+               }
+       }
        
        /////////
        // materialized hops recompilation flags
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java 
b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index 4f8cd1b..6e9545c 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -29,6 +29,7 @@ import java.util.Map.Entry;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.hops.FunctionOp;
 import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.recompile.Recompiler;
 import org.apache.sysds.hops.rewrite.StatementBlockRewriteRule;
@@ -64,6 +65,9 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
        private boolean _splitDag = false;
        private boolean _nondeterministic = false;
 
+       protected double repetitions = 1;
+       public final static double DEFAULT_LOOP_REPETITIONS = 10;
+
        public StatementBlock() {
                _ID = getNextSBID();
                _name = "SB"+_ID;
@@ -1238,6 +1242,35 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
                return liveInReturn;
        }
 
+       public boolean hasHops(){
+               return getHops() != null && !getHops().isEmpty();
+       }
+
+       /**
+        * Updates the repetition estimate for this statement block
+        * and all contained hops. FunctionStatementBlocks are loaded
+        * from the function dictionary and repetitions are estimated
+        * for the contained statement blocks.
+        *
+        * This method is overridden in the subclasses of StatementBlock.
+        * @param repetitions estimated for this statement block
+        */
+       public void updateRepetitionEstimates(double repetitions){
+               this.repetitions = repetitions;
+               if ( hasHops() ){
+                       for ( Hop root : getHops() ){
+                               // Set repetitionNum for hops recursively
+                               if(root instanceof FunctionOp) {
+                                       String funcName = ((FunctionOp) 
root).getFunctionName();
+                                       FunctionStatementBlock sbFuncBlock = 
getDMLProg().getBuiltinFunctionDictionary().getFunction(funcName);
+                                       
sbFuncBlock.updateRepetitionEstimates(repetitions);
+                               }
+                               else
+                                       
root.updateRepetitionEstimates(repetitions);
+                       }
+               }
+       }
+
        ///////////////////////////////////////////////////////////////
        // validate error handling (consistent for all expressions)
 
diff --git a/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java 
b/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java
index 7a09242..b28e682 100644
--- a/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/WhileStatementBlock.java
@@ -22,6 +22,7 @@ package org.apache.sysds.parser;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.List;
 
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.hops.Hop;
@@ -317,6 +318,18 @@ public class WhileStatementBlock extends StatementBlock
                
                return liveInReturn;
        }
+
+       @Override
+       public void updateRepetitionEstimates(double repetitions){
+               this.repetitions = repetitions * DEFAULT_LOOP_REPETITIONS;
+               getPredicateHops().updateRepetitionEstimates(this.repetitions);
+               for(Statement statement : getStatements()) {
+                       List<StatementBlock> children = 
((WhileStatement)statement).getBody();
+                       for ( StatementBlock stmBlock : children ){
+                               
stmBlock.updateRepetitionEstimates(this.repetitions);
+                       }
+               }
+       }
        
        /////////
        // materialized hops recompilation flags
@@ -331,4 +344,4 @@ public class WhileStatementBlock extends StatementBlock
        public boolean requiresPredicateRecompilation() {
                return _requiresPredicateRecompile;
        }
-}
\ No newline at end of file
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java
index cc916de..4695b94 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java
@@ -20,8 +20,12 @@
 package org.apache.sysds.runtime.controlprogram;
 
 import java.util.ArrayList;
+import java.util.List;
 
 import org.apache.sysds.hops.Hop;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.parser.StatementBlock;
 import org.apache.sysds.parser.WhileStatementBlock;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ValueType;
@@ -151,4 +155,4 @@ public class WhileProgramBlock extends ProgramBlock
        public String printBlockErrorLocation(){
                return "ERROR: Runtime error in while program block generated 
from while statement block between lines " + _beginLine + " and " + _endLine + 
" -- ";
        }
-}
\ No newline at end of file
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
index 906ed1f..b8ad989 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.test.functions.privacy.fedplanning;
 
+import net.jcip.annotations.NotThreadSafe;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.conf.ConfigurationManager;
@@ -32,15 +33,21 @@ import org.apache.sysds.hops.NaryOp;
 import org.apache.sysds.hops.ReorgOp;
 import org.apache.sysds.hops.cost.FederatedCost;
 import org.apache.sysds.hops.cost.FederatedCostEstimator;
+import org.apache.sysds.hops.ipa.FunctionCallGraph;
+import org.apache.sysds.hops.ipa.IPAPassRewriteFederatedPlan;
 import org.apache.sysds.parser.DMLProgram;
 import org.apache.sysds.parser.DMLTranslator;
 import org.apache.sysds.parser.LanguageException;
 import org.apache.sysds.parser.ParserFactory;
 import org.apache.sysds.parser.ParserWrapper;
+import org.apache.sysds.parser.StatementBlock;
 import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
+import org.junit.After;
 import org.junit.Assert;
+import org.junit.Before;
+import org.junit.BeforeClass;
 import org.junit.Test;
 
 import java.io.FileNotFoundException;
@@ -51,6 +58,7 @@ import java.util.Set;
 
 import static org.apache.sysds.common.Types.OpOp2.MULT;
 
+@NotThreadSafe
 public class FederatedCostEstimatorTest extends AutomatedTestBase {
 
        private static final String TEST_DIR = "functions/privacy/fedplanning/";
@@ -58,13 +66,36 @@ public class FederatedCostEstimatorTest extends 
AutomatedTestBase {
        private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedCostEstimatorTest.class.getSimpleName() + "/";
        FederatedCostEstimator fedCostEstimator = new FederatedCostEstimator();
 
+       private static double COMPUTE_FLOPS;
+       private static double READ_PS;
+       private static double NETWORK_PS;
+
        @Override
        public void setUp() {}
 
+       @BeforeClass
+       public static void storeConstants(){
+               COMPUTE_FLOPS = 
FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS;
+               READ_PS = FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS;
+               NETWORK_PS = 
FederatedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS;
+       }
+
+       @Before
+       public void setConstants(){
+               FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
+               FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
+               FederatedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS = 5;
+       }
+
+       @After
+       public void resetConstants(){
+               FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 
COMPUTE_FLOPS;
+               FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = READ_PS;
+               FederatedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS = 
NETWORK_PS;
+       }
+
        @Test
        public void simpleBinary() {
-               fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
-               fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
 
                /*
                 * HOP                  Occurences              ComputeCost     
        ReadCost        ComputeCostFinal        ReadCostFinal
@@ -75,70 +106,87 @@ public class FederatedCostEstimatorTest extends 
AutomatedTestBase {
                 * TOSTRING             1                               1       
                        800                     0.0625                          
80
                 * UnaryOp              1                               1       
                        8                       0.0625                          
0.8
                 */
-               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS 
*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
-               double readCost = (2*64+1600+800+8) / 
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+               double computeCost = (16+2*100+100+1+1) / 
(FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * 
FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+               double readCost = (2*64+1600+800+8) / 
(FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
 
                double expectedCost = computeCost + readCost;
                runTest("BinaryCostEstimatorTest.dml", false, expectedCost);
        }
 
        @Test
+       public void simpleBinaryHopRelTest() {
+               runHopRelTest("BinaryCostEstimatorTest.dml", false);
+       }
+
+       @Test
        public void ifElseTest(){
-               fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
-               fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
-               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS 
*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
-               double readCost = (2*64+1600+800+8) / 
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+               double computeCost = (16+2*100+100+1+1) / 
(FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * 
FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+               double readCost = (2*64+1600+800+8) / 
(FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
                double expectedCost = ((computeCost + readCost + 0.8 + 0.0625 + 
0.0625) / 2) + 0.0625 + 0.8 + 0.0625;
                runTest("IfElseCostEstimatorTest.dml", false, expectedCost);
        }
 
        @Test
+       public void ifElseHopRelTest(){
+               runHopRelTest("IfElseCostEstimatorTest.dml", false);
+       }
+
+       @Test
        public void whileTest(){
-               fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
-               fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
-               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS 
*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
-               double readCost = (2*64+1600+800+8) / 
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
-               double expectedCost = (computeCost + readCost + 0.0625) * 
fedCostEstimator.DEFAULT_ITERATION_NUMBER + 0.0625 + 0.8;
+               double computeCost = (16+2*100+100+1+1) / 
(FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * 
FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+               double readCost = (2*64+1600+800+8) / 
(FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+               double expectedCost = (computeCost + readCost + 0.0625 + 0.0625 
+ 0.8) * StatementBlock.DEFAULT_LOOP_REPETITIONS;
                runTest("WhileCostEstimatorTest.dml", false, expectedCost);
        }
 
        @Test
+       public void whileHopRelTest(){
+               runHopRelTest("WhileCostEstimatorTest.dml", false);
+       }
+
+       @Test
        public void forLoopTest(){
-               fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
-               fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
-               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS 
*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
-               double readCost = (2*64+1600+800+8) / 
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+               double computeCost = (16+2*100+100+1+1) / 
(FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * 
FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+               double readCost = (2*64+1600+800+8) / 
(FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
                double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 
0.0625 + 0.0625 + 0.8 + 0.0625;
                double expectedCost = (computeCost + readCost + predicateCost) 
* 5;
                runTest("ForLoopCostEstimatorTest.dml", false, expectedCost);
        }
 
        @Test
+       public void forLoopHopRelTest(){
+               runHopRelTest("ForLoopCostEstimatorTest.dml", false);
+       }
+
+       @Test
        public void parForLoopTest(){
-               fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
-               fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
-               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS 
*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
-               double readCost = (2*64+1600+800+8) / 
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+               double computeCost = (16+2*100+100+1+1) / 
(FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * 
FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+               double readCost = (2*64+1600+800+8) / 
(FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
                double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 
0.0625 + 0.0625 + 0.8 + 0.0625;
                double expectedCost = (computeCost + readCost + predicateCost) 
* 5;
                runTest("ParForLoopCostEstimatorTest.dml", false, expectedCost);
        }
 
        @Test
+       public void parForLoopHopRelTest(){
+               runHopRelTest("ParForLoopCostEstimatorTest.dml", false);
+       }
+
+       @Test
        public void functionTest(){
-               fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
-               fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
-               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS 
*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
-               double readCost = (2*64+1600+800+8) / 
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
+               double computeCost = (16+2*100+100+1+1) / 
(FederatedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS * 
FederatedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+               double readCost = (2*64+1600+800+8) / 
(FederatedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
                double expectedCost = (computeCost + readCost);
                runTest("FunctionCostEstimatorTest.dml", false, expectedCost);
        }
 
        @Test
+       public void functionHopRelTest(){
+               runHopRelTest("FunctionCostEstimatorTest.dml", false);
+       }
+
+       @Test
        public void federatedMultiply() {
-               fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
-               fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
-               fedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS = 5;
 
                double literalOpCost = 10*0.0625;
                double naryOpCostSpecial = (0.125+2.2);
@@ -224,27 +272,72 @@ public class FederatedCostEstimatorTest extends 
AutomatedTestBase {
                
hops.stream().map(Hop::getClass).distinct().forEach(System.out::println);
        }
 
+       private DMLProgram testSetup(String scriptFilename) throws IOException{
+               setTestConfig(scriptFilename);
+               String dmlScriptString = readScript(scriptFilename);
+
+               //parsing, dependency analysis and constructing hops (step 3 
and 4 in DMLScript.java)
+               ParserWrapper parser = ParserFactory.createParser();
+               DMLProgram prog = 
parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new 
HashMap<>());
+               DMLTranslator dmlt = new DMLTranslator(prog);
+               dmlt.liveVariableAnalysis(prog);
+               dmlt.validateParseTree(prog);
+               dmlt.constructHops(prog);
+               if ( 
scriptFilename.equals("FederatedMultiplyCostEstimatorTest.dml")){
+                       modifyFedouts(prog);
+                       dmlt.rewriteHopsDAG(prog);
+                       hops = new HashSet<>();
+                       prog.getStatementBlocks().forEach(stmBlock -> 
stmBlock.getHops().forEach(this::addHop));
+               }
+               return prog;
+       }
+
+       private void compareResults(DMLProgram prog) {
+               IPAPassRewriteFederatedPlan rewriter = new 
IPAPassRewriteFederatedPlan();
+               rewriter.rewriteProgram(prog, new FunctionCallGraph(prog), 
null);
+
+               double actualCost = 0;
+               for ( Hop root : rewriter.getTerminalHops() ){
+                       actualCost += root.getFederatedCost().getTotal();
+               }
+
+
+               rewriter.getTerminalHops().forEach(Hop::resetFederatedCost);
+               fedCostEstimator = new FederatedCostEstimator();
+               double expectedCost = 0;
+               for ( Hop root : rewriter.getTerminalHops() )
+                       expectedCost += 
fedCostEstimator.costEstimate(root).getTotal();
+               Assert.assertEquals(expectedCost, actualCost, 0.0001);
+       }
+
+       private void runHopRelTest( String scriptFilename, boolean 
expectedException ) {
+               boolean raisedException = false;
+               try
+               {
+                       DMLProgram prog = testSetup(scriptFilename);
+                       compareResults(prog);
+               }
+               catch(LanguageException ex) {
+                       raisedException = true;
+                       if(raisedException!=expectedException)
+                               ex.printStackTrace();
+               }
+               catch(Exception ex2) {
+                       ex2.printStackTrace();
+                       throw new RuntimeException(ex2);
+               }
+
+               Assert.assertEquals("Expected exception does not match raised 
exception",
+                       expectedException, raisedException);
+       }
+
        private void runTest( String scriptFilename, boolean expectedException, 
double expectedCost ) {
                boolean raisedException = false;
                try
                {
-                       setTestConfig(scriptFilename);
-                       String dmlScriptString = readScript(scriptFilename);
-
-                       //parsing, dependency analysis and constructing hops 
(step 3 and 4 in DMLScript.java)
-                       ParserWrapper parser = ParserFactory.createParser();
-                       DMLProgram prog = 
parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new 
HashMap<>());
-                       DMLTranslator dmlt = new DMLTranslator(prog);
-                       dmlt.liveVariableAnalysis(prog);
-                       dmlt.validateParseTree(prog);
-                       dmlt.constructHops(prog);
-                       if ( 
scriptFilename.equals("FederatedMultiplyCostEstimatorTest.dml")){
-                               modifyFedouts(prog);
-                               dmlt.rewriteHopsDAG(prog);
-                               hops = new HashSet<>();
-                               prog.getStatementBlocks().forEach(stmBlock -> 
stmBlock.getHops().forEach(this::addHop));
-                       }
+                       DMLProgram prog = testSetup(scriptFilename);
 
+                       fedCostEstimator = new FederatedCostEstimator();
                        FederatedCost actualCost = 
fedCostEstimator.costEstimate(prog);
                        Assert.assertEquals(expectedCost, 
actualCost.getTotal(), 0.0001);
                }

Reply via email to