This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 4318d429f8 [SYSTEMDS-3506] Additional decisionTreePredict GEMM strategy
4318d429f8 is described below
commit 4318d429f814854256414ec4ada61c558d1cfee5
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Mar 17 20:32:20 2023 +0100
[SYSTEMDS-3506] Additional decisionTreePredict GEMM strategy
This patch adds, besides the tree traversal (TT), the missing generic
matrix multiplication (GEMM) scoring strategy which can improve
execution time despite redundancy due to higher degree of parallelism.
---
scripts/builtin/decisionTreePredict.dml | 40 ++++++++++++++++++++--
.../part1/BuiltinDecisionTreePredictTest.java | 14 ++++++--
2 files changed, 49 insertions(+), 5 deletions(-)
diff --git a/scripts/builtin/decisionTreePredict.dml
b/scripts/builtin/decisionTreePredict.dml
index 8194e85785..d54e784e19 100644
--- a/scripts/builtin/decisionTreePredict.dml
+++ b/scripts/builtin/decisionTreePredict.dml
@@ -52,11 +52,13 @@
# Y Matrix containing the predicted labels for X
# ------------------------------------------------------------------
-m_decisionTreePredict = function(Matrix[Double] M, Matrix[Double] X, String
strategy)
+m_decisionTreePredict = function(Matrix[Double] M, Matrix[Double] X, String
strategy="TT")
return (Matrix[Double] Y)
{
if( strategy == "TT" )
Y = predict_TT(M, X);
+ else if( strategy == "GEMM" )
+ Y = predict_GEMM(M, X);
else {
print ("No such strategy" + strategy)
Y = matrix("0", rows=0, cols=0)
@@ -67,7 +69,7 @@ predict_TT = function (Matrix[Double] M, Matrix[Double] X)
return (Matrix[Double] Y)
{
# initialization of model tensors and parameters
- [N_L, N_R, N_F, N_T] = createNodeTensors(M)
+ [N_L, N_R, N_F, N_T] = createTTNodeTensors(M)
nr = nrow(X); n = ncol(M);
tree_depth = ceiling(log(n+1,2)) # max depth
@@ -91,7 +93,17 @@ predict_TT = function (Matrix[Double] M, Matrix[Double] X)
Y = t(table(seq(1,nr), Ti, nr, n) %*% t(M[4,]));
}
-createNodeTensors = function( Matrix[Double] M )
+predict_GEMM = function (Matrix[Double] M, Matrix[Double] X)
+ return (Matrix[Double] Y)
+{
+ # initialization of model tensors and parameters
+ [A, B, C, D, E] = createGEMMNodeTensors(M, ncol(X));
+
+ # scoring pipline, evaluating all nodes in parallel
+ Y = t(rowIndexMax(((((X %*% A) < B) %*% C) == D) %*% E));
+}
+
+createTTNodeTensors = function( Matrix[Double] M )
return ( Matrix[Double] N_L, Matrix[Double] N_R, Matrix[Double] N_F,
Matrix[Double] N_T)
{
N = M[1,] # all tree nodes
@@ -108,3 +120,25 @@ createNodeTensors = function( Matrix[Double] M )
N_F = ifelse(M[3,]!=0, M[3,], 1);
N_T = M[6,]; # threshold values for inner nodes, otherwise 0
}
+
+createGEMMNodeTensors = function( Matrix[Double] M, Int m )
+ return (Matrix[Double] A, Matrix[Double] B, Matrix[Double] C,
+ Matrix[Double] D, Matrix[Double] E)
+{
+ nin = sum(M[2,]!=0); # num inner nodes
+
+ # predicate map [#feat x #inodes] and values [1 x #inodes]
+ I1 = removeEmpty(target=M[3,], margin="cols");
+ A = table(I1, seq(1,nin), m, nin);
+ B = removeEmpty(target=M[6,], margin="cols", select=M[2,]!=0);
+
+ # bucket paths [#inodes x #paths] and path sums
+ I2 = (M[2,] == 0)
+ np = ncol(M) - nin;
+ C = matrix("1 1 -1 -1 1 -1 0 0 0 0 1 -1",
+ rows=3, cols=4); # TODO general case
+ D = colSums(max(C, 0));
+
+ # class map [#paths x #classes]
+ E = table(seq(1,ncol(C)), t(M[4,(ncol(M)-ncol(C)+1):ncol(M)]));
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java
index 04c1a53737..187529e7d7 100644
---
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java
@@ -42,14 +42,24 @@ public class BuiltinDecisionTreePredictTest extends
AutomatedTestBase {
}
@Test
- public void testDecisionTreePredictDefaultCP() {
+ public void testDecisionTreeTTPredictDefaultCP() {
runDecisionTreePredict(true, ExecType.CP, "TT");
}
@Test
- public void testDecisionTreePredictSP() {
+ public void testDecisionTreeTTPredictSP() {
runDecisionTreePredict(true, ExecType.SPARK, "TT");
}
+
+ @Test
+ public void testDecisionTreeGEMMPredictDefaultCP() {
+ runDecisionTreePredict(true, ExecType.CP, "GEMM");
+ }
+
+ @Test
+ public void testDecisionTreeGEMMPredictSP() {
+ runDecisionTreePredict(true, ExecType.SPARK, "GEMM");
+ }
private void runDecisionTreePredict(boolean defaultProb, ExecType
instType, String strategy) {
Types.ExecMode platformOld = setExecMode(instType);