This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 195f83ba48 [SYSTEMDS-3708] Additional hash join method for raJoin 
builtin
195f83ba48 is described below

commit 195f83ba480455a8db0a8c9a5a7a601e62d0ab5d
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Jul 23 19:46:25 2024 +0200

    [SYSTEMDS-3708] Additional hash join method for raJoin builtin
    
    This patch adds broader tests for all join methods (e.g., nested-loop,
    sort-merge) and adds a new hash join method that performs the join
    without sort or loop-based concatenation but currently only supports
    one-to-many joins.
---
 scripts/builtin/raJoin.dml                         |  24 ++++
 .../functions/builtin/part2/BuiltinRaJoinTest.java | 128 ++++++++++++++++++---
 src/test/scripts/functions/builtin/raJoin.dml      |   4 +-
 3 files changed, 137 insertions(+), 19 deletions(-)

diff --git a/scripts/builtin/raJoin.dml b/scripts/builtin/raJoin.dml
index b7b299e7a8..bddf94c019 100644
--- a/scripts/builtin/raJoin.dml
+++ b/scripts/builtin/raJoin.dml
@@ -124,6 +124,30 @@ m_raJoin = function (Matrix[Double] A, Integer colA, 
Matrix[Double] B,
       Y = matrix(0, rows=0, cols=1)
     }
   }
+  else if( method == "hash" ) {
+    # Ensure histograms are aligned by creating a common set of keys
+    commonKeys = max(max(A[,colA]), max(B[,colB]));
+
+    # Build histograms for the left and right key columns
+    leftHist = table(A[,colA], 1, commonKeys, 1)
+    rightHist = table(B[,colB], 1, commonKeys, 1)
+    hist = leftHist * rightHist;
+
+    # Check for one-to-many
+    if( max(leftHist)>1 )
+      stop("Hash join implementation only supports one-to-many joins: 
"+toString(leftHist));
+
+    # Compute selection matrices P1 (one-side) with row duplication
+    keyPos1 = rowIndexMax(table(A[,colA], seq(1,nrow(A)), commonKeys, nrow(A)))
+    keyPos1 = removeEmpty(target=keyPos1, margin="rows", select=hist);
+    hist = removeEmpty(target=hist, margin="rows");
+    I1 = t(cumsum(rev(t(table(seq(1,nrow(hist)),hist))))) * keyPos1
+    I1 = removeEmpty(target=matrix(I1, nrow(I1)*ncol(I1),1), margin="rows"); # 
keys
+    P1 = table(seq(1,nrow(I1)), I1);
+
+    # Select left rows and concatenate right rows
+    Y = cbind(P1 %*% A, B);
+  }
 }
 
 # Function to perform parallel binary search
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaJoinTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaJoinTest.java
index f19262b854..e8f803633f 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaJoinTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaJoinTest.java
@@ -41,10 +41,78 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
                addTestConfiguration(TEST_NAME,new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"result"}));
        }
 
-       // TODO test all join methods
+       @Test
+       public void testRaJoinTest1() {
+               testRaJoinTest("nested-loop");
+       }
+       
+       @Test
+       public void testRaJoinTest2() {
+               testRaJoinTest("sort-merge");
+       }
+       
+       @Test
+       public void testRaJoinTestwithDifferentColumn1() {
+               testRaJoinTestwithDifferentColumn("nested-loop");
+       }
+       
+       @Test
+       public void testRaJoinTestwithDifferentColumn2() {
+               testRaJoinTestwithDifferentColumn("sort-merge");
+       }
+       
+       @Test
+       public void testRaJoinTestwithDifferentColumn21() {
+               testRaJoinTestwithDifferentColumn2("nested-loop");
+       }
+       
+       @Test
+       public void testRaJoinTestwithDifferentColumn22() {
+               testRaJoinTestwithDifferentColumn2("sort-merge");
+       }
+       
+       @Test
+       public void testRaJoinTestwithNoMatchingRows1() {
+               testRaJoinTestwithNoMatchingRows("nested-loop");
+       }
+       
+       @Test
+       public void testRaJoinTestwithNoMatchingRows2() {
+               testRaJoinTestwithNoMatchingRows("sort-merge");
+       }
+       
+       @Test
+       public void testRaJoinTestwithAllMatchingRows1() {
+               testRaJoinTestwithAllMatchingRows("nested-loop");
+       }
+       
+       @Test
+       public void testRaJoinTestwithAllMatchingRows2() {
+               testRaJoinTestwithAllMatchingRows("sort-merge");
+       }
        
        @Test
-       public void testRaJoinTest() {
+       public void testRaJoinTestwithAllMatchingRows3() {
+               testRaJoinTestwithAllMatchingRows("hash");
+       }
+       
+       @Test
+       public void testRaJoinTestwithOneToMany1() {
+               testRaJoinTestwithOneToMany("nested-loop");
+       }
+       
+       @Test
+       public void testRaJoinTestwithOneToMany2() {
+               testRaJoinTestwithOneToMany("sort-merge");
+       }
+       
+       @Test
+       public void testRaJoinTestwithOneToMany3() {
+               testRaJoinTestwithOneToMany("hash");
+       }
+       
+       
+       private void testRaJoinTest(String method) {
                //generate actual dataset and variables
                double[][] A = {
                                {1, 2, 3},
@@ -72,11 +140,10 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
                                {4, 3, 5, 4, 7, 8},
                                {4, 3, 5, 4, 5, 10},
                };
-               runRaJoinTest(A, colA, B, colB, Y);
+               runRaJoinTest(A, colA, B, colB, Y, method);
        }
 
-       @Test
-       public void testRaJoinTestwithDifferentColumn() {
+       private void testRaJoinTestwithDifferentColumn(String method) {
                // Generate actual dataset and variables
                double[][] A = {
                                {1, 5, 3},
@@ -100,11 +167,10 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
                                {2, 6, 8, 3, 7, 6},
                                {3, 7, 6, 2, 8, 7}
                };
-               runRaJoinTest(A, colA, B, colB, Y);
+               runRaJoinTest(A, colA, B, colB, Y, method);
        }
 
-       @Test
-       public void testRaJoinTestwithDifferentColumn2() {
+       private void testRaJoinTestwithDifferentColumn2(String method) {
                // Generate actual dataset and variables
                double[][] A = {
                                {1, 2, 3, 4, 5},
@@ -127,11 +193,10 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
                                {6, 7, 8, 9, 10, 1, 10, 200},
                                {21, 22, 23, 24, 25, 50, 25, 500}
                };
-               runRaJoinTest(A, colA, B, colB, Y);
+               runRaJoinTest(A, colA, B, colB, Y, method);
        }
 
-       @Test
-       public void testRaJoinTestwithNoMatchingRows() {
+       private void testRaJoinTestwithNoMatchingRows(String method) {
                // Generate actual dataset and variables
                double[][] A = {
                                {1, 2, 3},
@@ -148,11 +213,10 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
 
                // Expected output matrix (no matching rows)
                double[][] Y = {};
-               runRaJoinTest(A, colA, B, colB, Y);
+               runRaJoinTest(A, colA, B, colB, Y, method);
        }
 
-       @Test
-       public void testRaJoinTestwithAllMatchingRows() {
+       private void testRaJoinTestwithAllMatchingRows(String method) {
                // Generate actual dataset and variables
                double[][] A = {
                                {1, 2, 3},
@@ -173,10 +237,39 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
                                {2, 3, 4, 2, 3, 7},
                                {3, 4, 5, 3, 4, 8}
                };
-               runRaJoinTest(A, colA, B, colB, Y);
+               runRaJoinTest(A, colA, B, colB, Y, method);
+       }
+       
+       private void testRaJoinTestwithOneToMany(String method) {
+               // Generate actual dataset and variables
+               double[][] A = {
+                               {2, 2, 2},
+                               {3, 3, 3},
+                               {4, 4, 4}
+               };
+               double[][] B = {
+                               {2, 1, 1},
+                               {2, 2, 2},
+                               {3, 1, 1},
+                               {3, 2, 2},
+                               {3, 3, 3},
+                               {4, 1, 1}
+               };
+               int colA = 1;
+               int colB = 1;
+
+               double[][] Y = {
+                               {2, 2, 2, 2, 1, 1},
+                               {2, 2, 2, 2, 2, 2},
+                               {3, 3, 3, 3, 1, 1},
+                               {3, 3, 3, 3, 2, 2},
+                               {3, 3, 3, 3, 3, 3},
+                               {4, 4, 4, 4, 1, 1}
+               };
+               runRaJoinTest(A, colA, B, colB, Y, method);
        }
 
-       private void runRaJoinTest(double [][] A, int colA, double [][] B, int 
colB, double [][] Y)
+       private void runRaJoinTest(double [][] A, int colA, double [][] B, int 
colB, double [][] Y, String method)
        {
                ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE);
                
@@ -187,7 +280,8 @@ public class BuiltinRaJoinTest extends AutomatedTestBase
 
                        fullDMLScriptName = HOME + TEST_NAME + ".dml";
                        programArgs = new String[]{"-stats", "-args",
-                               input("A"), String.valueOf(colA), input("B"), 
String.valueOf(colB), output("result") };
+                               input("A"), String.valueOf(colA), input("B"),
+                               String.valueOf(colB), method, output("result") 
};
                        System.out.println(Arrays.deepToString(A));
                        System.out.println(colA);
                        //fullRScriptName = HOME + TEST_NAME + ".R";
diff --git a/src/test/scripts/functions/builtin/raJoin.dml 
b/src/test/scripts/functions/builtin/raJoin.dml
index 63aa2807c8..eedab1433e 100644
--- a/src/test/scripts/functions/builtin/raJoin.dml
+++ b/src/test/scripts/functions/builtin/raJoin.dml
@@ -24,6 +24,6 @@ colA = as.integer($2)
 B = read($3)
 colB = as.integer($4)
 
-result = raJoin(A, colA, B, colB, "sort-merge");
-write(result, $5);
+result = raJoin(A, colA, B, colB, $5);
+write(result, $6);
 

Reply via email to