This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 78b23cf417 [SYSTEMDS-3841] Additional multiLogReg tests to check
mmchain rewrite
78b23cf417 is described below
commit 78b23cf41756a8c86583c0d61f2a1ced812e565d
Author: Matthias Boehm <[email protected]>
AuthorDate: Mon Feb 17 16:06:00 2025 +0100
[SYSTEMDS-3841] Additional multiLogReg tests to check mmchain rewrite
As it turns out, there was no bug causing mmchain not being applied
for the builtin function multiLogReg, but we only apply this rewrite
for binary classification not multi-class classification (where only
codegen is capable of doing so). We now added additional tests for
both binary/multi-class and the check for correctly applied mmchain.
---
.../builtin/part2/BuiltinMultiLogRegTest.java | 69 ++++++++++++++++------
1 file changed, 50 insertions(+), 19 deletions(-)
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinMultiLogRegTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinMultiLogRegTest.java
index 34f08e588c..6a5c64ffb8 100644
---
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinMultiLogRegTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinMultiLogRegTest.java
@@ -19,11 +19,10 @@
package org.apache.sysds.test.functions.builtin.part2;
+import org.junit.Assert;
import org.junit.Test;
-import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.ExecType;
-import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
@@ -50,58 +49,90 @@ public class BuiltinMultiLogRegTest extends
AutomatedTestBase {
@Test
public void testMultiLogRegInterceptCP0() {
- runMultiLogeRegTest( 0, tol, 1.0, maxIter, maxInnerIter,
ExecType.CP);
+ runMultiLogeRegTest(0, tol, 1.0, maxIter, maxInnerIter,
ExecType.CP);
}
@Test
public void testMultiLogRegInterceptCP1() {
- runMultiLogeRegTest( 1, tol, 1.0, maxIter, maxInnerIter,
ExecType.CP);
+ runMultiLogeRegTest(1, tol, 1.0, maxIter, maxInnerIter,
ExecType.CP);
}
@Test
public void testMultiLogRegInterceptCP2() {
- runMultiLogeRegTest( 2, tol, 1.0, maxIter, maxInnerIter,
ExecType.CP);
+ runMultiLogeRegTest(2, tol, 1.0, maxIter, maxInnerIter,
ExecType.CP);
}
+ @Test
+ public void testMultiLogRegBinInterceptCP0() {
+ runMultiLogeRegTest(0, tol, 1.0, maxIter, maxInnerIter, 2,
ExecType.CP);
+ }
+ @Test
+ public void testMultiLogRegBinInterceptCP1() {
+ runMultiLogeRegTest(1, tol, 1.0, maxIter, maxInnerIter, 2,
ExecType.CP);
+ }
+ @Test
+ public void testMultiLogRegBinInterceptCP2() {
+ runMultiLogeRegTest(2, tol, 1.0, maxIter, maxInnerIter, 2,
ExecType.CP);
+ }
+
@Test
public void testMultiLogRegInterceptSpark0() {
- runMultiLogeRegTest( 0, tol, 1.0, maxIter, maxInnerIter,
ExecType.SPARK);
+ runMultiLogeRegTest(0, tol, 1.0, maxIter, maxInnerIter,
ExecType.SPARK);
}
@Test
public void testMultiLogRegInterceptSpark1() {
- runMultiLogeRegTest( 1, tol, 1.0, maxIter, maxInnerIter,
ExecType.SPARK);
+ runMultiLogeRegTest(1, tol, 1.0, maxIter, maxInnerIter,
ExecType.SPARK);
}
@Test
public void testMultiLogRegInterceptSpark2() {
runMultiLogeRegTest(2, tol, 1.0, maxIter, maxInnerIter,
ExecType.SPARK);
}
+
+ @Test
+ public void testMultiLogRegBinInterceptSpark0() {
+ runMultiLogeRegTest(0, tol, 1.0, maxIter, maxInnerIter, 2,
ExecType.SPARK);
+ }
+ @Test
+ public void testMultiLogRegBinInterceptSpark1() {
+ runMultiLogeRegTest(1, tol, 1.0, maxIter, maxInnerIter, 2,
ExecType.SPARK);
+ }
+ @Test
+ public void testMultiLogRegBinInterceptSpark2() {
+ runMultiLogeRegTest(2, tol, 1.0, maxIter, maxInnerIter, 2,
ExecType.SPARK);
+ }
- private void runMultiLogeRegTest( int inc, double tol, double reg, int
maxOut, int maxIn, ExecType instType) {
+ private void runMultiLogeRegTest(int inc, double tol, double reg, int
maxOut, int maxIn, ExecType instType) {
+ runMultiLogeRegTest(inc, tol, reg, maxOut, maxIn, 6, instType);
+ }
+
+ private void runMultiLogeRegTest(int inc, double tol, double reg,
+ int maxOut, int maxIn, int numClasses, ExecType instType)
+ {
Types.ExecMode platformOld = setExecMode(instType);
- boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
- boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
-
try {
loadTestConfiguration(getTestConfiguration(TEST_NAME));
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[]{"-nvargs", "X=" +
input("X"), "Y=" + input("Y"), "output=" + output("betas"),
- "inc=" +
String.valueOf(inc).toUpperCase(), "tol=" + tol, "reg=" + reg, "maxOut=" +
maxOut, "maxIn="+maxIn, "verbose=FALSE"};
+ programArgs = new String[]{"-stats","-nvargs",
+ "X=" + input("X"), "Y=" + input("Y"), "output="
+ output("betas"),
+ "inc=" + String.valueOf(inc).toUpperCase(),
"tol=" + tol,
+ "reg=" + reg, "maxOut=" + maxOut,
"maxIn="+maxIn, "verbose=FALSE"};
double[][] X = getRandomMatrix(rows, colsX, 0, 1,
sparse, -1);
- double[][] Y = getRandomMatrix(rows, 1, 0, 5, 1, -1);
+ double[][] Y = getRandomMatrix(rows, 1, 0,
numClasses-1, 1, -1);
Y = TestUtils.round(Y);
writeInputMatrixWithMTD("X", X, true);
writeInputMatrixWithMTD("Y", Y, true);
runTest(true, false, null, -1);
+
+ if(numClasses == 2) {
+ String opcode = instType==ExecType.SPARK ?
"sp_mapmmchain" : "mmchain";
+
Assert.assertTrue(heavyHittersContainsString(opcode));
+ }
}
finally {
- rtplatform = platformOld;
- DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
- OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
- OptimizerUtils.ALLOW_AUTO_VECTORIZATION = true;
- OptimizerUtils.ALLOW_OPERATOR_FUSION = true;
+ resetExecMode(platformOld);
}
}
}