This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 9484f1116c [SYSTEMDS-3817] Printing of matrices, frames, lists w/o
toString
9484f1116c is described below
commit 9484f1116c5f07b3938bc9010e3663ea289140f6
Author: ReneEnjilian <[email protected]>
AuthorDate: Sat Jan 18 11:41:46 2025 +0100
[SYSTEMDS-3817] Printing of matrices, frames, lists w/o toString
Closes #2180.
---
.../apache/sysds/hops/rewrite/ProgramRewriter.java | 3 +-
.../RewriteAlgebraicSimplificationStatic.java | 15 ---
.../sysds/hops/rewrite/RewriteNonScalarPrint.java | 67 +++++++++++++
.../org/apache/sysds/parser/StatementBlock.java | 4 +-
.../rewrite/RewriteNonScalarPrintTest.java | 109 +++++++++++++++++++++
.../functions/rewrite/RewriteNonScalarPrint.dml | 46 +++++++++
6 files changed, 226 insertions(+), 18 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index 8433a5d90b..b08d836efe 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -115,7 +115,8 @@ public class ProgramRewriter{
if( LineageCacheConfig.getCompAssRW() )
_sbRuleSet.add( new MarkForLineageReuse()
);
_sbRuleSet.add( new
RewriteRemoveTransformEncodeMeta() );
- }
+ _dagRuleSet.add( new RewriteNonScalarPrint()
);
+ }
// DYNAMIC REWRITES (which do require size information)
if( dynamicRewrites )
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 72aa05ad16..6988ee5839 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -197,7 +197,6 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
hi = simplifyNotOverComparisons(hop, hi, i);
//e.g., !(A>B) -> (A<=B)
//hi = removeUnecessaryPPred(hop, hi, i);
//e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
- hi = fixNonScalarPrint(hop, hi, i);
//e.g., print(m) -> print(toString(m))
//process childs recursively after rewrites (to
investigate pattern newly created by rewrites)
if( !descendFirst )
@@ -2131,20 +2130,6 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
return hi;
}
- private static Hop fixNonScalarPrint(Hop parent, Hop hi, int pos) {
- if(HopRewriteUtils.isUnary(parent, OpOp1.PRINT) &&
!hi.getDataType().isScalar()) {
- LinkedHashMap<String, Hop> args = new LinkedHashMap<>();
- args.put("target", hi);
- Hop newHop =
HopRewriteUtils.createParameterizedBuiltinOp(
- hi, args, ParamBuiltinOp.TOSTRING);
- HopRewriteUtils.replaceChildReference(parent, hi,
newHop, pos);
- hi = newHop;
- LOG.debug("Applied fixNonScalarPrint (line " +
hi.getBeginLine() + ")");
- }
-
- return hi;
- }
-
/**
* NOTE: currently disabled since this rewrite is INVALID in the
* presence of NaNs (because (NaN!=NaN) is true).
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteNonScalarPrint.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteNonScalarPrint.java
new file mode 100644
index 0000000000..66d7707dd7
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteNonScalarPrint.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.rewrite;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.Hop;
+
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+
+public class RewriteNonScalarPrint extends HopRewriteRule
+{
+ @Override
+ public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots,
ProgramRewriteStatus state) {
+ if(roots != null) {
+ for(Hop h : roots)
+ rewriteHopDAG(h, state);
+ }
+ return roots;
+ }
+
+ @Override
+ public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
+ if(root != null)
+ rewritePrintNonScalar(root);
+ return root;
+ }
+
+ private void rewritePrintNonScalar(Hop hop) {
+ // Check if hop is a unary PRINT op
+ if(HopRewriteUtils.isUnary(hop, Types.OpOp1.PRINT)) {
+ // Check if child is non-scalar
+ Hop child = hop.getInput().get(0);
+
+ if(!child.getDataType().isScalar()) {
+ LinkedHashMap<String, Hop> args = new
LinkedHashMap<>();
+ args.put("target", child);
+
+ // create toString hop
+ Hop toStringOp =
HopRewriteUtils.createParameterizedBuiltinOp(child, args,
+ Types.ParamBuiltinOp.TOSTRING);
+
+ // Replace child with toString in hop
+ HopRewriteUtils.replaceChildReference(hop,
child, toStringOp, 0);
+ LOG.debug("Applied non-scalar print rewrite on
hop ID = " + hop.getHopID());
+ }
+ }
+ }
+
+}
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java
b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index 82501f63c5..b81a603e7c 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -894,11 +894,11 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
DataType outputDatatype =
expression.getOutput().getDataType();
switch (outputDatatype) {
case SCALAR:
- break;
case MATRIX:
- case TENSOR:
case FRAME:
case LIST:
+ break;
+ case TENSOR:
pstmt.raiseValidateError("Print statements can only print scalars. To print a "
+ outputDatatype + ", please wrap it in a toString() function.", conditional);
default:
pstmt.raiseValidateError("Print statements can only print scalars. Input
datatype was: " + outputDatatype, conditional);
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteNonScalarPrintTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteNonScalarPrintTest.java
new file mode 100644
index 0000000000..5fe546200f
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteNonScalarPrintTest.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class RewriteNonScalarPrintTest extends AutomatedTestBase {
+
+ private static final String TEST_NAME = "RewriteNonScalarPrint";
+ private static final String TEST_DIR = "functions/rewrite/";
+
+ private static final String TEST_CLASS_DIR =
+ TEST_DIR + RewriteNonScalarPrintTest.class.getSimpleName() +
"/";
+
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
+ }
+
+ @Test
+ public void testNonScalarPrintMatrix() {
+ testRewriteNonScalarPrint(1);
+ }
+
+ @Test
+ public void testNonScalarPrintFrame() {
+ testRewriteNonScalarPrint(2);
+ }
+
+ @Test
+ public void testNonScalarPrintList() {
+ testRewriteNonScalarPrint(3);
+ }
+
+ @Test
+ public void testNonScalarPrintMatrixRow() {
+ testRewriteNonScalarPrint(4);
+ }
+
+ @Test
+ public void testNonScalarPrintMatrixCol() {
+ testRewriteNonScalarPrint(5);
+ }
+
+ private void testRewriteNonScalarPrint(int ID) {
+ TestConfiguration config = getTestConfiguration(TEST_NAME);
+ loadTestConfiguration(config);
+ setOutputBuffering(true);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "-args",
String.valueOf(ID), output("R")};
+ String fullOut = runTest(true, false, null, -1).toString();
+
+ //Extract or remove "SystemDS Statistics:"
+ int idxStats = fullOut.indexOf("SystemDS Statistics:");
+ String userOutput = (idxStats >= 0) ? fullOut.substring(0,
idxStats) : fullOut;
+
+ String toString = "toString";
+ long numtoString = Statistics.getCPHeavyHitterCount(toString);
+
+ if(ID == 1) {
+ Assert.assertTrue(
+ userOutput.contains("1.000 2.000 3.000\n" +
"4.000 5.000 6.000\n" + "7.000 8.000 9.000") &&
+ numtoString == 1);
+ }
+ else if(ID == 2) {
+ Assert.assertTrue(userOutput.contains(
+ "# FRAME: nrow = 3, ncol = 3\n" + "# C1 C2
C3\n" + "# INT32 INT32 INT32\n" + "1 2 3\n" + "4 5 6\n" +
+ "7 8 9") && numtoString == 1);
+ }
+ else if(ID == 3) {
+ Assert.assertTrue(userOutput.contains("[1, 2, 3]") &&
numtoString == 1);
+ }
+ else if(ID == 4) {
+ Assert.assertTrue(userOutput.contains("1.000 2.000
3.000\n") && numtoString == 1);
+ }
+ else if(ID == 5) {
+ Assert.assertTrue(userOutput.contains("1.000\n" +
"4.000\n" + "7.000") && numtoString == 1);
+ }
+
+ //Print the entire output
+ System.out.println(fullOut);
+ }
+}
diff --git a/src/test/scripts/functions/rewrite/RewriteNonScalarPrint.dml
b/src/test/scripts/functions/rewrite/RewriteNonScalarPrint.dml
new file mode 100644
index 0000000000..cf4c3e79c2
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteNonScalarPrint.dml
@@ -0,0 +1,46 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+# Generate matrix
+type = $1
+A = matrix("1 2 3 4 5 6 7 8 9", rows=3, cols=3)
+A_list = list(1, 2, 3)
+
+if(type==1){ # standard matrix case
+ print(A)
+}
+else if(type==2){ # standard frame case
+ A_frame = as.frame(A)
+ print(A_frame)
+}
+else if(type==3){ # standard list case
+ print(A_list)
+}
+else if(type==4){ # slice row from matrix
+ A_row = A[1,]
+ print(A_row) # print(A[1,]) produces incorrect output
+}
+else if(type==5){ # slice column from matrix
+ A_col = A[,1]
+ print(A_col)
+}
+