This is an automated email from the ASF dual-hosted git repository. myui pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git
The following commit(s) were added to refs/heads/master by this push: new 7f23a86 [HIVEMALL-121] Add -libsvm formatting option to feature_hashing UDF 7f23a86 is described below commit 7f23a86d66e0030fb3eb92dddf0cd3b6a1a46bac Author: Makoto Yui <m...@apache.org> AuthorDate: Mon Nov 25 19:03:15 2019 +0900 [HIVEMALL-121] Add -libsvm formatting option to feature_hashing UDF ## What changes were proposed in this pull request? Add `-libsvm` formatting option for `feature_hashing ## What type of PR is it? Improvement ## What is the Jira issue? https://issues.apache.org/jira/browse/HIVEMALL-121 ## How was this patch tested? unit tests, manual tests on EMR ## How to use this feature? ```sql select feature_hashing(array('aaa:1.0','aaa','bbb:2.0'), '-libsvm'); > ["4063537:1.0","4063537:1","8459207:2.0"] select feature_hashing(array('aaa:1.0','aaa','bbb:2.0'), '-features 10 -libsvm'); > ["1:2.0","7:1.0","7:1"] ``` ## Checklist - [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit? - [x] Did you run system tests on Hive (or Spark)? Author: Makoto Yui <m...@apache.org> Closes #216 from myui/HIVEMALL-121. --- .../hivemall/ftvec/hashing/FeatureHashingUDF.java | 107 ++++++++++++++++----- .../ftvec/hashing/FeatureHashingUDFTest.java | 65 +++++++++++++ docs/gitbook/ft_engineering/hashing.md | 23 ++++- docs/gitbook/misc/funcs.md | 13 ++- 4 files changed, 180 insertions(+), 28 deletions(-) diff --git a/core/src/main/java/hivemall/ftvec/hashing/FeatureHashingUDF.java b/core/src/main/java/hivemall/ftvec/hashing/FeatureHashingUDF.java index b2d5dac..0bc1c97 100644 --- a/core/src/main/java/hivemall/ftvec/hashing/FeatureHashingUDF.java +++ b/core/src/main/java/hivemall/ftvec/hashing/FeatureHashingUDF.java @@ -20,12 +20,16 @@ package hivemall.ftvec.hashing; import hivemall.HivemallConstants; import hivemall.UDFWithOptions; +import hivemall.annotations.VisibleForTesting; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.hashing.MurmurHash3; import hivemall.utils.lang.Primitives; +import hivemall.utils.lang.StringUtils; +import java.io.Serializable; import java.util.ArrayList; -import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; import java.util.List; import javax.annotation.Nonnull; @@ -35,38 +39,47 @@ import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; -import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.UDFType; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -import org.apache.hadoop.io.Text; +//@formatter:off @Description(name = "feature_hashing", value = "_FUNC_(array<string> features [, const string options])" - + " - returns a hashed feature vector in array<string>") + + " - returns a hashed feature vector in array<string>", + extended = "select feature_hashing(array('aaa:1.0','aaa','bbb:2.0'), '-libsvm');\n" + + "> [\"4063537:1.0\",\"4063537:1\",\"8459207:2.0\"]\n" + + "\n" + + "select feature_hashing(array('aaa:1.0','aaa','bbb:2.0'), '-features 10');\n" + + "> [\"7:1.0\",\"7\",\"1:2.0\"]\n" + + "\n" + + "select feature_hashing(array('aaa:1.0','aaa','bbb:2.0'), '-features 10 -libsvm');\n" + + "> [\"1:2.0\",\"7:1.0\",\"7:1\"]\n" + + "") +//@formatter:on @UDFType(deterministic = true, stateful = false) public final class FeatureHashingUDF extends UDFWithOptions { + private static final IndexComparator indexCmp = new IndexComparator(); + @Nullable private ListObjectInspector _listOI; + private boolean _libsvmFormat = false; private int _numFeatures = MurmurHash3.DEFAULT_NUM_FEATURES; @Nullable - private List<Text> _returnObj; + private transient List<String> _returnObj; public FeatureHashingUDF() {} @Override - public String getDisplayString(String[] children) { - return "feature_hashing(" + Arrays.toString(children) + ')'; - } - - @Override protected Options getOptions() { Options opts = new Options(); + opts.addOption("libsvm", false, + "Returns in libsvm format (<index>:<value>)* sorted by index ascending order"); opts.addOption("features", "num_features", true, "The number of features [default: 16777217 (2^24)]"); return opts; @@ -76,6 +89,7 @@ public final class FeatureHashingUDF extends UDFWithOptions { protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException { CommandLine cl = parseOptions(optionValue); + this._libsvmFormat = cl.hasOption("libsvm"); this._numFeatures = Primitives.parseInt(cl.getOptionValue("num_features"), _numFeatures); return cl; } @@ -84,8 +98,7 @@ public final class FeatureHashingUDF extends UDFWithOptions { public ObjectInspector initialize(@Nonnull ObjectInspector[] argOIs) throws UDFArgumentException { if (argOIs.length != 1 && argOIs.length != 2) { - throw new UDFArgumentLengthException( - "The feature_hashing function takes 1 or 2 arguments: " + argOIs.length); + showHelp("The feature_hashing function takes 1 or 2 arguments: " + argOIs.length); } ObjectInspector argOI0 = argOIs[0]; this._listOI = HiveUtils.isListOI(argOI0) ? (ListObjectInspector) argOI0 : null; @@ -96,10 +109,10 @@ public final class FeatureHashingUDF extends UDFWithOptions { } if (_listOI == null) { - return PrimitiveObjectInspectorFactory.writableStringObjectInspector; + return PrimitiveObjectInspectorFactory.javaStringObjectInspector; } else { return ObjectInspectorFactory.getStandardListObjectInspector( - PrimitiveObjectInspectorFactory.writableStringObjectInspector); + PrimitiveObjectInspectorFactory.javaStringObjectInspector); } } @@ -118,17 +131,17 @@ public final class FeatureHashingUDF extends UDFWithOptions { } @Nonnull - private Text evaluateScalar(@Nonnull final Object arg0) { + private String evaluateScalar(@Nonnull final Object arg0) { String fv = arg0.toString(); - return new Text(featureHashing(fv, _numFeatures)); + return featureHashing(fv, _numFeatures, _libsvmFormat); } @Nonnull - private List<Text> evaluateList(@Nonnull final Object arg0) { + private List<String> evaluateList(@Nonnull final Object arg0) throws HiveException { final int len = _listOI.getListLength(arg0); - List<Text> list = _returnObj; + List<String> list = _returnObj; if (list == null) { - list = new ArrayList<Text>(len); + list = new ArrayList<String>(len); this._returnObj = list; } else { list.clear(); @@ -140,23 +153,40 @@ public final class FeatureHashingUDF extends UDFWithOptions { if (obj == null) { continue; } - String fv = obj.toString(); - Text t = new Text(featureHashing(fv, numFeatures)); - list.add(t); + String fv = featureHashing(obj.toString(), numFeatures, _libsvmFormat); + list.add(fv); } + if (_libsvmFormat) { + try { + Collections.sort(list, indexCmp); + } catch (NumberFormatException e) { + throw new HiveException(e); + } + } return list; } + @VisibleForTesting @Nonnull static String featureHashing(@Nonnull final String fv, final int numFeatures) { + return featureHashing(fv, numFeatures, false); + } + + @Nonnull + static String featureHashing(@Nonnull final String fv, final int numFeatures, + final boolean libsvmFormat) { final int headPos = fv.indexOf(':'); if (headPos == -1) { if (fv.equals(HivemallConstants.BIAS_CLAUSE)) { return fv; } - int h = mhash(fv, numFeatures); - return String.valueOf(h); + final int h = mhash(fv, numFeatures); + if (libsvmFormat) { + return h + ":1"; + } else { + return String.valueOf(h); + } } else { final int tailPos = fv.lastIndexOf(':'); if (headPos == tailPos) { @@ -189,4 +219,33 @@ public final class FeatureHashingUDF extends UDFWithOptions { return r + 1; } + @Override + public String getDisplayString(String[] children) { + return "feature_hashing(" + StringUtils.join(children, ',') + ')'; + } + + private static final class IndexComparator implements Comparator<String>, Serializable { + private static final long serialVersionUID = -260142385860586255L; + + @Override + public int compare(@Nonnull final String lhs, @Nonnull final String rhs) { + int l = getIndex(lhs); + int r = getIndex(rhs); + return Integer.compare(l, r); + } + + private static int getIndex(@Nonnull final String fv) { + final int headPos = fv.indexOf(':'); + final int tailPos = fv.lastIndexOf(':'); + final String f; + if (headPos == tailPos) { + f = fv.substring(0, headPos); + } else { + f = fv.substring(headPos + 1, tailPos); + } + return Integer.parseInt(f); + } + + } + } diff --git a/core/src/test/java/hivemall/ftvec/hashing/FeatureHashingUDFTest.java b/core/src/test/java/hivemall/ftvec/hashing/FeatureHashingUDFTest.java index 06277c2..d73e9fc 100644 --- a/core/src/test/java/hivemall/ftvec/hashing/FeatureHashingUDFTest.java +++ b/core/src/test/java/hivemall/ftvec/hashing/FeatureHashingUDFTest.java @@ -19,12 +19,18 @@ package hivemall.ftvec.hashing; import hivemall.TestUtils; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.hadoop.WritableUtils; import hivemall.utils.hashing.MurmurHash3; import java.io.IOException; import java.util.Arrays; +import java.util.Collections; +import java.util.List; import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; @@ -52,6 +58,65 @@ public class FeatureHashingUDFTest { } @Test + public void testBiasLibsvm() { + String expected = "0:1.0"; + String actual = + FeatureHashingUDF.featureHashing(expected, MurmurHash3.DEFAULT_NUM_FEATURES, true); + Assert.assertEquals(expected, actual); + + expected = "0:1"; + actual = FeatureHashingUDF.featureHashing(expected, MurmurHash3.DEFAULT_NUM_FEATURES, true); + Assert.assertEquals(expected, actual); + + expected = "0:1.1"; + actual = FeatureHashingUDF.featureHashing(expected, MurmurHash3.DEFAULT_NUM_FEATURES, true); + Assert.assertEquals(FeatureHashingUDF.mhash("0", MurmurHash3.DEFAULT_NUM_FEATURES) + ":1.1", + actual); + } + + @Test + public void testEvaluateList() throws HiveException, IOException { + FeatureHashingUDF udf = new FeatureHashingUDF(); + + udf.initialize(new ObjectInspector[] {ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.writableStringObjectInspector)}); + + DeferredObject[] args = new DeferredObject[] {new GenericUDF.DeferredJavaObject( + WritableUtils.val("apple:3", "orange:2", "banana", "0:1"))}; + + List<String> expected = Arrays.asList( + FeatureHashingUDF.mhash("apple", MurmurHash3.DEFAULT_NUM_FEATURES) + ":3", + FeatureHashingUDF.mhash("orange", MurmurHash3.DEFAULT_NUM_FEATURES) + ":2", + Integer.toString(FeatureHashingUDF.mhash("banana", MurmurHash3.DEFAULT_NUM_FEATURES)), + "0:1"); + Assert.assertEquals(expected, udf.evaluate(args)); + + udf.close(); + } + + @Test + public void testEvaluateListLibsvm() throws HiveException, IOException { + FeatureHashingUDF udf = new FeatureHashingUDF(); + + udf.initialize(new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.writableStringObjectInspector), + HiveUtils.getConstStringObjectInspector("-libsvm")}); + + DeferredObject[] args = new DeferredObject[] {new GenericUDF.DeferredJavaObject( + WritableUtils.val("apple:3", "orange:2", "banana", "0:1"))}; + + List<String> expected = Arrays.asList( + FeatureHashingUDF.mhash("apple", MurmurHash3.DEFAULT_NUM_FEATURES) + ":3", + FeatureHashingUDF.mhash("orange", MurmurHash3.DEFAULT_NUM_FEATURES) + ":2", + FeatureHashingUDF.mhash("banana", MurmurHash3.DEFAULT_NUM_FEATURES) + ":1", "0:1"); + Collections.sort(expected); + Assert.assertEquals(expected, udf.evaluate(args)); + + udf.close(); + } + + @Test public void testSerialization() throws HiveException, IOException { TestUtils.testGenericUDFSerialization(FeatureHashingUDF.class, new ObjectInspector[] { diff --git a/docs/gitbook/ft_engineering/hashing.md b/docs/gitbook/ft_engineering/hashing.md index 8a08b8c..ff3de6e 100644 --- a/docs/gitbook/ft_engineering/hashing.md +++ b/docs/gitbook/ft_engineering/hashing.md @@ -52,6 +52,21 @@ select feature_hashing(array('aaa:1.0','aaa','bbb:2.0')); > ["4063537:1.0","4063537","8459207:2.0"] ```sql +select feature_hashing(array('aaa:1.0','aaa','bbb:2.0'), '-libsvm'); +``` +> ["4063537:1.0","4063537:1","8459207:2.0"] + +```sql +select feature_hashing(array('aaa:1.0','aaa','bbb:2.0'), '-features 10'); +``` +> ["7:1.0","7","1:2.0"] + +```sql +select feature_hashing(array('aaa:1.0','aaa','bbb:2.0'), '-features 10 -libsvm'); +``` +> ["1:2.0","7:1.0","7:1"] + +```sql select feature_hashing(array(1,2,3)); ``` > ["11293631","3322224","4331412"] @@ -78,14 +93,16 @@ select feature_hashing(array("userid#4505:3.3","movieid#2331:4.999", "movieid#23 > ["1828616:3.3","6238429:4.999","6238429"] ```sql -select feature_hashing(null,'-help'); +select feature_hashing(); usage: feature_hashing(array<string> features [, const string options]) - returns a hashed feature vector in array<string> [-features <arg>] - [-help] + [-libsvm] -features,--num_features <arg> The number of features [default: 16777217 (2^24)] - -help Show function help + -libsvm Returns in libsvm format + (<index>:<value>)* sorted by index + ascending order ``` > #### Note diff --git a/docs/gitbook/misc/funcs.md b/docs/gitbook/misc/funcs.md index bda8e22..3e5f92b 100644 --- a/docs/gitbook/misc/funcs.md +++ b/docs/gitbook/misc/funcs.md @@ -325,6 +325,17 @@ Reference: <a href="https://papers.nips.cc/paper/3848-adaptive-regularization-of - `array_hash_values(array<string> values, [string prefix [, int numFeatures], boolean useIndexAsPrefix])` returns hash values in array<int> - `feature_hashing(array<string> features [, const string options])` - returns a hashed feature vector in array<string> + ```sql + select feature_hashing(array('aaa:1.0','aaa','bbb:2.0'), '-libsvm'); + > ["4063537:1.0","4063537:1","8459207:2.0"] + + select feature_hashing(array('aaa:1.0','aaa','bbb:2.0'), '-features 10'); + > ["7:1.0","7","1:2.0"] + + select feature_hashing(array('aaa:1.0','aaa','bbb:2.0'), '-features 10 -libsvm'); + > ["1:2.0","7:1.0","7:1"] + + ``` - `mhash(string word)` returns a murmurhash3 INT value starting from 1 @@ -636,7 +647,7 @@ Reference: <a href="https://papers.nips.cc/paper/3848-adaptive-regularization-of # XGBoost -- `train_xgboost(array<string|double> features, int|double target [, string options])` - Returns a relation consists of <string model_id, array<string> pred_model> +- `train_xgboost(array<string|double> features, <int|double> target, const string options)` - Returns a relation consists of <string model_id, array<string> pred_model> ```sql SELECT train_xgboost(features, label, '-objective binary:logistic -iters 10')