This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 55075f8a76 [SYSTEMDS-3784] Fix weighted unary-mm rewrite test cases, 
part 2
55075f8a76 is described below

commit 55075f8a76f98c7ae2eb1453cd24aec984979629
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Oct 24 13:04:13 2024 +0200

    [SYSTEMDS-3784] Fix weighted unary-mm rewrite test cases, part 2
---
 ...tiReturnParameterizedBuiltinFEDInstruction.java |  1 -
 .../RewriteSimplifyWeightedUnaryMMTest.java        | 17 +++++-------
 .../rewrite/RewriteSimplifyWeightedUnaryMM.R       | 31 +++++++---------------
 .../rewrite/RewriteSimplifyWeightedUnaryMM.dml     |  5 ++--
 4 files changed, 18 insertions(+), 36 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
index 917226fef7..69e0361ee7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
@@ -126,7 +126,6 @@ public class MultiReturnParameterizedBuiltinFEDInstruction 
extends ComputationFE
                        CPOperand in2 = new CPOperand(parts[2]);
                        int pos = 3;
                        boolean metaReturn = true;
-                       System.out.println(Arrays.toString(parts));
                        if( parts.length == 7 ) //no need for meta data
                                metaReturn = new 
CPOperand(parts[pos++]).getLiteral().getBooleanValue();
                        outputs.add(new CPOperand(parts[pos], 
Types.ValueType.FP64, Types.DataType.MATRIX));
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedUnaryMMTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedUnaryMMTest.java
index 84f7ebfe04..e9bad6736b 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedUnaryMMTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedUnaryMMTest.java
@@ -28,7 +28,6 @@ import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 import org.junit.Assert;
-import org.junit.Ignore;
 import org.junit.Test;
 
 public class RewriteSimplifyWeightedUnaryMMTest extends AutomatedTestBase {
@@ -60,7 +59,7 @@ public class RewriteSimplifyWeightedUnaryMMTest extends 
AutomatedTestBase {
 
        @Test
        public void testWeightedUnaryMMExpRewrite(){
-               testRewriteSimplifyWeightedUnaryMM(1, true);    //pattern: W * 
exp(U%*%t(V))
+               testRewriteSimplifyWeightedUnaryMM(1, true); //pattern: W * 
exp(U%*%t(V))
        }
 
        @Test
@@ -70,7 +69,7 @@ public class RewriteSimplifyWeightedUnaryMMTest extends 
AutomatedTestBase {
 
        @Test
        public void testWeightedUnaryMMAbsRewrite(){
-               testRewriteSimplifyWeightedUnaryMM(2, true);    //pattern: W * 
abs(U%*%t(V))
+               testRewriteSimplifyWeightedUnaryMM(2, true); //pattern: W * 
abs(U%*%t(V))
        }
 
        @Test
@@ -80,7 +79,7 @@ public class RewriteSimplifyWeightedUnaryMMTest extends 
AutomatedTestBase {
 
        @Test
        public void testWeightedUnaryMMSinRewrite(){
-               testRewriteSimplifyWeightedUnaryMM(3, true);    //pattern: W * 
sin(U%*%t(V))
+               testRewriteSimplifyWeightedUnaryMM(3, true); //pattern: W * 
sin(U%*%t(V))
        }
 
        /**
@@ -95,7 +94,7 @@ public class RewriteSimplifyWeightedUnaryMMTest extends 
AutomatedTestBase {
 
        @Test
        public void testWeightedUnaryMMScalarRightRewrite(){
-               testRewriteSimplifyWeightedUnaryMM(4, true);    //pattern: 
(W*(U%*%t(V)))*2
+               testRewriteSimplifyWeightedUnaryMM(4, true); //pattern: 
(W*(U%*%t(V)))*2
        }
 
        @Test
@@ -105,7 +104,7 @@ public class RewriteSimplifyWeightedUnaryMMTest extends 
AutomatedTestBase {
 
        @Test
        public void testWeightedUnaryMMScalarLeftRewrite(){
-               testRewriteSimplifyWeightedUnaryMM(5, true);    //pattern: 
2*(W*(U%*%t(V)))
+               testRewriteSimplifyWeightedUnaryMM(5, true); //pattern: 
2*(W*(U%*%t(V)))
        }
 
        @Test
@@ -114,9 +113,8 @@ public class RewriteSimplifyWeightedUnaryMMTest extends 
AutomatedTestBase {
        }
 
        @Test
-       @Ignore //FIXME non-applied rewrite
        public void testWeightedUnaryMMMultLeftRewrite(){
-               testRewriteSimplifyWeightedUnaryMM(8, true);    //pattern: W * 
(c * (U%*%t(V)))
+               testRewriteSimplifyWeightedUnaryMM(8, true); //pattern: W * (2 
* (U%*%t(V)))
        }
 
        @Test
@@ -125,9 +123,8 @@ public class RewriteSimplifyWeightedUnaryMMTest extends 
AutomatedTestBase {
        }
 
        @Test
-       @Ignore //FIXME non-applied rewrite
        public void testWeightedUnaryMMMultRightRewrite(){
-               testRewriteSimplifyWeightedUnaryMM(12, true);   //pattern: W * 
((U%*%t(V)) * c)
+               testRewriteSimplifyWeightedUnaryMM(12, true); //pattern: W * 
((U%*%t(V)) * 2)
        }
 
 
diff --git 
a/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.R 
b/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.R
index 10015491e6..cdc3a4dafc 100644
--- a/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.R
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.R
@@ -33,35 +33,22 @@ U = as.matrix(readMM(paste(args[1], "U.mtx", sep="")))
 V = as.matrix(readMM(paste(args[1], "V.mtx", sep="")))
 W = as.matrix(readMM(paste(args[1], "W.mtx", sep="")))
 type = as.integer(args[2])
-c = 4.0
 
 # Perform operations
-if(type == 1 || type == 14){
+if(type == 1){
     R = W * exp(U%*%t(V))
-} else if(type == 2 || type == 15){
+} else if(type == 2){
     R = W * abs(U%*%t(V))
-} else if(type == 3 || type == 16){
+} else if(type == 3){
     R = W * sin(U%*%t(V))
-} else if(type == 4 || type == 17){
+} else if(type == 4){
     R = (W*(U%*%t(V)))*2
-} else if(type == 5 || type == 18){
+} else if(type == 5){
     R = 2*(W*(U%*%t(V)))
-} else if(type == 6 || type == 19){
-    R = W * (c + U%*%t(V))
-} else if(type == 7 || type == 20){
-    R = W * (c - U%*%t(V))
-} else if(type == 8 || type == 21){
-    R = W * (c * (U%*%t(V)))
-} else if(type == 9 || type == 22){
-    R = W * (c / (U%*%t(V)))
-} else if(type == 10 || type == 23){
-    R = W * (U%*%t(V) + c)
-} else if(type == 11 || type == 24){
-    R = W * (U%*%t(V) - c)
-} else if(type == 12 || type == 25){
-    R = W * ((U%*%t(V)) * c)
-} else if(type == 13 || type == 26){
-    R = W * ((U%*%t(V)) / c)
+} else if(type == 8){
+    R = W * (2 * (U%*%t(V)))
+} else if(type == 12){
+    R = W * ((U%*%t(V)) * 2)
 }
 
 #Write result matrix R
diff --git 
a/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.dml 
b/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.dml
index bda9da8d06..ea335d81a0 100644
--- a/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.dml
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.dml
@@ -24,7 +24,6 @@ U = read($1)
 V = read($2)
 W = read($3)
 type = $4
-c = 4.0
 
 # Perform operations
 if(type == 1){
@@ -43,10 +42,10 @@ else if(type == 5){
   R = 2*(W*(U%*%t(V)))
 }
 else if(type == 8){
-  R = W * (c * (U%*%t(V)))
+  R = W * (2 * (U%*%t(V)))
 }
 else if(type == 12){
-  R = W * ((U%*%t(V)) * c)
+  R = W * ((U%*%t(V)) * 2)
 }
 
 # Write the result matrix R

Reply via email to