This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new f7af63f3ef [SYSTEMDS-3818] Fix parsing of indexing operations (scalar
datatype)
f7af63f3ef is described below
commit f7af63f3ef95706cd429a53820a34750cdcb44eb
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Jan 30 17:50:22 2025 +0100
[SYSTEMDS-3818] Fix parsing of indexing operations (scalar datatype)
This patch fixes an edge case of print(X[1,]) where the indexing
is mistakenly created with scalar data type because the print accepts
scalar. However, later we introduce print(toString(X[1,])). We now
simply make the parsing more robust as indexing is never scalar
other than forced by internal rewrites.
---
src/main/java/org/apache/sysds/parser/DMLTranslator.java | 4 +++-
src/test/scripts/functions/rewrite/RewriteNonScalarPrint.dml | 3 +--
2 files changed, 4 insertions(+), 3 deletions(-)
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index b0673be092..4eff055dbd 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -1686,11 +1686,13 @@ public class DMLTranslator
if (target == null) {
target = createTarget(source);
}
+
//unknown nnz after range indexing (applies to indexing op but
also
//data dependent operations)
target.setNnz(-1);
- Hop indexOp = new IndexingOp(target.getName(),
target.getDataType(), target.getValueType(),
+ DataType dt = target.getDataType().isScalar() ? DataType.MATRIX
: target.getDataType();
+ Hop indexOp = new IndexingOp(target.getName(), dt,
target.getValueType(),
hops.get(source.getName()), ixRange[0], ixRange[1],
ixRange[2], ixRange[3],
source.getRowLowerEqualsUpper(),
source.getColLowerEqualsUpper());
diff --git a/src/test/scripts/functions/rewrite/RewriteNonScalarPrint.dml
b/src/test/scripts/functions/rewrite/RewriteNonScalarPrint.dml
index cf4c3e79c2..609cf4f1c7 100644
--- a/src/test/scripts/functions/rewrite/RewriteNonScalarPrint.dml
+++ b/src/test/scripts/functions/rewrite/RewriteNonScalarPrint.dml
@@ -36,8 +36,7 @@ else if(type==3){ # standard list case
print(A_list)
}
else if(type==4){ # slice row from matrix
- A_row = A[1,]
- print(A_row) # print(A[1,]) produces incorrect output
+ print(A[1,])
}
else if(type==5){ # slice column from matrix
A_col = A[,1]