This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 9b6a96dfc0 [SYSTEMDS-3783] Fix wsigmoid rewrite test setup
9b6a96dfc0 is described below
commit 9b6a96dfc0cb0f44d904818edad3ee080b47a11c
Author: Matthias Boehm <[email protected]>
AuthorDate: Mon Oct 21 10:47:31 2024 +0200
[SYSTEMDS-3783] Fix wsigmoid rewrite test setup
The recent addition of various rewrite tests for code coverage left a
FIXME on the wsigmoid test which gave incorrect results for all
variants without transpose. After double checking, it turns out the
test setup was wrong in the assumptions when the rewrite should apply
(missing transpose) and how the shapes of involved matrices look like.
---
.../RewriteSimplifyWeightedSigmoidMMChainsTest.java | 18 ++++++++++--------
.../rewrite/RewriteSimplifyWeightedSigmoidMMChains.dml | 2 ++
2 files changed, 12 insertions(+), 8 deletions(-)
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedSigmoidMMChainsTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedSigmoidMMChainsTest.java
index beae14ee00..fe3f92ae40 100644
---
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedSigmoidMMChainsTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedSigmoidMMChainsTest.java
@@ -19,7 +19,10 @@
package org.apache.sysds.test.functions.rewrite;
+import java.util.HashMap;
+
import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
@@ -32,9 +35,8 @@ public class RewriteSimplifyWeightedSigmoidMMChainsTest
extends AutomatedTestBas
private static final String TEST_CLASS_DIR =
TEST_DIR +
RewriteSimplifyWeightedSigmoidMMChainsTest.class.getSimpleName() + "/";
- private static final int rows = 100;
+ private static final int rows = 150;
private static final int cols = 100;
- //private static final double eps = Math.pow(10, -10);
@Override
public void setUp() {
@@ -125,8 +127,9 @@ public class RewriteSimplifyWeightedSigmoidMMChainsTest
extends AutomatedTestBas
OptimizerUtils.ALLOW_OPERATOR_FUSION = rewrites;
//create matrices
- double[][] X = getRandomMatrix(rows, cols, -1, 1,
0.80d, 3);
- double[][] Y = getRandomMatrix(rows, cols, -1, 1,
0.70d, 4);
+ int rank = 50;
+ double[][] X = getRandomMatrix(cols, rank, -1, 1,
0.80d, 3);
+ double[][] Y = getRandomMatrix(rows, rank, -1, 1,
0.70d, 4);
double[][] W = getRandomMatrix(rows, cols, -1, 1,
0.60d, 5);
writeInputMatrixWithMTD("X", X, true);
writeInputMatrixWithMTD("Y", Y, true);
@@ -136,10 +139,9 @@ public class RewriteSimplifyWeightedSigmoidMMChainsTest
extends AutomatedTestBas
runRScript(true);
//compare matrices
- // FIXME
- // HashMap<MatrixValue.CellIndex, Double> dmlfile =
readDMLMatrixFromOutputDir("R");
- // HashMap<MatrixValue.CellIndex, Double> rfile =
readRMatrixFromExpectedDir("R");
- // compareMatrices(dmlfile, rfile, eps, "Stat-DML",
"Stat-R");
+ HashMap<MatrixValue.CellIndex, Double> dmlfile =
readDMLMatrixFromOutputDir("R");
+ HashMap<MatrixValue.CellIndex, Double> rfile =
readRMatrixFromExpectedDir("R");
+ TestUtils.compareMatrices(dmlfile, rfile, 1e-8,
"Stat-DML", "Stat-R");
if(rewrites)
Assert.assertTrue(heavyHittersContainsString("wsigmoid"));
diff --git
a/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedSigmoidMMChains.dml
b/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedSigmoidMMChains.dml
index e35c12d703..9e1543f9d9 100644
---
a/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedSigmoidMMChains.dml
+++
b/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedSigmoidMMChains.dml
@@ -25,6 +25,8 @@ Y = read($2)
W = read($3)
type = $4
+if( type > 4 )
+ X = t(X);
# Perform operations
if(type == 1){