This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new c171170b84 [SYSTEMDS-3921] Initial join ordering rewrites
c171170b84 is described below
commit c171170b84433a61bf80d9e6df58079d44a46841
Author: migraine-user <[email protected]>
AuthorDate: Sat Mar 28 15:00:49 2026 +0100
[SYSTEMDS-3921] Initial join ordering rewrites
Closes #2424.
---
scripts/builtin/raJoin.dml | 2 +-
.../java/org/apache/sysds/hops/OptimizerUtils.java | 1 +
.../apache/sysds/hops/rewrite/ProgramRewriter.java | 4 +-
.../sysds/hops/rewrite/RewriteJoinReordering.java | 626 +++++++++++++++++++++
.../test/functions/rewrite/RewriteRaJoinTest.java | 101 ++++
src/test/scripts/functions/rewrite/raJoin.dml | 33 ++
6 files changed, 765 insertions(+), 2 deletions(-)
diff --git a/scripts/builtin/raJoin.dml b/scripts/builtin/raJoin.dml
index 7fa7572a36..5d3335277d 100644
--- a/scripts/builtin/raJoin.dml
+++ b/scripts/builtin/raJoin.dml
@@ -27,7 +27,7 @@
# A Matrix of left input data [shape: N x M]
# colA Integer indicating the column index of matrix A to execute inner
join command
# B Matrix of right left data [shape: N x M]
-# colA Integer indicating the column index of matrix B to execute inner
join command
+# colB Integer indicating the column index of matrix B to execute inner
join command
# method Join implementation method (nested-loop, sort-merge, hash, hash2)
#
------------------------------------------------------------------------------
#
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index 9ba3ea3ed7..f9e09c852c 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -302,6 +302,7 @@ public class OptimizerUtils
*/
public static boolean ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = true;
+ public static boolean ALLOW_JOIN_REORDERING_REWRITE = false;
/**
* Enable prefetch and broadcast. Prefetch asynchronously calls
acquireReadAndRelease() to trigger remote
* operations, which would otherwise make the next instruction wait
till completion. Broadcast allows
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index 98534f5d8c..d84ae107d7 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -77,6 +77,7 @@ public class ProgramRewriter{
//add static HOP DAG rewrite rules
_dagRuleSet.add( new RewriteRemoveReadAfterWrite()
); //dependency: before blocksize
_dagRuleSet.add( new RewriteBlockSizeAndReblock()
);
+
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
_dagRuleSet.add( new
RewriteRemoveUnnecessaryCasts() );
if(
OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
@@ -93,7 +94,6 @@ public class ProgramRewriter{
if( OptimizerUtils.ALLOW_QUANTIZE_COMPRESS_REWRITE )
_dagRuleSet.add( new
RewriteQuantizationFusedCompression() );
-
//add statement block rewrite rules
if( OptimizerUtils.ALLOW_BRANCH_REMOVAL )
_sbRuleSet.add( new
RewriteRemoveUnnecessaryBranches() ); //dependency: constant folding
@@ -119,6 +119,8 @@ public class ProgramRewriter{
_sbRuleSet.add( new MarkForLineageReuse()
);
_sbRuleSet.add( new
RewriteRemoveTransformEncodeMeta() );
_dagRuleSet.add( new RewriteNonScalarPrint()
);
+ if( OptimizerUtils.ALLOW_JOIN_REORDERING_REWRITE )
+ _sbRuleSet.add( new RewriteJoinReordering() );
}
// DYNAMIC REWRITES (which do require size information)
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteJoinReordering.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteJoinReordering.java
new file mode 100644
index 0000000000..39192d8df2
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteJoinReordering.java
@@ -0,0 +1,626 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+
+package org.apache.sysds.hops.rewrite;
+
+import java.util.ArrayList;
+import java.util.BitSet;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.FunctionOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.VariableSet;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+import org.apache.sysds.parser.DataIdentifier;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.FunctionStatement;
+import org.apache.sysds.parser.FunctionStatementBlock;
+import org.apache.sysds.parser.IfStatement;
+import org.apache.sysds.parser.IfStatementBlock;
+
+public class RewriteJoinReordering extends StatementBlockRewriteRule {
+ // This exception is thrown when we cannot determine the base
dependencies of a
+ // given join.
+ @SuppressWarnings("serial")
+ private class UnknownCanonicalJoinException extends RuntimeException {
+ private UnknownCanonicalJoinException() {
+ super();
+ }
+ }
+
+ // This exception is thrown when we cannot determine the dimension
information
+ // for a given non-raJoin HOP.
+ @SuppressWarnings("serial")
+private class UnknownDimensionInfoException extends RuntimeException {
+ private UnknownDimensionInfoException() {
+ super();
+ }
+ }
+
+ private boolean isRaJoin(Hop node) {
+ if (node instanceof FunctionOp fnode) {
+ return fnode.getFunctionNamespace().equals(".builtinNS")
+ &&
fnode.getFunctionName().equals("m_raJoin");
+ }
+ return false;
+ }
+
+ private boolean isLiteralInt(Hop node) {
+ if (node instanceof LiteralOp) {
+ return node.getValueType() == ValueType.INT64;
+ }
+ return false;
+ }
+
+ private boolean isKnownMatrix(Hop hop) {
+ return hop.getDim1() > 0 && hop.getDim2() > 0;
+ }
+
+ @Override
+ public boolean createsSplitDag() {
+ return false;
+ }
+
+ /**
+ * Collect all raJoin calls
+ *
+ * @param sb current statement block to search from
+ * @param joinMap a mapping from the bound output variable name to the
index of
+ * the
join in the `joins` list.
+ * @param joins a list to accumulate all found raJoins
+ */
+ private void collectRaJoin(HashMap<Hop, StatementBlock> hopToSb,
StatementBlock sb, HashMap<String, Integer> joinMap,
+ ArrayList<FunctionOp> joins) {
+ if (sb instanceof FunctionStatementBlock) {
+ FunctionStatementBlock fsb = (FunctionStatementBlock)
sb;
+ FunctionStatement fstmt = (FunctionStatement)
fsb.getStatement(0);
+ for (StatementBlock sbi : fstmt.getBody())
+ collectRaJoin(hopToSb, sbi, joinMap, joins);
+ } else if (sb instanceof WhileStatementBlock) {
+ WhileStatementBlock wsb = (WhileStatementBlock) sb;
+ WhileStatement wstmt = (WhileStatement)
wsb.getStatement(0);
+ for (StatementBlock sbi : wstmt.getBody())
+ collectRaJoin(hopToSb, sbi, joinMap, joins);
+ } else if (sb instanceof IfStatementBlock) {
+ IfStatementBlock isb = (IfStatementBlock) sb;
+ IfStatement istmt = (IfStatement) isb.getStatement(0);
+ for (StatementBlock sbi : istmt.getIfBody())
+ collectRaJoin(hopToSb, sbi, joinMap, joins);
+ for (StatementBlock sbi : istmt.getElseBody())
+ collectRaJoin(hopToSb, sbi, joinMap, joins);
+ } else if (sb instanceof ForStatementBlock) // incl parfor
+ {
+ ForStatementBlock fsb = (ForStatementBlock) sb;
+ ForStatement fstmt = (ForStatement) fsb.getStatement(0);
+ for (StatementBlock sbi : fstmt.getBody())
+ collectRaJoin(hopToSb, sbi, joinMap, joins);
+ } else // generic (last-level)
+ {
+ /*
+ * Check for raJoins at this branch
+ */
+ for (Hop hop : sb.getHops()) {
+ if (isRaJoin(hop)) {
+ FunctionOp fhop = (FunctionOp) hop;
+ processRaJoin(sb, hopToSb, fhop,
joinMap, joins);
+ }
+ }
+ }
+ }
+
+ /**
+ * Add an raJoin HOP to custom intermediate objects.
+ *
+ * @param fhop the raJoin Hop
+ * @param joinMap a mapping from the bound output variable name to the
index of
+ * the
join in the `joins` list.
+ * @param joins a list to accumulate all found raJoins
+ */
+ private void processRaJoin(StatementBlock sb, HashMap<Hop,
StatementBlock> hopToSb, FunctionOp fhop,
+ HashMap<String, Integer> joinMap, ArrayList<FunctionOp>
joins) {
+ Hop acol = fhop.getInput(1);
+ Hop bcol = fhop.getInput(3);
+ // only support literal values.
+ if (!isLiteralInt(acol) || !isLiteralInt(bcol)) {
+ return;
+ }
+
+ for (String varName : fhop.getOutputVariableNames()) {
+ joinMap.put(varName, joins.size());
+ }
+ joins.add(fhop);
+ hopToSb.put(fhop, sb);
+ }
+
+ /**
+ * Find the topological order of all joins.
+ *
+ * @param joinMap
+ * @param joins
+ * @return the topological order of joins as indices of `joins`
+ */
+ private ArrayList<Integer> topoOrder(HashMap<String, Integer> joinMap,
ArrayList<FunctionOp> joins) {
+ ArrayList<Integer> topoOrder = new ArrayList<>();
+ boolean[] visited = new boolean[joins.size()];
+ for (int i = 0; i < joins.size(); i++)
+ dfsOrder(joinMap, joins, topoOrder, visited, i);
+ Collections.reverse(topoOrder);
+ return topoOrder;
+ }
+
+ /**
+ * DFS call to find the topological order.
+ *
+ * @param joinMap
+ * @param joins
+ * @param order
+ * @param visited
+ * @param i the current join index we are at
+ */
+ private void dfsOrder(HashMap<String, Integer> joinMap,
ArrayList<FunctionOp> joins, ArrayList<Integer> order,
+ boolean[] visited, int i) {
+ visited[i] = true;
+ FunctionOp join = joins.get(i);
+ Hop a = join.getInput(0);
+ Hop b = join.getInput(2);
+ // recurse if the matrix is not a base matrix.
+ if (!isKnownMatrix(a)) {
+ Integer next = joinMap.get(a.getName());
+ if (next == null)
+ throw new UnknownCanonicalJoinException();
+ if (!visited[next]) {
+ dfsOrder(joinMap, joins, order, visited, next);
+ }
+ }
+ if (!isKnownMatrix(b)) {
+ Integer next = joinMap.get(b.getName());
+ if (next == null)
+ throw new UnknownCanonicalJoinException();
+ if (!visited[next]) {
+ dfsOrder(joinMap, joins, order, visited, next);
+ }
+ }
+ order.add(i);
+ }
+
+ /**
+ * rewrite all roots
+ *
+ * @param joinMap
+ * @param joins all raJoins
+ * @param order topological order of joins
+ */
+ private void rewriteRoots(ArrayList<StatementBlock> sbs, HashMap<Hop,
StatementBlock> hopToSb,
+ HashMap<String, Integer> joinMap, ArrayList<FunctionOp>
joins, ArrayList<Integer> order) {
+ boolean[] visited = new boolean[joins.size()];
+ for (int i : order) {
+ if (!visited[i]) {
+ try {
+ rewriteRoot(sbs, hopToSb, joinMap,
joins, visited, i);
+ } catch (Exception e) {
+ // if it is a local exception, try
rewriting the next root.
+ if ((e instanceof
UnknownCanonicalJoinException) || (e instanceof UnknownDimensionInfoException))
{
+ continue;
+ }
+ throw e;
+ }
+ }
+ }
+
+ HashSet<Hop> consumedHops = new HashSet<>();
+ for (int i = 0; i < joins.size(); i++) {
+ if (!visited[i])
+ continue;
+ consumedHops.add(joins.get(i));
+ HopRewriteUtils.cleanupUnreferenced(joins.get(i));
+ }
+ for (Hop hop : hopToSb.keySet()) {
+ if (!consumedHops.contains(hop))
+ continue;
+ hopToSb.get(hop).getHops().remove(hop);
+ }
+ }
+
+ // Custom representation of nested join calls.
+ sealed interface JoinNode permits BaseNode, BinaryNode {
+ }
+
+ private record BaseNode(int i) implements JoinNode {
+ };
+
+ private record BinaryNode(JoinNode left, long leftCol, JoinNode right,
long rightCol, String method)
+ implements JoinNode {
+ };
+
+ private record Cost(long dim1, long dim2, long cost, JoinNode node) {
+ };
+
+ // Rewrite a single root
+ private void rewriteRoot(ArrayList<StatementBlock> sbs, HashMap<Hop,
StatementBlock> hopToSb,
+ HashMap<String, Integer> joinMap, ArrayList<FunctionOp>
joins, boolean[] visited, int rootIndex) {
+ // get bases traversal = base relations(matrices)
+ FunctionOp root = joins.get(rootIndex);
+ ArrayList<Hop> bases = new ArrayList<>();
+ ArrayList<Long> basesLengthPrefixSum = new ArrayList<>();
+ ArrayList<CanonicalJoin> canonicalJoins = new ArrayList<>();
+ dfsInorder(joinMap, joins, canonicalJoins, visited, bases,
basesLengthPrefixSum, rootIndex);
+ // convert all joins to joins between base relations.
+ HashMap<BitSet, Cost> dp = new HashMap<>();
+ for (int i = 0; i < bases.size() - 1; i++) {
+ BitSet leftBS = new BitSet();
+ BitSet rightBS = new BitSet();
+ leftBS.set(i);
+ rightBS.set(i + 1);
+ CanonicalJoin validJoin = getValidJoin(canonicalJoins,
leftBS, rightBS);
+ if (validJoin == null) {
+ continue;
+ }
+ BitSet bs = new BitSet(bases.size());
+ bs.set(i);
+ bs.set(i + 1);
+ Hop left = bases.get(i);
+ Hop right = bases.get(i + 1);
+
+ long dim1 = left.getDim1() * right.getDim1();
+ long dim2 = left.getDim2() + right.getDim2();
+ long cost = dim1 * dim2;
+
+ long leftCol = validJoin.acol;
+ long rightCol = validJoin.bcol;
+
+ JoinNode joinNode = new BinaryNode(new BaseNode(i),
leftCol, new BaseNode(i + 1), rightCol, validJoin.method);
+ dp.put(bs, new Cost(dim1, dim2, cost, joinNode));
+ }
+ for (int intervalLength = 2; intervalLength < bases.size();
intervalLength++) {
+ // join base relation from the left
+ for (int start = 1; start + intervalLength <=
bases.size(); start++) {
+ BitSet leftBS = new BitSet(bases.size());
+ leftBS.set(start - 1);
+ BitSet rightBS = new BitSet(bases.size());
+ rightBS.set(start, start + intervalLength);
+ if (dp.get(rightBS) == null) {
+ continue;
+ }
+ CanonicalJoin validJoin =
getValidJoin(canonicalJoins, leftBS, rightBS);
+ if (validJoin == null) {
+ continue;
+ }
+
+ BitSet bs = new BitSet(bases.size());
+ bs.set(start - 1, start + intervalLength);
+
+ Hop left = bases.get(start - 1);
+
+ Cost right = dp.get(rightBS);
+
+ long dim1 = left.getDim1() * right.dim1;
+ long dim2 = left.getDim2() + right.dim2;
+ long cost = dim1 * dim2 + right.cost;
+
+ long leftCol = validJoin.acol;
+ long rightCol =
getRelativeCol(basesLengthPrefixSum, start, validJoin.bBaseIndex,
validJoin.bcol);
+ JoinNode joinNode = new BinaryNode(new
BaseNode(start - 1), leftCol, right.node, rightCol, validJoin.method);
+ dp.put(bs, new Cost(dim1, dim2, cost,
joinNode));
+ }
+ // join base relation from the right
+ for (int start = 0; start + intervalLength + 1 <=
bases.size(); start++) {
+ BitSet leftBS = new BitSet(bases.size());
+ leftBS.set(start, start + intervalLength);
+ BitSet rightBS = new BitSet(bases.size());
+ rightBS.set(start + intervalLength);
+ BitSet bs = new BitSet(bases.size());
+ bs.set(start, start + intervalLength + 1);
+
+ if (dp.get(leftBS) == null)
+ continue;
+ CanonicalJoin validJoin =
getValidJoin(canonicalJoins, leftBS, rightBS);
+ if (validJoin == null)
+ continue;
+ BitSet leftBs = new BitSet(bases.size());
+ leftBs.set(start, start + intervalLength);
+ Cost left = dp.get(leftBs);
+
+ Hop right = bases.get(start + intervalLength);
+
+ long dim1 = left.dim1 * right.getDim1();
+ long dim2 = left.dim2 + right.getDim2();
+ long cost = dim1 * dim2 + left.cost;
+
+ if (dp.get(bs) == null || cost <
dp.get(bs).cost) {
+ long leftCol =
getRelativeCol(basesLengthPrefixSum, start, validJoin.aBaseIndex,
validJoin.acol);
+ long rightCol = validJoin.bcol;
+ JoinNode joinNode = new
BinaryNode(left.node(), leftCol, new BaseNode(start + intervalLength), rightCol,
+ validJoin.method);
+ dp.put(bs, new Cost(dim1, dim2, cost,
joinNode));
+ }
+ }
+ }
+ BitSet fullBs = new BitSet(bases.size());
+ fullBs.set(0, bases.size());
+ JoinNode optimalJoin = dp.get(fullBs).node;
+ // System.out.println("optimalJoin: " + optimalJoin);
+
+ // rewire the nodes.
+ StatementBlock rootSb = hopToSb.get(root);
+ ArrayList<Hop> rootSbHops = hopToSb.get(root).getHops();
+
+ ArrayList<DataOp> intermediateWrites = new ArrayList<>();
+ Hop newHop = generateHop(root, intermediateWrites, bases,
optimalJoin);
+
+ // remove and replace root
+ for (int i = 0; i < rootSbHops.size(); i++) {
+ if (rootSbHops.get(i) == root) {
+ rootSbHops.set(i, newHop);
+ }
+ }
+ HopRewriteUtils.rewireAllParentChildReferences(root, newHop);
+
+ // remove all consumed joins that now aren't used
+ HashSet<Hop> consumed = new HashSet<>();
+ for (int j = 0; j < joins.size(); j++)
+ if (visited[j])
+ consumed.add(joins.get(j));
+
+ rootSbHops.removeIf(consumed::contains);
+
+ // rootSbHops.addAll(0,intermediateWrites);
+ // add new Sb containing TWrites to right before it is consumed
+ StatementBlock newSb = createIntermediateStatementBlock(rootSb,
intermediateWrites);
+ sbs.add(sbs.indexOf(rootSb), newSb);
+ }
+
+ // get the column number relative to the current relation starting at
+ // `intervalStart
+ long getRelativeCol(ArrayList<Long> prefixSum, int intervalStart, int
baseIndex, long col) {
+ long offset = col;
+ if (intervalStart - 1 >= 0)
+ offset -= prefixSum.get(intervalStart - 1);
+ if (baseIndex - 1 >= 0)
+ offset += prefixSum.get(baseIndex - 1);
+ return offset;
+ }
+
+ // modified from RewriteHoistLoopInvariantOperations.java
+ private StatementBlock createIntermediateStatementBlock(StatementBlock
originalSb, List<DataOp> intermediateWrites) {
+ //create empty last-level statement block
+ StatementBlock ret = new StatementBlock();
+ ret.setDMLProg(originalSb.getDMLProg());
+ ret.setParseInfo(originalSb);
+ ret.setLiveIn(new VariableSet(originalSb.liveIn()));
+ ret.setLiveOut(new VariableSet(originalSb.liveIn()));
+
+ //put custom hops
+ ret.setHops(new ArrayList<>(intermediateWrites));
+
+ // live variable analysis
+ for (DataOp tWrite : intermediateWrites) {
+ String varName = tWrite.getName();
+ Hop hop = tWrite.getInput().get(0);
+ DataIdentifier diVar = new DataIdentifier(varName);
+ diVar.setDimensions(hop.getDim1(), hop.getDim2());
+ diVar.setBlocksize(hop.getBlocksize());
+ diVar.setDataType(hop.getDataType());
+ diVar.setValueType(hop.getValueType());
+ ret.liveOut().addVariable(varName, diVar);
+ originalSb.liveIn().addVariable(varName, diVar);
+ }
+
+ return ret;
+ }
+
+ // process a Hop to TRead and TWrite to be consumed.
+ private Hop materialize(Hop hop, ArrayList<DataOp> intermediateWrites) {
+ if (!(hop instanceof FunctionOp fop))
+ return hop;
+
+ String varName = fop.getOutputVariableNames()[0];
+
+ DataOp tWrite = HopRewriteUtils.createTransientWrite(varName,
fop);
+ intermediateWrites.add(tWrite);
+
+ return HopRewriteUtils.createTransientRead(varName, fop);
+ }
+
+ /**
+ * Generate the Hop to replace the existing root.
+ *
+ * @param root root of the current rewrite if
`optimalJoin` corresponds
+ *
to the root, otherwise null
+ * @param bases
+ * @param optimalJoin the current JoinNode we are constructing
+ */
+ private Hop generateHop(FunctionOp root, ArrayList<DataOp>
intermediateWrites, ArrayList<Hop> bases,
+ JoinNode optimalJoin) {
+ if (optimalJoin instanceof BaseNode baseNode) {
+ return bases.get(baseNode.i);
+ }
+ BinaryNode binaryNode = (BinaryNode) optimalJoin;
+
+ String[] inputNames = new String[] { "A", "colA", "B", "colB",
"method" };
+ String[] outputNames;
+ ArrayList<Hop> outputHops;
+
+ Hop a = generateHop(null, intermediateWrites, bases,
binaryNode.left);
+ a = materialize(a, intermediateWrites);
+ Hop colA = new LiteralOp(binaryNode.leftCol);
+ Hop b = generateHop(null, intermediateWrites, bases,
binaryNode.right);
+ b = materialize(b, intermediateWrites);
+ Hop colB = new LiteralOp(binaryNode.rightCol);
+ Hop method = new LiteralOp(binaryNode.method);
+
+ ArrayList<Hop> inputs = new ArrayList<>(List.of(a, colA, b,
colB, method));
+ String varName = "_rajoin_reorder_tmp_" + a.getHopID() + "_" +
b.getHopID();
+ if (root != null) {
+ outputNames = root.getOutputVariableNames();
+ outputHops = root.getOutputs();
+ } else {
+ outputNames = new String[] { varName };
+ outputHops = new ArrayList<>();
+ }
+
+ FunctionOp fop = new FunctionOp(FunctionOp.FunctionType.DML,
".builtinNS", "m_raJoin", inputNames, inputs,
+ outputNames, outputHops);
+ fop.setDim2(a.getDim2() + b.getDim2());
+ fop.setDataType(DataType.MATRIX);
+ fop.setValueType(ValueType.FP64);
+ if (root == null) {
+ // Return a TRead if it is not the root.
+ return materialize(fop, intermediateWrites);
+ }
+ return fop;
+ }
+
+ /**
+ * get a join that is applicable to left and right
+ *
+ * @param canonicalJoins
+ * @param left the bitset
representing the left side of the raJoin
+ * @param right the bitset representing
the right side of the raJoin
+ */
+ private CanonicalJoin getValidJoin(ArrayList<CanonicalJoin>
canonicalJoins, BitSet left, BitSet right) {
+ for (CanonicalJoin join : canonicalJoins) {
+ if (left.get(join.aBaseIndex) &&
right.get(join.bBaseIndex)) {
+ return join;
+ }
+ }
+ return null;
+ }
+
+ private record IntPair(int left, int right) {
+ };
+
+ // representation of the dependencies on the bases and its indices for
a given
+ // raJoin
+ private record CanonicalJoin(int aBaseIndex, long acol, int bBaseIndex,
long bcol, String method) {
+ };
+
+ /**
+ * Inorder traversal of an raJoin
+ *
+ * @param joinMap
+ * @param joins
+ * @param canonicaljoins
+ * @return inclusive [left, right] range of the indices of `joins` that
the
+ * current join corresponds to
+ */
+ private IntPair dfsInorder(HashMap<String, Integer> joinMap,
ArrayList<FunctionOp> joins,
+ ArrayList<CanonicalJoin> cannonicalJoins, boolean[]
visited, ArrayList<Hop> bases,
+ ArrayList<Long> basesLengthPrefixSum, int i) {
+ visited[i] = true;
+ FunctionOp join = joins.get(i);
+ Hop a = join.getInput(0);
+ long acol = ((LiteralOp) join.getInput(1)).getLongValue();
+
+ Hop b = join.getInput(2);
+ long bcol = ((LiteralOp) join.getInput(3)).getLongValue();
+
+ String method = ((LiteralOp) join.getInput(4)).getStringValue();
+ IntPair aPair;
+ if (isKnownMatrix(a)) {
+ bases.add(a);
+ basesLengthPrefixSum
+ .add((basesLengthPrefixSum.size() > 0 ?
basesLengthPrefixSum.get(basesLengthPrefixSum.size() - 1) : 0)
+ + a.getDim2());
+ aPair = new IntPair(bases.size() - 1, bases.size() - 1);
+ } else {
+ Integer aIndex = joinMap.get(a.getName());
+ if (aIndex == null)
+ throw new UnknownDimensionInfoException();
+ aPair = dfsInorder(joinMap, joins, cannonicalJoins,
visited, bases, basesLengthPrefixSum, aIndex);
+ }
+ IntPair bPair;
+ if (isKnownMatrix(b)) {
+ bases.add(b);
+ basesLengthPrefixSum
+ .add((basesLengthPrefixSum.size() > 0 ?
basesLengthPrefixSum.get(basesLengthPrefixSum.size() - 1) : 0)
+ + b.getDim2());
+ bPair = new IntPair(bases.size() - 1, bases.size() - 1);
+ } else {
+ Integer bIndex = joinMap.get(b.getName());
+ if (bIndex == null)
+ throw new UnknownDimensionInfoException();
+ bPair = dfsInorder(joinMap, joins, cannonicalJoins,
visited, bases, basesLengthPrefixSum, bIndex);
+ }
+ int aBaseIndex = -1;
+ for (int j = aPair.left; j <= aPair.right; j++) {
+ if (acol <= basesLengthPrefixSum.get(j)) {
+ // if (j - 1 >= 0) acol -=
basesLengthPrefixSum.get(j-1);
+ aBaseIndex = j;
+ break;
+ }
+ }
+
+ int bBaseIndex = -1;
+ for (int j = bPair.left; j <= bPair.right; j++) {
+ if (bcol <= basesLengthPrefixSum.get(j)) {
+ // if (j - 1 >= 0) bcol -=
basesLengthPrefixSum.get(j-1);
+ bBaseIndex = j;
+ break;
+ }
+ }
+ acol = getRelativeCol(basesLengthPrefixSum, aPair.left,
aBaseIndex, acol);
+ bcol = getRelativeCol(basesLengthPrefixSum, bPair.left,
bBaseIndex, bcol);
+
+ // throw an error and do not rewrite if we cannot figure out
the dependencies.
+ if (aBaseIndex < 0 || bBaseIndex < 0) {
+ throw new UnknownCanonicalJoinException();
+ }
+ cannonicalJoins.add(new CanonicalJoin(aBaseIndex, acol,
bBaseIndex, bcol, method));
+ return new IntPair(aPair.left, bPair.right);
+ }
+
+ @Override
+ public List<StatementBlock> rewriteStatementBlock(StatementBlock sb,
ProgramRewriteStatus state) {
+ List<StatementBlock> ret = new ArrayList<>();
+ ret.add(sb);
+ return ret;
+ }
+
+ @Override
+ public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock>
sbs, ProgramRewriteStatus state) {
+ HashMap<Hop, StatementBlock> hopToSb = new HashMap<>();
+ HashMap<String, Integer> joinMap = new HashMap<>();
+ ArrayList<FunctionOp> joins = new ArrayList<>();
+ ArrayList<StatementBlock> sbsA = new ArrayList<>(sbs);
+ for (StatementBlock sb : sbsA) {
+ collectRaJoin(hopToSb, sb, joinMap, joins);
+ }
+ try {
+ ArrayList<Integer> order = topoOrder(joinMap, joins);
+ rewriteRoots(sbsA, hopToSb, joinMap, joins, order);
+ } catch (Exception e) {
+ // if it is a local exception, try rewriting the next
root.
+ if (!((e instanceof UnknownCanonicalJoinException) ||
(e instanceof UnknownDimensionInfoException))) {
+ throw e;
+ }
+ }
+ return sbs;
+ }
+}
\ No newline at end of file
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRaJoinTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRaJoinTest.java
new file mode 100644
index 0000000000..f05970e312
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRaJoinTest.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+public class RewriteRaJoinTest extends AutomatedTestBase {
+ private final static String TEST_NAME = "raJoin";
+ private final static String TEST_DIR = "functions/rewrite/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
RewriteRaJoinTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] { "OUT" }));
+ }
+
+ @Test
+ public void testRaJoin() {
+ // Load test configuration (sets up temp input/output folders)
+ getAndLoadTestConfiguration(TEST_NAME);
+
+ // Create inputs
+ double[][] A = {{1,1},{1,1}};
+ double[][] B = {{3,2,1},{1,2,3},{3,1,2}};
+ double[][] C = {
+ {1,1,1,1},
+ {2,2,2,2},
+ {3,3,3,3},
+ {4,4,4,4}
+ };
+
+ MatrixCharacteristics mcA = new
MatrixCharacteristics(2,2,-1,-1);
+ writeInputMatrixWithMTD("A", A, true, mcA);
+
+ MatrixCharacteristics mcB = new
MatrixCharacteristics(3,3,-1,-1);
+ writeInputMatrixWithMTD("B", B, true, mcB);
+
+ MatrixCharacteristics mcC = new
MatrixCharacteristics(4,4,-1,-1);
+ writeInputMatrixWithMTD("C", C, true, mcC);
+
+ programArgs = new String[] {
+ "-explain", "hops",
+ "-stats",
+ "-args",
+ input("A"),
+ input("B"),
+ input("C"),
+ output("OUT")
+ };
+
+ fullDMLScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME + ".dml";
+
+ // Execute single threaded
+ ExecMode oldPlatform = setExecMode(ExecMode.SINGLE_NODE);
+ try {
+ runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+
+ HashMap<CellIndex, Double> out =
readDMLMatrixFromOutputDir("OUT");
+
+ System.out.println("Result matrix:");
+ for (CellIndex idx : out.keySet()) {
+ System.out.println(idx + " -> " + out.get(idx));
+ }
+
+ double[][] expected = {
+ {1,1,3,2,1,2,2,2,2},
+ {1,1,3,2,1,2,2,2,2}
+ };
+ HashMap<CellIndex, Double> expectedMap =
TestUtils.convert2DDoubleArrayToHashMap(expected);
+ TestUtils.compareMatrices(expectedMap, out, 1e-10,
"expected", "actual");
+
+ } finally {
+ rtplatform = oldPlatform;
+ }
+ }
+}
diff --git a/src/test/scripts/functions/rewrite/raJoin.dml
b/src/test/scripts/functions/rewrite/raJoin.dml
new file mode 100644
index 0000000000..f0dbc68d9c
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/raJoin.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# A = matrix(2, rows=2,cols=2)
+# B = matrix(3, rows=3,cols=3)
+# C = matrix(4, rows=4,cols=4)
+A = read($1)
+B = read($2)
+C = read($3)
+# A = matrix("1 1 1 1", rows=2, cols=2)
+# B = matrix("3 2 1 1 2 3 3 1 2", rows=3, cols=3)
+# C = matrix("1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4", rows=4, cols=4)
+ans = raJoin(A, 1, raJoin(B, 2, C, 3, "nested-loop"), 3, "nested-loop")
+# ans = raJoin(raJoin(A,1,B,3, "nested-loop"),4,C,3, "nested-loop")
+write(ans, $4)
\ No newline at end of file