This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 45de519829 [MINOR] Fix index bounds checks on recompilation literal 
replacement
45de519829 is described below

commit 45de519829860f515f6f40871e9f4a4ce92d5cec
Author: Matthias Boehm <mboe...@gmail.com>
AuthorDate: Fri Apr 19 10:46:41 2024 +0200

    [MINOR] Fix index bounds checks on recompilation literal replacement
---
 .../org/apache/sysds/hops/recompile/LiteralReplacement.java    |  4 ++++
 .../instructions/spark/MatrixIndexingSPInstruction.java        | 10 +++++++---
 .../functions/indexing/UnboundedScalarRightIndexingTest.java   |  8 ++++----
 3 files changed, 15 insertions(+), 7 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/recompile/LiteralReplacement.java 
b/src/main/java/org/apache/sysds/hops/recompile/LiteralReplacement.java
index 1a40ada94e..441a80ceb6 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/LiteralReplacement.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/LiteralReplacement.java
@@ -253,6 +253,10 @@ public class LiteralReplacement
                                if( mo.getNumRows()*mo.getNumColumns() < 
REPLACE_LITERALS_MAX_MATRIX_SIZE )
                                {
                                        MatrixBlock mBlock = mo.acquireRead();
+                                       if( rlval>mo.getNumRows() || 
clval>mo.getNumColumns() ) {
+                                               throw new 
DMLRuntimeException("Scalar indexing out-of-bounds:"
+                                                       + " ["+rlval+", 
"+clval+"] in "+mo.getDataCharacteristics());
+                                       }
                                        double value = 
mBlock.get((int)rlval-1,(int)clval-1);
                                        mo.release();
                                        
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java
index 3c8583d34c..e97336a8a6 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java
@@ -90,11 +90,15 @@ public class MatrixIndexingSPInstruction extends 
IndexingSPInstruction {
                long cu = ec.getScalarInput(colUpper).getLongValue();
                IndexRange ixrange = new IndexRange(rl, ru, cl, cu);
                
+               //check bounds
+               DataCharacteristics mcIn = 
sec.getDataCharacteristics(input1.getName());
+               if( mcIn.dimsKnown() && (ru>mcIn.getRows() || 
cu>mcIn.getCols()) )
+                       throw new DMLRuntimeException("Index range out of 
bounds: "+ixrange+" "+mcIn);
+               
                //right indexing
                if( opcode.equalsIgnoreCase(RightIndex.OPCODE) )
                {
                        //update and check output dimensions
-                       DataCharacteristics mcIn = 
sec.getDataCharacteristics(input1.getName());
                        DataCharacteristics mcOut = 
sec.getDataCharacteristics(output.getName());
                        mcOut.set(ru-rl+1, cu-cl+1, mcIn.getBlocksize(), 
mcIn.getBlocksize());
                        mcOut.setNonZerosBound(Math.min(mcOut.getLength(), 
mcIn.getNonZerosBound()));
@@ -114,7 +118,7 @@ public class MatrixIndexingSPInstruction extends 
IndexingSPInstruction {
                                
                                //put output RDD handle into symbol table
                                sec.setRDDHandleForVariable(output.getName(), 
out);
-                               sec.addLineageRDD(output.getName(), 
input1.getName());  
+                               sec.addLineageRDD(output.getName(), 
input1.getName());
                        }
                }
                //left indexing
@@ -129,7 +133,7 @@ public class MatrixIndexingSPInstruction extends 
IndexingSPInstruction {
                        
                        //update and check output dimensions
                        DataCharacteristics mcOut = 
sec.getDataCharacteristics(output.getName());
-                       DataCharacteristics mcLeft = 
ec.getDataCharacteristics(input1.getName());
+                       DataCharacteristics mcLeft = mcIn;
                        mcOut.set(mcLeft.getRows(), mcLeft.getCols(), 
mcLeft.getBlocksize(), mcLeft.getBlocksize());
                        checkValidOutputDimensions(mcOut);
                        
diff --git 
a/src/test/java/org/apache/sysds/test/functions/indexing/UnboundedScalarRightIndexingTest.java
 
b/src/test/java/org/apache/sysds/test/functions/indexing/UnboundedScalarRightIndexingTest.java
index 8fc9bf6d9e..d32e7865ec 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/indexing/UnboundedScalarRightIndexingTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/indexing/UnboundedScalarRightIndexingTest.java
@@ -73,10 +73,10 @@ public class UnboundedScalarRightIndexingTest extends 
AutomatedTestBase
                        DMLScript.USE_LOCAL_SPARK_CONFIG = true;
 
                try {
-                   TestConfiguration config = getTestConfiguration(TEST_NAME);
-                   loadTestConfiguration(config);
-               
-               String RI_HOME = SCRIPT_DIR + TEST_DIR;
+                       TestConfiguration config = 
getTestConfiguration(TEST_NAME);
+                       loadTestConfiguration(config);
+
+                       String RI_HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = RI_HOME + TEST_NAME + ".dml";
                        programArgs = new String[]{ "-args", 
String.valueOf(val) };
                        fullRScriptName = RI_HOME + TEST_NAME + ".R";

Reply via email to