This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 195f83ba48 [SYSTEMDS-3708] Additional hash join method for raJoin
builtin
195f83ba48 is described below
commit 195f83ba480455a8db0a8c9a5a7a601e62d0ab5d
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Jul 23 19:46:25 2024 +0200
[SYSTEMDS-3708] Additional hash join method for raJoin builtin
This patch adds broader tests for all join methods (e.g., nested-loop,
sort-merge) and adds a new hash join method that performs the join
without sort or loop-based concatenation but currently only supports
one-to-many joins.
---
scripts/builtin/raJoin.dml | 24 ++++
.../functions/builtin/part2/BuiltinRaJoinTest.java | 128 ++++++++++++++++++---
src/test/scripts/functions/builtin/raJoin.dml | 4 +-
3 files changed, 137 insertions(+), 19 deletions(-)
diff --git a/scripts/builtin/raJoin.dml b/scripts/builtin/raJoin.dml
index b7b299e7a8..bddf94c019 100644
--- a/scripts/builtin/raJoin.dml
+++ b/scripts/builtin/raJoin.dml
@@ -124,6 +124,30 @@ m_raJoin = function (Matrix[Double] A, Integer colA,
Matrix[Double] B,
Y = matrix(0, rows=0, cols=1)
}
}
+ else if( method == "hash" ) {
+ # Ensure histograms are aligned by creating a common set of keys
+ commonKeys = max(max(A[,colA]), max(B[,colB]));
+
+ # Build histograms for the left and right key columns
+ leftHist = table(A[,colA], 1, commonKeys, 1)
+ rightHist = table(B[,colB], 1, commonKeys, 1)
+ hist = leftHist * rightHist;
+
+ # Check for one-to-many
+ if( max(leftHist)>1 )
+ stop("Hash join implementation only supports one-to-many joins:
"+toString(leftHist));
+
+ # Compute selection matrices P1 (one-side) with row duplication
+ keyPos1 = rowIndexMax(table(A[,colA], seq(1,nrow(A)), commonKeys, nrow(A)))
+ keyPos1 = removeEmpty(target=keyPos1, margin="rows", select=hist);
+ hist = removeEmpty(target=hist, margin="rows");
+ I1 = t(cumsum(rev(t(table(seq(1,nrow(hist)),hist))))) * keyPos1
+ I1 = removeEmpty(target=matrix(I1, nrow(I1)*ncol(I1),1), margin="rows"); #
keys
+ P1 = table(seq(1,nrow(I1)), I1);
+
+ # Select left rows and concatenate right rows
+ Y = cbind(P1 %*% A, B);
+ }
}
# Function to perform parallel binary search
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaJoinTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaJoinTest.java
index f19262b854..e8f803633f 100644
---
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaJoinTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaJoinTest.java
@@ -41,10 +41,78 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
addTestConfiguration(TEST_NAME,new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"result"}));
}
- // TODO test all join methods
+ @Test
+ public void testRaJoinTest1() {
+ testRaJoinTest("nested-loop");
+ }
+
+ @Test
+ public void testRaJoinTest2() {
+ testRaJoinTest("sort-merge");
+ }
+
+ @Test
+ public void testRaJoinTestwithDifferentColumn1() {
+ testRaJoinTestwithDifferentColumn("nested-loop");
+ }
+
+ @Test
+ public void testRaJoinTestwithDifferentColumn2() {
+ testRaJoinTestwithDifferentColumn("sort-merge");
+ }
+
+ @Test
+ public void testRaJoinTestwithDifferentColumn21() {
+ testRaJoinTestwithDifferentColumn2("nested-loop");
+ }
+
+ @Test
+ public void testRaJoinTestwithDifferentColumn22() {
+ testRaJoinTestwithDifferentColumn2("sort-merge");
+ }
+
+ @Test
+ public void testRaJoinTestwithNoMatchingRows1() {
+ testRaJoinTestwithNoMatchingRows("nested-loop");
+ }
+
+ @Test
+ public void testRaJoinTestwithNoMatchingRows2() {
+ testRaJoinTestwithNoMatchingRows("sort-merge");
+ }
+
+ @Test
+ public void testRaJoinTestwithAllMatchingRows1() {
+ testRaJoinTestwithAllMatchingRows("nested-loop");
+ }
+
+ @Test
+ public void testRaJoinTestwithAllMatchingRows2() {
+ testRaJoinTestwithAllMatchingRows("sort-merge");
+ }
@Test
- public void testRaJoinTest() {
+ public void testRaJoinTestwithAllMatchingRows3() {
+ testRaJoinTestwithAllMatchingRows("hash");
+ }
+
+ @Test
+ public void testRaJoinTestwithOneToMany1() {
+ testRaJoinTestwithOneToMany("nested-loop");
+ }
+
+ @Test
+ public void testRaJoinTestwithOneToMany2() {
+ testRaJoinTestwithOneToMany("sort-merge");
+ }
+
+ @Test
+ public void testRaJoinTestwithOneToMany3() {
+ testRaJoinTestwithOneToMany("hash");
+ }
+
+
+ private void testRaJoinTest(String method) {
//generate actual dataset and variables
double[][] A = {
{1, 2, 3},
@@ -72,11 +140,10 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
{4, 3, 5, 4, 7, 8},
{4, 3, 5, 4, 5, 10},
};
- runRaJoinTest(A, colA, B, colB, Y);
+ runRaJoinTest(A, colA, B, colB, Y, method);
}
- @Test
- public void testRaJoinTestwithDifferentColumn() {
+ private void testRaJoinTestwithDifferentColumn(String method) {
// Generate actual dataset and variables
double[][] A = {
{1, 5, 3},
@@ -100,11 +167,10 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
{2, 6, 8, 3, 7, 6},
{3, 7, 6, 2, 8, 7}
};
- runRaJoinTest(A, colA, B, colB, Y);
+ runRaJoinTest(A, colA, B, colB, Y, method);
}
- @Test
- public void testRaJoinTestwithDifferentColumn2() {
+ private void testRaJoinTestwithDifferentColumn2(String method) {
// Generate actual dataset and variables
double[][] A = {
{1, 2, 3, 4, 5},
@@ -127,11 +193,10 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
{6, 7, 8, 9, 10, 1, 10, 200},
{21, 22, 23, 24, 25, 50, 25, 500}
};
- runRaJoinTest(A, colA, B, colB, Y);
+ runRaJoinTest(A, colA, B, colB, Y, method);
}
- @Test
- public void testRaJoinTestwithNoMatchingRows() {
+ private void testRaJoinTestwithNoMatchingRows(String method) {
// Generate actual dataset and variables
double[][] A = {
{1, 2, 3},
@@ -148,11 +213,10 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
// Expected output matrix (no matching rows)
double[][] Y = {};
- runRaJoinTest(A, colA, B, colB, Y);
+ runRaJoinTest(A, colA, B, colB, Y, method);
}
- @Test
- public void testRaJoinTestwithAllMatchingRows() {
+ private void testRaJoinTestwithAllMatchingRows(String method) {
// Generate actual dataset and variables
double[][] A = {
{1, 2, 3},
@@ -173,10 +237,39 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
{2, 3, 4, 2, 3, 7},
{3, 4, 5, 3, 4, 8}
};
- runRaJoinTest(A, colA, B, colB, Y);
+ runRaJoinTest(A, colA, B, colB, Y, method);
+ }
+
+ private void testRaJoinTestwithOneToMany(String method) {
+ // Generate actual dataset and variables
+ double[][] A = {
+ {2, 2, 2},
+ {3, 3, 3},
+ {4, 4, 4}
+ };
+ double[][] B = {
+ {2, 1, 1},
+ {2, 2, 2},
+ {3, 1, 1},
+ {3, 2, 2},
+ {3, 3, 3},
+ {4, 1, 1}
+ };
+ int colA = 1;
+ int colB = 1;
+
+ double[][] Y = {
+ {2, 2, 2, 2, 1, 1},
+ {2, 2, 2, 2, 2, 2},
+ {3, 3, 3, 3, 1, 1},
+ {3, 3, 3, 3, 2, 2},
+ {3, 3, 3, 3, 3, 3},
+ {4, 4, 4, 4, 1, 1}
+ };
+ runRaJoinTest(A, colA, B, colB, Y, method);
}
- private void runRaJoinTest(double [][] A, int colA, double [][] B, int
colB, double [][] Y)
+ private void runRaJoinTest(double [][] A, int colA, double [][] B, int
colB, double [][] Y, String method)
{
ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE);
@@ -187,7 +280,8 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[]{"-stats", "-args",
- input("A"), String.valueOf(colA), input("B"),
String.valueOf(colB), output("result") };
+ input("A"), String.valueOf(colA), input("B"),
+ String.valueOf(colB), method, output("result")
};
System.out.println(Arrays.deepToString(A));
System.out.println(colA);
//fullRScriptName = HOME + TEST_NAME + ".R";
diff --git a/src/test/scripts/functions/builtin/raJoin.dml
b/src/test/scripts/functions/builtin/raJoin.dml
index 63aa2807c8..eedab1433e 100644
--- a/src/test/scripts/functions/builtin/raJoin.dml
+++ b/src/test/scripts/functions/builtin/raJoin.dml
@@ -24,6 +24,6 @@ colA = as.integer($2)
B = read($3)
colB = as.integer($4)
-result = raJoin(A, colA, B, colB, "sort-merge");
-write(result, $5);
+result = raJoin(A, colA, B, colB, $5);
+write(result, $6);