OPENNLP-125: Make POS Tagger feature generation configurable
Project: http://git-wip-us.apache.org/repos/asf/opennlp/repo Commit: http://git-wip-us.apache.org/repos/asf/opennlp/commit/e06a5327 Tree: http://git-wip-us.apache.org/repos/asf/opennlp/tree/e06a5327 Diff: http://git-wip-us.apache.org/repos/asf/opennlp/diff/e06a5327 Branch: refs/heads/parser_regression Commit: e06a5327e433a1aa72dda073d418adf825e719f5 Parents: 98a0425 Author: Jörn Kottmann <[email protected]> Authored: Thu Feb 9 18:54:27 2017 +0100 Committer: Jörn Kottmann <[email protected]> Committed: Thu Apr 20 12:40:23 2017 +0200 ---------------------------------------------------------------------- .../namefind/TokenNameFinderTrainerTool.java | 2 +- .../postag/POSTaggerCrossValidatorTool.java | 10 +- .../cmdline/postag/POSTaggerTrainerTool.java | 26 +-- .../tools/cmdline/postag/TrainingParams.java | 13 +- .../postag/ConfigurablePOSContextGenerator.java | 105 +++++++++++ .../opennlp/tools/postag/POSDictionary.java | 8 +- .../java/opennlp/tools/postag/POSModel.java | 40 +++-- .../tools/postag/POSTaggerCrossValidator.java | 44 ++--- .../opennlp/tools/postag/POSTaggerFactory.java | 179 ++++++++++++++++++- .../tools/util/featuregen/GeneratorFactory.java | 12 ++ .../featuregen/PosTaggerFeatureGenerator.java | 62 +++++++ .../tools/postag/pos-default-features.xml | 38 ++++ .../ConfigurablePOSContextGeneratorTest.java | 55 ++++++ .../tools/postag/DummyPOSTaggerFactory.java | 14 +- .../tools/postag/POSTaggerFactoryTest.java | 11 +- 15 files changed, 534 insertions(+), 85 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/opennlp/blob/e06a5327/opennlp-tools/src/main/java/opennlp/tools/cmdline/namefind/TokenNameFinderTrainerTool.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/cmdline/namefind/TokenNameFinderTrainerTool.java b/opennlp-tools/src/main/java/opennlp/tools/cmdline/namefind/TokenNameFinderTrainerTool.java index 5bb18d2..4fb8cb9 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/cmdline/namefind/TokenNameFinderTrainerTool.java +++ b/opennlp-tools/src/main/java/opennlp/tools/cmdline/namefind/TokenNameFinderTrainerTool.java @@ -67,7 +67,7 @@ public final class TokenNameFinderTrainerTool return null; } - static byte[] openFeatureGeneratorBytes(File featureGenDescriptorFile) { + public static byte[] openFeatureGeneratorBytes(File featureGenDescriptorFile) { byte[] featureGeneratorBytes = null; // load descriptor file into memory if (featureGenDescriptorFile != null) { http://git-wip-us.apache.org/repos/asf/opennlp/blob/e06a5327/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerCrossValidatorTool.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerCrossValidatorTool.java b/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerCrossValidatorTool.java index d91d4ee..67ad2b9 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerCrossValidatorTool.java +++ b/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerCrossValidatorTool.java @@ -22,10 +22,12 @@ import java.io.FileNotFoundException; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; +import java.util.Map; import opennlp.tools.cmdline.AbstractCrossValidatorTool; import opennlp.tools.cmdline.CmdLineUtil; import opennlp.tools.cmdline.TerminateToolException; +import opennlp.tools.cmdline.namefind.TokenNameFinderTrainerTool; import opennlp.tools.cmdline.params.CVParams; import opennlp.tools.cmdline.params.FineGrainedEvaluatorParams; import opennlp.tools.cmdline.postag.POSTaggerCrossValidatorTool.CVToolParams; @@ -75,10 +77,16 @@ public final class POSTaggerCrossValidatorTool } } + Map<String, Object> resources = TokenNameFinderTrainerTool.loadResources( + params.getResources(), params.getFeaturegen()); + + byte[] featureGeneratorBytes = + TokenNameFinderTrainerTool.openFeatureGeneratorBytes(params.getFeaturegen()); + POSTaggerCrossValidator validator; try { validator = new POSTaggerCrossValidator(params.getLang(), mlParams, - params.getDict(), params.getNgram(), params.getTagDictCutoff(), + params.getDict(), featureGeneratorBytes, resources, params.getTagDictCutoff(), params.getFactory(), missclassifiedListener, reportListener); validator.evaluate(sampleStream, params.getFolds()); http://git-wip-us.apache.org/repos/asf/opennlp/blob/e06a5327/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerTrainerTool.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerTrainerTool.java b/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerTrainerTool.java index 1e6fb54..b922176 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerTrainerTool.java +++ b/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/POSTaggerTrainerTool.java @@ -19,13 +19,14 @@ package opennlp.tools.cmdline.postag; import java.io.File; import java.io.IOException; +import java.util.Map; import opennlp.tools.cmdline.AbstractTrainerTool; import opennlp.tools.cmdline.CmdLineUtil; import opennlp.tools.cmdline.TerminateToolException; +import opennlp.tools.cmdline.namefind.TokenNameFinderTrainerTool; import opennlp.tools.cmdline.params.TrainingToolParams; import opennlp.tools.cmdline.postag.POSTaggerTrainerTool.TrainerToolParams; -import opennlp.tools.dictionary.Dictionary; import opennlp.tools.ml.TrainerFactory; import opennlp.tools.postag.MutableTagDictionary; import opennlp.tools.postag.POSModel; @@ -66,25 +67,16 @@ public final class POSTaggerTrainerTool File modelOutFile = params.getModel(); CmdLineUtil.checkOutputFile("pos tagger model", modelOutFile); - Dictionary ngramDict = null; + Map<String, Object> resources = TokenNameFinderTrainerTool.loadResources( + params.getResources(), params.getFeaturegen()); - Integer ngramCutoff = params.getNgram(); - - if (ngramCutoff != null) { - System.err.print("Building ngram dictionary ... "); - try { - ngramDict = POSTaggerME.buildNGramDictionary(sampleStream, ngramCutoff); - sampleStream.reset(); - } catch (IOException e) { - throw new TerminateToolException(-1, - "IO error while building NGram Dictionary: " + e.getMessage(), e); - } - System.err.println("done"); - } + byte[] featureGeneratorBytes = + TokenNameFinderTrainerTool.openFeatureGeneratorBytes(params.getFeaturegen()); POSTaggerFactory postaggerFactory; try { - postaggerFactory = POSTaggerFactory.create(params.getFactory(), ngramDict, null); + postaggerFactory = POSTaggerFactory.create(params.getFactory(), featureGeneratorBytes, + resources, null); } catch (InvalidFormatException e) { throw new TerminateToolException(-1, e.getMessage(), e); } @@ -95,7 +87,7 @@ public final class POSTaggerTrainerTool .createTagDictionary(params.getDict())); } catch (IOException e) { throw new TerminateToolException(-1, - "IO error while loading POS Dictionary: " + e.getMessage(), e); + "IO error while loading POS Dictionary", e); } } http://git-wip-us.apache.org/repos/asf/opennlp/blob/e06a5327/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/TrainingParams.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/TrainingParams.java b/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/TrainingParams.java index 690b359..31d5e48 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/TrainingParams.java +++ b/opennlp-tools/src/main/java/opennlp/tools/cmdline/postag/TrainingParams.java @@ -29,14 +29,17 @@ import opennlp.tools.cmdline.params.BasicTrainingParams; * Note: Do not use this class, internal use only! */ interface TrainingParams extends BasicTrainingParams { - @ParameterDescription(valueName = "dictionaryPath", description = "The XML tag dictionary file") + @ParameterDescription(valueName = "featuregenFile", description = "The feature generator descriptor file") @OptionalParameter - File getDict(); + File getFeaturegen(); + + @ParameterDescription(valueName = "resourcesDir", description = "The resources directory") + @OptionalParameter + File getResources(); - @ParameterDescription(valueName = "cutoff", - description = "NGram cutoff. If not specified will not create ngram dictionary.") + @ParameterDescription(valueName = "dictionaryPath", description = "The XML tag dictionary file") @OptionalParameter - Integer getNgram(); + File getDict(); @ParameterDescription(valueName = "tagDictCutoff", description = "TagDictionary cutoff. If specified will create/expand a mutable TagDictionary") http://git-wip-us.apache.org/repos/asf/opennlp/blob/e06a5327/opennlp-tools/src/main/java/opennlp/tools/postag/ConfigurablePOSContextGenerator.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/postag/ConfigurablePOSContextGenerator.java b/opennlp-tools/src/main/java/opennlp/tools/postag/ConfigurablePOSContextGenerator.java new file mode 100644 index 0000000..e6b65df --- /dev/null +++ b/opennlp-tools/src/main/java/opennlp/tools/postag/ConfigurablePOSContextGenerator.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.postag; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import opennlp.tools.util.Cache; +import opennlp.tools.util.featuregen.AdaptiveFeatureGenerator; + +/** + * A context generator for the POS Tagger. + */ +public class ConfigurablePOSContextGenerator implements POSContextGenerator { + + private Cache<String, String[]> contextsCache; + private Object wordsKey; + + private final AdaptiveFeatureGenerator featureGenerator; + + /** + * Initializes the current instance. + * + * @param cacheSize + */ + public ConfigurablePOSContextGenerator(int cacheSize, AdaptiveFeatureGenerator featureGenerator) { + this.featureGenerator = Objects.requireNonNull(featureGenerator, "featureGenerator must not be null"); + + if (cacheSize > 0) { + contextsCache = new Cache<>(cacheSize); + } + } + + /** + * Initializes the current instance. + * + */ + public ConfigurablePOSContextGenerator(AdaptiveFeatureGenerator featureGenerator) { + this(0, featureGenerator); + } + + /** + * Returns the context for making a pos tag decision at the specified token index + * given the specified tokens and previous tags. + * @param index The index of the token for which the context is provided. + * @param tokens The tokens in the sentence. + * @param tags The tags assigned to the previous words in the sentence. + * @return The context for making a pos tag decision at the specified token index + * given the specified tokens and previous tags. + */ + public String[] getContext(int index, String[] tokens, String[] tags, + Object[] additionalContext) { + + String tagprev = null; + String tagprevprev = null; + + if (index - 1 >= 0) { + tagprev = tags[index - 1]; + + if (index - 2 >= 0) { + tagprevprev = tags[index - 2]; + } + } + + String cacheKey = index + tagprev + tagprevprev; + if (contextsCache != null) { + if (wordsKey == tokens) { + String[] cachedContexts = contextsCache.get(cacheKey); + if (cachedContexts != null) { + return cachedContexts; + } + } + else { + contextsCache.clear(); + wordsKey = tokens; + } + } + + List<String> e = new ArrayList<>(); + + featureGenerator.createFeatures(e, tokens, index, tags); + + String[] contexts = e.toArray(new String[e.size()]); + if (contextsCache != null) { + contextsCache.put(cacheKey, contexts); + } + return contexts; + } +} http://git-wip-us.apache.org/repos/asf/opennlp/blob/e06a5327/opennlp-tools/src/main/java/opennlp/tools/postag/POSDictionary.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/postag/POSDictionary.java b/opennlp-tools/src/main/java/opennlp/tools/postag/POSDictionary.java index 5f5eb25..90d51c1 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/postag/POSDictionary.java +++ b/opennlp-tools/src/main/java/opennlp/tools/postag/POSDictionary.java @@ -32,12 +32,13 @@ import opennlp.tools.dictionary.serializer.Entry; import opennlp.tools.util.InvalidFormatException; import opennlp.tools.util.StringList; import opennlp.tools.util.StringUtil; +import opennlp.tools.util.model.SerializableArtifact; /** * Provides a means of determining which tags are valid for a particular word * based on a tag dictionary read from a file. */ -public class POSDictionary implements Iterable<String>, MutableTagDictionary { +public class POSDictionary implements Iterable<String>, MutableTagDictionary, SerializableArtifact { private Map<String, String[]> dictionary; @@ -265,4 +266,9 @@ public class POSDictionary implements Iterable<String>, MutableTagDictionary { public boolean isCaseSensitive() { return this.caseSensitive; } + + @Override + public Class<?> getArtifactSerializerClass() { + return POSTaggerFactory.POSDictionarySerializer.class; + } } http://git-wip-us.apache.org/repos/asf/opennlp/blob/e06a5327/opennlp-tools/src/main/java/opennlp/tools/postag/POSModel.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/postag/POSModel.java b/opennlp-tools/src/main/java/opennlp/tools/postag/POSModel.java index bfe5c90..f81092b 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/postag/POSModel.java +++ b/opennlp-tools/src/main/java/opennlp/tools/postag/POSModel.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.io.InputStream; import java.net.URL; import java.util.Map; +import java.util.Objects; import java.util.Properties; import opennlp.tools.dictionary.Dictionary; @@ -32,6 +33,7 @@ import opennlp.tools.util.BaseToolFactory; import opennlp.tools.util.InvalidFormatException; import opennlp.tools.util.model.ArtifactSerializer; import opennlp.tools.util.model.BaseModel; +import opennlp.tools.util.model.ByteArraySerializer; /** * The {@link POSModel} is the model used @@ -42,18 +44,23 @@ import opennlp.tools.util.model.BaseModel; public final class POSModel extends BaseModel { private static final String COMPONENT_NAME = "POSTaggerME"; - static final String POS_MODEL_ENTRY_NAME = "pos.model"; + static final String GENERATOR_DESCRIPTOR_ENTRY_NAME = "generator.featuregen"; public POSModel(String languageCode, SequenceClassificationModel<String> posModel, Map<String, String> manifestInfoEntries, POSTaggerFactory posFactory) { super(COMPONENT_NAME, languageCode, manifestInfoEntries, posFactory); - if (posModel == null) - throw new IllegalArgumentException("The maxentPosModel param must not be null!"); + artifactMap.put(POS_MODEL_ENTRY_NAME, + Objects.requireNonNull(posModel, "posModel must not be null")); + + artifactMap.put(GENERATOR_DESCRIPTOR_ENTRY_NAME, posFactory.getFeatureGenerator()); + + for (Map.Entry<String, Object> resource : posFactory.getResources().entrySet()) { + artifactMap.put(resource.getKey(), resource.getValue()); + } - artifactMap.put(POS_MODEL_ENTRY_NAME, posModel); // TODO: This fails probably for the sequence model ... ?! // checkArtifactMap(); } @@ -68,13 +75,18 @@ public final class POSModel extends BaseModel { super(COMPONENT_NAME, languageCode, manifestInfoEntries, posFactory); - if (posModel == null) - throw new IllegalArgumentException("The maxentPosModel param must not be null!"); + Objects.requireNonNull(posModel, "posModel must not be null"); Properties manifest = (Properties) artifactMap.get(MANIFEST_ENTRY); manifest.setProperty(BeamSearch.BEAM_SIZE_PARAMETER, Integer.toString(beamSize)); artifactMap.put(POS_MODEL_ENTRY_NAME, posModel); + artifactMap.put(GENERATOR_DESCRIPTOR_ENTRY_NAME, posFactory.getFeatureGenerator()); + + for (Map.Entry<String, Object> resource : posFactory.getResources().entrySet()) { + artifactMap.put(resource.getKey(), resource.getValue()); + } + checkArtifactMap(); } @@ -96,14 +108,6 @@ public final class POSModel extends BaseModel { } @Override - @SuppressWarnings("rawtypes") - protected void createArtifactSerializers( - Map<String, ArtifactSerializer> serializers) { - - super.createArtifactSerializers(serializers); - } - - @Override protected void validateArtifactMap() throws InvalidFormatException { super.validateArtifactMap(); @@ -114,6 +118,7 @@ public final class POSModel extends BaseModel { /** * @deprecated use getPosSequenceModel instead. This method will be removed soon. + * Only required for Parser 1.5.x backward compatibility. Newer models don't need this anymore. */ @Deprecated public MaxentModel getPosModel() { @@ -151,6 +156,13 @@ public final class POSModel extends BaseModel { return (POSTaggerFactory) this.toolFactory; } + @Override + protected void createArtifactSerializers(Map<String, ArtifactSerializer> serializers) { + super.createArtifactSerializers(serializers); + + serializers.put("featuregen", new ByteArraySerializer()); + } + /** * Retrieves the ngram dictionary. * http://git-wip-us.apache.org/repos/asf/opennlp/blob/e06a5327/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java index 3010e03..a35bbb6 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java +++ b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java @@ -19,6 +19,7 @@ package opennlp.tools.postag; import java.io.File; import java.io.IOException; +import java.util.Map; import opennlp.tools.dictionary.Dictionary; import opennlp.tools.util.ObjectStream; @@ -32,7 +33,8 @@ public class POSTaggerCrossValidator { private final TrainingParameters params; - private Integer ngramCutoff; + private byte[] featureGeneratorBytes; + private Map<String, Object> resources; private Mean wordAccuracy = new Mean(); private POSTaggerEvaluationMonitor[] listeners; @@ -51,18 +53,21 @@ public class POSTaggerCrossValidator { * the tag and the ngram dictionaries. */ public POSTaggerCrossValidator(String languageCode, - TrainingParameters trainParam, File tagDictionary, - Integer ngramCutoff, Integer tagdicCutoff, String factoryClass, - POSTaggerEvaluationMonitor... listeners) { + TrainingParameters trainParam, File tagDictionary, + byte[] featureGeneratorBytes, Map<String, Object> resources, + Integer tagdicCutoff, String factoryClass, + POSTaggerEvaluationMonitor... listeners) { this.languageCode = languageCode; this.params = trainParam; - this.ngramCutoff = ngramCutoff; + this.featureGeneratorBytes = featureGeneratorBytes; + this.resources = resources; this.listeners = listeners; this.factoryClassName = factoryClass; this.tagdicCutoff = tagdicCutoff; this.tagDictionaryFile = tagDictionary; } + /** * Creates a {@link POSTaggerCrossValidator} using the given * {@link POSTaggerFactory}. @@ -74,7 +79,6 @@ public class POSTaggerCrossValidator { this.params = trainParam; this.listeners = listeners; this.factory = factory; - this.ngramCutoff = null; this.tagdicCutoff = null; } @@ -98,33 +102,18 @@ public class POSTaggerCrossValidator { CrossValidationPartitioner.TrainingSampleStream<POSSample> trainingSampleStream = partitioner .next(); - if (this.factory == null) { - this.factory = POSTaggerFactory.create(this.factoryClassName, null, - null); - } - - Dictionary ngramDict = this.factory.getDictionary(); - if (ngramDict == null) { - if (this.ngramCutoff != null) { - System.err.print("Building ngram dictionary ... "); - ngramDict = POSTaggerME.buildNGramDictionary(trainingSampleStream, - this.ngramCutoff); - trainingSampleStream.reset(); - System.err.println("done"); - } - this.factory.setDictionary(ngramDict); - } if (this.tagDictionaryFile != null && this.factory.getTagDictionary() == null) { this.factory.setTagDictionary(this.factory .createTagDictionary(tagDictionaryFile)); } + + TagDictionary dict = null; if (this.tagdicCutoff != null) { - TagDictionary dict = this.factory.getTagDictionary(); + dict = this.factory.getTagDictionary(); if (dict == null) { dict = this.factory.createEmptyTagDictionary(); - this.factory.setTagDictionary(dict); } if (dict instanceof MutableTagDictionary) { POSTaggerME.populatePOSDictionary(trainingSampleStream, (MutableTagDictionary)dict, @@ -136,6 +125,12 @@ public class POSTaggerCrossValidator { trainingSampleStream.reset(); } + if (this.factory == null) { + this.factory = POSTaggerFactory.create(this.factoryClassName, null, null); + } + + factory.init(featureGeneratorBytes, resources, dict); + POSModel model = POSTaggerME.train(languageCode, trainingSampleStream, params, this.factory); @@ -148,7 +143,6 @@ public class POSTaggerCrossValidator { if (this.tagdicCutoff != null) { this.factory.setTagDictionary(null); } - } } http://git-wip-us.apache.org/repos/asf/opennlp/blob/e06a5327/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerFactory.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerFactory.java b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerFactory.java index eb5466e..37143c9 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerFactory.java +++ b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerFactory.java @@ -17,6 +17,8 @@ package opennlp.tools.postag; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.File; import java.io.FileInputStream; import java.io.IOException; @@ -29,10 +31,15 @@ import java.util.Set; import opennlp.tools.dictionary.Dictionary; import opennlp.tools.ml.model.AbstractModel; +import opennlp.tools.namefind.TokenNameFinderFactory; import opennlp.tools.util.BaseToolFactory; import opennlp.tools.util.InvalidFormatException; import opennlp.tools.util.SequenceValidator; +import opennlp.tools.util.Version; import opennlp.tools.util.ext.ExtensionLoader; +import opennlp.tools.util.featuregen.AdaptiveFeatureGenerator; +import opennlp.tools.util.featuregen.AggregatedFeatureGenerator; +import opennlp.tools.util.featuregen.GeneratorFactory; import opennlp.tools.util.model.ArtifactSerializer; import opennlp.tools.util.model.UncloseableInputStream; @@ -44,7 +51,10 @@ public class POSTaggerFactory extends BaseToolFactory { private static final String TAG_DICTIONARY_ENTRY_NAME = "tags.tagdict"; private static final String NGRAM_DICTIONARY_ENTRY_NAME = "ngram.dictionary"; + protected Dictionary ngramDictionary; + private byte[] featureGeneratorBytes; + private Map<String, Object> resources; protected TagDictionary posDictionary; /** @@ -60,23 +70,127 @@ public class POSTaggerFactory extends BaseToolFactory { * * @param ngramDictionary * @param posDictionary + * + * @deprecated this constructor is here for backward compatibility and + * is not functional anymore in the training of 1.8.x series models */ - public POSTaggerFactory(Dictionary ngramDictionary, - TagDictionary posDictionary) { + @Deprecated + public POSTaggerFactory(Dictionary ngramDictionary, TagDictionary posDictionary) { this.init(ngramDictionary, posDictionary); + + // TODO: This could be made functional by creating some default feature generation + // which uses the dictionary ... + } + + public POSTaggerFactory(byte[] featureGeneratorBytes, final Map<String, Object> resources, + TagDictionary posDictionary) { + this.featureGeneratorBytes = featureGeneratorBytes; + + if (this.featureGeneratorBytes == null) { + this.featureGeneratorBytes = loadDefaultFeatureGeneratorBytes(); + } + + this.resources = resources; + this.posDictionary = posDictionary; } + @Deprecated // will be removed when only 8 series models are supported protected void init(Dictionary ngramDictionary, TagDictionary posDictionary) { this.ngramDictionary = ngramDictionary; this.posDictionary = posDictionary; } + protected void init(byte[] featureGeneratorBytes, final Map<String, Object> resources, + TagDictionary posDictionary) { + this.featureGeneratorBytes = featureGeneratorBytes; + this.resources = resources; + this.posDictionary = posDictionary; + } + private static byte[] loadDefaultFeatureGeneratorBytes() { + + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + try (InputStream in = TokenNameFinderFactory.class.getResourceAsStream( + "/opennlp/tools/postag/pos-default-features.xml")) { + + if (in == null) { + throw new IllegalStateException("Classpath must contain pos-default-features.xml file!"); + } + + byte[] buf = new byte[1024]; + int len; + while ((len = in.read(buf)) > 0) { + bytes.write(buf, 0, len); + } + } + catch (IOException e) { + throw new IllegalStateException("Failed reading from pos-default-features.xml file on classpath!"); + } + + return bytes.toByteArray(); + } + + /** + * Creates the {@link AdaptiveFeatureGenerator}. Usually this + * is a set of generators contained in the {@link AggregatedFeatureGenerator}. + * + * Note: + * The generators are created on every call to this method. + * + * @return the feature generator or null if there is no descriptor in the model + */ + public AdaptiveFeatureGenerator createFeatureGenerators() { + + if (featureGeneratorBytes == null && artifactProvider != null) { + featureGeneratorBytes = artifactProvider.getArtifact( + POSModel.GENERATOR_DESCRIPTOR_ENTRY_NAME); + } + + if (featureGeneratorBytes == null) { + featureGeneratorBytes = loadDefaultFeatureGeneratorBytes(); + } + + InputStream descriptorIn = new ByteArrayInputStream(featureGeneratorBytes); + + AdaptiveFeatureGenerator generator; + try { + generator = GeneratorFactory.create(descriptorIn, key -> { + if (artifactProvider != null) { + return artifactProvider.getArtifact(key); + } + else { + return resources.get(key); + } + }); + } catch (InvalidFormatException e) { + // It is assumed that the creation of the feature generation does not + // fail after it succeeded once during model loading. + + // But it might still be possible that such an exception is thrown, + // in this case the caller should not be forced to handle the exception + // and a Runtime Exception is thrown instead. + + // If the re-creation of the feature generation fails it is assumed + // that this can only be caused by a programming mistake and therefore + // throwing a Runtime Exception is reasonable + + throw new IllegalStateException(); // FeatureGeneratorCreationError(e); + } catch (IOException e) { + throw new IllegalStateException("Reading from mem cannot result in an I/O error", e); + } + + return generator; + } + @Override @SuppressWarnings("rawtypes") public Map<String, ArtifactSerializer> createArtifactSerializersMap() { Map<String, ArtifactSerializer> serializers = super.createArtifactSerializersMap(); - POSDictionarySerializer.register(serializers); - // the ngram Dictionary uses a base serializer, we don't need to add it here. + + // NOTE: This is only needed for old models and this if can be removed if support is dropped + if (Version.currentVersion().getMinor() < 8) { + POSDictionarySerializer.register(serializers); + } + return serializers; } @@ -111,18 +225,37 @@ public class POSTaggerFactory extends BaseToolFactory { this.posDictionary = dictionary; } + protected Map<String, Object> getResources() { + + + if (resources != null) { + return resources; + } + + return Collections.emptyMap(); + } + + protected byte[] getFeatureGenerator() { + return featureGeneratorBytes; + } + public TagDictionary getTagDictionary() { if (this.posDictionary == null && artifactProvider != null) this.posDictionary = artifactProvider.getArtifact(TAG_DICTIONARY_ENTRY_NAME); return this.posDictionary; } + /** + * @deprecated this will be reduced in visibility and later removed + */ + @Deprecated public Dictionary getDictionary() { if (this.ngramDictionary == null && artifactProvider != null) this.ngramDictionary = artifactProvider.getArtifact(NGRAM_DICTIONARY_ENTRY_NAME); return this.ngramDictionary; } + @Deprecated public void setDictionary(Dictionary ngramDict) { if (artifactProvider != null) { throw new IllegalStateException( @@ -132,10 +265,14 @@ public class POSTaggerFactory extends BaseToolFactory { } public POSContextGenerator getPOSContextGenerator() { - return new DefaultPOSContextGenerator(0, getDictionary()); + return getPOSContextGenerator(0); } public POSContextGenerator getPOSContextGenerator(int cacheSize) { + if (Version.currentVersion().getMinor() >= 8) { + return new ConfigurablePOSContextGenerator(cacheSize, createFeatureGenerators()); + } + return new DefaultPOSContextGenerator(cacheSize, getDictionary()); } @@ -143,7 +280,9 @@ public class POSTaggerFactory extends BaseToolFactory { return new DefaultPOSSequenceValidator(getTagDictionary()); } - static class POSDictionarySerializer implements ArtifactSerializer<POSDictionary> { + // TODO: This should not be done anymore for 8 models, they can just + // use the SerializableArtifact interface + public static class POSDictionarySerializer implements ArtifactSerializer<POSDictionary> { public POSDictionary create(InputStream in) throws IOException { return POSDictionary.create(new UncloseableInputStream(in)); @@ -218,6 +357,7 @@ public class POSTaggerFactory extends BaseToolFactory { } + @Deprecated public static POSTaggerFactory create(String subclassName, Dictionary ngramDictionary, TagDictionary posDictionary) throws InvalidFormatException { @@ -233,11 +373,34 @@ public class POSTaggerFactory extends BaseToolFactory { } catch (Exception e) { String msg = "Could not instantiate the " + subclassName + ". The initialization throw an exception."; - System.err.println(msg); - e.printStackTrace(); throw new InvalidFormatException(msg, e); } + } + + public static POSTaggerFactory create(String subclassName, byte[] featureGeneratorBytes, + Map<String, Object> resources, TagDictionary posDictionary) + throws InvalidFormatException { + + POSTaggerFactory theFactory; + + if (subclassName == null) { + // will create the default factory + theFactory = new POSTaggerFactory(null, posDictionary); + } + else { + try { + theFactory = ExtensionLoader.instantiateExtension( + POSTaggerFactory.class, subclassName); + } catch (Exception e) { + String msg = "Could not instantiate the " + subclassName + + ". The initialization throw an exception."; + throw new InvalidFormatException(msg, e); + } + } + + theFactory.init(featureGeneratorBytes, resources, posDictionary); + return theFactory; } public TagDictionary createEmptyTagDictionary() { http://git-wip-us.apache.org/repos/asf/opennlp/blob/e06a5327/opennlp-tools/src/main/java/opennlp/tools/util/featuregen/GeneratorFactory.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/util/featuregen/GeneratorFactory.java b/opennlp-tools/src/main/java/opennlp/tools/util/featuregen/GeneratorFactory.java index ef08cfb..a1ac72b 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/util/featuregen/GeneratorFactory.java +++ b/opennlp-tools/src/main/java/opennlp/tools/util/featuregen/GeneratorFactory.java @@ -489,6 +489,17 @@ public class GeneratorFactory { } } + static class PosTaggerFeatureGeneratorFactory implements XmlFeatureGeneratorFactory { + public AdaptiveFeatureGenerator create(Element generatorElement, + FeatureGeneratorResourceProvider resourceManager) { + return new PosTaggerFeatureGenerator(); + } + + static void register(Map<String, XmlFeatureGeneratorFactory> factoryMap) { + factoryMap.put("postagger", new PosTaggerFeatureGeneratorFactory()); + } + } + /** * @see WindowFeatureGenerator */ @@ -658,6 +669,7 @@ public class GeneratorFactory { TokenFeatureGeneratorFactory.register(factories); BigramNameFeatureGeneratorFactory.register(factories); TokenPatternFeatureGeneratorFactory.register(factories); + PosTaggerFeatureGeneratorFactory.register(factories); PrefixFeatureGeneratorFactory.register(factories); SuffixFeatureGeneratorFactory.register(factories); WindowFeatureGeneratorFactory.register(factories); http://git-wip-us.apache.org/repos/asf/opennlp/blob/e06a5327/opennlp-tools/src/main/java/opennlp/tools/util/featuregen/PosTaggerFeatureGenerator.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/java/opennlp/tools/util/featuregen/PosTaggerFeatureGenerator.java b/opennlp-tools/src/main/java/opennlp/tools/util/featuregen/PosTaggerFeatureGenerator.java new file mode 100644 index 0000000..c32baec --- /dev/null +++ b/opennlp-tools/src/main/java/opennlp/tools/util/featuregen/PosTaggerFeatureGenerator.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.util.featuregen; + +import java.util.List; + +public class PosTaggerFeatureGenerator implements AdaptiveFeatureGenerator { + + private final String SB = "S=begin"; + + @Override + public void createFeatures(List<String> features, String[] tokens, int index, + String[] tags) { + + String prev, prevprev = null; + String tagprev, tagprevprev; + tagprev = tagprevprev = null; + + if (index - 1 >= 0) { + prev = tokens[index - 1]; + tagprev = tags[index - 1]; + + if (index - 2 >= 0) { + prevprev = tokens[index - 2]; + tagprevprev = tags[index - 2]; + } + else { + prevprev = SB; + } + } + else { + prev = SB; + } + + // add the words and pos's of the surrounding context + if (prev != null) { + if (tagprev != null) { + features.add("t=" + tagprev); + } + if (prevprev != null) { + if (tagprevprev != null) { + features.add("t2=" + tagprevprev + "," + tagprev); + } + } + } + } +} http://git-wip-us.apache.org/repos/asf/opennlp/blob/e06a5327/opennlp-tools/src/main/resources/opennlp/tools/postag/pos-default-features.xml ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/main/resources/opennlp/tools/postag/pos-default-features.xml b/opennlp-tools/src/main/resources/opennlp/tools/postag/pos-default-features.xml new file mode 100644 index 0000000..0be1fc8 --- /dev/null +++ b/opennlp-tools/src/main/resources/opennlp/tools/postag/pos-default-features.xml @@ -0,0 +1,38 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> + + +<!-- Default pos tagger feature generator configuration --> +<generators> + <cache> + <generators> + <definition/> + <suffix/> + <prefix/> + <window prevLength = "2" nextLength = "2"> + <token/> + </window> + <window prevLength = "2" nextLength = "2"> + <sentence begin="true" end="false"/> + </window> + <tokenclass/> + <postagger/> + </generators> + </cache> +</generators> http://git-wip-us.apache.org/repos/asf/opennlp/blob/e06a5327/opennlp-tools/src/test/java/opennlp/tools/postag/ConfigurablePOSContextGeneratorTest.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/test/java/opennlp/tools/postag/ConfigurablePOSContextGeneratorTest.java b/opennlp-tools/src/test/java/opennlp/tools/postag/ConfigurablePOSContextGeneratorTest.java new file mode 100644 index 0000000..f00e855 --- /dev/null +++ b/opennlp-tools/src/test/java/opennlp/tools/postag/ConfigurablePOSContextGeneratorTest.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package opennlp.tools.postag; + +import org.junit.Assert; +import org.junit.Test; + +import opennlp.tools.util.featuregen.AdaptiveFeatureGenerator; +import opennlp.tools.util.featuregen.TokenFeatureGenerator; + +public class ConfigurablePOSContextGeneratorTest { + + private void testContextGeneration(int cacheSize) { + AdaptiveFeatureGenerator fg = new TokenFeatureGenerator(); + ConfigurablePOSContextGenerator cg = new ConfigurablePOSContextGenerator(cacheSize, fg); + + String[] tokens = new String[] {"a", "b", "c", "d", "e"}; + String[] tags = new String[] {"t_a", "t_b", "t_c", "t_d", "t_e"}; + + cg.getContext(0, tokens, tags, null); + + Assert.assertEquals(1, cg.getContext(0, tokens, tags, null).length); + Assert.assertEquals("w=a", cg.getContext(0, tokens, tags, null)[0]); + Assert.assertEquals("w=b", cg.getContext(1, tokens, tags, null)[0]); + Assert.assertEquals("w=c", cg.getContext(2, tokens, tags, null)[0]); + Assert.assertEquals("w=d", cg.getContext(3, tokens, tags, null)[0]); + Assert.assertEquals("w=e", cg.getContext(4, tokens, tags, null)[0]); + } + + @Test + public void testWithoutCache() { + testContextGeneration(0); + } + + @Test + public void testWithCache() { + testContextGeneration(3); + } + +} http://git-wip-us.apache.org/repos/asf/opennlp/blob/e06a5327/opennlp-tools/src/test/java/opennlp/tools/postag/DummyPOSTaggerFactory.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/test/java/opennlp/tools/postag/DummyPOSTaggerFactory.java b/opennlp-tools/src/test/java/opennlp/tools/postag/DummyPOSTaggerFactory.java index e0ce2a6..91228fc 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/postag/DummyPOSTaggerFactory.java +++ b/opennlp-tools/src/test/java/opennlp/tools/postag/DummyPOSTaggerFactory.java @@ -36,8 +36,8 @@ public class DummyPOSTaggerFactory extends POSTaggerFactory { public DummyPOSTaggerFactory() { } - public DummyPOSTaggerFactory(Dictionary ngramDictionary, DummyPOSDictionary posDictionary) { - super(ngramDictionary, null); + public DummyPOSTaggerFactory(DummyPOSDictionary posDictionary) { + super(null, null, null); this.dict = posDictionary; } @@ -81,7 +81,7 @@ public class DummyPOSTaggerFactory extends POSTaggerFactory { } - static class DummyPOSDictionarySerializer implements ArtifactSerializer<DummyPOSDictionary> { + public static class DummyPOSDictionarySerializer implements ArtifactSerializer<DummyPOSDictionary> { public DummyPOSDictionary create(InputStream in) throws IOException { return DummyPOSDictionary.create(new UncloseableInputStream(in)); @@ -106,6 +106,9 @@ public class DummyPOSTaggerFactory extends POSTaggerFactory { private POSDictionary dict; + public DummyPOSDictionary() { + } + public DummyPOSDictionary(POSDictionary dict) { this.dict = dict; } @@ -123,6 +126,9 @@ public class DummyPOSTaggerFactory extends POSTaggerFactory { return dict.getTags(word); } + @Override + public Class<?> getArtifactSerializerClass() { + return DummyPOSDictionarySerializer.class; + } } - } http://git-wip-us.apache.org/repos/asf/opennlp/blob/e06a5327/opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerFactoryTest.java ---------------------------------------------------------------------- diff --git a/opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerFactoryTest.java b/opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerFactoryTest.java index edb20b3..b98d3bf 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerFactoryTest.java +++ b/opennlp-tools/src/test/java/opennlp/tools/postag/POSTaggerFactoryTest.java @@ -25,7 +25,6 @@ import java.nio.charset.StandardCharsets; import org.junit.Assert; import org.junit.Test; -import opennlp.tools.dictionary.Dictionary; import opennlp.tools.formats.ResourceAsStreamFactory; import opennlp.tools.postag.DummyPOSTaggerFactory.DummyPOSContextGenerator; import opennlp.tools.postag.DummyPOSTaggerFactory.DummyPOSDictionary; @@ -62,9 +61,8 @@ public class POSTaggerFactoryTest { DummyPOSDictionary posDict = new DummyPOSDictionary( POSDictionary.create(POSDictionaryTest.class .getResourceAsStream("TagDictionaryCaseSensitive.xml"))); - Dictionary dic = POSTaggerME.buildNGramDictionary(createSampleStream(), 0); - POSModel posModel = trainPOSModel(new DummyPOSTaggerFactory(dic, posDict)); + POSModel posModel = trainPOSModel(new DummyPOSTaggerFactory(posDict)); POSTaggerFactory factory = posModel.getFactory(); Assert.assertTrue(factory.getTagDictionary() instanceof DummyPOSDictionary); @@ -81,22 +79,18 @@ public class POSTaggerFactoryTest { Assert.assertTrue(factory.getTagDictionary() instanceof DummyPOSDictionary); Assert.assertTrue(factory.getPOSContextGenerator() instanceof DummyPOSContextGenerator); Assert.assertTrue(factory.getSequenceValidator() instanceof DummyPOSSequenceValidator); - Assert.assertTrue(factory.getDictionary() != null); } @Test public void testPOSTaggerWithDefaultFactory() throws IOException { POSDictionary posDict = POSDictionary.create(POSDictionaryTest.class .getResourceAsStream("TagDictionaryCaseSensitive.xml")); - Dictionary dic = POSTaggerME.buildNGramDictionary(createSampleStream(), 0); - - POSModel posModel = trainPOSModel(new POSTaggerFactory(dic, posDict)); + POSModel posModel = trainPOSModel(new POSTaggerFactory(null, null, posDict)); POSTaggerFactory factory = posModel.getFactory(); Assert.assertTrue(factory.getTagDictionary() instanceof POSDictionary); Assert.assertTrue(factory.getPOSContextGenerator() != null); Assert.assertTrue(factory.getSequenceValidator() instanceof DefaultPOSSequenceValidator); - Assert.assertTrue(factory.getDictionary() != null); ByteArrayOutputStream out = new ByteArrayOutputStream(); posModel.serialize(out); @@ -108,7 +102,6 @@ public class POSTaggerFactoryTest { Assert.assertTrue(factory.getTagDictionary() instanceof POSDictionary); Assert.assertTrue(factory.getPOSContextGenerator() != null); Assert.assertTrue(factory.getSequenceValidator() instanceof DefaultPOSSequenceValidator); - Assert.assertTrue(factory.getDictionary() != null); } @Test(expected = InvalidFormatException.class)
