This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new dc8e36db8c [SYSTEMDS-3708] Additional hash2 join method in raJoin
builtin
dc8e36db8c is described below
commit dc8e36db8cf332c2f74511c56a39a13a31932af4
Author: gghsu <[email protected]>
AuthorDate: Fri Aug 16 12:53:58 2024 +0200
[SYSTEMDS-3708] Additional hash2 join method in raJoin builtin
Closes #2056.
---
scripts/builtin/raJoin.dml | 78 +++++++++++++++++++++-
.../functions/builtin/part2/BuiltinRaJoinTest.java | 23 +++++++
2 files changed, 100 insertions(+), 1 deletion(-)
diff --git a/scripts/builtin/raJoin.dml b/scripts/builtin/raJoin.dml
index 2750b598f3..bd0ff9d91d 100644
--- a/scripts/builtin/raJoin.dml
+++ b/scripts/builtin/raJoin.dml
@@ -28,7 +28,7 @@
# colA Integer indicating the column index of matrix A to execute inner
join command
# B Matrix of right left data [shape: N x M]
# colA Integer indicating the column index of matrix B to execute inner
join command
-# method Join implementation method (nested-loop, sort-merge, hash)
+# method Join implementation method (nested-loop, sort-merge, hash, hash2)
#
------------------------------------------------------------------------------
#
# OUTPUT:
@@ -59,6 +59,7 @@ m_raJoin = function (Matrix[Double] A, Integer colA,
Matrix[Double] B,
}
}
}
+ # The sort-merge method is from original paper: Qery Processing on Tensor
Computation Runtime, section 5-2
else if (method == "sort-merge") {
# get join key columns
left = A[, colA]
@@ -143,6 +144,81 @@ m_raJoin = function (Matrix[Double] A, Integer colA,
Matrix[Double] B,
# Select left rows and concatenate right rows
Y = cbind(P1 %*% A, B);
}
+ # The hash2 method is from the original paper: Qery Processing on Tensor
Computation Runtime, section 5-3
+ else if ( method == "hash2" ) {
+ # Get join key columns
+ left = A[,colA]
+ right = B [,colB]
+
+ # Compute indexes and hash values
+ leftIdx = seq(1, nrow(A))
+ rightIdx = seq(1, nrow(B))
+ m = max(max(left),max(right)) + 1; # Assuming a large hash table size
+ #m = 100
+ leftHash = left %% m
+ rightHash = right %% m
+
+ # Build histogram of hash values for left join keys
+ hashBincount = table( leftHash, 1, max(leftHash), 1 )
+
+ #Initialize output indexes
+ leftOutIdx = matrix(0,0,1)
+ rightOutIdx = matrix(0,0,1)
+
+ # Check for one-to-many
+ if( max(hashBincount) > 1 )
+ stop("Hash join implementation only supports one-to-many joins:
"+toString(hashBincount))
+
+ # Build and probe hash table
+ # Initialize hash table
+ hashTable=matrix(0,m,1)
+
+ # Create a select or matrix and use matrix multiplication to place values
+ hashTable = t(table(seq(1,nrow(leftIdx)), leftHash, nrow(leftIdx),
nrow(hashTable))) %*% leftIdx
+
+ # Update lefHash to skip scattered values for future iterations by setting
their hashes to m
+ leftIdxSct = removeEmpty(target=seq(1,nrow(hashTable)), margin="rows",
select=(hashTable>=1))
+ selectedMatrix = table(seq(1, nrow(leftIdxSct)), leftIdxSct,
nrow(leftIdxSct), nrow(hashTable))
+ leftHash = t(selectedMatrix) %*% matrix(m, rows=nrow(leftIdxSct), cols=1,
byrow=TRUE)
+
+ #Probe hash table and get the left and right indexes
+ validLeftIdx = matrix(0,0,1)
+ validRightIdx = matrix(0,0,1)
+
+ lefCandIdx = table(seq(1, nrow(rightHash)), rightHash, nrow(rightHash),
nrow(hashTable)) %*% hashTable
+ validKeyMask = (lefCandIdx>0)
+
+ # Check if non matching
+ if( as.scalar(colSums(validKeyMask)) > 0 ){
+ validLeftIdx = removeEmpty(target=lefCandIdx, margin="rows",
select=validKeyMask)
+ validRightIdx = removeEmpty(target=rightIdx, margin="rows",
select=validKeyMask)
+
+ # Find matching join keys
+ selectedValidLeftIdx = table(seq(1,nrow(validLeftIdx)), validLeftIdx,
nrow(validLeftIdx), nrow(left)) %*% left
+ selectedValidRightIdx = table(seq(1,nrow(validRightIdx)), validRightIdx,
nrow(validRightIdx), cols=nrow(right)) %*% right
+
+ matchMask = ( selectedValidLeftIdx == selectedValidRightIdx )
+ if ( as.scalar(colSums(matchMask[,1])) > 0) {
+ leftMatchIdx = removeEmpty(target=validLeftIdx, margin="rows",
select=matchMask)
+ rightMatchIdx = removeEmpty(target=validRightIdx, margin="rows",
select=matchMask)
+
+ #Append indexes to global results
+ leftOutIdx = rbind(leftOutIdx, removeEmpty(target=leftMatchIdx,
margin="rows"))
+ rightOutIdx = rbind(rightOutIdx, removeEmpty(target=rightMatchIdx,
margin="rows"))
+ }
+ }
+
+ # Create output
+ if ( nrow(leftOutIdx) == 0 | nrow(rightOutIdx) == 0 ) {
+ Y = matrix(0, rows=0, cols=1)
+ }
+ else {
+ Y = matrix(0, rows=nrow(leftOutIdx), cols=ncol(A)+ncol(B))
+ for( j in 1:nrow(leftOutIdx) ) {
+ Y[j, ] = cbind( A[as.scalar(leftOutIdx[j]), ],
B[as.scalar(rightOutIdx[j]), ] )
+ }
+ }
+ }
}
# Function to perform parallel binary search
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaJoinTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaJoinTest.java
index e2ea8a6512..97b820d2b1 100644
---
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaJoinTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaJoinTest.java
@@ -60,6 +60,10 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
public void testRaJoinTestwithDifferentColumn2() {
testRaJoinTestwithDifferentColumn("sort-merge");
}
+ @Test
+ public void testRaJoinTestwithDifferentColumn3() {
+ testRaJoinTestwithDifferentColumn("hash2");
+ }
@Test
public void testRaJoinTestwithDifferentColumn21() {
@@ -70,6 +74,10 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
public void testRaJoinTestwithDifferentColumn22() {
testRaJoinTestwithDifferentColumn2("sort-merge");
}
+ @Test
+ public void testRaJoinTestwithDifferentColumn23() {
+ testRaJoinTestwithDifferentColumn2("hash2");
+ }
@Test
public void testRaJoinTestwithNoMatchingRows1() {
@@ -80,6 +88,11 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
public void testRaJoinTestwithNoMatchingRows2() {
testRaJoinTestwithNoMatchingRows("sort-merge");
}
+
+ @Test
+ public void testRaJoinTestwithNoMatchingRows3() {
+ testRaJoinTestwithNoMatchingRows("hash2");
+ }
@Test
public void testRaJoinTestwithAllMatchingRows1() {
@@ -95,6 +108,11 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
public void testRaJoinTestwithAllMatchingRows3() {
testRaJoinTestwithAllMatchingRows("hash");
}
+
+ @Test
+ public void testRaJoinTestwithAllMatchingRows4() {
+ testRaJoinTestwithAllMatchingRows("hash2");
+ }
@Test
public void testRaJoinTestwithOneToMany1() {
@@ -110,6 +128,11 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
public void testRaJoinTestwithOneToMany3() {
testRaJoinTestwithOneToMany("hash");
}
+
+ @Test
+ public void testRaJoinTestwithOneToMany4() {
+ testRaJoinTestwithOneToMany("hash2");
+ }
private void testRaJoinTest(String method) {