This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new dc8e36db8c [SYSTEMDS-3708] Additional hash2 join method in raJoin 
builtin
dc8e36db8c is described below

commit dc8e36db8cf332c2f74511c56a39a13a31932af4
Author: gghsu <ppp432...@gmail.com>
AuthorDate: Fri Aug 16 12:53:58 2024 +0200

    [SYSTEMDS-3708] Additional hash2 join method in raJoin builtin
    
    Closes #2056.
---
 scripts/builtin/raJoin.dml                         | 78 +++++++++++++++++++++-
 .../functions/builtin/part2/BuiltinRaJoinTest.java | 23 +++++++
 2 files changed, 100 insertions(+), 1 deletion(-)

diff --git a/scripts/builtin/raJoin.dml b/scripts/builtin/raJoin.dml
index 2750b598f3..bd0ff9d91d 100644
--- a/scripts/builtin/raJoin.dml
+++ b/scripts/builtin/raJoin.dml
@@ -28,7 +28,7 @@
 # colA      Integer indicating the column index of matrix A to execute inner 
join command
 # B         Matrix of right left data [shape: N x M]
 # colA      Integer indicating the column index of matrix B to execute inner 
join command
-# method    Join implementation method (nested-loop, sort-merge, hash)
+# method    Join implementation method (nested-loop, sort-merge, hash, hash2)
 # 
------------------------------------------------------------------------------
 #
 # OUTPUT:
@@ -59,6 +59,7 @@ m_raJoin = function (Matrix[Double] A, Integer colA, 
Matrix[Double] B,
       }
     }
   }
+  # The sort-merge method is from original paper: Qery Processing on Tensor 
Computation Runtime, section 5-2
   else if (method == "sort-merge") {
     # get join key columns
     left = A[, colA]
@@ -143,6 +144,81 @@ m_raJoin = function (Matrix[Double] A, Integer colA, 
Matrix[Double] B,
     # Select left rows and concatenate right rows
     Y = cbind(P1 %*% A, B);
   }
+  # The hash2 method is from the original paper: Qery Processing on Tensor 
Computation Runtime, section 5-3
+  else if ( method == "hash2" ) {
+    # Get join key columns
+    left = A[,colA]
+    right = B [,colB]
+
+    # Compute indexes and hash values
+    leftIdx = seq(1, nrow(A))
+    rightIdx = seq(1, nrow(B))
+    m = max(max(left),max(right)) + 1; # Assuming a large hash table size
+    #m = 100
+    leftHash = left %% m
+    rightHash = right %% m
+
+    # Build histogram of hash values for left join keys
+    hashBincount = table( leftHash, 1, max(leftHash), 1 )
+
+    #Initialize output indexes
+    leftOutIdx = matrix(0,0,1)
+    rightOutIdx = matrix(0,0,1)
+
+    # Check for one-to-many
+    if( max(hashBincount) > 1 )
+      stop("Hash join implementation only supports one-to-many joins: 
"+toString(hashBincount))
+
+    # Build and probe hash table
+    # Initialize hash table
+    hashTable=matrix(0,m,1)
+
+    # Create a select or matrix and use matrix multiplication to place values
+    hashTable = t(table(seq(1,nrow(leftIdx)), leftHash, nrow(leftIdx), 
nrow(hashTable))) %*% leftIdx
+
+    # Update lefHash to skip scattered values for future iterations by setting 
their hashes to m
+    leftIdxSct = removeEmpty(target=seq(1,nrow(hashTable)), margin="rows", 
select=(hashTable>=1))
+    selectedMatrix = table(seq(1, nrow(leftIdxSct)), leftIdxSct, 
nrow(leftIdxSct), nrow(hashTable))
+    leftHash = t(selectedMatrix) %*% matrix(m, rows=nrow(leftIdxSct), cols=1, 
byrow=TRUE)
+
+    #Probe hash table and get the left and right indexes
+    validLeftIdx = matrix(0,0,1)
+    validRightIdx = matrix(0,0,1)
+
+    lefCandIdx = table(seq(1, nrow(rightHash)), rightHash, nrow(rightHash), 
nrow(hashTable)) %*% hashTable
+    validKeyMask = (lefCandIdx>0)
+
+    # Check if non matching
+    if( as.scalar(colSums(validKeyMask)) > 0 ){
+      validLeftIdx = removeEmpty(target=lefCandIdx, margin="rows", 
select=validKeyMask)
+      validRightIdx = removeEmpty(target=rightIdx, margin="rows", 
select=validKeyMask)
+
+      # Find matching join keys
+      selectedValidLeftIdx = table(seq(1,nrow(validLeftIdx)), validLeftIdx, 
nrow(validLeftIdx), nrow(left)) %*% left
+      selectedValidRightIdx = table(seq(1,nrow(validRightIdx)), validRightIdx, 
nrow(validRightIdx), cols=nrow(right)) %*% right
+
+      matchMask = ( selectedValidLeftIdx == selectedValidRightIdx )
+      if ( as.scalar(colSums(matchMask[,1])) > 0) {
+        leftMatchIdx = removeEmpty(target=validLeftIdx, margin="rows", 
select=matchMask)
+        rightMatchIdx = removeEmpty(target=validRightIdx, margin="rows", 
select=matchMask)
+
+        #Append indexes to global results
+        leftOutIdx = rbind(leftOutIdx, removeEmpty(target=leftMatchIdx, 
margin="rows"))
+        rightOutIdx = rbind(rightOutIdx, removeEmpty(target=rightMatchIdx, 
margin="rows"))
+      }
+    }
+
+    # Create output
+    if ( nrow(leftOutIdx) == 0 | nrow(rightOutIdx) == 0 ) {
+      Y = matrix(0, rows=0, cols=1)
+    }
+    else {
+      Y = matrix(0, rows=nrow(leftOutIdx), cols=ncol(A)+ncol(B))
+      for( j in 1:nrow(leftOutIdx) ) {
+        Y[j, ] = cbind( A[as.scalar(leftOutIdx[j]), ], 
B[as.scalar(rightOutIdx[j]), ] )
+      }
+    }
+  }
 }
 
 # Function to perform parallel binary search
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaJoinTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaJoinTest.java
index e2ea8a6512..97b820d2b1 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaJoinTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaJoinTest.java
@@ -60,6 +60,10 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
        public void testRaJoinTestwithDifferentColumn2() {
                testRaJoinTestwithDifferentColumn("sort-merge");
        }
+       @Test
+       public void testRaJoinTestwithDifferentColumn3() {
+               testRaJoinTestwithDifferentColumn("hash2");
+       }
        
        @Test
        public void testRaJoinTestwithDifferentColumn21() {
@@ -70,6 +74,10 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
        public void testRaJoinTestwithDifferentColumn22() {
                testRaJoinTestwithDifferentColumn2("sort-merge");
        }
+       @Test
+       public void testRaJoinTestwithDifferentColumn23() {
+               testRaJoinTestwithDifferentColumn2("hash2");
+       }
        
        @Test
        public void testRaJoinTestwithNoMatchingRows1() {
@@ -80,6 +88,11 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
        public void testRaJoinTestwithNoMatchingRows2() {
                testRaJoinTestwithNoMatchingRows("sort-merge");
        }
+
+       @Test
+       public void testRaJoinTestwithNoMatchingRows3() {
+               testRaJoinTestwithNoMatchingRows("hash2");
+       }
        
        @Test
        public void testRaJoinTestwithAllMatchingRows1() {
@@ -95,6 +108,11 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
        public void testRaJoinTestwithAllMatchingRows3() {
                testRaJoinTestwithAllMatchingRows("hash");
        }
+
+       @Test
+       public void testRaJoinTestwithAllMatchingRows4() {
+               testRaJoinTestwithAllMatchingRows("hash2");
+       }
        
        @Test
        public void testRaJoinTestwithOneToMany1() {
@@ -110,6 +128,11 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
        public void testRaJoinTestwithOneToMany3() {
                testRaJoinTestwithOneToMany("hash");
        }
+
+       @Test
+       public void testRaJoinTestwithOneToMany4() {
+               testRaJoinTestwithOneToMany("hash2");
+       }
        
        
        private void testRaJoinTest(String method) {

Reply via email to