This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 4facab16c1 [SYSTEMDS-3726,3725] Fix loop recompile-once and rewrites 4facab16c1 is described below commit 4facab16c1ad12583e47af1760da9a5c0c9e3b03 Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Sun Aug 25 13:56:47 2024 +0200 [SYSTEMDS-3726,3725] Fix loop recompile-once and rewrites This patch fixes the new loop-recompile once feature. Instead of passing the list of loop body blocks, we pass the for/while blocks such that the size propagation can corrected understand the loop semantics and deoptimize mismatching sizes accordingly. Furthermore, this patch also includes a slightly to aggressive rewrite for right indexing in empty matrices in order to ensure this rewrite doesn't hide index-out-of-bound exceptions that otherwise would occur. --- src/main/java/org/apache/sysds/hops/BinaryOp.java | 17 ++++++++------- .../apache/sysds/hops/recompile/Recompiler.java | 16 ++++++-------- .../apache/sysds/hops/rewrite/ProgramRewriter.java | 4 ++++ .../RewriteAlgebraicSimplificationDynamic.java | 15 +++++++------ .../runtime/controlprogram/ForProgramBlock.java | 9 ++++++-- .../controlprogram/FunctionProgramBlock.java | 10 ++++++++- .../runtime/controlprogram/WhileProgramBlock.java | 9 ++++++-- src/main/java/org/apache/sysds/utils/Explain.java | 25 +++++++++++++++------- .../indexing/UnboundedScalarRightIndexingTest.java | 1 - .../test/functions/misc/SizePropagationTest.java | 2 -- .../functions/recompile/remove_empty_recompile.dml | 12 +++++------ 11 files changed, 74 insertions(+), 46 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java index 954f0919ab..a47d6238be 100644 --- a/src/main/java/org/apache/sysds/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java @@ -371,18 +371,19 @@ public class BinaryOp extends MultiThreadedHop { Lop append = null; if( dt1==DataType.MATRIX || dt1==DataType.FRAME ) { - long rlen = cbind ? getInput().get(0).getDim1() : (getInput().get(0).dimsKnown() && getInput().get(1).dimsKnown()) ? - getInput().get(0).getDim1()+getInput().get(1).getDim1() : -1; - long clen = cbind ? ((getInput().get(0).dimsKnown() && getInput().get(1).dimsKnown()) ? - getInput().get(0).getDim2()+getInput().get(1).getDim2() : -1) : getInput().get(0).getDim2(); - + long rlen = cbind ? getInput(0).getDim1() : (getInput(0).dimsKnown() && getInput(1).dimsKnown()) ? + getInput(0).getDim1()+getInput(1).getDim1() : -1; + long clen = cbind ? ((getInput(0).dimsKnown() && getInput().get(1).dimsKnown()) ? + getInput(0).getDim2()+getInput(1).getDim2() : -1) : getInput(0).getDim2(); + if(et == ExecType.SPARK) { - append = constructSPAppendLop(getInput().get(0), getInput().get(1), getDataType(), getValueType(), cbind, this); + append = constructSPAppendLop(getInput(0), getInput(1), getDataType(), getValueType(), cbind, this); append.getOutputParameters().setDimensions(rlen, clen, getBlocksize(), getNnz()); } else { //CP - Lop offset = createOffsetLop( getInput().get(0), cbind ); //offset 1st input - append = new Append(getInput().get(0).constructLops(), getInput().get(1).constructLops(), offset, getDataType(), getValueType(), cbind, et); + Lop offset = createOffsetLop( getInput(0), cbind ); //offset 1st input + append = new Append(getInput(0).constructLops(), getInput(1).constructLops(), + offset, getDataType(), getValueType(), cbind, et); append.getOutputParameters().setDimensions(rlen, clen, getBlocksize(), getNnz()); } } diff --git a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java index 7b79b495ae..a56c630c52 100644 --- a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java +++ b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java @@ -37,7 +37,6 @@ import org.apache.hadoop.fs.Path; import org.apache.sysds.api.DMLScript; import org.apache.sysds.api.jmlc.JMLCUtils; import org.apache.sysds.common.Types.DataType; -import org.apache.sysds.common.Types.ExecMode; import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.FileFormat; import org.apache.sysds.common.Types.OpOp1; @@ -459,7 +458,7 @@ public class Recompiler { System.out.println("EXPLAIN RECOMPILE \nPRED (line "+hops.getBeginLine()+"):\n" + Explain.explain(inst,1)); } - public static void recompileProgramBlockHierarchy( ArrayList<ProgramBlock> pbs, LocalVariableMap vars, long tid, boolean inplace, ResetType resetRecompile ) { + public static void recompileProgramBlockHierarchy( List<ProgramBlock> pbs, LocalVariableMap vars, long tid, boolean inplace, ResetType resetRecompile ) { //function recompilation via two-phase approach due to challenges //of unclear reconciliation of arbitrary complex control flow @@ -788,7 +787,7 @@ public class Recompiler { } //handle sparsity change if( mcOld.getNonZeros() != mc.getNonZeros() ) { - lnnz=-1; //unknown + lnnz=-1; //unknown requiresRecompile = true; } @@ -832,7 +831,7 @@ public class Recompiler { } //handle sparsity change if( dcOld.getNonZeros() != dc.getNonZeros() ) { - lnnz = -1; + lnnz = -1; requiresRecompile = true; } @@ -894,7 +893,7 @@ public class Recompiler { } //handle sparsity change if( mcOld.getNonZeros() != mc.getNonZeros() ) { - lnnz = -1; //unknown + lnnz = -1; //unknown } MatrixObject moNew = createOutputMatrix(ldim1, ldim2, lnnz); @@ -1554,7 +1553,7 @@ public class Recompiler { } public static void recompileFunctionOnceIfNeeded(boolean recompileOnce, - ArrayList<ProgramBlock> childBlocks, long tid, ExecutionContext ec) + List<ProgramBlock> childBlocks, long tid, boolean inplace, ResetType reset, ExecutionContext ec) { try { if( ConfigurationManager.isDynamicRecompilation() @@ -1568,10 +1567,7 @@ public class Recompiler { // function will be recompiled for every execution. // (2) without reset, there would be no benefit in recompiling the entire function LocalVariableMap tmp = (LocalVariableMap) ec.getVariables().clone(); - boolean codegen = ConfigurationManager.isCodegenEnabled(); - boolean singlenode = DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE; - ResetType reset = (codegen || singlenode) ? ResetType.RESET_KNOWN_DIMS : ResetType.RESET; - Recompiler.recompileProgramBlockHierarchy(childBlocks, tmp, tid, false, reset); + Recompiler.recompileProgramBlockHierarchy(childBlocks, tmp, tid, inplace, reset); if( DMLScript.STATISTICS ){ long t1 = System.nanoTime(); diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java index 1754b72b5e..fa0984b3e9 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java @@ -49,6 +49,10 @@ import org.apache.sysds.runtime.lineage.LineageCacheConfig; public class ProgramRewriter{ private static final boolean CHECK = false; + static { + //Logger.getLogger("org.apache.sysds.hops.rewrite").setLevel(Level.DEBUG); + } + private ArrayList<HopRewriteRule> _dagRuleSet = null; private ArrayList<StatementBlockRewriteRule> _sbRuleSet = null; diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index fea525703d..1a4c4ecebd 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -146,7 +146,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule rule_AlgebraicSimplification(hi, descendFirst); //see below //apply actual simplification rewrites (of childs incl checks) - hi = removeEmptyRightIndexing(hop, hi, i); //e.g., X[,1] -> matrix(0,ru-rl+1,cu-cl+1), if nnz(X)==0 + hi = removeEmptyRightIndexing(hop, hi, i); //e.g., X[,1] -> matrix(0,ru-rl+1,cu-cl+1), if nnz(X)==0 and known indices hi = removeUnnecessaryRightIndexing(hop, hi, i); //e.g., X[,1] -> X, if output == input size hi = removeEmptyLeftIndexing(hop, hi, i); //e.g., X[,1]=Y -> matrix(0,nrow(X),ncol(X)), if nnz(X)==0 and nnz(Y)==0 hi = removeUnnecessaryLeftIndexing(hop, hi, i); //e.g., X[,1]=Y -> Y, if output == input dims @@ -214,10 +214,13 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule private static Hop removeEmptyRightIndexing(Hop parent, Hop hi, int pos) { if( hi instanceof IndexingOp && hi.getDataType()==DataType.MATRIX ) //indexing op - { - Hop input = hi.getInput().get(0); - if( input.getNnz()==0 && //nnz input known and empty - HopRewriteUtils.isDimsKnown(hi)) //output dims known + { + Hop input = hi.getInput(0); + if( input.getNnz()==0 //nnz input known and empty + && HopRewriteUtils.isDimsKnown(hi) //output dims known + //we also check for known indices to ensure correct error handling of out-of-bounds indexing + && hi.getInput(1) instanceof LiteralOp && hi.getInput(2) instanceof LiteralOp + && hi.getInput(3) instanceof LiteralOp && hi.getInput(4) instanceof LiteralOp) { //remove unnecessary right indexing Hop hnew = HopRewriteUtils.createDataGenOpByVal( new LiteralOp(hi.getDim1()), @@ -2498,7 +2501,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); hi = hnew; - LOG.debug("Applied simplifyEmptyBinaryOperation"); + LOG.debug("Applied simplifyEmptyBinaryOperation (line "+hi.getBeginLine()+")."); } } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ForProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ForProgramBlock.java index 67ff28fbbe..073aa448cc 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/ForProgramBlock.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ForProgramBlock.java @@ -20,12 +20,15 @@ package org.apache.sysds.runtime.controlprogram; import java.util.ArrayList; +import java.util.Arrays; import java.util.Iterator; import org.apache.sysds.api.DMLScript; import org.apache.sysds.hops.Hop; import org.apache.sysds.hops.recompile.Recompiler; +import org.apache.sysds.hops.recompile.Recompiler.ResetType; import org.apache.sysds.parser.ForStatementBlock; +import org.apache.sysds.parser.StatementBlock; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.DMLScriptException; @@ -120,9 +123,11 @@ public class ForProgramBlock extends ProgramBlock UpdateType[] flags = prepareUpdateInPlaceVariables(ec, _tid); //dynamically recompile entire loop body (according to loop inputs) - if( getStatementBlock() != null ) + //pass loop not just child blocks for correct size propagation + StatementBlock sb = getStatementBlock(); + if( sb != null ) Recompiler.recompileFunctionOnceIfNeeded( - getStatementBlock().isRecompileOnce(), _childBlocks, _tid, ec); + sb.isRecompileOnce(), Arrays.asList(this), _tid, true, ResetType.RESET_KNOWN_DIMS, ec); // compute and store the number of distinct paths if (DMLScript.LINEAGE_DEDUP) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java index 22c6d03128..61b466888f 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java @@ -24,9 +24,13 @@ import java.util.HashSet; import java.util.List; import java.util.stream.Collectors; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types.ExecMode; import org.apache.sysds.common.Types.FunctionBlock; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.hops.recompile.Recompiler; +import org.apache.sysds.hops.recompile.Recompiler.ResetType; import org.apache.sysds.parser.DataIdentifier; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.DMLScriptException; @@ -100,7 +104,11 @@ public class FunctionProgramBlock extends ProgramBlock implements FunctionBlock public void execute(ExecutionContext ec) { //dynamically recompile entire function body (according to function inputs) - Recompiler.recompileFunctionOnceIfNeeded(isRecompileOnce(), _childBlocks, _tid, ec); + boolean codegen = ConfigurationManager.isCodegenEnabled(); + boolean singlenode = DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE; + ResetType reset = (codegen || singlenode) ? ResetType.RESET_KNOWN_DIMS : ResetType.RESET; + Recompiler.recompileFunctionOnceIfNeeded( + isRecompileOnce(), _childBlocks, _tid, false, reset, ec); // for each program block try { diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java index 38e5aa46be..7f54300cd2 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/WhileProgramBlock.java @@ -20,9 +20,12 @@ package org.apache.sysds.runtime.controlprogram; import java.util.ArrayList; +import java.util.Arrays; import org.apache.sysds.hops.Hop; import org.apache.sysds.hops.recompile.Recompiler; +import org.apache.sysds.hops.recompile.Recompiler.ResetType; +import org.apache.sysds.parser.StatementBlock; import org.apache.sysds.parser.WhileStatementBlock; import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types.ValueType; @@ -101,9 +104,11 @@ public class WhileProgramBlock extends ProgramBlock UpdateType[] flags = prepareUpdateInPlaceVariables(ec, _tid); //dynamically recompile entire loop body (according to loop inputs) - if( getStatementBlock() != null ) + //pass loop not just child blocks for correct size propagation + StatementBlock sb = getStatementBlock(); + if( sb != null ) Recompiler.recompileFunctionOnceIfNeeded( - getStatementBlock().isRecompileOnce(), _childBlocks, _tid, ec); + sb.isRecompileOnce(), Arrays.asList(this), _tid, true, ResetType.RESET_KNOWN_DIMS, ec); // compute and store the number of distinct paths if (DMLScript.LINEAGE_DEDUP) diff --git a/src/main/java/org/apache/sysds/utils/Explain.java b/src/main/java/org/apache/sysds/utils/Explain.java index a779a45022..f60947bed2 100644 --- a/src/main/java/org/apache/sysds/utils/Explain.java +++ b/src/main/java/org/apache/sysds/utils/Explain.java @@ -26,6 +26,7 @@ import java.util.HashSet; import java.util.Map; import java.util.List; import java.util.Map.Entry; +import java.util.Set; import java.util.Stack; import org.apache.commons.lang3.mutable.MutableInt; @@ -297,16 +298,16 @@ public class Explain return sb.toString(); } - + public static String explain( ProgramBlock pb ) { return explainProgramBlock(pb, 0); } - public static String explain( ArrayList<Instruction> inst ) { + public static String explain( List<Instruction> inst ) { return explainInstructions(inst, 0); } - public static String explain( ArrayList<Instruction> inst, int level ) { + public static String explain( List<Instruction> inst, int level ) { return explainInstructions(inst, level); } @@ -318,11 +319,11 @@ public class Explain return explainStatementBlock(sb, 0); } - public static String explainHops( ArrayList<Hop> hops ) { + public static String explainHops( List<Hop> hops ) { return explainHops(hops, 0); } - public static String explainHops( ArrayList<Hop> hops, int level ) { + public static String explainHops( List<Hop> hops, int level ) { StringBuilder sb = new StringBuilder(); Hop.resetVisitStatus(hops); for( Hop hop : hops ) @@ -720,6 +721,14 @@ public class Explain ////////////// // internal explain RUNTIME + + public static String explainProgramBlocks( List<ProgramBlock> pbs ) { + StringBuilder sb = new StringBuilder(); + for(ProgramBlock pb : pbs) + sb.append(explain(pb)); + return sb.toString(); + } + private static String explainProgramBlock( ProgramBlock pb, int level ) { StringBuilder sb = new StringBuilder(); @@ -797,7 +806,7 @@ public class Explain return sb.toString(); } - private static String explainInstructions( ArrayList<Instruction> instSet, int level ) { + private static String explainInstructions( List<Instruction> instSet, int level ) { StringBuilder sb = new StringBuilder(); String offsetInst = createOffset(level); for( Instruction inst : instSet ) { @@ -921,7 +930,7 @@ public class Explain * if true, count Spark instructions and Spark reblock * instructions */ - private static void countCompiledInstructions( ArrayList<Instruction> instSet, ExplainCounts counts, boolean CP, boolean SP ) + private static void countCompiledInstructions( List<Instruction> instSet, ExplainCounts counts, boolean CP, boolean SP ) { for( Instruction inst : instSet ) { @@ -938,7 +947,7 @@ public class Explain } } - public static String explainFunctionCallGraph(FunctionCallGraph fgraph, HashSet<String> fstack, String fkey, int level) + public static String explainFunctionCallGraph(FunctionCallGraph fgraph, Set<String> fstack, String fkey, int level) { StringBuilder builder = new StringBuilder(); String offset = createOffset(level); diff --git a/src/test/java/org/apache/sysds/test/functions/indexing/UnboundedScalarRightIndexingTest.java b/src/test/java/org/apache/sysds/test/functions/indexing/UnboundedScalarRightIndexingTest.java index d32e7865ec..9c855e22e9 100644 --- a/src/test/java/org/apache/sysds/test/functions/indexing/UnboundedScalarRightIndexingTest.java +++ b/src/test/java/org/apache/sysds/test/functions/indexing/UnboundedScalarRightIndexingTest.java @@ -49,7 +49,6 @@ public class UnboundedScalarRightIndexingTest extends AutomatedTestBase runRightIndexingTest(ExecType.SPARK, 7); } - @Test public void testRightIndexingCPZero() { runRightIndexingTest(ExecType.CP, 0); diff --git a/src/test/java/org/apache/sysds/test/functions/misc/SizePropagationTest.java b/src/test/java/org/apache/sysds/test/functions/misc/SizePropagationTest.java index bfa95e9efe..4b4a76aa19 100644 --- a/src/test/java/org/apache/sysds/test/functions/misc/SizePropagationTest.java +++ b/src/test/java/org/apache/sysds/test/functions/misc/SizePropagationTest.java @@ -28,7 +28,6 @@ import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; import org.junit.Assert; -import org.junit.Ignore; import java.util.HashMap; @@ -81,7 +80,6 @@ public class SizePropagationTest extends AutomatedTestBase } @Test - @Ignore //FIXME deeper issue of incorrect size propagation during recompile? public void testSizePropagationLoopIx2Rewrites() { testSizePropagation( TEST_NAME3, true, N-2 ); } diff --git a/src/test/scripts/functions/recompile/remove_empty_recompile.dml b/src/test/scripts/functions/recompile/remove_empty_recompile.dml index a4ee8be7cb..ddb1bd8e4a 100644 --- a/src/test/scripts/functions/recompile/remove_empty_recompile.dml +++ b/src/test/scripts/functions/recompile/remove_empty_recompile.dml @@ -20,8 +20,8 @@ #------------------------------------------------------------- -execFun = function(Matrix[Double] X, Integer type) - return (Matrix[Double] R) +execFun = function(Matrix[Double] X, Integer type) + return (Matrix[Double] R) { R = X; @@ -32,7 +32,7 @@ execFun = function(Matrix[Double] X, Integer type) R = round(X); } if( type==2 ){ - R = t(X); + R = t(X); } if( type==3 ){ R = X*(X-1); @@ -49,7 +49,7 @@ execFun = function(Matrix[Double] X, Integer type) if( type==7 ){ R = X-(X+2); } - if( type==8 ){ + if( type==8 ){ R = (X+2)-X; } if( type==9 ){ @@ -59,7 +59,7 @@ execFun = function(Matrix[Double] X, Integer type) R = (X-1)%*%X; } if( type==11 ){ - R = X[1:(nrow(X)-1), 1:(ncol(X)-1)]; + R = X[1:19, 1:19]; } if( type==12 ){ X[1,] = X[2,]; @@ -69,4 +69,4 @@ execFun = function(Matrix[Double] X, Integer type) X = read($1); R = execFun( X, $2 ) -write(R, $3); \ No newline at end of file +write(R, $3);