This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new af2c896289 [SYSTEMDS-1151] Fix error handling left indexing w/ scalars
af2c896289 is described below

commit af2c896289365ea2d7487eab500bf0c43444bb14
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Mar 22 12:59:08 2024 +0100

    [SYSTEMDS-1151] Fix error handling left indexing w/ scalars
    
    This patch creates consistency of error messages for left indexing
    (e.g., A[a:b,c:d] = x) for matrix and scalar right-hand-sides.
---
 .../sysds/runtime/matrix/data/MatrixBlock.java     | 43 +++++++++++++++-------
 1 file changed, 29 insertions(+), 14 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 86ee70bc18..701abd1c20 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -4166,19 +4166,10 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
        }
 
        public MatrixBlock leftIndexingOperations(MatrixBlock rhsMatrix,
-                       int rl, int ru, int cl, int cu, MatrixBlock ret, 
UpdateType update) {
+                       int rl, int ru, int cl, int cu, MatrixBlock ret, 
UpdateType update)
+       {
                // Check the validity of bounds
-               if( rl < 0 || rl >= getNumRows() || ru < rl || ru >= 
getNumRows()
-                       || cl < 0 || cl >= getNumColumns() || cu < cl || cu >= 
getNumColumns() ) {
-                       throw new DMLRuntimeException("Invalid values for 
matrix indexing: ["+(rl+1)+":"+(ru+1)+"," 
-                               + (cl+1)+":"+(cu+1)+"] " + "must be within 
matrix dimensions ["+getNumRows()+","+getNumColumns()+"].");
-               }
-               if( (ru-rl+1) != rhsMatrix.getNumRows() || (cu-cl+1) != 
rhsMatrix.getNumColumns() ) {
-                       throw new DMLRuntimeException("Invalid values for 
matrix indexing: " +
-                               "dimensions of the source matrix 
["+rhsMatrix.getNumRows()+"x" + rhsMatrix.getNumColumns() + "] " +
-                               "do not match the shape of the matrix specified 
by indices [" +
-                               (rl+1) +":" + (ru+1) + ", " + (cl+1) + ":" + 
(cu+1) + "] (i.e., ["+(ru-rl+1)+"x"+(cu-cl+1)+"]).");
-               }
+               checkDimsForLeftIndexing(rl, ru, cl, cu, true, rhsMatrix.rlen, 
rhsMatrix.clen);
                
                MatrixBlock result = ret;
                boolean sp = estimateSparsityOnLeftIndexing(rlen, clen, 
nonZeros,
@@ -4260,9 +4251,12 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
         * @param update ?
         * @return matrix block
         */
-       public MatrixBlock leftIndexingOperations(ScalarObject scalar, int rl, 
int cl, MatrixBlock ret, UpdateType update) {
+       public MatrixBlock leftIndexingOperations(ScalarObject scalar, int rl, 
int cl,
+               MatrixBlock ret, UpdateType update)
+       {
                double inVal = scalar.getDoubleValue();
                boolean sp = estimateSparsityOnLeftIndexing(rlen, clen, 
nonZeros, 1, 1, (inVal!=0)?1:0);
+               checkDimsForLeftIndexing(rl, rl, cl, cl, false, -1, -1);
                
                if( !update.isInPlace() ) { //general case
                        if(ret==null)
@@ -4283,7 +4277,28 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                ret.quickSetValue(rl, cl, inVal);
                return ret;
        }
-
+       
+       private void checkDimsForLeftIndexing(int rl, int ru, int cl, int cu,
+               boolean checkSrc, int rhsr, int rhsc)
+       {
+               int rlen = getNumRows(), clen = getNumColumns();
+               if( rl < 0 || rl >= rlen || ru < rl || ru >= rlen
+                       || cl < 0 || cl >= clen || cu < cl || cu >= clen ) {
+                       throw new DMLRuntimeException("Invalid values for 
matrix indexing: "
+                               + "["+(rl+1)+":"+(ru+1)+"," + 
(cl+1)+":"+(cu+1)+"] " 
+                               + "must be within matrix dimensions 
["+rlen+"x"+clen+"].");
+               }
+               if( checkSrc ) {
+                       if( (ru-rl+1) != rhsr || (cu-cl+1) != rhsc ) {
+                               throw new DMLRuntimeException("Invalid values 
for matrix indexing: "
+                                       + "dimensions of the source matrix 
["+rhsr+"x"+rhsc+"] "
+                                       + "do not match the shape of the matrix 
specified by indices "
+                                       + "["+(rl+1)+":"+(ru+1)+", 
"+(cl+1)+":"+(cu+1)+"] "
+                                       + "(i.e., 
["+(ru-rl+1)+"x"+(cu-cl+1)+"]).");
+                       }
+               }
+       }
+       
        @Override
        public final MatrixBlock slice(IndexRange ixrange, MatrixBlock ret) {
                return slice(

Reply via email to