Repository: incubator-hivemall Updated Branches: refs/heads/v0.5.0 2958af0af -> c742ce58e
[HIVEMALL-172] Change tree_predict 3rd argument to accept string options Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/c742ce58 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/c742ce58 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/c742ce58 Branch: refs/heads/v0.5.0 Commit: c742ce58e94913bf446c3b296a24415676f9ac3b Parents: 2958af0 Author: Makoto Yui <m...@apache.org> Authored: Thu Feb 8 17:36:50 2018 +0900 Committer: Makoto Yui <m...@apache.org> Committed: Thu Feb 8 17:36:50 2018 +0900 ---------------------------------------------------------------------- .../hivemall/smile/tools/TreePredictUDF.java | 63 ++++++++++++++------ docs/gitbook/binaryclass/news20_rf.md | 5 +- docs/gitbook/binaryclass/titanic_rf.md | 10 ++-- docs/gitbook/multiclass/iris_randomforest.md | 8 ++- 4 files changed, 60 insertions(+), 26 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c742ce58/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java index 46b8758..ea3bc29 100644 --- a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java +++ b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java @@ -18,6 +18,7 @@ */ package hivemall.smile.tools; +import hivemall.UDFWithOptions; import hivemall.math.vector.DenseVector; import hivemall.math.vector.SparseVector; import hivemall.math.vector.Vector; @@ -37,11 +38,12 @@ import java.util.List; import javax.annotation.Nonnull; import javax.annotation.Nullable; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.UDFType; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; @@ -53,12 +55,12 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspe import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; -@Description( - name = "tree_predict", - value = "_FUNC_(string modelId, string model, array<double|string> features [, const boolean classification])" - + " - Returns a prediction result of a random forest") +@Description(name = "tree_predict", + value = "_FUNC_(string modelId, string model, array<double|string> features [, const string options | const boolean classification=false])" + + " - Returns a prediction result of a random forest" + + " in <int value, array<double> posteriori> for classification and <double> for regression") @UDFType(deterministic = true, stateful = false) -public final class TreePredictUDF extends GenericUDF { +public final class TreePredictUDF extends UDFWithOptions { private boolean classification; private StringObjectInspector modelOI; @@ -72,9 +74,25 @@ public final class TreePredictUDF extends GenericUDF { private transient Evaluator evaluator; @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("c", "classification", false, + "Predict as classification [default: not enabled]"); + return opts; + } + + @Override + protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException { + CommandLine cl = parseOptions(optionValue); + + this.classification = cl.hasOption("classification"); + return cl; + } + + @Override public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { if (argOIs.length != 3 && argOIs.length != 4) { - throw new UDFArgumentException("_FUNC_ takes 3 or 4 arguments"); + throw new UDFArgumentException("tree_predict takes 3 or 4 arguments"); } this.modelOI = HiveUtils.asStringOI(argOIs[1]); @@ -89,15 +107,25 @@ public final class TreePredictUDF extends GenericUDF { this.denseInput = false; } else { throw new UDFArgumentException( - "_FUNC_ takes array<double> or array<string> for the second argument: " + "tree_predict takes array<double> or array<string> for the second argument: " + listOI.getTypeName()); } - boolean classification = false; if (argOIs.length == 4) { - classification = HiveUtils.getConstBoolean(argOIs[3]); + ObjectInspector argOI3 = argOIs[3]; + if (HiveUtils.isConstBoolean(argOI3)) { + this.classification = HiveUtils.getConstBoolean(argOI3); + } else if (HiveUtils.isConstString(argOI3)) { + String opts = HiveUtils.getConstString(argOI3); + processOptions(opts); + } else { + throw new UDFArgumentException( + "tree_predict expects <const boolean> or <const string> for the fourth argument: " + + argOI3.getTypeName()); + } + } else { + this.classification = false; } - this.classification = classification; if (classification) { List<String> fieldNames = new ArrayList<String>(2); @@ -105,7 +133,8 @@ public final class TreePredictUDF extends GenericUDF { fieldNames.add("value"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); fieldNames.add("posteriori"); - fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); } else { return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; @@ -116,7 +145,7 @@ public final class TreePredictUDF extends GenericUDF { public Object evaluate(@Nonnull DeferredObject[] arguments) throws HiveException { Object arg0 = arguments[0].get(); if (arg0 == null) { - throw new HiveException("ModelId was null"); + throw new HiveException("modelId should not be null"); } // Not using string OI for backward compatibilities String modelId = arg0.toString(); @@ -134,8 +163,8 @@ public final class TreePredictUDF extends GenericUDF { this.featuresProbe = parseFeatures(arg2, featuresProbe); if (evaluator == null) { - this.evaluator = classification ? new ClassificationEvaluator() - : new RegressionEvaluator(); + this.evaluator = + classification ? new ClassificationEvaluator() : new RegressionEvaluator(); } return evaluator.evaluate(modelId, model, featuresProbe); } @@ -192,8 +221,8 @@ public final class TreePredictUDF extends GenericUDF { } if (feature.indexOf(':') != -1) { - throw new UDFArgumentException("Invaliad feature format `<index>:<value>`: " - + col); + throw new UDFArgumentException( + "Invaliad feature format `<index>:<value>`: " + col); } final int colIndex = Integer.parseInt(feature); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c742ce58/docs/gitbook/binaryclass/news20_rf.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/binaryclass/news20_rf.md b/docs/gitbook/binaryclass/news20_rf.md index fd0b475..327939b 100644 --- a/docs/gitbook/binaryclass/news20_rf.md +++ b/docs/gitbook/binaryclass/news20_rf.md @@ -47,7 +47,7 @@ from ## Prediction ```sql -SET hivevar:classification=true; +-- SET hivevar:classification=true; drop table rf_predicted; create table rf_predicted @@ -60,7 +60,8 @@ FROM ( SELECT rowid, m.model_weight, - tree_predict(m.model_id, m.model, t.features, ${classification}) as predicted + tree_predict(m.model_id, m.model, t.features, "-classification") as predicted + -- tree_predict(m.model_id, m.model, t.features, ${classification}) as predicted FROM rf_model m LEFT OUTER JOIN -- CROSS JOIN http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c742ce58/docs/gitbook/binaryclass/titanic_rf.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/binaryclass/titanic_rf.md b/docs/gitbook/binaryclass/titanic_rf.md index 29784e0..2b54074 100644 --- a/docs/gitbook/binaryclass/titanic_rf.md +++ b/docs/gitbook/binaryclass/titanic_rf.md @@ -175,7 +175,7 @@ from # Prediction ```sql -SET hivevar:classification=true; +-- SET hivevar:classification=true; set hive.auto.convert.join=true; SET hive.mapjoin.optimized.hashtable=false; SET mapred.reduce.tasks=16; @@ -202,7 +202,8 @@ FROM ( -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- hivemall v0.5-rc.1 or later p.model_weight, - tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted + tree_predict(p.model_id, p.model, t.features, "-classification") as predicted + -- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted -- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- to use the old model in v0.5-rc.1 or later FROM ( SELECT @@ -319,7 +320,7 @@ from > [116.12055542977338,960.8569891444097,291.08765260103837,469.74671636586226,163.721292772701,120.784769882858,847.9769298113661,554.4617571355476,346.3500941757221,97.42593940113392] > 0.1838351822503962 ```sql -SET hivevar:classification=true; +-- SET hivevar:classification=true; SET hive.mapjoin.optimized.hashtable=false; SET mapred.reduce.tasks=16; @@ -345,7 +346,8 @@ FROM ( -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- hivemall v0.5-rc.1 or later p.model_weight, - tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted + tree_predict(p.model_id, p.model, t.features, "-classification") as predicted + -- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted -- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- to use the old model in v0.5-rc.1 or later FROM ( SELECT http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c742ce58/docs/gitbook/multiclass/iris_randomforest.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/multiclass/iris_randomforest.md b/docs/gitbook/multiclass/iris_randomforest.md index b421297..bfc197f 100644 --- a/docs/gitbook/multiclass/iris_randomforest.md +++ b/docs/gitbook/multiclass/iris_randomforest.md @@ -206,7 +206,7 @@ from # Prediction ```sql -set hivevar:classification=true; +-- set hivevar:classification=true; set hive.auto.convert.join=true; set hive.mapjoin.optimized.hashtable=false; @@ -225,7 +225,8 @@ FROM ( -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- hivemall v0.5-rc.1 or later p.model_weight, - tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted + tree_predict(p.model_id, p.model, t.features, "-classification") as predicted + -- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted -- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- to use the old model in v0.5-rc.1 or later FROM model p @@ -265,7 +266,8 @@ FROM ( -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- hivemall v0.5-rc.1 or later p.model_weight, - tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted + tree_predict(p.model_id, p.model, t.features, "-classification") as predicted + -- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted -- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted as predicted -- to use the old model in v0.5-rc.1 or later FROM ( SELECT