This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new e671cc30bc [SYSTEMDS-3621] Maintain loop-dependency ratio in 
statementblock headers
e671cc30bc is described below

commit e671cc30bc6852090126f9f57458707210398f1c
Author: Arnab Phani <phaniar...@gmail.com>
AuthorDate: Mon Apr 22 17:17:37 2024 +0200

    [SYSTEMDS-3621] Maintain loop-dependency ratio in statementblock headers
    
    This patch adds a flag in the loop header to maintain the ratio
    of loop-dependent HOP dags. We extended the MarkForLineageReuse
    rewrite for this purpose. The function header updates the average
    of loop-dependency ratio of the child blocks.
    The aim is to extend this rewrite to update the ratio for all block
    headers and tune the delay factor for lineage caching accordingly.
---
 .../sysds/hops/rewrite/MarkForLineageReuse.java    | 22 ++++++++++++++--------
 .../sysds/parser/FunctionStatementBlock.java       |  9 +++++++++
 .../org/apache/sysds/parser/StatementBlock.java    | 10 ++++++++++
 3 files changed, 33 insertions(+), 8 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/MarkForLineageReuse.java 
b/src/main/java/org/apache/sysds/hops/rewrite/MarkForLineageReuse.java
index f4fc86ce24..e782976efd 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/MarkForLineageReuse.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/MarkForLineageReuse.java
@@ -85,35 +85,41 @@ public class MarkForLineageReuse extends 
StatementBlockRewriteRule
                                }
                                else if (sb instanceof WhileStatementBlock) {
                                        WhileStatement wstmt = 
(WhileStatement)sb.getStatement(0);
-                                       rUnmarkLoopDepVarsSB(wstmt.getBody(), 
newdepsbs, loopVar); 
+                                       rUnmarkLoopDepVarsSB(wstmt.getBody(), 
newdepsbs, loopVar);
                                }
                                else if (sb instanceof IfStatementBlock) {
                                        IfStatement ifstmt = 
(IfStatement)sb.getStatement(0);
-                                       
rUnmarkLoopDepVarsSB(ifstmt.getIfBody(), newdepsbs, loopVar); 
+                                       
rUnmarkLoopDepVarsSB(ifstmt.getIfBody(), newdepsbs, loopVar);
                                        if (ifstmt.getElseBody() != null)
-                                               
rUnmarkLoopDepVarsSB(ifstmt.getElseBody(), newdepsbs, loopVar); 
+                                               
rUnmarkLoopDepVarsSB(ifstmt.getElseBody(), newdepsbs, loopVar);
                                }
                                else if (sb instanceof FunctionStatementBlock) {
                                        FunctionStatement fnstmt = 
(FunctionStatement)sb.getStatement(0);
                                        rUnmarkLoopDepVarsSB(fnstmt.getBody(), 
newdepsbs, loopVar);
+                                       ((FunctionStatementBlock) 
sb).setAvgLoopDepRatio();
                                }
                                else {
-                                       if (sb.getHops() != null)
-                                               for (int j=0; 
j<sb.variablesUpdated().getSize(); j++) {
+                                       if (sb.getHops() != null) {
+                                               for(int j = 0; j < 
sb.variablesUpdated().getSize(); j++) {
+                                                       int hopCount = 0;
                                                        HashSet<String> 
newdeproots = new HashSet<>(deproots);
-                                                       for (Hop hop : 
sb.getHops()) {
+                                                       for(Hop hop : 
sb.getHops()) {
                                                                // find the 
loop dependent DAG roots
                                                                
Hop.resetVisitStatus(sb.getHops());
                                                                HashSet<Long> 
dephops = new HashSet<>();
                                                                
rUnmarkLoopDepVars(hop, loopVar, newdeproots, dephops);
+                                                               if 
(dephops.size() > 0)
+                                                                       
hopCount++;
                                                        }
-                                                       if (!deproots.isEmpty() 
&& deproots.equals(newdeproots))
-                                                               // break if 
loop dependent DAGs are converged to a unvarying set
+                                                       
sb.setLoopDepRatio((double)hopCount/sb.getHops().size());
+                                                       if(!deproots.isEmpty() 
&& deproots.equals(newdeproots))
+                                                               // break if 
loop dependent DAGs are converged to an unvarying set
                                                                break;
                                                        else
                                                                // iterate to 
propagate the loop dependents across all the DAGs in this SB
                                                                
deproots.addAll(newdeproots);
                                                }
+                                       }
                                }
                        }
                        deproots.addAll(newdepsbs);
diff --git a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java 
b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
index 4b2e28578b..8601147869 100644
--- a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
@@ -255,6 +255,15 @@ public class FunctionStatementBlock extends StatementBlock 
implements FunctionBl
                return _nondeterministic;
        }
 
+       // Set the loop dependent hop ratio as the average of all SBs in this 
function
+       public void setAvgLoopDepRatio() {
+               double totDep = 0;
+               FunctionStatement fstmt = (FunctionStatement) 
_statements.get(0);
+               for (var sb : fstmt.getBody())
+                       totDep += sb.getLoopDepRatio();
+               this.setLoopDepRatio(totDep/fstmt.getBody().size());
+       }
+
        @Override
        public FunctionBlock cloneFunctionBlock() {
                return ProgramConverter
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java 
b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index 00117aeb2c..e8658e359e 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -67,6 +67,7 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
        private HashMap<Lop.Type, List<Lop.Type>> _checkpointPositions = null;
 
        protected double repetitions = 1;
+       private double loopDepRatio = 0; //ratio of loop dependent HOP dags
        public final static double DEFAULT_LOOP_REPETITIONS = 10;
 
        public StatementBlock() {
@@ -180,6 +181,15 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
                return _splitDag;
        }
 
+       public double getLoopDepRatio() {
+               return loopDepRatio;
+       }
+
+       // maintain the ration of loop-dependent HOP dags in this block
+       public void setLoopDepRatio(double dep) {
+               loopDepRatio = dep;
+       }
+
        private static boolean isMergeablePrintStatement(Statement stmt) {
                return ( stmt instanceof PrintStatement &&
                        (((PrintStatement)stmt).getType() == PRINTTYPE.STOP || 
((PrintStatement)stmt).getType() == PRINTTYPE.ASSERT) );

Reply via email to