This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 5a8c979  [SYSTEMDS-3076] Additional hop rewrites for colMeans sequences
5a8c979 is described below

commit 5a8c979bdadc8285a63979e5894d80cf6d94dcd8
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Jul 29 20:34:32 2021 +0200

    [SYSTEMDS-3076] Additional hop rewrites for colMeans sequences
    
    This patch adds two new rewrites that help remove unnecessary operations
    in PCA with shifting (see functions/compress/WorkloadAlgorithmTest):
    
    1) colSums(X) / N -> colMeans(X) (precondition and fewer ops)
    2) colMeans((X-colMeans(X))/...) -> matrix(0,1,ncol(X))
    
    After these rewrites have been applied, various additional rewrites
    trigger to remove unnecessary terms with empty matrix multiplications.
---
 src/main/java/org/apache/sysds/common/Types.java   |  7 ++-
 .../RewriteAlgebraicSimplificationDynamic.java     | 60 ++++++++++++++++++++--
 .../RewriteAlgebraicSimplificationStatic.java      |  2 +-
 3 files changed, 62 insertions(+), 7 deletions(-)

diff --git a/src/main/java/org/apache/sysds/common/Types.java 
b/src/main/java/org/apache/sysds/common/Types.java
index badf307..da53091 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -144,7 +144,12 @@ public class Types
                RowCol, // full aggregate
                Row,    // row aggregate (e.g., rowSums)
                Col;    // column aggregate (e.g., colSums)
-               
+               public boolean isRow() {
+                       return this == Row;
+               }
+               public boolean isCol() {
+                       return this == Col;
+               }
                @Override
                public String toString() {
                        switch(this) {
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 269050c..5a29e6b 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -160,13 +160,15 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                                hi = removeUnnecessaryAppendTSMM(hop, hi, i);   
  //e.g., X = t(rbind(A,B,C)) %*% rbind(A,B,C) -> t(A)%*%A + t(B)%*%B + t(C)%*%C
                        if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
                                hi = fuseDatagenAndReorgOperation(hop, hi, i);  
  //e.g., t(rand(rows=10,cols=1)) -> rand(rows=1,cols=10), if one dim=1
-                       hi = simplifyColwiseAggregate(hop, hi, i);        
//e.g., colsums(X) -> sum(X) or X, if col/row vector
-                       hi = simplifyRowwiseAggregate(hop, hi, i);        
//e.g., rowsums(X) -> sum(X) or X, if row/col vector
+                       hi = simplifyColwiseAggregate(hop, hi, i);        
//e.g., colSums(X) -> sum(X) or X, if col/row vector
+                       hi = simplifyRowwiseAggregate(hop, hi, i);        
//e.g., rowSums(X) -> sum(X) or X, if row/col vector
+                       hi = simplifyMeanAggregation(hop, hi, i);         
//e.g., colSums(X)/N -> colMeans(X) if N = nrow(X)
                        hi = simplifyColSumsMVMult(hop, hi, i);           
//e.g., colSums(X*Y) -> t(Y) %*% X, if Y col vector
                        hi = simplifyRowSumsMVMult(hop, hi, i);           
//e.g., rowSums(X*Y) -> X %*% t(Y), if Y row vector
                        hi = simplifyUnnecessaryAggregate(hop, hi, i);    
//e.g., sum(X) -> as.scalar(X), if 1x1 dims
                        hi = simplifyEmptyAggregate(hop, hi, i);          
//e.g., sum(X) -> 0, if nnz(X)==0
-                       hi = simplifyEmptyUnaryOperation(hop, hi, i);     
//e.g., round(X) -> matrix(0,nrow(X),ncol(X)), if nnz(X)==0                   
+                       hi = simplifyEmptyColMeans(hop, hi, i);           
//e.g., colMeans(X-colMeans(X)) if none or scaling by scalars/col-vectors
+                       hi = simplifyEmptyUnaryOperation(hop, hi, i);     
//e.g., round(X) -> matrix(0,nrow(X),ncol(X)), if nnz(X)==0
                        hi = simplifyEmptyReorgOperation(hop, hi, i);     
//e.g., t(X) -> matrix(0, ncol(X), nrow(X)) 
                        hi = simplifyEmptySortOperation(hop, hi, i);      
//e.g., order(X) -> seq(1, nrow(X)), if nnz(X)==0 
                        hi = simplifyEmptyMatrixMult(hop, hi, i);         
//e.g., X%*%Y -> matrix(0,...), if nnz(Y)==0 | X if Y==matrix(1,1,1)
@@ -722,6 +724,30 @@ public class RewriteAlgebraicSimplificationDynamic extends 
HopRewriteRule
                return hi;
        }
        
+       private static Hop simplifyMeanAggregation( Hop parent, Hop hi, int pos 
) {
+               // colSums(X)/N -> colMeans(X), if N = nrow(X), all directions 
but different vals
+               if( HopRewriteUtils.isBinary(hi, OpOp2.DIV)
+                       && HopRewriteUtils.isAggUnaryOp(hi.getInput(0), 
AggOp.SUM)
+                       && hi.getInput(0).getParent().size()==1 //prevent 
repeated scans
+                       && hi.getInput(1).getDataType().isScalar())
+               {
+                       AggUnaryOp agg = (AggUnaryOp)hi.getInput(0);
+                       Hop in = agg.getInput(0);
+                       Hop N = hi.getInput(1);
+                       if( (agg.getDirection().isRow() && 
HopRewriteUtils.isSizeExpressionOf(N, in, false))
+                               || (agg.getDirection().isCol() && 
HopRewriteUtils.isSizeExpressionOf(N, in, true)) )
+                       {
+                               HopRewriteUtils.replaceChildReference(parent, 
hi, agg, pos);
+                               HopRewriteUtils.cleanupUnreferenced(hi, N);
+                               agg.setOp(AggOp.MEAN);
+                               hi = agg;
+                               LOG.debug("Applied simplifyMeanAggregation");
+                       }
+               }
+               
+               return hi;
+       }
+       
        private static Hop simplifyColSumsMVMult( Hop parent, Hop hi, int pos ) 
        {
                //colSums(X*Y) -> t(Y) %*% X, if Y col vector; additional 
transpose later
@@ -821,7 +847,7 @@ public class RewriteAlgebraicSimplificationDynamic extends 
HopRewriteRule
        
        private static Hop simplifyEmptyAggregate(Hop parent, Hop hi, int pos) 
        {
-               if( hi instanceof AggUnaryOp  ) 
+               if( hi instanceof AggUnaryOp )
                {
                        AggUnaryOp uhi = (AggUnaryOp)hi;
                        Hop input = uhi.getInput().get(0);
@@ -848,6 +874,30 @@ public class RewriteAlgebraicSimplificationDynamic extends 
HopRewriteRule
                return hi;
        }
        
+       private static Hop simplifyEmptyColMeans(Hop parent, Hop hi, int pos) 
+       {
+               if( hi.dimsKnown() && HopRewriteUtils.isAggUnaryOp(hi, 
AggOp.MEAN, Direction.Col) ) {
+                       Hop in = hi.getInput(0);
+                       //colMeans(X-colMeans(X)) without scaling
+                       boolean apply = HopRewriteUtils.isBinary(in, 
OpOp2.MINUS)
+                               && HopRewriteUtils.isAggUnaryOp(in.getInput(1), 
AggOp.MEAN, Direction.Col)
+                               && in.getInput(0) == 
in.getInput(1).getInput(0); //requires CSE
+                       //colMeans((X-colMeans(X))/colSds(X)) if scaling by 
scalars/col-vectors
+                       apply = apply || (HopRewriteUtils.isBinary(in, 
OpOp2.DIV, OpOp2.MULT)
+                               && in.getInput(1).getDim1()==1 //row vector
+                               && HopRewriteUtils.isBinary(in.getInput(0), 
OpOp2.MINUS)
+                               && 
HopRewriteUtils.isAggUnaryOp(in.getInput(0).getInput(1), AggOp.MEAN, 
Direction.Col)
+                               && in.getInput(0).getInput(0) == 
in.getInput(0).getInput(1).getInput(0));
+                       if( apply ) {
+                               Hop hnew = HopRewriteUtils.createDataGenOp(hi, 
hi, 0); //empty
+                               HopRewriteUtils.replaceChildReference(parent, 
hi, hnew, pos);
+                               hi = hnew;
+                               LOG.debug("Applied simplifyEmptyColMeans");
+                       }
+               }
+               return hi;
+       }
+       
        private static Hop simplifyEmptyUnaryOperation(Hop parent, Hop hi, int 
pos) 
        {
                if( hi instanceof UnaryOp  ) 
@@ -866,7 +916,7 @@ public class RewriteAlgebraicSimplificationDynamic extends 
HopRewriteRule
                                        
                                        LOG.debug("Applied 
simplifyEmptyUnaryOperation");
                                }
-                       }                       
+                       }
                }
                
                return hi;
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index f0d9dea..56854ff 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -158,7 +158,7 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        hi = simplifyBinaryMatrixScalarOperation(hop, hi, 
i);//e.g., as.scalar(X*s) -> as.scalar(X)*s;
                        hi = pushdownUnaryAggTransposeOperation(hop, hi, i); 
//e.g., colSums(t(X)) -> t(rowSums(X))
                        hi = pushdownCSETransposeScalarOperation(hop, hi, 
i);//e.g., a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X)
-                       hi = pushdownSumBinaryMult(hop, hi, i);              
//e.g., sum(lamda*X) -> lamda*sum(X)
+                       hi = pushdownSumBinaryMult(hop, hi, i);              
//e.g., sum(lambda*X) -> lambda*sum(X)
                        hi = simplifyUnaryPPredOperation(hop, hi, i);        
//e.g., abs(ppred()) -> ppred(), others: round, ceil, floor
                        hi = simplifyTransposedAppend(hop, hi, i);           
//e.g., t(cbind(t(A),t(B))) -> rbind(A,B);
                        if(OptimizerUtils.ALLOW_OPERATOR_FUSION)

Reply via email to