This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 9b6a96dfc0 [SYSTEMDS-3783] Fix wsigmoid rewrite test setup
9b6a96dfc0 is described below

commit 9b6a96dfc0cb0f44d904818edad3ee080b47a11c
Author: Matthias Boehm <[email protected]>
AuthorDate: Mon Oct 21 10:47:31 2024 +0200

    [SYSTEMDS-3783] Fix wsigmoid rewrite test setup
    
    The recent addition of various rewrite tests for code coverage left a
    FIXME on the wsigmoid test which gave incorrect results for all
    variants without transpose. After double checking, it turns out the
    test setup was wrong in the assumptions when the rewrite should apply
    (missing transpose) and how the shapes of involved matrices look like.
---
 .../RewriteSimplifyWeightedSigmoidMMChainsTest.java    | 18 ++++++++++--------
 .../rewrite/RewriteSimplifyWeightedSigmoidMMChains.dml |  2 ++
 2 files changed, 12 insertions(+), 8 deletions(-)

diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedSigmoidMMChainsTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedSigmoidMMChainsTest.java
index beae14ee00..fe3f92ae40 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedSigmoidMMChainsTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedSigmoidMMChainsTest.java
@@ -19,7 +19,10 @@
 
 package org.apache.sysds.test.functions.rewrite;
 
+import java.util.HashMap;
+
 import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
@@ -32,9 +35,8 @@ public class RewriteSimplifyWeightedSigmoidMMChainsTest 
extends AutomatedTestBas
        private static final String TEST_CLASS_DIR =
                TEST_DIR + 
RewriteSimplifyWeightedSigmoidMMChainsTest.class.getSimpleName() + "/";
 
-       private static final int rows = 100;
+       private static final int rows = 150;
        private static final int cols = 100;
-       //private static final double eps = Math.pow(10, -10);
 
        @Override
        public void setUp() {
@@ -125,8 +127,9 @@ public class RewriteSimplifyWeightedSigmoidMMChainsTest 
extends AutomatedTestBas
                        OptimizerUtils.ALLOW_OPERATOR_FUSION = rewrites;
 
                        //create matrices
-                       double[][] X = getRandomMatrix(rows, cols, -1, 1, 
0.80d, 3);
-                       double[][] Y = getRandomMatrix(rows, cols, -1, 1, 
0.70d, 4);
+                       int rank = 50;
+                       double[][] X = getRandomMatrix(cols, rank, -1, 1, 
0.80d, 3);
+                       double[][] Y = getRandomMatrix(rows, rank, -1, 1, 
0.70d, 4);
                        double[][] W = getRandomMatrix(rows, cols, -1, 1, 
0.60d, 5);
                        writeInputMatrixWithMTD("X", X, true);
                        writeInputMatrixWithMTD("Y", Y, true);
@@ -136,10 +139,9 @@ public class RewriteSimplifyWeightedSigmoidMMChainsTest 
extends AutomatedTestBas
                        runRScript(true);
 
                        //compare matrices
-                       // FIXME
-                       // HashMap<MatrixValue.CellIndex, Double> dmlfile = 
readDMLMatrixFromOutputDir("R");
-                       // HashMap<MatrixValue.CellIndex, Double> rfile = 
readRMatrixFromExpectedDir("R");
-                       // compareMatrices(dmlfile, rfile, eps, "Stat-DML", 
"Stat-R");
+                       HashMap<MatrixValue.CellIndex, Double> dmlfile = 
readDMLMatrixFromOutputDir("R");
+                       HashMap<MatrixValue.CellIndex, Double> rfile = 
readRMatrixFromExpectedDir("R");
+                       TestUtils.compareMatrices(dmlfile, rfile, 1e-8, 
"Stat-DML", "Stat-R");
 
                        if(rewrites)
                                
Assert.assertTrue(heavyHittersContainsString("wsigmoid"));
diff --git 
a/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedSigmoidMMChains.dml 
b/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedSigmoidMMChains.dml
index e35c12d703..9e1543f9d9 100644
--- 
a/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedSigmoidMMChains.dml
+++ 
b/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedSigmoidMMChains.dml
@@ -25,6 +25,8 @@ Y = read($2)
 W = read($3)
 type = $4
 
+if( type > 4 )
+  X = t(X);
 
 # Perform operations
 if(type == 1){

Reply via email to