This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new f3b638a48f [SYSTEMDS-3812] Improved rewrites pushdow-sum and rm-reorg
f3b638a48f is described below
commit f3b638a48f843dca3ac9b963100e2146c88e7751
Author: aarna <[email protected]>
AuthorDate: Fri Jan 10 15:46:57 2025 +0100
[SYSTEMDS-3812] Improved rewrites pushdow-sum and rm-reorg
Closes #2176.
---
.../RewriteAlgebraicSimplificationDynamic.java | 106 ++++++++++++++-------
.../functions/aggregate/PushdownSumBinaryTest.java | 36 ++-----
.../rewrite/RewritePushdownSumBinaryMult.java | 5 -
.../rewrite/RewritePushdownSumOnBinaryTest.java | 84 +++++++++++-----
.../rewrite/RewritePushdownSumOnBinary.dml | 27 ++++--
5 files changed, 160 insertions(+), 98 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index d73f8489b6..ddb2252f51 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -381,30 +381,28 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
return hi;
}
-
- private static Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi,
int pos)
- {
- if( hi instanceof ReorgOp )
- {
+
+ private static Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi,
int pos) {
+ if( hi instanceof ReorgOp ) {
ReorgOp rop = (ReorgOp) hi;
- Hop input = hi.getInput(0);
+ Hop input = hi.getInput(0);
boolean apply = false;
-
- //equal dims of reshape input and output -> no need for
reshape because
+
+ //equal dims of reshape input and output -> no need for
reshape because
//byrow always refers to both input/output and hence
gives the same result
apply |= (rop.getOp()==ReOrgOp.RESHAPE &&
HopRewriteUtils.isEqualSize(hi, input));
-
- //1x1 dimensions of transpose/reshape -> no need for
reorg
- apply |= ((rop.getOp()==ReOrgOp.TRANS ||
rop.getOp()==ReOrgOp.RESHAPE)
- && rop.getDim1()==1 &&
rop.getDim2()==1);
-
+
+ //1x1 dimensions of transpose/reshape/roll -> no need
for reorg
+ apply |= ((rop.getOp()==ReOrgOp.TRANS ||
rop.getOp()==ReOrgOp.RESHAPE
+ || rop.getOp()==ReOrgOp.ROLL) &&
rop.getDim1()==1 && rop.getDim2()==1);
+
if( apply ) {
HopRewriteUtils.replaceChildReference(parent,
hi, input, pos);
hi = input;
LOG.debug("Applied removeUnnecessaryReorg.");
}
}
-
+
return hi;
}
@@ -1356,44 +1354,78 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
* @param pos position
* @return high-level operator
*/
- private static Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int
pos)
+ private static Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int
pos)
{
//all patterns headed by full sum over binary operation
if( hi instanceof AggUnaryOp //full sum root over binaryop
- && ((AggUnaryOp)hi).getDirection()==Direction.RowCol
- && ((AggUnaryOp)hi).getOp() == AggOp.SUM
- && hi.getInput(0) instanceof BinaryOp
- && hi.getInput(0).getParent().size()==1 ) //single
parent
+ &&
((AggUnaryOp)hi).getDirection()==Direction.RowCol
+ && ((AggUnaryOp)hi).getOp() == AggOp.SUM
+ && hi.getInput(0) instanceof BinaryOp
+ && hi.getInput(0).getParent().size()==1 )
//single parent
{
BinaryOp bop = (BinaryOp) hi.getInput(0);
Hop left = bop.getInput(0);
Hop right = bop.getInput(1);
-
- if( HopRewriteUtils.isEqualSize(left, right) //dims(A)
== dims(B)
- && left.getDataType() == DataType.MATRIX
- && right.getDataType() == DataType.MATRIX )
+
+ if( left.getDataType() == DataType.MATRIX
+ && right.getDataType() ==
DataType.MATRIX )
{
OpOp2 applyOp = ( bop.getOp() == OpOp2.PLUS
//pattern a: sum(A+B)->sum(A)+sum(B)
|| bop.getOp() == OpOp2.MINUS )
//pattern b: sum(A-B)->sum(A)-sum(B)
? bop.getOp() : null;
-
+
if( applyOp != null ) {
- //create new subdag sum(A) bop sum(B)
- AggUnaryOp sum1 =
HopRewriteUtils.createSum(left);
- AggUnaryOp sum2 =
HopRewriteUtils.createSum(right);
- BinaryOp newBin =
HopRewriteUtils.createBinary(sum1, sum2, applyOp);
-
- //rewire new subdag
-
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
- HopRewriteUtils.cleanupUnreferenced(hi,
bop);
-
- hi = newBin;
-
- LOG.debug("Applied
pushdownSumOnAdditiveBinary (line "+hi.getBeginLine()+").");
+ if (HopRewriteUtils.isEqualSize(left,
right)) {
+ //create new subdag sum(A) bop
sum(B) for equal-sized matrices
+ AggUnaryOp sum1 =
HopRewriteUtils.createSum(left);
+ AggUnaryOp sum2 =
HopRewriteUtils.createSum(right);
+ BinaryOp newBin =
HopRewriteUtils.createBinary(sum1, sum2, applyOp);
+ //rewire new subdag
+
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
+
HopRewriteUtils.cleanupUnreferenced(hi, bop);
+
+ hi = newBin;
+
+ LOG.debug("Applied
pushdownSumOnAdditiveBinary (line "+hi.getBeginLine()+").");
+ }
+ // Check if right operand is a vector
(has dimension of 1 in either rows or columns)
+ else if (right.getDim1() == 1 ||
right.getDim2() == 1) {
+ AggUnaryOp sum1 =
HopRewriteUtils.createSum(left);
+ AggUnaryOp sum2 =
HopRewriteUtils.createSum(right);
+
+ // Row vector case (1 x n)
+ if (right.getDim1() == 1) {
+ // Create nrow(A)
operation using dimensions
+ LiteralOp nRows = new
LiteralOp(left.getDim1());
+ BinaryOp scaledSum =
HopRewriteUtils.createBinary(nRows, sum2, OpOp2.MULT);
+ BinaryOp newBin =
HopRewriteUtils.createBinary(sum1, scaledSum, applyOp);
+ //rewire new subdag
+
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
+
HopRewriteUtils.cleanupUnreferenced(hi, bop);
+
+ hi = newBin;
+
+ LOG.debug("Applied
pushdownSumOnAdditiveBinary with row vector (line "+hi.getBeginLine()+").");
+ }
+ // Column vector case (n x 1)
+ else if (right.getDim2() == 1) {
+ // Create ncol(A)
operation using dimensions
+ LiteralOp nCols = new
LiteralOp(left.getDim2());
+ BinaryOp scaledSum =
HopRewriteUtils.createBinary(nCols, sum2, OpOp2.MULT);
+ BinaryOp newBin =
HopRewriteUtils.createBinary(sum1, scaledSum, applyOp);
+ //rewire new subdag
+
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
+
HopRewriteUtils.cleanupUnreferenced(hi, bop);
+
+ hi = newBin;
+
+ LOG.debug("Applied
pushdownSumOnAdditiveBinary with column vector (line "+hi.getBeginLine()+").");
+ }
+ }
}
}
}
-
+
return hi;
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/aggregate/PushdownSumBinaryTest.java
b/src/test/java/org/apache/sysds/test/functions/aggregate/PushdownSumBinaryTest.java
index d4ac5fc6dc..3e7286c274 100644
---
a/src/test/java/org/apache/sysds/test/functions/aggregate/PushdownSumBinaryTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/aggregate/PushdownSumBinaryTest.java
@@ -25,10 +25,8 @@ import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
-import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.ExecType;
-import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
@@ -89,39 +87,24 @@ public class PushdownSumBinaryTest extends AutomatedTestBase
}
@Test
- public void testPushDownSumPlusNoRewriteSP() {
+ public void testPushDownSumPlusBroadcastSP() {
runPushdownSumOnBinaryTest(TEST_NAME1, false, ExecType.SPARK);
}
@Test
- public void testPushDownSumMinusNoRewriteSP() {
+ public void testPushDownSumMinusBroadcastSP() {
runPushdownSumOnBinaryTest(TEST_NAME2, false, ExecType.SPARK);
}
-
- /**
- *
- * @param testname
- * @param type
- * @param sparse
- * @param instType
- */
+
private void runPushdownSumOnBinaryTest( String testname, boolean
equiDims, ExecType instType)
{
//rtplatform for MR
- ExecMode platformOld = rtplatform;
- switch( instType ){
- case SPARK: rtplatform = ExecMode.SPARK; break;
- default: rtplatform = ExecMode.HYBRID; break;
- }
-
- boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
- if( rtplatform == ExecMode.SPARK )
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ ExecMode platformOld = setExecMode(instType);
try
{
//determine script and function name
- String TEST_NAME = testname;
+ String TEST_NAME = testname;
String TEST_CACHE_DIR = TEST_CACHE_ENABLED ? TEST_NAME
+ "_" + String.valueOf(equiDims) + "/" : "";
TestConfiguration config =
getTestConfiguration(TEST_NAME);
@@ -150,13 +133,10 @@ public class PushdownSumBinaryTest extends
AutomatedTestBase
TestUtils.compareMatrices(dmlfile, rfile, eps,
"Stat-DML", "Stat-R");
String lopcode = TEST_NAME.equals(TEST_NAME1) ? "+" :
"-";
- String opcode = equiDims ? lopcode :
Instruction.SP_INST_PREFIX+"map"+lopcode;
- Assert.assertTrue("Non-applied rewrite",
Statistics.getCPHeavyHitterOpCodes().contains(opcode));
+ Assert.assertTrue("Non-applied rewrite",
Statistics.getCPHeavyHitterOpCodes().contains(lopcode));
}
- finally
- {
- rtplatform = platformOld;
- DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ finally {
+ resetExecMode(platformOld);
}
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownSumBinaryMult.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownSumBinaryMult.java
index 60ce24f105..cb135e21c8 100644
---
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownSumBinaryMult.java
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownSumBinaryMult.java
@@ -68,11 +68,6 @@ public class RewritePushdownSumBinaryMult extends
AutomatedTestBase
testRewritePushdownSumBinaryMult( TEST_NAME2, true );
}
- /**
- *
- * @param testname
- * @param rewrites
- */
private void testRewritePushdownSumBinaryMult( String testname, boolean
rewrites )
{
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownSumOnBinaryTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownSumOnBinaryTest.java
index 9391af719a..d9459b03a9 100644
---
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownSumOnBinaryTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownSumOnBinaryTest.java
@@ -29,54 +29,94 @@ import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
-public class RewritePushdownSumOnBinaryTest extends AutomatedTestBase
+public class RewritePushdownSumOnBinaryTest extends AutomatedTestBase
{
private static final String TEST_NAME1 = "RewritePushdownSumOnBinary";
private static final String TEST_DIR = "functions/rewrite/";
private static final String TEST_CLASS_DIR = TEST_DIR +
RewritePushdownSumOnBinaryTest.class.getSimpleName() + "/";
-
+
private static final int rows = 1000;
private static final int cols = 1;
-
+ private static final double eps = 1e-8;
+
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
- addTestConfiguration( TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R1", "R2" }) );
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1,
+ new String[] { "R1", "R2", "R3", "R4" }));
+ }
+
+ @Test
+ public void testRewritePushdownSumOnBinaryNoRewrite() {
+ testRewritePushdownSumOnBinary(TEST_NAME1, false);
+ }
+
+ @Test
+ public void testRewritePushdownSumOnBinary() {
+ testRewritePushdownSumOnBinary(TEST_NAME1, true);
}
@Test
- public void testRewritePushdownSumOnBinaryNoRewrite() {
- testRewritePushdownSumOnBinary( TEST_NAME1, false );
+ public void testRewritePushdownSumOnBinaryRowVector() {
+ testRewritePushdownSumOnBinaryVector(TEST_NAME1, true, true);
}
-
+
@Test
- public void testRewritePushdownSumOnBinary() {
- testRewritePushdownSumOnBinary( TEST_NAME1, true );
+ public void testRewritePushdownSumOnBinaryColVector() {
+ testRewritePushdownSumOnBinaryVector(TEST_NAME1, true, false);
}
-
- private void testRewritePushdownSumOnBinary( String testname, boolean
rewrites )
- {
+
+ private void testRewritePushdownSumOnBinary(String testname, boolean
rewrites) {
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
-
+
try {
TestConfiguration config =
getTestConfiguration(testname);
loadTestConfiguration(config);
-
+
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
- programArgs = new String[]{ "-args",
String.valueOf(rows),
- String.valueOf(cols), output("R1"),
output("R2") };
+
+ programArgs = new String[]{ "-args",
String.valueOf(rows),
+ String.valueOf(cols), output("R1"),
output("R2"),
+ String.valueOf(rows),
String.valueOf(cols) }; // Assuming row and col vectors
+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
rewrites;
- //run performance tests
+ // Run performance tests
runTest(true, false, null, -1);
-
- //compare matrices
- long expect = Math.round(0.5*rows);
+
+ // Compare matrices
+ long expect = Math.round(0.5 * rows);
HashMap<CellIndex, Double> dmlfile1 =
readDMLScalarFromOutputDir("R1");
- Assert.assertEquals(expect, dmlfile1.get(new
CellIndex(1,1)), expect*0.01);
+ Assert.assertEquals(expect, dmlfile1.get(new
CellIndex(1, 1)), eps);
HashMap<CellIndex, Double> dmlfile2 =
readDMLScalarFromOutputDir("R2");
- Assert.assertEquals(expect, dmlfile2.get(new
CellIndex(1,1)), expect*0.01);
+ Assert.assertEquals(expect, dmlfile2.get(new
CellIndex(1, 1)), eps);
+ } finally {
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+ }
+ }
+
+
+ private void testRewritePushdownSumOnBinaryVector(String testname,
boolean rewrites, boolean isRow) {
+ boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+ try {
+ TestConfiguration config =
getTestConfiguration(testname);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testname + ".dml";
+ programArgs = new String[]{ "-args",
String.valueOf(rows),
+ String.valueOf(cols), output("R3"),
output("R4"),
+ String.valueOf(isRow ? 1 : rows),
String.valueOf(isRow ? cols : 1) };
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
rewrites;
+
+ runTest(true, false, null, -1);
+
+ long expect = Math.round(500); // Expected value for
0.5 + 0.5
+ HashMap<CellIndex, Double> dmlfile3 =
readDMLScalarFromOutputDir("R3");
+ Assert.assertEquals(expect, dmlfile3.get(new
CellIndex(1,1)), eps);
+ HashMap<CellIndex, Double> dmlfile4 =
readDMLScalarFromOutputDir("R4");
+ Assert.assertEquals(expect, dmlfile4.get(new
CellIndex(1,1)), eps);
}
finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
diff --git a/src/test/scripts/functions/rewrite/RewritePushdownSumOnBinary.dml
b/src/test/scripts/functions/rewrite/RewritePushdownSumOnBinary.dml
index d48ac0aad8..0d1b812397 100644
--- a/src/test/scripts/functions/rewrite/RewritePushdownSumOnBinary.dml
+++ b/src/test/scripts/functions/rewrite/RewritePushdownSumOnBinary.dml
@@ -19,15 +19,30 @@
#
#-------------------------------------------------------------
-A = rand(rows=$1, cols=$2, seed=1);
-B = rand(rows=$1, cols=$2, seed=2);
-C = rand(rows=$1, cols=$2, seed=3);
-D = rand(rows=$1, cols=$2, seed=4);
+# Required parameters
+A = matrix(0.5, rows=$1, cols=$2);
+B = matrix(0.5, rows=$1, cols=$2);
+C = matrix(0.5, rows=$1, cols=$2);
+D = matrix(0.5, rows=$1, cols=$2);
+# Set defaults for optional parameters
+rowsV = ifdef($5, 0)
+colsV = ifdef($6, 0)
+
+# Original matrix tests
r1 = sum(A*B + C*D);
r2 = r1;
-print("r1="+r1+", r2="+r2);
+# Vector tests
+if (rowsV != 0 & colsV != 0) {
+ V = matrix(0.5, rows=rowsV, cols=colsV);
+ r3 = sum(A + V);
+ r4 = r3;
+}
+
write(r1, $3);
write(r2, $4);
-
+if (rowsV != 0 & colsV != 0) {
+ write(r3, $5);
+ write(r4, $6);
+}
\ No newline at end of file