This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new dc8e36db8c [SYSTEMDS-3708] Additional hash2 join method in raJoin builtin dc8e36db8c is described below commit dc8e36db8cf332c2f74511c56a39a13a31932af4 Author: gghsu <ppp432...@gmail.com> AuthorDate: Fri Aug 16 12:53:58 2024 +0200 [SYSTEMDS-3708] Additional hash2 join method in raJoin builtin Closes #2056. --- scripts/builtin/raJoin.dml | 78 +++++++++++++++++++++- .../functions/builtin/part2/BuiltinRaJoinTest.java | 23 +++++++ 2 files changed, 100 insertions(+), 1 deletion(-) diff --git a/scripts/builtin/raJoin.dml b/scripts/builtin/raJoin.dml index 2750b598f3..bd0ff9d91d 100644 --- a/scripts/builtin/raJoin.dml +++ b/scripts/builtin/raJoin.dml @@ -28,7 +28,7 @@ # colA Integer indicating the column index of matrix A to execute inner join command # B Matrix of right left data [shape: N x M] # colA Integer indicating the column index of matrix B to execute inner join command -# method Join implementation method (nested-loop, sort-merge, hash) +# method Join implementation method (nested-loop, sort-merge, hash, hash2) # ------------------------------------------------------------------------------ # # OUTPUT: @@ -59,6 +59,7 @@ m_raJoin = function (Matrix[Double] A, Integer colA, Matrix[Double] B, } } } + # The sort-merge method is from original paper: Qery Processing on Tensor Computation Runtime, section 5-2 else if (method == "sort-merge") { # get join key columns left = A[, colA] @@ -143,6 +144,81 @@ m_raJoin = function (Matrix[Double] A, Integer colA, Matrix[Double] B, # Select left rows and concatenate right rows Y = cbind(P1 %*% A, B); } + # The hash2 method is from the original paper: Qery Processing on Tensor Computation Runtime, section 5-3 + else if ( method == "hash2" ) { + # Get join key columns + left = A[,colA] + right = B [,colB] + + # Compute indexes and hash values + leftIdx = seq(1, nrow(A)) + rightIdx = seq(1, nrow(B)) + m = max(max(left),max(right)) + 1; # Assuming a large hash table size + #m = 100 + leftHash = left %% m + rightHash = right %% m + + # Build histogram of hash values for left join keys + hashBincount = table( leftHash, 1, max(leftHash), 1 ) + + #Initialize output indexes + leftOutIdx = matrix(0,0,1) + rightOutIdx = matrix(0,0,1) + + # Check for one-to-many + if( max(hashBincount) > 1 ) + stop("Hash join implementation only supports one-to-many joins: "+toString(hashBincount)) + + # Build and probe hash table + # Initialize hash table + hashTable=matrix(0,m,1) + + # Create a select or matrix and use matrix multiplication to place values + hashTable = t(table(seq(1,nrow(leftIdx)), leftHash, nrow(leftIdx), nrow(hashTable))) %*% leftIdx + + # Update lefHash to skip scattered values for future iterations by setting their hashes to m + leftIdxSct = removeEmpty(target=seq(1,nrow(hashTable)), margin="rows", select=(hashTable>=1)) + selectedMatrix = table(seq(1, nrow(leftIdxSct)), leftIdxSct, nrow(leftIdxSct), nrow(hashTable)) + leftHash = t(selectedMatrix) %*% matrix(m, rows=nrow(leftIdxSct), cols=1, byrow=TRUE) + + #Probe hash table and get the left and right indexes + validLeftIdx = matrix(0,0,1) + validRightIdx = matrix(0,0,1) + + lefCandIdx = table(seq(1, nrow(rightHash)), rightHash, nrow(rightHash), nrow(hashTable)) %*% hashTable + validKeyMask = (lefCandIdx>0) + + # Check if non matching + if( as.scalar(colSums(validKeyMask)) > 0 ){ + validLeftIdx = removeEmpty(target=lefCandIdx, margin="rows", select=validKeyMask) + validRightIdx = removeEmpty(target=rightIdx, margin="rows", select=validKeyMask) + + # Find matching join keys + selectedValidLeftIdx = table(seq(1,nrow(validLeftIdx)), validLeftIdx, nrow(validLeftIdx), nrow(left)) %*% left + selectedValidRightIdx = table(seq(1,nrow(validRightIdx)), validRightIdx, nrow(validRightIdx), cols=nrow(right)) %*% right + + matchMask = ( selectedValidLeftIdx == selectedValidRightIdx ) + if ( as.scalar(colSums(matchMask[,1])) > 0) { + leftMatchIdx = removeEmpty(target=validLeftIdx, margin="rows", select=matchMask) + rightMatchIdx = removeEmpty(target=validRightIdx, margin="rows", select=matchMask) + + #Append indexes to global results + leftOutIdx = rbind(leftOutIdx, removeEmpty(target=leftMatchIdx, margin="rows")) + rightOutIdx = rbind(rightOutIdx, removeEmpty(target=rightMatchIdx, margin="rows")) + } + } + + # Create output + if ( nrow(leftOutIdx) == 0 | nrow(rightOutIdx) == 0 ) { + Y = matrix(0, rows=0, cols=1) + } + else { + Y = matrix(0, rows=nrow(leftOutIdx), cols=ncol(A)+ncol(B)) + for( j in 1:nrow(leftOutIdx) ) { + Y[j, ] = cbind( A[as.scalar(leftOutIdx[j]), ], B[as.scalar(rightOutIdx[j]), ] ) + } + } + } } # Function to perform parallel binary search diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaJoinTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaJoinTest.java index e2ea8a6512..97b820d2b1 100644 --- a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaJoinTest.java +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaJoinTest.java @@ -60,6 +60,10 @@ public class BuiltinRaJoinTest extends AutomatedTestBase public void testRaJoinTestwithDifferentColumn2() { testRaJoinTestwithDifferentColumn("sort-merge"); } + @Test + public void testRaJoinTestwithDifferentColumn3() { + testRaJoinTestwithDifferentColumn("hash2"); + } @Test public void testRaJoinTestwithDifferentColumn21() { @@ -70,6 +74,10 @@ public class BuiltinRaJoinTest extends AutomatedTestBase public void testRaJoinTestwithDifferentColumn22() { testRaJoinTestwithDifferentColumn2("sort-merge"); } + @Test + public void testRaJoinTestwithDifferentColumn23() { + testRaJoinTestwithDifferentColumn2("hash2"); + } @Test public void testRaJoinTestwithNoMatchingRows1() { @@ -80,6 +88,11 @@ public class BuiltinRaJoinTest extends AutomatedTestBase public void testRaJoinTestwithNoMatchingRows2() { testRaJoinTestwithNoMatchingRows("sort-merge"); } + + @Test + public void testRaJoinTestwithNoMatchingRows3() { + testRaJoinTestwithNoMatchingRows("hash2"); + } @Test public void testRaJoinTestwithAllMatchingRows1() { @@ -95,6 +108,11 @@ public class BuiltinRaJoinTest extends AutomatedTestBase public void testRaJoinTestwithAllMatchingRows3() { testRaJoinTestwithAllMatchingRows("hash"); } + + @Test + public void testRaJoinTestwithAllMatchingRows4() { + testRaJoinTestwithAllMatchingRows("hash2"); + } @Test public void testRaJoinTestwithOneToMany1() { @@ -110,6 +128,11 @@ public class BuiltinRaJoinTest extends AutomatedTestBase public void testRaJoinTestwithOneToMany3() { testRaJoinTestwithOneToMany("hash"); } + + @Test + public void testRaJoinTestwithOneToMany4() { + testRaJoinTestwithOneToMany("hash2"); + } private void testRaJoinTest(String method) {