Github user nzw0301 commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141553131 --- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java --- @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.UDTFWithOptions; +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; +import hivemall.utils.collections.maps.OpenHashTable; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Primitives; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import java.util.List; +import java.util.Arrays; +import java.util.ArrayList; + +@Description( + name = "train_word2vec", + value = "_FUNC_(array<array<float | string>> negative_table, array<int | string> doc [, const string options]) - Returns a prediction model") +public class Word2VecUDTF extends UDTFWithOptions { + protected transient AbstractWord2VecModel model; + @Nonnegative + private float startingLR; + @Nonnegative + private long numTrainWords; + private OpenHashTable<String, Integer> word2index; + + @Nonnegative + private int dim; + @Nonnegative + private int win; + @Nonnegative + private int neg; + @Nonnegative + private int iter; + private boolean skipgram; + private boolean isStringInput; + + private Int2FloatOpenHashTable S; + private int[] aliasWordIds; + + private ListObjectInspector negativeTableOI; + private ListObjectInspector negativeTableElementListOI; + private PrimitiveObjectInspector negativeTableElementOI; + + private ListObjectInspector docOI; + private PrimitiveObjectInspector wordOI; + + @Override + public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + final int numArgs = argOIs.length; + + if (numArgs != 3) { + throw new UDFArgumentException(getClass().getSimpleName() + + " takes 3 arguments: [, constant string options]: " + + Arrays.toString(argOIs)); + } + + processOptions(argOIs); + + this.negativeTableOI = HiveUtils.asListOI(argOIs[0]); + this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector()); + this.docOI = HiveUtils.asListOI(argOIs[1]); + + this.isStringInput = HiveUtils.isStringListOI(argOIs[1]); + + if (isStringInput) { + this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector()); + this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector()); + } else { + this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObjectInspector()); + this.wordOI = HiveUtils.asIntCompatibleOI(docOI.getListElementObjectInspector()); + } + + List<String> fieldNames = new ArrayList<>(); + List<ObjectInspector> fieldOIs = new ArrayList<>(); + + fieldNames.add("word"); + fieldNames.add("i"); + fieldNames.add("wi"); + + if (isStringInput) { + fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); + } else { + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + } + + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + + this.model = null; + this.word2index = null; + this.S = null; + this.aliasWordIds = null; + + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + @Override + public void process(Object[] args) throws HiveException { + if (model == null) { + parseNegativeTable(args[0]); + this.model = createModel(); + } + + List<?> rawDoc = docOI.getList(args[1]); + + // parse rawDoc + final int docLength = rawDoc.size(); + final int[] doc = new int[docLength]; + if (isStringInput) { + for (int i = 0; i < docLength; i++) { + doc[i] = getWordId(PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI)); + } + } else { + for (int i = 0; i < docLength; i++) { + doc[i] = PrimitiveObjectInspectorUtils.getInt(rawDoc.get(i), wordOI); + } + } + + model.trainOnDoc(doc); + } + + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("n", "numTrainWords", true, + "The number of words in the documents. It is used to update learning rate"); + opts.addOption("dim", "dimension", true, "The number of vector dimension [default: 100]"); + opts.addOption("win", "window", true, "Context window size [default: 5]"); + opts.addOption("neg", "negative", true, + "The number of negative sampled words per word [default: 5]"); + opts.addOption("iter", "iteration", true, "The number of iterations [default: 5]"); + opts.addOption("model", "modelName", true, + "The model name of word2vec: skipgram or cbow [default: skipgram]"); + opts.addOption( + "lr", --- End diff -- I see. Does `longOpt` remain `learningRate` or remove this field?
---