This is an automated email from the ASF dual-hosted git repository.

myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git


The following commit(s) were added to refs/heads/master by this push:
     new 176fa07  [HIVEMALL-171] Tracing functionality for prediction of 
DecisionTrees
176fa07 is described below

commit 176fa070c1e2ea3b0737c8150a1302e4cb643816
Author: Makoto Yui <m...@apache.org>
AuthorDate: Sat Sep 28 03:39:01 2019 +0900

    [HIVEMALL-171] Tracing functionality for prediction of DecisionTrees
    
    ## What changes were proposed in this pull request?
    
    Introduce `decision_path` UDF providing tracing of decision tree prediction 
paths
    
    ## What type of PR is it?
    
    Feature
    
    ## What is the Jira issue?
    
    https://issues.apache.org/jira/browse/HIVEMALL-171
    
    ## How was this patch tested?
    
    unit tests, manual tests on EMR
    
    ## How to use this feature?
    
    to be described in the user guide
    
    ## Checklist
    
    - [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, 
for your commit?
    - [x] Did you run system tests on Hive (or Spark)?
    
    Author: Makoto Yui <m...@apache.org>
    
    Closes #199 from myui/HIVEMALL-171.
---
 core/src/main/java/hivemall/annotations/Cite.java  |   1 +
 .../smile/classification/DecisionTree.java         |  19 +-
 .../smile/classification/PredictionHandler.java    |  35 +-
 .../hivemall/smile/regression/RegressionTree.java  |  31 +
 .../java/hivemall/smile/tools/DecisionPathUDF.java | 659 +++++++++++++++++++++
 .../java/hivemall/smile/tools/TreePredictUDF.java  |   2 +-
 .../main/java/hivemall/utils/lang/ArrayUtils.java  |  20 +-
 .../smile/classification/DecisionTreeTest.java     |  80 +++
 docs/gitbook/misc/funcs.md                         |  33 ++
 resources/ddl/define-all-as-permanent.hive         |   3 +
 resources/ddl/define-all.hive                      |   4 +
 resources/ddl/define-all.spark                     |   3 +
 12 files changed, 879 insertions(+), 11 deletions(-)

diff --git a/core/src/main/java/hivemall/annotations/Cite.java 
b/core/src/main/java/hivemall/annotations/Cite.java
index 2b93cd6..7d09320 100644
--- a/core/src/main/java/hivemall/annotations/Cite.java
+++ b/core/src/main/java/hivemall/annotations/Cite.java
@@ -30,6 +30,7 @@ import javax.annotation.Nullable;
 public @interface Cite {
     @Nonnull
     String description();
+
     @Nullable
     String url();
 }
diff --git a/core/src/main/java/hivemall/smile/classification/DecisionTree.java 
b/core/src/main/java/hivemall/smile/classification/DecisionTree.java
index 95b4b2a..74a99ad 100644
--- a/core/src/main/java/hivemall/smile/classification/DecisionTree.java
+++ b/core/src/main/java/hivemall/smile/classification/DecisionTree.java
@@ -17,6 +17,10 @@
 // 
https://github.com/haifengl/smile/blob/master/core/src/main/java/smile/classification/DecisionTree.java
 package hivemall.smile.classification;
 
+import static hivemall.smile.classification.PredictionHandler.Operator.EQ;
+import static hivemall.smile.classification.PredictionHandler.Operator.GT;
+import static hivemall.smile.classification.PredictionHandler.Operator.LE;
+import static hivemall.smile.classification.PredictionHandler.Operator.NE;
 import static hivemall.smile.utils.SmileExtUtils.NOMINAL;
 import static hivemall.smile.utils.SmileExtUtils.NUMERIC;
 import static hivemall.smile.utils.SmileExtUtils.resolveFeatureName;
@@ -319,18 +323,23 @@ public class DecisionTree implements Classifier<Vector> {
          */
         public void predict(@Nonnull final Vector x, @Nonnull final 
PredictionHandler handler) {
             if (isLeaf()) {
-                handler.handle(output, posteriori);
+                handler.visitLeaf(output, posteriori);
             } else {
+                final double feature = x.get(splitFeature, Double.NaN);
                 if (quantitativeFeature) {
-                    if (x.get(splitFeature, Double.NaN) <= splitValue) {
+                    if (feature <= splitValue) {
+                        handler.visitBranch(LE, splitFeature, feature, 
splitValue);
                         trueChild.predict(x, handler);
                     } else {
+                        handler.visitBranch(GT, splitFeature, feature, 
splitValue);
                         falseChild.predict(x, handler);
                     }
                 } else {
-                    if (x.get(splitFeature, Double.NaN) == splitValue) {
+                    if (feature == splitValue) {
+                        handler.visitBranch(EQ, splitFeature, feature, 
splitValue);
                         trueChild.predict(x, handler);
                     } else {
+                        handler.visitBranch(NE, splitFeature, feature, 
splitValue);
                         falseChild.predict(x, handler);
                     }
                 }
@@ -1359,6 +1368,10 @@ public class DecisionTree implements Classifier<Vector> {
         return _root.predict(x);
     }
 
+    public void predict(@Nonnull final Vector x, @Nonnull final 
PredictionHandler handler) {
+        _root.predict(x, handler);
+    }
+
     /**
      * Predicts the class label of an instance and also calculate a posteriori 
probabilities. Not
      * supported.
diff --git 
a/core/src/main/java/hivemall/smile/classification/PredictionHandler.java 
b/core/src/main/java/hivemall/smile/classification/PredictionHandler.java
index 84ef244..6c19641 100644
--- a/core/src/main/java/hivemall/smile/classification/PredictionHandler.java
+++ b/core/src/main/java/hivemall/smile/classification/PredictionHandler.java
@@ -20,8 +20,39 @@ package hivemall.smile.classification;
 
 import javax.annotation.Nonnull;
 
-public interface PredictionHandler {
+public abstract class PredictionHandler {
 
-    void handle(int output, @Nonnull double[] posteriori);
+    public enum Operator {
+        /* = */ EQ, /* != */ NE, /* <= */ LE, /* > */ GT;
+
+        @Override
+        public String toString() {
+            switch (this) {
+                case EQ:
+                    return "=";
+                case NE:
+                    return "!=";
+                case LE:
+                    return "<=";
+                case GT:
+                    return ">";
+                default:
+                    throw new IllegalStateException("Unexpected operator: " + 
this);
+            }
+        }
+    }
+
+    public void init() {};
+
+    public void visitBranch(@Nonnull Operator op, int splitFeatureIndex, 
double splitFeature,
+            double splitValue) {}
+
+    public void visitLeaf(double output) {}
+
+    public void visitLeaf(int output, @Nonnull double[] posteriori) {}
+
+    public <T> T getResult() {
+        throw new UnsupportedOperationException();
+    }
 
 }
diff --git a/core/src/main/java/hivemall/smile/regression/RegressionTree.java 
b/core/src/main/java/hivemall/smile/regression/RegressionTree.java
index 764c352..ab2f25f 100755
--- a/core/src/main/java/hivemall/smile/regression/RegressionTree.java
+++ b/core/src/main/java/hivemall/smile/regression/RegressionTree.java
@@ -17,6 +17,10 @@
 // 
https://github.com/haifengl/smile/blob/master/core/src/main/java/smile/regression/RegressionTree.java
 package hivemall.smile.regression;
 
+import static hivemall.smile.classification.PredictionHandler.Operator.EQ;
+import static hivemall.smile.classification.PredictionHandler.Operator.GT;
+import static hivemall.smile.classification.PredictionHandler.Operator.LE;
+import static hivemall.smile.classification.PredictionHandler.Operator.NE;
 import static hivemall.smile.utils.SmileExtUtils.NOMINAL;
 import static hivemall.smile.utils.SmileExtUtils.NUMERIC;
 import static hivemall.smile.utils.SmileExtUtils.resolveFeatureName;
@@ -29,6 +33,7 @@ import hivemall.math.vector.DenseVector;
 import hivemall.math.vector.SparseVector;
 import hivemall.math.vector.Vector;
 import hivemall.math.vector.VectorProcedure;
+import hivemall.smile.classification.PredictionHandler;
 import hivemall.smile.utils.SmileExtUtils;
 import hivemall.smile.utils.VariableOrder;
 import hivemall.utils.collections.arrays.SparseIntArray;
@@ -274,6 +279,32 @@ public final class RegressionTree implements 
Regression<Vector> {
             }
         }
 
+        public double predict(@Nonnull final Vector x, @Nonnull final 
PredictionHandler handler) {
+            if (isLeaf()) {
+                handler.visitLeaf(output);
+                return output;
+            } else {
+                final double feature = x.get(splitFeature, Double.NaN);
+                if (quantitativeFeature) {
+                    if (feature <= splitValue) {
+                        handler.visitBranch(LE, splitFeature, feature, 
splitValue);
+                        return trueChild.predict(x);
+                    } else {
+                        handler.visitBranch(GT, splitFeature, feature, 
splitValue);
+                        return falseChild.predict(x);
+                    }
+                } else {
+                    if (feature == splitValue) {
+                        handler.visitBranch(EQ, splitFeature, feature, 
splitValue);
+                        return trueChild.predict(x);
+                    } else {
+                        handler.visitBranch(NE, splitFeature, feature, 
splitValue);
+                        return falseChild.predict(x);
+                    }
+                }
+            }
+        }
+
         /**
          * Evaluate the regression tree over an instance.
          */
diff --git a/core/src/main/java/hivemall/smile/tools/DecisionPathUDF.java 
b/core/src/main/java/hivemall/smile/tools/DecisionPathUDF.java
new file mode 100644
index 0000000..11a05da
--- /dev/null
+++ b/core/src/main/java/hivemall/smile/tools/DecisionPathUDF.java
@@ -0,0 +1,659 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.smile.tools;
+
+import hivemall.UDFWithOptions;
+import hivemall.math.vector.DenseVector;
+import hivemall.math.vector.SparseVector;
+import hivemall.math.vector.Vector;
+import hivemall.smile.classification.DecisionTree;
+import hivemall.smile.classification.PredictionHandler;
+import hivemall.smile.regression.RegressionTree;
+import hivemall.utils.codec.Base91;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Preconditions;
+import hivemall.utils.lang.StringUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
+import org.apache.hadoop.io.Text;
+
+// @formatter:off
+@Description(name = "decision_path",
+        value = "_FUNC_(string modelId, string model, array<double|string> 
features [, const string options] [, optional array<string> featureNames=null, 
optional array<string> classNames=null])"
+                + " - Returns a decision path for each prediction in 
array<string>",
+        extended = "SELECT\n" + 
+                "  t.passengerid,\n" + 
+                "  decision_path(m.model_id, m.model, t.features, 
'-classification')\n" + 
+                "FROM\n" + 
+                "  model_rf m\n" + 
+                "  LEFT OUTER JOIN\n" + 
+                "  test_rf t;\n" +
+                "> | 892 | [\"2 [0.0] = 0.0\",\"0 [3.0] = 3.0\",\"1 [696.0] != 
107.0\",\"7 [7.8292] <= 7.9104\",\"1 [696.0] != 828.0\",\"1 [696.0] != 
391.0\",\"0 [0.961038961038961, 0.03896103896103896]\"] |\n\n" +
+                "-- Show 100 frequent branches\n" +
+                "WITH tmp as (\n" + 
+                "  SELECT\n" + 
+                "    decision_path(m.model_id, m.model, t.features, 
'-classification -no_verbose -no_leaf', 
array('pclass','name','sex','age','sibsp','parch','ticket','fare','cabin','embarked'),
 array('no','yes')) as path\n" + 
+                "  FROM\n" + 
+                "    model_rf m\n" + 
+                "    LEFT OUTER JOIN -- CROSS JOIN\n" + 
+                "    test_rf t\n" + 
+                ")\n" + 
+                "select\n" + 
+                "  r.branch,\n" + 
+                "  count(1) as cnt\n" + 
+                "from\n" + 
+                "  tmp l\n" + 
+                "  LATERAL VIEW explode(l.path) r as branch\n" + 
+                "group by\n" + 
+                "  r.branch\n" + 
+                "order by\n" + 
+                "  cnt desc\n" + 
+                "limit 100;")
+// @formatter:on
+@UDFType(deterministic = true, stateful = false)
+public final class DecisionPathUDF extends UDFWithOptions {
+
+    private StringObjectInspector modelOI;
+    private ListObjectInspector featureListOI;
+    private PrimitiveObjectInspector featureElemOI;
+    private boolean denseInput;
+
+    // options
+    private boolean classification = false;
+    private boolean summarize = true;
+    private boolean verbose = true;
+    private boolean noLeaf = false;
+
+    @Nullable
+    private String[] featureNames;
+    @Nullable
+    private String[] classNames;
+
+    @Nullable
+    private transient Vector featuresProbe;
+
+    @Nullable
+    private transient Evaluator evaluator;
+
+    @Override
+    protected Options getOptions() {
+        Options opts = new Options();
+        opts.addOption("c", "classification", false,
+            "Predict as classification [default: not enabled]");
+        opts.addOption("no_sumarize", "disable_summarization", false,
+            "Do not summarize decision paths");
+        opts.addOption("no_verbose", "disable_verbose_output", false,
+            "Disable verbose output [default: verbose]");
+        opts.addOption("no_leaf", "disable_leaf_output", false,
+            "Show leaf value [default: not enabled]");
+        return opts;
+    }
+
+    @Override
+    protected CommandLine processOptions(@Nonnull String optionValue) throws 
UDFArgumentException {
+        CommandLine cl = parseOptions(optionValue);
+
+        this.classification = cl.hasOption("classification");
+        this.summarize = !cl.hasOption("no_sumarize");
+        this.verbose = !cl.hasOption("disable_verbose_output");
+        this.noLeaf = cl.hasOption("disable_leaf_output");
+
+        return cl;
+    }
+
+    @Override
+    public ObjectInspector initialize(ObjectInspector[] argOIs) throws 
UDFArgumentException {
+        if (argOIs.length < 3 || argOIs.length > 6) {
+            showHelp("tree_predict takes 3 ~ 6 arguments");
+        }
+
+        this.modelOI = HiveUtils.asStringOI(argOIs[1]);
+
+        ListObjectInspector listOI = HiveUtils.asListOI(argOIs[2]);
+        this.featureListOI = listOI;
+        ObjectInspector elemOI = listOI.getListElementObjectInspector();
+        if (HiveUtils.isNumberOI(elemOI)) {
+            this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
+            this.denseInput = true;
+        } else if (HiveUtils.isStringOI(elemOI)) {
+            this.featureElemOI = HiveUtils.asStringOI(elemOI);
+            this.denseInput = false;
+        } else {
+            throw new UDFArgumentException(
+                "tree_predict takes array<double> or array<string> for the 3rd 
argument: "
+                        + listOI.getTypeName());
+        }
+
+        if (argOIs.length >= 4) {
+            ObjectInspector argOI3 = argOIs[3];
+            if (HiveUtils.isConstString(argOI3)) {
+                String opts = HiveUtils.getConstString(argOI3);
+                processOptions(opts);
+                if (argOIs.length >= 5) {
+                    ObjectInspector argOI4 = argOIs[4];
+                    if (HiveUtils.isConstStringListOI(argOI4)) {
+                        this.featureNames = 
HiveUtils.getConstStringArray(argOI4);
+                        if (argOIs.length >= 6) {
+                            ObjectInspector argOI5 = argOIs[5];
+                            if (HiveUtils.isConstStringListOI(argOI5)) {
+                                if (!classification) {
+                                    throw new UDFArgumentException(
+                                        "classNames should not be provided for 
regression");
+                                }
+                                this.classNames = 
HiveUtils.getConstStringArray(argOI5);
+                            } else {
+                                throw new UDFArgumentException(
+                                    "decision_path expects 'const 
array<string> classNames' for the 6th argument: "
+                                            + argOI5.getTypeName());
+                            }
+                        }
+                    } else {
+                        throw new UDFArgumentException(
+                            "decision_path expects 'const array<string> 
featureNames' for the 5th argument: "
+                                    + argOI4.getTypeName());
+                    }
+                }
+            } else if (HiveUtils.isConstStringListOI(argOI3)) {
+                this.featureNames = HiveUtils.getConstStringArray(argOI3);
+                if (argOIs.length >= 5) {
+                    ObjectInspector argOI4 = argOIs[4];
+                    if (HiveUtils.isConstStringListOI(argOI4)) {
+                        if (!classification) {
+                            throw new UDFArgumentException(
+                                "classNames should not be provided for 
regression");
+                        }
+                        this.classNames = 
HiveUtils.getConstStringArray(argOI4);
+                    } else {
+                        throw new UDFArgumentException(
+                            "decision_path expects 'const array<string> 
classNames' for the 5th argument: "
+                                    + argOI4.getTypeName());
+                    }
+                }
+            } else {
+                throw new UDFArgumentException(
+                    "decision_path expects 'const array<string> options' or 
'const array<string> featureNames' for the 4th argument: "
+                            + argOI3.getTypeName());
+            }
+        }
+
+        return ObjectInspectorFactory.getStandardListObjectInspector(
+            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
+    }
+
+    @Override
+    public List<String> evaluate(@Nonnull DeferredObject[] arguments) throws 
HiveException {
+        Object arg0 = arguments[0].get();
+        if (arg0 == null) {
+            throw new HiveException("modelId should not be null");
+        }
+        // Not using string OI for backward compatibilities
+        String modelId = arg0.toString();
+
+        Object arg1 = arguments[1].get();
+        if (arg1 == null) {
+            return null;
+        }
+        Text model = modelOI.getPrimitiveWritableObject(arg1);
+
+        Object arg2 = arguments[2].get();
+        if (arg2 == null) {
+            throw new HiveException("features was null");
+        }
+        this.featuresProbe = parseFeatures(arg2, featuresProbe);
+
+        if (evaluator == null) {
+            this.evaluator = classification ? new ClassificationEvaluator(this)
+                    : new RegressionEvaluator(this);
+        }
+        return evaluator.evaluate(modelId, model, featuresProbe);
+    }
+
+    @Nonnull
+    private Vector parseFeatures(@Nonnull final Object argObj, @Nullable 
Vector probe)
+            throws UDFArgumentException {
+        if (denseInput) {
+            final int length = featureListOI.getListLength(argObj);
+            if (probe == null) {
+                probe = new DenseVector(length);
+            } else if (length != probe.size()) {
+                probe = new DenseVector(length);
+            }
+
+            for (int i = 0; i < length; i++) {
+                final Object o = featureListOI.getListElement(argObj, i);
+                if (o == null) {
+                    probe.set(i, 0.d);
+                } else {
+                    double v = PrimitiveObjectInspectorUtils.getDouble(o, 
featureElemOI);
+                    probe.set(i, v);
+                }
+            }
+        } else {
+            if (probe == null) {
+                probe = new SparseVector();
+            } else {
+                probe.clear();
+            }
+
+            final int length = featureListOI.getListLength(argObj);
+            for (int i = 0; i < length; i++) {
+                Object o = featureListOI.getListElement(argObj, i);
+                if (o == null) {
+                    continue;
+                }
+                String col = o.toString();
+
+                final int pos = col.indexOf(':');
+                if (pos == 0) {
+                    throw new UDFArgumentException("Invalid feature value 
representation: " + col);
+                }
+
+                final String feature;
+                final double value;
+                if (pos > 0) {
+                    feature = col.substring(0, pos);
+                    String s2 = col.substring(pos + 1);
+                    value = Double.parseDouble(s2);
+                } else {
+                    feature = col;
+                    value = 1.d;
+                }
+
+                if (feature.indexOf(':') != -1) {
+                    throw new UDFArgumentException(
+                        "Invalid feature format `<index>:<value>`: " + col);
+                }
+
+                final int colIndex = Integer.parseInt(feature);
+                if (colIndex < 0) {
+                    throw new UDFArgumentException(
+                        "Col index MUST be greater than or equals to 0: " + 
colIndex);
+                }
+                probe.set(colIndex, value);
+            }
+        }
+        return probe;
+    }
+
+    @Override
+    public void close() throws IOException {
+        this.modelOI = null;
+        this.featureElemOI = null;
+        this.featureListOI = null;
+        this.featureNames = null;
+        this.classNames = null;
+        this.featuresProbe = null;
+        this.evaluator = null;
+    }
+
+    @Override
+    public String getDisplayString(String[] children) {
+        return "decision_path(" + StringUtils.join(children, ',') + ")";
+    }
+
+    interface Evaluator {
+
+        @Nonnull
+        List<String> evaluate(@Nonnull String modelId, @Nonnull Text model,
+                @Nonnull Vector features) throws HiveException;
+
+    }
+
+    static final class ClassificationEvaluator implements Evaluator {
+
+        @Nullable
+        private final String[] featureNames;
+        @Nullable
+        private final String[] classNames;
+
+        @Nonnull
+        private final List<String> result;
+        @Nonnull
+        private final PredictionHandler handler;
+
+        @Nullable
+        private String prevModelId = null;
+        private DecisionTree.Node cNode = null;
+
+        ClassificationEvaluator(@Nonnull final DecisionPathUDF udf) {
+            this.featureNames = udf.featureNames;
+            this.classNames = udf.classNames;
+
+            final StringBuilder buf = new StringBuilder();
+            final ArrayList<String> result = new ArrayList<>();
+            this.result = result;
+
+            if (udf.summarize) {
+                final LinkedHashMap<String, Double> map = new 
LinkedHashMap<>();
+
+                this.handler = new PredictionHandler() {
+
+                    @Override
+                    public void init() {
+                        map.clear();
+                        result.clear();
+                    }
+
+                    @Override
+                    public void visitBranch(Operator op, int 
splitFeatureIndex, double splitFeature,
+                            double splitValue) {
+                        buf.append(resolveFeatureName(splitFeatureIndex));
+                        if (udf.verbose) {
+                            buf.append(" [" + splitFeature + "] ");
+                        } else {
+                            buf.append(' ');
+                        }
+                        buf.append(op);
+                        if (op == Operator.EQ || op == Operator.NE) {
+                            buf.append(' ');
+                            buf.append(splitValue);
+                        }
+                        String key = buf.toString();
+                        map.put(key, splitValue);
+                        StringUtils.clear(buf);
+                    }
+
+                    @Override
+                    public void visitLeaf(int output, double[] posteriori) {
+                        for (Map.Entry<String, Double> e : map.entrySet()) {
+                            final String key = e.getKey();
+                            if (key.indexOf('<') == -1 && key.indexOf('>') == 
-1) {
+                                result.add(key);
+                            } else {
+                                double value = e.getValue().doubleValue();
+                                result.add(key + ' ' + value);
+                            }
+                        }
+                        if (udf.noLeaf) {
+                            return;
+                        }
+
+                        if (udf.verbose) {
+                            buf.append(resolveClassName(output));
+                            buf.append(' ');
+                            buf.append(Arrays.toString(posteriori));
+                            result.add(buf.toString());
+                            StringUtils.clear(buf);
+                        } else {
+                            result.add(resolveClassName(output));
+                        }
+                    }
+
+                    @SuppressWarnings("unchecked")
+                    @Override
+                    public ArrayList<String> getResult() {
+                        return result;
+                    }
+
+                };
+            } else {
+                this.handler = new PredictionHandler() {
+
+                    @Override
+                    public void init() {
+                        result.clear();
+                    }
+
+                    @Override
+                    public void visitBranch(Operator op, int 
splitFeatureIndex, double splitFeature,
+                            double splitValue) {
+                        buf.append(resolveFeatureName(splitFeatureIndex));
+                        if (udf.verbose) {
+                            buf.append(" [" + splitFeature + "] ");
+                        } else {
+                            buf.append(' ');
+                        }
+                        buf.append(op);
+                        buf.append(' ');
+                        buf.append(splitValue);
+                        result.add(buf.toString());
+                        StringUtils.clear(buf);
+                    }
+
+                    @Override
+                    public void visitLeaf(int output, double[] posteriori) {
+                        if (udf.noLeaf) {
+                            return;
+                        }
+
+                        if (udf.verbose) {
+                            buf.append(resolveClassName(output));
+                            buf.append(' ');
+                            buf.append(Arrays.toString(posteriori));
+                            result.add(buf.toString());
+                            StringUtils.clear(buf);
+                        } else {
+                            result.add(resolveClassName(output));
+                        }
+                    }
+
+                    @SuppressWarnings("unchecked")
+                    @Override
+                    public ArrayList<String> getResult() {
+                        return result;
+                    }
+
+                };
+            }
+        }
+
+        @Nonnull
+        private String resolveFeatureName(final int splitFeatureIndex) {
+            if (featureNames == null) {
+                return Integer.toString(splitFeatureIndex);
+            } else {
+                return featureNames[splitFeatureIndex];
+            }
+        }
+
+        @Nonnull
+        private String resolveClassName(final int classLabel) {
+            if (classNames == null) {
+                return Integer.toString(classLabel);
+            } else {
+                return classNames[classLabel];
+            }
+        }
+
+        @Nonnull
+        public List<String> evaluate(@Nonnull final String modelId, @Nonnull 
final Text script,
+                @Nonnull final Vector features) throws HiveException {
+            if (!modelId.equals(prevModelId)) {
+                this.prevModelId = modelId;
+                int length = script.getLength();
+                byte[] b = script.getBytes();
+                b = Base91.decode(b, 0, length);
+                this.cNode = DecisionTree.deserialize(b, b.length, true);
+            }
+            Preconditions.checkNotNull(cNode);
+
+            handler.init();
+            cNode.predict(features, handler);
+            return handler.getResult();
+        }
+
+    }
+
+    static final class RegressionEvaluator implements Evaluator {
+
+        @Nullable
+        private final String[] featureNames;
+
+        @Nonnull
+        private final List<String> result;
+        @Nonnull
+        private final PredictionHandler handler;
+
+        @Nullable
+        private String prevModelId = null;
+        private RegressionTree.Node rNode = null;
+
+        RegressionEvaluator(@Nonnull final DecisionPathUDF udf) {
+            this.featureNames = udf.featureNames;
+
+            final StringBuilder buf = new StringBuilder();
+            final ArrayList<String> result = new ArrayList<>();
+            this.result = result;
+
+            if (udf.summarize) {
+                final LinkedHashMap<String, Double> map = new 
LinkedHashMap<>();
+
+                this.handler = new PredictionHandler() {
+
+                    @Override
+                    public void init() {
+                        map.clear();
+                        result.clear();
+                    }
+
+                    @Override
+                    public void visitBranch(Operator op, int 
splitFeatureIndex, double splitFeature,
+                            double splitValue) {
+                        buf.append(resolveFeatureName(splitFeatureIndex));
+                        if (udf.verbose) {
+                            buf.append(" [" + splitFeature + "] ");
+                        } else {
+                            buf.append(' ');
+                        }
+                        buf.append(op);
+                        if (op == Operator.EQ || op == Operator.NE) {
+                            buf.append(' ');
+                            buf.append(splitValue);
+                        }
+                        String key = buf.toString();
+                        map.put(key, splitValue);
+                        StringUtils.clear(buf);
+                    }
+
+                    @Override
+                    public void visitLeaf(double output) {
+                        for (Map.Entry<String, Double> e : map.entrySet()) {
+                            final String key = e.getKey();
+                            if (key.indexOf('<') == -1 && key.indexOf('>') == 
-1) {
+                                result.add(key);
+                            } else {
+                                double value = e.getValue().doubleValue();
+                                result.add(key + ' ' + value);
+                            }
+                        }
+                        if (udf.noLeaf) {
+                            return;
+                        }
+
+                        result.add(Double.toString(output));
+                    }
+
+                    @SuppressWarnings("unchecked")
+                    @Override
+                    public ArrayList<String> getResult() {
+                        return result;
+                    }
+
+                };
+            } else {
+                this.handler = new PredictionHandler() {
+
+                    @Override
+                    public void init() {
+                        result.clear();
+                    }
+
+                    @Override
+                    public void visitBranch(Operator op, int 
splitFeatureIndex, double splitFeature,
+                            double splitValue) {
+                        buf.append(resolveFeatureName(splitFeatureIndex));
+                        if (udf.verbose) {
+                            buf.append(" [" + splitFeature + "] ");
+                        }
+                        buf.append(op);
+                        buf.append(' ');
+                        buf.append(splitValue);
+                        result.add(buf.toString());
+                        StringUtils.clear(buf);
+                    }
+
+                    @Override
+                    public void visitLeaf(double output) {
+                        if (udf.noLeaf) {
+                            return;
+                        }
+
+                        result.add(Double.toString(output));
+                    }
+
+                    @SuppressWarnings("unchecked")
+                    @Override
+                    public ArrayList<String> getResult() {
+                        return result;
+                    }
+
+                };
+            }
+        }
+
+        @Nonnull
+        private String resolveFeatureName(final int splitFeatureIndex) {
+            if (featureNames == null) {
+                return Integer.toString(splitFeatureIndex);
+            } else {
+                return featureNames[splitFeatureIndex];
+            }
+        }
+
+        @Nonnull
+        public List<String> evaluate(@Nonnull final String modelId, @Nonnull 
final Text script,
+                @Nonnull final Vector features) throws HiveException {
+            if (!modelId.equals(prevModelId)) {
+                this.prevModelId = modelId;
+                int length = script.getLength();
+                byte[] b = script.getBytes();
+                b = Base91.decode(b, 0, length);
+                this.rNode = RegressionTree.deserialize(b, b.length, true);
+            }
+            Preconditions.checkNotNull(rNode);
+
+            handler.init();
+            rNode.predict(features, handler);
+            return handler.getResult();
+        }
+    }
+
+}
diff --git a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java 
b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
index 511944c..262a28d 100644
--- a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
+++ b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
@@ -284,7 +284,7 @@ public final class TreePredictUDF extends UDFWithOptions {
             Arrays.fill(result, null);
             Preconditions.checkNotNull(cNode);
             cNode.predict(features, new PredictionHandler() {
-                public void handle(int output, double[] posteriori) {
+                public void visitLeaf(int output, double[] posteriori) {
                     result[0] = new IntWritable(output);
                     result[1] = WritableUtils.toWritableList(posteriori);
                 }
diff --git a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java 
b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
index 4e73ebc..caf21d3 100644
--- a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
+++ b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
@@ -148,17 +148,23 @@ public final class ArrayUtils {
         return Arrays.asList(v);
     }
 
-    public static <T> void shuffle(@Nonnull final T[] array) {
+    @Nonnull
+    public static <T> T[] shuffle(@Nonnull final T[] array) {
         shuffle(array, array.length);
+        return array;
     }
 
-    public static <T> void shuffle(@Nonnull final T[] array, final Random rnd) 
{
+    @Nonnull
+    public static <T> T[] shuffle(@Nonnull final T[] array, final Random rnd) {
         shuffle(array, array.length, rnd);
+        return array;
     }
 
-    public static <T> void shuffle(@Nonnull final T[] array, final int size) {
+    @Nonnull
+    public static <T> T[] shuffle(@Nonnull final T[] array, final int size) {
         Random rnd = new Random();
         shuffle(array, size, rnd);
+        return array;
     }
 
     /**
@@ -166,19 +172,23 @@ public final class ArrayUtils {
      * 
      * @link http://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
      */
-    public static <T> void shuffle(@Nonnull final T[] array, final int size,
+    @Nonnull
+    public static <T> T[] shuffle(@Nonnull final T[] array, final int size,
             @Nonnull final Random rnd) {
         for (int i = size; i > 1; i--) {
             int randomPosition = rnd.nextInt(i);
             swap(array, i - 1, randomPosition);
         }
+        return array;
     }
 
-    public static void shuffle(@Nonnull final int[] array, @Nonnull final 
Random rnd) {
+    @Nonnull
+    public static int[] shuffle(@Nonnull final int[] array, @Nonnull final 
Random rnd) {
         for (int i = array.length; i > 1; i--) {
             int randomPosition = rnd.nextInt(i);
             swap(array, i - 1, randomPosition);
         }
+        return array;
     }
 
     public static void swap(@Nonnull final Object[] arr, final int i, final 
int j) {
diff --git 
a/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java 
b/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
index c3601eb..9e5ee9a 100644
--- a/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
+++ b/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
@@ -25,12 +25,17 @@ import hivemall.math.matrix.builders.CSRMatrixBuilder;
 import hivemall.math.matrix.dense.RowMajorDenseMatrix2d;
 import hivemall.math.random.PRNG;
 import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.math.vector.DenseVector;
 import hivemall.smile.classification.DecisionTree.Node;
 import hivemall.smile.classification.DecisionTree.SplitRule;
 import hivemall.smile.tools.TreeExportUDF.Evaluator;
 import hivemall.smile.tools.TreeExportUDF.OutputType;
 import hivemall.smile.utils.SmileExtUtils;
 import hivemall.utils.codec.Base91;
+import hivemall.utils.lang.ArrayUtils;
+import hivemall.utils.lang.StringUtils;
+import hivemall.utils.math.MathUtils;
+import smile.data.Attribute;
 import smile.data.AttributeDataset;
 import smile.data.NominalAttribute;
 import smile.data.parser.ArffParser;
@@ -43,6 +48,9 @@ import java.io.IOException;
 import java.io.InputStream;
 import java.net.URL;
 import java.text.ParseException;
+import java.util.Arrays;
+import java.util.LinkedHashMap;
+import java.util.Random;
 
 import javax.annotation.Nonnull;
 
@@ -99,6 +107,15 @@ public class DecisionTreeTest {
     }
 
     @Test
+    public void testIrisTracePredict() throws IOException, ParseException {
+        int responseIndex = 4;
+        int numLeafs = Integer.MAX_VALUE;
+        runTracePredict(
+            
"https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff";,
+            responseIndex, numLeafs);
+    }
+
+    @Test
     public void testIrisDepth4() throws IOException, ParseException {
         int responseIndex = 4;
         int numLeafs = 4;
@@ -240,6 +257,69 @@ public class DecisionTreeTest {
         }
     }
 
+    private static void runTracePredict(String datasetUrl, int responseIndex, 
int numLeafs)
+            throws IOException, ParseException {
+        URL url = new URL(datasetUrl);
+        InputStream is = new BufferedInputStream(url.openStream());
+
+        ArffParser arffParser = new ArffParser();
+        arffParser.setResponseIndex(responseIndex);
+
+        AttributeDataset ds = arffParser.parse(is);
+        final Attribute[] attrs = ds.attributes();
+        final Attribute targetAttr = ds.response();
+
+        double[][] x = ds.toArray(new double[ds.size()][]);
+        int[] y = ds.toArray(new int[ds.size()]);
+
+        Random rnd = new Random(43L);
+        int numTrain = (int) (x.length * 0.7);
+        int[] index = ArrayUtils.shuffle(MathUtils.permutation(x.length), rnd);
+        int[] cvTrain = Arrays.copyOf(index, numTrain);
+        int[] cvTest = Arrays.copyOfRange(index, numTrain, index.length);
+
+        double[][] trainx = Math.slice(x, cvTrain);
+        int[] trainy = Math.slice(y, cvTrain);
+        double[][] testx = Math.slice(x, cvTest);
+
+        DecisionTree tree = new 
DecisionTree(SmileExtUtils.convertAttributeTypes(attrs),
+            matrix(trainx, false), trainy, numLeafs, 
RandomNumberGeneratorFactory.createPRNG(43L));
+
+        final LinkedHashMap<String, Double> map = new LinkedHashMap<>();
+        final StringBuilder buf = new StringBuilder();
+        for (int i = 0; i < testx.length; i++) {
+            final DenseVector test = new DenseVector(testx[i]);
+            tree.predict(test, new PredictionHandler() {
+
+                @Override
+                public void visitBranch(Operator op, int splitFeatureIndex, 
double splitFeature,
+                        double splitValue) {
+                    buf.append(attrs[splitFeatureIndex].name);
+                    buf.append(" [" + splitFeature + "] ");
+                    buf.append(op);
+                    buf.append(' ');
+                    buf.append(splitValue);
+                    buf.append('\n');
+
+                    map.put(attrs[splitFeatureIndex].name + " [" + 
splitFeature + "] " + op,
+                        splitValue);
+                }
+
+                @Override
+                public void visitLeaf(int output, double[] posteriori) {
+                    buf.append(targetAttr.toString(output));
+                }
+            });
+
+            Assert.assertTrue(buf.length() > 0);
+            Assert.assertFalse(map.isEmpty());
+
+            StringUtils.clear(buf);
+            map.clear();
+        }
+
+    }
+
     @Test
     public void testIrisSerializedObj() throws IOException, ParseException, 
HiveException {
         URL url = new URL(
diff --git a/docs/gitbook/misc/funcs.md b/docs/gitbook/misc/funcs.md
index d860dba..e5e9dc8 100644
--- a/docs/gitbook/misc/funcs.md
+++ b/docs/gitbook/misc/funcs.md
@@ -589,6 +589,39 @@ Reference: <a 
href="https://papers.nips.cc/paper/3848-adaptive-regularization-of
 
 - `train_randomforest_regressor(array<double|string> features, double target 
[, string options])` - Returns a relation consists of &lt;int model_id, int 
model_type, string model, array&lt;double&gt; var_importance, double 
oob_errors, int oob_tests&gt;
 
+- `decision_path(string modelId, string model, array<double|string> features 
[, const string options] [, optional array<string> featureNames=null, optional 
array<string> classNames=null])` - Returns a decision path for each prediction 
in array&lt;string&gt;
+  ```sql
+  SELECT
+    t.passengerid,
+    decision_path(m.model_id, m.model, t.features, '-classification')
+  FROM
+    model_rf m
+    LEFT OUTER JOIN
+    test_rf t;
+  > | 892 | ["2 [0.0] = 0.0","0 [3.0] = 3.0","1 [696.0] != 107.0","7 [7.8292] 
<= 7.9104","1 [696.0] != 828.0","1 [696.0] != 391.0","0 [0.961038961038961, 
0.03896103896103896]"] |
+
+  -- Show 100 frequent branches
+  WITH tmp as (
+    SELECT
+      decision_path(m.model_id, m.model, t.features, '-classification 
-no_verbose -no_leaf', 
array('pclass','name','sex','age','sibsp','parch','ticket','fare','cabin','embarked'),
 array('no','yes')) as path
+    FROM
+      model_rf m
+      LEFT OUTER JOIN -- CROSS JOIN
+      test_rf t
+  )
+  select
+    r.branch,
+    count(1) as cnt
+  from
+    tmp l
+    LATERAL VIEW explode(l.path) r as branch
+  group by
+    r.branch
+  order by
+    cnt desc
+  limit 100;
+  ```
+
 - `guess_attribute_types(ANY, ...)` - Returns attribute types
   ```sql
   select guess_attribute_types(*) from train limit 1;
diff --git a/resources/ddl/define-all-as-permanent.hive 
b/resources/ddl/define-all-as-permanent.hive
index 17797a8..343215a 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -829,6 +829,9 @@ CREATE FUNCTION rf_ensemble as 
'hivemall.smile.tools.RandomForestEnsembleUDAF' U
 DROP FUNCTION IF EXISTS guess_attribute_types;
 CREATE FUNCTION guess_attribute_types as 
'hivemall.smile.tools.GuessAttributesUDF' USING JAR '${hivemall_jar}';
 
+DROP FUNCTION IF EXISTS decision_path;
+CREATE FUNCTION decision_path as 'hivemall.smile.tools.DecisionPathUDF' USING 
JAR '${hivemall_jar}';
+
 --------------------
 -- Recommendation --
 --------------------
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index 04e8915..2a9b437 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -821,6 +821,9 @@ create temporary function rf_ensemble as 
'hivemall.smile.tools.RandomForestEnsem
 drop temporary function if exists guess_attribute_types;
 create temporary function guess_attribute_types as 
'hivemall.smile.tools.GuessAttributesUDF';
 
+drop temporary function if exists decision_path;
+create temporary function decision_path as 
'hivemall.smile.tools.DecisionPathUDF';
+
 --------------------
 -- Recommendation --
 --------------------
@@ -889,3 +892,4 @@ log(10, n_docs / max2(1,df_t)) + 1.0;
 
 create temporary macro tfidf(tf FLOAT, df_t DOUBLE, n_docs DOUBLE)
 tf * (log(10, n_docs / max2(1,df_t)) + 1.0);
+
diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark
index 19f01bc..d62e3a2 100644
--- a/resources/ddl/define-all.spark
+++ b/resources/ddl/define-all.spark
@@ -807,6 +807,9 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION 
guess_attribute_types AS 'hivemall.smi
 sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS 
train_gradient_tree_boosting_classifier")
 sqlContext.sql("CREATE TEMPORARY FUNCTION 
train_gradient_tree_boosting_classifier AS 
'hivemall.smile.classification.GradientTreeBoostingClassifierUDTF'")
 
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS decision_path")
+sqlContext.sql("CREATE TEMPORARY FUNCTION decision_path AS 
'hivemall.smile.tools.DecisionPathUDF'")
+
 /**
  * Recommendation
  */

Reply via email to