This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new cef8a6f271 [SYSTEMDS-3798,3807] Improved loop vectorization rewrite,
code coverage
cef8a6f271 is described below
commit cef8a6f271ea2b5e9f815a657c4dc13b5e780290
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Dec 13 15:07:31 2024 +0100
[SYSTEMDS-3798,3807] Improved loop vectorization rewrite, code coverage
---
.github/workflows/javaTests.yml | 3 +-
.../hops/rewrite/RewriteForLoopVectorization.java | 77 ++++++++++++++++++++++
.../test/functions/rewrite/RewriteIfElseTest.java | 1 -
.../rewrite/RewriteLoopVectorization.java | 2 -
4 files changed, 79 insertions(+), 4 deletions(-)
diff --git a/.github/workflows/javaTests.yml b/.github/workflows/javaTests.yml
index c2cab87c22..b4341c544c 100644
--- a/.github/workflows/javaTests.yml
+++ b/.github/workflows/javaTests.yml
@@ -86,7 +86,8 @@ jobs:
"**.functions.transform.**","**.functions.unique.**",
"**.functions.unary.matrix.**,**.functions.linearization.**,**.functions.jmlc.**"
]
- java: [11]
+ java: ['11']
+ javadist: ['adopt']
name: ${{ matrix.tests }}
steps:
- name: Checkout Repository
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.java
index 1d2223dcf9..0c09c2efb4 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.java
@@ -88,6 +88,9 @@ public class RewriteForLoopVectorization extends
StatementBlockRewriteRule
//e.g., for(i in a:b){s = s +
as.scalar(X[i,2])} -> s = sum(X[a:b,2])
sb = vectorizeScalarAggregate(sb, csb,
from, to, incr, iterVar);
+ //e.g., for(i in a:b){s = s + X[i,2]}
-> s = sum(X[a:b,2])
+ sb = vectorizeScalarAggregate2(sb, csb,
from, to, incr, iterVar);
+
//e.g., for(i in a:b){X[i,2] = Y[i,1] +
Z[i,3]} -> X[a:b,2] = Y[a:b,1] + Z[a:b,3];
sb = vectorizeElementwiseBinary(sb,
csb, from, to, incr, iterVar);
@@ -205,6 +208,80 @@ public class RewriteForLoopVectorization extends
StatementBlockRewriteRule
return ret;
}
+ private static StatementBlock vectorizeScalarAggregate2( StatementBlock
sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar )
+ {
+ StatementBlock ret = sb;
+
+ //check for applicability
+ boolean leftScalar = false;
+ boolean rightScalar = false;
+ boolean rowIx = false; //row or col
+
+ if( csb.getHops()!=null && csb.getHops().size()==1 ) {
+ Hop root = csb.getHops().get(0);
+
+ if( root.getDataType()==DataType.SCALAR &&
root.getInput(0) instanceof BinaryOp ) {
+ BinaryOp bop = (BinaryOp) root.getInput(0);
+ Hop left = bop.getInput(0);
+ Hop right = bop.getInput(1);
+
+ //check for left scalar plus
+ if( HopRewriteUtils.isValidOp(bop.getOp(),
MAP_SCALAR_AGGREGATE_SOURCE_OPS)
+ && left instanceof DataOp &&
left.getDataType() == DataType.SCALAR
+ && root.getName().equals(left.getName())
+ && right instanceof IndexingOp &&
right.isScalar())
+ {
+ leftScalar = true;
+ rowIx = true; //row and col
+ }
+ //check for right scalar plus
+ else if( HopRewriteUtils.isValidOp(bop.getOp(),
MAP_SCALAR_AGGREGATE_SOURCE_OPS)
+ && right instanceof DataOp &&
right.getDataType() == DataType.SCALAR
+ &&
root.getName().equals(right.getName())
+ && left instanceof IndexingOp &&
left.isScalar())
+ {
+ rightScalar = true;
+ rowIx = true; //row and col
+ }
+ }
+ }
+
+ //apply rewrite if possible
+ if( leftScalar || rightScalar ) {
+ Hop root = csb.getHops().get(0);
+ BinaryOp bop = (BinaryOp) root.getInput(0);
+ Hop ix = bop.getInput().get( leftScalar?1:0 );
+ int aggOpPos =
HopRewriteUtils.getValidOpPos(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS);
+ AggOp aggOp = MAP_SCALAR_AGGREGATE_TARGET_OPS[aggOpPos];
+
+ //replace cast with sum
+ AggUnaryOp newSum =
HopRewriteUtils.createAggUnaryOp(ix, aggOp, Direction.RowCol);
+ HopRewriteUtils.removeChildReference(bop, ix);
+ HopRewriteUtils.addChildReference(bop, newSum,
leftScalar?1:0 );
+
+ //modify indexing expression according to loop
predicate from-to
+ //NOTE: any redundant index operations are removed via
dynamic algebraic simplification rewrites
+ int index1 = rowIx ? 1 : 3;
+ int index2 = rowIx ? 2 : 4;
+ HopRewriteUtils.replaceChildReference(ix,
ix.getInput().get(index1), from, index1);
+ HopRewriteUtils.replaceChildReference(ix,
ix.getInput().get(index2), to, index2);
+
+ //update indexing size information
+ if( rowIx )
+ ((IndexingOp)ix).setRowLowerEqualsUpper(false);
+ else
+ ((IndexingOp)ix).setColLowerEqualsUpper(false);
+ ix.setDataType(DataType.MATRIX);
+ ix.refreshSizeInformation();
+ Hop.resetVisitStatus(csb.getHops(), true);
+
+ ret = csb;
+ LOG.debug("Applied vectorizeScalarSumForLoop2.");
+ }
+
+ return ret;
+ }
+
private static StatementBlock vectorizeElementwiseBinary(
StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String
itervar )
{
StatementBlock ret = sb;
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteIfElseTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteIfElseTest.java
index 087fc49d98..1e7abfb03b 100644
---
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteIfElseTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteIfElseTest.java
@@ -23,7 +23,6 @@ import java.util.HashMap;
import org.junit.Assert;
import org.junit.Test;
-import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.common.Types.ExecType;
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteLoopVectorization.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteLoopVectorization.java
index 927b0fd666..d9358fef30 100644
---
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteLoopVectorization.java
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteLoopVectorization.java
@@ -22,7 +22,6 @@ package org.apache.sysds.test.functions.rewrite;
import java.util.HashMap;
import org.junit.Assert;
-import org.junit.Ignore;
import org.junit.Test;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
@@ -58,7 +57,6 @@ public class RewriteLoopVectorization extends
AutomatedTestBase
}
@Test
- @Ignore //FIXME: extend loop vectorization rewrite
public void testLoopVectorizationSumRewrite() {
testRewriteLoopVectorizationSum( TEST_NAME1, true );
}