This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new ceb50a2d21 [SYSTEMDS-3665] New rewrite for mmult-add expressions
ceb50a2d21 is described below
commit ceb50a2d2175390267796e0cfd8620ca251c1e3d
Author: ReneEnjilian <[email protected]>
AuthorDate: Sat Jan 20 01:00:12 2024 +0100
[SYSTEMDS-3665] New rewrite for mmult-add expressions
A%*%B + A%*%C -> A%*%(B+C) iff A, B, and C dense and the target
expression reduces the number of floating points operations.
Closes #1986.
---
.../RewriteAlgebraicSimplificationDynamic.java | 68 +++++++++----
.../rewrite/RewriteDistributiveMatrixMultTest.java | 107 +++++++++++++++++++++
.../rewrite/RewriteDistributiveMatrixMult.R | 41 ++++++++
.../rewrite/RewriteDistributiveMatrixMult.dml | 32 ++++++
4 files changed, 231 insertions(+), 17 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index e181c60a78..3e1c498f01 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -176,6 +176,7 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
hi = simplifyScalarMatrixMult(hop, hi, i);
//e.g., X%*%y -> X*as.scalar(y), if y is a 1-1 matrix
hi = simplifyMatrixMultDiag(hop, hi, i);
//e.g., diag(X)%*%Y -> X*Y, if ncol(Y)==1 / -> Y*X if ncol(Y)>1
hi = simplifyDiagMatrixMult(hop, hi, i);
//e.g., diag(X%*%Y)->rowSums(X*t(Y)); if col vector
+ hi = simplifyDistributiveMatrixMult(hop, hi, i);
//e.g., (A%*%B)+(A%*%C) -> A%*%(B+C)
hi = simplifySumDiagToTrace(hi);
//e.g., sum(diag(X)) -> trace(X); if col vector
hi = simplifyLowerTriExtraction(hop, hi, i);
//e.g., X * cumsum(diag(matrix(1,nrow(X),1))) -> lower.tri
hi = simplifyConstantCumsum(hop, hi, i);
//e.g., cumsum(matrix(1/n,n,1)) -> seq(1/n, 1, 1/n)
@@ -1137,46 +1138,79 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
return hi;
}
-
- private static Hop simplifyDiagMatrixMult(Hop parent, Hop hi, int pos)
- {
- if( hi instanceof ReorgOp &&
((ReorgOp)hi).getOp()==ReOrgOp.DIAG && hi.getDim2()==1 ) //diagM2V
+
+ private static Hop simplifyDiagMatrixMult(Hop parent, Hop hi, int pos) {
+ if(hi instanceof ReorgOp && ((ReorgOp) hi).getOp() ==
ReOrgOp.DIAG && hi.getDim2() == 1) //diagM2V
{
Hop hi2 = hi.getInput().get(0);
- if( HopRewriteUtils.isMatrixMultiply(hi2) ) //X%*%Y
+ if(HopRewriteUtils.isMatrixMultiply(hi2)) //X%*%Y
{
Hop left = hi2.getInput().get(0);
Hop right = hi2.getInput().get(1);
-
+
//create new operators (incl refresh size
inside for transpose)
ReorgOp trans =
HopRewriteUtils.createTranspose(right);
BinaryOp mult =
HopRewriteUtils.createBinary(left, trans, OpOp2.MULT);
AggUnaryOp rowSum =
HopRewriteUtils.createAggUnaryOp(mult, AggOp.SUM, Direction.Row);
-
+
//rehang new subdag under parent node
HopRewriteUtils.replaceChildReference(parent,
hi, rowSum, pos);
HopRewriteUtils.cleanupUnreferenced(hi, hi2);
-
+
hi = rowSum;
LOG.debug("Applied simplifyDiagMatrixMult");
- }
+ }
}
-
+
return hi;
}
-
- private static Hop simplifySumDiagToTrace(Hop hi)
- {
- if( hi instanceof AggUnaryOp )
+
+ private static Hop simplifyDistributiveMatrixMult(Hop parent, Hop hi,
int pos) {
+ // A%*%B + A%*%C -> A%*%(B+C)
+ if(HopRewriteUtils.isBinary(hi, OpOp2.PLUS)
+ && HopRewriteUtils.isMatrixMultiply(hi.getInput(0))
+ && HopRewriteUtils.isMatrixMultiply(hi.getInput(1))
+ && hi.getInput(0).getParent().size() == 1 //single
consumer
+ && hi.getInput(1).getParent().size() == 1 //single
consumer
+ && hi.getInput(0).getInput(0) ==
hi.getInput(1).getInput(0)) //common A
{
+ Hop A = hi.getInput(0).getInput(0);
+ Hop B = hi.getInput(0).getInput(1);
+ Hop C = hi.getInput(1).getInput(1);
+ boolean dense = HopRewriteUtils.isDense(A)
+ && HopRewriteUtils.isDense(B) &&
HopRewriteUtils.isDense(C);
+ //compute floating point and mem bandwidth requirements
and
+ //according for special cases where C might be a column
vector
+ long m = A.getDim1(), n = A.getDim2(), l = B.getDim2(),
o = C.getDim2();
+ long costOriginal = m * n * l + m * n * o + m * l //FLOP
+ + m*n + n*l + n*o + m*l + m*o +
m*l; //I/O ABC+intermediates
+ long costRewrite = n * l + m * n * l //FLOP
+ + m*n + n*l + n*o + n*l + m*l;
//I/O ABC+intermediates
+ //Check that rewrite reduces FLOPs
+ if(dense && costRewrite < costOriginal) {
+ Hop BplusC = HopRewriteUtils.createBinary(B, C,
OpOp2.PLUS);
+ Hop newHop =
HopRewriteUtils.createMatrixMultiply(A, BplusC);
+ if(parent != null) {
+
HopRewriteUtils.replaceChildReference(parent, hi, newHop, pos);
+ HopRewriteUtils.cleanupUnreferenced(hi);
+ hi = newHop;
+ LOG.debug("Applied
simplifyDistributiveMatrixMult (line " + hi.getBeginLine() + ")");
+ }
+ }
+ }
+ return hi;
+ }
+
+ private static Hop simplifySumDiagToTrace(Hop hi) {
+ if(hi instanceof AggUnaryOp) {
AggUnaryOp au = (AggUnaryOp) hi;
- if( au.getOp()==AggOp.SUM &&
au.getDirection()==Direction.RowCol ) //sum
+ if(au.getOp() == AggOp.SUM && au.getDirection() ==
Direction.RowCol) //sum
{
Hop hi2 = au.getInput().get(0);
- if( hi2 instanceof ReorgOp &&
((ReorgOp)hi2).getOp()==ReOrgOp.DIAG && hi2.getDim2()==1 ) //diagM2V
+ if(hi2 instanceof ReorgOp && ((ReorgOp)
hi2).getOp() == ReOrgOp.DIAG && hi2.getDim2() == 1) //diagM2V
{
Hop hi3 = hi2.getInput().get(0);
-
+
//remove diag operator
HopRewriteUtils.replaceChildReference(au, hi2, hi3, 0);
HopRewriteUtils.cleanupUnreferenced(hi2);
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteDistributiveMatrixMultTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteDistributiveMatrixMultTest.java
new file mode 100644
index 0000000000..7f40a2bef3
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteDistributiveMatrixMultTest.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import java.util.HashMap;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+
+public class RewriteDistributiveMatrixMultTest extends AutomatedTestBase {
+ private static final String TEST_NAME1 =
"RewriteDistributiveMatrixMult";
+ private static final String TEST_DIR = "functions/rewrite/";
+ private static final String TEST_CLASS_DIR =
+ TEST_DIR +
RewriteSimplifyRowColSumMVMultTest.class.getSimpleName() + "/";
+
+ private static final int rows = 500;
+ private static final int cols = 500;
+ private static final double eps = Math.pow(10, -10);
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"}));
+
+ }
+
+ @Test
+ public void testDistributiveMatrixMultNoRewrite() {
+ testRewriteDistributiveMatrixMult(TEST_NAME1, false);
+ }
+
+ @Test
+ public void testDistributiveMatrixMultRewrite() {
+ testRewriteDistributiveMatrixMult(TEST_NAME1, true);
+ }
+
+ private void testRewriteDistributiveMatrixMult(String testname, boolean
rewrites) {
+ boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+ try {
+ TestConfiguration config =
getTestConfiguration(testname);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testname + ".dml";
+ programArgs = new String[] {"-stats", "-args",
input("A"), input("B"), input("C"), output("R")};
+
+ fullRScriptName = HOME + testname + ".R";
+ rCmd = getRCmd(inputDir(), expectedDir());
+
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
rewrites;
+ //create dense matrices so that rewrites are possible
+ double[][] A = getRandomMatrix(rows, cols, -1, 1,
0.70d, 7);
+ double[][] B = getRandomMatrix(rows, cols, -1, 1,
0.70d, 6);
+ double[][] C = getRandomMatrix(rows, cols, -1, 1,
0.70d, 3);
+ writeInputMatrixWithMTD("A", A, 174522, true);
+ writeInputMatrixWithMTD("B", B, 174935, true);
+ writeInputMatrixWithMTD("C", C, 174848, true);
+
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ //compare matrices
+ HashMap<CellIndex, Double> dmlfile =
readDMLMatrixFromOutputDir("R");
+ HashMap<CellIndex, Double> rfile =
readRMatrixFromExpectedDir("R");
+ TestUtils.compareMatrices(dmlfile, rfile, eps,
"Stat-DML", "Stat-R");
+
+ //check matrix mult existence
+ String ba = "ba+*";
+ long numMatMul = Statistics.getCPHeavyHitterCount(ba);
+
+ if(rewrites == true) {
+ Assert.assertTrue(numMatMul == 1);
+ }
+ else {
+ Assert.assertTrue(numMatMul == 2);
+ }
+
+ }
+ finally {
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+ }
+
+ }
+}
diff --git a/src/test/scripts/functions/rewrite/RewriteDistributiveMatrixMult.R
b/src/test/scripts/functions/rewrite/RewriteDistributiveMatrixMult.R
new file mode 100644
index 0000000000..7c7a623fbe
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteDistributiveMatrixMult.R
@@ -0,0 +1,41 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Read command line arguments
+args <- commandArgs(TRUE)
+
+# Set options for numeric precision
+options(digits=22)
+
+# Load required libraries
+library("Matrix")
+library("matrixStats")
+
+# Read matrices A, B, and C from Matrix Market format files
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B = as.matrix(readMM(paste(args[1], "B.mtx", sep="")))
+C = as.matrix(readMM(paste(args[1], "C.mtx", sep="")))
+
+# Perform the matrix operation
+R = (A %*% B) + (A %*% C)
+
+# Write the result matrix R in Matrix Market format
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""))
diff --git
a/src/test/scripts/functions/rewrite/RewriteDistributiveMatrixMult.dml
b/src/test/scripts/functions/rewrite/RewriteDistributiveMatrixMult.dml
new file mode 100644
index 0000000000..fc3a3a8cf2
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteDistributiveMatrixMult.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+# Load matrices A, B, and C
+A = read($1)
+B = read($2)
+C = read($3)
+
+# Perform the operation
+R = (A %*% B) + (A %*% C)
+
+# Write the result matrix R
+write(R, $4)