Solodye commented on a change in pull request #193: [WIP][HIVEMALL-253-2] 
map_roulette UDF
URL: https://github.com/apache/incubator-hivemall/pull/193#discussion_r291990428
 
 

 ##########
 File path: core/src/main/java/hivemall/tools/map/MapRouletteUDF.java
 ##########
 @@ -49,164 +39,171 @@
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
 import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
 import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
 
+import com.clearspring.analytics.util.Preconditions;
+
 /**
  * The map_roulette returns a map key based on weighted random sampling of map 
values.
  */
-@Description(name = "map_roulette", value = "_FUNC_(Map<K, number> map)"
-        + " - Returns a map key based on weighted random sampling of map 
values")
+@Description(name = "map_roulette",
+        value = "_FUNC_(Map<K, number> map [, (const) int/bigint seed])"
+                + " - Returns a map key based on weighted random sampling of 
map values."
+                + " Average of values is used for null values")
 @UDFType(deterministic = false, stateful = false) // it is false because it 
return value base on probability
 public final class MapRouletteUDF extends GenericUDF {
 
     private transient MapObjectInspector mapOI;
     private transient PrimitiveObjectInspector valueOI;
+    @Nullable
+    private transient PrimitiveObjectInspector seedOI;
+
+    @Nullable
+    private transient Random _rand;
 
     @Override
-    public ObjectInspector initialize(ObjectInspector[] arguments) throws 
UDFArgumentException {
-        if (arguments.length != 1) {
+    public ObjectInspector initialize(ObjectInspector[] argOIs) throws 
UDFArgumentException {
+        if (argOIs.length != 1 && argOIs.length != 2) {
             throw new UDFArgumentLengthException(
-                "Expected exactly one argument for map_roulette: " + 
arguments.length);
+                "Expected exactly one argument for map_roulette: " + 
argOIs.length);
         }
-        if (arguments[0].getCategory() != ObjectInspector.Category.MAP) {
+        if (argOIs[0].getCategory() != ObjectInspector.Category.MAP) {
             throw new UDFArgumentTypeException(0,
-                "Only map type argument is accepted but got " + 
arguments[0].getTypeName());
+                "Only map type argument is accepted but got " + 
argOIs[0].getTypeName());
         }
-        mapOI = HiveUtils.asMapOI(arguments[0]);
-        ObjectInspector keyOI = mapOI.getMapKeyObjectInspector();
-
-        //judge valueOI is a number
-        valueOI = (PrimitiveObjectInspector) 
mapOI.getMapValueObjectInspector();
-        switch (valueOI.getTypeName()) {
-            case INT_TYPE_NAME:
-            case DOUBLE_TYPE_NAME:
-            case BIGINT_TYPE_NAME:
-            case FLOAT_TYPE_NAME:
-            case SMALLINT_TYPE_NAME:
-            case TINYINT_TYPE_NAME:
-            case DECIMAL_TYPE_NAME:
-            case STRING_TYPE_NAME:
-                // Pass an empty map or a map full of {null, null} will get 
string type
-                // An number in string format like "3.5" also support
-                break;
-            default:
+
+        this.mapOI = HiveUtils.asMapOI(argOIs[0]);
+        this.valueOI = 
HiveUtils.asNumberOI(mapOI.getMapValueObjectInspector());
+
+        if (argOIs.length == 2) {
+            ObjectInspector argOI1 = argOIs[1];
+            if (HiveUtils.isIntegerOI(argOI1) == false) {
                 throw new UDFArgumentException(
-                    "Expected a number but get: " + valueOI.getTypeName());
+                    "The second argument of map_roulette must be integer type: 
"
+                            + argOI1.getTypeName());
+            }
+            if (ObjectInspectorUtils.isConstantObjectInspector(argOI1)) {
+                long seed = HiveUtils.getAsConstLong(argOI1);
+                this._rand = new Random(seed); // fixed seed
+            } else {
+                this.seedOI = HiveUtils.asLongCompatibleOI(argOI1);
+            }
+        } else {
+            this._rand = new Random(); // random seed
         }
-        return keyOI;
+
+        return mapOI.getMapKeyObjectInspector();
     }
 
     @Override
     public Object evaluate(DeferredObject[] arguments) throws HiveException {
-        Map<Object, Double> input = processObjectDoubleMap(arguments[0]);
+        Random rand = _rand;
+        if (rand == null) {
+            Object arg1 = arguments[1].get();
+            if (arg1 == null) {
+                rand = new Random();
+            } else {
+                long seed = HiveUtils.getLong(arg1, seedOI);
+                rand = new Random(seed);
+            }
+        }
+
+        Map<Object, Double> input = getObjectDoubleMap(arguments[0], mapOI, 
valueOI);
         if (input == null) {
             return null;
         }
-        return algorithm(input);
-    }
 
-    @Override
-    public String getDisplayString(String[] children) {
-        return "map_roulette(" + StringUtils.join(children, ',') + ")";
+        return rouletteWheelSelection(input, rand);
     }
 
-    /**
-     * Process the data passed by user.
-     * 
-     * @param argument data passed by user
-     * @return If all the value is ,
-     * @throws HiveException If get the wrong weight value like {key = "Wang", 
value = "Zhang"},
-     *         "Zhang" isn't a number ,this Method will throw exception when
-     *         convertPrimitiveToDouble("Zhang", valueOD)
-     */
-    private Map<Object, Double> processObjectDoubleMap(DeferredObject argument)
-            throws HiveException {
-        // get
-        Map<?, ?> m = mapOI.getMap(argument.get());
+    @Nullable
+    private static Map<Object, Double> getObjectDoubleMap(@Nonnull final 
DeferredObject argument,
+            @Nonnull final MapObjectInspector mapOI,
+            @Nonnull final PrimitiveObjectInspector valueOI) throws 
HiveException {
+        final Map<?, ?> m = mapOI.getMap(argument.get());
         if (m == null) {
             return null;
         }
-        if (m.size() == 0) {
+        final int size = m.size();
+        if (size == 0) {
             return null;
         }
-        // convert
-        Map<Object, Double> input = new HashMap<>();
-        Double avg = 0.0;
+
+        final Map<Object, Double> result = new HashMap<>(size);
 
 Review comment:
   `Map<Object, Double> result` will be updated, why use final here?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to