This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 741be73 [SYSTEMDS-3149] Decision Tree Prediction Builtin DIA project WS2021/22 Closes #1506 741be73 is described below commit 741be739c8659e67105a6ba66a972b1b3f7d3d11 Author: Magdalena Hinterkoerner <m.hinterkoer...@student.tugraz.at> AuthorDate: Wed Jan 5 14:26:12 2022 +0100 [SYSTEMDS-3149] Decision Tree Prediction Builtin DIA project WS2021/22 Closes #1506 --- scripts/builtin/decisionTreePredict.dml | 149 +++++++++++++++++++++ .../java/org/apache/sysds/common/Builtins.java | 1 + .../part1/BuiltinDecisionTreePredictTest.java | 87 ++++++++++++ .../functions/builtin/decisionTreePredict.dml | 25 ++++ 4 files changed, 262 insertions(+) diff --git a/scripts/builtin/decisionTreePredict.dml b/scripts/builtin/decisionTreePredict.dml new file mode 100644 index 0000000..48c7f6f --- /dev/null +++ b/scripts/builtin/decisionTreePredict.dml @@ -0,0 +1,149 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# +# Builtin script implementing prediction based on classification trees with scale features using prediction methods of the +# Hummingbird paper (https://www.usenix.org/system/files/osdi20-nakandala.pdf). +# +# INPUT PARAMETERS: +# --------------------------------------------------------------------------------------------- +# NAME TYPE MEANING +# --------------------------------------------------------------------------------------------- +# M Matrix[Double] Decision tree matrix M, as generated by scripts/builtin/decisionTree.dml, where each column corresponds +# to a node in the learned tree and each row contains the following information: +# M[1,j]: id of node j (in a complete binary tree) +# M[2,j]: Offset (no. of columns) to left child of j if j is an internal node, otherwise 0 +# M[3,j]: Feature index of the feature (scale feature id if the feature is scale or +# categorical feature id if the feature is categorical) +# that node j looks at if j is an internal node, otherwise 0 +# M[4,j]: Type of the feature that node j looks at if j is an internal node: holds +# the same information as R input vector +# M[5,j]: If j is an internal node: 1 if the feature chosen for j is scale, +# otherwise the size of the subset of values +# stored in rows 6,7,... if j is categorical +# If j is a leaf node: number of misclassified samples reaching at node j +# M[6:,j]: If j is an internal node: Threshold the example's feature value is compared +# to is stored at M[6,j] if the feature chosen for j is scale, +# otherwise if the feature chosen for j is categorical rows 6,7,... depict the value subset chosen for j +# If j is a leaf node 1 if j is impure and the number of samples at j > threshold, otherwise 0 +# +# X Matrix[Double] Feature matrix X +# +# strategy String Prediction strategy, can be one of ["GEMM", "TT", "PTT"], referring to "Generic matrix multiplication", +# "Tree traversal", and "Perfect tree traversal", respectively +# ------------------------------------------------------------------------------------------- +# OUTPUT: +# --------------------------------------------------------------------------------------------- +# NAME TYPE MEANING +# --------------------------------------------------------------------------------------------- +# Y Matrix[Double] Matrix containing the predicted labels for X +# --------------------------------------------------------------------------------------------- + +m_decisionTreePredict = function(Matrix[Double] M, Matrix[Double] X, String strategy) + return (Matrix[Double] Y) +{ + if (strategy == "TT") { + Y = predict_TT(M, X) + } + else { + print ("No such strategy" + strategy) + Y = matrix("0", rows=0, cols=0) + } +} + +predict_TT = function (Matrix[Double] M, Matrix[Double] X) + return (Matrix[Double] Y) +{ + Y = matrix(0, rows=1, cols=nrow(X)) + n = ncol(M) + tree_depth = ceiling(log(n+1,2)) # max depth of complete binary tree + [N_L, N_R, N_F, N_T] = createNodeTensors(M) + + parfor (k in 1:nrow(X)){ + # iterate over every sample in X matrix + sample = X[k,] + current_node = 1 + cnt = 1 + while (cnt < tree_depth){ + feature_id = as.scalar(N_F[1, current_node]) + feature = as.scalar(sample[,feature_id]) # select feature from sample data + threshold = as.scalar(N_T[1, current_node]) + + if (feature < threshold){ + # move on to left child node + next_node = as.scalar(N_L[1, current_node]) + } else { + # move on to right child node + next_node = as.scalar(N_R[1, current_node]) + } + current_node = next_node + cnt +=1 + } + + class = M[4, current_node] + Y[1, k] = class + } +} + +createNodeTensors = function( Matrix[Double] M ) + return ( Matrix[Double] N_L, Matrix[Double] N_R, Matrix[Double] N_F, Matrix[Double] N_T) +{ + N = M[1,] # all tree nodes + I = M[2,] # list of node offsets to their left children + n_nodes = ncol(N) + + N_L = matrix(0, rows=1, cols=n_nodes) + N_R = matrix(0, rows=1, cols=n_nodes) + N_F = matrix(0, rows=1, cols=n_nodes) + N_T = matrix(0, rows=1, cols=n_nodes) + + parfor (i in 1:n_nodes){ + # if the node is an internal node, add its left and right child to the N_L and N_R tensor, respectively + if (as.scalar(I[1,i]) != 0){ + offset = as.scalar(I[1, i]) + leftChild = as.scalar(N[1, i+offset]) + N_L[1, i] = N[1, i+offset] + rightChild = leftChild + 1 + + if (as.scalar(I[1, leftChild]) == 0 & as.scalar(I[1, rightChild]) != 0){ + rightChild = i + } + N_R[1, i] = N[1, rightChild] + } else { + N_L[1, i] = as.matrix(i) + N_R[1, i] = as.matrix(i) + } + + # if the node is an internal node, add index of the feature it evaluates + if (as.scalar(M[3,i]) != 0){ + N_F[1, i] = M[3,i] + } else { + N_F[1, i] = as.matrix(1) + } + + # if the node is an internal node, add the threshold of the feature it evaluates + if (as.scalar(M[6,i]) != 0){ + N_T[1, i] = M[6,i] + } else { + N_T[1, i] = as.matrix(0) + } + } +} diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index a124220..85ca3c7 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -107,6 +107,7 @@ public enum Builtins { DBSCAN("dbscan", true), DBSCANAPPLY("dbscanApply", true), DECISIONTREE("decisionTree", true), + DECISIONTREEPREDICT("decisionTreePredict", true), DECOMPRESS("decompress", false), DEEPWALK("deepWalk", true), DETECTSCHEMA("detectSchema", false), diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java new file mode 100644 index 0000000..04c1a53 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreePredictTest.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.builtin.part1; + +import java.util.HashMap; + +import org.apache.sysds.common.Types; +import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +public class BuiltinDecisionTreePredictTest extends AutomatedTestBase { + private final static String TEST_NAME = "decisionTreePredict"; + private final static String TEST_DIR = "functions/builtin/"; + private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinDecisionTreeTest.class.getSimpleName() + "/"; + + private final static double eps = 1e-10; + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"C"})); + } + + @Test + public void testDecisionTreePredictDefaultCP() { + runDecisionTreePredict(true, ExecType.CP, "TT"); + } + + @Test + public void testDecisionTreePredictSP() { + runDecisionTreePredict(true, ExecType.SPARK, "TT"); + } + + private void runDecisionTreePredict(boolean defaultProb, ExecType instType, String strategy) { + Types.ExecMode platformOld = setExecMode(instType); + try { + loadTestConfiguration(getTestConfiguration(TEST_NAME)); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-args", input("M"), input("X"), strategy, output("Y")}; + + double[][] X = {{0.5, 7, 0.1}, {0.5, 7, 0.7}, {-1, -0.2, 3}, {-1, -0.2, -0.8}, {-0.3, -0.7, 3}}; + double[][] M = {{1, 2, 3, 4, 5, 6, 7}, {1, 2, 3, 0, 0, 0, 0}, {1, 2, 3, 0, 0, 0, 0}, + {1, 1, 1, 4, 5, 6, 7}, {1, 1, 1, 0, 0, 0, 0}, {0, -0.5, 0.5, 0, 0, 0, 0}}; + + HashMap<MatrixValue.CellIndex, Double> expected_Y = new HashMap<>(); + expected_Y.put(new MatrixValue.CellIndex(1, 1), 6.0); + expected_Y.put(new MatrixValue.CellIndex(1, 2), 7.0); + expected_Y.put(new MatrixValue.CellIndex(1, 3), 5.0); + expected_Y.put(new MatrixValue.CellIndex(1, 4), 5.0); + expected_Y.put(new MatrixValue.CellIndex(1, 5), 4.0); + + writeInputMatrixWithMTD("M", M, true); + writeInputMatrixWithMTD("X", X, true); + + runTest(true, false, null, -1); + + HashMap<MatrixValue.CellIndex, Double> actual_Y = readDMLMatrixFromOutputDir("Y"); + + TestUtils.compareMatrices(expected_Y, actual_Y, eps, "Expected-DML", "Actual-DML"); + } + finally { + rtplatform = platformOld; + } + } +} diff --git a/src/test/scripts/functions/builtin/decisionTreePredict.dml b/src/test/scripts/functions/builtin/decisionTreePredict.dml new file mode 100644 index 0000000..208a827 --- /dev/null +++ b/src/test/scripts/functions/builtin/decisionTreePredict.dml @@ -0,0 +1,25 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +M = read($1); +X = read($2); +Y = decisionTreePredict(M = M, X = X, strategy = $3); +write(Y, $4);