This is an automated email from the ASF dual-hosted git repository.

myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git


The following commit(s) were added to refs/heads/master by this push:
     new f8a2b06  [HIVEMALL-288] mf_predict throws SemanticException No 
matching method with (array<double>, array<double>, int)
f8a2b06 is described below

commit f8a2b06de1d2c33d1ef1753b3ec8a42a48e6537d
Author: Makoto Yui <m...@apache.org>
AuthorDate: Thu Dec 12 17:32:27 2019 +0900

    [HIVEMALL-288] mf_predict throws SemanticException No matching method with 
(array<double>, array<double>, int)
    
    ## What changes were proposed in this pull request?
    
    `mf_predict` throws SemanticException No matching method with 
(array<double>, array<double>, int)
    
    ## What type of PR is it?
    
    Bug Fix
    
    ## What is the Jira issue?
    
    https://issues.apache.org/jira/browse/HIVEMALL-288
    
    ## How was this patch tested?
    
    manual tests on EMR
    
    ```sql
    select
      -- 3 arguments
      mf_predict(array(cast(1.0 as float),cast(2.0 as float),cast(3.0 as 
float)), array(cast(1.0 as float),cast(2.0 as float),cast(3.0 as float)), 1),
      mf_predict(array(1.0,2.0,3.0), array(1.0,2.0,3.0), 1),
      mf_predict(array(cast(1.0 as DOUBLE),cast(2.0 as DOUBLE),cast(3.0 as 
DOUBLE)), array(cast(1.0 as DOUBLE),cast(2.0 as DOUBLE),cast(3.0 as DOUBLE)), 
1),
      -- 2 arguments
      mf_predict(array(1.0,2.0,3.0), array(1.0,2.0,3.0)),
      -- 4 arguments
      mf_predict(array(1.0,2.0,3.0), array(1.0,2.0,3.0), 0, 0),
      -- 5 arguments
      mf_predict(array(1.0,2.0,3.0), array(1.0,2.0,3.0), 0, 0, 1);
    ```
    
    ## Checklist
    
    (Please remove this section if not needed; check `x` for YES, blank for NO)
    
    - [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, 
for your commit?
    - [x] Did you run system tests on Hive (or Spark)?
    
    Author: Makoto Yui <m...@apache.org>
    
    Closes #224 from myui/HIVEMALL-288.
---
 .../hivemall/factorization/mf/MFPredictionUDF.java | 204 ++++++++++++---------
 .../main/java/hivemall/utils/hadoop/HiveUtils.java |  20 ++
 2 files changed, 142 insertions(+), 82 deletions(-)

diff --git a/core/src/main/java/hivemall/factorization/mf/MFPredictionUDF.java 
b/core/src/main/java/hivemall/factorization/mf/MFPredictionUDF.java
index c91e0eb..c73e96f 100644
--- a/core/src/main/java/hivemall/factorization/mf/MFPredictionUDF.java
+++ b/core/src/main/java/hivemall/factorization/mf/MFPredictionUDF.java
@@ -18,121 +18,161 @@
  */
 package hivemall.factorization.mf;
 
-import java.util.List;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Preconditions;
 
-import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
 
+import org.apache.commons.lang.StringUtils;
 import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDF;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
 import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
 import org.apache.hadoop.hive.serde2.io.DoubleWritable;
-import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
 
 @Description(name = "mf_predict",
-        value = "_FUNC_(List<Float> Pu, List<Float> Qi[, double Bu, double 
Bi[, double mu]]) - Returns the prediction value")
+        value = "_FUNC_(array<double> Pu, array<double> Qi[, double Bu, double 
Bi[, double mu]]) - Returns the prediction value")
 @UDFType(deterministic = true, stateful = false)
-public final class MFPredictionUDF extends UDF {
+public final class MFPredictionUDF extends GenericUDF {
 
-    @Nonnull
-    public DoubleWritable evaluate(@Nullable List<FloatWritable> Pu,
-            @Nullable List<FloatWritable> Qi) throws HiveException {
-        return evaluate(Pu, Qi, null);
-    }
+    private ListObjectInspector puOI, qiOI;
+    private PrimitiveObjectInspector puElemOI, qiElemOI;
 
-    @Nonnull
-    public DoubleWritable evaluate(@Nullable List<FloatWritable> Pu,
-            @Nullable List<FloatWritable> Qi, @Nullable DoubleWritable mu) 
throws HiveException {
-        final double muValue = (mu == null) ? 0.d : mu.get();
-        if (Pu == null || Qi == null) {
-            return new DoubleWritable(muValue);
-        }
+    @Nullable
+    private PrimitiveObjectInspector buOI, biOI, muOI;
 
-        final int PuSize = Pu.size();
-        final int QiSize = Qi.size();
-        // workaround for TD
-        if (PuSize == 0) {
-            return new DoubleWritable(muValue);
-        } else if (QiSize == 0) {
-            return new DoubleWritable(muValue);
+    private DoubleWritable result;
+
+    @Override
+    public ObjectInspector initialize(ObjectInspector[] argOIs) throws 
UDFArgumentException {
+        if (argOIs.length < 2 || argOIs.length > 5) {
+            throw new UDFArgumentException("mf_predict takes 2~5 arguments: " 
+ argOIs.length);
         }
 
-        if (QiSize != PuSize) {
-            throw new HiveException("|Pu| " + PuSize + " was not equal to |Qi| 
" + QiSize);
+        this.puOI = HiveUtils.asListOI(argOIs, 0);
+        this.puElemOI = 
HiveUtils.asFloatingPointOI(puOI.getListElementObjectInspector());
+        this.qiOI = HiveUtils.asListOI(argOIs, 1);
+        this.qiElemOI = 
HiveUtils.asFloatingPointOI(qiOI.getListElementObjectInspector());
+
+        switch (argOIs.length) {
+            case 3:
+                this.muOI = HiveUtils.asNumberOI(argOIs, 2);
+                break;
+            case 4:
+                this.buOI = HiveUtils.asNumberOI(argOIs, 2);
+                this.biOI = HiveUtils.asNumberOI(argOIs, 3);
+                break;
+            case 5:
+                this.buOI = HiveUtils.asNumberOI(argOIs, 2);
+                this.biOI = HiveUtils.asNumberOI(argOIs, 3);
+                this.muOI = HiveUtils.asNumberOI(argOIs, 4);
+                break;
+            default:
+                break;
         }
 
-        double ret = muValue;
-        for (int k = 0; k < PuSize; k++) {
-            FloatWritable Pu_k = Pu.get(k);
-            if (Pu_k == null) {
-                continue;
+        this.result = new DoubleWritable();
+        return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
+    }
+
+    @Override
+    public Object evaluate(DeferredObject[] args) throws HiveException {
+        Preconditions.checkArgument(args.length >= 2 && args.length <= 5, 
args.length);
+
+        @Nullable
+        double[] pu = HiveUtils.asDoubleArray(args[0].get(), puOI, puElemOI);
+        @Nullable
+        double[] qi = HiveUtils.asDoubleArray(args[1].get(), qiOI, qiElemOI);
+
+        double mu = 0.d, bu = 0.d, bi = 0.d;
+        switch (args.length) {
+            case 3: {
+                Object arg2 = args[2].get();
+                if (arg2 != null) {
+                    mu = PrimitiveObjectInspectorUtils.getDouble(arg2, muOI);
+                }
+                break;
+            }
+            case 4: {
+                Object arg2 = args[2].get();
+                if (arg2 != null) {
+                    bu = PrimitiveObjectInspectorUtils.getDouble(arg2, buOI);
+                }
+                Object arg3 = args[3].get();
+                if (arg3 != null) {
+                    bi = PrimitiveObjectInspectorUtils.getDouble(arg3, biOI);
+                }
+                break;
             }
-            FloatWritable Qi_k = Qi.get(k);
-            if (Qi_k == null) {
-                continue;
+            case 5: {
+                Object arg2 = args[2].get();
+                if (arg2 != null) {
+                    bu = PrimitiveObjectInspectorUtils.getDouble(arg2, buOI);
+                }
+                Object arg3 = args[3].get();
+                if (arg3 != null) {
+                    bi = PrimitiveObjectInspectorUtils.getDouble(arg3, biOI);
+                }
+                Object arg4 = args[4].get();
+                if (arg4 != null) {
+                    mu = PrimitiveObjectInspectorUtils.getDouble(arg4, muOI);
+                }
+                break;
             }
-            ret += Pu_k.get() * Qi_k.get();
+            default:
+                break;
         }
-        return new DoubleWritable(ret);
-    }
 
-    @Nonnull
-    public DoubleWritable evaluate(@Nullable List<FloatWritable> Pu,
-            @Nullable List<FloatWritable> Qi, @Nullable DoubleWritable Bu,
-            @Nullable DoubleWritable Bi) throws HiveException {
-        return evaluate(Pu, Qi, Bu, Bi, null);
+        double predicted = mfPredict(pu, qi, bu, bi, mu);
+        result.set(predicted);
+        return result;
     }
 
-    @Nonnull
-    public DoubleWritable evaluate(@Nullable List<FloatWritable> Pu,
-            @Nullable List<FloatWritable> Qi, @Nullable DoubleWritable Bu,
-            @Nullable DoubleWritable Bi, @Nullable DoubleWritable mu) throws 
HiveException {
-        final double muValue = (mu == null) ? 0.d : mu.get();
-        if (Pu == null && Qi == null) {
-            return new DoubleWritable(muValue);
-        }
-        final double BiValue = (Bi == null) ? 0.d : Bi.get();
-        final double BuValue = (Bu == null) ? 0.d : Bu.get();
+    private static double mfPredict(@Nullable final double[] Pu, @Nullable 
final double[] Qi,
+            final double Bu, final double Bi, final double mu) throws 
UDFArgumentException {
         if (Pu == null) {
-            double ret = muValue + BiValue;
-            return new DoubleWritable(ret);
+            if (Qi == null) {
+                return mu;
+            } else {
+                return mu + Bi;
+            }
         } else if (Qi == null) {
-            return new DoubleWritable(muValue);
+            return mu + Bu;
         }
-
-        final int PuSize = Pu.size();
-        final int QiSize = Qi.size();
-        // workaround for TD        
-        if (PuSize == 0) {
-            if (QiSize == 0) {
-                return new DoubleWritable(muValue);
+        // workaround for TD
+        if (Pu.length == 0) {
+            if (Qi.length == 0) {
+                return mu;
             } else {
-                double ret = muValue + BiValue;
-                return new DoubleWritable(ret);
+                return mu + Bi;
             }
-        } else if (QiSize == 0) {
-            double ret = muValue + BuValue;
-            return new DoubleWritable(ret);
+        } else if (Qi.length == 0) {
+            return mu + Bu;
         }
 
-        if (QiSize != PuSize) {
-            throw new HiveException("|Pu| " + PuSize + " was not equal to |Qi| 
" + QiSize);
+        if (Pu.length != Qi.length) {
+            throw new UDFArgumentException(
+                "|Pu| " + Pu.length + " was not equal to |Qi| " + Qi.length);
         }
 
-        double ret = muValue + BuValue + BiValue;
-        for (int k = 0; k < PuSize; k++) {
-            FloatWritable Pu_k = Pu.get(k);
-            if (Pu_k == null) {
-                continue;
-            }
-            FloatWritable Qi_k = Qi.get(k);
-            if (Qi_k == null) {
-                continue;
-            }
-            ret += Pu_k.get() * Qi_k.get();
+        double ret = mu + Bu + Bi;
+        for (int k = 0, size = Pu.length; k < size; k++) {
+            double pu_k = Pu[k];
+            double qi_k = Qi[k];
+            ret += pu_k * qi_k;
         }
-        return new DoubleWritable(ret);
+        return ret;
+    }
+
+    @Override
+    public String getDisplayString(String[] args) {
+        return "mf_predict(" + StringUtils.join(args, ',') + ')';
     }
 
 }
diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java 
b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index 38b37a4..293d236 100644
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@ -1302,6 +1302,26 @@ public final class HiveUtils {
     }
 
     @Nonnull
+    public static PrimitiveObjectInspector asNumberOI(@Nonnull final 
ObjectInspector[] argOIs,
+            final int argIndex) throws UDFArgumentException {
+        final PrimitiveObjectInspector oi = asPrimitiveObjectInspector(argOIs, 
argIndex);
+        switch (oi.getPrimitiveCategory()) {
+            case BYTE:
+            case SHORT:
+            case INT:
+            case LONG:
+            case FLOAT:
+            case DOUBLE:
+            case DECIMAL:
+                break;
+            default:
+                throw new UDFArgumentTypeException(argIndex,
+                    "Only numeric argument is accepted but " + 
oi.getTypeName() + " is passed.");
+        }
+        return oi;
+    }
+
+    @Nonnull
     public static PrimitiveObjectInspector asNumberOI(@Nonnull final 
ObjectInspector argOI)
             throws UDFArgumentTypeException {
         if (argOI.getCategory() != Category.PRIMITIVE) {

Reply via email to