This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 55075f8a76 [SYSTEMDS-3784] Fix weighted unary-mm rewrite test cases,
part 2
55075f8a76 is described below
commit 55075f8a76f98c7ae2eb1453cd24aec984979629
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Oct 24 13:04:13 2024 +0200
[SYSTEMDS-3784] Fix weighted unary-mm rewrite test cases, part 2
---
...tiReturnParameterizedBuiltinFEDInstruction.java | 1 -
.../RewriteSimplifyWeightedUnaryMMTest.java | 17 +++++-------
.../rewrite/RewriteSimplifyWeightedUnaryMM.R | 31 +++++++---------------
.../rewrite/RewriteSimplifyWeightedUnaryMM.dml | 5 ++--
4 files changed, 18 insertions(+), 36 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
index 917226fef7..69e0361ee7 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
@@ -126,7 +126,6 @@ public class MultiReturnParameterizedBuiltinFEDInstruction
extends ComputationFE
CPOperand in2 = new CPOperand(parts[2]);
int pos = 3;
boolean metaReturn = true;
- System.out.println(Arrays.toString(parts));
if( parts.length == 7 ) //no need for meta data
metaReturn = new
CPOperand(parts[pos++]).getLiteral().getBooleanValue();
outputs.add(new CPOperand(parts[pos],
Types.ValueType.FP64, Types.DataType.MATRIX));
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedUnaryMMTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedUnaryMMTest.java
index 84f7ebfe04..e9bad6736b 100644
---
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedUnaryMMTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedUnaryMMTest.java
@@ -28,7 +28,6 @@ import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
-import org.junit.Ignore;
import org.junit.Test;
public class RewriteSimplifyWeightedUnaryMMTest extends AutomatedTestBase {
@@ -60,7 +59,7 @@ public class RewriteSimplifyWeightedUnaryMMTest extends
AutomatedTestBase {
@Test
public void testWeightedUnaryMMExpRewrite(){
- testRewriteSimplifyWeightedUnaryMM(1, true); //pattern: W *
exp(U%*%t(V))
+ testRewriteSimplifyWeightedUnaryMM(1, true); //pattern: W *
exp(U%*%t(V))
}
@Test
@@ -70,7 +69,7 @@ public class RewriteSimplifyWeightedUnaryMMTest extends
AutomatedTestBase {
@Test
public void testWeightedUnaryMMAbsRewrite(){
- testRewriteSimplifyWeightedUnaryMM(2, true); //pattern: W *
abs(U%*%t(V))
+ testRewriteSimplifyWeightedUnaryMM(2, true); //pattern: W *
abs(U%*%t(V))
}
@Test
@@ -80,7 +79,7 @@ public class RewriteSimplifyWeightedUnaryMMTest extends
AutomatedTestBase {
@Test
public void testWeightedUnaryMMSinRewrite(){
- testRewriteSimplifyWeightedUnaryMM(3, true); //pattern: W *
sin(U%*%t(V))
+ testRewriteSimplifyWeightedUnaryMM(3, true); //pattern: W *
sin(U%*%t(V))
}
/**
@@ -95,7 +94,7 @@ public class RewriteSimplifyWeightedUnaryMMTest extends
AutomatedTestBase {
@Test
public void testWeightedUnaryMMScalarRightRewrite(){
- testRewriteSimplifyWeightedUnaryMM(4, true); //pattern:
(W*(U%*%t(V)))*2
+ testRewriteSimplifyWeightedUnaryMM(4, true); //pattern:
(W*(U%*%t(V)))*2
}
@Test
@@ -105,7 +104,7 @@ public class RewriteSimplifyWeightedUnaryMMTest extends
AutomatedTestBase {
@Test
public void testWeightedUnaryMMScalarLeftRewrite(){
- testRewriteSimplifyWeightedUnaryMM(5, true); //pattern:
2*(W*(U%*%t(V)))
+ testRewriteSimplifyWeightedUnaryMM(5, true); //pattern:
2*(W*(U%*%t(V)))
}
@Test
@@ -114,9 +113,8 @@ public class RewriteSimplifyWeightedUnaryMMTest extends
AutomatedTestBase {
}
@Test
- @Ignore //FIXME non-applied rewrite
public void testWeightedUnaryMMMultLeftRewrite(){
- testRewriteSimplifyWeightedUnaryMM(8, true); //pattern: W *
(c * (U%*%t(V)))
+ testRewriteSimplifyWeightedUnaryMM(8, true); //pattern: W * (2
* (U%*%t(V)))
}
@Test
@@ -125,9 +123,8 @@ public class RewriteSimplifyWeightedUnaryMMTest extends
AutomatedTestBase {
}
@Test
- @Ignore //FIXME non-applied rewrite
public void testWeightedUnaryMMMultRightRewrite(){
- testRewriteSimplifyWeightedUnaryMM(12, true); //pattern: W *
((U%*%t(V)) * c)
+ testRewriteSimplifyWeightedUnaryMM(12, true); //pattern: W *
((U%*%t(V)) * 2)
}
diff --git
a/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.R
b/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.R
index 10015491e6..cdc3a4dafc 100644
--- a/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.R
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.R
@@ -33,35 +33,22 @@ U = as.matrix(readMM(paste(args[1], "U.mtx", sep="")))
V = as.matrix(readMM(paste(args[1], "V.mtx", sep="")))
W = as.matrix(readMM(paste(args[1], "W.mtx", sep="")))
type = as.integer(args[2])
-c = 4.0
# Perform operations
-if(type == 1 || type == 14){
+if(type == 1){
R = W * exp(U%*%t(V))
-} else if(type == 2 || type == 15){
+} else if(type == 2){
R = W * abs(U%*%t(V))
-} else if(type == 3 || type == 16){
+} else if(type == 3){
R = W * sin(U%*%t(V))
-} else if(type == 4 || type == 17){
+} else if(type == 4){
R = (W*(U%*%t(V)))*2
-} else if(type == 5 || type == 18){
+} else if(type == 5){
R = 2*(W*(U%*%t(V)))
-} else if(type == 6 || type == 19){
- R = W * (c + U%*%t(V))
-} else if(type == 7 || type == 20){
- R = W * (c - U%*%t(V))
-} else if(type == 8 || type == 21){
- R = W * (c * (U%*%t(V)))
-} else if(type == 9 || type == 22){
- R = W * (c / (U%*%t(V)))
-} else if(type == 10 || type == 23){
- R = W * (U%*%t(V) + c)
-} else if(type == 11 || type == 24){
- R = W * (U%*%t(V) - c)
-} else if(type == 12 || type == 25){
- R = W * ((U%*%t(V)) * c)
-} else if(type == 13 || type == 26){
- R = W * ((U%*%t(V)) / c)
+} else if(type == 8){
+ R = W * (2 * (U%*%t(V)))
+} else if(type == 12){
+ R = W * ((U%*%t(V)) * 2)
}
#Write result matrix R
diff --git
a/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.dml
b/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.dml
index bda9da8d06..ea335d81a0 100644
--- a/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.dml
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.dml
@@ -24,7 +24,6 @@ U = read($1)
V = read($2)
W = read($3)
type = $4
-c = 4.0
# Perform operations
if(type == 1){
@@ -43,10 +42,10 @@ else if(type == 5){
R = 2*(W*(U%*%t(V)))
}
else if(type == 8){
- R = W * (c * (U%*%t(V)))
+ R = W * (2 * (U%*%t(V)))
}
else if(type == 12){
- R = W * ((U%*%t(V)) * c)
+ R = W * ((U%*%t(V)) * 2)
}
# Write the result matrix R