This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new e671cc30bc [SYSTEMDS-3621] Maintain loop-dependency ratio in
statementblock headers
e671cc30bc is described below
commit e671cc30bc6852090126f9f57458707210398f1c
Author: Arnab Phani <[email protected]>
AuthorDate: Mon Apr 22 17:17:37 2024 +0200
[SYSTEMDS-3621] Maintain loop-dependency ratio in statementblock headers
This patch adds a flag in the loop header to maintain the ratio
of loop-dependent HOP dags. We extended the MarkForLineageReuse
rewrite for this purpose. The function header updates the average
of loop-dependency ratio of the child blocks.
The aim is to extend this rewrite to update the ratio for all block
headers and tune the delay factor for lineage caching accordingly.
---
.../sysds/hops/rewrite/MarkForLineageReuse.java | 22 ++++++++++++++--------
.../sysds/parser/FunctionStatementBlock.java | 9 +++++++++
.../org/apache/sysds/parser/StatementBlock.java | 10 ++++++++++
3 files changed, 33 insertions(+), 8 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/MarkForLineageReuse.java
b/src/main/java/org/apache/sysds/hops/rewrite/MarkForLineageReuse.java
index f4fc86ce24..e782976efd 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/MarkForLineageReuse.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/MarkForLineageReuse.java
@@ -85,35 +85,41 @@ public class MarkForLineageReuse extends
StatementBlockRewriteRule
}
else if (sb instanceof WhileStatementBlock) {
WhileStatement wstmt =
(WhileStatement)sb.getStatement(0);
- rUnmarkLoopDepVarsSB(wstmt.getBody(),
newdepsbs, loopVar);
+ rUnmarkLoopDepVarsSB(wstmt.getBody(),
newdepsbs, loopVar);
}
else if (sb instanceof IfStatementBlock) {
IfStatement ifstmt =
(IfStatement)sb.getStatement(0);
-
rUnmarkLoopDepVarsSB(ifstmt.getIfBody(), newdepsbs, loopVar);
+
rUnmarkLoopDepVarsSB(ifstmt.getIfBody(), newdepsbs, loopVar);
if (ifstmt.getElseBody() != null)
-
rUnmarkLoopDepVarsSB(ifstmt.getElseBody(), newdepsbs, loopVar);
+
rUnmarkLoopDepVarsSB(ifstmt.getElseBody(), newdepsbs, loopVar);
}
else if (sb instanceof FunctionStatementBlock) {
FunctionStatement fnstmt =
(FunctionStatement)sb.getStatement(0);
rUnmarkLoopDepVarsSB(fnstmt.getBody(),
newdepsbs, loopVar);
+ ((FunctionStatementBlock)
sb).setAvgLoopDepRatio();
}
else {
- if (sb.getHops() != null)
- for (int j=0;
j<sb.variablesUpdated().getSize(); j++) {
+ if (sb.getHops() != null) {
+ for(int j = 0; j <
sb.variablesUpdated().getSize(); j++) {
+ int hopCount = 0;
HashSet<String>
newdeproots = new HashSet<>(deproots);
- for (Hop hop :
sb.getHops()) {
+ for(Hop hop :
sb.getHops()) {
// find the
loop dependent DAG roots
Hop.resetVisitStatus(sb.getHops());
HashSet<Long>
dephops = new HashSet<>();
rUnmarkLoopDepVars(hop, loopVar, newdeproots, dephops);
+ if
(dephops.size() > 0)
+
hopCount++;
}
- if (!deproots.isEmpty()
&& deproots.equals(newdeproots))
- // break if
loop dependent DAGs are converged to a unvarying set
+
sb.setLoopDepRatio((double)hopCount/sb.getHops().size());
+ if(!deproots.isEmpty()
&& deproots.equals(newdeproots))
+ // break if
loop dependent DAGs are converged to an unvarying set
break;
else
// iterate to
propagate the loop dependents across all the DAGs in this SB
deproots.addAll(newdeproots);
}
+ }
}
}
deproots.addAll(newdepsbs);
diff --git a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
index 4b2e28578b..8601147869 100644
--- a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java
@@ -255,6 +255,15 @@ public class FunctionStatementBlock extends StatementBlock
implements FunctionBl
return _nondeterministic;
}
+ // Set the loop dependent hop ratio as the average of all SBs in this
function
+ public void setAvgLoopDepRatio() {
+ double totDep = 0;
+ FunctionStatement fstmt = (FunctionStatement)
_statements.get(0);
+ for (var sb : fstmt.getBody())
+ totDep += sb.getLoopDepRatio();
+ this.setLoopDepRatio(totDep/fstmt.getBody().size());
+ }
+
@Override
public FunctionBlock cloneFunctionBlock() {
return ProgramConverter
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java
b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index 00117aeb2c..e8658e359e 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -67,6 +67,7 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
private HashMap<Lop.Type, List<Lop.Type>> _checkpointPositions = null;
protected double repetitions = 1;
+ private double loopDepRatio = 0; //ratio of loop dependent HOP dags
public final static double DEFAULT_LOOP_REPETITIONS = 10;
public StatementBlock() {
@@ -180,6 +181,15 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
return _splitDag;
}
+ public double getLoopDepRatio() {
+ return loopDepRatio;
+ }
+
+ // maintain the ration of loop-dependent HOP dags in this block
+ public void setLoopDepRatio(double dep) {
+ loopDepRatio = dep;
+ }
+
private static boolean isMergeablePrintStatement(Statement stmt) {
return ( stmt instanceof PrintStatement &&
(((PrintStatement)stmt).getType() == PRINTTYPE.STOP ||
((PrintStatement)stmt).getType() == PRINTTYPE.ASSERT) );