This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 4a48a689af [SYSTEMDS-3889] New simplification rewrite for
matrix-scalar ops
4a48a689af is described below
commit 4a48a689afb56239d3df81926606f720e9501fc4
Author: aarna <[email protected]>
AuthorDate: Fri Jun 13 15:37:33 2025 +0200
[SYSTEMDS-3889] New simplification rewrite for matrix-scalar ops
e.g., a-A-b -> (a-b)-A; a+A-b -> (a-b)+A
Closes #2272.
---
.../RewriteAlgebraicSimplificationStatic.java | 35 ++++++++
...RewriteSimplifyScalarMatrixPMOperationTest.java | 98 ++++++++++++++++++++++
.../rewrite/RewriteScalarMinusMatrixMinusScalar.R | 30 +++++++
.../RewriteScalarMinusMatrixMinusScalar.dml | 28 +++++++
.../rewrite/RewriteScalarPlusMatrixMinusScalar.R | 30 +++++++
.../rewrite/RewriteScalarPlusMatrixMinusScalar.dml | 28 +++++++
6 files changed, 249 insertions(+)
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 65c8805c7c..ef5670dda8 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -202,6 +202,7 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
hi = simplifyNegatedSubtraction(hop, hi, i);
//e.g., -(B-A)->A-B
hi = simplifyTransposeAddition(hop, hi, i);
//e.g., t(A+s1)+s2 -> t(A)+(s1+s2) + potential constant folding
hi = simplifyNotOverComparisons(hop, hi, i);
//e.g., !(A>B) -> (A<=B)
+ hi = simplifyMatrixScalarPMOperation(hop, hi, i);
//e.g., a-A-b -> (a-b)-A; a+A-b -> (a-b)+A
//hi = removeUnecessaryPPred(hop, hi, i);
//e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
//process childs recursively after rewrites (to
investigate pattern newly created by rewrites)
@@ -212,6 +213,40 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
hop.setVisited();
}
+ private Hop simplifyMatrixScalarPMOperation(Hop parent, Hop hi, int
pos) {
+ if (!(hi instanceof BinaryOp))
+ return hi;
+
+ BinaryOp outer = (BinaryOp) hi;
+ Hop left = outer.getInput(0);
+ Hop right = outer.getInput(1);
+ OpOp2 outerOp = outer.getOp();
+
+ if((outerOp != OpOp2.PLUS && outerOp != OpOp2.MINUS) || !(left
instanceof BinaryOp))
+ return hi;
+
+ Hop a = left.getInput(0);
+ Hop A = left.getInput(1);
+ Hop b = right;
+
+ java.util.function.Predicate<Hop> isScalar = h ->
h.getDataType().isScalar();
+ if (!isScalar.test(a) || !isScalar.test(b) || A.getDataType()
!= DataType.MATRIX)
+ return hi;
+
+ // Determine the scalarOp (between a and b) and matrixOp (with
A)
+ OpOp2 innerOp = ((BinaryOp)left).getOp();
+ if( innerOp != OpOp2.PLUS && innerOp != OpOp2.MINUS )
+ return hi;
+ OpOp2 scalarOp = (outerOp == OpOp2.PLUS) ? OpOp2.PLUS :
OpOp2.MINUS;
+ OpOp2 matrixOp = (innerOp == OpOp2.PLUS) ? OpOp2.PLUS :
OpOp2.MINUS;
+ Hop scalarCombined = HopRewriteUtils.createBinary(a, b,
scalarOp);
+ Hop result = HopRewriteUtils.createBinary(scalarCombined, A,
matrixOp);
+
+ HopRewriteUtils.replaceChildReference(parent, hi, result, pos);
+ LOG.debug("Applied simplifyMatrixScalarPMOperation");
+ return result;
+ }
+
private static Hop simplifyTransposeAddition(Hop parent, Hop hi, int
pos) {
//pattern: t(A+s1)+s2 -> t(A)+(s1+s2), and subsequent constant
folding
if (HopRewriteUtils.isBinary(hi, OpOp2.PLUS)
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java
new file mode 100644
index 0000000000..64d3b06544
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysds.test.functions.rewrite;
+
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+public class RewriteSimplifyScalarMatrixPMOperationTest extends
AutomatedTestBase {
+ private static final String TEST_NAME1 =
"RewriteScalarMinusMatrixMinusScalar";
+ private static final String TEST_NAME2 =
"RewriteScalarPlusMatrixMinusScalar";
+ private static final String TEST_DIR = "functions/rewrite/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
RewriteSimplifyScalarMatrixPMOperationTest.class.getSimpleName() + "/";
+ private static final int rows = 100;
+ private static final int cols = 100;
+ private static final double eps = 1e-6;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"A", "a", "b",
"R"}));
+ addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"A", "a", "b",
"R"}));
+ }
+
+ @Test
+ public void testScalarMinusMatrixMinusScalarRewriteEnabled() {
+ runRewriteTest(TEST_NAME1, true);
+ }
+
+ @Test
+ public void testScalarMinusMatrixMinusScalarRewriteDisabled() {
+ runRewriteTest(TEST_NAME1, false);
+ }
+
+ @Test
+ public void testScalarPlusMatrixMinusScalarRewriteEnabled() {
+ runRewriteTest(TEST_NAME2, true);
+ }
+
+ @Test
+ public void testScalarPlusMatrixMinusScalarRewriteDisabled() {
+ runRewriteTest(TEST_NAME2, false);
+ }
+
+ private void runRewriteTest(String testName, boolean rewriteEnabled) {
+ boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+ try {
+ TestConfiguration config =
getTestConfiguration(testName);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testName + ".dml";
+ fullRScriptName = HOME + testName + ".R";
+ programArgs = new String[]{"-stats", "-args",
input("A"), input("a"), input("b"), output("R")};
+ rCmd = getRCmd(inputDir(), expectedDir());
+
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
rewriteEnabled;
+
+ double[][] A = getRandomMatrix(rows, cols, -100, 100,
0.9, 3);
+ double[][] a = getRandomMatrix(1, 1, -10, 10, 1.0, 7);
+ double[][] b = getRandomMatrix(1, 1, -10, 10, 1.0, 5);
+
+ writeInputMatrixWithMTD("A", A, true);
+ writeInputMatrixWithMTD("a", a, true);
+ writeInputMatrixWithMTD("b", b, true);
+
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ HashMap<MatrixValue.CellIndex, Double> dml =
readDMLMatrixFromOutputDir("R");
+ HashMap<MatrixValue.CellIndex, Double> r =
readRMatrixFromExpectedDir("R");
+ TestUtils.compareMatrices(dml, r, eps, "Stat-DML",
"Stat-R");
+ } finally {
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+ }
+ }
+}
diff --git
a/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.R
b/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.R
new file mode 100644
index 0000000000..bd9ab23ed2
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.R
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+args <- commandArgs(TRUE)
+library("Matrix")
+
+A <- as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+a <- as.numeric(readMM(paste(args[1], "a.mtx", sep="")))
+b <- as.numeric(readMM(paste(args[1], "b.mtx", sep="")))
+
+R <- (a-b)-A
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""))
diff --git
a/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.dml
b/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.dml
new file mode 100644
index 0000000000..28cdb61dec
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+A = read($1);
+a = read($2);
+b = read($3);
+
+R = a - A - b;
+
+write(R, $4);
+
diff --git
a/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.R
b/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.R
new file mode 100644
index 0000000000..ec2764bb28
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.R
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+args <- commandArgs(TRUE)
+library("Matrix")
+
+A <- as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+a <- as.numeric(readMM(paste(args[1], "a.mtx", sep="")))
+b <- as.numeric(readMM(paste(args[1], "b.mtx", sep="")))
+
+R <- (a-b)+A
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""))
diff --git
a/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.dml
b/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.dml
new file mode 100644
index 0000000000..5ba04566ef
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+A = read($1);
+a = as.scalar(read($2));
+b = as.scalar(read($3));
+
+# Original form: a + A - b
+R = a + A - b;
+
+write(R, $4);