This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 4318d429f8 [SYSTEMDS-3506] Additional decisionTreePredict GEMM strategy
4318d429f8 is described below

commit 4318d429f814854256414ec4ada61c558d1cfee5
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Mar 17 20:32:20 2023 +0100

    [SYSTEMDS-3506] Additional decisionTreePredict GEMM strategy
    
    This patch adds, besides the tree traversal (TT), the missing generic
    matrix multiplication (GEMM) scoring strategy which can improve
    execution time despite redundancy due to higher degree of parallelism.
---
 scripts/builtin/decisionTreePredict.dml            | 40 ++++++++++++++++++++--
 .../part1/BuiltinDecisionTreePredictTest.java      | 14 ++++++--
 2 files changed, 49 insertions(+), 5 deletions(-)

diff --git a/scripts/builtin/decisionTreePredict.dml 
b/scripts/builtin/decisionTreePredict.dml
index 8194e85785..d54e784e19 100644
--- a/scripts/builtin/decisionTreePredict.dml
+++ b/scripts/builtin/decisionTreePredict.dml
@@ -52,11 +52,13 @@
 # Y     Matrix containing the predicted labels for X 
 # ------------------------------------------------------------------
 
-m_decisionTreePredict = function(Matrix[Double] M, Matrix[Double] X, String 
strategy)
+m_decisionTreePredict = function(Matrix[Double] M, Matrix[Double] X, String 
strategy="TT")
   return (Matrix[Double] Y) 
 {
   if( strategy == "TT" )
     Y = predict_TT(M, X);
+  else if( strategy == "GEMM" )
+    Y = predict_GEMM(M, X);
   else {
     print ("No such strategy" + strategy)
     Y = matrix("0", rows=0, cols=0)
@@ -67,7 +69,7 @@ predict_TT = function (Matrix[Double] M, Matrix[Double] X)
   return (Matrix[Double] Y)
 {
   # initialization of model tensors and parameters
-  [N_L, N_R, N_F, N_T] = createNodeTensors(M)
+  [N_L, N_R, N_F, N_T] = createTTNodeTensors(M)
   nr = nrow(X); n = ncol(M);
   tree_depth = ceiling(log(n+1,2)) # max depth
 
@@ -91,7 +93,17 @@ predict_TT = function (Matrix[Double] M, Matrix[Double] X)
   Y = t(table(seq(1,nr), Ti, nr, n) %*%  t(M[4,]));
 }
 
-createNodeTensors = function( Matrix[Double] M )
+predict_GEMM = function (Matrix[Double] M, Matrix[Double] X)
+  return (Matrix[Double] Y)
+{
+  # initialization of model tensors and parameters
+  [A, B, C, D, E] = createGEMMNodeTensors(M, ncol(X));
+
+  # scoring pipline, evaluating all nodes in parallel
+  Y = t(rowIndexMax(((((X %*% A) < B) %*% C) == D) %*% E));
+}
+
+createTTNodeTensors = function( Matrix[Double] M )
   return ( Matrix[Double] N_L, Matrix[Double] N_R, Matrix[Double] N_F, 
Matrix[Double] N_T)
 {
   N = M[1,] # all tree nodes
@@ -108,3 +120,25 @@ createNodeTensors = function( Matrix[Double] M )
   N_F = ifelse(M[3,]!=0, M[3,], 1);
   N_T = M[6,]; # threshold values for inner nodes, otherwise 0
 }
+
+createGEMMNodeTensors = function( Matrix[Double] M, Int m )
+  return (Matrix[Double] A, Matrix[Double] B, Matrix[Double] C,
+  Matrix[Double] D, Matrix[Double] E)
+{
+  nin = sum(M[2,]!=0); # num inner nodes
+
+  # predicate map [#feat x #inodes] and values [1 x #inodes]
+  I1 = removeEmpty(target=M[3,], margin="cols");
+  A = table(I1, seq(1,nin), m, nin);
+  B = removeEmpty(target=M[6,], margin="cols", select=M[2,]!=0);
+
+  # bucket paths [#inodes x #paths] and path sums
+  I2 = (M[2,] == 0)
+  np = ncol(M) - nin;
+  C = matrix("1 1 -1 -1 1 -1 0 0 0 0 1 -1",
+       rows=3, cols=4); # TODO general case
+  D = colSums(max(C, 0));
+
+  # class map [#paths x #classes]
+  E = table(seq(1,ncol(C)), t(M[4,(ncol(M)-ncol(C)+1):ncol(M)]));
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java
index 04c1a53737..187529e7d7 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java
@@ -42,14 +42,24 @@ public class BuiltinDecisionTreePredictTest extends 
AutomatedTestBase {
        }
 
        @Test
-       public void testDecisionTreePredictDefaultCP() {
+       public void testDecisionTreeTTPredictDefaultCP() {
                runDecisionTreePredict(true, ExecType.CP, "TT");
        }
 
        @Test
-       public void testDecisionTreePredictSP() {
+       public void testDecisionTreeTTPredictSP() {
                runDecisionTreePredict(true, ExecType.SPARK, "TT");
        }
+       
+       @Test
+       public void testDecisionTreeGEMMPredictDefaultCP() {
+               runDecisionTreePredict(true, ExecType.CP, "GEMM");
+       }
+
+       @Test
+       public void testDecisionTreeGEMMPredictSP() {
+               runDecisionTreePredict(true, ExecType.SPARK, "GEMM");
+       }
 
        private void runDecisionTreePredict(boolean defaultProb, ExecType 
instType, String strategy) {
                Types.ExecMode platformOld = setExecMode(instType);

Reply via email to