This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 7ac6e2a [SYSTEMDS-2855] Rework function recompilation on entry (w/
rewrites)
7ac6e2a is described below
commit 7ac6e2a87ffb447dcf0f8064e979bcc2800cec38
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed Feb 10 12:50:39 2021 +0100
[SYSTEMDS-2855] Rework function recompilation on entry (w/ rewrites)
This patch makes a major change to the recompilation of functions (i.e.,
functions that have been marked during inter-procedural analysis for
recompile-once). So far, we applied inplace recompilation for updating
size information but without rewrites to allow a reset for future
function invocations. This caused inconsistencies where function
recompilation does not apply the same rewrites as normal block
recompilation (e.g., fails to rewrite -t(X) %*% y to -(t(X) %*% y) and
then -t(t(y)%*%X)). Now we first apply the logic for size propagation
(with potential reset), and if applicable then apply rewrites in a
second pass. Furthermore, this patch also cleans up the somewhat messy
passing of recompilation configurations via a reworked RecompileStatus
and fixes the local size propagation of matrix multiplications to allow
for a reset with unknown sizes (to ensure correct results).
---
.../java/org/apache/sysds/hops/AggBinaryOp.java | 25 +--
.../sysds/hops/recompile/RecompileStatus.java | 53 ++++-
.../apache/sysds/hops/recompile/Recompiler.java | 214 +++++++++------------
.../apache/sysds/hops/rewrite/ProgramRewriter.java | 11 ++
.../RewriteAlgebraicSimplificationDynamic.java | 2 +-
.../RewriteAlgebraicSimplificationStatic.java | 11 +-
.../org/apache/sysds/parser/ForStatementBlock.java | 4 +
.../apache/sysds/parser/ParForStatementBlock.java | 1 -
.../controlprogram/FunctionProgramBlock.java | 4 +-
.../parfor/opt/OptimizationWrapper.java | 4 +-
.../fed/AggregateUnaryFEDInstruction.java | 1 -
.../federated/algorithms/FederatedLmPipeline.java | 10 +-
12 files changed, 188 insertions(+), 152 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index 5dcc5ee..c279071 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -977,19 +977,19 @@ public class AggBinaryOp extends MultiThreadedHop
//right side cached (no agg if left has just one column block)
if( method == MMultMethod.MAPMM_R &&
getInput().get(0).getDim2() >= 0 //known num columns
- && getInput().get(0).getDim2() <=
getInput().get(0).getBlocksize() )
- {
- ret = false;
- }
-
+ && getInput().get(0).getDim2() <=
getInput().get(0).getBlocksize() )
+ {
+ ret = false;
+ }
+
//left side cached (no agg if right has just one row block)
- if( method == MMultMethod.MAPMM_L && getInput().get(1).getDim1() >= 0
//known num rows
- && getInput().get(1).getDim1() <=
getInput().get(1).getBlocksize() )
- {
- ret = false;
- }
-
- return ret;
+ if( method == MMultMethod.MAPMM_L &&
getInput().get(1).getDim1() >= 0 //known num rows
+ && getInput().get(1).getDim1() <=
getInput().get(1).getBlocksize() )
+ {
+ ret = false;
+ }
+
+ return ret;
}
/**
@@ -1274,6 +1274,7 @@ public class AggBinaryOp extends MultiThreadedHop
if( isMatrixMultiply() ) {
setDim1(input1.getDim1());
setDim2(input2.getDim2());
+ setNnz(-1); // for reset on recompile w/ unknowns
if( input1.getNnz() == 0 || input2.getNnz() == 0 )
setNnz(0);
}
diff --git a/src/main/java/org/apache/sysds/hops/recompile/RecompileStatus.java
b/src/main/java/org/apache/sysds/hops/recompile/RecompileStatus.java
index bdb91fd..edb03d2 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/RecompileStatus.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/RecompileStatus.java
@@ -19,21 +19,39 @@
package org.apache.sysds.hops.recompile;
+import org.apache.sysds.hops.recompile.Recompiler.ResetType;
+import org.apache.sysds.runtime.controlprogram.ProgramBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import java.util.HashMap;
public class RecompileStatus
{
+ //immutable flags for recompilation configurations
+ private final long _tid; // thread-id, 0 if main thread
+ private final boolean _inplace; // in-place recompilation, false
for rewrites
+ private final ResetType _reset; // reset type for program
compilation
+ private final boolean _initialCodegen; // initial codegen compilation
(no recompilation)
+
+ //track if parts of recompiled program still require recompilation
+ private boolean _requiresRecompile = false;
+
+ //collection of extracted statistics for control flow reconciliation
private final HashMap<String, DataCharacteristics> _lastTWrites;
- private final boolean _initialCodegen;
public RecompileStatus() {
- this(false);
+ this(0, true, ResetType.NO_RESET, false);
}
public RecompileStatus(boolean initialCodegen) {
+ this(0, true, ResetType.NO_RESET, initialCodegen);
+ }
+
+ public RecompileStatus(long tid, boolean inplace, ResetType reset,
boolean initialCodegen) {
_lastTWrites = new HashMap<>();
+ _tid = tid;
+ _inplace = inplace;
+ _reset = reset;
_initialCodegen = initialCodegen;
}
@@ -41,13 +59,42 @@ public class RecompileStatus
return _lastTWrites;
}
+ public long getTID() {
+ return _tid;
+ }
+
+ public boolean hasThreadID() {
+ return ProgramBlock.isThreadID(_tid);
+ }
+
+ public boolean isInPlace() {
+ return _inplace;
+ }
+
+ public boolean isReset() {
+ return _reset.isReset();
+ }
+
+ public ResetType getReset() {
+ return _reset;
+ }
+
public boolean isInitialCodegen() {
return _initialCodegen;
}
+
+ public void trackRecompile(boolean flag) {
+ _requiresRecompile |= flag;
+ }
+
+ public boolean requiresRecompile() {
+ return _requiresRecompile;
+ }
@Override
public Object clone() {
- RecompileStatus ret = new RecompileStatus();
+ RecompileStatus ret = new RecompileStatus(
+ _tid, _inplace, _reset, _initialCodegen);
ret._lastTWrites.putAll(_lastTWrites);
return ret;
}
diff --git a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
index 3b15b44..a714266 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
@@ -173,28 +173,6 @@ public class Recompiler
{
return recompileHopsDag(sb, hops, new ExecutionContext(vars),
status, inplace, replaceLit, tid);
}
-
- public static ArrayList<Instruction> recompileHopsDag( Hop hop,
ExecutionContext ec,
- RecompileStatus status, boolean inplace, boolean
replaceLit, long tid )
- {
- ArrayList<Instruction> newInst = null;
-
- //need for synchronization as we do temp changes in shared
hops/lops
- synchronized( hop ) {
- newInst = recompile(null, new
ArrayList<>(Arrays.asList(hop)),
- ec, status, inplace, replaceLit, true, false,
true, null, tid);
- }
-
- // replace thread ids in new instructions
- if( ProgramBlock.isThreadID(tid) ) //only in parfor context
- newInst =
ProgramConverter.createDeepCopyInstructionSet(newInst, tid, -1, null, null,
null, false, false);
-
- // explain recompiled instructions
- if( DMLScript.EXPLAIN == ExplainType.RECOMPILE_RUNTIME )
- logExplainPred(hop, newInst);
-
- return newInst;
- }
public static ArrayList<Instruction> recompileHopsDag( Hop hop,
LocalVariableMap vars,
RecompileStatus status, boolean inplace, boolean
replaceLit, long tid )
@@ -447,11 +425,25 @@ public class Recompiler
System.out.println("EXPLAIN RECOMPILE \nPRED (line
"+hops.getBeginLine()+"):\n" + Explain.explain(inst,1));
}
- public static void recompileProgramBlockHierarchy(
ArrayList<ProgramBlock> pbs, LocalVariableMap vars, long tid, ResetType
resetRecompile ) {
- RecompileStatus status = new RecompileStatus();
+ public static void recompileProgramBlockHierarchy(
ArrayList<ProgramBlock> pbs, LocalVariableMap vars, long tid, boolean inplace,
ResetType resetRecompile ) {
+ //function recompilation via two-phase approach due to
challenges
+ //of unclear reconciliation of arbitrary complex control flow
+
+ // phase 1: normal inplace=true w/o rewrite as usual, but track
requiresRecompile
+ // (preserve variables for potential second pass, otherwise
corrupted stats)
+ RecompileStatus status1 = new RecompileStatus(tid, true,
resetRecompile, false);
synchronized( pbs ) {
for( ProgramBlock pb : pbs )
- rRecompileProgramBlock(pb, vars, status, tid,
resetRecompile);
+ rRecompileProgramBlock(pb, vars, status1);
+
+ // phase 2: if called with inplace-false, run a second
in-place=false pass in
+ // order to apply rewrites (at this point sizes are
already propagated, but for
+ // correctness we call it with an empty symbol table to
avoid invalid size updates)
+ if( !status1.requiresRecompile() && !inplace ) {
+ RecompileStatus status2 = new
RecompileStatus(tid, false, resetRecompile, false);
+ for( ProgramBlock pb : pbs )
+ rRecompileProgramBlock(pb, new
LocalVariableMap(), status2);
+ }
}
}
@@ -635,27 +627,26 @@ public class Recompiler
// private helper functions //
//////////////////////////////
- private static void rRecompileProgramBlock( ProgramBlock pb,
LocalVariableMap vars,
- RecompileStatus status, long tid, ResetType resetRecompile )
+ private static void rRecompileProgramBlock( ProgramBlock pb,
LocalVariableMap vars, RecompileStatus status )
{
if (pb instanceof WhileProgramBlock) {
WhileProgramBlock wpb = (WhileProgramBlock)pb;
WhileStatementBlock wsb = (WhileStatementBlock)
wpb.getStatementBlock();
//recompile predicate
- recompileWhilePredicate(wpb, wsb, vars, status, tid,
resetRecompile);
+ recompileWhilePredicate(wpb, wsb, vars, status);
//remove updated scalars because in loop
removeUpdatedScalars(vars, wsb);
//copy vars for later compare
LocalVariableMap oldVars = (LocalVariableMap)
vars.clone();
RecompileStatus oldStatus = (RecompileStatus)
status.clone();
for (ProgramBlock pb2 : wpb.getChildBlocks())
- rRecompileProgramBlock(pb2, vars, status, tid,
resetRecompile);
+ rRecompileProgramBlock(pb2, vars, status);
if( reconcileUpdatedCallVarsLoops(oldVars, vars, wsb)
| reconcileUpdatedCallVarsLoops(oldStatus,
status, wsb) ) {
//second pass with unknowns if required
- recompileWhilePredicate(wpb, wsb, vars, status,
tid, resetRecompile);
+ recompileWhilePredicate(wpb, wsb, vars, status);
for (ProgramBlock pb2 : wpb.getChildBlocks())
- rRecompileProgramBlock(pb2, vars,
status, tid, resetRecompile);
+ rRecompileProgramBlock(pb2, vars,
status);
}
removeUpdatedScalars(vars, wsb);
}
@@ -663,16 +654,16 @@ public class Recompiler
IfProgramBlock ipb = (IfProgramBlock)pb;
IfStatementBlock isb =
(IfStatementBlock)ipb.getStatementBlock();
//recompile predicate
- recompileIfPredicate(ipb, isb, vars, status, tid,
resetRecompile);
+ recompileIfPredicate(ipb, isb, vars, status);
//copy vars for later compare
LocalVariableMap oldVars = (LocalVariableMap)
vars.clone();
LocalVariableMap varsElse = (LocalVariableMap)
vars.clone();
RecompileStatus oldStatus =
(RecompileStatus)status.clone();
RecompileStatus statusElse =
(RecompileStatus)status.clone();
for( ProgramBlock pb2 : ipb.getChildBlocksIfBody() )
- rRecompileProgramBlock(pb2, vars, status, tid,
resetRecompile);
+ rRecompileProgramBlock(pb2, vars, status);
for( ProgramBlock pb2 : ipb.getChildBlocksElseBody() )
- rRecompileProgramBlock(pb2, varsElse,
statusElse, tid, resetRecompile);
+ rRecompileProgramBlock(pb2, varsElse,
statusElse);
reconcileUpdatedCallVarsIf(oldVars, vars, varsElse,
isb);
reconcileUpdatedCallVarsIf(oldStatus, status,
statusElse, isb);
removeUpdatedScalars(vars, ipb.getStatementBlock());
@@ -681,20 +672,20 @@ public class Recompiler
ForProgramBlock fpb = (ForProgramBlock)pb;
ForStatementBlock fsb = (ForStatementBlock)
fpb.getStatementBlock();
//recompile predicates
- recompileForPredicates(fpb, fsb, vars, status, tid,
resetRecompile);
+ recompileForPredicates(fpb, fsb, vars, status);
//remove updated scalars because in loop
removeUpdatedScalars(vars, fpb.getStatementBlock());
//copy vars for later compare
LocalVariableMap oldVars = (LocalVariableMap)
vars.clone();
RecompileStatus oldStatus = (RecompileStatus)
status.clone();
for( ProgramBlock pb2 : fpb.getChildBlocks() )
- rRecompileProgramBlock(pb2, vars, status, tid,
resetRecompile);
+ rRecompileProgramBlock(pb2, vars, status);
if( reconcileUpdatedCallVarsLoops(oldVars, vars, fsb)
| reconcileUpdatedCallVarsLoops(oldStatus,
status, fsb)) {
//second pass with unknowns if required
- recompileForPredicates(fpb, fsb, vars, status,
tid, resetRecompile);
+ recompileForPredicates(fpb, fsb, vars, status);
for( ProgramBlock pb2 : fpb.getChildBlocks() )
- rRecompileProgramBlock(pb2, vars,
status, tid, resetRecompile);
+ rRecompileProgramBlock(pb2, vars,
status);
}
removeUpdatedScalars(vars, fpb.getStatementBlock());
}
@@ -711,20 +702,22 @@ public class Recompiler
//recompile all for stats propagation and recompile
flags
tmp = Recompiler.recompileHopsDag(
- sb, sb.getHops(), vars, status, true, false,
tid);
+ sb, sb.getHops(), vars, status,
status.isInPlace(), false, status.getTID());
bpb.setInstructions( tmp );
//propagate stats across hops (should be executed on
clone of vars)
- Recompiler.extractDAGOutputStatistics(sb.getHops(),
vars);
+ if( status.isInPlace() )
+
Recompiler.extractDAGOutputStatistics(sb.getHops(), vars);
//reset recompilation flags (w/ special handling
functions)
if( ParForProgramBlock.RESET_RECOMPILATION_FLAGs
&& !containsRootFunctionOp(sb.getHops())
- && resetRecompile.isReset() )
+ && status.isReset() )
{
- Hop.resetRecompilationFlag(sb.getHops(),
ExecType.CP, resetRecompile);
+ Hop.resetRecompilationFlag(sb.getHops(),
ExecType.CP, status.getReset());
sb.updateRecompilationFlag();
}
+ status.trackRecompile(sb.requiresRecompilation());
}
}
@@ -952,91 +945,70 @@ public class Recompiler
//helper functions for predicate recompile
- private static void recompileIfPredicate( IfProgramBlock ipb,
IfStatementBlock isb, LocalVariableMap vars, RecompileStatus status, long tid,
ResetType resetRecompile )
- {
- if( isb == null )
+ private static void recompileIfPredicate( IfProgramBlock ipb,
IfStatementBlock isb, LocalVariableMap vars, RecompileStatus status ) {
+ if( isb == null || isb.getPredicateHops() == null )
return;
-
Hop hops = isb.getPredicateHops();
- if( hops != null ) {
- ArrayList<Instruction> tmp = recompileHopsDag(
- hops, vars, status, true, false, tid);
- ipb.setPredicate( tmp );
- if( ParForProgramBlock.RESET_RECOMPILATION_FLAGs
- && resetRecompile.isReset() ) {
- Hop.resetRecompilationFlag(hops, ExecType.CP,
resetRecompile);
- isb.updatePredicateRecompilationFlag();
- }
- }
+ ArrayList<Instruction> tmp = recompileHopsDag(
+ hops, vars, status, status.isInPlace(), false,
status.getTID());
+ ipb.setPredicate( tmp );
+ if( ParForProgramBlock.RESET_RECOMPILATION_FLAGs &&
status.isReset() ) {
+ Hop.resetRecompilationFlag(hops, ExecType.CP,
status.getReset());
+ isb.updatePredicateRecompilationFlag();
+ }
+ status.trackRecompile(isb.requiresPredicateRecompilation());
}
- private static void recompileWhilePredicate( WhileProgramBlock wpb,
WhileStatementBlock wsb, LocalVariableMap vars, RecompileStatus status, long
tid, ResetType resetRecompile ) {
- if( wsb == null )
+ private static void recompileWhilePredicate( WhileProgramBlock wpb,
WhileStatementBlock wsb, LocalVariableMap vars, RecompileStatus status ) {
+ if( wsb == null || wsb.getPredicateHops() == null )
return;
-
Hop hops = wsb.getPredicateHops();
- if( hops != null ) {
- ArrayList<Instruction> tmp = recompileHopsDag(
- hops, vars, status, true, false, tid);
- wpb.setPredicate( tmp );
- if( ParForProgramBlock.RESET_RECOMPILATION_FLAGs
- && resetRecompile.isReset() ) {
- Hop.resetRecompilationFlag(hops, ExecType.CP,
resetRecompile);
- wsb.updatePredicateRecompilationFlag();
- }
- }
+ ArrayList<Instruction> tmp = recompileHopsDag(
+ hops, vars, status, status.isInPlace(), false,
status.getTID());
+ wpb.setPredicate( tmp );
+ if( ParForProgramBlock.RESET_RECOMPILATION_FLAGs &&
status.isReset() ) {
+ Hop.resetRecompilationFlag(hops, ExecType.CP,
status.getReset());
+ wsb.updatePredicateRecompilationFlag();
+ }
+ status.trackRecompile(wsb.requiresPredicateRecompilation());
}
- private static void recompileForPredicates( ForProgramBlock fpb,
ForStatementBlock fsb, LocalVariableMap vars, RecompileStatus status, long tid,
ResetType resetRecompile ) {
- if( fsb != null )
- {
- Hop fromHops = fsb.getFromHops();
- Hop toHops = fsb.getToHops();
- Hop incrHops = fsb.getIncrementHops();
-
- //handle recompilation flags
- if( ParForProgramBlock.RESET_RECOMPILATION_FLAGs
- && resetRecompile.isReset() )
- {
- if( fromHops != null ) {
- ArrayList<Instruction> tmp =
recompileHopsDag(
- fromHops, vars, status, true,
false, tid);
- fpb.setFromInstructions(tmp);
-
Hop.resetRecompilationFlag(fromHops,ExecType.CP, resetRecompile);
- }
- if( toHops != null ) {
- ArrayList<Instruction> tmp =
recompileHopsDag(
- toHops, vars, status, true,
false, tid);
- fpb.setToInstructions(tmp);
-
Hop.resetRecompilationFlag(toHops,ExecType.CP, resetRecompile);
- }
- if( incrHops != null ) {
- ArrayList<Instruction> tmp =
recompileHopsDag(
- incrHops, vars, status, true,
false, tid);
- fpb.setIncrementInstructions(tmp);
-
Hop.resetRecompilationFlag(incrHops,ExecType.CP, resetRecompile);
- }
- fsb.updatePredicateRecompilationFlags();
- }
- else //no reset of recompilation flags
- {
- if( fromHops != null ) {
- ArrayList<Instruction> tmp =
recompileHopsDag(
- fromHops, vars, status, true,
false, tid);
- fpb.setFromInstructions(tmp);
- }
- if( toHops != null ) {
- ArrayList<Instruction> tmp =
recompileHopsDag(
- toHops, vars, status, true,
false, tid);
- fpb.setToInstructions(tmp);
- }
- if( incrHops != null ) {
- ArrayList<Instruction> tmp =
recompileHopsDag(
- incrHops, vars, status, true,
false, tid);
- fpb.setIncrementInstructions(tmp);
- }
- }
+ private static void recompileForPredicates( ForProgramBlock fpb,
ForStatementBlock fsb, LocalVariableMap vars, RecompileStatus status ) {
+ if( fsb == null )
+ return;
+
+ Hop fromHops = fsb.getFromHops();
+ Hop toHops = fsb.getToHops();
+ Hop incrHops = fsb.getIncrementHops();
+
+ // recompile predicates
+ if( fromHops != null ) {
+ ArrayList<Instruction> tmp = recompileHopsDag(
+ fromHops, vars, status, status.isInPlace(),
false, status.getTID());
+ fpb.setFromInstructions(tmp);
}
+ if( toHops != null ) {
+ ArrayList<Instruction> tmp = recompileHopsDag(
+ toHops, vars, status, status.isInPlace(),
false, status.getTID());
+ fpb.setToInstructions(tmp);
+ }
+ if( incrHops != null ) {
+ ArrayList<Instruction> tmp = recompileHopsDag(
+ incrHops, vars, status, status.isInPlace(),
false, status.getTID());
+ fpb.setIncrementInstructions(tmp);
+ }
+
+ //handle recompilation flags
+ if( ParForProgramBlock.RESET_RECOMPILATION_FLAGs &&
status.isReset() ) {
+ if( fromHops != null )
+ Hop.resetRecompilationFlag(fromHops,
ExecType.CP, status.getReset());
+ if( toHops != null )
+ Hop.resetRecompilationFlag(toHops, ExecType.CP,
status.getReset());
+ if( incrHops != null )
+ Hop.resetRecompilationFlag(incrHops,
ExecType.CP, status.getReset());
+ fsb.updatePredicateRecompilationFlags();
+ }
+ status.trackRecompile(fsb.requiresPredicateRecompilation());
}
public static void rRecompileProgramBlock2Forced( ProgramBlock pb, long
tid, HashSet<String> fnStack, ExecType et ) {
@@ -1142,13 +1114,11 @@ public class Recompiler
}
}
- public static void extractDAGOutputStatistics(ArrayList<Hop> hops,
LocalVariableMap vars)
- {
+ public static void extractDAGOutputStatistics(ArrayList<Hop> hops,
LocalVariableMap vars) {
extractDAGOutputStatistics(hops, vars, true);
}
- public static void extractDAGOutputStatistics(ArrayList<Hop> hops,
LocalVariableMap vars, boolean overwrite)
- {
+ public static void extractDAGOutputStatistics(ArrayList<Hop> hops,
LocalVariableMap vars, boolean overwrite) {
for( Hop hop : hops ) //for all hop roots
extractDAGOutputStatistics(hop, vars, overwrite);
}
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index 5127384..62ba4ae 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -22,6 +22,8 @@ package org.apache.sysds.hops.rewrite;
import java.util.ArrayList;
import java.util.List;
+import org.apache.log4j.Level;
+import org.apache.log4j.Logger;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.CompilerConfig.ConfigType;
@@ -48,11 +50,20 @@ import org.apache.sysds.runtime.lineage.LineageCacheConfig;
*/
public class ProgramRewriter
{
+ private static final boolean LDEBUG = false; //internal local debug
level
private static final boolean CHECK = false;
private ArrayList<HopRewriteRule> _dagRuleSet = null;
private ArrayList<StatementBlockRewriteRule> _sbRuleSet = null;
+ static {
+ // for internal debugging only
+ if( LDEBUG ) {
+ Logger.getLogger("org.apache.sysds.hops.rewrite")
+ .setLevel(Level.DEBUG);
+ }
+ }
+
public ProgramRewriter() {
// by default which is used during initial compile
// apply all (static and dynamic) rewrites
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 519c400..71a7240 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -2495,7 +2495,7 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
hi = minus;
LOG.debug("Applied reorderMinusMatrixMult (line
"+hi.getBeginLine()+").");
- }
+ }
}
return hi;
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 5ab97bf..f0d9dea 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -254,7 +254,7 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
* handle removal of unnecessary binary operations
*
* X/1 or X*1 or 1*X or X-0 -> X
- * -1*X or X*-1-> -X
+ * -1*X or X*-1-> -X
*
* @param parent parent high-level operator
* @param hi high-level operator
@@ -777,7 +777,6 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
*/
private static Hop simplifyDistributiveBinaryOperation( Hop parent, Hop
hi, int pos )
{
-
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
@@ -810,9 +809,9 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
hi = mult;
applied = true;
- LOG.debug("Applied
simplifyDistributiveBinaryOperation1");
- }
- }
+ LOG.debug("Applied
simplifyDistributiveBinaryOperation1 (line "+hi.getBeginLine()+").");
+ }
+ }
if( !applied && HopRewriteUtils.isBinary(right,
OpOp2.MULT) ) //(X-Y*X) -> (1-Y)*X
{
@@ -831,7 +830,7 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
HopRewriteUtils.cleanupUnreferenced(hi, right);
hi = mult;
- LOG.debug("Applied
simplifyDistributiveBinaryOperation2");
+ LOG.debug("Applied
simplifyDistributiveBinaryOperation2 (line "+hi.getBeginLine()+").");
}
}
}
diff --git a/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
b/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
index 092fbb7..c2686da 100644
--- a/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/ForStatementBlock.java
@@ -410,6 +410,10 @@ public class ForStatementBlock extends StatementBlock
_requiresToRecompile =
Recompiler.requiresRecompilation(getToHops());
_requiresIncrementRecompile =
Recompiler.requiresRecompilation(getIncrementHops());
}
+ return requiresPredicateRecompilation();
+ }
+
+ public boolean requiresPredicateRecompilation() {
return (_requiresFromRecompile || _requiresToRecompile ||
_requiresIncrementRecompile);
}
diff --git a/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java
b/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java
index af88d60..9219ba1 100644
--- a/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java
@@ -152,7 +152,6 @@ public class ParForStatementBlock extends ForStatementBlock
if( LDEBUG ) {
Logger.getLogger("org.apache.sysds.parser.ParForStatementBlock")
.setLevel(Level.TRACE);
- System.out.println();
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
index cd7f0cc..32ba97f 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
@@ -116,8 +116,8 @@ public class FunctionProgramBlock extends ProgramBlock
implements FunctionBlock
boolean codegen =
ConfigurationManager.isCodegenEnabled();
boolean singlenode =
DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE;
ResetType reset = (codegen || singlenode) ?
ResetType.RESET_KNOWN_DIMS : ResetType.RESET;
-
Recompiler.recompileProgramBlockHierarchy(_childBlocks, tmp, _tid, reset);
-
+
Recompiler.recompileProgramBlockHierarchy(_childBlocks, tmp, _tid, false,
reset);
+
if( DMLScript.STATISTICS ){
long t1 = System.nanoTime();
Statistics.incrementFunRecompileTime(t1-t0);
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
index a329f9c..62f7e41 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
@@ -185,7 +185,7 @@ public class OptimizationWrapper
LocalVariableMap tmp = (LocalVariableMap)
ec.getVariables().clone();
ResetType reset =
ConfigurationManager.isCodegenEnabled() ?
ResetType.RESET_KNOWN_DIMS :
ResetType.RESET;
-
Recompiler.recompileProgramBlockHierarchy(pb.getChildBlocks(), tmp, 0, reset);
+
Recompiler.recompileProgramBlockHierarchy(pb.getChildBlocks(), tmp, 0, true,
reset);
//inter-procedural optimization (based on
previous recompilation)
if( pb.hasFunctions() ) {
@@ -201,7 +201,7 @@ public class OptimizationWrapper
//reset recompilation
flags according to recompileOnce because it is only safe if function is
recompileOnce
//because then
recompiled for every execution (otherwise potential issues if func also called
outside parfor)
ResetType reset2 =
fpb.isRecompileOnce() ? reset : ResetType.NO_RESET;
-
Recompiler.recompileProgramBlockHierarchy(fpb.getChildBlocks(), new
LocalVariableMap(), 0, reset2);
+
Recompiler.recompileProgramBlockHierarchy(fpb.getChildBlocks(), new
LocalVariableMap(), 0, true, reset2);
}
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index 4fbe4e6..097b678 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -21,7 +21,6 @@ package org.apache.sysds.runtime.instructions.fed;
import java.util.concurrent.Future;
-import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java
index 8d4ac8d..14e4de1 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java
@@ -28,6 +28,8 @@ import
org.apache.sysds.runtime.transform.encode.EncoderRecode;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
import org.junit.Test;
@net.jcip.annotations.NotThreadSafe
@@ -110,7 +112,7 @@ public class FederatedLmPipeline extends AutomatedTestBase {
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
-
+
// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {"-args", input("X1"),
input("X2"), input("X3"), input("X4"), input("Y"),
@@ -119,7 +121,7 @@ public class FederatedLmPipeline extends AutomatedTestBase {
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-nvargs", "in_X1=" +
TestUtils.federatedAddress(port1, input("X1")),
+ programArgs = new String[] {"-stats", "-nvargs",
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2,
input("X2")),
"in_X3=" + TestUtils.federatedAddress(port3,
input("X3")),
"in_X4=" + TestUtils.federatedAddress(port4,
input("X4")), "rows=" + rows, "cols=" + (cols + 1),
@@ -129,6 +131,10 @@ public class FederatedLmPipeline extends AutomatedTestBase
{
// compare via files
compareResults(1e-2);
TestUtils.shutdownThreads(t1, t2, t3, t4);
+
+ // check correct federated operations
+
Assert.assertTrue(Statistics.getCPHeavyHitterCount("fed_mmchain")>10);
+
Assert.assertTrue(Statistics.getCPHeavyHitterCount("fed_ba+*")==3);
}
finally {
resetExecMode(oldExec);