Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141545805 --- Diff: core/src/main/java/hivemall/embedding/AliasTableBuilderUDTF.java --- @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; +import hivemall.utils.collections.maps.Int2IntOpenHashTable; +import hivemall.utils.hadoop.HiveUtils; + +import java.util.List; +import java.util.ArrayList; +import java.util.Map; +import java.util.Queue; +import java.util.ArrayDeque; + +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; + +import javax.annotation.Nonnull; + +public final class AliasTableBuilderUDTF extends GenericUDTF { + private MapObjectInspector negativeTableOI; + private PrimitiveObjectInspector negativeTableKeyOI; + private PrimitiveObjectInspector negativeTableValueOI; + + private int numVocab; + private List<String> index2word; + private Int2IntOpenHashTable A; + private Int2FloatOpenHashTable S; + private boolean isIntElement; + + @Override + public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + if (!(argOIs.length >= 1)) { + throw new UDFArgumentException( + "_FUNC_(map<string, double>) takes at least one argument"); + } + + this.negativeTableOI = HiveUtils.asMapOI(argOIs[0]); + this.negativeTableValueOI = HiveUtils.asFloatingPointOI(negativeTableOI.getMapValueObjectInspector()); + + boolean isIntEmelentOI = HiveUtils.isIntOI((negativeTableOI.getMapKeyObjectInspector())); + + if (isIntEmelentOI) { + this.negativeTableKeyOI = HiveUtils.asIntCompatibleOI(negativeTableOI.getMapKeyObjectInspector()); + } else { + this.negativeTableKeyOI = HiveUtils.asStringOI(negativeTableOI.getMapKeyObjectInspector()); + } + + List<String> fieldNames = new ArrayList<>(); + List<ObjectInspector> fieldOIs = new ArrayList<>(); + fieldNames.add("word"); + + if (isIntEmelentOI) { + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + } else { + fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); + } + + fieldNames.add("p"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + + fieldNames.add("other"); + if (isIntEmelentOI) { + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + } else { + fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); + } + + this.isIntElement = isIntEmelentOI; + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + @Override + public void process(Object[] args) throws HiveException { + if (!isIntElement) { + index2word = new ArrayList<>(); + } + + final List<Float> unnormalizedProb = new ArrayList<>(); + int numVocab = 0; + float denom = 0.f; + for (Map.Entry<?, ?> entry : negativeTableOI.getMap(args[0]).entrySet()) { + if (!isIntElement) { + String word = PrimitiveObjectInspectorUtils.getString(entry.getKey(), + negativeTableKeyOI); + index2word.add(word); + } + + float v = PrimitiveObjectInspectorUtils.getFloat(entry.getValue(), negativeTableValueOI); + unnormalizedProb.add(v); + denom += v; + numVocab++; + } + + this.numVocab = numVocab; + createAliasTable(numVocab, denom, unnormalizedProb); + } + + private void createAliasTable(final int V, final float denom, + final @Nonnull List<Float> unnormalizedProb) { + Int2FloatOpenHashTable S = new Int2FloatOpenHashTable(V); + Int2IntOpenHashTable A = new Int2IntOpenHashTable(V); + + final Queue<Integer> higherBin = new ArrayDeque<>(); + final Queue<Integer> lowerBin = new ArrayDeque<>(); + + for (int i = 0; i < V; i++) { + float v = V * unnormalizedProb.get(i) / denom; + S.put(i, v); + if (v > 1.f) { + higherBin.add(i); + } else { + lowerBin.add(i); + } + } + + while (lowerBin.size() > 0 && higherBin.size() > 0) { + int low = lowerBin.remove(); + int high = higherBin.remove(); + A.put(low, high); + S.put(high, S.get(high) - 1.f + S.get(low)); + if (S.get(high) < 1.f) { + lowerBin.add(high); + } else { + higherBin.add(high); + } + } + this.A = A; + this.S = S; + } + + @Override + public void close() throws HiveException { + if (isIntElement) { + IntWritable word = new IntWritable(); + FloatWritable pro = new FloatWritable(); + IntWritable otherWord = new IntWritable(); + + Object[] res = new Object[3]; + res[0] = word; + res[1] = pro; + res[2] = otherWord; + + for (int i = 0; i < numVocab; i++) { + word.set(i); + pro.set(S.get(i)); + if (A.get(i) == -1) { + otherWord.set(0); + } else { + otherWord.set(A.get(i)); + } + forward(res); + } + } else { + Text word = new Text(); + FloatWritable pro = new FloatWritable(); + Text otherWord = new Text(); + + Object[] res = new Object[3]; --- End diff -- put `final` for local variables that are unchanged in the for loop.
---