dkerschbaumer commented on a change in pull request #1334:
URL: https://github.com/apache/systemds/pull/1334#discussion_r671124041
##########
File path: scripts/builtin/xgboostPredict.dml
##########
@@ -0,0 +1,147 @@
+# INPUT PARAMETERS:
+# ----------------------------------------------------------------------------
+# NAME TYPE DEFAULT MEANING
+# ----------------------------------------------------------------------------
+# X Matrix --- Matrix of feature
vectors we want to predict (X_test)
+# M Matrix --- The model created at
xgboost
+# sml_type Integer 1 Supervised machine
learning type: 1 = Regression(default), 2 = Classification
+# learning_rate Double 0.3 the learning rate used
in the model
+
+# RETURN VALUES
+# ----------------------------------------------------------------------------
+# NAME TYPE DEFAULT MEANING
+# ----------------------------------------------------------------------------
+# P Matrix --- The predictions of the samples using the given
xgboost model. (y_prediction)
+# ----------------------------------------------------------------------------
+
+m_xgboostPredict = function(Matrix[Double] X, Matrix[Double] M, Integer
sml_type = 1, Double learning_rate = 0.3
+) return (Matrix[Double] P) {
+
+ nr_trees = max(M[2,])
+ P = matrix(0, rows=nrow(X), cols=1)
+ initial_prediction = M[6,1]
+ trees_M_offset = calculateTreesOffset(M)
+ if(sml_type == 1) # Regression
+ {
+ for(entry in 1:nrow(X)) # go though each entry in X and calculate the new
prediction
+ {
+ output_values = matrix(0, rows=1, cols=0)
+
+ for(i in 1:nr_trees) # go through all trees
+ {
+ begin_cur_tree = as.scalar(trees_M_offset[i,])
+ if(i == nr_trees)
+ end_cur_tree = ncol(M)
+ else
+ end_cur_tree = as.scalar(trees_M_offset[i+1,]) - 1
+ output_value = getOutputValueForEntryPredict(X[entry,], M[,
begin_cur_tree:end_cur_tree])
+ output_values = cbind(output_values, as.matrix(output_value))
+ }
+ P[entry,] = initial_prediction + learning_rate * sum(output_values)
+ }
+ }
+ else # Classification
+ {
+ assert(sml_type == 2)
+ for(entry in 1:nrow(X)) # go though each entry in X and calculate the new
prediction
Review comment:
done in
https://github.com/apache/systemds/pull/1334/commits/c750df72571574b0c59d2cec7a732d5d6a558ef2
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]