Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/111#discussion_r139901202 --- Diff: core/src/main/java/hivemall/evaluation/HitRateUDAF.java --- @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ +package hivemall.evaluation; + +import hivemall.utils.hadoop.HiveUtils; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import javax.annotation.Nonnull; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.io.LongWritable; + + +@Description( + name = "hitrate", + value = "_FUNC_(array rankItems, array correctItems [, const int recommendSize = rankItems.size])" + + " - Returns HitRate") +public final class HitRateUDAF extends AbstractGenericUDAFResolver { + + // prevent instantiation + private HitRateUDAF() {} + + @Override + public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) throws SemanticException { + if (typeInfo.length != 2 && typeInfo.length != 3) { + throw new UDFArgumentTypeException(typeInfo.length - 1, + "_FUNC_ takes two or three arguments"); + } + + ListTypeInfo arg1type = HiveUtils.asListTypeInfo(typeInfo[0]); + if (!HiveUtils.isPrimitiveTypeInfo(arg1type.getListElementTypeInfo())) { + throw new UDFArgumentTypeException(0, + "The first argument `array rankItems` is invalid form: " + typeInfo[0]); + } + ListTypeInfo arg2type = HiveUtils.asListTypeInfo(typeInfo[1]); + if (!HiveUtils.isPrimitiveTypeInfo(arg2type.getListElementTypeInfo())) { + throw new UDFArgumentTypeException(1, + "The second argument `array correctItems` is invalid form: " + typeInfo[1]); + } + + return new HitRateUDAF.Evaluator(); + } + + public static class Evaluator extends GenericUDAFEvaluator { + + private ListObjectInspector recommendListOI; + private ListObjectInspector truthListOI; + private WritableIntObjectInspector recommendSizeOI; + + private StructObjectInspector internalMergeOI; + private StructField countField; + private StructField sumField; + + public Evaluator() {} + + @Override + public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException { + assert (parameters.length == 2 || parameters.length == 3) : parameters.length; + super.init(mode, parameters); + + // initialize input + if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data + this.recommendListOI = (ListObjectInspector) parameters[0]; + this.truthListOI = (ListObjectInspector) parameters[1]; + if (parameters.length == 3) { + this.recommendSizeOI = (WritableIntObjectInspector) parameters[2]; + } + } else {// from partial aggregation + StructObjectInspector soi = (StructObjectInspector) parameters[0]; + this.internalMergeOI = soi; + this.countField = soi.getStructFieldRef("count"); + this.sumField = soi.getStructFieldRef("sum"); + } + + // initialize output + final ObjectInspector outputOI; + if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial + outputOI = internalMergeOI(); + } else {// terminate + outputOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; + } + return outputOI; + } + + private static StructObjectInspector internalMergeOI() { + ArrayList<String> fieldNames = new ArrayList<String>(); + ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + + fieldNames.add("sum"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + fieldNames.add("count"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + @Override + public HitRateAggregationBuffer getNewAggregationBuffer() throws HiveException { + HitRateAggregationBuffer myAggr = new HitRateAggregationBuffer(); + reset(myAggr); + return myAggr; + } + + @Override + public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + HitRateAggregationBuffer myAggr = (HitRateAggregationBuffer) agg; + myAggr.reset(); + } + + @Override + public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg, + Object[] parameters) throws HiveException { + HitRateAggregationBuffer myAggr = (HitRateAggregationBuffer) agg; + + List<?> recommendList = recommendListOI.getList(parameters[0]); + if (recommendList == null) { + recommendList = Collections.emptyList(); + } + List<?> truthList = truthListOI.getList(parameters[1]); + if (truthList == null) { + return; + } + + int recommendSize = recommendList.size(); + if (parameters.length == 3) { + recommendSize = recommendSizeOI.get(parameters[2]); + } + if (recommendSize < 0 || recommendSize > recommendList.size()) { --- End diff -- Hit rate should accept `recommendSize > recommendList.size()`.
---