OPENNLP-986 - Stupid Backoff as default LM discounting
Project: http://git-wip-us.apache.org/repos/asf/opennlp/repo Commit: http://git-wip-us.apache.org/repos/asf/opennlp/commit/b1316eb4 Tree: http://git-wip-us.apache.org/repos/asf/opennlp/tree/b1316eb4 Diff: http://git-wip-us.apache.org/repos/asf/opennlp/diff/b1316eb4 Branch: refs/heads/parser_regression Commit: b1316eb4b488479db6e4dc5f4bb6606cb07dca08 Parents: 89e6af0 Author: Tommaso Teofili <[email protected]> Authored: Tue Feb 14 14:49:09 2017 +0100 Committer: Jörn Kottmann <[email protected]> Committed: Thu Apr 20 12:40:20 2017 +0200 ---------------------------------------------------------------------- .../tools/languagemodel/NGramLanguageModel.java | 74 +++++--------------- .../java/opennlp/tools/ngram/NGramUtils.java | 3 +- .../LanguageModelEvaluationTest.java | 2 +- .../languagemodel/NgramLanguageModelTest.java | 15 ++-- 4 files changed, 28 insertions(+), 66 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/opennlp/blob/b1316eb4/opennlp-tools/src/main/java/opennlp/tools/languagemodel/NGramLanguageModel.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/languagemodel/NGramLanguageModel.java b/opennlp-tools/src/main/java/opennlp/tools/languagemodel/NGramLanguageModel.java index e11c107..501c1bc 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/languagemodel/NGramLanguageModel.java +++ b/opennlp-tools/src/main/java/opennlp/tools/languagemodel/NGramLanguageModel.java @@ -26,52 +26,30 @@ import opennlp.tools.util.StringList; /** * A {@link opennlp.tools.languagemodel.LanguageModel} based on a {@link opennlp.tools.ngram.NGramModel} - * using Laplace smoothing probability estimation to get the probabilities of the ngrams. - * See also {@link NGramUtils#calculateLaplaceSmoothingProbability( - *opennlp.tools.util.StringList, Iterable, int, Double)}. + * using Stupid Backoff to get the probabilities of the ngrams. */ public class NGramLanguageModel extends NGramModel implements LanguageModel { private static final int DEFAULT_N = 3; - private static final double DEFAULT_K = 1d; private final int n; - private final double k; public NGramLanguageModel() { - this(DEFAULT_N, DEFAULT_K); + this(DEFAULT_N); } public NGramLanguageModel(int n) { - this(n, DEFAULT_K); - } - - public NGramLanguageModel(double k) { - this(DEFAULT_N, k); - } - - public NGramLanguageModel(int n, double k) { this.n = n; - this.k = k; } public NGramLanguageModel(InputStream in) throws IOException { - this(in, DEFAULT_N, DEFAULT_K); - } - - public NGramLanguageModel(InputStream in, double k) throws IOException { - this(in, DEFAULT_N, k); - } - - public NGramLanguageModel(InputStream in, int n) throws IOException { - this(in, n, DEFAULT_K); + this(in, DEFAULT_N); } - public NGramLanguageModel(InputStream in, int n, double k) + public NGramLanguageModel(InputStream in, int n) throws IOException { super(in); this.n = n; - this.k = k; } @Override @@ -79,24 +57,13 @@ public class NGramLanguageModel extends NGramModel implements LanguageModel { double probability = 0d; if (size() > 0) { for (StringList ngram : NGramUtils.getNGrams(sample, n)) { - StringList nMinusOneToken = NGramUtils - .getNMinusOneTokenFirst(ngram); - if (size() > 1000000) { - // use stupid backoff - probability += Math.log( - getStupidBackoffProbability(ngram, nMinusOneToken)); - } else { - // use laplace smoothing - probability += Math.log( - getLaplaceSmoothingProbability(ngram, nMinusOneToken)); + double score = stupidBackoff(ngram); + probability += Math.log(score); + if (Double.isNaN(probability)) { + probability = 0d; } } - if (Double.isNaN(probability)) { - probability = 0d; - } else if (probability != 0) { - probability = Math.exp(probability); - } - + probability = Math.exp(probability); } return probability; } @@ -125,24 +92,21 @@ public class NGramLanguageModel extends NGramModel implements LanguageModel { return token; } - private double getLaplaceSmoothingProbability(StringList ngram, - StringList nMinusOneToken) { - return (getCount(ngram) + k) / (getCount(nMinusOneToken) + k * size()); - } - - private double getStupidBackoffProbability(StringList ngram, - StringList nMinusOneToken) { + private double stupidBackoff(StringList ngram) { int count = getCount(ngram); + StringList nMinusOneToken = NGramUtils.getNMinusOneTokenFirst(ngram); if (nMinusOneToken == null || nMinusOneToken.size() == 0) { - return count / size(); + return (double) count / (double) size(); } else if (count > 0) { - return ((double) count) / ((double) getCount( - nMinusOneToken)); // maximum likelihood probability + double countM1 = getCount(nMinusOneToken); + if (countM1 == 0d) { + countM1 = size(); // to avoid Infinite if n-1grams do not exist + } + return (double) count / countM1; } else { - StringList nextNgram = NGramUtils.getNMinusOneTokenLast(ngram); - return 0.4d * getStupidBackoffProbability(nextNgram, - NGramUtils.getNMinusOneTokenFirst(nextNgram)); + return 0.4 * stupidBackoff(NGramUtils.getNMinusOneTokenLast(ngram)); } + } } http://git-wip-us.apache.org/repos/asf/opennlp/blob/b1316eb4/opennlp-tools/src/main/java/opennlp/tools/ngram/NGramUtils.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/ngram/NGramUtils.java b/opennlp-tools/src/main/java/opennlp/tools/ngram/NGramUtils.java index 0132c92..e41291f 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/ngram/NGramUtils.java +++ b/opennlp-tools/src/main/java/opennlp/tools/ngram/NGramUtils.java @@ -34,13 +34,12 @@ public class NGramUtils { * * @param ngram the ngram to get the probability for * @param set the vocabulary - * @param size the size of the vocabulary * @param k the smoothing factor * @return the Laplace smoothing probability * @see <a href="https://en.wikipedia.org/wiki/Additive_smoothing">Additive Smoothing</a> */ public static double calculateLaplaceSmoothingProbability(StringList ngram, - Iterable<StringList> set, int size, Double k) { + Iterable<StringList> set, Double k) { return (count(ngram, set) + k) / (count(getNMinusOneTokenFirst(ngram), set) + k * 1); } http://git-wip-us.apache.org/repos/asf/opennlp/blob/b1316eb4/opennlp-tools/src/test/java/opennlp/tools/languagemodel/LanguageModelEvaluationTest.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/test/java/opennlp/tools/languagemodel/LanguageModelEvaluationTest.java b/opennlp-tools/src/test/java/opennlp/tools/languagemodel/LanguageModelEvaluationTest.java index b6c3f01..d4e8e37 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/languagemodel/LanguageModelEvaluationTest.java +++ b/opennlp-tools/src/test/java/opennlp/tools/languagemodel/LanguageModelEvaluationTest.java @@ -54,7 +54,7 @@ public class LanguageModelEvaluationTest { NGramLanguageModel trigramLM = new NGramLanguageModel(3); for (StringList sentence : trainingVocabulary) { - trigramLM.add(sentence, 2, 3); + trigramLM.add(sentence, 1, 3); } double trigramPerplexity = LanguageModelTestUtils.getPerplexity(trigramLM, testVocabulary, 3); http://git-wip-us.apache.org/repos/asf/opennlp/blob/b1316eb4/opennlp-tools/src/test/java/opennlp/tools/languagemodel/NgramLanguageModelTest.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/test/java/opennlp/tools/languagemodel/NgramLanguageModelTest.java b/opennlp-tools/src/test/java/opennlp/tools/languagemodel/NgramLanguageModelTest.java index 7ffbf27..2ac1f5e 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/languagemodel/NgramLanguageModelTest.java +++ b/opennlp-tools/src/test/java/opennlp/tools/languagemodel/NgramLanguageModelTest.java @@ -22,7 +22,6 @@ import java.util.Arrays; import java.util.List; import org.apache.commons.io.IOUtils; - import org.junit.Assert; import org.junit.Test; @@ -47,7 +46,7 @@ public class NgramLanguageModelTest { public void testRandomVocabularyAndSentence() throws Exception { NGramLanguageModel model = new NGramLanguageModel(); for (StringList sentence : LanguageModelTestUtils.generateRandomVocabulary(10)) { - model.add(sentence, 2, 3); + model.add(sentence, 1, 3); } double probability = model.calculateProbability(LanguageModelTestUtils.generateRandomSentence()); Assert.assertTrue("a probability measure should be between 0 and 1 [was " @@ -71,7 +70,7 @@ public class NgramLanguageModelTest { @Test public void testBigramProbabilityNoSmoothing() throws Exception { - NGramLanguageModel model = new NGramLanguageModel(2, 0); + NGramLanguageModel model = new NGramLanguageModel(2); model.add(new StringList("<s>", "I", "am", "Sam", "</s>"), 1, 2); model.add(new StringList("<s>", "Sam", "I", "am", "</s>"), 1, 2); model.add(new StringList("<s>", "I", "do", "not", "like", "green", "eggs", "and", "ham", "</s>"), 1, 2); @@ -94,16 +93,16 @@ public class NgramLanguageModelTest { @Test public void testTrigram() throws Exception { NGramLanguageModel model = new NGramLanguageModel(3); - model.add(new StringList("I", "see", "the", "fox"), 2, 3); - model.add(new StringList("the", "red", "house"), 2, 3); - model.add(new StringList("I", "saw", "something", "nice"), 2, 3); + model.add(new StringList("I", "see", "the", "fox"), 1, 3); + model.add(new StringList("the", "red", "house"), 1, 3); + model.add(new StringList("I", "saw", "something", "nice"), 1, 3); double probability = model.calculateProbability(new StringList("I", "saw", "the", "red", "house")); Assert.assertTrue("a probability measure should be between 0 and 1 [was " + probability + "]", probability >= 0 && probability <= 1); StringList tokens = model.predictNextTokens(new StringList("I", "saw")); Assert.assertNotNull(tokens); - Assert.assertEquals(new StringList("something", "nice"), tokens); + Assert.assertEquals(new StringList("something"), tokens); } @Test @@ -128,7 +127,7 @@ public class NgramLanguageModelTest { double probability = languageModel.calculateProbability(new StringList("The", "brown", "fox", "jumped")); Assert.assertTrue("a probability measure should be between 0 and 1 [was " + probability + "]", probability >= 0 && probability <= 1); - StringList tokens = languageModel.predictNextTokens(new StringList("fox")); + StringList tokens = languageModel.predictNextTokens(new StringList("the","brown","fox")); Assert.assertNotNull(tokens); Assert.assertEquals(new StringList("jumped"), tokens); }
