Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141544782 --- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java --- @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.UDTFWithOptions; +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; +import hivemall.utils.collections.maps.OpenHashTable; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Primitives; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import java.util.List; +import java.util.Arrays; +import java.util.ArrayList; + +@Description( + name = "train_word2vec", + value = "_FUNC_(array<array<float | string>> negative_table, array<int | string> doc [, const string options]) - Returns a prediction model") +public class Word2VecUDTF extends UDTFWithOptions { + protected transient AbstractWord2VecModel model; + @Nonnegative + private float startingLR; + @Nonnegative + private long numTrainWords; + private OpenHashTable<String, Integer> word2index; + + @Nonnegative + private int dim; + @Nonnegative + private int win; + @Nonnegative + private int neg; + @Nonnegative + private int iter; + private boolean skipgram; + private boolean isStringInput; + + private Int2FloatOpenHashTable S; + private int[] aliasWordIds; + + private ListObjectInspector negativeTableOI; + private ListObjectInspector negativeTableElementListOI; + private PrimitiveObjectInspector negativeTableElementOI; + + private ListObjectInspector docOI; + private PrimitiveObjectInspector wordOI; + + @Override + public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + final int numArgs = argOIs.length; + + if (numArgs != 3) { + throw new UDFArgumentException(getClass().getSimpleName() + + " takes 3 arguments: [, constant string options]: " + + Arrays.toString(argOIs)); + } + + processOptions(argOIs); + + this.negativeTableOI = HiveUtils.asListOI(argOIs[0]); + this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector()); + this.docOI = HiveUtils.asListOI(argOIs[1]); + + this.isStringInput = HiveUtils.isStringListOI(argOIs[1]); + + if (isStringInput) { + this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector()); + this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector()); + } else { + this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObjectInspector()); + this.wordOI = HiveUtils.asIntCompatibleOI(docOI.getListElementObjectInspector()); + } + + List<String> fieldNames = new ArrayList<>(); + List<ObjectInspector> fieldOIs = new ArrayList<>(); + + fieldNames.add("word"); + fieldNames.add("i"); + fieldNames.add("wi"); + + if (isStringInput) { + fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); + } else { + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + } + + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + + this.model = null; + this.word2index = null; + this.S = null; + this.aliasWordIds = null; + + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + @Override + public void process(Object[] args) throws HiveException { + if (model == null) { + parseNegativeTable(args[0]); + this.model = createModel(); + } + + List<?> rawDoc = docOI.getList(args[1]); + + // parse rawDoc + final int docLength = rawDoc.size(); + final int[] doc = new int[docLength]; + if (isStringInput) { + for (int i = 0; i < docLength; i++) { + doc[i] = getWordId(PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI)); + } + } else { + for (int i = 0; i < docLength; i++) { + doc[i] = PrimitiveObjectInspectorUtils.getInt(rawDoc.get(i), wordOI); + } + } + + model.trainOnDoc(doc); + } + + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("n", "numTrainWords", true, + "The number of words in the documents. It is used to update learning rate"); + opts.addOption("dim", "dimension", true, "The number of vector dimension [default: 100]"); + opts.addOption("win", "window", true, "Context window size [default: 5]"); + opts.addOption("neg", "negative", true, + "The number of negative sampled words per word [default: 5]"); + opts.addOption("iter", "iteration", true, "The number of iterations [default: 5]"); + opts.addOption("model", "modelName", true, + "The model name of word2vec: skipgram or cbow [default: skipgram]"); + opts.addOption( + "lr", + "learningRate", + true, + "Initial learning rate of SGD. The default value depends on model [default: 0.025 (skipgram), 0.05 (cbow)]"); + + return opts; + } + + @Override + protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { + CommandLine cl = null; + int win = 5; + int neg = 5; + int iter = 5; + int dim = 100; + long numTrainWords = 0L; + String modelName = "skipgram"; + float lr = 0.025f; + + if (argOIs.length >= 3) { + String rawArgs = HiveUtils.getConstString(argOIs[2]); + cl = parseOptions(rawArgs); + + numTrainWords = Primitives.parseLong(cl.getOptionValue("n"), numTrainWords); + if (numTrainWords <= 0) { + throw new UDFArgumentException("Argument `int numTrainWords` must be positive: " + + numTrainWords); + } + + dim = Primitives.parseInt(cl.getOptionValue("dim"), dim); + if (dim <= 0.d) { + throw new UDFArgumentException("Argument `int dim` must be positive: " + dim); + } + + win = Primitives.parseInt(cl.getOptionValue("win"), win); + if (win <= 0) { + throw new UDFArgumentException("Argument `int win` must be positive: " + win); + } + + neg = Primitives.parseInt(cl.getOptionValue("neg"), neg); + if (neg < 0) { + throw new UDFArgumentException("Argument `int neg` must be non-negative: " + neg); + } + + iter = Primitives.parseInt(cl.getOptionValue("iter"), iter); + if (iter <= 0) { + throw new UDFArgumentException("Argument `int iter` must be non-negative: " + iter); + } + + modelName = cl.getOptionValue("model", modelName); + if (!(modelName.equals("skipgram") || modelName.equals("cbow"))) { + throw new UDFArgumentException("Argument `string model` must be skipgram or cbow: " + + modelName); + } + + if (modelName.equals("cbow")) { + lr = 0.05f; + } + + lr = Primitives.parseFloat(cl.getOptionValue("lr"), lr); + if (lr <= 0.f) { + throw new UDFArgumentException("Argument `float lr` must be positive: " + lr); + } + } + + this.numTrainWords = numTrainWords; + this.win = win; + this.neg = neg; + this.iter = iter; + this.dim = dim; + this.skipgram = modelName.equals("skipgram"); + this.startingLR = lr; + return cl; + } + + public void close() throws HiveException { + if (model != null) { + forwardModel(); + this.model = null; + this.word2index = null; + this.S = null; + } + } + + private void forwardModel() throws HiveException { + if (isStringInput) { + final Text word = new Text(); + final IntWritable dimIndex = new IntWritable(); + final FloatWritable value = new FloatWritable(); + + final Object[] result = new Object[3]; + result[0] = word; + result[1] = dimIndex; + result[2] = value; + + IMapIterator<String, Integer> iter = word2index.entries(); + while (iter.next() != -1) { + int wordId = iter.getValue(); + if (!model.inputWeights.containsKey(wordId * dim)){ + continue; + } + + word.set(iter.getKey()); + + for (int i = 0; i < dim; i++) { + dimIndex.set(i); + value.set(model.inputWeights.get(wordId * dim + i)); + forward(result); + } + } + } else { + final IntWritable word = new IntWritable(); + final IntWritable dimIndex = new IntWritable(); + final FloatWritable value = new FloatWritable(); + + final Object[] result = new Object[3]; + result[0] = word; + result[1] = dimIndex; + result[2] = value; + + for (int wordId = 0; wordId < aliasWordIds.length; wordId++) { + if (!model.inputWeights.containsKey(wordId * dim)){ + break; + } + word.set(wordId); + for (int i = 0; i < dim; i++) { + dimIndex.set(i); + value.set(model.inputWeights.get(wordId * dim + i)); + forward(result); + } + } + } + } + + private int getWordId(@Nonnull final String word) { + if (word2index.containsKey(word)) { --- End diff -- `word2index` is not ensured to be non-null. ```java private static int getWordId(@Nonnull final String word, @CheckNotNull OpenHashTable<String, Integer> word2Index) { Precondition.checkNotNull(word2index); ```
---