This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 8307afe866 [SYSTEMDS-3921] New Rewrite for Relational Selection
Pushdown
8307afe866 is described below
commit 8307afe8663e8006c881ab2333d83f4a7bd3fe66
Author: mor-gk <[email protected]>
AuthorDate: Sun Mar 29 11:51:44 2026 +0200
[SYSTEMDS-3921] New Rewrite for Relational Selection Pushdown
Closes #2413.
---
.../java/org/apache/sysds/hops/OptimizerUtils.java | 7 +-
.../apache/sysds/hops/rewrite/ProgramRewriter.java | 2 +
.../sysds/hops/rewrite/RewriteRaPushdown.java | 215 +++++++++++++++++++++
.../rewrite/RewritePushdownRaSelectionTest.java | 154 +++++++++++++++
.../rewrite/RewritePushdownRaSelection.dml | 40 ++++
5 files changed, 417 insertions(+), 1 deletion(-)
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index f9e09c852c..f2d26be535 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -335,7 +335,12 @@ public class OptimizerUtils
public static boolean AUTO_GPU_CACHE_EVICTION = true;
- //////////////////////
+ /**
+ * Boolean specifying if relational algebra rewrites are allowed (e.g.
Selection Pushdowns).
+ */
+ public static boolean ALLOW_RA_REWRITES = false;
+
+ //////////////////////
// Optimizer levels //
//////////////////////
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index d84ae107d7..927433c8c7 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -117,6 +117,8 @@ public class ProgramRewriter{
_sbRuleSet.add( new
RewriteMarkLoopVariablesUpdateInPlace() );
if( LineageCacheConfig.getCompAssRW() )
_sbRuleSet.add( new MarkForLineageReuse()
);
+ if( OptimizerUtils.ALLOW_RA_REWRITES )
+ _sbRuleSet.add( new RewriteRaPushdown()
);
_sbRuleSet.add( new
RewriteRemoveTransformEncodeMeta() );
_dagRuleSet.add( new RewriteNonScalarPrint()
);
if( OptimizerUtils.ALLOW_JOIN_REORDERING_REWRITE )
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteRaPushdown.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteRaPushdown.java
new file mode 100644
index 0000000000..44fa5c686e
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteRaPushdown.java
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+
+package org.apache.sysds.hops.rewrite;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.FunctionOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.UnaryOp;
+import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.VariableSet;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Rule: Simplify program structure by rewriting relational expressions,
+ * implemented here: Pushdown of Selections before Join.
+ */
+public class RewriteRaPushdown extends StatementBlockRewriteRule
+{
+ @Override
+ public boolean createsSplitDag() {
+ return false;
+ }
+
+ @Override
+ public List<StatementBlock> rewriteStatementBlock(StatementBlock sb,
ProgramRewriteStatus state) {
+ ArrayList<StatementBlock> ret = new ArrayList<>();
+ ret.add(sb);
+ return ret;
+ }
+
+ @Override
+ public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock>
sbs, ProgramRewriteStatus state) {
+ if (sbs == null || sbs.size() <= 1)
+ return sbs;
+
+ ArrayList<StatementBlock> tmpList = new ArrayList<>(sbs);
+ boolean changed = false;
+
+ // iterate over all SBs including a FuncOp with FuncName
m_raJoin
+ for (int i : findFunctionSb(tmpList, "m_raJoin", 0)){
+ StatementBlock sb1 = tmpList.get(i);
+ FunctionOp joinOp = findFunctionOp(sb1.getHops(),
"m_raJoin");
+
+ // iterate over all following SBs including a FuncOp
with FuncName m_raSelection
+ for (int j : findFunctionSb(tmpList, "m_raSelection",
i+1)){
+ StatementBlock sb2 = tmpList.get(j);
+ FunctionOp selOp =
findFunctionOp(sb2.getHops(), "m_raSelection");
+
+ // create deep copy to ensure data consistency
+ FunctionOp tmpJoinOp = (FunctionOp)
Recompiler.deepCopyHopsDag(joinOp);
+ FunctionOp tmpSelOp = (FunctionOp)
Recompiler.deepCopyHopsDag(selOp);
+
+ if (!checkDataDependency(tmpJoinOp,
tmpSelOp)){continue;}
+
+ Hop selColHop = tmpSelOp.getInput(1);
+ long selCol =
getConstantSelectionCol(selColHop);
+ if (selCol <= 0)
+ continue;
+
+ // collect Variable Sets
+ VariableSet joinRead = new
VariableSet(sb1.variablesRead());
+ VariableSet joinUpdated = new
VariableSet(sb1.variablesUpdated());
+ VariableSet selRead = new
VariableSet(sb2.variablesRead());
+
+ // join inputs: [A, colA, B, colB, method]
+ long colsLeft =
tmpJoinOp.getInput(0).getDataCharacteristics().getCols();
+ long colsRight =
tmpJoinOp.getInput(2).getDataCharacteristics().getCols();
+ if (colsLeft <= 0 || colsRight <= 0)
+ continue;
+
+ // decide which side of inner join the
selection belongs to (A / B)
+ int selSideIdx;
+ if (selCol <= colsLeft) {
+ selSideIdx = 0;
+ }
+ else if (selCol <= colsLeft + colsRight) {
+ selSideIdx = 2;
+ LiteralOp adjustedColHop = new
LiteralOp(selCol - colsLeft);
+
adjustedColHop.setName(selColHop.getName());
+
HopRewriteUtils.replaceChildReference(tmpSelOp, selColHop, adjustedColHop, 1);
+ }
+ else { continue; } // invalid column index
+
+ // switch funcOps Output Variables
+ String joinOutVar =
tmpJoinOp.getOutputVariableNames()[0];
+ tmpJoinOp.getOutputVariableNames()[0] =
tmpSelOp.getOutputVariableNames()[0];
+ tmpSelOp.getOutputVariableNames()[0] =
joinOutVar;
+
+ // rewire selection to consume the correct join
input and adjusted column
+ Hop newSelInput =
tmpJoinOp.getInput().get(selSideIdx);
+ HopRewriteUtils.replaceChildReference(tmpSelOp,
tmpSelOp.getInput().get(0), newSelInput, 0);
+
+ // let the join take selection output instead
of raw input
+ Hop newJoinInput =
HopRewriteUtils.createTransientRead(joinOutVar, tmpSelOp);
+
HopRewriteUtils.replaceChildReference(tmpJoinOp, newSelInput, newJoinInput,
selSideIdx);
+
+ //switch StatementBlock-assignments
+ sb1.getHops().remove(joinOp);
+ sb1.getHops().add(tmpSelOp);
+ sb2.getHops().remove(selOp);
+ sb2.getHops().add(tmpJoinOp);
+
+ // modify SB- variable sets
+ VariableSet vs = new VariableSet();
+ vs.addVariable(joinOutVar,
joinUpdated.getVariable(joinOutVar));
+ selRead.removeVariables(vs);
+ selRead.addVariable(newSelInput.getName(),
joinRead.getVariable(newSelInput.getName()));
+
+ // selection now reads the original join inputs
plus its own metadata
+ sb1.setReadVariables(selRead);
+ sb1.setLiveOut(VariableSet.minus(joinUpdated,
selRead));
+ sb1.setLiveIn(selRead);
+ sb1.setGen(selRead);
+
+ // join now consumes the selection output and
produces the output
+ sb2.setReadVariables(sb1.liveOut());
+ sb2.setGen(sb1.liveOut());
+ sb2.setLiveIn(sb1.liveOut());
+
+ // mark change & increment i by 1 (i+1 = now
join-Sb)
+ changed = true;
+ i++;
+
+ LOG.debug("Applied rewrite: pushed
m_raSelection before m_raJoin (blocks lines "
+ + sb1.getBeginLine() + "-" +
sb1.getEndLine() + " and "
+ + sb2.getBeginLine() + "-" +
sb2.getEndLine() + ").");
+ }
+ }
+ return changed ? tmpList : sbs;
+ }
+
+ private List<Integer> findFunctionSb(List<StatementBlock> sbs, String
functionName, int startIdx) {
+ List<Integer> functionSbs = new ArrayList<>();
+
+ for (int i = startIdx; i < sbs.size(); i++) {
+ StatementBlock sb = sbs.get(i);
+
+ // easy preconditions
+ if (!HopRewriteUtils.isLastLevelStatementBlock(sb) ||
sb.isSplitDag()) {
+ continue;
+ }
+
+ // find if StatementBlocks have certain FunctionOp,
continue if not found
+ FunctionOp functionOp = findFunctionOp(sb.getHops(),
functionName);
+
+ // if found, add to list
+ if (functionOp != null) { functionSbs.add(i); }
+ }
+
+ return functionSbs;
+ }
+
+ private boolean checkDataDependency(FunctionOp fOut, FunctionOp fIn){
+ for (String out : fOut.getOutputVariableNames()) {
+ for (Hop h : fIn.getInput()) {
+ if (h.getName().equals(out)){
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
+ private FunctionOp findFunctionOp(List<Hop> roots, String functionName)
{
+ if (roots == null)
+ return null;
+ Hop.resetVisitStatus(roots, true);
+ for (Hop root : roots) {
+ if (root instanceof FunctionOp funcOp) {
+ if
(funcOp.getFunctionName().equals(functionName))
+ { return funcOp; }
+ }
+ }
+ return null;
+ }
+
+ private long getConstantSelectionCol(Hop selColHop) {
+ if (selColHop instanceof LiteralOp lit)
+ return HopRewriteUtils.getIntValueSafe(lit);
+
+ // Handle casted literals (e.g., type propagation inserted
casts)
+ if (selColHop instanceof UnaryOp uop && uop.getOp() ==
Types.OpOp1.CAST_AS_INT
+ && uop.getInput().get(0) instanceof LiteralOp
lit)
+ return HopRewriteUtils.getIntValueSafe(lit);
+
+ // If hop is a dataop whose input is a literal, try to fold
+ if (selColHop instanceof DataOp dop &&
!dop.getInput().isEmpty() && dop.getInput().get(0) instanceof LiteralOp lit)
+ return HopRewriteUtils.getIntValueSafe(lit);
+
+ return -1; // unknown at rewrite time
+ }
+}
\ No newline at end of file
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownRaSelectionTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownRaSelectionTest.java
new file mode 100644
index 0000000000..d9bb0f6bdc
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownRaSelectionTest.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import java.util.HashMap;
+
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public class RewritePushdownRaSelectionTest extends AutomatedTestBase
+{
+ private static final String TEST_NAME = "RewritePushdownRaSelection";
+ private static final String TEST_DIR = "functions/rewrite/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
RewritePushdownRaSelectionTest.class.getSimpleName() + "/";
+
+ private static final double eps = 1e-8;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"result"}));
+ }
+
+ @Test
+ public void testRewritePushdownRaSelectionNoRewrite() {
+ int col = 1;
+ String op = Opcodes.EQUAL.toString();
+ double val = 4.0;
+
+ // Expected output matrix
+ double[][] Y = {
+ {4,7,8,4,7,8},
+ {4,7,8,4,5,10},
+ {4,3,5,4,7,8},
+ {4,3,5,4,5,10},
+ };
+
+ testRewritePushdownRaSelection(col, op, val, Y, "nested-loop",
false);
+ }
+
+ @Test
+ public void testRewritePushdownRaSelection1() {
+ int col = 1;
+ String op = Opcodes.EQUAL.toString();
+ double val = 4.0;
+
+ // Expected output matrix
+ double[][] Y = {
+ {4,7,8,4,7,8},
+ {4,7,8,4,5,10},
+ {4,3,5,4,7,8},
+ {4,3,5,4,5,10},
+ };
+
+ testRewritePushdownRaSelection(col, op, val, Y, "sort-merge",
true);
+ }
+
+ @Test
+ public void testRewritePushdownRaSelection2() {
+ int col = 5;
+ String op = Opcodes.EQUAL.toString();
+ double val = 7.0;
+
+ // Expected output matrix
+ double[][] Y = {
+ {4,7,8,4,7,8},
+ {4,3,5,4,7,8},
+ };
+
+ testRewritePushdownRaSelection(col, op, val, Y, "sort-merge",
true);
+ }
+
+ private void testRewritePushdownRaSelection(int col, String op, double
val, double[][] Y,
+
String method, boolean rewrites) {
+
+ //generate actual dataset and variables
+ double[][] A = {
+ {1, 2, 3},
+ {4, 7, 8},
+ {1, 3, 6},
+ {4, 3, 5},
+ {5, 8, 9}
+ };
+ double[][] B = {
+ {1, 2, 9},
+ {3, 7, 6},
+ {2, 8, 5},
+ {4, 7, 8},
+ {4, 5, 10}
+ };
+ int colA = 1;
+ int colB = 1;
+
+ runRewritePushdownRaSelectionTest(A, colA, B, colB, Y, col, op,
val, method, rewrites);
+ }
+
+
+ private void runRewritePushdownRaSelectionTest(double [][] A, int colA,
double [][] B, int colB, double [][] Y,
+
int col, String op, double val, String method, boolean
rewrites)
+ {
+ Types.ExecMode platformOld =
setExecMode(Types.ExecMode.SINGLE_NODE);
+ boolean oldFlag = OptimizerUtils.ALLOW_RA_REWRITES;
+
+ try
+ {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[]{"-explain", "hops", "-args",
+ input("A"), String.valueOf(colA),
input("B"),
+ String.valueOf(colB),
String.valueOf(col), op, String.valueOf(val), method, output("result") };
+ writeInputMatrixWithMTD("A", A, true);
+ writeInputMatrixWithMTD("B", B, true);
+
+ OptimizerUtils.ALLOW_RA_REWRITES = rewrites;
+
+ // run dmlScript
+ runTest(null);
+
+ //compare matrices
+ HashMap<CellIndex, Double> dmlfile =
readDMLMatrixFromOutputDir("result");
+ HashMap<CellIndex, Double> expectedOutput =
TestUtils.convert2DDoubleArrayToHashMap(Y);
+ TestUtils.compareMatrices(dmlfile, expectedOutput, eps,
"Stat-DML", "Expected");
+ }
+ finally {
+ rtplatform = platformOld;
+ OptimizerUtils.ALLOW_RA_REWRITES = oldFlag;
+ }
+ }
+}
diff --git a/src/test/scripts/functions/rewrite/RewritePushdownRaSelection.dml
b/src/test/scripts/functions/rewrite/RewritePushdownRaSelection.dml
new file mode 100644
index 0000000000..ef149acb35
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewritePushdownRaSelection.dml
@@ -0,0 +1,40 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A= read($1)
+colA = as.integer($2)
+B = read($3)
+colB = as.integer($4)
+op = $6
+
+C = raJoin(A, colA, B, colB, $8);
+result = raSelection(C, $5, op, $7);
+
+# the above will be rewritten into:
+#
+# C = raSelection(A, col, op, val);
+# result = raJoin(C, colA, B, colB, method);
+# or (depending on col):
+# C = raSelection(B, (col - A.cols), op, val);
+# result = raJoin(A, colA, C, colB, method);
+
+write(result, $9);
+print(toString(result))
\ No newline at end of file