This is an automated email from the ASF dual-hosted git repository. myui pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git
The following commit(s) were added to refs/heads/master by this push: new f8a2b06 [HIVEMALL-288] mf_predict throws SemanticException No matching method with (array<double>, array<double>, int) f8a2b06 is described below commit f8a2b06de1d2c33d1ef1753b3ec8a42a48e6537d Author: Makoto Yui <m...@apache.org> AuthorDate: Thu Dec 12 17:32:27 2019 +0900 [HIVEMALL-288] mf_predict throws SemanticException No matching method with (array<double>, array<double>, int) ## What changes were proposed in this pull request? `mf_predict` throws SemanticException No matching method with (array<double>, array<double>, int) ## What type of PR is it? Bug Fix ## What is the Jira issue? https://issues.apache.org/jira/browse/HIVEMALL-288 ## How was this patch tested? manual tests on EMR ```sql select -- 3 arguments mf_predict(array(cast(1.0 as float),cast(2.0 as float),cast(3.0 as float)), array(cast(1.0 as float),cast(2.0 as float),cast(3.0 as float)), 1), mf_predict(array(1.0,2.0,3.0), array(1.0,2.0,3.0), 1), mf_predict(array(cast(1.0 as DOUBLE),cast(2.0 as DOUBLE),cast(3.0 as DOUBLE)), array(cast(1.0 as DOUBLE),cast(2.0 as DOUBLE),cast(3.0 as DOUBLE)), 1), -- 2 arguments mf_predict(array(1.0,2.0,3.0), array(1.0,2.0,3.0)), -- 4 arguments mf_predict(array(1.0,2.0,3.0), array(1.0,2.0,3.0), 0, 0), -- 5 arguments mf_predict(array(1.0,2.0,3.0), array(1.0,2.0,3.0), 0, 0, 1); ``` ## Checklist (Please remove this section if not needed; check `x` for YES, blank for NO) - [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit? - [x] Did you run system tests on Hive (or Spark)? Author: Makoto Yui <m...@apache.org> Closes #224 from myui/HIVEMALL-288. --- .../hivemall/factorization/mf/MFPredictionUDF.java | 204 ++++++++++++--------- .../main/java/hivemall/utils/hadoop/HiveUtils.java | 20 ++ 2 files changed, 142 insertions(+), 82 deletions(-) diff --git a/core/src/main/java/hivemall/factorization/mf/MFPredictionUDF.java b/core/src/main/java/hivemall/factorization/mf/MFPredictionUDF.java index c91e0eb..c73e96f 100644 --- a/core/src/main/java/hivemall/factorization/mf/MFPredictionUDF.java +++ b/core/src/main/java/hivemall/factorization/mf/MFPredictionUDF.java @@ -18,121 +18,161 @@ */ package hivemall.factorization.mf; -import java.util.List; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Preconditions; -import javax.annotation.Nonnull; import javax.annotation.Nullable; +import org.apache.commons.lang.StringUtils; import org.apache.hadoop.hive.ql.exec.Description; -import org.apache.hadoop.hive.ql.exec.UDF; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.io.DoubleWritable; -import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; @Description(name = "mf_predict", - value = "_FUNC_(List<Float> Pu, List<Float> Qi[, double Bu, double Bi[, double mu]]) - Returns the prediction value") + value = "_FUNC_(array<double> Pu, array<double> Qi[, double Bu, double Bi[, double mu]]) - Returns the prediction value") @UDFType(deterministic = true, stateful = false) -public final class MFPredictionUDF extends UDF { +public final class MFPredictionUDF extends GenericUDF { - @Nonnull - public DoubleWritable evaluate(@Nullable List<FloatWritable> Pu, - @Nullable List<FloatWritable> Qi) throws HiveException { - return evaluate(Pu, Qi, null); - } + private ListObjectInspector puOI, qiOI; + private PrimitiveObjectInspector puElemOI, qiElemOI; - @Nonnull - public DoubleWritable evaluate(@Nullable List<FloatWritable> Pu, - @Nullable List<FloatWritable> Qi, @Nullable DoubleWritable mu) throws HiveException { - final double muValue = (mu == null) ? 0.d : mu.get(); - if (Pu == null || Qi == null) { - return new DoubleWritable(muValue); - } + @Nullable + private PrimitiveObjectInspector buOI, biOI, muOI; - final int PuSize = Pu.size(); - final int QiSize = Qi.size(); - // workaround for TD - if (PuSize == 0) { - return new DoubleWritable(muValue); - } else if (QiSize == 0) { - return new DoubleWritable(muValue); + private DoubleWritable result; + + @Override + public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + if (argOIs.length < 2 || argOIs.length > 5) { + throw new UDFArgumentException("mf_predict takes 2~5 arguments: " + argOIs.length); } - if (QiSize != PuSize) { - throw new HiveException("|Pu| " + PuSize + " was not equal to |Qi| " + QiSize); + this.puOI = HiveUtils.asListOI(argOIs, 0); + this.puElemOI = HiveUtils.asFloatingPointOI(puOI.getListElementObjectInspector()); + this.qiOI = HiveUtils.asListOI(argOIs, 1); + this.qiElemOI = HiveUtils.asFloatingPointOI(qiOI.getListElementObjectInspector()); + + switch (argOIs.length) { + case 3: + this.muOI = HiveUtils.asNumberOI(argOIs, 2); + break; + case 4: + this.buOI = HiveUtils.asNumberOI(argOIs, 2); + this.biOI = HiveUtils.asNumberOI(argOIs, 3); + break; + case 5: + this.buOI = HiveUtils.asNumberOI(argOIs, 2); + this.biOI = HiveUtils.asNumberOI(argOIs, 3); + this.muOI = HiveUtils.asNumberOI(argOIs, 4); + break; + default: + break; } - double ret = muValue; - for (int k = 0; k < PuSize; k++) { - FloatWritable Pu_k = Pu.get(k); - if (Pu_k == null) { - continue; + this.result = new DoubleWritable(); + return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; + } + + @Override + public Object evaluate(DeferredObject[] args) throws HiveException { + Preconditions.checkArgument(args.length >= 2 && args.length <= 5, args.length); + + @Nullable + double[] pu = HiveUtils.asDoubleArray(args[0].get(), puOI, puElemOI); + @Nullable + double[] qi = HiveUtils.asDoubleArray(args[1].get(), qiOI, qiElemOI); + + double mu = 0.d, bu = 0.d, bi = 0.d; + switch (args.length) { + case 3: { + Object arg2 = args[2].get(); + if (arg2 != null) { + mu = PrimitiveObjectInspectorUtils.getDouble(arg2, muOI); + } + break; + } + case 4: { + Object arg2 = args[2].get(); + if (arg2 != null) { + bu = PrimitiveObjectInspectorUtils.getDouble(arg2, buOI); + } + Object arg3 = args[3].get(); + if (arg3 != null) { + bi = PrimitiveObjectInspectorUtils.getDouble(arg3, biOI); + } + break; } - FloatWritable Qi_k = Qi.get(k); - if (Qi_k == null) { - continue; + case 5: { + Object arg2 = args[2].get(); + if (arg2 != null) { + bu = PrimitiveObjectInspectorUtils.getDouble(arg2, buOI); + } + Object arg3 = args[3].get(); + if (arg3 != null) { + bi = PrimitiveObjectInspectorUtils.getDouble(arg3, biOI); + } + Object arg4 = args[4].get(); + if (arg4 != null) { + mu = PrimitiveObjectInspectorUtils.getDouble(arg4, muOI); + } + break; } - ret += Pu_k.get() * Qi_k.get(); + default: + break; } - return new DoubleWritable(ret); - } - @Nonnull - public DoubleWritable evaluate(@Nullable List<FloatWritable> Pu, - @Nullable List<FloatWritable> Qi, @Nullable DoubleWritable Bu, - @Nullable DoubleWritable Bi) throws HiveException { - return evaluate(Pu, Qi, Bu, Bi, null); + double predicted = mfPredict(pu, qi, bu, bi, mu); + result.set(predicted); + return result; } - @Nonnull - public DoubleWritable evaluate(@Nullable List<FloatWritable> Pu, - @Nullable List<FloatWritable> Qi, @Nullable DoubleWritable Bu, - @Nullable DoubleWritable Bi, @Nullable DoubleWritable mu) throws HiveException { - final double muValue = (mu == null) ? 0.d : mu.get(); - if (Pu == null && Qi == null) { - return new DoubleWritable(muValue); - } - final double BiValue = (Bi == null) ? 0.d : Bi.get(); - final double BuValue = (Bu == null) ? 0.d : Bu.get(); + private static double mfPredict(@Nullable final double[] Pu, @Nullable final double[] Qi, + final double Bu, final double Bi, final double mu) throws UDFArgumentException { if (Pu == null) { - double ret = muValue + BiValue; - return new DoubleWritable(ret); + if (Qi == null) { + return mu; + } else { + return mu + Bi; + } } else if (Qi == null) { - return new DoubleWritable(muValue); + return mu + Bu; } - - final int PuSize = Pu.size(); - final int QiSize = Qi.size(); - // workaround for TD - if (PuSize == 0) { - if (QiSize == 0) { - return new DoubleWritable(muValue); + // workaround for TD + if (Pu.length == 0) { + if (Qi.length == 0) { + return mu; } else { - double ret = muValue + BiValue; - return new DoubleWritable(ret); + return mu + Bi; } - } else if (QiSize == 0) { - double ret = muValue + BuValue; - return new DoubleWritable(ret); + } else if (Qi.length == 0) { + return mu + Bu; } - if (QiSize != PuSize) { - throw new HiveException("|Pu| " + PuSize + " was not equal to |Qi| " + QiSize); + if (Pu.length != Qi.length) { + throw new UDFArgumentException( + "|Pu| " + Pu.length + " was not equal to |Qi| " + Qi.length); } - double ret = muValue + BuValue + BiValue; - for (int k = 0; k < PuSize; k++) { - FloatWritable Pu_k = Pu.get(k); - if (Pu_k == null) { - continue; - } - FloatWritable Qi_k = Qi.get(k); - if (Qi_k == null) { - continue; - } - ret += Pu_k.get() * Qi_k.get(); + double ret = mu + Bu + Bi; + for (int k = 0, size = Pu.length; k < size; k++) { + double pu_k = Pu[k]; + double qi_k = Qi[k]; + ret += pu_k * qi_k; } - return new DoubleWritable(ret); + return ret; + } + + @Override + public String getDisplayString(String[] args) { + return "mf_predict(" + StringUtils.join(args, ',') + ')'; } } diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java index 38b37a4..293d236 100644 --- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java +++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java @@ -1302,6 +1302,26 @@ public final class HiveUtils { } @Nonnull + public static PrimitiveObjectInspector asNumberOI(@Nonnull final ObjectInspector[] argOIs, + final int argIndex) throws UDFArgumentException { + final PrimitiveObjectInspector oi = asPrimitiveObjectInspector(argOIs, argIndex); + switch (oi.getPrimitiveCategory()) { + case BYTE: + case SHORT: + case INT: + case LONG: + case FLOAT: + case DOUBLE: + case DECIMAL: + break; + default: + throw new UDFArgumentTypeException(argIndex, + "Only numeric argument is accepted but " + oi.getTypeName() + " is passed."); + } + return oi; + } + + @Nonnull public static PrimitiveObjectInspector asNumberOI(@Nonnull final ObjectInspector argOI) throws UDFArgumentTypeException { if (argOI.getCategory() != Category.PRIMITIVE) {