This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 11418431bc [SYSTEMDS-3506] Finalize DecisionTree Predict Methods
11418431bc is described below

commit 11418431bc995a02d6910331d3fe72492f2bb1b6
Author: e-strauss <lathan...@gmx.de>
AuthorDate: Fri Aug 23 20:09:02 2024 +0200

    [SYSTEMDS-3506] Finalize DecisionTree Predict Methods
    
    Closes #2069.
---
 scripts/builtin/decisionTreePredict.dml            |  66 +++++++--
 .../part1/BuiltinDecisionTreePredictTest.java      | 161 +++++++++++++++++----
 .../builtin/part1/BuiltinDecisionTreeTest.java     |   2 +-
 3 files changed, 191 insertions(+), 38 deletions(-)

diff --git a/scripts/builtin/decisionTreePredict.dml 
b/scripts/builtin/decisionTreePredict.dml
index 50a8c2a912..4585d56d08 100644
--- a/scripts/builtin/decisionTreePredict.dml
+++ b/scripts/builtin/decisionTreePredict.dml
@@ -92,7 +92,7 @@ predict_GEMM = function (Matrix[Double] M, Matrix[Double] X)
   [A, B, C, D, E] = createGEMMNodeTensors(M, ncol(X));
 
   # scoring pipline, evaluating all nodes in parallel
-  Y = rowIndexMax(((((X %*% A) < B) %*% C) == D) %*% E);
+  Y = rowIndexMax(((((X %*% A) <= B) %*% C) == D) %*% E);
 }
 
 createTTNodeTensors = function( Matrix[Double] M )
@@ -125,22 +125,64 @@ createGEMMNodeTensors = function( Matrix[Double] M, Int m 
)
   return (Matrix[Double] A, Matrix[Double] B, Matrix[Double] C,
   Matrix[Double] D, Matrix[Double] E)
 {
-  #TODO update for new model layout and generalize
-  stop("GEMM not fully supported yet");
-
-  nin = sum(M[2,]!=0); # num inner nodes
+  M2 = matrix(M, rows=ncol(M)/2, cols=2)
+  NID = seq(1, nrow(M2))
 
   # predicate map [#feat x #inodes] and values [1 x #inodes]
-  I1 = removeEmpty(target=M[3,], margin="cols");
-  A = table(I1, seq(1,nin), m, nin);
-  B = removeEmpty(target=M[6,], margin="cols", select=M[2,]!=0);
+  is_inner = M2[,1]!=0
+  I1 = removeEmpty(target=NID, margin="rows", select=is_inner)
+  pivot = removeEmpty(target=M2[,1], margin="rows", select=is_inner)
+  nin = nrow(I1)
+  A = table(pivot, seq(1,nin), m, nin)
+  B = t(removeEmpty(target=M2[,2], margin="rows", select=is_inner))
 
   # bucket paths [#inodes x #paths] and path sums
-  I2 = (M[2,] == 0)
-  np = ncol(M) - nin;
-  C = matrix("1 -1", rows=1, cols=2); # TODO general case
+  is_leaf = (!is_inner & M2[,2]!=0)
+  leaf_ids = t(removeEmpty(target=NID, margin="rows", select=is_leaf))
+  last_leaf = as.scalar(leaf_ids[1,ncol(leaf_ids)])
+  leaf_classes = removeEmpty(target=M2[,2], margin="rows", select=is_leaf)
+  nl = ncol(leaf_ids)
+
+  # iterate over each inner node and check for each leaf node if it is the 
left subtree (1), right subtree (-1) or not included (0)
+  #              | i |
+  #           /        \
+  #        |2i|       |2i+1|
+  #       /    \     /      \
+  #     |4i| |4i+1| |4i+2| |4i+3|
+  #
+  # left_subtree_of_node(i) = { x | (2^j)*i <= x < (2^j)*i + 2^(j-1), for j 
elem {1, 2, 3, ...}} -> j is the level of tree
+  # right_subtree_of_node(i) = { x | (2^j)*i + 2^(j-1) <= x < (2^j + 1)*i, for 
j elem {1, 2, 3, ...}}
+
+  C = matrix(0, nin, nl)
+  parfor(i in seq(1, nin)){
+    boundary_left = 2*as.scalar(I1[i, 1]) # initialize the left boundary with 
the left child of the inner node
+    out = matrix(0, 1, nl)
+    step_size = 1
+
+    # iterate each level of tree [log(max_node_id) iterations]
+    while(boundary_left < last_leaf) {
+
+      # left side
+      subset_lower_bound = leaf_ids >= boundary_left
+      boundary_right = boundary_left + step_size
+      subset_upper_bound = leaf_ids < boundary_right
+      ones = subset_lower_bound & subset_upper_bound
+      out = out + ones
+
+      # right side
+      subset_lower_bound = !subset_upper_bound #reuse by inverting
+      boundary_right = boundary_right + step_size
+      subset_upper_bound = leaf_ids < boundary_right
+      ones = subset_lower_bound & subset_upper_bound
+      out = out - ones
+
+      step_size = step_size*2 # with each level the amount of nodes in subtree 
level doubles
+      boundary_left = boundary_left*2 # new left boundary is the left child of 
the previous left boundary
+    }
+    C[i,] = out
+  }
   D = colSums(max(C, 0));
 
   # class map [#paths x #classes]
-  E = table(seq(1,ncol(C)), t(M[4,(ncol(M)-ncol(C)+1):ncol(M)]));
+  E = table(seq(1,ncol(C)),leaf_classes)
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java
index 0cecc0e15c..6eb22da335 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java
@@ -27,13 +27,12 @@ import org.apache.sysds.runtime.matrix.data.MatrixValue;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
-import org.junit.Ignore;
 import org.junit.Test;
 
 public class BuiltinDecisionTreePredictTest extends AutomatedTestBase {
        private final static String TEST_NAME = "decisionTreePredict";
        private final static String TEST_DIR = "functions/builtin/";
-       private static final String TEST_CLASS_DIR = TEST_DIR + 
BuiltinDecisionTreeTest.class.getSimpleName() + "/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
BuiltinDecisionTreePredictTest.class.getSimpleName() + "/";
 
        private final static double eps = 1e-10;
 
@@ -43,28 +42,58 @@ public class BuiltinDecisionTreePredictTest extends 
AutomatedTestBase {
        }
 
        @Test
-       public void testDecisionTreeTTPredictDefaultCP() {
-               runDecisionTreePredict(true, ExecType.CP, "TT");
+       public void testDecisionTreeTTPredictDefaultCP1() {
+               runDecisionTreePredict(true, ExecType.CP, "TT", 1);
        }
 
+       @Test
+       public void testDecisionTreeTTPredictDefaultCP2() {
+               runDecisionTreePredict(true, ExecType.CP, "TT", 2);
+       }
+
+       @Test
+       public void testDecisionTreeTTPredictDefaultCP3() {
+               runDecisionTreePredict(true, ExecType.CP, "TT", 3);
+       }
+
+       @Test
+       public void testDecisionTreeTTPredictDefaultCP4() {
+               runDecisionTreePredict(true, ExecType.CP, "TT", 4);
+       }
+
+
        @Test
        public void testDecisionTreeTTPredictSP() {
-               runDecisionTreePredict(true, ExecType.SPARK, "TT");
+               runDecisionTreePredict(true, ExecType.SPARK, "TT", 1);
        }
        
        @Test
-       @Ignore
-       public void testDecisionTreeGEMMPredictDefaultCP() {
-               runDecisionTreePredict(true, ExecType.CP, "GEMM");
+       public void testDecisionTreeGEMMPredictDefaultCP1() {
+               runDecisionTreePredict(true, ExecType.CP, "GEMM", 1);
+       }
+
+       @Test
+       public void testDecisionTreeGEMMPredictDefaultCP2() {
+               runDecisionTreePredict(true, ExecType.CP, "GEMM", 2);
        }
 
        @Test
-       @Ignore
+       public void testDecisionTreeGEMMPredictDefaultCP3() {
+               runDecisionTreePredict(true, ExecType.CP, "GEMM", 3);
+       }
+
+       @Test
+       public void testDecisionTreeGEMMPredictDefaultCP4() {
+               runDecisionTreePredict(true, ExecType.CP, "GEMM", 4);
+       }
+
+
+       @Test
        public void testDecisionTreeGEMMPredictSP() {
-               runDecisionTreePredict(true, ExecType.SPARK, "GEMM");
+               runDecisionTreePredict(true, ExecType.SPARK, "GEMM", 1);
        }
 
-       private void runDecisionTreePredict(boolean defaultProb, ExecType 
instType, String strategy) {
+       private void runDecisionTreePredict(boolean defaultProb, ExecType 
instType, String strategy, int test_case) {
                Types.ExecMode platformOld = setExecMode(instType);
                try {
                        loadTestConfiguration(getTestConfiguration(TEST_NAME));
@@ -74,20 +103,102 @@ public class BuiltinDecisionTreePredictTest extends 
AutomatedTestBase {
                        programArgs = new String[] {"-args", input("M"), 
input("X"), strategy, output("Y")};
 
                        //data and model consistent with decision tree test
-                       double[][] X = {
-                                       {3, 1, 2, 1, 5}, 
-                                       {2, 1, 2, 2, 4}, 
-                                       {1, 1, 1, 3, 3},
-                                       {4, 2, 1, 4, 2}, 
-                                       {2, 2, 1, 5, 1},};
-                       double[][] M = {{1.0, 2.0, 0.0, 1.0, 0.0, 2.0}};
-                       
+                       double[][] X = null;
+                       double[][] M = null;
+
                        HashMap<MatrixValue.CellIndex, Double> expected_Y = new 
HashMap<>();
-                       expected_Y.put(new MatrixValue.CellIndex(1, 1), 2.0);
-                       expected_Y.put(new MatrixValue.CellIndex(2, 1), 1.0);
-                       expected_Y.put(new MatrixValue.CellIndex(3, 1), 1.0);
-                       expected_Y.put(new MatrixValue.CellIndex(4, 1), 2.0);
-                       expected_Y.put(new MatrixValue.CellIndex(5, 1), 1.0);
+                       switch(test_case){
+                               case 1:
+                                       double[][] X1 = {
+                                                       {3, 1, 2, 1, 5},
+                                                       {2, 1, 2, 2, 4},
+                                                       {1, 1, 1, 3, 3},
+                                                       {4, 2, 1, 4, 2},
+                                                       {2, 2, 1, 5, 1},};
+                                       double[][] M1 = {{1.0, 2.0, 0.0, 1.0, 
0.0, 2.0}};
+
+                                       expected_Y.put(new 
MatrixValue.CellIndex(1, 1), 2.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(2, 1), 1.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(3, 1), 1.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(4, 1), 2.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(5, 1), 1.0);
+                                       X = X1;
+                                       M = M1;
+                                       break;
+                               case 2:
+                                       double[][] X2 = {
+                                                       {3, 1, 2, 1},
+                                                       {2, 1, 2, 6},
+                                                       {1, 1, 1, 3},
+                                                       {9, 2, 1, 7},
+                                                       {2, 2, 1, 1},};
+                                       double[][] M2 = {{4, 5, 0, 2, 1, 7, 0, 
0, 0, 0, 0, 2, 0, 1}};
+
+                                       expected_Y.put(new 
MatrixValue.CellIndex(1, 1), 2.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(2, 1), 2.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(3, 1), 2.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(4, 1), 1.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(5, 1), 2.0);
+                                       X = X2;
+                                       M = M2;
+                                       break;
+                               case 3:
+                                       double[][] X3 = {
+                                                       {1, 1, 1},
+                                                       {1, 1, 7,},
+                                                       {1, 5, 1},
+                                                       {1, 5, 7,},};
+                                       double[][] M3 = {{1, 5, 2, 4, 2, 4, 3, 
6, 3, 6, 3, 6, 3, 6, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8}};
+
+                                       expected_Y.put(new 
MatrixValue.CellIndex(1, 1), 1.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(2, 1), 2.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(3, 1), 3.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(4, 1), 4.0);
+                                       X = X3;
+                                       M = M3;
+                                       break;
+                               case 4:
+                                       double[][] X4 = {
+                                                       {1, 1, 1, 1},
+                                                       {4, 1, 1, 1},
+                                                       {1, 1, 7, 1},
+                                                       {4, 1, 7, 1},
+                                                       {1, 5, 1, 1},
+                                                       {4, 5, 1, 1},
+                                                       {1, 5, 7, 1},
+                                                       {4, 5, 7, 1},
+                                                       {1, 1, 1, 6},
+                                                       {4, 1, 1, 6},
+                                                       {1, 1, 7, 6},
+                                                       {4, 1, 7, 6},
+                                                       {1, 5, 1, 6},
+                                                       {4, 5, 1, 6},
+                                                       {1, 5, 7, 6},
+                                                       {4, 5, 7, 6},};
+                                       double[][] M4 = {{4, 5, 2, 4, 2, 4, 3, 
6, 3, 6, 3, 6, 3, 6, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1,
+                                                       3, 1, 3, 0, 1, 0, 2, 0, 
3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0, 13,
+                                                       0, 14, 0, 15, 0, 16}};
+
+                                       expected_Y.put(new 
MatrixValue.CellIndex(1, 1), 1.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(2, 1), 2.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(3, 1), 3.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(4, 1), 4.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(5, 1), 5.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(6, 1), 6.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(7, 1), 7.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(8, 1), 8.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(9, 1), 9.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(10, 1), 10.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(11, 1), 11.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(12, 1), 12.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(13, 1), 13.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(14, 1), 14.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(15, 1), 15.0);
+                                       expected_Y.put(new 
MatrixValue.CellIndex(16, 1), 16.0);
+                                       X = X4;
+                                       M = M4;
+                                       break;
+                       }
 
                        writeInputMatrixWithMTD("M", M, true);
                        writeInputMatrixWithMTD("X", X, true);
@@ -98,7 +209,7 @@ public class BuiltinDecisionTreePredictTest extends 
AutomatedTestBase {
                        TestUtils.compareMatrices(expected_Y, actual_Y, eps, 
"Expected-DML", "Actual-DML");
                }
                finally {
-                       rtplatform = platformOld;
+                       resetExecMode(platformOld);
                }
        }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeTest.java
index f8ac8397cb..a8b3112992 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeTest.java
@@ -86,7 +86,7 @@ public class BuiltinDecisionTreeTest extends 
AutomatedTestBase {
                        TestUtils.compareMatrices(expected_M, actual_M, eps, 
"Expected-DML", "Actual-DML");
                }
                finally {
-                       rtplatform = platformOld;
+                       resetExecMode(platformOld);
                }
        }
 }

Reply via email to