This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 6644ce3807 [SYSTEMDS-3333] Fix new mmchain-opt rewrite code style and
test
6644ce3807 is described below
commit 6644ce3807795f147d047373988b369f811a2ba8
Author: Matthias Boehm <[email protected]>
AuthorDate: Mon May 4 17:19:30 2026 +0200
[SYSTEMDS-3333] Fix new mmchain-opt rewrite code style and test
* Fixes the remaining invalid method names of the new and existing
mmchain-opt rewrites
* Fixes the handling of stdout streams with default output buffering
(which is on in the github actions)
---
.../RewriteMatrixMultChainOptimization.java | 8 +--
...ewriteMatrixMultChainOptimizationTranspose.java | 8 +--
...ewriteMatrixMultChainWithTransOptimization.java | 8 +--
.../rewrite/RewriteMatrixChainDPTest.java | 71 +++++++++-------------
4 files changed, 40 insertions(+), 55 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java
index 960560c254..884f9e82e8 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java
@@ -50,7 +50,7 @@ public class RewriteMatrixMultChainOptimization extends
HopRewriteRule
// Find the optimal order for the chain whose result is the
current HOP
for( Hop h : roots )
- rule_OptimizeMMChains(h, state);
+ ruleOptimizeMMChains(h, state);
return roots;
}
@@ -62,7 +62,7 @@ public class RewriteMatrixMultChainOptimization extends
HopRewriteRule
return null;
// Find the optimal order for the chain whose result is the
current HOP
- rule_OptimizeMMChains(root, state);
+ ruleOptimizeMMChains(root, state);
return root;
}
@@ -73,7 +73,7 @@ public class RewriteMatrixMultChainOptimization extends
HopRewriteRule
*
* @param hop high-level operator
*/
- private void rule_OptimizeMMChains(Hop hop, ProgramRewriteStatus state)
+ private void ruleOptimizeMMChains(Hop hop, ProgramRewriteStatus state)
{
if( hop.isVisited() )
return;
@@ -87,7 +87,7 @@ public class RewriteMatrixMultChainOptimization extends
HopRewriteRule
}
for( Hop hi : hop.getInput() )
- rule_OptimizeMMChains(hi, state);
+ ruleOptimizeMMChains(hi, state);
hop.setVisited();
}
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimizationTranspose.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimizationTranspose.java
index b327480609..2d204888c2 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimizationTranspose.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimizationTranspose.java
@@ -51,7 +51,7 @@ public class RewriteMatrixMultChainOptimizationTranspose
extends HopRewriteRule
// Find the optimal order for the chain whose result is the
current HOP
for( Hop h : roots )
- rule_OptimizeMMChains(h, state);
+ ruleOptimizeMMChains(h, state);
return roots;
}
@@ -63,7 +63,7 @@ public class RewriteMatrixMultChainOptimizationTranspose
extends HopRewriteRule
return null;
// Find the optimal order for the chain whose result is the
current HOP
- rule_OptimizeMMChains(root, state);
+ ruleOptimizeMMChains(root, state);
return root;
}
@@ -74,7 +74,7 @@ public class RewriteMatrixMultChainOptimizationTranspose
extends HopRewriteRule
*
* @param hop high-level operator
*/
- private void rule_OptimizeMMChains(Hop hop, ProgramRewriteStatus state)
+ private void ruleOptimizeMMChains(Hop hop, ProgramRewriteStatus state)
{
if( !hop.isVisited() ) {
@@ -85,7 +85,7 @@ public class RewriteMatrixMultChainOptimizationTranspose
extends HopRewriteRule
}
for (Hop hi : hop.getInput())
- rule_OptimizeMMChains(hi, state);
+ ruleOptimizeMMChains(hi, state);
hop.setVisited();
}
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainWithTransOptimization.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainWithTransOptimization.java
index d866fe9343..c2f14e0b0d 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainWithTransOptimization.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainWithTransOptimization.java
@@ -45,7 +45,7 @@ public class RewriteMatrixMultChainWithTransOptimization
extends HopRewriteRule
// Find the optimal order for the chain whose result is the
current HOP
for( Hop h : roots )
- rule_OptimizeMMChains(h, state);
+ ruleOptimizeMMChains(h, state);
return roots;
}
@@ -57,7 +57,7 @@ public class RewriteMatrixMultChainWithTransOptimization
extends HopRewriteRule
return null;
// Find the optimal order for the chain whose result is the
current HOP
- rule_OptimizeMMChains(root, state);
+ ruleOptimizeMMChains(root, state);
return root;
}
@@ -69,7 +69,7 @@ public class RewriteMatrixMultChainWithTransOptimization
extends HopRewriteRule
* @param hop The current high-level operator node.
* @param state The rewrite status.
*/
- private void rule_OptimizeMMChains(Hop hop, ProgramRewriteStatus state)
{
+ private void ruleOptimizeMMChains(Hop hop, ProgramRewriteStatus state) {
if (hop.isVisited()) return;
boolean isMatrixMult = HopRewriteUtils.isMatrixMultiply(hop) &&
!((AggBinaryOp) hop).hasLeftPMInput();
@@ -91,7 +91,7 @@ public class RewriteMatrixMultChainWithTransOptimization
extends HopRewriteRule
// .toArray(new Hop[0]) this prevents
ConcurrentModificationException because the optimizer
// may replace or modify parts of the HOP DAG during recursion
for( Hop i : currentHop.getInput().toArray(new Hop[0]) ) {
- rule_OptimizeMMChains(i, state);
+ ruleOptimizeMMChains(i, state);
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMatrixChainDPTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMatrixChainDPTest.java
index 66d224af77..64af7415f8 100644
---
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMatrixChainDPTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMatrixChainDPTest.java
@@ -28,9 +28,6 @@ import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
-import java.io.ByteArrayOutputStream;
-import java.io.PrintStream;
-
public class RewriteMatrixChainDPTest extends AutomatedTestBase {
private static final String TEST_DIR = "functions/rewrite/mmchain/";
@@ -52,76 +49,76 @@ public class RewriteMatrixChainDPTest extends
AutomatedTestBase {
}
@Test
- public void testMatrixChainDP_Test1() {
runTestMatrixChainDP(TEST_CASES[0]); }
+ public void testMatrixChainDPTest1() {
runTestMatrixChainDP(TEST_CASES[0]); }
@Test
- public void testMatrixChainDP_Test2() {
runTestMatrixChainDP(TEST_CASES[1]); }
+ public void testMatrixChainDPTest2() {
runTestMatrixChainDP(TEST_CASES[1]); }
@Test
- public void testMatrixChainDP_Test3() {
runTestMatrixChainDP(TEST_CASES[2]); }
+ public void testMatrixChainDPTest3() {
runTestMatrixChainDP(TEST_CASES[2]); }
@Test
- public void testMatrixChainDP_Test4() {
runTestMatrixChainDP(TEST_CASES[3]); }
+ public void testMatrixChainDPTest4() {
runTestMatrixChainDP(TEST_CASES[3]); }
@Test
- public void testMatrixChainDP_Test5() {
runTestMatrixChainDP(TEST_CASES[4]); }
+ public void testMatrixChainDPTest5() {
runTestMatrixChainDP(TEST_CASES[4]); }
@Test
- public void testMatrixChainDP_Test6() {
runTestMatrixChainDP(TEST_CASES[5]); }
+ public void testMatrixChainDPTest6() {
runTestMatrixChainDP(TEST_CASES[5]); }
@Test
- public void testMatrixChainDP_Test7() {
runTestMatrixChainDP(TEST_CASES[6]); }
+ public void testMatrixChainDPTest7() {
runTestMatrixChainDP(TEST_CASES[6]); }
@Test
- public void testMatrixChainDP_Test8() {
runTestMatrixChainDP(TEST_CASES[7]); }
+ public void testMatrixChainDPTest8() {
runTestMatrixChainDP(TEST_CASES[7]); }
@Test
- public void testMatrixChainDP_Test9() {
runTestMatrixChainDP(TEST_CASES[8]); }
+ public void testMatrixChainDPTest9() {
runTestMatrixChainDP(TEST_CASES[8]); }
@Test
- public void testMatrixChainDP_Test10() {
runTestMatrixChainDP(TEST_CASES[9]); }
+ public void testMatrixChainDPTest10() {
runTestMatrixChainDP(TEST_CASES[9]); }
@Test
- public void testMatrixChainDP_Test11() {
runTestMatrixChainDP(TEST_CASES[10]); }
+ public void testMatrixChainDPTest11() {
runTestMatrixChainDP(TEST_CASES[10]); }
@Test
- public void testMatrixChainDP_Test12() {
runTestMatrixChainDP(TEST_CASES[11]); }
+ public void testMatrixChainDPTest12() {
runTestMatrixChainDP(TEST_CASES[11]); }
@Test
- public void testMatrixChainDP_Test13() {
runTestMatrixChainDP(TEST_CASES[12]); }
+ public void testMatrixChainDPTest13() {
runTestMatrixChainDP(TEST_CASES[12]); }
@Test
- public void testMatrixChainDP_Test14() {
runTestMatrixChainDP(TEST_CASES[13]); }
+ public void testMatrixChainDPTest14() {
runTestMatrixChainDP(TEST_CASES[13]); }
@Test
- public void testMatrixChainDP_Test15() {
runTestMatrixChainDP(TEST_CASES[14]); }
+ public void testMatrixChainDPTest15() {
runTestMatrixChainDP(TEST_CASES[14]); }
@Test
- public void testMatrixChainDP_Test16() {
runTestMatrixChainDP(TEST_CASES[15]); }
+ public void testMatrixChainDPTest16() {
runTestMatrixChainDP(TEST_CASES[15]); }
@Test
- public void testMatrixChainDP_Test17() {
runTestMatrixChainDP(TEST_CASES[16]); }
+ public void testMatrixChainDPTest17() {
runTestMatrixChainDP(TEST_CASES[16]); }
@Test
- public void testMatrixChainDP_Test18() {
runTestMatrixChainDP(TEST_CASES[17]); }
+ public void testMatrixChainDPTest18() {
runTestMatrixChainDP(TEST_CASES[17]); }
@Test
- public void testMatrixChainDP_Test19() {
runTestMatrixChainDP(TEST_CASES[18]); }
+ public void testMatrixChainDPTest19() {
runTestMatrixChainDP(TEST_CASES[18]); }
@Test
- public void testMatrixChainDP_Test20() {
runTestMatrixChainDP(TEST_CASES[19]); }
+ public void testMatrixChainDPTest20() {
runTestMatrixChainDP(TEST_CASES[19]); }
@Test
- public void testMatrixChainDP_Test21() {
runTestMatrixChainDP(TEST_CASES[20]); }
+ public void testMatrixChainDPTest21() {
runTestMatrixChainDP(TEST_CASES[20]); }
@Test
- public void testMatrixChainDP_Test22() {
runTestMatrixChainDP(TEST_CASES[21]); }
+ public void testMatrixChainDPTest22() {
runTestMatrixChainDP(TEST_CASES[21]); }
@Test
- public void testMatrixChainDP_Test23() {
runTestMatrixChainDP(TEST_CASES[22]); }
+ public void testMatrixChainDPTest23() {
runTestMatrixChainDP(TEST_CASES[22]); }
@Test
- public void testMatrixChainDP_Test24()
{runTestMatrixChainDP(TEST_CASES[23]);}
+ public void testMatrixChainDPTest24()
{runTestMatrixChainDP(TEST_CASES[23]);}
private void runTestMatrixChainDP(String testName) {
@@ -144,23 +141,11 @@ public class RewriteMatrixChainDPTest extends
AutomatedTestBase {
programArgs = new String[]{ "-explain", "hops",
"-stats", "-args", output("R") };
- // print HOP DAG
- PrintStream originalOut = System.out;
- ByteArrayOutputStream bos = new ByteArrayOutputStream();
- System.setOut(new PrintStream(bos));
-
- try {
- // Execute the DML script
- runTest(true, false, null, -1);
- } finally {
- System.setOut(originalOut);
- }
-
- String output = bos.toString();
-
- System.out.println("Output for " + testName + ":\n" +
output);
+ // Execute the DML script
+ setOutputBuffering(true);
+ String output = runTest(true, false, null,
-1).toString();
- /* the following uses the intermediate matrices
dimensions to check, wether
+ /* the following uses the intermediate matrices
dimensions to check, whether
* the rewrite rule has found the optimal plan, which
is commented in each script
*/
switch(testName) {