This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 147519e495 [SYSTEMDS-3664] New simplification rewrite rev(seq())
147519e495 is described below
commit 147519e49558493e58e14f359befd56fcf74ffda
Author: aarna <[email protected]>
AuthorDate: Sun Mar 16 18:08:49 2025 +0100
[SYSTEMDS-3664] New simplification rewrite rev(seq())
This patch introduces a new simplification rewrite for reversing a
sequence rev(seq(1,n)) --> seq(n,1).
Closes #2242.
---
.../RewriteAlgebraicSimplificationStatic.java | 54 ++++++++++
.../RewriteSimplifyReverseSequenceStepTest.java | 109 +++++++++++++++++++++
.../rewrite/RewriteSimplifyReverseSequenceStep.dml | 35 +++++++
3 files changed, 198 insertions(+)
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 5d867bf0ff..c46bc62400 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -156,6 +156,7 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
hi = simplifyConstantConjunction(hop, hi, i);
//e.g., a & !a -> FALSE
hi = simplifyReverseOperation(hop, hi, i);
//e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X)
hi = simplifyReverseSequence(hop, hi, i);
//e.g., rev(seq(1,n)) -> seq(n,1)
+ hi = simplifyReverseSequenceStep(hop, hi, i);
//e.g., rev(seq(1,n,2)) -> rev(n,1,-2)
if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
hi = simplifyMultiBinaryToBinaryOperation(hi);
//e.g., 1-X*Y -> X 1-* Y
hi = simplifyDistributiveBinaryOperation(hop, hi,
i);//e.g., (X-Y*X) -> (1-Y)*X
@@ -824,6 +825,59 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
return hi;
}
+
+ private static Hop simplifyReverseSequenceStep(Hop parent, Hop hi, int
pos) {
+ if (HopRewriteUtils.isReorg(hi, ReOrgOp.REV)
+ && hi.getInput(0) instanceof DataGenOp
+ && ((DataGenOp) hi.getInput(0)).getOp() ==
OpOpDG.SEQ
+ && hi.getInput(0).getParent().size() == 1) //
only one consumer
+ {
+ DataGenOp seq = (DataGenOp) hi.getInput(0);
+ Hop from =
seq.getInput().get(seq.getParamIndex(Statement.SEQ_FROM));
+ Hop to =
seq.getInput().get(seq.getParamIndex(Statement.SEQ_TO));
+ Hop incr =
seq.getInput().get(seq.getParamIndex(Statement.SEQ_INCR));
+
+ if (from instanceof LiteralOp && to instanceof
LiteralOp && incr instanceof LiteralOp) {
+ double fromVal = ((LiteralOp)
from).getDoubleValue();
+ double toVal = ((LiteralOp)
to).getDoubleValue();
+ double incrVal = ((LiteralOp)
incr).getDoubleValue();
+
+ // Skip if increment is zero (invalid sequence)
+ if (Math.abs(incrVal) < 1e-10)
+ return hi;
+
+ boolean isValidDirection = false;
+
+ // Checking direction compatibility
+ if ((incrVal > 0 && fromVal <= toVal) ||
(incrVal < 0 && fromVal >= toVal)) {
+ isValidDirection = true;
+ }
+
+ if (isValidDirection) {
+ // Calculate the number of elements and
the last element
+ int numValues =
(int)Math.floor(Math.abs((toVal - fromVal) / incrVal)) + 1;
+ double lastVal = fromVal + (numValues -
1) * incrVal;
+
+ // Create a new sequence based on
actual last value
+ LiteralOp newFrom = new
LiteralOp(lastVal);
+ LiteralOp newTo = new
LiteralOp(fromVal);
+ LiteralOp newIncr = new
LiteralOp(-incrVal);
+
+ // Replace the parameters
+
seq.getInput().set(seq.getParamIndex(Statement.SEQ_FROM), newFrom);
+
seq.getInput().set(seq.getParamIndex(Statement.SEQ_TO), newTo);
+
seq.getInput().set(seq.getParamIndex(Statement.SEQ_INCR), newIncr);
+
+ // Replace the old sequence with the
new one
+
HopRewriteUtils.replaceChildReference(parent, hi, seq, pos);
+ HopRewriteUtils.cleanupUnreferenced(hi,
seq);
+ hi = seq;
+ LOG.debug("Applied
simplifyReverseSequenceStep (line " + hi.getBeginLine() + ").");
+ }
+ }
+ }
+ return hi;
+ }
private static Hop simplifyMultiBinaryToBinaryOperation( Hop hi )
{
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyReverseSequenceStepTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyReverseSequenceStepTest.java
new file mode 100644
index 0000000000..7d176c7aa8
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyReverseSequenceStepTest.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class RewriteSimplifyReverseSequenceStepTest extends AutomatedTestBase {
+ private static final String TEST_NAME1 =
"RewriteSimplifyReverseSequenceStep";
+
+ private static final String TEST_DIR = "functions/rewrite/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
RewriteSimplifyReverseSequenceStepTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"R"}));
+ }
+
+ @Test
+ public void testRewriteReverseSeqStep() {
+ testRewriteReverseSeq(TEST_NAME1, true);
+ }
+
+ @Test
+ public void testNoRewriteReverseSeqStep() {
+ testRewriteReverseSeq(TEST_NAME1, false);
+ }
+
+ private void testRewriteReverseSeq(String testname, boolean rewrites) {
+ boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+ int rows = 10;
+
+ try {
+ TestConfiguration config =
getTestConfiguration(testname);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testname + ".dml";
+ programArgs = new String[]{"-stats", "-args",
String.valueOf(rows), output("Scalar")};
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
rewrites;
+
+ runTest(true, false, null, -1);
+
+ // Calculate expected sums for each sequence
+ double sum1 = calculateSum(0, rows-1, 1); // A1 =
rev(seq(0, rows-1, 1))
+ double sum2 = calculateSum(0, rows, 2); // A2 =
rev(seq(0, rows, 2))
+ double sum3 = calculateSum(2, rows, 2); // A3 =
rev(seq(2, rows, 2))
+ double sum4 = calculateSum(0, 100, 5); // A4 =
rev(seq(0, 100, 5))
+ double sum5 = calculateSum(15, 5, -0.5);
// A5 = rev(seq(15, 5, -0.5))
+
+ double expected = sum1 + sum2 + sum3 + sum4 + sum5;
+
+ double ret =
readDMLScalarFromOutputDir("Scalar").get(new MatrixValue.CellIndex(1,
1)).doubleValue();
+
+ Assert.assertEquals("Incorrect sum computed", expected,
ret, 1e-10);
+
+ if (rewrites) {
+ // With bidirectional rewrite, REV operations
should be removed
+ Assert.assertFalse("Rewrite should have removed
REV operation!",
+
heavyHittersContainsString("rev"));
+ }
+ }
+ finally {
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+ }
+ }
+
+ // Helper method to calculate sum of a sequence
+ private double calculateSum(double from, double to, double incr) {
+ double sum = 0;
+ int n = 0;
+
+ if ((incr > 0 && from <= to) || (incr < 0 && from >= to)) {
+ // Calculate number of elements in the sequence
+ n = (int)Math.floor(Math.abs((to - from) / incr)) + 1;
+
+ // Calculate the last element in the sequence
+ double last = from + (n - 1) * incr;
+
+ // Use arithmetic sequence sum formula: n * (first +
last) / 2
+ sum = n * (from + last) / 2;
+ }
+
+ return sum;
+ }
+}
diff --git
a/src/test/scripts/functions/rewrite/RewriteSimplifyReverseSequenceStep.dml
b/src/test/scripts/functions/rewrite/RewriteSimplifyReverseSequenceStep.dml
new file mode 100644
index 0000000000..e8f3314c26
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyReverseSequenceStep.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+rows = as.integer($1)
+
+# Original test sequences (positive increments)
+A1 = rev(seq(0, rows-1, 1)) # Should become seq(rows-1, 0, -1)
+A2 = rev(seq(0, rows, 2)) # Should become seq(rows, 0, -2)
+A3 = rev(seq(2, rows, 2)) # Should become seq(lastVal, 2, -2) where
lastVal is the last value in the sequence
+A4 = rev(seq(0, 100, 5)) # Should become seq(100, 0, -5)
+A5 = rev(seq(15, 5, -0.5)) # Should become seq(5, 15, 0.5)
+
+# Sum all sequences
+R = sum(A1) + sum(A2) + sum(A3) + sum(A4) + sum(A5)
+
+# Output
+write(R, $2)
\ No newline at end of file