This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 9484f1116c [SYSTEMDS-3817] Printing of matrices, frames, lists w/o 
toString
9484f1116c is described below

commit 9484f1116c5f07b3938bc9010e3663ea289140f6
Author: ReneEnjilian <[email protected]>
AuthorDate: Sat Jan 18 11:41:46 2025 +0100

    [SYSTEMDS-3817] Printing of matrices, frames, lists w/o toString
    
    Closes #2180.
---
 .../apache/sysds/hops/rewrite/ProgramRewriter.java |   3 +-
 .../RewriteAlgebraicSimplificationStatic.java      |  15 ---
 .../sysds/hops/rewrite/RewriteNonScalarPrint.java  |  67 +++++++++++++
 .../org/apache/sysds/parser/StatementBlock.java    |   4 +-
 .../rewrite/RewriteNonScalarPrintTest.java         | 109 +++++++++++++++++++++
 .../functions/rewrite/RewriteNonScalarPrint.dml    |  46 +++++++++
 6 files changed, 226 insertions(+), 18 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java 
b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index 8433a5d90b..b08d836efe 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -115,7 +115,8 @@ public class ProgramRewriter{
                        if( LineageCacheConfig.getCompAssRW() )
                                _sbRuleSet.add(  new MarkForLineageReuse()      
                 );
                        _sbRuleSet.add(      new 
RewriteRemoveTransformEncodeMeta()          );
-               }
+                       _dagRuleSet.add( new RewriteNonScalarPrint()            
             );
+               }
                
                // DYNAMIC REWRITES (which do require size information)
                if( dynamicRewrites )
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 72aa05ad16..6988ee5839 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -197,7 +197,6 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        hi = simplifyNotOverComparisons(hop, hi, i);         
//e.g., !(A>B) -> (A<=B)
                        //hi = removeUnecessaryPPred(hop, hi, i);            
//e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
 
-                       hi = fixNonScalarPrint(hop, hi, i);                  
//e.g., print(m) -> print(toString(m))
 
                        //process childs recursively after rewrites (to 
investigate pattern newly created by rewrites)
                        if( !descendFirst )
@@ -2131,20 +2130,6 @@ public class RewriteAlgebraicSimplificationStatic 
extends HopRewriteRule
                return hi;
        }
 
-       private static Hop fixNonScalarPrint(Hop parent, Hop hi, int pos) {
-               if(HopRewriteUtils.isUnary(parent, OpOp1.PRINT) && 
!hi.getDataType().isScalar()) {
-                       LinkedHashMap<String, Hop> args = new LinkedHashMap<>();
-                       args.put("target", hi);
-                       Hop newHop = 
HopRewriteUtils.createParameterizedBuiltinOp(
-                                       hi, args, ParamBuiltinOp.TOSTRING);
-                       HopRewriteUtils.replaceChildReference(parent, hi, 
newHop, pos);
-                       hi = newHop;
-                       LOG.debug("Applied fixNonScalarPrint (line " + 
hi.getBeginLine() + ")");
-               }
-
-               return hi;
-       }
-
        /**
         * NOTE: currently disabled since this rewrite is INVALID in the
         * presence of NaNs (because (NaN!=NaN) is true). 
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteNonScalarPrint.java 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteNonScalarPrint.java
new file mode 100644
index 0000000000..66d7707dd7
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteNonScalarPrint.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.rewrite;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.Hop;
+
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+
+public class RewriteNonScalarPrint extends HopRewriteRule
+{
+       @Override
+       public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, 
ProgramRewriteStatus state) {
+               if(roots != null) {
+                       for(Hop h : roots)
+                               rewriteHopDAG(h, state);
+               }
+               return roots;
+       }
+
+       @Override
+       public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
+               if(root != null)
+                       rewritePrintNonScalar(root);
+               return root;
+       }
+
+       private void rewritePrintNonScalar(Hop hop) {
+               // Check if hop is a unary PRINT op
+               if(HopRewriteUtils.isUnary(hop, Types.OpOp1.PRINT)) {
+                       // Check if child is non-scalar
+                       Hop child = hop.getInput().get(0);
+                       
+                       if(!child.getDataType().isScalar()) {
+                               LinkedHashMap<String, Hop> args = new 
LinkedHashMap<>();
+                               args.put("target", child);
+
+                               // create toString hop
+                               Hop toStringOp = 
HopRewriteUtils.createParameterizedBuiltinOp(child, args,
+                                       Types.ParamBuiltinOp.TOSTRING);
+
+                               // Replace child with toString in hop
+                               HopRewriteUtils.replaceChildReference(hop, 
child, toStringOp, 0);
+                               LOG.debug("Applied non-scalar print rewrite on 
hop ID = " + hop.getHopID());
+                       }
+               }
+       }
+
+}
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java 
b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index 82501f63c5..b81a603e7c 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -894,11 +894,11 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
                                        DataType outputDatatype = 
expression.getOutput().getDataType();
                                        switch (outputDatatype) {
                                                case SCALAR:
-                                                       break;
                                                case MATRIX:
-                                               case TENSOR:
                                                case FRAME:
                                                case LIST:
+                                                       break;
+                                               case TENSOR:
                                                        
pstmt.raiseValidateError("Print statements can only print scalars. To print a " 
+ outputDatatype + ", please wrap it in a toString() function.", conditional);
                                                default:
                                                        
pstmt.raiseValidateError("Print statements can only print scalars. Input 
datatype was: " + outputDatatype, conditional);
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteNonScalarPrintTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteNonScalarPrintTest.java
new file mode 100644
index 0000000000..5fe546200f
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteNonScalarPrintTest.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class RewriteNonScalarPrintTest extends AutomatedTestBase {
+
+       private static final String TEST_NAME = "RewriteNonScalarPrint";
+       private static final String TEST_DIR = "functions/rewrite/";
+
+       private static final String TEST_CLASS_DIR =
+               TEST_DIR + RewriteNonScalarPrintTest.class.getSimpleName() + 
"/";
+
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
+       }
+
+       @Test
+       public void testNonScalarPrintMatrix() {
+               testRewriteNonScalarPrint(1);
+       }
+
+       @Test
+       public void testNonScalarPrintFrame() {
+               testRewriteNonScalarPrint(2);
+       }
+
+       @Test
+       public void testNonScalarPrintList() {
+               testRewriteNonScalarPrint(3);
+       }
+
+       @Test
+       public void testNonScalarPrintMatrixRow() {
+               testRewriteNonScalarPrint(4);
+       }
+
+       @Test
+       public void testNonScalarPrintMatrixCol() {
+               testRewriteNonScalarPrint(5);
+       }
+
+       private void testRewriteNonScalarPrint(int ID) {
+               TestConfiguration config = getTestConfiguration(TEST_NAME);
+               loadTestConfiguration(config);
+               setOutputBuffering(true);
+               
+               String HOME = SCRIPT_DIR + TEST_DIR;
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-stats", "-args", 
String.valueOf(ID), output("R")};
+               String fullOut = runTest(true, false, null, -1).toString();
+
+               //Extract or remove "SystemDS Statistics:"
+               int idxStats = fullOut.indexOf("SystemDS Statistics:");
+               String userOutput = (idxStats >= 0) ? fullOut.substring(0, 
idxStats) : fullOut;
+
+               String toString = "toString";
+               long numtoString = Statistics.getCPHeavyHitterCount(toString);
+
+               if(ID == 1) {
+                       Assert.assertTrue(
+                               userOutput.contains("1.000 2.000 3.000\n" + 
"4.000 5.000 6.000\n" + "7.000 8.000 9.000") &&
+                                       numtoString == 1);
+               }
+               else if(ID == 2) {
+                       Assert.assertTrue(userOutput.contains(
+                               "# FRAME: nrow = 3, ncol = 3\n" + "# C1 C2 
C3\n" + "# INT32 INT32 INT32\n" + "1 2 3\n" + "4 5 6\n" +
+                                       "7 8 9") && numtoString == 1);
+               }
+               else if(ID == 3) {
+                       Assert.assertTrue(userOutput.contains("[1, 2, 3]") && 
numtoString == 1);
+               }
+               else if(ID == 4) {
+                       Assert.assertTrue(userOutput.contains("1.000 2.000 
3.000\n") && numtoString == 1);
+               }
+               else if(ID == 5) {
+                       Assert.assertTrue(userOutput.contains("1.000\n" + 
"4.000\n" + "7.000") && numtoString == 1);
+               }
+
+               //Print the entire output
+               System.out.println(fullOut);
+       }
+}
diff --git a/src/test/scripts/functions/rewrite/RewriteNonScalarPrint.dml 
b/src/test/scripts/functions/rewrite/RewriteNonScalarPrint.dml
new file mode 100644
index 0000000000..cf4c3e79c2
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteNonScalarPrint.dml
@@ -0,0 +1,46 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+# Generate matrix
+type = $1
+A = matrix("1 2 3 4 5 6 7 8 9", rows=3, cols=3)
+A_list = list(1, 2, 3)
+
+if(type==1){    # standard matrix case
+    print(A)
+}
+else if(type==2){  # standard frame case
+    A_frame = as.frame(A)
+    print(A_frame)
+}
+else if(type==3){   # standard list case
+    print(A_list)
+}
+else if(type==4){   # slice row from matrix
+    A_row = A[1,]
+    print(A_row)    # print(A[1,]) produces incorrect output
+}
+else if(type==5){   # slice column from matrix
+    A_col = A[,1]
+    print(A_col)
+}
+

Reply via email to