This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 11418431bc [SYSTEMDS-3506] Finalize DecisionTree Predict Methods 11418431bc is described below commit 11418431bc995a02d6910331d3fe72492f2bb1b6 Author: e-strauss <lathan...@gmx.de> AuthorDate: Fri Aug 23 20:09:02 2024 +0200 [SYSTEMDS-3506] Finalize DecisionTree Predict Methods Closes #2069. --- scripts/builtin/decisionTreePredict.dml | 66 +++++++-- .../part1/BuiltinDecisionTreePredictTest.java | 161 +++++++++++++++++---- .../builtin/part1/BuiltinDecisionTreeTest.java | 2 +- 3 files changed, 191 insertions(+), 38 deletions(-) diff --git a/scripts/builtin/decisionTreePredict.dml b/scripts/builtin/decisionTreePredict.dml index 50a8c2a912..4585d56d08 100644 --- a/scripts/builtin/decisionTreePredict.dml +++ b/scripts/builtin/decisionTreePredict.dml @@ -92,7 +92,7 @@ predict_GEMM = function (Matrix[Double] M, Matrix[Double] X) [A, B, C, D, E] = createGEMMNodeTensors(M, ncol(X)); # scoring pipline, evaluating all nodes in parallel - Y = rowIndexMax(((((X %*% A) < B) %*% C) == D) %*% E); + Y = rowIndexMax(((((X %*% A) <= B) %*% C) == D) %*% E); } createTTNodeTensors = function( Matrix[Double] M ) @@ -125,22 +125,64 @@ createGEMMNodeTensors = function( Matrix[Double] M, Int m ) return (Matrix[Double] A, Matrix[Double] B, Matrix[Double] C, Matrix[Double] D, Matrix[Double] E) { - #TODO update for new model layout and generalize - stop("GEMM not fully supported yet"); - - nin = sum(M[2,]!=0); # num inner nodes + M2 = matrix(M, rows=ncol(M)/2, cols=2) + NID = seq(1, nrow(M2)) # predicate map [#feat x #inodes] and values [1 x #inodes] - I1 = removeEmpty(target=M[3,], margin="cols"); - A = table(I1, seq(1,nin), m, nin); - B = removeEmpty(target=M[6,], margin="cols", select=M[2,]!=0); + is_inner = M2[,1]!=0 + I1 = removeEmpty(target=NID, margin="rows", select=is_inner) + pivot = removeEmpty(target=M2[,1], margin="rows", select=is_inner) + nin = nrow(I1) + A = table(pivot, seq(1,nin), m, nin) + B = t(removeEmpty(target=M2[,2], margin="rows", select=is_inner)) # bucket paths [#inodes x #paths] and path sums - I2 = (M[2,] == 0) - np = ncol(M) - nin; - C = matrix("1 -1", rows=1, cols=2); # TODO general case + is_leaf = (!is_inner & M2[,2]!=0) + leaf_ids = t(removeEmpty(target=NID, margin="rows", select=is_leaf)) + last_leaf = as.scalar(leaf_ids[1,ncol(leaf_ids)]) + leaf_classes = removeEmpty(target=M2[,2], margin="rows", select=is_leaf) + nl = ncol(leaf_ids) + + # iterate over each inner node and check for each leaf node if it is the left subtree (1), right subtree (-1) or not included (0) + # | i | + # / \ + # |2i| |2i+1| + # / \ / \ + # |4i| |4i+1| |4i+2| |4i+3| + # + # left_subtree_of_node(i) = { x | (2^j)*i <= x < (2^j)*i + 2^(j-1), for j elem {1, 2, 3, ...}} -> j is the level of tree + # right_subtree_of_node(i) = { x | (2^j)*i + 2^(j-1) <= x < (2^j + 1)*i, for j elem {1, 2, 3, ...}} + + C = matrix(0, nin, nl) + parfor(i in seq(1, nin)){ + boundary_left = 2*as.scalar(I1[i, 1]) # initialize the left boundary with the left child of the inner node + out = matrix(0, 1, nl) + step_size = 1 + + # iterate each level of tree [log(max_node_id) iterations] + while(boundary_left < last_leaf) { + + # left side + subset_lower_bound = leaf_ids >= boundary_left + boundary_right = boundary_left + step_size + subset_upper_bound = leaf_ids < boundary_right + ones = subset_lower_bound & subset_upper_bound + out = out + ones + + # right side + subset_lower_bound = !subset_upper_bound #reuse by inverting + boundary_right = boundary_right + step_size + subset_upper_bound = leaf_ids < boundary_right + ones = subset_lower_bound & subset_upper_bound + out = out - ones + + step_size = step_size*2 # with each level the amount of nodes in subtree level doubles + boundary_left = boundary_left*2 # new left boundary is the left child of the previous left boundary + } + C[i,] = out + } D = colSums(max(C, 0)); # class map [#paths x #classes] - E = table(seq(1,ncol(C)), t(M[4,(ncol(M)-ncol(C)+1):ncol(M)])); + E = table(seq(1,ncol(C)),leaf_classes) } diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java index 0cecc0e15c..6eb22da335 100644 --- a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java @@ -27,13 +27,12 @@ import org.apache.sysds.runtime.matrix.data.MatrixValue; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; -import org.junit.Ignore; import org.junit.Test; public class BuiltinDecisionTreePredictTest extends AutomatedTestBase { private final static String TEST_NAME = "decisionTreePredict"; private final static String TEST_DIR = "functions/builtin/"; - private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinDecisionTreeTest.class.getSimpleName() + "/"; + private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinDecisionTreePredictTest.class.getSimpleName() + "/"; private final static double eps = 1e-10; @@ -43,28 +42,58 @@ public class BuiltinDecisionTreePredictTest extends AutomatedTestBase { } @Test - public void testDecisionTreeTTPredictDefaultCP() { - runDecisionTreePredict(true, ExecType.CP, "TT"); + public void testDecisionTreeTTPredictDefaultCP1() { + runDecisionTreePredict(true, ExecType.CP, "TT", 1); } + @Test + public void testDecisionTreeTTPredictDefaultCP2() { + runDecisionTreePredict(true, ExecType.CP, "TT", 2); + } + + @Test + public void testDecisionTreeTTPredictDefaultCP3() { + runDecisionTreePredict(true, ExecType.CP, "TT", 3); + } + + @Test + public void testDecisionTreeTTPredictDefaultCP4() { + runDecisionTreePredict(true, ExecType.CP, "TT", 4); + } + + @Test public void testDecisionTreeTTPredictSP() { - runDecisionTreePredict(true, ExecType.SPARK, "TT"); + runDecisionTreePredict(true, ExecType.SPARK, "TT", 1); } @Test - @Ignore - public void testDecisionTreeGEMMPredictDefaultCP() { - runDecisionTreePredict(true, ExecType.CP, "GEMM"); + public void testDecisionTreeGEMMPredictDefaultCP1() { + runDecisionTreePredict(true, ExecType.CP, "GEMM", 1); + } + + @Test + public void testDecisionTreeGEMMPredictDefaultCP2() { + runDecisionTreePredict(true, ExecType.CP, "GEMM", 2); } @Test - @Ignore + public void testDecisionTreeGEMMPredictDefaultCP3() { + runDecisionTreePredict(true, ExecType.CP, "GEMM", 3); + } + + @Test + public void testDecisionTreeGEMMPredictDefaultCP4() { + runDecisionTreePredict(true, ExecType.CP, "GEMM", 4); + } + + + @Test public void testDecisionTreeGEMMPredictSP() { - runDecisionTreePredict(true, ExecType.SPARK, "GEMM"); + runDecisionTreePredict(true, ExecType.SPARK, "GEMM", 1); } - private void runDecisionTreePredict(boolean defaultProb, ExecType instType, String strategy) { + private void runDecisionTreePredict(boolean defaultProb, ExecType instType, String strategy, int test_case) { Types.ExecMode platformOld = setExecMode(instType); try { loadTestConfiguration(getTestConfiguration(TEST_NAME)); @@ -74,20 +103,102 @@ public class BuiltinDecisionTreePredictTest extends AutomatedTestBase { programArgs = new String[] {"-args", input("M"), input("X"), strategy, output("Y")}; //data and model consistent with decision tree test - double[][] X = { - {3, 1, 2, 1, 5}, - {2, 1, 2, 2, 4}, - {1, 1, 1, 3, 3}, - {4, 2, 1, 4, 2}, - {2, 2, 1, 5, 1},}; - double[][] M = {{1.0, 2.0, 0.0, 1.0, 0.0, 2.0}}; - + double[][] X = null; + double[][] M = null; + HashMap<MatrixValue.CellIndex, Double> expected_Y = new HashMap<>(); - expected_Y.put(new MatrixValue.CellIndex(1, 1), 2.0); - expected_Y.put(new MatrixValue.CellIndex(2, 1), 1.0); - expected_Y.put(new MatrixValue.CellIndex(3, 1), 1.0); - expected_Y.put(new MatrixValue.CellIndex(4, 1), 2.0); - expected_Y.put(new MatrixValue.CellIndex(5, 1), 1.0); + switch(test_case){ + case 1: + double[][] X1 = { + {3, 1, 2, 1, 5}, + {2, 1, 2, 2, 4}, + {1, 1, 1, 3, 3}, + {4, 2, 1, 4, 2}, + {2, 2, 1, 5, 1},}; + double[][] M1 = {{1.0, 2.0, 0.0, 1.0, 0.0, 2.0}}; + + expected_Y.put(new MatrixValue.CellIndex(1, 1), 2.0); + expected_Y.put(new MatrixValue.CellIndex(2, 1), 1.0); + expected_Y.put(new MatrixValue.CellIndex(3, 1), 1.0); + expected_Y.put(new MatrixValue.CellIndex(4, 1), 2.0); + expected_Y.put(new MatrixValue.CellIndex(5, 1), 1.0); + X = X1; + M = M1; + break; + case 2: + double[][] X2 = { + {3, 1, 2, 1}, + {2, 1, 2, 6}, + {1, 1, 1, 3}, + {9, 2, 1, 7}, + {2, 2, 1, 1},}; + double[][] M2 = {{4, 5, 0, 2, 1, 7, 0, 0, 0, 0, 0, 2, 0, 1}}; + + expected_Y.put(new MatrixValue.CellIndex(1, 1), 2.0); + expected_Y.put(new MatrixValue.CellIndex(2, 1), 2.0); + expected_Y.put(new MatrixValue.CellIndex(3, 1), 2.0); + expected_Y.put(new MatrixValue.CellIndex(4, 1), 1.0); + expected_Y.put(new MatrixValue.CellIndex(5, 1), 2.0); + X = X2; + M = M2; + break; + case 3: + double[][] X3 = { + {1, 1, 1}, + {1, 1, 7,}, + {1, 5, 1}, + {1, 5, 7,},}; + double[][] M3 = {{1, 5, 2, 4, 2, 4, 3, 6, 3, 6, 3, 6, 3, 6, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8}}; + + expected_Y.put(new MatrixValue.CellIndex(1, 1), 1.0); + expected_Y.put(new MatrixValue.CellIndex(2, 1), 2.0); + expected_Y.put(new MatrixValue.CellIndex(3, 1), 3.0); + expected_Y.put(new MatrixValue.CellIndex(4, 1), 4.0); + X = X3; + M = M3; + break; + case 4: + double[][] X4 = { + {1, 1, 1, 1}, + {4, 1, 1, 1}, + {1, 1, 7, 1}, + {4, 1, 7, 1}, + {1, 5, 1, 1}, + {4, 5, 1, 1}, + {1, 5, 7, 1}, + {4, 5, 7, 1}, + {1, 1, 1, 6}, + {4, 1, 1, 6}, + {1, 1, 7, 6}, + {4, 1, 7, 6}, + {1, 5, 1, 6}, + {4, 5, 1, 6}, + {1, 5, 7, 6}, + {4, 5, 7, 6},}; + double[][] M4 = {{4, 5, 2, 4, 2, 4, 3, 6, 3, 6, 3, 6, 3, 6, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, + 3, 1, 3, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0, 13, + 0, 14, 0, 15, 0, 16}}; + + expected_Y.put(new MatrixValue.CellIndex(1, 1), 1.0); + expected_Y.put(new MatrixValue.CellIndex(2, 1), 2.0); + expected_Y.put(new MatrixValue.CellIndex(3, 1), 3.0); + expected_Y.put(new MatrixValue.CellIndex(4, 1), 4.0); + expected_Y.put(new MatrixValue.CellIndex(5, 1), 5.0); + expected_Y.put(new MatrixValue.CellIndex(6, 1), 6.0); + expected_Y.put(new MatrixValue.CellIndex(7, 1), 7.0); + expected_Y.put(new MatrixValue.CellIndex(8, 1), 8.0); + expected_Y.put(new MatrixValue.CellIndex(9, 1), 9.0); + expected_Y.put(new MatrixValue.CellIndex(10, 1), 10.0); + expected_Y.put(new MatrixValue.CellIndex(11, 1), 11.0); + expected_Y.put(new MatrixValue.CellIndex(12, 1), 12.0); + expected_Y.put(new MatrixValue.CellIndex(13, 1), 13.0); + expected_Y.put(new MatrixValue.CellIndex(14, 1), 14.0); + expected_Y.put(new MatrixValue.CellIndex(15, 1), 15.0); + expected_Y.put(new MatrixValue.CellIndex(16, 1), 16.0); + X = X4; + M = M4; + break; + } writeInputMatrixWithMTD("M", M, true); writeInputMatrixWithMTD("X", X, true); @@ -98,7 +209,7 @@ public class BuiltinDecisionTreePredictTest extends AutomatedTestBase { TestUtils.compareMatrices(expected_Y, actual_Y, eps, "Expected-DML", "Actual-DML"); } finally { - rtplatform = platformOld; + resetExecMode(platformOld); } } } diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeTest.java index f8ac8397cb..a8b3112992 100644 --- a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeTest.java +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeTest.java @@ -86,7 +86,7 @@ public class BuiltinDecisionTreeTest extends AutomatedTestBase { TestUtils.compareMatrices(expected_M, actual_M, eps, "Expected-DML", "Actual-DML"); } finally { - rtplatform = platformOld; + resetExecMode(platformOld); } } }