This is an automated email from the ASF dual-hosted git repository. arnabp20 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new e671cc30bc [SYSTEMDS-3621] Maintain loop-dependency ratio in statementblock headers e671cc30bc is described below commit e671cc30bc6852090126f9f57458707210398f1c Author: Arnab Phani <phaniar...@gmail.com> AuthorDate: Mon Apr 22 17:17:37 2024 +0200 [SYSTEMDS-3621] Maintain loop-dependency ratio in statementblock headers This patch adds a flag in the loop header to maintain the ratio of loop-dependent HOP dags. We extended the MarkForLineageReuse rewrite for this purpose. The function header updates the average of loop-dependency ratio of the child blocks. The aim is to extend this rewrite to update the ratio for all block headers and tune the delay factor for lineage caching accordingly. --- .../sysds/hops/rewrite/MarkForLineageReuse.java | 22 ++++++++++++++-------- .../sysds/parser/FunctionStatementBlock.java | 9 +++++++++ .../org/apache/sysds/parser/StatementBlock.java | 10 ++++++++++ 3 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/rewrite/MarkForLineageReuse.java b/src/main/java/org/apache/sysds/hops/rewrite/MarkForLineageReuse.java index f4fc86ce24..e782976efd 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/MarkForLineageReuse.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/MarkForLineageReuse.java @@ -85,35 +85,41 @@ public class MarkForLineageReuse extends StatementBlockRewriteRule } else if (sb instanceof WhileStatementBlock) { WhileStatement wstmt = (WhileStatement)sb.getStatement(0); - rUnmarkLoopDepVarsSB(wstmt.getBody(), newdepsbs, loopVar); + rUnmarkLoopDepVarsSB(wstmt.getBody(), newdepsbs, loopVar); } else if (sb instanceof IfStatementBlock) { IfStatement ifstmt = (IfStatement)sb.getStatement(0); - rUnmarkLoopDepVarsSB(ifstmt.getIfBody(), newdepsbs, loopVar); + rUnmarkLoopDepVarsSB(ifstmt.getIfBody(), newdepsbs, loopVar); if (ifstmt.getElseBody() != null) - rUnmarkLoopDepVarsSB(ifstmt.getElseBody(), newdepsbs, loopVar); + rUnmarkLoopDepVarsSB(ifstmt.getElseBody(), newdepsbs, loopVar); } else if (sb instanceof FunctionStatementBlock) { FunctionStatement fnstmt = (FunctionStatement)sb.getStatement(0); rUnmarkLoopDepVarsSB(fnstmt.getBody(), newdepsbs, loopVar); + ((FunctionStatementBlock) sb).setAvgLoopDepRatio(); } else { - if (sb.getHops() != null) - for (int j=0; j<sb.variablesUpdated().getSize(); j++) { + if (sb.getHops() != null) { + for(int j = 0; j < sb.variablesUpdated().getSize(); j++) { + int hopCount = 0; HashSet<String> newdeproots = new HashSet<>(deproots); - for (Hop hop : sb.getHops()) { + for(Hop hop : sb.getHops()) { // find the loop dependent DAG roots Hop.resetVisitStatus(sb.getHops()); HashSet<Long> dephops = new HashSet<>(); rUnmarkLoopDepVars(hop, loopVar, newdeproots, dephops); + if (dephops.size() > 0) + hopCount++; } - if (!deproots.isEmpty() && deproots.equals(newdeproots)) - // break if loop dependent DAGs are converged to a unvarying set + sb.setLoopDepRatio((double)hopCount/sb.getHops().size()); + if(!deproots.isEmpty() && deproots.equals(newdeproots)) + // break if loop dependent DAGs are converged to an unvarying set break; else // iterate to propagate the loop dependents across all the DAGs in this SB deproots.addAll(newdeproots); } + } } } deproots.addAll(newdepsbs); diff --git a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java index 4b2e28578b..8601147869 100644 --- a/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java +++ b/src/main/java/org/apache/sysds/parser/FunctionStatementBlock.java @@ -255,6 +255,15 @@ public class FunctionStatementBlock extends StatementBlock implements FunctionBl return _nondeterministic; } + // Set the loop dependent hop ratio as the average of all SBs in this function + public void setAvgLoopDepRatio() { + double totDep = 0; + FunctionStatement fstmt = (FunctionStatement) _statements.get(0); + for (var sb : fstmt.getBody()) + totDep += sb.getLoopDepRatio(); + this.setLoopDepRatio(totDep/fstmt.getBody().size()); + } + @Override public FunctionBlock cloneFunctionBlock() { return ProgramConverter diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java b/src/main/java/org/apache/sysds/parser/StatementBlock.java index 00117aeb2c..e8658e359e 100644 --- a/src/main/java/org/apache/sysds/parser/StatementBlock.java +++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java @@ -67,6 +67,7 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo private HashMap<Lop.Type, List<Lop.Type>> _checkpointPositions = null; protected double repetitions = 1; + private double loopDepRatio = 0; //ratio of loop dependent HOP dags public final static double DEFAULT_LOOP_REPETITIONS = 10; public StatementBlock() { @@ -180,6 +181,15 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo return _splitDag; } + public double getLoopDepRatio() { + return loopDepRatio; + } + + // maintain the ration of loop-dependent HOP dags in this block + public void setLoopDepRatio(double dep) { + loopDepRatio = dep; + } + private static boolean isMergeablePrintStatement(Statement stmt) { return ( stmt instanceof PrintStatement && (((PrintStatement)stmt).getType() == PRINTTYPE.STOP || ((PrintStatement)stmt).getType() == PRINTTYPE.ASSERT) );