This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 12d8cd70af [SYSTEMDS-3784] Fix weighted unary-mm rewrite test cases
12d8cd70af is described below
commit 12d8cd70afa2156bda74c0b8e5d6d11d27e75c2a
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed Oct 23 15:11:01 2024 +0200
[SYSTEMDS-3784] Fix weighted unary-mm rewrite test cases
---
.../RewriteAlgebraicSimplificationDynamic.java | 7 +-
.../RewriteSimplifyWeightedUnaryMMTest.java | 174 +++------------------
.../rewrite/RewriteSimplifyWeightedUnaryMM.dml | 72 +--------
3 files changed, 32 insertions(+), 221 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 5d894170df..396c40d114 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -85,8 +85,11 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
private static OpOp2[] LOOKUP_VALID_WDIVMM_BINARY = new
OpOp2[]{OpOp2.MULT, OpOp2.DIV};
//valid unary and binary operators for wumm
- private static OpOp1[] LOOKUP_VALID_WUMM_UNARY = new OpOp1[]{OpOp1.ABS,
OpOp1.ROUND, OpOp1.CEIL, OpOp1.FLOOR, OpOp1.EXP, OpOp1.LOG, OpOp1.SQRT,
OpOp1.SIGMOID, OpOp1.SPROP};
- private static OpOp2[] LOOKUP_VALID_WUMM_BINARY = new
OpOp2[]{OpOp2.MULT, OpOp2.POW};
+ private static OpOp1[] LOOKUP_VALID_WUMM_UNARY = new OpOp1[]{
+ OpOp1.ABS, OpOp1.ROUND, OpOp1.CEIL, OpOp1.FLOOR, OpOp1.EXP,
OpOp1.LOG,
+ OpOp1.SQRT, OpOp1.SIN, OpOp1.COS, OpOp1.SIGMOID, OpOp1.SPROP};
+ private static OpOp2[] LOOKUP_VALID_WUMM_BINARY = new OpOp2[]{
+ OpOp2.MULT, OpOp2.POW};
@Override
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots,
ProgramRewriteStatus state) {
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedUnaryMMTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedUnaryMMTest.java
index aab2970913..84f7ebfe04 100644
---
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedUnaryMMTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedUnaryMMTest.java
@@ -19,10 +19,16 @@
package org.apache.sysds.test.functions.rewrite;
+import java.util.HashMap;
+
import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Ignore;
import org.junit.Test;
public class RewriteSimplifyWeightedUnaryMMTest extends AutomatedTestBase {
@@ -31,9 +37,8 @@ public class RewriteSimplifyWeightedUnaryMMTest extends
AutomatedTestBase {
private static final String TEST_CLASS_DIR =
TEST_DIR +
RewriteSimplifyWeightedUnaryMMTest.class.getSimpleName() + "/";
- private static final int rows = 100;
- private static final int cols = 100;
- //private static final double eps = Math.pow(10, -7);
+ private static final int rows = 1123; //larger than blocksize needed
+ private static final int cols = 1245;
@Override
public void setUp() {
@@ -103,166 +108,28 @@ public class RewriteSimplifyWeightedUnaryMMTest extends
AutomatedTestBase {
testRewriteSimplifyWeightedUnaryMM(5, true); //pattern:
2*(W*(U%*%t(V)))
}
- /**
- * These tests cover the case for the third pattern
- * W * sop(U%*%t(V), c) or W * sop(U%*%t(V), c), where
- * sop stands for scalar operation (+, -, *, /) and c represents
- * some constant scalar.
- * */
-
- @Test
- public void testWeightedUnaryMMAddLeftNoRewrite(){
- testRewriteSimplifyWeightedUnaryMM(6, false);
- }
-
- @Test
- public void testWeightedUnaryMMAddLeftRewrite(){
- testRewriteSimplifyWeightedUnaryMM(6, true); //pattern: W *
(c + U%*%t(V))
- }
-
- @Test
- public void testWeightedUnaryMMMinusLeftNoRewrite(){
- testRewriteSimplifyWeightedUnaryMM(7, false);
- }
-
- @Test
- public void testWeightedUnaryMMMinusLeftRewrite(){
- testRewriteSimplifyWeightedUnaryMM(7, true); //pattern: W *
(c - U%*%t(V))
- }
-
@Test
public void testWeightedUnaryMMMultLeftNoRewrite(){
testRewriteSimplifyWeightedUnaryMM(8, false);
}
@Test
+ @Ignore //FIXME non-applied rewrite
public void testWeightedUnaryMMMultLeftRewrite(){
testRewriteSimplifyWeightedUnaryMM(8, true); //pattern: W *
(c * (U%*%t(V)))
}
- @Test
- public void testWeightedUnaryMMDivLeftNoRewrite(){
- testRewriteSimplifyWeightedUnaryMM(9, false);
- }
-
- @Test
- public void testWeightedUnaryMMDivLeftRewrite(){
- testRewriteSimplifyWeightedUnaryMM(9, true); //pattern: W *
(c / (U%*%t(V)))
- }
-
- // Same pattern but scalar from right instead of left
-
- @Test
- public void testWeightedUnaryMMAddRightNoRewrite(){
- testRewriteSimplifyWeightedUnaryMM(10, false);
- }
-
- @Test
- public void testWeightedUnaryMMAddRightRewrite(){
- testRewriteSimplifyWeightedUnaryMM(10, true); //pattern: W *
(U%*%t(V) + c)
- }
-
- @Test
- public void testWeightedUnaryMMMinusRightNoRewrite(){
- testRewriteSimplifyWeightedUnaryMM(11, false);
- }
-
- @Test
- public void testWeightedUnaryMMMinusRightRewrite(){
- testRewriteSimplifyWeightedUnaryMM(11, true); //pattern: W *
(U%*%t(V) - c)
- }
-
@Test
public void testWeightedUnaryMMMulRightNoRewrite(){
testRewriteSimplifyWeightedUnaryMM(12, false);
}
@Test
+ @Ignore //FIXME non-applied rewrite
public void testWeightedUnaryMMMultRightRewrite(){
testRewriteSimplifyWeightedUnaryMM(12, true); //pattern: W *
((U%*%t(V)) * c)
}
- @Test
- public void testWeightedUnaryMMDivRightNoRewrite(){
- testRewriteSimplifyWeightedUnaryMM(13, false);
- }
-
- @Test
- public void testWeightedUnaryMMDivRightRewrite(){
- testRewriteSimplifyWeightedUnaryMM(13, true); //pattern: W *
((U%*%t(V)) / c)
- }
-
- /**
- * Here, we omit the transpose in the dml script. The rewrite should
catch the missing transpose
- * and replace V with t(V).
- **/
-
- @Test
- public void testWeightedUnaryMMExpNoTranspose(){
- testRewriteSimplifyWeightedUnaryMM(14, true); //pattern: W *
exp(U%*%V)
- }
-
- @Test
- public void testWeightedUnaryMMAbsNoTranspose(){
- testRewriteSimplifyWeightedUnaryMM(15, true); //pattern: W *
abs(U%*%V)
- }
-
- @Test
- public void testWeightedUnaryMMSinNoTranspose(){
- testRewriteSimplifyWeightedUnaryMM(16, true); //pattern: W *
sin(U%*%V)
- }
-
- @Test
- public void testWeightedUnaryMMScalarRightNoTranspose(){
- testRewriteSimplifyWeightedUnaryMM(17, true); //pattern:
(W*(U%*%V))*2
- }
-
- @Test
- public void testWeightedUnaryMMScalarLeftNoTranspose(){
- testRewriteSimplifyWeightedUnaryMM(18, true); //pattern:
2*(W*(U%*%V))
- }
-
- @Test
- public void testWeightedUnaryMMAddLeftNoTranspose(){
- testRewriteSimplifyWeightedUnaryMM(19, true); //pattern: W *
(c + U%*%V)
- }
-
- @Test
- public void testWeightedUnaryMMMinusLeftNoTranspose(){
- testRewriteSimplifyWeightedUnaryMM(20, true); //pattern: W *
(c - U%*%V)
- }
-
- @Test
- public void testWeightedUnaryMMMultLeftNoTranspose(){
- testRewriteSimplifyWeightedUnaryMM(21, true); //pattern: W *
(c * (U%*%V))
- }
-
- @Test
- public void testWeightedUnaryMMDivLeftNoTranspose(){
- testRewriteSimplifyWeightedUnaryMM(22, true); //pattern: W *
(c / (U%*%V))
- }
-
- @Test
- public void testWeightedUnaryMMAddRightNoTranspose(){
- testRewriteSimplifyWeightedUnaryMM(23, true); //pattern: W *
(U%*%V + c)
- }
-
- @Test
- public void testWeightedUnaryMMMinusRightNoTranspose(){
- testRewriteSimplifyWeightedUnaryMM(24, true); //pattern: W *
(U%*%V - c)
- }
-
- @Test
- public void testWeightedUnaryMMMultRightNoTranspose(){
- testRewriteSimplifyWeightedUnaryMM(25, true); //pattern: W *
((U%*%V) * c)
- }
-
- @Test
- public void testWeightedUnaryMMDivRightNoTranspose(){
- testRewriteSimplifyWeightedUnaryMM(26, true); //pattern: W *
((U%*%V) / c)
- }
-
-
private void testRewriteSimplifyWeightedUnaryMM(int ID, boolean
rewrites) {
boolean oldFlag1 =
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
@@ -280,11 +147,13 @@ public class RewriteSimplifyWeightedUnaryMMTest extends
AutomatedTestBase {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
rewrites;
OptimizerUtils.ALLOW_OPERATOR_FUSION = rewrites;
+ Recompiler.reinitRecompiler();
//create matrices
- double[][] U = getRandomMatrix(rows, cols, -1, 1,
0.80d, 3);
- double[][] V = getRandomMatrix(rows, cols, -1, 1,
0.70d, 4);
- double[][] W = getRandomMatrix(rows, cols, -1, 1,
0.60d, 5);
+ int rank = 50;
+ double[][] U = getRandomMatrix(rows, rank, -1, 1,
0.80d, 3);
+ double[][] V = getRandomMatrix(cols, rank, -1, 1,
0.70d, 4);
+ double[][] W = getRandomMatrix(rows, cols, -1, 1,
0.01d, 5);
writeInputMatrixWithMTD("U", U, true);
writeInputMatrixWithMTD("V", V, true);
writeInputMatrixWithMTD("W", W, true);
@@ -293,15 +162,10 @@ public class RewriteSimplifyWeightedUnaryMMTest extends
AutomatedTestBase {
runRScript(true);
//compare matrices
-// FIXME
-// HashMap<MatrixValue.CellIndex, Double> dmlfile =
readDMLMatrixFromOutputDir("R");
-// HashMap<MatrixValue.CellIndex, Double> rfile =
readRMatrixFromExpectedDir("R");
-// TestUtils.compareMatrices(dmlfile, rfile, eps,
"Stat-DML", "Stat-R");
-// if(rewrites)
-//
Assert.assertTrue(heavyHittersContainsString("wumm"));
-// else
-//
Assert.assertFalse(heavyHittersContainsString("wumm"));
-
+ HashMap<MatrixValue.CellIndex, Double> dmlfile =
readDMLMatrixFromOutputDir("R");
+ HashMap<MatrixValue.CellIndex, Double> rfile =
readRMatrixFromExpectedDir("R");
+ TestUtils.compareMatrices(dmlfile, rfile, 1e-8,
"Stat-DML", "Stat-R");
+
Assert.assertTrue(heavyHittersContainsString("wumm")==rewrites);
}
finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
oldFlag1;
diff --git
a/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.dml
b/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.dml
index 300d2d11ea..bda9da8d06 100644
--- a/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.dml
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.dml
@@ -28,83 +28,27 @@ c = 4.0
# Perform operations
if(type == 1){
- R = W * exp(U%*%t(V))
+ R = W * exp(U%*%t(V))
}
else if(type == 2){
- R = W * abs(U%*%t(V))
+ R = W * abs(U%*%t(V))
}
else if(type == 3){
- R = W * sin(U%*%t(V))
+ R = W * sin(U%*%t(V))
}
else if(type == 4){
- R = (W*(U%*%t(V)))*2
+ R = (W*(U%*%t(V)))*2
}
else if(type == 5){
- R = 2*(W*(U%*%t(V)))
-}
-else if(type == 6){
- R = W * (c + U%*%t(V))
-}
-else if(type == 7){
- R = W * (c - U%*%t(V))
+ R = 2*(W*(U%*%t(V)))
}
else if(type == 8){
- R = W * (c * (U%*%t(V)))
-}
-else if(type == 9){
- R = W * (c / (U%*%t(V)))
-}
-else if(type == 10){
- R = W * (U%*%t(V) + c)
-}
-else if(type == 11){
- R = W * (U%*%t(V) - c)
+ R = W * (c * (U%*%t(V)))
}
else if(type == 12){
- R = W * ((U%*%t(V)) * c)
-}
-else if(type == 13){
- R = W * ((U%*%t(V)) / c)
-}
-else if(type == 14){
- R = W * exp(U%*%V)
-}
-else if(type == 15){
- R = W * abs(U%*%V)
-}
-else if(type == 16){
- R = W * sin(U%*%V)
-}
-else if(type == 17){
- R = (W*(U%*%V))*2
-}
-else if(type == 18){
- R = 2*(W*(U%*%V))
-}
-else if(type == 19){
- R = W * (c + U%*%V)
-}
-else if(type == 20){
- R = W * (c - U%*%V)
-}
-else if(type == 21){
- R = W * (c * (U%*%V))
-}
-else if(type == 22){
- R = W * (c / (U%*%V))
-}
-else if(type == 23){
- R = W * (U%*%V + c)
-}
-else if(type == 24){
- R = W * (U%*%V - c)
-}
-else if(type == 25){
- R = W * ((U%*%V) * c)
-}
-else if(type == 26){
- R = W * ((U%*%V) / c)
+ R = W * ((U%*%t(V)) * c)
}
# Write the result matrix R
write(R, $5)
+