This is an automated email from the ASF dual-hosted git repository. myui pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git
The following commit(s) were added to refs/heads/master by this push: new 176fa07 [HIVEMALL-171] Tracing functionality for prediction of DecisionTrees 176fa07 is described below commit 176fa070c1e2ea3b0737c8150a1302e4cb643816 Author: Makoto Yui <m...@apache.org> AuthorDate: Sat Sep 28 03:39:01 2019 +0900 [HIVEMALL-171] Tracing functionality for prediction of DecisionTrees ## What changes were proposed in this pull request? Introduce `decision_path` UDF providing tracing of decision tree prediction paths ## What type of PR is it? Feature ## What is the Jira issue? https://issues.apache.org/jira/browse/HIVEMALL-171 ## How was this patch tested? unit tests, manual tests on EMR ## How to use this feature? to be described in the user guide ## Checklist - [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit? - [x] Did you run system tests on Hive (or Spark)? Author: Makoto Yui <m...@apache.org> Closes #199 from myui/HIVEMALL-171. --- core/src/main/java/hivemall/annotations/Cite.java | 1 + .../smile/classification/DecisionTree.java | 19 +- .../smile/classification/PredictionHandler.java | 35 +- .../hivemall/smile/regression/RegressionTree.java | 31 + .../java/hivemall/smile/tools/DecisionPathUDF.java | 659 +++++++++++++++++++++ .../java/hivemall/smile/tools/TreePredictUDF.java | 2 +- .../main/java/hivemall/utils/lang/ArrayUtils.java | 20 +- .../smile/classification/DecisionTreeTest.java | 80 +++ docs/gitbook/misc/funcs.md | 33 ++ resources/ddl/define-all-as-permanent.hive | 3 + resources/ddl/define-all.hive | 4 + resources/ddl/define-all.spark | 3 + 12 files changed, 879 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/hivemall/annotations/Cite.java b/core/src/main/java/hivemall/annotations/Cite.java index 2b93cd6..7d09320 100644 --- a/core/src/main/java/hivemall/annotations/Cite.java +++ b/core/src/main/java/hivemall/annotations/Cite.java @@ -30,6 +30,7 @@ import javax.annotation.Nullable; public @interface Cite { @Nonnull String description(); + @Nullable String url(); } diff --git a/core/src/main/java/hivemall/smile/classification/DecisionTree.java b/core/src/main/java/hivemall/smile/classification/DecisionTree.java index 95b4b2a..74a99ad 100644 --- a/core/src/main/java/hivemall/smile/classification/DecisionTree.java +++ b/core/src/main/java/hivemall/smile/classification/DecisionTree.java @@ -17,6 +17,10 @@ // https://github.com/haifengl/smile/blob/master/core/src/main/java/smile/classification/DecisionTree.java package hivemall.smile.classification; +import static hivemall.smile.classification.PredictionHandler.Operator.EQ; +import static hivemall.smile.classification.PredictionHandler.Operator.GT; +import static hivemall.smile.classification.PredictionHandler.Operator.LE; +import static hivemall.smile.classification.PredictionHandler.Operator.NE; import static hivemall.smile.utils.SmileExtUtils.NOMINAL; import static hivemall.smile.utils.SmileExtUtils.NUMERIC; import static hivemall.smile.utils.SmileExtUtils.resolveFeatureName; @@ -319,18 +323,23 @@ public class DecisionTree implements Classifier<Vector> { */ public void predict(@Nonnull final Vector x, @Nonnull final PredictionHandler handler) { if (isLeaf()) { - handler.handle(output, posteriori); + handler.visitLeaf(output, posteriori); } else { + final double feature = x.get(splitFeature, Double.NaN); if (quantitativeFeature) { - if (x.get(splitFeature, Double.NaN) <= splitValue) { + if (feature <= splitValue) { + handler.visitBranch(LE, splitFeature, feature, splitValue); trueChild.predict(x, handler); } else { + handler.visitBranch(GT, splitFeature, feature, splitValue); falseChild.predict(x, handler); } } else { - if (x.get(splitFeature, Double.NaN) == splitValue) { + if (feature == splitValue) { + handler.visitBranch(EQ, splitFeature, feature, splitValue); trueChild.predict(x, handler); } else { + handler.visitBranch(NE, splitFeature, feature, splitValue); falseChild.predict(x, handler); } } @@ -1359,6 +1368,10 @@ public class DecisionTree implements Classifier<Vector> { return _root.predict(x); } + public void predict(@Nonnull final Vector x, @Nonnull final PredictionHandler handler) { + _root.predict(x, handler); + } + /** * Predicts the class label of an instance and also calculate a posteriori probabilities. Not * supported. diff --git a/core/src/main/java/hivemall/smile/classification/PredictionHandler.java b/core/src/main/java/hivemall/smile/classification/PredictionHandler.java index 84ef244..6c19641 100644 --- a/core/src/main/java/hivemall/smile/classification/PredictionHandler.java +++ b/core/src/main/java/hivemall/smile/classification/PredictionHandler.java @@ -20,8 +20,39 @@ package hivemall.smile.classification; import javax.annotation.Nonnull; -public interface PredictionHandler { +public abstract class PredictionHandler { - void handle(int output, @Nonnull double[] posteriori); + public enum Operator { + /* = */ EQ, /* != */ NE, /* <= */ LE, /* > */ GT; + + @Override + public String toString() { + switch (this) { + case EQ: + return "="; + case NE: + return "!="; + case LE: + return "<="; + case GT: + return ">"; + default: + throw new IllegalStateException("Unexpected operator: " + this); + } + } + } + + public void init() {}; + + public void visitBranch(@Nonnull Operator op, int splitFeatureIndex, double splitFeature, + double splitValue) {} + + public void visitLeaf(double output) {} + + public void visitLeaf(int output, @Nonnull double[] posteriori) {} + + public <T> T getResult() { + throw new UnsupportedOperationException(); + } } diff --git a/core/src/main/java/hivemall/smile/regression/RegressionTree.java b/core/src/main/java/hivemall/smile/regression/RegressionTree.java index 764c352..ab2f25f 100755 --- a/core/src/main/java/hivemall/smile/regression/RegressionTree.java +++ b/core/src/main/java/hivemall/smile/regression/RegressionTree.java @@ -17,6 +17,10 @@ // https://github.com/haifengl/smile/blob/master/core/src/main/java/smile/regression/RegressionTree.java package hivemall.smile.regression; +import static hivemall.smile.classification.PredictionHandler.Operator.EQ; +import static hivemall.smile.classification.PredictionHandler.Operator.GT; +import static hivemall.smile.classification.PredictionHandler.Operator.LE; +import static hivemall.smile.classification.PredictionHandler.Operator.NE; import static hivemall.smile.utils.SmileExtUtils.NOMINAL; import static hivemall.smile.utils.SmileExtUtils.NUMERIC; import static hivemall.smile.utils.SmileExtUtils.resolveFeatureName; @@ -29,6 +33,7 @@ import hivemall.math.vector.DenseVector; import hivemall.math.vector.SparseVector; import hivemall.math.vector.Vector; import hivemall.math.vector.VectorProcedure; +import hivemall.smile.classification.PredictionHandler; import hivemall.smile.utils.SmileExtUtils; import hivemall.smile.utils.VariableOrder; import hivemall.utils.collections.arrays.SparseIntArray; @@ -274,6 +279,32 @@ public final class RegressionTree implements Regression<Vector> { } } + public double predict(@Nonnull final Vector x, @Nonnull final PredictionHandler handler) { + if (isLeaf()) { + handler.visitLeaf(output); + return output; + } else { + final double feature = x.get(splitFeature, Double.NaN); + if (quantitativeFeature) { + if (feature <= splitValue) { + handler.visitBranch(LE, splitFeature, feature, splitValue); + return trueChild.predict(x); + } else { + handler.visitBranch(GT, splitFeature, feature, splitValue); + return falseChild.predict(x); + } + } else { + if (feature == splitValue) { + handler.visitBranch(EQ, splitFeature, feature, splitValue); + return trueChild.predict(x); + } else { + handler.visitBranch(NE, splitFeature, feature, splitValue); + return falseChild.predict(x); + } + } + } + } + /** * Evaluate the regression tree over an instance. */ diff --git a/core/src/main/java/hivemall/smile/tools/DecisionPathUDF.java b/core/src/main/java/hivemall/smile/tools/DecisionPathUDF.java new file mode 100644 index 0000000..11a05da --- /dev/null +++ b/core/src/main/java/hivemall/smile/tools/DecisionPathUDF.java @@ -0,0 +1,659 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.smile.tools; + +import hivemall.UDFWithOptions; +import hivemall.math.vector.DenseVector; +import hivemall.math.vector.SparseVector; +import hivemall.math.vector.Vector; +import hivemall.smile.classification.DecisionTree; +import hivemall.smile.classification.PredictionHandler; +import hivemall.smile.regression.RegressionTree; +import hivemall.utils.codec.Base91; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Preconditions; +import hivemall.utils.lang.StringUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; +import org.apache.hadoop.io.Text; + +// @formatter:off +@Description(name = "decision_path", + value = "_FUNC_(string modelId, string model, array<double|string> features [, const string options] [, optional array<string> featureNames=null, optional array<string> classNames=null])" + + " - Returns a decision path for each prediction in array<string>", + extended = "SELECT\n" + + " t.passengerid,\n" + + " decision_path(m.model_id, m.model, t.features, '-classification')\n" + + "FROM\n" + + " model_rf m\n" + + " LEFT OUTER JOIN\n" + + " test_rf t;\n" + + "> | 892 | [\"2 [0.0] = 0.0\",\"0 [3.0] = 3.0\",\"1 [696.0] != 107.0\",\"7 [7.8292] <= 7.9104\",\"1 [696.0] != 828.0\",\"1 [696.0] != 391.0\",\"0 [0.961038961038961, 0.03896103896103896]\"] |\n\n" + + "-- Show 100 frequent branches\n" + + "WITH tmp as (\n" + + " SELECT\n" + + " decision_path(m.model_id, m.model, t.features, '-classification -no_verbose -no_leaf', array('pclass','name','sex','age','sibsp','parch','ticket','fare','cabin','embarked'), array('no','yes')) as path\n" + + " FROM\n" + + " model_rf m\n" + + " LEFT OUTER JOIN -- CROSS JOIN\n" + + " test_rf t\n" + + ")\n" + + "select\n" + + " r.branch,\n" + + " count(1) as cnt\n" + + "from\n" + + " tmp l\n" + + " LATERAL VIEW explode(l.path) r as branch\n" + + "group by\n" + + " r.branch\n" + + "order by\n" + + " cnt desc\n" + + "limit 100;") +// @formatter:on +@UDFType(deterministic = true, stateful = false) +public final class DecisionPathUDF extends UDFWithOptions { + + private StringObjectInspector modelOI; + private ListObjectInspector featureListOI; + private PrimitiveObjectInspector featureElemOI; + private boolean denseInput; + + // options + private boolean classification = false; + private boolean summarize = true; + private boolean verbose = true; + private boolean noLeaf = false; + + @Nullable + private String[] featureNames; + @Nullable + private String[] classNames; + + @Nullable + private transient Vector featuresProbe; + + @Nullable + private transient Evaluator evaluator; + + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("c", "classification", false, + "Predict as classification [default: not enabled]"); + opts.addOption("no_sumarize", "disable_summarization", false, + "Do not summarize decision paths"); + opts.addOption("no_verbose", "disable_verbose_output", false, + "Disable verbose output [default: verbose]"); + opts.addOption("no_leaf", "disable_leaf_output", false, + "Show leaf value [default: not enabled]"); + return opts; + } + + @Override + protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException { + CommandLine cl = parseOptions(optionValue); + + this.classification = cl.hasOption("classification"); + this.summarize = !cl.hasOption("no_sumarize"); + this.verbose = !cl.hasOption("disable_verbose_output"); + this.noLeaf = cl.hasOption("disable_leaf_output"); + + return cl; + } + + @Override + public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + if (argOIs.length < 3 || argOIs.length > 6) { + showHelp("tree_predict takes 3 ~ 6 arguments"); + } + + this.modelOI = HiveUtils.asStringOI(argOIs[1]); + + ListObjectInspector listOI = HiveUtils.asListOI(argOIs[2]); + this.featureListOI = listOI; + ObjectInspector elemOI = listOI.getListElementObjectInspector(); + if (HiveUtils.isNumberOI(elemOI)) { + this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI); + this.denseInput = true; + } else if (HiveUtils.isStringOI(elemOI)) { + this.featureElemOI = HiveUtils.asStringOI(elemOI); + this.denseInput = false; + } else { + throw new UDFArgumentException( + "tree_predict takes array<double> or array<string> for the 3rd argument: " + + listOI.getTypeName()); + } + + if (argOIs.length >= 4) { + ObjectInspector argOI3 = argOIs[3]; + if (HiveUtils.isConstString(argOI3)) { + String opts = HiveUtils.getConstString(argOI3); + processOptions(opts); + if (argOIs.length >= 5) { + ObjectInspector argOI4 = argOIs[4]; + if (HiveUtils.isConstStringListOI(argOI4)) { + this.featureNames = HiveUtils.getConstStringArray(argOI4); + if (argOIs.length >= 6) { + ObjectInspector argOI5 = argOIs[5]; + if (HiveUtils.isConstStringListOI(argOI5)) { + if (!classification) { + throw new UDFArgumentException( + "classNames should not be provided for regression"); + } + this.classNames = HiveUtils.getConstStringArray(argOI5); + } else { + throw new UDFArgumentException( + "decision_path expects 'const array<string> classNames' for the 6th argument: " + + argOI5.getTypeName()); + } + } + } else { + throw new UDFArgumentException( + "decision_path expects 'const array<string> featureNames' for the 5th argument: " + + argOI4.getTypeName()); + } + } + } else if (HiveUtils.isConstStringListOI(argOI3)) { + this.featureNames = HiveUtils.getConstStringArray(argOI3); + if (argOIs.length >= 5) { + ObjectInspector argOI4 = argOIs[4]; + if (HiveUtils.isConstStringListOI(argOI4)) { + if (!classification) { + throw new UDFArgumentException( + "classNames should not be provided for regression"); + } + this.classNames = HiveUtils.getConstStringArray(argOI4); + } else { + throw new UDFArgumentException( + "decision_path expects 'const array<string> classNames' for the 5th argument: " + + argOI4.getTypeName()); + } + } + } else { + throw new UDFArgumentException( + "decision_path expects 'const array<string> options' or 'const array<string> featureNames' for the 4th argument: " + + argOI3.getTypeName()); + } + } + + return ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector); + } + + @Override + public List<String> evaluate(@Nonnull DeferredObject[] arguments) throws HiveException { + Object arg0 = arguments[0].get(); + if (arg0 == null) { + throw new HiveException("modelId should not be null"); + } + // Not using string OI for backward compatibilities + String modelId = arg0.toString(); + + Object arg1 = arguments[1].get(); + if (arg1 == null) { + return null; + } + Text model = modelOI.getPrimitiveWritableObject(arg1); + + Object arg2 = arguments[2].get(); + if (arg2 == null) { + throw new HiveException("features was null"); + } + this.featuresProbe = parseFeatures(arg2, featuresProbe); + + if (evaluator == null) { + this.evaluator = classification ? new ClassificationEvaluator(this) + : new RegressionEvaluator(this); + } + return evaluator.evaluate(modelId, model, featuresProbe); + } + + @Nonnull + private Vector parseFeatures(@Nonnull final Object argObj, @Nullable Vector probe) + throws UDFArgumentException { + if (denseInput) { + final int length = featureListOI.getListLength(argObj); + if (probe == null) { + probe = new DenseVector(length); + } else if (length != probe.size()) { + probe = new DenseVector(length); + } + + for (int i = 0; i < length; i++) { + final Object o = featureListOI.getListElement(argObj, i); + if (o == null) { + probe.set(i, 0.d); + } else { + double v = PrimitiveObjectInspectorUtils.getDouble(o, featureElemOI); + probe.set(i, v); + } + } + } else { + if (probe == null) { + probe = new SparseVector(); + } else { + probe.clear(); + } + + final int length = featureListOI.getListLength(argObj); + for (int i = 0; i < length; i++) { + Object o = featureListOI.getListElement(argObj, i); + if (o == null) { + continue; + } + String col = o.toString(); + + final int pos = col.indexOf(':'); + if (pos == 0) { + throw new UDFArgumentException("Invalid feature value representation: " + col); + } + + final String feature; + final double value; + if (pos > 0) { + feature = col.substring(0, pos); + String s2 = col.substring(pos + 1); + value = Double.parseDouble(s2); + } else { + feature = col; + value = 1.d; + } + + if (feature.indexOf(':') != -1) { + throw new UDFArgumentException( + "Invalid feature format `<index>:<value>`: " + col); + } + + final int colIndex = Integer.parseInt(feature); + if (colIndex < 0) { + throw new UDFArgumentException( + "Col index MUST be greater than or equals to 0: " + colIndex); + } + probe.set(colIndex, value); + } + } + return probe; + } + + @Override + public void close() throws IOException { + this.modelOI = null; + this.featureElemOI = null; + this.featureListOI = null; + this.featureNames = null; + this.classNames = null; + this.featuresProbe = null; + this.evaluator = null; + } + + @Override + public String getDisplayString(String[] children) { + return "decision_path(" + StringUtils.join(children, ',') + ")"; + } + + interface Evaluator { + + @Nonnull + List<String> evaluate(@Nonnull String modelId, @Nonnull Text model, + @Nonnull Vector features) throws HiveException; + + } + + static final class ClassificationEvaluator implements Evaluator { + + @Nullable + private final String[] featureNames; + @Nullable + private final String[] classNames; + + @Nonnull + private final List<String> result; + @Nonnull + private final PredictionHandler handler; + + @Nullable + private String prevModelId = null; + private DecisionTree.Node cNode = null; + + ClassificationEvaluator(@Nonnull final DecisionPathUDF udf) { + this.featureNames = udf.featureNames; + this.classNames = udf.classNames; + + final StringBuilder buf = new StringBuilder(); + final ArrayList<String> result = new ArrayList<>(); + this.result = result; + + if (udf.summarize) { + final LinkedHashMap<String, Double> map = new LinkedHashMap<>(); + + this.handler = new PredictionHandler() { + + @Override + public void init() { + map.clear(); + result.clear(); + } + + @Override + public void visitBranch(Operator op, int splitFeatureIndex, double splitFeature, + double splitValue) { + buf.append(resolveFeatureName(splitFeatureIndex)); + if (udf.verbose) { + buf.append(" [" + splitFeature + "] "); + } else { + buf.append(' '); + } + buf.append(op); + if (op == Operator.EQ || op == Operator.NE) { + buf.append(' '); + buf.append(splitValue); + } + String key = buf.toString(); + map.put(key, splitValue); + StringUtils.clear(buf); + } + + @Override + public void visitLeaf(int output, double[] posteriori) { + for (Map.Entry<String, Double> e : map.entrySet()) { + final String key = e.getKey(); + if (key.indexOf('<') == -1 && key.indexOf('>') == -1) { + result.add(key); + } else { + double value = e.getValue().doubleValue(); + result.add(key + ' ' + value); + } + } + if (udf.noLeaf) { + return; + } + + if (udf.verbose) { + buf.append(resolveClassName(output)); + buf.append(' '); + buf.append(Arrays.toString(posteriori)); + result.add(buf.toString()); + StringUtils.clear(buf); + } else { + result.add(resolveClassName(output)); + } + } + + @SuppressWarnings("unchecked") + @Override + public ArrayList<String> getResult() { + return result; + } + + }; + } else { + this.handler = new PredictionHandler() { + + @Override + public void init() { + result.clear(); + } + + @Override + public void visitBranch(Operator op, int splitFeatureIndex, double splitFeature, + double splitValue) { + buf.append(resolveFeatureName(splitFeatureIndex)); + if (udf.verbose) { + buf.append(" [" + splitFeature + "] "); + } else { + buf.append(' '); + } + buf.append(op); + buf.append(' '); + buf.append(splitValue); + result.add(buf.toString()); + StringUtils.clear(buf); + } + + @Override + public void visitLeaf(int output, double[] posteriori) { + if (udf.noLeaf) { + return; + } + + if (udf.verbose) { + buf.append(resolveClassName(output)); + buf.append(' '); + buf.append(Arrays.toString(posteriori)); + result.add(buf.toString()); + StringUtils.clear(buf); + } else { + result.add(resolveClassName(output)); + } + } + + @SuppressWarnings("unchecked") + @Override + public ArrayList<String> getResult() { + return result; + } + + }; + } + } + + @Nonnull + private String resolveFeatureName(final int splitFeatureIndex) { + if (featureNames == null) { + return Integer.toString(splitFeatureIndex); + } else { + return featureNames[splitFeatureIndex]; + } + } + + @Nonnull + private String resolveClassName(final int classLabel) { + if (classNames == null) { + return Integer.toString(classLabel); + } else { + return classNames[classLabel]; + } + } + + @Nonnull + public List<String> evaluate(@Nonnull final String modelId, @Nonnull final Text script, + @Nonnull final Vector features) throws HiveException { + if (!modelId.equals(prevModelId)) { + this.prevModelId = modelId; + int length = script.getLength(); + byte[] b = script.getBytes(); + b = Base91.decode(b, 0, length); + this.cNode = DecisionTree.deserialize(b, b.length, true); + } + Preconditions.checkNotNull(cNode); + + handler.init(); + cNode.predict(features, handler); + return handler.getResult(); + } + + } + + static final class RegressionEvaluator implements Evaluator { + + @Nullable + private final String[] featureNames; + + @Nonnull + private final List<String> result; + @Nonnull + private final PredictionHandler handler; + + @Nullable + private String prevModelId = null; + private RegressionTree.Node rNode = null; + + RegressionEvaluator(@Nonnull final DecisionPathUDF udf) { + this.featureNames = udf.featureNames; + + final StringBuilder buf = new StringBuilder(); + final ArrayList<String> result = new ArrayList<>(); + this.result = result; + + if (udf.summarize) { + final LinkedHashMap<String, Double> map = new LinkedHashMap<>(); + + this.handler = new PredictionHandler() { + + @Override + public void init() { + map.clear(); + result.clear(); + } + + @Override + public void visitBranch(Operator op, int splitFeatureIndex, double splitFeature, + double splitValue) { + buf.append(resolveFeatureName(splitFeatureIndex)); + if (udf.verbose) { + buf.append(" [" + splitFeature + "] "); + } else { + buf.append(' '); + } + buf.append(op); + if (op == Operator.EQ || op == Operator.NE) { + buf.append(' '); + buf.append(splitValue); + } + String key = buf.toString(); + map.put(key, splitValue); + StringUtils.clear(buf); + } + + @Override + public void visitLeaf(double output) { + for (Map.Entry<String, Double> e : map.entrySet()) { + final String key = e.getKey(); + if (key.indexOf('<') == -1 && key.indexOf('>') == -1) { + result.add(key); + } else { + double value = e.getValue().doubleValue(); + result.add(key + ' ' + value); + } + } + if (udf.noLeaf) { + return; + } + + result.add(Double.toString(output)); + } + + @SuppressWarnings("unchecked") + @Override + public ArrayList<String> getResult() { + return result; + } + + }; + } else { + this.handler = new PredictionHandler() { + + @Override + public void init() { + result.clear(); + } + + @Override + public void visitBranch(Operator op, int splitFeatureIndex, double splitFeature, + double splitValue) { + buf.append(resolveFeatureName(splitFeatureIndex)); + if (udf.verbose) { + buf.append(" [" + splitFeature + "] "); + } + buf.append(op); + buf.append(' '); + buf.append(splitValue); + result.add(buf.toString()); + StringUtils.clear(buf); + } + + @Override + public void visitLeaf(double output) { + if (udf.noLeaf) { + return; + } + + result.add(Double.toString(output)); + } + + @SuppressWarnings("unchecked") + @Override + public ArrayList<String> getResult() { + return result; + } + + }; + } + } + + @Nonnull + private String resolveFeatureName(final int splitFeatureIndex) { + if (featureNames == null) { + return Integer.toString(splitFeatureIndex); + } else { + return featureNames[splitFeatureIndex]; + } + } + + @Nonnull + public List<String> evaluate(@Nonnull final String modelId, @Nonnull final Text script, + @Nonnull final Vector features) throws HiveException { + if (!modelId.equals(prevModelId)) { + this.prevModelId = modelId; + int length = script.getLength(); + byte[] b = script.getBytes(); + b = Base91.decode(b, 0, length); + this.rNode = RegressionTree.deserialize(b, b.length, true); + } + Preconditions.checkNotNull(rNode); + + handler.init(); + rNode.predict(features, handler); + return handler.getResult(); + } + } + +} diff --git a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java index 511944c..262a28d 100644 --- a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java +++ b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java @@ -284,7 +284,7 @@ public final class TreePredictUDF extends UDFWithOptions { Arrays.fill(result, null); Preconditions.checkNotNull(cNode); cNode.predict(features, new PredictionHandler() { - public void handle(int output, double[] posteriori) { + public void visitLeaf(int output, double[] posteriori) { result[0] = new IntWritable(output); result[1] = WritableUtils.toWritableList(posteriori); } diff --git a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java index 4e73ebc..caf21d3 100644 --- a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java +++ b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java @@ -148,17 +148,23 @@ public final class ArrayUtils { return Arrays.asList(v); } - public static <T> void shuffle(@Nonnull final T[] array) { + @Nonnull + public static <T> T[] shuffle(@Nonnull final T[] array) { shuffle(array, array.length); + return array; } - public static <T> void shuffle(@Nonnull final T[] array, final Random rnd) { + @Nonnull + public static <T> T[] shuffle(@Nonnull final T[] array, final Random rnd) { shuffle(array, array.length, rnd); + return array; } - public static <T> void shuffle(@Nonnull final T[] array, final int size) { + @Nonnull + public static <T> T[] shuffle(@Nonnull final T[] array, final int size) { Random rnd = new Random(); shuffle(array, size, rnd); + return array; } /** @@ -166,19 +172,23 @@ public final class ArrayUtils { * * @link http://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle */ - public static <T> void shuffle(@Nonnull final T[] array, final int size, + @Nonnull + public static <T> T[] shuffle(@Nonnull final T[] array, final int size, @Nonnull final Random rnd) { for (int i = size; i > 1; i--) { int randomPosition = rnd.nextInt(i); swap(array, i - 1, randomPosition); } + return array; } - public static void shuffle(@Nonnull final int[] array, @Nonnull final Random rnd) { + @Nonnull + public static int[] shuffle(@Nonnull final int[] array, @Nonnull final Random rnd) { for (int i = array.length; i > 1; i--) { int randomPosition = rnd.nextInt(i); swap(array, i - 1, randomPosition); } + return array; } public static void swap(@Nonnull final Object[] arr, final int i, final int j) { diff --git a/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java b/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java index c3601eb..9e5ee9a 100644 --- a/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java +++ b/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java @@ -25,12 +25,17 @@ import hivemall.math.matrix.builders.CSRMatrixBuilder; import hivemall.math.matrix.dense.RowMajorDenseMatrix2d; import hivemall.math.random.PRNG; import hivemall.math.random.RandomNumberGeneratorFactory; +import hivemall.math.vector.DenseVector; import hivemall.smile.classification.DecisionTree.Node; import hivemall.smile.classification.DecisionTree.SplitRule; import hivemall.smile.tools.TreeExportUDF.Evaluator; import hivemall.smile.tools.TreeExportUDF.OutputType; import hivemall.smile.utils.SmileExtUtils; import hivemall.utils.codec.Base91; +import hivemall.utils.lang.ArrayUtils; +import hivemall.utils.lang.StringUtils; +import hivemall.utils.math.MathUtils; +import smile.data.Attribute; import smile.data.AttributeDataset; import smile.data.NominalAttribute; import smile.data.parser.ArffParser; @@ -43,6 +48,9 @@ import java.io.IOException; import java.io.InputStream; import java.net.URL; import java.text.ParseException; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.Random; import javax.annotation.Nonnull; @@ -99,6 +107,15 @@ public class DecisionTreeTest { } @Test + public void testIrisTracePredict() throws IOException, ParseException { + int responseIndex = 4; + int numLeafs = Integer.MAX_VALUE; + runTracePredict( + "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff", + responseIndex, numLeafs); + } + + @Test public void testIrisDepth4() throws IOException, ParseException { int responseIndex = 4; int numLeafs = 4; @@ -240,6 +257,69 @@ public class DecisionTreeTest { } } + private static void runTracePredict(String datasetUrl, int responseIndex, int numLeafs) + throws IOException, ParseException { + URL url = new URL(datasetUrl); + InputStream is = new BufferedInputStream(url.openStream()); + + ArffParser arffParser = new ArffParser(); + arffParser.setResponseIndex(responseIndex); + + AttributeDataset ds = arffParser.parse(is); + final Attribute[] attrs = ds.attributes(); + final Attribute targetAttr = ds.response(); + + double[][] x = ds.toArray(new double[ds.size()][]); + int[] y = ds.toArray(new int[ds.size()]); + + Random rnd = new Random(43L); + int numTrain = (int) (x.length * 0.7); + int[] index = ArrayUtils.shuffle(MathUtils.permutation(x.length), rnd); + int[] cvTrain = Arrays.copyOf(index, numTrain); + int[] cvTest = Arrays.copyOfRange(index, numTrain, index.length); + + double[][] trainx = Math.slice(x, cvTrain); + int[] trainy = Math.slice(y, cvTrain); + double[][] testx = Math.slice(x, cvTest); + + DecisionTree tree = new DecisionTree(SmileExtUtils.convertAttributeTypes(attrs), + matrix(trainx, false), trainy, numLeafs, RandomNumberGeneratorFactory.createPRNG(43L)); + + final LinkedHashMap<String, Double> map = new LinkedHashMap<>(); + final StringBuilder buf = new StringBuilder(); + for (int i = 0; i < testx.length; i++) { + final DenseVector test = new DenseVector(testx[i]); + tree.predict(test, new PredictionHandler() { + + @Override + public void visitBranch(Operator op, int splitFeatureIndex, double splitFeature, + double splitValue) { + buf.append(attrs[splitFeatureIndex].name); + buf.append(" [" + splitFeature + "] "); + buf.append(op); + buf.append(' '); + buf.append(splitValue); + buf.append('\n'); + + map.put(attrs[splitFeatureIndex].name + " [" + splitFeature + "] " + op, + splitValue); + } + + @Override + public void visitLeaf(int output, double[] posteriori) { + buf.append(targetAttr.toString(output)); + } + }); + + Assert.assertTrue(buf.length() > 0); + Assert.assertFalse(map.isEmpty()); + + StringUtils.clear(buf); + map.clear(); + } + + } + @Test public void testIrisSerializedObj() throws IOException, ParseException, HiveException { URL url = new URL( diff --git a/docs/gitbook/misc/funcs.md b/docs/gitbook/misc/funcs.md index d860dba..e5e9dc8 100644 --- a/docs/gitbook/misc/funcs.md +++ b/docs/gitbook/misc/funcs.md @@ -589,6 +589,39 @@ Reference: <a href="https://papers.nips.cc/paper/3848-adaptive-regularization-of - `train_randomforest_regressor(array<double|string> features, double target [, string options])` - Returns a relation consists of <int model_id, int model_type, string model, array<double> var_importance, double oob_errors, int oob_tests> +- `decision_path(string modelId, string model, array<double|string> features [, const string options] [, optional array<string> featureNames=null, optional array<string> classNames=null])` - Returns a decision path for each prediction in array<string> + ```sql + SELECT + t.passengerid, + decision_path(m.model_id, m.model, t.features, '-classification') + FROM + model_rf m + LEFT OUTER JOIN + test_rf t; + > | 892 | ["2 [0.0] = 0.0","0 [3.0] = 3.0","1 [696.0] != 107.0","7 [7.8292] <= 7.9104","1 [696.0] != 828.0","1 [696.0] != 391.0","0 [0.961038961038961, 0.03896103896103896]"] | + + -- Show 100 frequent branches + WITH tmp as ( + SELECT + decision_path(m.model_id, m.model, t.features, '-classification -no_verbose -no_leaf', array('pclass','name','sex','age','sibsp','parch','ticket','fare','cabin','embarked'), array('no','yes')) as path + FROM + model_rf m + LEFT OUTER JOIN -- CROSS JOIN + test_rf t + ) + select + r.branch, + count(1) as cnt + from + tmp l + LATERAL VIEW explode(l.path) r as branch + group by + r.branch + order by + cnt desc + limit 100; + ``` + - `guess_attribute_types(ANY, ...)` - Returns attribute types ```sql select guess_attribute_types(*) from train limit 1; diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive index 17797a8..343215a 100644 --- a/resources/ddl/define-all-as-permanent.hive +++ b/resources/ddl/define-all-as-permanent.hive @@ -829,6 +829,9 @@ CREATE FUNCTION rf_ensemble as 'hivemall.smile.tools.RandomForestEnsembleUDAF' U DROP FUNCTION IF EXISTS guess_attribute_types; CREATE FUNCTION guess_attribute_types as 'hivemall.smile.tools.GuessAttributesUDF' USING JAR '${hivemall_jar}'; +DROP FUNCTION IF EXISTS decision_path; +CREATE FUNCTION decision_path as 'hivemall.smile.tools.DecisionPathUDF' USING JAR '${hivemall_jar}'; + -------------------- -- Recommendation -- -------------------- diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive index 04e8915..2a9b437 100644 --- a/resources/ddl/define-all.hive +++ b/resources/ddl/define-all.hive @@ -821,6 +821,9 @@ create temporary function rf_ensemble as 'hivemall.smile.tools.RandomForestEnsem drop temporary function if exists guess_attribute_types; create temporary function guess_attribute_types as 'hivemall.smile.tools.GuessAttributesUDF'; +drop temporary function if exists decision_path; +create temporary function decision_path as 'hivemall.smile.tools.DecisionPathUDF'; + -------------------- -- Recommendation -- -------------------- @@ -889,3 +892,4 @@ log(10, n_docs / max2(1,df_t)) + 1.0; create temporary macro tfidf(tf FLOAT, df_t DOUBLE, n_docs DOUBLE) tf * (log(10, n_docs / max2(1,df_t)) + 1.0); + diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark index 19f01bc..d62e3a2 100644 --- a/resources/ddl/define-all.spark +++ b/resources/ddl/define-all.spark @@ -807,6 +807,9 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION guess_attribute_types AS 'hivemall.smi sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_gradient_tree_boosting_classifier") sqlContext.sql("CREATE TEMPORARY FUNCTION train_gradient_tree_boosting_classifier AS 'hivemall.smile.classification.GradientTreeBoostingClassifierUDTF'") +sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS decision_path") +sqlContext.sql("CREATE TEMPORARY FUNCTION decision_path AS 'hivemall.smile.tools.DecisionPathUDF'") + /** * Recommendation */