This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 8307afe866 [SYSTEMDS-3921] New Rewrite for Relational Selection 
Pushdown
8307afe866 is described below

commit 8307afe8663e8006c881ab2333d83f4a7bd3fe66
Author: mor-gk <[email protected]>
AuthorDate: Sun Mar 29 11:51:44 2026 +0200

    [SYSTEMDS-3921] New Rewrite for Relational Selection Pushdown
    
    Closes #2413.
---
 .../java/org/apache/sysds/hops/OptimizerUtils.java |   7 +-
 .../apache/sysds/hops/rewrite/ProgramRewriter.java |   2 +
 .../sysds/hops/rewrite/RewriteRaPushdown.java      | 215 +++++++++++++++++++++
 .../rewrite/RewritePushdownRaSelectionTest.java    | 154 +++++++++++++++
 .../rewrite/RewritePushdownRaSelection.dml         |  40 ++++
 5 files changed, 417 insertions(+), 1 deletion(-)

diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java 
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index f9e09c852c..f2d26be535 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -335,7 +335,12 @@ public class OptimizerUtils
 
        public static boolean AUTO_GPU_CACHE_EVICTION = true;
 
-       //////////////////////
+    /**
+     * Boolean specifying if relational algebra rewrites are allowed (e.g. 
Selection Pushdowns).
+     */
+    public static boolean ALLOW_RA_REWRITES = false;
+
+    //////////////////////
        // Optimizer levels //
        //////////////////////
 
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java 
b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index d84ae107d7..927433c8c7 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -117,6 +117,8 @@ public class ProgramRewriter{
                                _sbRuleSet.add(  new 
RewriteMarkLoopVariablesUpdateInPlace()     );
                        if( LineageCacheConfig.getCompAssRW() )
                                _sbRuleSet.add(  new MarkForLineageReuse()      
                 );
+            if( OptimizerUtils.ALLOW_RA_REWRITES )
+                _sbRuleSet.add(  new RewriteRaPushdown()                       
  );
                        _sbRuleSet.add(      new 
RewriteRemoveTransformEncodeMeta()          );
                        _dagRuleSet.add( new RewriteNonScalarPrint()            
             );
                        if( OptimizerUtils.ALLOW_JOIN_REORDERING_REWRITE )
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteRaPushdown.java 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteRaPushdown.java
new file mode 100644
index 0000000000..44fa5c686e
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteRaPushdown.java
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+
+package org.apache.sysds.hops.rewrite;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.FunctionOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.UnaryOp;
+import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.VariableSet;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Rule: Simplify program structure by rewriting relational expressions,
+ * implemented here: Pushdown of Selections before Join.
+ */
+public class RewriteRaPushdown extends StatementBlockRewriteRule
+{
+       @Override
+       public boolean createsSplitDag() {
+               return false;
+       }
+
+       @Override
+       public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, 
ProgramRewriteStatus state) {
+               ArrayList<StatementBlock> ret = new ArrayList<>();
+               ret.add(sb);
+               return ret;
+       }
+
+       @Override
+       public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> 
sbs, ProgramRewriteStatus state) {
+               if (sbs == null || sbs.size() <= 1)
+                       return sbs;
+
+               ArrayList<StatementBlock> tmpList = new ArrayList<>(sbs);
+               boolean changed = false;
+
+               // iterate over all SBs including a FuncOp with FuncName 
m_raJoin
+               for (int i : findFunctionSb(tmpList, "m_raJoin", 0)){
+                       StatementBlock sb1 = tmpList.get(i);
+                       FunctionOp joinOp = findFunctionOp(sb1.getHops(),  
"m_raJoin");
+
+                       // iterate over all following SBs including a FuncOp 
with FuncName m_raSelection
+                       for (int j : findFunctionSb(tmpList, "m_raSelection", 
i+1)){
+                               StatementBlock sb2 = tmpList.get(j);
+                               FunctionOp selOp = 
findFunctionOp(sb2.getHops(), "m_raSelection");
+
+                               // create deep copy to ensure data consistency
+                               FunctionOp tmpJoinOp = (FunctionOp) 
Recompiler.deepCopyHopsDag(joinOp);
+                               FunctionOp tmpSelOp = (FunctionOp) 
Recompiler.deepCopyHopsDag(selOp);
+
+                               if (!checkDataDependency(tmpJoinOp, 
tmpSelOp)){continue;}
+
+                               Hop selColHop = tmpSelOp.getInput(1);
+                               long selCol = 
getConstantSelectionCol(selColHop);
+                               if (selCol <= 0)
+                                       continue;
+
+                               // collect Variable Sets
+                               VariableSet joinRead = new 
VariableSet(sb1.variablesRead());
+                               VariableSet joinUpdated = new 
VariableSet(sb1.variablesUpdated());
+                               VariableSet selRead = new 
VariableSet(sb2.variablesRead());
+
+                               // join inputs: [A, colA, B, colB, method]
+                               long colsLeft = 
tmpJoinOp.getInput(0).getDataCharacteristics().getCols();
+                               long colsRight = 
tmpJoinOp.getInput(2).getDataCharacteristics().getCols();
+                               if (colsLeft <= 0 || colsRight <= 0)
+                                       continue;
+
+                               // decide which side of inner join the 
selection belongs to (A / B)
+                               int selSideIdx;
+                               if (selCol <= colsLeft) {
+                                       selSideIdx = 0;
+                               }
+                               else if (selCol <= colsLeft + colsRight) {
+                                       selSideIdx = 2;
+                                       LiteralOp adjustedColHop = new 
LiteralOp(selCol - colsLeft);
+                                       
adjustedColHop.setName(selColHop.getName());
+                                       
HopRewriteUtils.replaceChildReference(tmpSelOp, selColHop, adjustedColHop, 1);
+                               }
+                               else { continue; } // invalid column index
+
+                               // switch funcOps Output Variables
+                               String joinOutVar = 
tmpJoinOp.getOutputVariableNames()[0];
+                               tmpJoinOp.getOutputVariableNames()[0] = 
tmpSelOp.getOutputVariableNames()[0];
+                               tmpSelOp.getOutputVariableNames()[0] = 
joinOutVar;
+
+                               // rewire selection to consume the correct join 
input and adjusted column
+                               Hop newSelInput = 
tmpJoinOp.getInput().get(selSideIdx);
+                               HopRewriteUtils.replaceChildReference(tmpSelOp, 
tmpSelOp.getInput().get(0), newSelInput, 0);
+
+                               // let the join take selection output instead 
of raw input
+                               Hop newJoinInput = 
HopRewriteUtils.createTransientRead(joinOutVar, tmpSelOp);
+                               
HopRewriteUtils.replaceChildReference(tmpJoinOp, newSelInput, newJoinInput, 
selSideIdx);
+
+                               //switch StatementBlock-assignments
+                               sb1.getHops().remove(joinOp);
+                               sb1.getHops().add(tmpSelOp);
+                               sb2.getHops().remove(selOp);
+                               sb2.getHops().add(tmpJoinOp);
+
+                               // modify SB- variable sets
+                               VariableSet vs = new VariableSet();
+                               vs.addVariable(joinOutVar, 
joinUpdated.getVariable(joinOutVar));
+                               selRead.removeVariables(vs);
+                               selRead.addVariable(newSelInput.getName(), 
joinRead.getVariable(newSelInput.getName()));
+
+                               // selection now reads the original join inputs 
plus its own metadata
+                               sb1.setReadVariables(selRead);
+                               sb1.setLiveOut(VariableSet.minus(joinUpdated, 
selRead));
+                               sb1.setLiveIn(selRead);
+                               sb1.setGen(selRead);
+
+                               // join now consumes the selection output and 
produces the output
+                               sb2.setReadVariables(sb1.liveOut());
+                               sb2.setGen(sb1.liveOut());
+                               sb2.setLiveIn(sb1.liveOut());
+
+                               // mark change & increment i by 1 (i+1 = now 
join-Sb)
+                               changed = true;
+                               i++;
+
+                               LOG.debug("Applied rewrite: pushed 
m_raSelection before m_raJoin (blocks lines "
+                                               + sb1.getBeginLine() + "-" + 
sb1.getEndLine() + " and "
+                                               + sb2.getBeginLine() + "-" + 
sb2.getEndLine() + ").");
+                       }
+               }
+               return changed ? tmpList : sbs;
+       }
+
+       private List<Integer> findFunctionSb(List<StatementBlock> sbs, String 
functionName, int startIdx) {
+               List<Integer> functionSbs = new ArrayList<>();
+
+               for (int i = startIdx; i < sbs.size(); i++) {
+                       StatementBlock sb = sbs.get(i);
+
+                       // easy preconditions
+                       if (!HopRewriteUtils.isLastLevelStatementBlock(sb) || 
sb.isSplitDag()) {
+                               continue;
+                       }
+
+                       // find if StatementBlocks have certain FunctionOp, 
continue if not found
+                       FunctionOp functionOp = findFunctionOp(sb.getHops(), 
functionName);
+
+                       // if found, add to list
+                       if (functionOp != null) { functionSbs.add(i); }
+               }
+
+               return functionSbs;
+       }
+
+       private boolean checkDataDependency(FunctionOp fOut, FunctionOp fIn){
+               for (String out : fOut.getOutputVariableNames()) {
+                       for (Hop h : fIn.getInput()) {
+                               if (h.getName().equals(out)){
+                                       return true;
+                               }
+                       }
+               }
+               return false;
+       }
+
+       private FunctionOp findFunctionOp(List<Hop> roots, String functionName) 
{
+               if (roots == null)
+                       return null;
+               Hop.resetVisitStatus(roots, true);
+               for (Hop root : roots) {
+                       if (root instanceof FunctionOp funcOp) {
+                               if 
(funcOp.getFunctionName().equals(functionName))
+                               { return funcOp; }
+                       }
+               }
+               return null;
+       }
+
+       private long getConstantSelectionCol(Hop selColHop) {
+               if (selColHop instanceof LiteralOp lit)
+                       return HopRewriteUtils.getIntValueSafe(lit);
+
+               // Handle casted literals (e.g., type propagation inserted 
casts)
+               if (selColHop instanceof UnaryOp uop && uop.getOp() == 
Types.OpOp1.CAST_AS_INT
+                               && uop.getInput().get(0) instanceof LiteralOp 
lit)
+                       return HopRewriteUtils.getIntValueSafe(lit);
+
+               // If hop is a dataop whose input is a literal, try to fold
+               if (selColHop instanceof DataOp dop && 
!dop.getInput().isEmpty() && dop.getInput().get(0) instanceof LiteralOp lit)
+                       return HopRewriteUtils.getIntValueSafe(lit);
+
+               return -1; // unknown at rewrite time
+       }
+}
\ No newline at end of file
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownRaSelectionTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownRaSelectionTest.java
new file mode 100644
index 0000000000..d9bb0f6bdc
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownRaSelectionTest.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import java.util.HashMap;
+
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public class RewritePushdownRaSelectionTest extends AutomatedTestBase
+{
+       private static final String TEST_NAME = "RewritePushdownRaSelection";
+       private static final String TEST_DIR = "functions/rewrite/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewritePushdownRaSelectionTest.class.getSimpleName() + "/";
+
+       private static final double eps = 1e-8;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"result"}));
+       }
+
+       @Test
+       public void testRewritePushdownRaSelectionNoRewrite() {
+               int col = 1;
+               String op = Opcodes.EQUAL.toString();
+               double val = 4.0;
+
+               // Expected output matrix
+               double[][] Y = {
+                               {4,7,8,4,7,8},
+                               {4,7,8,4,5,10},
+                               {4,3,5,4,7,8},
+                               {4,3,5,4,5,10},
+               };
+
+               testRewritePushdownRaSelection(col, op, val, Y, "nested-loop", 
false);
+       }
+
+       @Test
+       public void testRewritePushdownRaSelection1() {
+               int col = 1;
+               String op = Opcodes.EQUAL.toString();
+               double val = 4.0;
+
+               // Expected output matrix
+               double[][] Y = {
+                               {4,7,8,4,7,8},
+                               {4,7,8,4,5,10},
+                               {4,3,5,4,7,8},
+                               {4,3,5,4,5,10},
+               };
+
+               testRewritePushdownRaSelection(col, op, val, Y, "sort-merge", 
true);
+       }
+
+       @Test
+       public void testRewritePushdownRaSelection2() {
+               int col = 5;
+               String op = Opcodes.EQUAL.toString();
+               double val = 7.0;
+
+               // Expected output matrix
+               double[][] Y = {
+                               {4,7,8,4,7,8},
+                               {4,3,5,4,7,8},
+               };
+
+               testRewritePushdownRaSelection(col, op, val, Y, "sort-merge", 
true);
+       }
+
+       private void testRewritePushdownRaSelection(int col, String op, double 
val, double[][] Y,
+                                                                               
                String method, boolean rewrites) {
+
+               //generate actual dataset and variables
+               double[][] A = {
+                               {1, 2, 3},
+                               {4, 7, 8},
+                               {1, 3, 6},
+                               {4, 3, 5},
+                               {5, 8, 9}
+               };
+               double[][] B = {
+                               {1, 2, 9},
+                               {3, 7, 6},
+                               {2, 8, 5},
+                               {4, 7, 8},
+                               {4, 5, 10}
+               };
+               int colA = 1;
+               int colB = 1;
+
+               runRewritePushdownRaSelectionTest(A, colA, B, colB, Y, col, op, 
val, method, rewrites);
+       }
+
+
+       private void runRewritePushdownRaSelectionTest(double [][] A, int colA, 
double [][] B, int colB, double [][] Y,
+                                                                               
                   int col, String op, double val, String method, boolean 
rewrites)
+       {
+               Types.ExecMode platformOld = 
setExecMode(Types.ExecMode.SINGLE_NODE);
+               boolean oldFlag = OptimizerUtils.ALLOW_RA_REWRITES;
+
+               try
+               {
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[]{"-explain", "hops", "-args",
+                                       input("A"), String.valueOf(colA), 
input("B"),
+                                       String.valueOf(colB), 
String.valueOf(col), op, String.valueOf(val), method, output("result") };
+                       writeInputMatrixWithMTD("A", A, true);
+                       writeInputMatrixWithMTD("B", B, true);
+
+                       OptimizerUtils.ALLOW_RA_REWRITES = rewrites;
+
+                       // run dmlScript
+                       runTest(null);
+
+                       //compare matrices
+                       HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromOutputDir("result");
+                       HashMap<CellIndex, Double> expectedOutput = 
TestUtils.convert2DDoubleArrayToHashMap(Y);
+                       TestUtils.compareMatrices(dmlfile, expectedOutput, eps, 
"Stat-DML", "Expected");
+               }
+               finally {
+                       rtplatform = platformOld;
+                       OptimizerUtils.ALLOW_RA_REWRITES = oldFlag;
+               }
+       }
+}
diff --git a/src/test/scripts/functions/rewrite/RewritePushdownRaSelection.dml 
b/src/test/scripts/functions/rewrite/RewritePushdownRaSelection.dml
new file mode 100644
index 0000000000..ef149acb35
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewritePushdownRaSelection.dml
@@ -0,0 +1,40 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A= read($1)
+colA = as.integer($2)
+B = read($3)
+colB = as.integer($4)
+op = $6
+
+C = raJoin(A, colA, B, colB, $8);
+result = raSelection(C, $5, op, $7);
+
+# the above will be rewritten into:
+#
+# C = raSelection(A, col, op, val);
+# result = raJoin(C, colA, B, colB, method);
+# or (depending on col):
+# C = raSelection(B, (col - A.cols), op, val);
+# result = raJoin(A, colA, C, colB, method);
+
+write(result, $9);
+print(toString(result))
\ No newline at end of file

Reply via email to