[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Github user coveralls commented on the issue: https://github.com/apache/incubator-hivemall/pull/116 [![Coverage Status](https://coveralls.io/builds/13472784/badge)](https://coveralls.io/builds/13472784) Coverage decreased (-0.6%) to 40.508% when pulling **0b163fade6f2d26ce918211c94a78c9a3b648cbe on nzw0301:skipgram** into **1e42387576fabbb326d451f4a00ac22d57828711 on apache:master**. ---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Github user coveralls commented on the issue: https://github.com/apache/incubator-hivemall/pull/116 [![Coverage Status](https://coveralls.io/builds/13472440/badge)](https://coveralls.io/builds/13472440) Coverage decreased (-0.6%) to 40.505% when pulling **8696f5ff668adf758d3545bab5885e51ce7d053e on nzw0301:skipgram** into **1e42387576fabbb326d451f4a00ac22d57828711 on apache:master**. ---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Github user nzw0301 commented on the issue: https://github.com/apache/incubator-hivemall/pull/116 @myui I resolved conflicts. ---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141556886 --- Diff: core/src/main/java/hivemall/embedding/SkipGramModel.java --- @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.math.random.PRNG; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; + +import javax.annotation.Nonnull; +import java.util.List; + +public final class SkipGramModel extends AbstractWord2VecModel { +protected SkipGramModel(final int dim, final int win, final int neg, final int iter, --- End diff -- Lot's of hyperparameters in constructor. Consider using Hyperparameter class as seen in https://github.com/apache/incubator-hivemall/blob/master/core/src/main/java/hivemall/fm/FMHyperParameters.java ---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141556621 --- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java --- @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.UDTFWithOptions; +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; +import hivemall.utils.collections.maps.OpenHashTable; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Primitives; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import java.util.List; +import java.util.Arrays; +import java.util.ArrayList; + +@Description( +name = "train_word2vec", +value = "_FUNC_(array> negative_table, array doc [, const string options]) - Returns a prediction model") +public class Word2VecUDTF extends UDTFWithOptions { +protected transient AbstractWord2VecModel model; +@Nonnegative +private float startingLR; +@Nonnegative +private long numTrainWords; +private OpenHashTable word2index; + +@Nonnegative +private int dim; +@Nonnegative +private int win; +@Nonnegative +private int neg; +@Nonnegative +private int iter; +private boolean skipgram; +private boolean isStringInput; + +private Int2FloatOpenHashTable S; +private int[] aliasWordIds; + +private ListObjectInspector negativeTableOI; +private ListObjectInspector negativeTableElementListOI; +private PrimitiveObjectInspector negativeTableElementOI; + +private ListObjectInspector docOI; +private PrimitiveObjectInspector wordOI; + +@Override +public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { +final int numArgs = argOIs.length; + +if (numArgs != 3) { +throw new UDFArgumentException(getClass().getSimpleName() ++ " takes 3 arguments: [, constant string options]: " ++ Arrays.toString(argOIs)); +} + +processOptions(argOIs); + +this.negativeTableOI = HiveUtils.asListOI(argOIs[0]); +this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector()); +this.docOI = HiveUtils.asListOI(argOIs[1]); + +this.isStringInput = HiveUtils.isStringListOI(argOIs[1]); + +if (isStringInput) { +this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector()); +this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector()); +} else { +this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObj
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141556391 --- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java --- @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.math.random.PRNG; +import hivemall.math.random.RandomNumberGeneratorFactory; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import java.util.List; + +public abstract class AbstractWord2VecModel { +// cached sigmoid function parameters +protected static final int MAX_SIGMOID = 6; +protected static final int SIGMOID_TABLE_SIZE = 1000; +protected float[] sigmoidTable; + + +@Nonnegative +protected int dim; +protected int win; +protected int neg; +protected int iter; + +// learning rate parameters +@Nonnegative +protected float lr; +@Nonnegative +private float startingLR; +@Nonnegative +private long numTrainWords; +@Nonnegative +protected long wordCount; +@Nonnegative +private long lastWordCount; + +protected PRNG rnd; + +protected Int2FloatOpenHashTable contextWeights; +protected Int2FloatOpenHashTable inputWeights; +protected Int2FloatOpenHashTable S; +protected int[] aliasWordId; + +protected AbstractWord2VecModel(final int dim, final int win, final int neg, final int iter, +final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S, +final int[] aliasWordId) { +this.win = win; +this.neg = neg; +this.iter = iter; +this.dim = dim; +this.startingLR = this.lr = startingLR; +this.numTrainWords = numTrainWords; + +// alias sampler for negative sampling +this.S = S; +this.aliasWordId = aliasWordId; + +this.wordCount = 0L; +this.lastWordCount = 0L; +this.rnd = RandomNumberGeneratorFactory.createPRNG(1001); + +this.sigmoidTable = initSigmoidTable(); + +// TODO how to estimate size +this.inputWeights = new Int2FloatOpenHashTable(10578 * dim); --- End diff -- 2^n or 1024 * 10 is more understandable. ---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user nzw0301 commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141553131 --- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java --- @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.UDTFWithOptions; +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; +import hivemall.utils.collections.maps.OpenHashTable; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Primitives; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import java.util.List; +import java.util.Arrays; +import java.util.ArrayList; + +@Description( +name = "train_word2vec", +value = "_FUNC_(array> negative_table, array doc [, const string options]) - Returns a prediction model") +public class Word2VecUDTF extends UDTFWithOptions { +protected transient AbstractWord2VecModel model; +@Nonnegative +private float startingLR; +@Nonnegative +private long numTrainWords; +private OpenHashTable word2index; + +@Nonnegative +private int dim; +@Nonnegative +private int win; +@Nonnegative +private int neg; +@Nonnegative +private int iter; +private boolean skipgram; +private boolean isStringInput; + +private Int2FloatOpenHashTable S; +private int[] aliasWordIds; + +private ListObjectInspector negativeTableOI; +private ListObjectInspector negativeTableElementListOI; +private PrimitiveObjectInspector negativeTableElementOI; + +private ListObjectInspector docOI; +private PrimitiveObjectInspector wordOI; + +@Override +public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { +final int numArgs = argOIs.length; + +if (numArgs != 3) { +throw new UDFArgumentException(getClass().getSimpleName() ++ " takes 3 arguments: [, constant string options]: " ++ Arrays.toString(argOIs)); +} + +processOptions(argOIs); + +this.negativeTableOI = HiveUtils.asListOI(argOIs[0]); +this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector()); +this.docOI = HiveUtils.asListOI(argOIs[1]); + +this.isStringInput = HiveUtils.isStringListOI(argOIs[1]); + +if (isStringInput) { +this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector()); +this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector()); +} else { +this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElement
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user nzw0301 commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141551510 --- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java --- @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.math.random.PRNG; +import hivemall.math.random.RandomNumberGeneratorFactory; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import java.util.List; + +public abstract class AbstractWord2VecModel { +// cached sigmoid function parameters +protected static final int MAX_SIGMOID = 6; +protected static final int SIGMOID_TABLE_SIZE = 1000; +protected float[] sigmoidTable; + + +@Nonnegative +protected int dim; +protected int win; +protected int neg; +protected int iter; + +// learning rate parameters +@Nonnegative +protected float lr; +@Nonnegative +private float startingLR; +@Nonnegative +private long numTrainWords; +@Nonnegative +protected long wordCount; +@Nonnegative +private long lastWordCount; + +protected PRNG rnd; + +protected Int2FloatOpenHashTable contextWeights; +protected Int2FloatOpenHashTable inputWeights; +protected Int2FloatOpenHashTable S; +protected int[] aliasWordId; + +protected AbstractWord2VecModel(final int dim, final int win, final int neg, final int iter, +final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S, +final int[] aliasWordId) { +this.win = win; +this.neg = neg; +this.iter = iter; +this.dim = dim; +this.startingLR = this.lr = startingLR; +this.numTrainWords = numTrainWords; + +// alias sampler for negative sampling +this.S = S; +this.aliasWordId = aliasWordId; + +this.wordCount = 0L; +this.lastWordCount = 0L; +this.rnd = RandomNumberGeneratorFactory.createPRNG(1001); + +this.sigmoidTable = initSigmoidTable(); + +// TODO how to estimate size +this.inputWeights = new Int2FloatOpenHashTable(10578 * dim); --- End diff -- There is no reason. ---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141550040 --- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java --- @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.math.random.PRNG; +import hivemall.math.random.RandomNumberGeneratorFactory; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import java.util.List; + +public abstract class AbstractWord2VecModel { +// cached sigmoid function parameters +protected static final int MAX_SIGMOID = 6; +protected static final int SIGMOID_TABLE_SIZE = 1000; +protected float[] sigmoidTable; + + +@Nonnegative +protected int dim; +protected int win; +protected int neg; +protected int iter; + +// learning rate parameters +@Nonnegative +protected float lr; +@Nonnegative +private float startingLR; +@Nonnegative +private long numTrainWords; +@Nonnegative +protected long wordCount; +@Nonnegative +private long lastWordCount; + +protected PRNG rnd; + +protected Int2FloatOpenHashTable contextWeights; +protected Int2FloatOpenHashTable inputWeights; +protected Int2FloatOpenHashTable S; +protected int[] aliasWordId; + +protected AbstractWord2VecModel(final int dim, final int win, final int neg, final int iter, +final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S, +final int[] aliasWordId) { +this.win = win; +this.neg = neg; +this.iter = iter; +this.dim = dim; +this.startingLR = this.lr = startingLR; +this.numTrainWords = numTrainWords; + +// alias sampler for negative sampling +this.S = S; +this.aliasWordId = aliasWordId; + +this.wordCount = 0L; +this.lastWordCount = 0L; +this.rnd = RandomNumberGeneratorFactory.createPRNG(1001); + +this.sigmoidTable = initSigmoidTable(); + +// TODO how to estimate size +this.inputWeights = new Int2FloatOpenHashTable(10578 * dim); --- End diff -- What's `10578`? ---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on the issue: https://github.com/apache/incubator-hivemall/pull/116 @nzw0301 Please rebase to master resolving ^ conflicts. ---
[GitHub] incubator-hivemall issue #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on the issue: https://github.com/apache/incubator-hivemall/pull/116 `What type of PR is it? => Improvement` should be `Feature`. ---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543095 --- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java --- @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.math.random.PRNG; +import hivemall.math.random.RandomNumberGeneratorFactory; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import java.util.List; + +public abstract class AbstractWord2VecModel { +// cached sigmoid function parameters +protected static final int MAX_SIGMOID = 6; +protected static final int SIGMOID_TABLE_SIZE = 1000; +protected float[] sigmoidTable; + + +@Nonnegative +protected int dim; +protected int win; +protected int neg; +protected int iter; + +// learning rate parameters +@Nonnegative +protected float lr; +@Nonnegative +private float startingLR; +@Nonnegative +private long numTrainWords; +@Nonnegative +protected long wordCount; +@Nonnegative +private long lastWordCount; + +protected PRNG rnd; + +protected Int2FloatOpenHashTable contextWeights; +protected Int2FloatOpenHashTable inputWeights; +protected Int2FloatOpenHashTable S; +protected int[] aliasWordId; + +protected AbstractWord2VecModel(final int dim, final int win, final int neg, final int iter, --- End diff -- add `@Nonnegative` for each constructor argument and caller methods. ---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141546522 --- Diff: core/src/main/java/hivemall/embedding/AliasTableBuilderUDTF.java --- @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; +import hivemall.utils.collections.maps.Int2IntOpenHashTable; +import hivemall.utils.hadoop.HiveUtils; + +import java.util.List; +import java.util.ArrayList; +import java.util.Map; +import java.util.Queue; +import java.util.ArrayDeque; + +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; + +import javax.annotation.Nonnull; + +public final class AliasTableBuilderUDTF extends GenericUDTF { --- End diff -- Add Javadoc comment for the class referring papers. ``` - A. J. Walker, New Fast Method for Generating Discrete Random Numbers with Arbitrary Frequency Distributions, in Electronics Letters 10, no. 8, pp. 127-128, 1974. - A. J. Walker, An Efficient Method for Generating Discrete Random Variables with General Distributions. ACM Transactions on Mathematical Software 3, no. 3, pp. 253-256, 1977. ``` ---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141545805 --- Diff: core/src/main/java/hivemall/embedding/AliasTableBuilderUDTF.java --- @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; +import hivemall.utils.collections.maps.Int2IntOpenHashTable; +import hivemall.utils.hadoop.HiveUtils; + +import java.util.List; +import java.util.ArrayList; +import java.util.Map; +import java.util.Queue; +import java.util.ArrayDeque; + +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; + +import javax.annotation.Nonnull; + +public final class AliasTableBuilderUDTF extends GenericUDTF { +private MapObjectInspector negativeTableOI; +private PrimitiveObjectInspector negativeTableKeyOI; +private PrimitiveObjectInspector negativeTableValueOI; + +private int numVocab; +private List index2word; +private Int2IntOpenHashTable A; +private Int2FloatOpenHashTable S; +private boolean isIntElement; + +@Override +public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { +if (!(argOIs.length >= 1)) { +throw new UDFArgumentException( +"_FUNC_(map) takes at least one argument"); +} + +this.negativeTableOI = HiveUtils.asMapOI(argOIs[0]); +this.negativeTableValueOI = HiveUtils.asFloatingPointOI(negativeTableOI.getMapValueObjectInspector()); + +boolean isIntEmelentOI = HiveUtils.isIntOI((negativeTableOI.getMapKeyObjectInspector())); + +if (isIntEmelentOI) { +this.negativeTableKeyOI = HiveUtils.asIntCompatibleOI(negativeTableOI.getMapKeyObjectInspector()); +} else { +this.negativeTableKeyOI = HiveUtils.asStringOI(negativeTableOI.getMapKeyObjectInspector()); +} + +List fieldNames = new ArrayList<>(); +List fieldOIs = new ArrayList<>(); +fieldNames.add("word"); + +if (isIntEmelentOI) { + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); +} else { + fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); +} + +fieldNames.add("p"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + +fieldNames.add("other"); +if (isIntEmelentOI) { + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); +} else { + fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); +} + +this.isIntElement = isIntEmelentOI; +return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); +} + +@Override +public void process(Object[] args) throws HiveException { +if (!isIntElement) { +
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141547369 --- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java --- @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.UDTFWithOptions; +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; +import hivemall.utils.collections.maps.OpenHashTable; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Primitives; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import java.util.List; +import java.util.Arrays; +import java.util.ArrayList; + +@Description( +name = "train_word2vec", +value = "_FUNC_(array> negative_table, array doc [, const string options]) - Returns a prediction model") +public class Word2VecUDTF extends UDTFWithOptions { +protected transient AbstractWord2VecModel model; +@Nonnegative +private float startingLR; +@Nonnegative +private long numTrainWords; +private OpenHashTable word2index; + +@Nonnegative +private int dim; +@Nonnegative +private int win; +@Nonnegative +private int neg; +@Nonnegative +private int iter; +private boolean skipgram; +private boolean isStringInput; + +private Int2FloatOpenHashTable S; +private int[] aliasWordIds; + +private ListObjectInspector negativeTableOI; +private ListObjectInspector negativeTableElementListOI; +private PrimitiveObjectInspector negativeTableElementOI; + +private ListObjectInspector docOI; +private PrimitiveObjectInspector wordOI; + +@Override +public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { +final int numArgs = argOIs.length; + +if (numArgs != 3) { +throw new UDFArgumentException(getClass().getSimpleName() ++ " takes 3 arguments: [, constant string options]: " ++ Arrays.toString(argOIs)); +} + +processOptions(argOIs); + +this.negativeTableOI = HiveUtils.asListOI(argOIs[0]); +this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector()); +this.docOI = HiveUtils.asListOI(argOIs[1]); + +this.isStringInput = HiveUtils.isStringListOI(argOIs[1]); + +if (isStringInput) { +this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector()); +this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector()); +} else { +this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObj
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543986 --- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java --- @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.UDTFWithOptions; +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; +import hivemall.utils.collections.maps.OpenHashTable; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Primitives; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import java.util.List; +import java.util.Arrays; +import java.util.ArrayList; + +@Description( +name = "train_word2vec", +value = "_FUNC_(array> negative_table, array doc [, const string options]) - Returns a prediction model") +public class Word2VecUDTF extends UDTFWithOptions { +protected transient AbstractWord2VecModel model; +@Nonnegative +private float startingLR; +@Nonnegative +private long numTrainWords; +private OpenHashTable word2index; + +@Nonnegative +private int dim; +@Nonnegative +private int win; +@Nonnegative +private int neg; +@Nonnegative +private int iter; +private boolean skipgram; +private boolean isStringInput; + +private Int2FloatOpenHashTable S; +private int[] aliasWordIds; + +private ListObjectInspector negativeTableOI; +private ListObjectInspector negativeTableElementListOI; +private PrimitiveObjectInspector negativeTableElementOI; + +private ListObjectInspector docOI; +private PrimitiveObjectInspector wordOI; + +@Override +public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { +final int numArgs = argOIs.length; + +if (numArgs != 3) { +throw new UDFArgumentException(getClass().getSimpleName() ++ " takes 3 arguments: [, constant string options]: " ++ Arrays.toString(argOIs)); +} + +processOptions(argOIs); + +this.negativeTableOI = HiveUtils.asListOI(argOIs[0]); +this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector()); +this.docOI = HiveUtils.asListOI(argOIs[1]); + +this.isStringInput = HiveUtils.isStringListOI(argOIs[1]); + +if (isStringInput) { +this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector()); +this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector()); +} else { +this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObj
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141545337 --- Diff: docs/gitbook/embedding/word2vec.md --- @@ -0,0 +1,399 @@ + + +Word Embedding is a powerful tool for many tasks, +e.g. finding similar words, +feature vectors for supervised machine learning task and word analogy, +such as `king - man + woman =~ queen`. +In word embedding, +each word represents a low dimension and dense vector. +**Skip-Gram** and **Continuous Bag-of-words** (CBoW) are the most popular algorithms to obtain good word embeddings (a.k.a word2vec). + +The papers introduce the method are as follows: + +- T. Mikolov, et al., [Distributed Representations of Words and Phrases and Their Compositionality +](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf). NIPS, 2013. +- T. Mikolov, et al., [Efficient Estimation of Word Representations in Vector Space](https://arxiv.org/abs/1301.3781). ICLR, 2013. + +Hivemall provides two type algorithms: Skip-gram and CBoW with negative sampling. +Hivemall enables you to train your sequence data such as, +but not limited to, documents based on word2vec. +This article gives usage instructions of the feature. + + + +> Note +> This feature is supported from Hivemall v0.5-rc.? or later. + +# Prepare document data + +Assume that you already have `docs` table which contains many documents as string format with unique index: + +```sql +select * FROM docs; +``` + +| docId | doc | +|:: |:| +| 0 | "Alice was beginning to get very tired of sitting by her sister on the bank ..." | +| ... | ... | + +First, each document is split into words by tokenize function like a [`tokenize`](../misc/tokenizer.html). + +```sql +drop table docs_words; +create table docs_words as + select +docid, +tokenize(doc, true) as words + FROM +docs +; +``` + +This table shows tokenized document. + +| docId | doc | +|:: |:| +| 0 | ["alice", "was", "beginning", "to", "get", "very", "tired", "of", "sitting", "by", "her", "sister", "on", "the", "bank", ...] | +| ... | ... | + +Then, you count frequency up per word and remove low frequency words from the vocabulary. +To remove low frequency words is optional preprocessing, but this process is effective to train word vector fastly. + +```sql +set hivevar:mincount=5; + +drop table freq; +create table freq as +select + row_number() over () - 1 as wordid, + word, + freq +from ( + select +word, +COUNT(*) as freq + from +docs_words + LATERAL VIEW explode(words) lTable as word + group by +word +) t +where freq >= ${mincount} +; +``` + +Hivemall's word2vec supports two type words; string and int. +String type tends to use huge memory during training. +On the other hand, int type tends to use less memory. +If you train on small dataset, we recommend using string type, +because memory usage can be ignored and HiveQL is more simple. +If you train on large dataset, we recommend using int type, +because it saves memory during training. + +# Create sub-sampling table + +Sub-sampling table is stored a sub-sampling probability per word. + +The sub-sampling probability of word $$w_i$$ is computed by the following equation: + +$$ +\begin{aligned} +f(w_i) = \sqrt{\frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}} + \frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)} +\end{aligned} +$$ + +During word2vec training, --- End diff -- remove line break after `,`. ---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141544782 --- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java --- @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.UDTFWithOptions; +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; +import hivemall.utils.collections.maps.OpenHashTable; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Primitives; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import java.util.List; +import java.util.Arrays; +import java.util.ArrayList; + +@Description( +name = "train_word2vec", +value = "_FUNC_(array> negative_table, array doc [, const string options]) - Returns a prediction model") +public class Word2VecUDTF extends UDTFWithOptions { +protected transient AbstractWord2VecModel model; +@Nonnegative +private float startingLR; +@Nonnegative +private long numTrainWords; +private OpenHashTable word2index; + +@Nonnegative +private int dim; +@Nonnegative +private int win; +@Nonnegative +private int neg; +@Nonnegative +private int iter; +private boolean skipgram; +private boolean isStringInput; + +private Int2FloatOpenHashTable S; +private int[] aliasWordIds; + +private ListObjectInspector negativeTableOI; +private ListObjectInspector negativeTableElementListOI; +private PrimitiveObjectInspector negativeTableElementOI; + +private ListObjectInspector docOI; +private PrimitiveObjectInspector wordOI; + +@Override +public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { +final int numArgs = argOIs.length; + +if (numArgs != 3) { +throw new UDFArgumentException(getClass().getSimpleName() ++ " takes 3 arguments: [, constant string options]: " ++ Arrays.toString(argOIs)); +} + +processOptions(argOIs); + +this.negativeTableOI = HiveUtils.asListOI(argOIs[0]); +this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector()); +this.docOI = HiveUtils.asListOI(argOIs[1]); + +this.isStringInput = HiveUtils.isStringListOI(argOIs[1]); + +if (isStringInput) { +this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector()); +this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector()); +} else { +this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObj
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141546893 --- Diff: core/src/main/java/hivemall/embedding/CBoWModel.java --- @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.math.random.PRNG; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; + +import javax.annotation.Nonnull; +import java.util.List; + +public final class CBoWModel extends AbstractWord2VecModel { +protected CBoWModel(final int dim, final int win, final int neg, final int iter, +final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S, +final int[] aliasWordId) { +super(dim, win, neg, iter, startingLR, numTrainWords, S, aliasWordId); +} + +protected void trainOnDoc(@Nonnull final int[] doc) { +final int vecDim = dim; +final int numNegative = neg; +final PRNG _rnd = rnd; +final Int2FloatOpenHashTable _S = S; +final int[] _aliasWordId = aliasWordId; +float label, gradient; + +// reuse instance +int windowSize, k, numContext, targetWord, inWord, positiveWord; + +updateLearningRate(); + +int docLength = doc.length; --- End diff -- `final int docLength` ---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543209 --- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java --- @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.math.random.PRNG; +import hivemall.math.random.RandomNumberGeneratorFactory; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import java.util.List; + +public abstract class AbstractWord2VecModel { +// cached sigmoid function parameters +protected static final int MAX_SIGMOID = 6; +protected static final int SIGMOID_TABLE_SIZE = 1000; +protected float[] sigmoidTable; + + +@Nonnegative +protected int dim; +protected int win; +protected int neg; +protected int iter; + +// learning rate parameters +@Nonnegative +protected float lr; +@Nonnegative +private float startingLR; +@Nonnegative +private long numTrainWords; +@Nonnegative +protected long wordCount; +@Nonnegative +private long lastWordCount; + +protected PRNG rnd; + +protected Int2FloatOpenHashTable contextWeights; +protected Int2FloatOpenHashTable inputWeights; +protected Int2FloatOpenHashTable S; +protected int[] aliasWordId; + +protected AbstractWord2VecModel(final int dim, final int win, final int neg, final int iter, +final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S, +final int[] aliasWordId) { +this.win = win; +this.neg = neg; +this.iter = iter; +this.dim = dim; +this.startingLR = this.lr = startingLR; +this.numTrainWords = numTrainWords; + +// alias sampler for negative sampling +this.S = S; +this.aliasWordId = aliasWordId; + +this.wordCount = 0L; +this.lastWordCount = 0L; +this.rnd = RandomNumberGeneratorFactory.createPRNG(1001); + +this.sigmoidTable = initSigmoidTable(); + +// TODO how to estimate size +this.inputWeights = new Int2FloatOpenHashTable(10578 * dim); +this.inputWeights.defaultReturnValue(0.f); +this.contextWeights = new Int2FloatOpenHashTable(10578 * dim); +this.contextWeights.defaultReturnValue(0.f); +} + +private static float[] initSigmoidTable() { +float[] sigmoidTable = new float[SIGMOID_TABLE_SIZE]; +for (int i = 0; i < SIGMOID_TABLE_SIZE; i++) { +float x = ((float) i / SIGMOID_TABLE_SIZE * 2 - 1) * (float) MAX_SIGMOID; +sigmoidTable[i] = 1.f / ((float) Math.exp(-x) + 1.f); +} +return sigmoidTable; +} + +protected void initWordWeights(final int wordId) { +for (int i = 0; i < dim; i++) { +inputWeights.put(wordId * dim + i, ((float) rnd.nextDouble() - 0.5f) / dim); +} +} + +protected static float sigmoid(final float v, final float[] sigmoidTable) { --- End diff -- `@Nonnull` for sigmoidTable ---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141545257 --- Diff: docs/gitbook/embedding/word2vec.md --- @@ -0,0 +1,399 @@ + + +Word Embedding is a powerful tool for many tasks, +e.g. finding similar words, +feature vectors for supervised machine learning task and word analogy, +such as `king - man + woman =~ queen`. +In word embedding, +each word represents a low dimension and dense vector. +**Skip-Gram** and **Continuous Bag-of-words** (CBoW) are the most popular algorithms to obtain good word embeddings (a.k.a word2vec). + +The papers introduce the method are as follows: + +- T. Mikolov, et al., [Distributed Representations of Words and Phrases and Their Compositionality +](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf). NIPS, 2013. +- T. Mikolov, et al., [Efficient Estimation of Word Representations in Vector Space](https://arxiv.org/abs/1301.3781). ICLR, 2013. + +Hivemall provides two type algorithms: Skip-gram and CBoW with negative sampling. +Hivemall enables you to train your sequence data such as, +but not limited to, documents based on word2vec. +This article gives usage instructions of the feature. + + + +> Note +> This feature is supported from Hivemall v0.5-rc.? or later. + +# Prepare document data + +Assume that you already have `docs` table which contains many documents as string format with unique index: + +```sql +select * FROM docs; +``` + +| docId | doc | +|:: |:| +| 0 | "Alice was beginning to get very tired of sitting by her sister on the bank ..." | +| ... | ... | + +First, each document is split into words by tokenize function like a [`tokenize`](../misc/tokenizer.html). + +```sql +drop table docs_words; +create table docs_words as + select +docid, +tokenize(doc, true) as words + FROM +docs +; +``` + +This table shows tokenized document. + +| docId | doc | +|:: |:| +| 0 | ["alice", "was", "beginning", "to", "get", "very", "tired", "of", "sitting", "by", "her", "sister", "on", "the", "bank", ...] | +| ... | ... | + +Then, you count frequency up per word and remove low frequency words from the vocabulary. +To remove low frequency words is optional preprocessing, but this process is effective to train word vector fastly. + +```sql +set hivevar:mincount=5; + +drop table freq; +create table freq as +select + row_number() over () - 1 as wordid, + word, + freq +from ( + select +word, +COUNT(*) as freq + from +docs_words + LATERAL VIEW explode(words) lTable as word + group by +word +) t +where freq >= ${mincount} +; +``` + +Hivemall's word2vec supports two type words; string and int. +String type tends to use huge memory during training. +On the other hand, int type tends to use less memory. +If you train on small dataset, we recommend using string type, +because memory usage can be ignored and HiveQL is more simple. +If you train on large dataset, we recommend using int type, +because it saves memory during training. + +# Create sub-sampling table + +Sub-sampling table is stored a sub-sampling probability per word. + +The sub-sampling probability of word $$w_i$$ is computed by the following equation: + +$$ +\begin{aligned} +f(w_i) = \sqrt{\frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}} + \frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)} +\end{aligned} +$$ + +During word2vec training, +not sub-sampled words are ignored. +It works to train fastly and to consider the imbalance the rare words and frequent words by reducing frequent words. +The smaller `sample` value set, +the fewer words are used during training. + +```sql +set hivevar:sample=1e-4; + +drop table subsampling_table; +create table subsampling_table as +with stats as ( + select +sum(freq) as numTrainWords + FROM +freq +) +select + l.wordid, + l.word, + sqrt(${sample}/(l.freq/r.numTrainWords)) + ${sample}/(l.freq/r.numTrainWords) as p +from + freq l +cross join + stats r +; +``` + +```sql +select * FROM subsampling_table order by p; +``` + +| wordid | word | p | +|:: | :: |::| +| 48645 | the | 0.04013665| +| 11245 | of | 0.052463654| +| 16368 | and | 0.0638| +| 61938 | 00 | 0.068162076| +| 19977 | in | 0.071441144| +| 83599 | 0| 0.07528994| +| 95017 | a| 0.07559573| +| 1
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141544983 --- Diff: docs/gitbook/embedding/word2vec.md --- @@ -0,0 +1,399 @@ + + +Word Embedding is a powerful tool for many tasks, +e.g. finding similar words, +feature vectors for supervised machine learning task and word analogy, +such as `king - man + woman =~ queen`. +In word embedding, +each word represents a low dimension and dense vector. +**Skip-Gram** and **Continuous Bag-of-words** (CBoW) are the most popular algorithms to obtain good word embeddings (a.k.a word2vec). + +The papers introduce the method are as follows: + +- T. Mikolov, et al., [Distributed Representations of Words and Phrases and Their Compositionality +](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf). NIPS, 2013. +- T. Mikolov, et al., [Efficient Estimation of Word Representations in Vector Space](https://arxiv.org/abs/1301.3781). ICLR, 2013. + +Hivemall provides two type algorithms: Skip-gram and CBoW with negative sampling. +Hivemall enables you to train your sequence data such as, +but not limited to, documents based on word2vec. +This article gives usage instructions of the feature. + + + +> Note +> This feature is supported from Hivemall v0.5-rc.? or later. + +# Prepare document data + +Assume that you already have `docs` table which contains many documents as string format with unique index: + +```sql +select * FROM docs; +``` + +| docId | doc | +|:: |:| +| 0 | "Alice was beginning to get very tired of sitting by her sister on the bank ..." | +| ... | ... | + +First, each document is split into words by tokenize function like a [`tokenize`](../misc/tokenizer.html). + +```sql +drop table docs_words; +create table docs_words as + select +docid, +tokenize(doc, true) as words + FROM +docs +; +``` + +This table shows tokenized document. + +| docId | doc | +|:: |:| +| 0 | ["alice", "was", "beginning", "to", "get", "very", "tired", "of", "sitting", "by", "her", "sister", "on", "the", "bank", ...] | +| ... | ... | + +Then, you count frequency up per word and remove low frequency words from the vocabulary. +To remove low frequency words is optional preprocessing, but this process is effective to train word vector fastly. + +```sql +set hivevar:mincount=5; + +drop table freq; +create table freq as +select + row_number() over () - 1 as wordid, + word, + freq +from ( + select +word, +COUNT(*) as freq + from +docs_words + LATERAL VIEW explode(words) lTable as word + group by +word +) t +where freq >= ${mincount} +; +``` + +Hivemall's word2vec supports two type words; string and int. +String type tends to use huge memory during training. +On the other hand, int type tends to use less memory. +If you train on small dataset, we recommend using string type, +because memory usage can be ignored and HiveQL is more simple. +If you train on large dataset, we recommend using int type, +because it saves memory during training. + +# Create sub-sampling table + +Sub-sampling table is stored a sub-sampling probability per word. + +The sub-sampling probability of word $$w_i$$ is computed by the following equation: + +$$ +\begin{aligned} +f(w_i) = \sqrt{\frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}} + \frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)} +\end{aligned} +$$ + +During word2vec training, +not sub-sampled words are ignored. +It works to train fastly and to consider the imbalance the rare words and frequent words by reducing frequent words. +The smaller `sample` value set, +the fewer words are used during training. + +```sql +set hivevar:sample=1e-4; + +drop table subsampling_table; +create table subsampling_table as +with stats as ( + select +sum(freq) as numTrainWords + FROM +freq +) +select + l.wordid, + l.word, + sqrt(${sample}/(l.freq/r.numTrainWords)) + ${sample}/(l.freq/r.numTrainWords) as p +from + freq l +cross join + stats r +; +``` + +```sql +select * FROM subsampling_table order by p; +``` + +| wordid | word | p | +|:: | :: |::| +| 48645 | the | 0.04013665| +| 11245 | of | 0.052463654| +| 16368 | and | 0.0638| +| 61938 | 00 | 0.068162076| +| 19977 | in | 0.071441144| +| 83599 | 0| 0.07528994| +| 95017 | a| 0.07559573| +| 1
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543643 --- Diff: core/src/main/java/hivemall/embedding/CBoWModel.java --- @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.math.random.PRNG; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; + +import javax.annotation.Nonnull; +import java.util.List; + +public final class CBoWModel extends AbstractWord2VecModel { +protected CBoWModel(final int dim, final int win, final int neg, final int iter, +final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S, +final int[] aliasWordId) { +super(dim, win, neg, iter, startingLR, numTrainWords, S, aliasWordId); +} + +protected void trainOnDoc(@Nonnull final int[] doc) { +final int vecDim = dim; +final int numNegative = neg; +final PRNG _rnd = rnd; +final Int2FloatOpenHashTable _S = S; +final int[] _aliasWordId = aliasWordId; +float label, gradient; + +// reuse instance +int windowSize, k, numContext, targetWord, inWord, positiveWord; + +updateLearningRate(); + +int docLength = doc.length; +for (int t = 0; t < iter; t++) { +for (int positiveWordPosition = 0; positiveWordPosition < docLength; positiveWordPosition++) { +windowSize = _rnd.nextInt(win) + 1; + +numContext = windowSize * 2 + Math.min(0, positiveWordPosition - windowSize) ++ Math.min(0, docLength - positiveWordPosition - windowSize - 1); + +float[] gradVec = new float[vecDim]; --- End diff -- add `final` for `gradVec` and `averageVec`. ---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141546846 --- Diff: core/src/main/java/hivemall/embedding/CBoWModel.java --- @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.math.random.PRNG; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; + +import javax.annotation.Nonnull; +import java.util.List; + +public final class CBoWModel extends AbstractWord2VecModel { +protected CBoWModel(final int dim, final int win, final int neg, final int iter, --- End diff -- add a blank line before constructor. ---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141545219 --- Diff: docs/gitbook/embedding/word2vec.md --- @@ -0,0 +1,399 @@ + + +Word Embedding is a powerful tool for many tasks, +e.g. finding similar words, +feature vectors for supervised machine learning task and word analogy, +such as `king - man + woman =~ queen`. +In word embedding, +each word represents a low dimension and dense vector. +**Skip-Gram** and **Continuous Bag-of-words** (CBoW) are the most popular algorithms to obtain good word embeddings (a.k.a word2vec). + +The papers introduce the method are as follows: + +- T. Mikolov, et al., [Distributed Representations of Words and Phrases and Their Compositionality +](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf). NIPS, 2013. +- T. Mikolov, et al., [Efficient Estimation of Word Representations in Vector Space](https://arxiv.org/abs/1301.3781). ICLR, 2013. + +Hivemall provides two type algorithms: Skip-gram and CBoW with negative sampling. +Hivemall enables you to train your sequence data such as, +but not limited to, documents based on word2vec. +This article gives usage instructions of the feature. + + + +> Note +> This feature is supported from Hivemall v0.5-rc.? or later. + +# Prepare document data + +Assume that you already have `docs` table which contains many documents as string format with unique index: + +```sql +select * FROM docs; +``` + +| docId | doc | +|:: |:| +| 0 | "Alice was beginning to get very tired of sitting by her sister on the bank ..." | +| ... | ... | + +First, each document is split into words by tokenize function like a [`tokenize`](../misc/tokenizer.html). + +```sql +drop table docs_words; +create table docs_words as + select +docid, +tokenize(doc, true) as words + FROM +docs +; +``` + +This table shows tokenized document. + +| docId | doc | +|:: |:| +| 0 | ["alice", "was", "beginning", "to", "get", "very", "tired", "of", "sitting", "by", "her", "sister", "on", "the", "bank", ...] | +| ... | ... | + +Then, you count frequency up per word and remove low frequency words from the vocabulary. +To remove low frequency words is optional preprocessing, but this process is effective to train word vector fastly. + +```sql +set hivevar:mincount=5; + +drop table freq; +create table freq as +select + row_number() over () - 1 as wordid, + word, + freq +from ( + select +word, +COUNT(*) as freq + from +docs_words + LATERAL VIEW explode(words) lTable as word + group by +word +) t +where freq >= ${mincount} +; +``` + +Hivemall's word2vec supports two type words; string and int. +String type tends to use huge memory during training. +On the other hand, int type tends to use less memory. +If you train on small dataset, we recommend using string type, +because memory usage can be ignored and HiveQL is more simple. +If you train on large dataset, we recommend using int type, +because it saves memory during training. + +# Create sub-sampling table + +Sub-sampling table is stored a sub-sampling probability per word. + +The sub-sampling probability of word $$w_i$$ is computed by the following equation: + +$$ +\begin{aligned} +f(w_i) = \sqrt{\frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}} + \frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)} +\end{aligned} +$$ + +During word2vec training, +not sub-sampled words are ignored. +It works to train fastly and to consider the imbalance the rare words and frequent words by reducing frequent words. +The smaller `sample` value set, +the fewer words are used during training. + +```sql +set hivevar:sample=1e-4; + +drop table subsampling_table; +create table subsampling_table as +with stats as ( + select +sum(freq) as numTrainWords + FROM +freq +) +select + l.wordid, + l.word, + sqrt(${sample}/(l.freq/r.numTrainWords)) + ${sample}/(l.freq/r.numTrainWords) as p +from + freq l +cross join + stats r +; +``` + +```sql +select * FROM subsampling_table order by p; +``` + +| wordid | word | p | +|:: | :: |::| +| 48645 | the | 0.04013665| +| 11245 | of | 0.052463654| +| 16368 | and | 0.0638| +| 61938 | 00 | 0.068162076| +| 19977 | in | 0.071441144| +| 83599 | 0| 0.07528994| +| 95017 | a| 0.07559573| +| 1
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543757 --- Diff: core/src/main/java/hivemall/embedding/CBoWModel.java --- @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.math.random.PRNG; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; + +import javax.annotation.Nonnull; +import java.util.List; + +public final class CBoWModel extends AbstractWord2VecModel { +protected CBoWModel(final int dim, final int win, final int neg, final int iter, +final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S, +final int[] aliasWordId) { +super(dim, win, neg, iter, startingLR, numTrainWords, S, aliasWordId); +} + +protected void trainOnDoc(@Nonnull final int[] doc) { +final int vecDim = dim; +final int numNegative = neg; +final PRNG _rnd = rnd; +final Int2FloatOpenHashTable _S = S; +final int[] _aliasWordId = aliasWordId; +float label, gradient; + +// reuse instance +int windowSize, k, numContext, targetWord, inWord, positiveWord; + +updateLearningRate(); + +int docLength = doc.length; +for (int t = 0; t < iter; t++) { +for (int positiveWordPosition = 0; positiveWordPosition < docLength; positiveWordPosition++) { +windowSize = _rnd.nextInt(win) + 1; + +numContext = windowSize * 2 + Math.min(0, positiveWordPosition - windowSize) ++ Math.min(0, docLength - positiveWordPosition - windowSize - 1); + +float[] gradVec = new float[vecDim]; +float[] averageVec = new float[vecDim]; + +// collect context words +for (int contextPosition = positiveWordPosition - windowSize; contextPosition < positiveWordPosition ++ windowSize + 1; contextPosition++) { +if (contextPosition == positiveWordPosition || contextPosition < 0 +|| contextPosition >= docLength) { +continue; +} + +inWord = doc[contextPosition]; + +// average vector of input word vectors +if (!inputWeights.containsKey(inWord * vecDim)) { +initWordWeights(inWord); +} + +for (int i = 0; i < vecDim; i++) { +averageVec[i] += inputWeights.get(inWord * vecDim + i) / numContext; +} +} +positiveWord = doc[positiveWordPosition]; +// negative sampling +for (int d = 0; d < numNegative + 1; d++) { +if (d == 0) { +targetWord = positiveWord; +label = 1.f; +} else { +do { +k = _rnd.nextInt(_S.size()); +if (_S.get(k) > _rnd.nextDouble()) { +targetWord = k; +} else { +targetWord = _aliasWordId[k]; +} +} while (targetWord == positiveWord); +label = 0.f; +} + +gradient = grad(label, averageVec, targetWord) * lr; +for (int i = 0; i < vecDim; i++) { +gradVec[i] += gradient * contextWeights.get(targetWord * vecDim + i); +contextWeights.put(targetWord * vecDim + i, +contextWeight
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141547506 --- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java --- @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.UDTFWithOptions; +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; +import hivemall.utils.collections.maps.OpenHashTable; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Primitives; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import java.util.List; +import java.util.Arrays; +import java.util.ArrayList; + +@Description( +name = "train_word2vec", +value = "_FUNC_(array> negative_table, array doc [, const string options]) - Returns a prediction model") +public class Word2VecUDTF extends UDTFWithOptions { +protected transient AbstractWord2VecModel model; +@Nonnegative +private float startingLR; +@Nonnegative +private long numTrainWords; +private OpenHashTable word2index; + +@Nonnegative +private int dim; +@Nonnegative +private int win; +@Nonnegative +private int neg; +@Nonnegative +private int iter; +private boolean skipgram; +private boolean isStringInput; + +private Int2FloatOpenHashTable S; +private int[] aliasWordIds; + +private ListObjectInspector negativeTableOI; +private ListObjectInspector negativeTableElementListOI; +private PrimitiveObjectInspector negativeTableElementOI; + +private ListObjectInspector docOI; +private PrimitiveObjectInspector wordOI; + +@Override +public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { +final int numArgs = argOIs.length; + +if (numArgs != 3) { +throw new UDFArgumentException(getClass().getSimpleName() ++ " takes 3 arguments: [, constant string options]: " ++ Arrays.toString(argOIs)); +} + +processOptions(argOIs); + +this.negativeTableOI = HiveUtils.asListOI(argOIs[0]); +this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector()); +this.docOI = HiveUtils.asListOI(argOIs[1]); + +this.isStringInput = HiveUtils.isStringListOI(argOIs[1]); + +if (isStringInput) { +this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector()); +this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector()); +} else { +this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObj
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141547708 --- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java --- @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.UDTFWithOptions; +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; +import hivemall.utils.collections.maps.OpenHashTable; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Primitives; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import java.util.List; +import java.util.Arrays; +import java.util.ArrayList; + +@Description( +name = "train_word2vec", +value = "_FUNC_(array> negative_table, array doc [, const string options]) - Returns a prediction model") +public class Word2VecUDTF extends UDTFWithOptions { +protected transient AbstractWord2VecModel model; +@Nonnegative +private float startingLR; +@Nonnegative +private long numTrainWords; +private OpenHashTable word2index; + +@Nonnegative +private int dim; +@Nonnegative +private int win; +@Nonnegative +private int neg; +@Nonnegative +private int iter; +private boolean skipgram; +private boolean isStringInput; + +private Int2FloatOpenHashTable S; +private int[] aliasWordIds; + +private ListObjectInspector negativeTableOI; +private ListObjectInspector negativeTableElementListOI; +private PrimitiveObjectInspector negativeTableElementOI; + +private ListObjectInspector docOI; +private PrimitiveObjectInspector wordOI; + +@Override +public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { +final int numArgs = argOIs.length; + +if (numArgs != 3) { +throw new UDFArgumentException(getClass().getSimpleName() ++ " takes 3 arguments: [, constant string options]: " ++ Arrays.toString(argOIs)); +} + +processOptions(argOIs); + +this.negativeTableOI = HiveUtils.asListOI(argOIs[0]); +this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector()); +this.docOI = HiveUtils.asListOI(argOIs[1]); + +this.isStringInput = HiveUtils.isStringListOI(argOIs[1]); + +if (isStringInput) { +this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector()); +this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector()); +} else { +this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObj
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141545448 --- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java --- @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.UDTFWithOptions; +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; +import hivemall.utils.collections.maps.OpenHashTable; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Primitives; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import java.util.List; +import java.util.Arrays; +import java.util.ArrayList; + +@Description( +name = "train_word2vec", +value = "_FUNC_(array> negative_table, array doc [, const string options]) - Returns a prediction model") +public class Word2VecUDTF extends UDTFWithOptions { +protected transient AbstractWord2VecModel model; +@Nonnegative +private float startingLR; +@Nonnegative +private long numTrainWords; +private OpenHashTable word2index; + +@Nonnegative +private int dim; +@Nonnegative +private int win; +@Nonnegative +private int neg; +@Nonnegative +private int iter; +private boolean skipgram; +private boolean isStringInput; + +private Int2FloatOpenHashTable S; +private int[] aliasWordIds; + +private ListObjectInspector negativeTableOI; +private ListObjectInspector negativeTableElementListOI; +private PrimitiveObjectInspector negativeTableElementOI; + +private ListObjectInspector docOI; +private PrimitiveObjectInspector wordOI; + +@Override +public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { +final int numArgs = argOIs.length; + +if (numArgs != 3) { +throw new UDFArgumentException(getClass().getSimpleName() ++ " takes 3 arguments: [, constant string options]: " ++ Arrays.toString(argOIs)); +} + +processOptions(argOIs); + +this.negativeTableOI = HiveUtils.asListOI(argOIs[0]); +this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector()); +this.docOI = HiveUtils.asListOI(argOIs[1]); + +this.isStringInput = HiveUtils.isStringListOI(argOIs[1]); + +if (isStringInput) { +this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector()); +this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector()); +} else { +this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObj
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543514 --- Diff: core/src/main/java/hivemall/embedding/CBoWModel.java --- @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.math.random.PRNG; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; + +import javax.annotation.Nonnull; +import java.util.List; + +public final class CBoWModel extends AbstractWord2VecModel { +protected CBoWModel(final int dim, final int win, final int neg, final int iter, +final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S, +final int[] aliasWordId) { +super(dim, win, neg, iter, startingLR, numTrainWords, S, aliasWordId); +} + +protected void trainOnDoc(@Nonnull final int[] doc) { +final int vecDim = dim; +final int numNegative = neg; +final PRNG _rnd = rnd; +final Int2FloatOpenHashTable _S = S; --- End diff -- Member variable should be `_S` and local variable should be `S`. `_rnd`, `_aliasWordId` as well. ---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141546656 --- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java --- @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.math.random.PRNG; +import hivemall.math.random.RandomNumberGeneratorFactory; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import java.util.List; + +public abstract class AbstractWord2VecModel { +// cached sigmoid function parameters +protected static final int MAX_SIGMOID = 6; +protected static final int SIGMOID_TABLE_SIZE = 1000; +protected float[] sigmoidTable; + + +@Nonnegative +protected int dim; +protected int win; +protected int neg; +protected int iter; + +// learning rate parameters +@Nonnegative +protected float lr; +@Nonnegative +private float startingLR; +@Nonnegative +private long numTrainWords; +@Nonnegative +protected long wordCount; +@Nonnegative +private long lastWordCount; + +protected PRNG rnd; + +protected Int2FloatOpenHashTable contextWeights; +protected Int2FloatOpenHashTable inputWeights; +protected Int2FloatOpenHashTable S; +protected int[] aliasWordId; + +protected AbstractWord2VecModel(final int dim, final int win, final int neg, final int iter, +final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S, +final int[] aliasWordId) { +this.win = win; +this.neg = neg; +this.iter = iter; +this.dim = dim; +this.startingLR = this.lr = startingLR; +this.numTrainWords = numTrainWords; + +// alias sampler for negative sampling +this.S = S; +this.aliasWordId = aliasWordId; + +this.wordCount = 0L; +this.lastWordCount = 0L; +this.rnd = RandomNumberGeneratorFactory.createPRNG(1001); + +this.sigmoidTable = initSigmoidTable(); + +// TODO how to estimate size +this.inputWeights = new Int2FloatOpenHashTable(10578 * dim); +this.inputWeights.defaultReturnValue(0.f); +this.contextWeights = new Int2FloatOpenHashTable(10578 * dim); +this.contextWeights.defaultReturnValue(0.f); +} + +private static float[] initSigmoidTable() { --- End diff -- `@Nonnull` for return ---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141545135 --- Diff: docs/gitbook/embedding/word2vec.md --- @@ -0,0 +1,399 @@ + + +Word Embedding is a powerful tool for many tasks, +e.g. finding similar words, +feature vectors for supervised machine learning task and word analogy, +such as `king - man + woman =~ queen`. +In word embedding, +each word represents a low dimension and dense vector. +**Skip-Gram** and **Continuous Bag-of-words** (CBoW) are the most popular algorithms to obtain good word embeddings (a.k.a word2vec). + +The papers introduce the method are as follows: + +- T. Mikolov, et al., [Distributed Representations of Words and Phrases and Their Compositionality +](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf). NIPS, 2013. +- T. Mikolov, et al., [Efficient Estimation of Word Representations in Vector Space](https://arxiv.org/abs/1301.3781). ICLR, 2013. + +Hivemall provides two type algorithms: Skip-gram and CBoW with negative sampling. +Hivemall enables you to train your sequence data such as, +but not limited to, documents based on word2vec. +This article gives usage instructions of the feature. + + + +> Note +> This feature is supported from Hivemall v0.5-rc.? or later. + +# Prepare document data + +Assume that you already have `docs` table which contains many documents as string format with unique index: + +```sql +select * FROM docs; +``` + +| docId | doc | +|:: |:| +| 0 | "Alice was beginning to get very tired of sitting by her sister on the bank ..." | +| ... | ... | + +First, each document is split into words by tokenize function like a [`tokenize`](../misc/tokenizer.html). + +```sql +drop table docs_words; +create table docs_words as + select +docid, +tokenize(doc, true) as words + FROM +docs +; +``` + +This table shows tokenized document. + +| docId | doc | +|:: |:| +| 0 | ["alice", "was", "beginning", "to", "get", "very", "tired", "of", "sitting", "by", "her", "sister", "on", "the", "bank", ...] | +| ... | ... | + +Then, you count frequency up per word and remove low frequency words from the vocabulary. +To remove low frequency words is optional preprocessing, but this process is effective to train word vector fastly. + +```sql +set hivevar:mincount=5; + +drop table freq; +create table freq as +select + row_number() over () - 1 as wordid, + word, + freq +from ( + select +word, +COUNT(*) as freq + from +docs_words + LATERAL VIEW explode(words) lTable as word + group by +word +) t +where freq >= ${mincount} +; +``` + +Hivemall's word2vec supports two type words; string and int. +String type tends to use huge memory during training. +On the other hand, int type tends to use less memory. +If you train on small dataset, we recommend using string type, +because memory usage can be ignored and HiveQL is more simple. +If you train on large dataset, we recommend using int type, +because it saves memory during training. + +# Create sub-sampling table + +Sub-sampling table is stored a sub-sampling probability per word. + +The sub-sampling probability of word $$w_i$$ is computed by the following equation: + +$$ +\begin{aligned} +f(w_i) = \sqrt{\frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)}} + \frac{\mathrm{sample}}{freq(w_i)/\sum freq(w)} +\end{aligned} +$$ + +During word2vec training, +not sub-sampled words are ignored. +It works to train fastly and to consider the imbalance the rare words and frequent words by reducing frequent words. +The smaller `sample` value set, +the fewer words are used during training. + +```sql +set hivevar:sample=1e-4; + +drop table subsampling_table; +create table subsampling_table as +with stats as ( + select +sum(freq) as numTrainWords + FROM +freq +) +select + l.wordid, + l.word, + sqrt(${sample}/(l.freq/r.numTrainWords)) + ${sample}/(l.freq/r.numTrainWords) as p +from + freq l +cross join + stats r +; +``` + +```sql +select * FROM subsampling_table order by p; +``` + +| wordid | word | p | +|:: | :: |::| +| 48645 | the | 0.04013665| +| 11245 | of | 0.052463654| +| 16368 | and | 0.0638| +| 61938 | 00 | 0.068162076| +| 19977 | in | 0.071441144| +| 83599 | 0| 0.07528994| +| 95017 | a| 0.07559573| +| 1
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543945 --- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java --- @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.UDTFWithOptions; +import hivemall.utils.collections.IMapIterator; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; +import hivemall.utils.collections.maps.OpenHashTable; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Primitives; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import java.util.List; +import java.util.Arrays; +import java.util.ArrayList; + +@Description( +name = "train_word2vec", +value = "_FUNC_(array> negative_table, array doc [, const string options]) - Returns a prediction model") +public class Word2VecUDTF extends UDTFWithOptions { +protected transient AbstractWord2VecModel model; +@Nonnegative +private float startingLR; +@Nonnegative +private long numTrainWords; +private OpenHashTable word2index; + +@Nonnegative +private int dim; +@Nonnegative +private int win; +@Nonnegative +private int neg; +@Nonnegative +private int iter; +private boolean skipgram; +private boolean isStringInput; + +private Int2FloatOpenHashTable S; +private int[] aliasWordIds; + +private ListObjectInspector negativeTableOI; +private ListObjectInspector negativeTableElementListOI; +private PrimitiveObjectInspector negativeTableElementOI; + +private ListObjectInspector docOI; +private PrimitiveObjectInspector wordOI; + +@Override +public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { +final int numArgs = argOIs.length; + +if (numArgs != 3) { +throw new UDFArgumentException(getClass().getSimpleName() ++ " takes 3 arguments: [, constant string options]: " ++ Arrays.toString(argOIs)); +} + +processOptions(argOIs); + +this.negativeTableOI = HiveUtils.asListOI(argOIs[0]); +this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector()); +this.docOI = HiveUtils.asListOI(argOIs[1]); + +this.isStringInput = HiveUtils.isStringListOI(argOIs[1]); + +if (isStringInput) { +this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector()); +this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector()); +} else { +this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObj
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141542877 --- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java --- @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.math.random.PRNG; +import hivemall.math.random.RandomNumberGeneratorFactory; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import java.util.List; + +public abstract class AbstractWord2VecModel { +// cached sigmoid function parameters +protected static final int MAX_SIGMOID = 6; +protected static final int SIGMOID_TABLE_SIZE = 1000; +protected float[] sigmoidTable; + + --- End diff -- remove unnecessary blank line ---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543245 --- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java --- @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.math.random.PRNG; +import hivemall.math.random.RandomNumberGeneratorFactory; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import java.util.List; + +public abstract class AbstractWord2VecModel { +// cached sigmoid function parameters +protected static final int MAX_SIGMOID = 6; +protected static final int SIGMOID_TABLE_SIZE = 1000; +protected float[] sigmoidTable; + + +@Nonnegative +protected int dim; +protected int win; +protected int neg; +protected int iter; + +// learning rate parameters +@Nonnegative +protected float lr; +@Nonnegative +private float startingLR; +@Nonnegative +private long numTrainWords; +@Nonnegative +protected long wordCount; +@Nonnegative +private long lastWordCount; + +protected PRNG rnd; + +protected Int2FloatOpenHashTable contextWeights; +protected Int2FloatOpenHashTable inputWeights; +protected Int2FloatOpenHashTable S; +protected int[] aliasWordId; + +protected AbstractWord2VecModel(final int dim, final int win, final int neg, final int iter, +final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable S, +final int[] aliasWordId) { +this.win = win; +this.neg = neg; +this.iter = iter; +this.dim = dim; +this.startingLR = this.lr = startingLR; +this.numTrainWords = numTrainWords; + +// alias sampler for negative sampling +this.S = S; +this.aliasWordId = aliasWordId; + +this.wordCount = 0L; +this.lastWordCount = 0L; +this.rnd = RandomNumberGeneratorFactory.createPRNG(1001); + +this.sigmoidTable = initSigmoidTable(); + +// TODO how to estimate size +this.inputWeights = new Int2FloatOpenHashTable(10578 * dim); +this.inputWeights.defaultReturnValue(0.f); +this.contextWeights = new Int2FloatOpenHashTable(10578 * dim); +this.contextWeights.defaultReturnValue(0.f); +} + +private static float[] initSigmoidTable() { +float[] sigmoidTable = new float[SIGMOID_TABLE_SIZE]; +for (int i = 0; i < SIGMOID_TABLE_SIZE; i++) { +float x = ((float) i / SIGMOID_TABLE_SIZE * 2 - 1) * (float) MAX_SIGMOID; +sigmoidTable[i] = 1.f / ((float) Math.exp(-x) + 1.f); +} +return sigmoidTable; +} + +protected void initWordWeights(final int wordId) { +for (int i = 0; i < dim; i++) { +inputWeights.put(wordId * dim + i, ((float) rnd.nextDouble() - 0.5f) / dim); +} +} + +protected static float sigmoid(final float v, final float[] sigmoidTable) { +if (v > MAX_SIGMOID) { +return 1.f; +} else if (v < -MAX_SIGMOID) { +return 0.f; +} else { +return sigmoidTable[(int) ((v + MAX_SIGMOID) * (SIGMOID_TABLE_SIZE / MAX_SIGMOID / 2))]; +} +} + +protected void updateLearningRate() { +// TODO: valid lr? --- End diff -- remove this TODO comment and blank lines. ---
[GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Github user myui commented on a diff in the pull request: https://github.com/apache/incubator-hivemall/pull/116#discussion_r141542968 --- Diff: core/src/main/java/hivemall/embedding/AbstractWord2VecModel.java --- @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.embedding; + +import hivemall.math.random.PRNG; +import hivemall.math.random.RandomNumberGeneratorFactory; +import hivemall.utils.collections.maps.Int2FloatOpenHashTable; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import java.util.List; + +public abstract class AbstractWord2VecModel { +// cached sigmoid function parameters +protected static final int MAX_SIGMOID = 6; +protected static final int SIGMOID_TABLE_SIZE = 1000; +protected float[] sigmoidTable; + + +@Nonnegative +protected int dim; +protected int win; --- End diff -- `@Nonnegative` for each variable (win, neg, iter). ---