This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 9b940f7051 [SYSTEMDS-3797] Fix rewrite for trace on reorg operations
9b940f7051 is described below

commit 9b940f7051ad1b8b216be130cd785ae3165da0b3
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Nov 28 10:39:56 2024 +0100

    [SYSTEMDS-3797] Fix rewrite for trace on reorg operations
    
    This patch fixes the rewrite for removing unnecessary reorg operations
    such as sum(t(X)) or sum(rev(X)) for trace aggregations which only
    consume a subset of values. Furthermore, we generalize this rewrite
    to now eliminate all reorg operations that are guaranteed to preserve
    all values (e.g., transpose/reshape/rev/roll, but not for diagM2V and
    sort with index return).
    
    Thanks to Jannik Lindemann for catching this issue.
---
 src/main/java/org/apache/sysds/common/Types.java             |  4 ++++
 .../hops/rewrite/RewriteAlgebraicSimplificationStatic.java   | 12 +++++-------
 .../rewrite/RewriteSimplifyTraceMatrixMultTest.java          | 12 +++---------
 .../functions/rewrite/RewriteSimplifyTraceMatrixMult.R       |  5 +++++
 .../functions/rewrite/RewriteSimplifyTraceMatrixMult.dml     |  2 ++
 5 files changed, 19 insertions(+), 16 deletions(-)

diff --git a/src/main/java/org/apache/sysds/common/Types.java 
b/src/main/java/org/apache/sysds/common/Types.java
index e7274b25c4..ba264dea7f 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -751,6 +751,10 @@ public interface Types {
                DIAG, //DIAG_V2M and DIAG_M2V could not be distinguished if 
sizes unknown
                RESHAPE, REV, ROLL, SORT, TRANS;
                
+               public boolean preservesValues() {
+                       return this != DIAG && this != SORT;
+               }
+               
                @Override
                public String toString() {
                        switch(this) {
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 056770dceb..8053ddc78a 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -980,23 +980,21 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
 
        private static Hop simplifyUnaryAggReorgOperation( Hop parent, Hop hi, 
int pos )
        {
-               if(   hi instanceof AggUnaryOp && 
((AggUnaryOp)hi).getDirection()==Direction.RowCol  //full uagg
-                               && hi.getInput().get(0) instanceof ReorgOp  ) 
//reorg operation
+               if( hi instanceof AggUnaryOp && 
((AggUnaryOp)hi).getDirection()==Direction.RowCol 
+                       && ((AggUnaryOp)hi).getOp() != AggOp.TRACE    //full 
uagg
+                       && hi.getInput().get(0) instanceof ReorgOp  ) //reorg 
operation
                {
                        ReorgOp rop = (ReorgOp)hi.getInput().get(0);
-                       if(   (rop.getOp()==ReOrgOp.TRANS || 
rop.getOp()==ReOrgOp.RESHAPE
-                                       || rop.getOp() == ReOrgOp.REV )         
//valid reorg
-                                       && rop.getParent().size()==1 )          
    //uagg only reorg consumer
+                       if( rop.getOp().preservesValues()       //valid reorg
+                               && rop.getParent().size()==1 )      //uagg only 
reorg consumer
                        {
                                Hop input = rop.getInput().get(0);
                                HopRewriteUtils.removeAllChildReferences(hi);
                                HopRewriteUtils.removeAllChildReferences(rop);
                                HopRewriteUtils.addChildReference(hi, input);
-
                                LOG.debug("Applied 
simplifyUnaryAggReorgOperation");
                        }
                }
-
                return hi;
        }
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceMatrixMultTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceMatrixMultTest.java
index 4a81609d3f..c2ae90eec7 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceMatrixMultTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceMatrixMultTest.java
@@ -85,18 +85,12 @@ public class RewriteSimplifyTraceMatrixMultTest extends 
AutomatedTestBase {
                        TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
 
                        //check trace operator existence
-                       String uaktrace = "uaktrace";
-                       long numTrace = 
Statistics.getCPHeavyHitterCount(uaktrace);
-
-                       if(rewrites)
-                               Assert.assertTrue(numTrace == 0);
-                       else
-                               Assert.assertTrue(numTrace == 1);
-
+                       long numTrace = 
Statistics.getCPHeavyHitterCount("uaktrace");
+                       Assert.assertTrue(numTrace == (rewrites ? 1 : 2)); 
+                       Assert.assertTrue(heavyHittersContainsString("rev"));
                }
                finally {
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
                }
-
        }
 }
diff --git 
a/src/test/scripts/functions/rewrite/RewriteSimplifyTraceMatrixMult.R 
b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceMatrixMult.R
index 2153b2dafd..3bb323986d 100644
--- a/src/test/scripts/functions/rewrite/RewriteSimplifyTraceMatrixMult.R
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceMatrixMult.R
@@ -36,6 +36,11 @@ B = as.matrix(readMM(paste(args[1], "B.mtx", sep="")))
 
 # Perform the matrix operation
 R = sum(diag(A %*% B))
+rA = A;
+for(i in 1:nrow(rA)) {
+  rA[,i] = rev(rA[,i])
+}
+R = R + sum(diag(rA))
 
 # Write the result scalar R
 write(R, paste(args[2], "R" ,sep=""))
diff --git 
a/src/test/scripts/functions/rewrite/RewriteSimplifyTraceMatrixMult.dml 
b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceMatrixMult.dml
index 315af97843..7189a7f4e6 100644
--- a/src/test/scripts/functions/rewrite/RewriteSimplifyTraceMatrixMult.dml
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceMatrixMult.dml
@@ -26,6 +26,8 @@ B = read($2)
 
 # Perform the operation
 R = trace(A %*% B)
+R = R + trace(rev(A))
 
 # Write the result R
 write(R, $3)
+

Reply via email to