This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 4facab16c1 [SYSTEMDS-3726,3725] Fix loop recompile-once and rewrites
4facab16c1 is described below
commit 4facab16c1ad12583e47af1760da9a5c0c9e3b03
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun Aug 25 13:56:47 2024 +0200
[SYSTEMDS-3726,3725] Fix loop recompile-once and rewrites
This patch fixes the new loop-recompile once feature. Instead of passing
the list of loop body blocks, we pass the for/while blocks such that
the size propagation can corrected understand the loop semantics and
deoptimize mismatching sizes accordingly. Furthermore, this patch also
includes a slightly to aggressive rewrite for right indexing in empty
matrices in order to ensure this rewrite doesn't hide index-out-of-bound
exceptions that otherwise would occur.
---
src/main/java/org/apache/sysds/hops/BinaryOp.java | 17 ++++++++-------
.../apache/sysds/hops/recompile/Recompiler.java | 16 ++++++--------
.../apache/sysds/hops/rewrite/ProgramRewriter.java | 4 ++++
.../RewriteAlgebraicSimplificationDynamic.java | 15 +++++++------
.../runtime/controlprogram/ForProgramBlock.java | 9 ++++++--
.../controlprogram/FunctionProgramBlock.java | 10 ++++++++-
.../runtime/controlprogram/WhileProgramBlock.java | 9 ++++++--
src/main/java/org/apache/sysds/utils/Explain.java | 25 +++++++++++++++-------
.../indexing/UnboundedScalarRightIndexingTest.java | 1 -
.../test/functions/misc/SizePropagationTest.java | 2 --
.../functions/recompile/remove_empty_recompile.dml | 12 +++++------
11 files changed, 74 insertions(+), 46 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index 954f0919ab..a47d6238be 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -371,18 +371,19 @@ public class BinaryOp extends MultiThreadedHop {
Lop append = null;
if( dt1==DataType.MATRIX || dt1==DataType.FRAME )
{
- long rlen = cbind ? getInput().get(0).getDim1() :
(getInput().get(0).dimsKnown() && getInput().get(1).dimsKnown()) ?
-
getInput().get(0).getDim1()+getInput().get(1).getDim1() : -1;
- long clen = cbind ? ((getInput().get(0).dimsKnown() &&
getInput().get(1).dimsKnown()) ?
-
getInput().get(0).getDim2()+getInput().get(1).getDim2() : -1) :
getInput().get(0).getDim2();
-
+ long rlen = cbind ? getInput(0).getDim1() :
(getInput(0).dimsKnown() && getInput(1).dimsKnown()) ?
+ getInput(0).getDim1()+getInput(1).getDim1() :
-1;
+ long clen = cbind ? ((getInput(0).dimsKnown() &&
getInput().get(1).dimsKnown()) ?
+ getInput(0).getDim2()+getInput(1).getDim2() :
-1) : getInput(0).getDim2();
+
if(et == ExecType.SPARK) {
- append =
constructSPAppendLop(getInput().get(0), getInput().get(1), getDataType(),
getValueType(), cbind, this);
+ append = constructSPAppendLop(getInput(0),
getInput(1), getDataType(), getValueType(), cbind, this);
append.getOutputParameters().setDimensions(rlen, clen, getBlocksize(),
getNnz());
}
else { //CP
- Lop offset = createOffsetLop(
getInput().get(0), cbind ); //offset 1st input
- append = new
Append(getInput().get(0).constructLops(), getInput().get(1).constructLops(),
offset, getDataType(), getValueType(), cbind, et);
+ Lop offset = createOffsetLop( getInput(0),
cbind ); //offset 1st input
+ append = new
Append(getInput(0).constructLops(), getInput(1).constructLops(),
+ offset, getDataType(), getValueType(),
cbind, et);
append.getOutputParameters().setDimensions(rlen, clen, getBlocksize(),
getNnz());
}
}
diff --git a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
index 7b79b495ae..a56c630c52 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
@@ -37,7 +37,6 @@ import org.apache.hadoop.fs.Path;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.api.jmlc.JMLCUtils;
import org.apache.sysds.common.Types.DataType;
-import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.OpOp1;
@@ -459,7 +458,7 @@ public class Recompiler {
System.out.println("EXPLAIN RECOMPILE \nPRED (line
"+hops.getBeginLine()+"):\n" + Explain.explain(inst,1));
}
- public static void recompileProgramBlockHierarchy(
ArrayList<ProgramBlock> pbs, LocalVariableMap vars, long tid, boolean inplace,
ResetType resetRecompile ) {
+ public static void recompileProgramBlockHierarchy( List<ProgramBlock>
pbs, LocalVariableMap vars, long tid, boolean inplace, ResetType resetRecompile
) {
//function recompilation via two-phase approach due to
challenges
//of unclear reconciliation of arbitrary complex control flow
@@ -788,7 +787,7 @@ public class Recompiler {
}
//handle sparsity change
if( mcOld.getNonZeros() !=
mc.getNonZeros() ) {
- lnnz=-1; //unknown
+ lnnz=-1; //unknown
requiresRecompile = true;
}
@@ -832,7 +831,7 @@ public class Recompiler {
}
//handle sparsity change
if( dcOld.getNonZeros() !=
dc.getNonZeros() ) {
- lnnz = -1;
+ lnnz = -1;
requiresRecompile = true;
}
@@ -894,7 +893,7 @@ public class Recompiler {
}
//handle sparsity change
if( mcOld.getNonZeros() !=
mc.getNonZeros() ) {
- lnnz = -1; //unknown
+ lnnz = -1; //unknown
}
MatrixObject moNew =
createOutputMatrix(ldim1, ldim2, lnnz);
@@ -1554,7 +1553,7 @@ public class Recompiler {
}
public static void recompileFunctionOnceIfNeeded(boolean recompileOnce,
- ArrayList<ProgramBlock> childBlocks, long tid, ExecutionContext
ec)
+ List<ProgramBlock> childBlocks, long tid, boolean inplace,
ResetType reset, ExecutionContext ec)
{
try {
if( ConfigurationManager.isDynamicRecompilation()
@@ -1568,10 +1567,7 @@ public class Recompiler {
// function will be recompiled for every
execution.
// (2) without reset, there would be no benefit
in recompiling the entire function
LocalVariableMap tmp = (LocalVariableMap)
ec.getVariables().clone();
- boolean codegen =
ConfigurationManager.isCodegenEnabled();
- boolean singlenode =
DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE;
- ResetType reset = (codegen || singlenode) ?
ResetType.RESET_KNOWN_DIMS : ResetType.RESET;
-
Recompiler.recompileProgramBlockHierarchy(childBlocks, tmp, tid, false, reset);
+
Recompiler.recompileProgramBlockHierarchy(childBlocks, tmp, tid, inplace,
reset);
if( DMLScript.STATISTICS ){
long t1 = System.nanoTime();
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index 1754b72b5e..fa0984b3e9 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -49,6 +49,10 @@ import org.apache.sysds.runtime.lineage.LineageCacheConfig;
public class ProgramRewriter{
private static final boolean CHECK = false;
+ static {
+
//Logger.getLogger("org.apache.sysds.hops.rewrite").setLevel(Level.DEBUG);
+ }
+
private ArrayList<HopRewriteRule> _dagRuleSet = null;
private ArrayList<StatementBlockRewriteRule> _sbRuleSet = null;
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index fea525703d..1a4c4ecebd 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -146,7 +146,7 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
rule_AlgebraicSimplification(hi, descendFirst);
//see below
//apply actual simplification rewrites (of childs incl
checks)
- hi = removeEmptyRightIndexing(hop, hi, i);
//e.g., X[,1] -> matrix(0,ru-rl+1,cu-cl+1), if nnz(X)==0
+ hi = removeEmptyRightIndexing(hop, hi, i);
//e.g., X[,1] -> matrix(0,ru-rl+1,cu-cl+1), if nnz(X)==0 and known indices
hi = removeUnnecessaryRightIndexing(hop, hi, i);
//e.g., X[,1] -> X, if output == input size
hi = removeEmptyLeftIndexing(hop, hi, i);
//e.g., X[,1]=Y -> matrix(0,nrow(X),ncol(X)), if nnz(X)==0 and nnz(Y)==0
hi = removeUnnecessaryLeftIndexing(hop, hi, i);
//e.g., X[,1]=Y -> Y, if output == input dims
@@ -214,10 +214,13 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
private static Hop removeEmptyRightIndexing(Hop parent, Hop hi, int
pos)
{
if( hi instanceof IndexingOp &&
hi.getDataType()==DataType.MATRIX ) //indexing op
- {
- Hop input = hi.getInput().get(0);
- if( input.getNnz()==0 && //nnz input known and empty
- HopRewriteUtils.isDimsKnown(hi)) //output dims known
+ {
+ Hop input = hi.getInput(0);
+ if( input.getNnz()==0 //nnz input known and empty
+ && HopRewriteUtils.isDimsKnown(hi) //output
dims known
+ //we also check for known indices to ensure
correct error handling of out-of-bounds indexing
+ && hi.getInput(1) instanceof LiteralOp &&
hi.getInput(2) instanceof LiteralOp
+ && hi.getInput(3) instanceof LiteralOp &&
hi.getInput(4) instanceof LiteralOp)
{
//remove unnecessary right indexing
Hop hnew =
HopRewriteUtils.createDataGenOpByVal( new LiteralOp(hi.getDim1()),
@@ -2498,7 +2501,7 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
hi = hnew;
- LOG.debug("Applied
simplifyEmptyBinaryOperation");
+ LOG.debug("Applied
simplifyEmptyBinaryOperation (line "+hi.getBeginLine()+").");
}
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/ForProgramBlock.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/ForProgramBlock.java
index 67ff28fbbe..073aa448cc 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/ForProgramBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ForProgramBlock.java
@@ -20,12 +20,15 @@
package org.apache.sysds.runtime.controlprogram;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Iterator;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.hops.recompile.Recompiler.ResetType;
import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.DMLScriptException;
@@ -120,9 +123,11 @@ public class ForProgramBlock extends ProgramBlock
UpdateType[] flags = prepareUpdateInPlaceVariables(ec,
_tid);
//dynamically recompile entire loop body (according to
loop inputs)
- if( getStatementBlock() != null )
+ //pass loop not just child blocks for correct size
propagation
+ StatementBlock sb = getStatementBlock();
+ if( sb != null )
Recompiler.recompileFunctionOnceIfNeeded(
- getStatementBlock().isRecompileOnce(),
_childBlocks, _tid, ec);
+ sb.isRecompileOnce(),
Arrays.asList(this), _tid, true, ResetType.RESET_KNOWN_DIMS, ec);
// compute and store the number of distinct paths
if (DMLScript.LINEAGE_DEDUP)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
index 22c6d03128..61b466888f 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
@@ -24,9 +24,13 @@ import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.FunctionBlock;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.hops.recompile.Recompiler.ResetType;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.DMLScriptException;
@@ -100,7 +104,11 @@ public class FunctionProgramBlock extends ProgramBlock
implements FunctionBlock
public void execute(ExecutionContext ec)
{
//dynamically recompile entire function body (according to
function inputs)
- Recompiler.recompileFunctionOnceIfNeeded(isRecompileOnce(),
_childBlocks, _tid, ec);
+ boolean codegen = ConfigurationManager.isCodegenEnabled();
+ boolean singlenode = DMLScript.getGlobalExecMode() ==
ExecMode.SINGLE_NODE;
+ ResetType reset = (codegen || singlenode) ?
ResetType.RESET_KNOWN_DIMS : ResetType.RESET;
+ Recompiler.recompileFunctionOnceIfNeeded(
+ isRecompileOnce(), _childBlocks, _tid, false, reset,
ec);
// for each program block
try {
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java
index 38e5aa46be..7f54300cd2 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java
@@ -20,9 +20,12 @@
package org.apache.sysds.runtime.controlprogram;
import java.util.ArrayList;
+import java.util.Arrays;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.hops.recompile.Recompiler.ResetType;
+import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ValueType;
@@ -101,9 +104,11 @@ public class WhileProgramBlock extends ProgramBlock
UpdateType[] flags = prepareUpdateInPlaceVariables(ec,
_tid);
//dynamically recompile entire loop body (according to
loop inputs)
- if( getStatementBlock() != null )
+ //pass loop not just child blocks for correct size
propagation
+ StatementBlock sb = getStatementBlock();
+ if( sb != null )
Recompiler.recompileFunctionOnceIfNeeded(
- getStatementBlock().isRecompileOnce(),
_childBlocks, _tid, ec);
+ sb.isRecompileOnce(),
Arrays.asList(this), _tid, true, ResetType.RESET_KNOWN_DIMS, ec);
// compute and store the number of distinct paths
if (DMLScript.LINEAGE_DEDUP)
diff --git a/src/main/java/org/apache/sysds/utils/Explain.java
b/src/main/java/org/apache/sysds/utils/Explain.java
index a779a45022..f60947bed2 100644
--- a/src/main/java/org/apache/sysds/utils/Explain.java
+++ b/src/main/java/org/apache/sysds/utils/Explain.java
@@ -26,6 +26,7 @@ import java.util.HashSet;
import java.util.Map;
import java.util.List;
import java.util.Map.Entry;
+import java.util.Set;
import java.util.Stack;
import org.apache.commons.lang3.mutable.MutableInt;
@@ -297,16 +298,16 @@ public class Explain
return sb.toString();
}
-
+
public static String explain( ProgramBlock pb ) {
return explainProgramBlock(pb, 0);
}
- public static String explain( ArrayList<Instruction> inst ) {
+ public static String explain( List<Instruction> inst ) {
return explainInstructions(inst, 0);
}
- public static String explain( ArrayList<Instruction> inst, int level ) {
+ public static String explain( List<Instruction> inst, int level ) {
return explainInstructions(inst, level);
}
@@ -318,11 +319,11 @@ public class Explain
return explainStatementBlock(sb, 0);
}
- public static String explainHops( ArrayList<Hop> hops ) {
+ public static String explainHops( List<Hop> hops ) {
return explainHops(hops, 0);
}
- public static String explainHops( ArrayList<Hop> hops, int level ) {
+ public static String explainHops( List<Hop> hops, int level ) {
StringBuilder sb = new StringBuilder();
Hop.resetVisitStatus(hops);
for( Hop hop : hops )
@@ -720,6 +721,14 @@ public class Explain
//////////////
// internal explain RUNTIME
+
+ public static String explainProgramBlocks( List<ProgramBlock> pbs ) {
+ StringBuilder sb = new StringBuilder();
+ for(ProgramBlock pb : pbs)
+ sb.append(explain(pb));
+ return sb.toString();
+ }
+
private static String explainProgramBlock( ProgramBlock pb, int level )
{
StringBuilder sb = new StringBuilder();
@@ -797,7 +806,7 @@ public class Explain
return sb.toString();
}
- private static String explainInstructions( ArrayList<Instruction>
instSet, int level ) {
+ private static String explainInstructions( List<Instruction> instSet,
int level ) {
StringBuilder sb = new StringBuilder();
String offsetInst = createOffset(level);
for( Instruction inst : instSet ) {
@@ -921,7 +930,7 @@ public class Explain
* if true, count Spark instructions and Spark reblock
* instructions
*/
- private static void countCompiledInstructions( ArrayList<Instruction>
instSet, ExplainCounts counts, boolean CP, boolean SP )
+ private static void countCompiledInstructions( List<Instruction>
instSet, ExplainCounts counts, boolean CP, boolean SP )
{
for( Instruction inst : instSet )
{
@@ -938,7 +947,7 @@ public class Explain
}
}
- public static String explainFunctionCallGraph(FunctionCallGraph fgraph,
HashSet<String> fstack, String fkey, int level)
+ public static String explainFunctionCallGraph(FunctionCallGraph fgraph,
Set<String> fstack, String fkey, int level)
{
StringBuilder builder = new StringBuilder();
String offset = createOffset(level);
diff --git
a/src/test/java/org/apache/sysds/test/functions/indexing/UnboundedScalarRightIndexingTest.java
b/src/test/java/org/apache/sysds/test/functions/indexing/UnboundedScalarRightIndexingTest.java
index d32e7865ec..9c855e22e9 100644
---
a/src/test/java/org/apache/sysds/test/functions/indexing/UnboundedScalarRightIndexingTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/indexing/UnboundedScalarRightIndexingTest.java
@@ -49,7 +49,6 @@ public class UnboundedScalarRightIndexingTest extends
AutomatedTestBase
runRightIndexingTest(ExecType.SPARK, 7);
}
-
@Test
public void testRightIndexingCPZero() {
runRightIndexingTest(ExecType.CP, 0);
diff --git
a/src/test/java/org/apache/sysds/test/functions/misc/SizePropagationTest.java
b/src/test/java/org/apache/sysds/test/functions/misc/SizePropagationTest.java
index bfa95e9efe..4b4a76aa19 100644
---
a/src/test/java/org/apache/sysds/test/functions/misc/SizePropagationTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/misc/SizePropagationTest.java
@@ -28,7 +28,6 @@ import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
-import org.junit.Ignore;
import java.util.HashMap;
@@ -81,7 +80,6 @@ public class SizePropagationTest extends AutomatedTestBase
}
@Test
- @Ignore //FIXME deeper issue of incorrect size propagation during
recompile?
public void testSizePropagationLoopIx2Rewrites() {
testSizePropagation( TEST_NAME3, true, N-2 );
}
diff --git a/src/test/scripts/functions/recompile/remove_empty_recompile.dml
b/src/test/scripts/functions/recompile/remove_empty_recompile.dml
index a4ee8be7cb..ddb1bd8e4a 100644
--- a/src/test/scripts/functions/recompile/remove_empty_recompile.dml
+++ b/src/test/scripts/functions/recompile/remove_empty_recompile.dml
@@ -20,8 +20,8 @@
#-------------------------------------------------------------
-execFun = function(Matrix[Double] X, Integer type)
- return (Matrix[Double] R)
+execFun = function(Matrix[Double] X, Integer type)
+ return (Matrix[Double] R)
{
R = X;
@@ -32,7 +32,7 @@ execFun = function(Matrix[Double] X, Integer type)
R = round(X);
}
if( type==2 ){
- R = t(X);
+ R = t(X);
}
if( type==3 ){
R = X*(X-1);
@@ -49,7 +49,7 @@ execFun = function(Matrix[Double] X, Integer type)
if( type==7 ){
R = X-(X+2);
}
- if( type==8 ){
+ if( type==8 ){
R = (X+2)-X;
}
if( type==9 ){
@@ -59,7 +59,7 @@ execFun = function(Matrix[Double] X, Integer type)
R = (X-1)%*%X;
}
if( type==11 ){
- R = X[1:(nrow(X)-1), 1:(ncol(X)-1)];
+ R = X[1:19, 1:19];
}
if( type==12 ){
X[1,] = X[2,];
@@ -69,4 +69,4 @@ execFun = function(Matrix[Double] X, Integer type)
X = read($1);
R = execFun( X, $2 )
-write(R, $3);
\ No newline at end of file
+write(R, $3);