This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new a3be9c0  [SYSTEMDS-3108] Fix size inference for block indexing 
expressions
a3be9c0 is described below

commit a3be9c019a68bdb57f9062018fe4d108bfd75651
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Dec 28 23:45:18 2021 +0100

    [SYSTEMDS-3108] Fix size inference for block indexing expressions
    
    This patch fixes an issue of the existing size inference for block
    indexing expressions such as (with nc being a constant)
    
    R = X[(nc * (i-1) + 1) : (nc * i), ];
    
    Previously, we specifically added this for Kmeans, but the detection
    logic expected i to be a transient read of a variable. Later rewrites
    (e.g., removal of branches and merge of basic blocks) then caused i
    to never be bound to a logical variable (hop intermediate) and thus
    requiring a workaround at script level.
    
    This patch makes the rewrite more general to (1) work for arbitrary
    variables (transient reads or intermediates) by comparing the hops
    directly (which depends on common subexpression elimination)
    as well as (2) variations such as (nc * i) and (i * nc).
---
 scripts/builtin/kmeans.dml                         |  1 -
 .../java/org/apache/sysds/hops/IndexingOp.java     | 35 +++++++++++-----------
 2 files changed, 17 insertions(+), 19 deletions(-)

diff --git a/scripts/builtin/kmeans.dml b/scripts/builtin/kmeans.dml
index 45b74d1..1e7e9df 100644
--- a/scripts/builtin/kmeans.dml
+++ b/scripts/builtin/kmeans.dml
@@ -210,7 +210,6 @@ m_kmeans = function(Matrix[Double] X, Integer k = 10, 
Integer runs = 10, Integer
              + ";  Avg WCSS = " + avg_wcss + ";  Worst WCSS = " + worst_wcss);
 
     C = All_Centroids [(num_centroids * (best_index - 1) + 1) : (num_centroids 
* best_index), ];
-    while(FALSE){} # workaround to make ncol t(C) known
     D =  -2 * (X %*% t(C)) + t(rowSums (C ^ 2));
     P = (D <= rowMins (D));
     aggr_P = t(cumsum (t(P)));
diff --git a/src/main/java/org/apache/sysds/hops/IndexingOp.java 
b/src/main/java/org/apache/sysds/hops/IndexingOp.java
index 394a5bd..17c097e 100644
--- a/src/main/java/org/apache/sysds/hops/IndexingOp.java
+++ b/src/main/java/org/apache/sysds/hops/IndexingOp.java
@@ -262,36 +262,35 @@ public class IndexingOp extends Hop
        {
                boolean ret = false;
                LiteralOp constant = null;
-               DataOp var = null;
+               Hop var = null;
 
                //handle lower bound
                if( lbound instanceof BinaryOp && 
((BinaryOp)lbound).getOp()==OpOp2.PLUS
-                       && lbound.getInput().get(1) instanceof LiteralOp 
-                       && 
HopRewriteUtils.getDoubleValueSafe((LiteralOp)lbound.getInput().get(1))==1
-                       && lbound.getInput().get(0) instanceof BinaryOp)
+                       && lbound.getInput(1) instanceof LiteralOp 
+                       && 
HopRewriteUtils.getDoubleValueSafe((LiteralOp)lbound.getInput(1))==1
+                       && lbound.getInput(0) instanceof BinaryOp)
                {
-                       BinaryOp lmult = (BinaryOp)lbound.getInput().get(0);
-                       if( lmult.getOp()==OpOp2.MULT && 
lmult.getInput().get(0) instanceof LiteralOp
-                               && lmult.getInput().get(1) instanceof BinaryOp )
+                       BinaryOp lmult = (BinaryOp)lbound.getInput(0);
+                       if( lmult.getOp()==OpOp2.MULT && lmult.getInput(0) 
instanceof LiteralOp
+                               && lmult.getInput(1) instanceof BinaryOp )
                        {
-                               BinaryOp lminus = 
(BinaryOp)lmult.getInput().get(1);
-                               if( lminus.getOp()==OpOp2.MINUS && 
lminus.getInput().get(1) instanceof LiteralOp
-                                       && 
HopRewriteUtils.getDoubleValueSafe((LiteralOp)lminus.getInput().get(1))==1 
-                                       && lminus.getInput().get(0) instanceof 
DataOp )
+                               BinaryOp lminus = (BinaryOp)lmult.getInput(1);
+                               if( lminus.getOp()==OpOp2.MINUS && 
lminus.getInput(1) instanceof LiteralOp
+                                       && 
HopRewriteUtils.getDoubleValueSafe((LiteralOp)lminus.getInput(1))==1 )
                                {
-                                       constant = 
(LiteralOp)lmult.getInput().get(0);
-                                       var = (DataOp) lminus.getInput().get(0);
+                                       constant = (LiteralOp)lmult.getInput(0);
+                                       var = lminus.getInput(0); //any DataOp 
or intermediate hop
                                }
                        }
                }
                
-               //handle upper bound
+               //handle upper bound (general check for var depends on CSE)
                if( var != null && constant != null && ubound instanceof 
BinaryOp 
-                       && ubound.getInput().get(0) instanceof LiteralOp
-                       && ubound.getInput().get(1) instanceof DataOp 
-                       && 
ubound.getInput().get(1).getName().equals(var.getName()) ) 
+                       && ((ubound.getInput(0) instanceof LiteralOp && 
ubound.getInput(1) == var)
+                         ||(ubound.getInput(1) instanceof LiteralOp && 
ubound.getInput(0) == var)) )
                {
-                       LiteralOp constant2 = 
(LiteralOp)ubound.getInput().get(0);
+                       int constIndex = (ubound.getInput(1) == var) ? 0 : 1;
+                       LiteralOp constant2 = 
(LiteralOp)ubound.getInput(constIndex);
                        ret = ( HopRewriteUtils.getDoubleValueSafe(constant) == 
                                        
HopRewriteUtils.getDoubleValueSafe(constant2) );
                }

Reply via email to