This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 091144d669 [SYSTEMDS-3798] Fix generality of loop vectorization rewrite
091144d669 is described below

commit 091144d669a803cf03ca953f0a866f5bc967246c
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Dec 13 18:18:59 2024 +0100

    [SYSTEMDS-3798] Fix generality of loop vectorization rewrite
---
 .../hops/rewrite/RewriteForLoopVectorization.java  | 38 +++++++++++++++-------
 .../test/functions/vect/AutoVectorizationTest.java | 24 +++++---------
 2 files changed, 35 insertions(+), 27 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.java 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.java
index 0c09c2efb4..ad06ac2359 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.java
@@ -138,13 +138,11 @@ public class RewriteForLoopVectorization extends 
StatementBlockRewriteRule
                                        && right.getInput(0) instanceof 
IndexingOp )
                                {
                                        IndexingOp ix = 
(IndexingOp)right.getInput(0);
-                                       if( ix.isRowLowerEqualsUpper() && 
ix.getInput(1) instanceof DataOp
-                                               && 
ix.getInput(1).getName().equals(itervar) ){
+                                       if( checkItervarIndexing(ix, itervar, 
true) ){
                                                leftScalar = true;
                                                rowIx = true;
                                        }
-                                       else if( ix.isColLowerEqualsUpper() && 
ix.getInput(3) instanceof DataOp
-                                               && 
ix.getInput(3).getName().equals(itervar) ){
+                                       else if( checkItervarIndexing(ix, 
itervar, false) ){
                                                leftScalar = true;
                                                rowIx = false;
                                        }
@@ -157,13 +155,11 @@ public class RewriteForLoopVectorization extends 
StatementBlockRewriteRule
                                        && left.getInput(0) instanceof 
IndexingOp )
                                {
                                        IndexingOp ix = 
(IndexingOp)left.getInput(0);
-                                       if( ix.isRowLowerEqualsUpper() && 
ix.getInput(1) instanceof DataOp
-                                               && 
ix.getInput(1).getName().equals(itervar) ){
+                                       if( checkItervarIndexing(ix, itervar, 
true) ){
                                                rightScalar = true;
                                                rowIx = true;
                                        }
-                                       else if( ix.isColLowerEqualsUpper() && 
ix.getInput(3) instanceof DataOp
-                                               && 
ix.getInput(3).getName().equals(itervar) ){
+                                       else if( checkItervarIndexing(ix, 
itervar, false) ){
                                                rightScalar = true;
                                                rowIx = false;
                                        }
@@ -231,8 +227,14 @@ public class RewriteForLoopVectorization extends 
StatementBlockRewriteRule
                                        && root.getName().equals(left.getName())
                                        && right instanceof IndexingOp && 
right.isScalar())
                                {
-                                       leftScalar = true;
-                                       rowIx = true; //row and col
+                                       if( 
checkItervarIndexing((IndexingOp)right, itervar, true) ){
+                                               leftScalar = true;
+                                               rowIx = true;
+                                       }
+                                       else if( 
checkItervarIndexing((IndexingOp)right, itervar, false) ){
+                                               leftScalar = true;
+                                               rowIx = false;
+                                       }
                                }
                                //check for right scalar plus
                                else if( HopRewriteUtils.isValidOp(bop.getOp(), 
MAP_SCALAR_AGGREGATE_SOURCE_OPS)  
@@ -240,8 +242,14 @@ public class RewriteForLoopVectorization extends 
StatementBlockRewriteRule
                                        && 
root.getName().equals(right.getName()) 
                                        && left instanceof IndexingOp && 
left.isScalar())
                                {
-                                       rightScalar = true;
-                                       rowIx = true; //row and col
+                                       if( 
checkItervarIndexing((IndexingOp)left, itervar, true) ){
+                                               rightScalar = true;
+                                               rowIx = true;
+                                       }
+                                       else if( 
checkItervarIndexing((IndexingOp)left, itervar, false) ){
+                                               rightScalar = true;
+                                               rowIx = false;
+                                       }
                                }
                        }
                }
@@ -461,6 +469,12 @@ public class RewriteForLoopVectorization extends 
StatementBlockRewriteRule
                return ret;
        }
        
+       private static boolean checkItervarIndexing(IndexingOp ix, String 
itervar, boolean row) {
+               return ix.isRowLowerEqualsUpper() 
+                       && ix.getInput(row?1:3) instanceof DataOp
+                       && ix.getInput(row?1:3).getName().equals(itervar);
+       }
+       
        private static boolean[] checkLeftAndRightIndexing(LeftIndexingOp lix, 
IndexingOp rix, String itervar) {
                boolean[] ret = new boolean[2]; //apply, rowIx
                
diff --git 
a/src/test/java/org/apache/sysds/test/functions/vect/AutoVectorizationTest.java 
b/src/test/java/org/apache/sysds/test/functions/vect/AutoVectorizationTest.java
index 0b0b301224..7771c96a8a 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/vect/AutoVectorizationTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/vect/AutoVectorizationTest.java
@@ -213,42 +213,36 @@ public class AutoVectorizationTest extends 
AutomatedTestBase
                runVectorizationTest( TEST_NAME24 ); 
        }
        
-       /**
-        * 
-        * @param cfc
-        * @param vt
-        */
        private void runVectorizationTest( String testName ) 
        {
                String TEST_NAME = testName;
                
                try
-               {               
+               {
                        TestConfiguration config = 
getTestConfiguration(TEST_NAME);
                        loadTestConfiguration(config);
 
-                   String HOME = SCRIPT_DIR + TEST_DIR;
+                       String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + TEST_NAME + ".dml";
                        programArgs = new String[]{"-explain","-args", 
input("A"), output("R") };
                        
                        fullRScriptName = HOME + TEST_NAME + ".R";
-                       rCmd = getRCmd(inputDir(), expectedDir());              
+                       rCmd = getRCmd(inputDir(), expectedDir());
                        
                        //generate input
                        double[][] A = getRandomMatrix(rows, cols, 0, 1, 1.0, 
7);
                        writeInputMatrixWithMTD("A", A, true);  
                        
                        //run tests
-               runTest(true, false, null, -1);
-               runRScript(true);
-               
-               //compare results
-               HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromOutputDir("R");
+                       runTest(true, false, null, -1);
+                       runRScript(true);
+                       
+                       //compare results
+                       HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromOutputDir("R");
                        HashMap<CellIndex, Double> rfile  = 
readRMatrixFromExpectedDir("R");
                        TestUtils.compareMatrices(dmlfile, rfile, 1e-14, "DML", 
"R");           
                }
-               catch(Exception ex)
-               {
+               catch(Exception ex) {
                        throw new RuntimeException(ex);
                }
        }

Reply via email to