Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141545805
  
    --- Diff: core/src/main/java/hivemall/embedding/AliasTableBuilderUDTF.java 
---
    @@ -0,0 +1,203 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +import hivemall.utils.collections.maps.Int2IntOpenHashTable;
    +import hivemall.utils.hadoop.HiveUtils;
    +
    +import java.util.List;
    +import java.util.ArrayList;
    +import java.util.Map;
    +import java.util.Queue;
    +import java.util.ArrayDeque;
    +
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
    +import org.apache.hadoop.hive.ql.metadata.HiveException;
    +import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF;
    +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
    +import 
org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
    +import 
org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
    +import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    +import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
    +import org.apache.hadoop.io.FloatWritable;
    +import org.apache.hadoop.io.IntWritable;
    +import org.apache.hadoop.io.Text;
    +
    +import javax.annotation.Nonnull;
    +
    +public final class AliasTableBuilderUDTF extends GenericUDTF {
    +    private MapObjectInspector negativeTableOI;
    +    private PrimitiveObjectInspector negativeTableKeyOI;
    +    private PrimitiveObjectInspector negativeTableValueOI;
    +
    +    private int numVocab;
    +    private List<String> index2word;
    +    private Int2IntOpenHashTable A;
    +    private Int2FloatOpenHashTable S;
    +    private boolean isIntElement;
    +
    +    @Override
    +    public StructObjectInspector initialize(ObjectInspector[] argOIs) 
throws UDFArgumentException {
    +        if (!(argOIs.length >= 1)) {
    +            throw new UDFArgumentException(
    +                "_FUNC_(map<string, double>) takes at least one argument");
    +        }
    +
    +        this.negativeTableOI = HiveUtils.asMapOI(argOIs[0]);
    +        this.negativeTableValueOI = 
HiveUtils.asFloatingPointOI(negativeTableOI.getMapValueObjectInspector());
    +
    +        boolean isIntEmelentOI = 
HiveUtils.isIntOI((negativeTableOI.getMapKeyObjectInspector()));
    +
    +        if (isIntEmelentOI) {
    +            this.negativeTableKeyOI = 
HiveUtils.asIntCompatibleOI(negativeTableOI.getMapKeyObjectInspector());
    +        } else {
    +            this.negativeTableKeyOI = 
HiveUtils.asStringOI(negativeTableOI.getMapKeyObjectInspector());
    +        }
    +
    +        List<String> fieldNames = new ArrayList<>();
    +        List<ObjectInspector> fieldOIs = new ArrayList<>();
    +        fieldNames.add("word");
    +
    +        if (isIntEmelentOI) {
    +            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        } else {
    +            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        }
    +
    +        fieldNames.add("p");
    +        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
    +
    +        fieldNames.add("other");
    +        if (isIntEmelentOI) {
    +            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        } else {
    +            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        }
    +
    +        this.isIntElement = isIntEmelentOI;
    +        return 
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    +    }
    +
    +    @Override
    +    public void process(Object[] args) throws HiveException {
    +        if (!isIntElement) {
    +            index2word = new ArrayList<>();
    +        }
    +
    +        final List<Float> unnormalizedProb = new ArrayList<>();
    +        int numVocab = 0;
    +        float denom = 0.f;
    +        for (Map.Entry<?, ?> entry : 
negativeTableOI.getMap(args[0]).entrySet()) {
    +            if (!isIntElement) {
    +                String word = 
PrimitiveObjectInspectorUtils.getString(entry.getKey(),
    +                    negativeTableKeyOI);
    +                index2word.add(word);
    +            }
    +
    +            float v = 
PrimitiveObjectInspectorUtils.getFloat(entry.getValue(), negativeTableValueOI);
    +            unnormalizedProb.add(v);
    +            denom += v;
    +            numVocab++;
    +        }
    +
    +        this.numVocab = numVocab;
    +        createAliasTable(numVocab, denom, unnormalizedProb);
    +    }
    +
    +    private void createAliasTable(final int V, final float denom,
    +            final @Nonnull List<Float> unnormalizedProb) {
    +        Int2FloatOpenHashTable S = new Int2FloatOpenHashTable(V);
    +        Int2IntOpenHashTable A = new Int2IntOpenHashTable(V);
    +
    +        final Queue<Integer> higherBin = new ArrayDeque<>();
    +        final Queue<Integer> lowerBin = new ArrayDeque<>();
    +
    +        for (int i = 0; i < V; i++) {
    +            float v = V * unnormalizedProb.get(i) / denom;
    +            S.put(i, v);
    +            if (v > 1.f) {
    +                higherBin.add(i);
    +            } else {
    +                lowerBin.add(i);
    +            }
    +        }
    +
    +        while (lowerBin.size() > 0 && higherBin.size() > 0) {
    +            int low = lowerBin.remove();
    +            int high = higherBin.remove();
    +            A.put(low, high);
    +            S.put(high, S.get(high) - 1.f + S.get(low));
    +            if (S.get(high) < 1.f) {
    +                lowerBin.add(high);
    +            } else {
    +                higherBin.add(high);
    +            }
    +        }
    +        this.A = A;
    +        this.S = S;
    +    }
    +
    +    @Override
    +    public void close() throws HiveException {
    +        if (isIntElement) {
    +            IntWritable word = new IntWritable();
    +            FloatWritable pro = new FloatWritable();
    +            IntWritable otherWord = new IntWritable();
    +
    +            Object[] res = new Object[3];
    +            res[0] = word;
    +            res[1] = pro;
    +            res[2] = otherWord;
    +
    +            for (int i = 0; i < numVocab; i++) {
    +                word.set(i);
    +                pro.set(S.get(i));
    +                if (A.get(i) == -1) {
    +                    otherWord.set(0);
    +                } else {
    +                    otherWord.set(A.get(i));
    +                }
    +                forward(res);
    +            }
    +        } else {
    +            Text word = new Text();
    +            FloatWritable pro = new FloatWritable();
    +            Text otherWord = new Text();
    +
    +            Object[] res = new Object[3];
    --- End diff --
    
    put `final` for local variables that are unchanged in the for loop.


---

Reply via email to