This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new f7af63f3ef [SYSTEMDS-3818] Fix parsing of indexing operations (scalar 
datatype)
f7af63f3ef is described below

commit f7af63f3ef95706cd429a53820a34750cdcb44eb
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Jan 30 17:50:22 2025 +0100

    [SYSTEMDS-3818] Fix parsing of indexing operations (scalar datatype)
    
    This patch fixes an edge case of print(X[1,]) where the indexing
    is mistakenly created with scalar data type because the print accepts
    scalar. However, later we introduce print(toString(X[1,])). We now
    simply make the parsing more robust as indexing is never scalar
    other than forced by internal rewrites.
---
 src/main/java/org/apache/sysds/parser/DMLTranslator.java     | 4 +++-
 src/test/scripts/functions/rewrite/RewriteNonScalarPrint.dml | 3 +--
 2 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index b0673be092..4eff055dbd 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -1686,11 +1686,13 @@ public class DMLTranslator
                if (target == null) {
                        target = createTarget(source);
                }
+               
                //unknown nnz after range indexing (applies to indexing op but 
also
                //data dependent operations)
                target.setNnz(-1); 
 
-               Hop indexOp = new IndexingOp(target.getName(), 
target.getDataType(), target.getValueType(),
+               DataType dt = target.getDataType().isScalar() ? DataType.MATRIX 
: target.getDataType();
+               Hop indexOp = new IndexingOp(target.getName(), dt, 
target.getValueType(),
                        hops.get(source.getName()), ixRange[0], ixRange[1], 
ixRange[2], ixRange[3],
                        source.getRowLowerEqualsUpper(), 
source.getColLowerEqualsUpper());
 
diff --git a/src/test/scripts/functions/rewrite/RewriteNonScalarPrint.dml 
b/src/test/scripts/functions/rewrite/RewriteNonScalarPrint.dml
index cf4c3e79c2..609cf4f1c7 100644
--- a/src/test/scripts/functions/rewrite/RewriteNonScalarPrint.dml
+++ b/src/test/scripts/functions/rewrite/RewriteNonScalarPrint.dml
@@ -36,8 +36,7 @@ else if(type==3){   # standard list case
     print(A_list)
 }
 else if(type==4){   # slice row from matrix
-    A_row = A[1,]
-    print(A_row)    # print(A[1,]) produces incorrect output
+    print(A[1,])
 }
 else if(type==5){   # slice column from matrix
     A_col = A[,1]

Reply via email to