This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 45de519829 [MINOR] Fix index bounds checks on recompilation literal replacement 45de519829 is described below commit 45de519829860f515f6f40871e9f4a4ce92d5cec Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Fri Apr 19 10:46:41 2024 +0200 [MINOR] Fix index bounds checks on recompilation literal replacement --- .../org/apache/sysds/hops/recompile/LiteralReplacement.java | 4 ++++ .../instructions/spark/MatrixIndexingSPInstruction.java | 10 +++++++--- .../functions/indexing/UnboundedScalarRightIndexingTest.java | 8 ++++---- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/recompile/LiteralReplacement.java b/src/main/java/org/apache/sysds/hops/recompile/LiteralReplacement.java index 1a40ada94e..441a80ceb6 100644 --- a/src/main/java/org/apache/sysds/hops/recompile/LiteralReplacement.java +++ b/src/main/java/org/apache/sysds/hops/recompile/LiteralReplacement.java @@ -253,6 +253,10 @@ public class LiteralReplacement if( mo.getNumRows()*mo.getNumColumns() < REPLACE_LITERALS_MAX_MATRIX_SIZE ) { MatrixBlock mBlock = mo.acquireRead(); + if( rlval>mo.getNumRows() || clval>mo.getNumColumns() ) { + throw new DMLRuntimeException("Scalar indexing out-of-bounds:" + + " ["+rlval+", "+clval+"] in "+mo.getDataCharacteristics()); + } double value = mBlock.get((int)rlval-1,(int)clval-1); mo.release(); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java index 3c8583d34c..e97336a8a6 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java @@ -90,11 +90,15 @@ public class MatrixIndexingSPInstruction extends IndexingSPInstruction { long cu = ec.getScalarInput(colUpper).getLongValue(); IndexRange ixrange = new IndexRange(rl, ru, cl, cu); + //check bounds + DataCharacteristics mcIn = sec.getDataCharacteristics(input1.getName()); + if( mcIn.dimsKnown() && (ru>mcIn.getRows() || cu>mcIn.getCols()) ) + throw new DMLRuntimeException("Index range out of bounds: "+ixrange+" "+mcIn); + //right indexing if( opcode.equalsIgnoreCase(RightIndex.OPCODE) ) { //update and check output dimensions - DataCharacteristics mcIn = sec.getDataCharacteristics(input1.getName()); DataCharacteristics mcOut = sec.getDataCharacteristics(output.getName()); mcOut.set(ru-rl+1, cu-cl+1, mcIn.getBlocksize(), mcIn.getBlocksize()); mcOut.setNonZerosBound(Math.min(mcOut.getLength(), mcIn.getNonZerosBound())); @@ -114,7 +118,7 @@ public class MatrixIndexingSPInstruction extends IndexingSPInstruction { //put output RDD handle into symbol table sec.setRDDHandleForVariable(output.getName(), out); - sec.addLineageRDD(output.getName(), input1.getName()); + sec.addLineageRDD(output.getName(), input1.getName()); } } //left indexing @@ -129,7 +133,7 @@ public class MatrixIndexingSPInstruction extends IndexingSPInstruction { //update and check output dimensions DataCharacteristics mcOut = sec.getDataCharacteristics(output.getName()); - DataCharacteristics mcLeft = ec.getDataCharacteristics(input1.getName()); + DataCharacteristics mcLeft = mcIn; mcOut.set(mcLeft.getRows(), mcLeft.getCols(), mcLeft.getBlocksize(), mcLeft.getBlocksize()); checkValidOutputDimensions(mcOut); diff --git a/src/test/java/org/apache/sysds/test/functions/indexing/UnboundedScalarRightIndexingTest.java b/src/test/java/org/apache/sysds/test/functions/indexing/UnboundedScalarRightIndexingTest.java index 8fc9bf6d9e..d32e7865ec 100644 --- a/src/test/java/org/apache/sysds/test/functions/indexing/UnboundedScalarRightIndexingTest.java +++ b/src/test/java/org/apache/sysds/test/functions/indexing/UnboundedScalarRightIndexingTest.java @@ -73,10 +73,10 @@ public class UnboundedScalarRightIndexingTest extends AutomatedTestBase DMLScript.USE_LOCAL_SPARK_CONFIG = true; try { - TestConfiguration config = getTestConfiguration(TEST_NAME); - loadTestConfiguration(config); - - String RI_HOME = SCRIPT_DIR + TEST_DIR; + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String RI_HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = RI_HOME + TEST_NAME + ".dml"; programArgs = new String[]{ "-args", String.valueOf(val) }; fullRScriptName = RI_HOME + TEST_NAME + ".R";