Solodye commented on a change in pull request #193: [WIP][HIVEMALL-253-2] map_roulette UDF URL: https://github.com/apache/incubator-hivemall/pull/193#discussion_r291990428
########## File path: core/src/main/java/hivemall/tools/map/MapRouletteUDF.java ########## @@ -49,164 +39,171 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import com.clearspring.analytics.util.Preconditions; + /** * The map_roulette returns a map key based on weighted random sampling of map values. */ -@Description(name = "map_roulette", value = "_FUNC_(Map<K, number> map)" - + " - Returns a map key based on weighted random sampling of map values") +@Description(name = "map_roulette", + value = "_FUNC_(Map<K, number> map [, (const) int/bigint seed])" + + " - Returns a map key based on weighted random sampling of map values." + + " Average of values is used for null values") @UDFType(deterministic = false, stateful = false) // it is false because it return value base on probability public final class MapRouletteUDF extends GenericUDF { private transient MapObjectInspector mapOI; private transient PrimitiveObjectInspector valueOI; + @Nullable + private transient PrimitiveObjectInspector seedOI; + + @Nullable + private transient Random _rand; @Override - public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { - if (arguments.length != 1) { + public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + if (argOIs.length != 1 && argOIs.length != 2) { throw new UDFArgumentLengthException( - "Expected exactly one argument for map_roulette: " + arguments.length); + "Expected exactly one argument for map_roulette: " + argOIs.length); } - if (arguments[0].getCategory() != ObjectInspector.Category.MAP) { + if (argOIs[0].getCategory() != ObjectInspector.Category.MAP) { throw new UDFArgumentTypeException(0, - "Only map type argument is accepted but got " + arguments[0].getTypeName()); + "Only map type argument is accepted but got " + argOIs[0].getTypeName()); } - mapOI = HiveUtils.asMapOI(arguments[0]); - ObjectInspector keyOI = mapOI.getMapKeyObjectInspector(); - - //judge valueOI is a number - valueOI = (PrimitiveObjectInspector) mapOI.getMapValueObjectInspector(); - switch (valueOI.getTypeName()) { - case INT_TYPE_NAME: - case DOUBLE_TYPE_NAME: - case BIGINT_TYPE_NAME: - case FLOAT_TYPE_NAME: - case SMALLINT_TYPE_NAME: - case TINYINT_TYPE_NAME: - case DECIMAL_TYPE_NAME: - case STRING_TYPE_NAME: - // Pass an empty map or a map full of {null, null} will get string type - // An number in string format like "3.5" also support - break; - default: + + this.mapOI = HiveUtils.asMapOI(argOIs[0]); + this.valueOI = HiveUtils.asNumberOI(mapOI.getMapValueObjectInspector()); + + if (argOIs.length == 2) { + ObjectInspector argOI1 = argOIs[1]; + if (HiveUtils.isIntegerOI(argOI1) == false) { throw new UDFArgumentException( - "Expected a number but get: " + valueOI.getTypeName()); + "The second argument of map_roulette must be integer type: " + + argOI1.getTypeName()); + } + if (ObjectInspectorUtils.isConstantObjectInspector(argOI1)) { + long seed = HiveUtils.getAsConstLong(argOI1); + this._rand = new Random(seed); // fixed seed + } else { + this.seedOI = HiveUtils.asLongCompatibleOI(argOI1); + } + } else { + this._rand = new Random(); // random seed } - return keyOI; + + return mapOI.getMapKeyObjectInspector(); } @Override public Object evaluate(DeferredObject[] arguments) throws HiveException { - Map<Object, Double> input = processObjectDoubleMap(arguments[0]); + Random rand = _rand; + if (rand == null) { + Object arg1 = arguments[1].get(); + if (arg1 == null) { + rand = new Random(); + } else { + long seed = HiveUtils.getLong(arg1, seedOI); + rand = new Random(seed); + } + } + + Map<Object, Double> input = getObjectDoubleMap(arguments[0], mapOI, valueOI); if (input == null) { return null; } - return algorithm(input); - } - @Override - public String getDisplayString(String[] children) { - return "map_roulette(" + StringUtils.join(children, ',') + ")"; + return rouletteWheelSelection(input, rand); } - /** - * Process the data passed by user. - * - * @param argument data passed by user - * @return If all the value is , - * @throws HiveException If get the wrong weight value like {key = "Wang", value = "Zhang"}, - * "Zhang" isn't a number ,this Method will throw exception when - * convertPrimitiveToDouble("Zhang", valueOD) - */ - private Map<Object, Double> processObjectDoubleMap(DeferredObject argument) - throws HiveException { - // get - Map<?, ?> m = mapOI.getMap(argument.get()); + @Nullable + private static Map<Object, Double> getObjectDoubleMap(@Nonnull final DeferredObject argument, + @Nonnull final MapObjectInspector mapOI, + @Nonnull final PrimitiveObjectInspector valueOI) throws HiveException { + final Map<?, ?> m = mapOI.getMap(argument.get()); if (m == null) { return null; } - if (m.size() == 0) { + final int size = m.size(); + if (size == 0) { return null; } - // convert - Map<Object, Double> input = new HashMap<>(); - Double avg = 0.0; + + final Map<Object, Double> result = new HashMap<>(size); Review comment: `Map<Object, Double> result` will be updated, why use final here? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services