This is an automated email from the ASF dual-hosted git repository. rzo1 pushed a commit to branch fix-eval-tests-after-pos-tagging-conversion in repository https://gitbox.apache.org/repos/asf/opennlp.git
commit 249197703e1f6d051b79ad8788bdeeb9cf19e273 Author: Richard Zowalla <[email protected]> AuthorDate: Wed May 29 20:43:08 2024 +0200 OPENNLP-1564 - Correct POSTagFormat to PENN format for some of the eval tests. Allow specification of POSTagFormat for POSTaggerCrossValidator. --- .../tools/postag/POSTaggerCrossValidator.java | 48 ++++++++++++++++++++-- .../opennlp/tools/eval/ConllXPosTaggerEval.java | 3 +- .../tools/eval/OntoNotes4PosTaggerEval.java | 4 +- .../opennlp/tools/eval/SourceForgeModelEval.java | 3 +- 4 files changed, 52 insertions(+), 6 deletions(-) diff --git a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java index 96ffa10b..7fecbe63 100644 --- a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java +++ b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java @@ -31,6 +31,7 @@ public class POSTaggerCrossValidator { private final String languageCode; private final TrainingParameters params; + private final POSTagFormat posTagFormat; private byte[] featureGeneratorBytes; private Map<String, Object> resources; @@ -46,6 +47,7 @@ public class POSTaggerCrossValidator { private final Integer tagdicCutoff; private File tagDictionaryFile; + /** * Initializes a {@link POSTaggerCrossValidator} that builds a ngram dictionary * dynamically. It instantiates a subclass of {@link POSTaggerFactory} using @@ -57,12 +59,13 @@ public class POSTaggerCrossValidator { * @param featureGeneratorBytes The bytes for feature generation. * @param resources Additional resources as key-value map. * @param factoryClass The class name used for factory instantiation. + * @param format A valid {@link POSTagFormat}. * @param listeners The {@link POSTaggerEvaluationMonitor evaluation listeners}. */ public POSTaggerCrossValidator(String languageCode, TrainingParameters trainParam, File tagDictionary, byte[] featureGeneratorBytes, Map<String, Object> resources, - Integer tagdicCutoff, String factoryClass, + Integer tagdicCutoff, String factoryClass, POSTagFormat format, POSTaggerEvaluationMonitor... listeners) { this.languageCode = languageCode; this.params = trainParam; @@ -72,6 +75,28 @@ public class POSTaggerCrossValidator { this.factoryClassName = factoryClass; this.tagdicCutoff = tagdicCutoff; this.tagDictionaryFile = tagDictionary; + this.posTagFormat = format; + } + /** + * Initializes a {@link POSTaggerCrossValidator} that builds a ngram dictionary + * dynamically. It instantiates a subclass of {@link POSTaggerFactory} using + * the tag and the ngram dictionaries. + * + * @param languageCode An ISO conform language code. + * @param trainParam The {@link TrainingParameters} for the context of cross validation. + * @param tagDictionary The {@link File} that references the a {@link TagDictionary}. + * @param featureGeneratorBytes The bytes for feature generation. + * @param resources Additional resources as key-value map. + * @param factoryClass The class name used for factory instantiation. + * @param listeners The {@link POSTaggerEvaluationMonitor evaluation listeners}. + */ + public POSTaggerCrossValidator(String languageCode, + TrainingParameters trainParam, File tagDictionary, + byte[] featureGeneratorBytes, Map<String, Object> resources, + Integer tagdicCutoff, String factoryClass, + POSTaggerEvaluationMonitor... listeners) { + this(languageCode, trainParam, tagDictionary, featureGeneratorBytes, resources, + tagdicCutoff, factoryClass, POSTagFormat.UD, listeners); } @@ -86,10 +111,27 @@ public class POSTaggerCrossValidator { public POSTaggerCrossValidator(String languageCode, TrainingParameters trainParam, POSTaggerFactory factory, POSTaggerEvaluationMonitor... listeners) { + this(languageCode, trainParam, factory, POSTagFormat.UD, listeners); + } + + + /** + * Creates a {@link POSTaggerCrossValidator} using the given {@link POSTaggerFactory}. + * + * @param languageCode An ISO conform language code. + * @param trainParam The {@link TrainingParameters} for the context of cross validation. + * @param factory The {@link POSTaggerFactory} to be used. + * @param format A valid {@link POSTagFormat}. + * @param listeners The {@link POSTaggerEvaluationMonitor evaluation listeners}. + */ + public POSTaggerCrossValidator(String languageCode, + TrainingParameters trainParam, POSTaggerFactory factory, POSTagFormat format, + POSTaggerEvaluationMonitor... listeners) { this.languageCode = languageCode; this.params = trainParam; this.listeners = listeners; this.factory = factory; + this.posTagFormat = format; this.tagdicCutoff = null; } @@ -142,7 +184,7 @@ public class POSTaggerCrossValidator { POSModel model = POSTaggerME.train(languageCode, trainingSampleStream, params, this.factory); - POSEvaluator evaluator = new POSEvaluator(new POSTaggerME(model), listeners); + POSEvaluator evaluator = new POSEvaluator(new POSTaggerME(model, posTagFormat), listeners); evaluator.evaluate(trainingSampleStream.getTestSampleStream()); @@ -169,5 +211,5 @@ public class POSTaggerCrossValidator { public long getWordCount() { return wordAccuracy.count(); } - + } diff --git a/opennlp-tools/src/test/java/opennlp/tools/eval/ConllXPosTaggerEval.java b/opennlp-tools/src/test/java/opennlp/tools/eval/ConllXPosTaggerEval.java index ac71faf8..635050ac 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/eval/ConllXPosTaggerEval.java +++ b/opennlp-tools/src/test/java/opennlp/tools/eval/ConllXPosTaggerEval.java @@ -31,6 +31,7 @@ import opennlp.tools.formats.ConllXPOSSampleStream; import opennlp.tools.postag.POSEvaluator; import opennlp.tools.postag.POSModel; import opennlp.tools.postag.POSSample; +import opennlp.tools.postag.POSTagFormat; import opennlp.tools.postag.POSTaggerFactory; import opennlp.tools.postag.POSTaggerME; import opennlp.tools.util.MarkableFileInputStreamFactory; @@ -73,7 +74,7 @@ public class ConllXPosTaggerEval extends AbstractEvalTest { ObjectStream<POSSample> samples = new ConllXPOSSampleStream( new MarkableFileInputStreamFactory(testData), StandardCharsets.UTF_8); - POSEvaluator evaluator = new POSEvaluator(new POSTaggerME(model)); + POSEvaluator evaluator = new POSEvaluator(new POSTaggerME(model, POSTagFormat.PENN)); evaluator.evaluate(samples); Assertions.assertEquals(expectedAccuracy, evaluator.getWordAccuracy(), 0.0001); diff --git a/opennlp-tools/src/test/java/opennlp/tools/eval/OntoNotes4PosTaggerEval.java b/opennlp-tools/src/test/java/opennlp/tools/eval/OntoNotes4PosTaggerEval.java index 252080d3..984fdf03 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/eval/OntoNotes4PosTaggerEval.java +++ b/opennlp-tools/src/test/java/opennlp/tools/eval/OntoNotes4PosTaggerEval.java @@ -32,6 +32,7 @@ import opennlp.tools.formats.convert.ParseToPOSSampleStream; import opennlp.tools.formats.ontonotes.DocumentToLineStream; import opennlp.tools.formats.ontonotes.OntoNotesParseSampleStream; import opennlp.tools.postag.POSSample; +import opennlp.tools.postag.POSTagFormat; import opennlp.tools.postag.POSTaggerCrossValidator; import opennlp.tools.postag.POSTaggerFactory; import opennlp.tools.util.ObjectStream; @@ -59,7 +60,8 @@ public class OntoNotes4PosTaggerEval extends AbstractEvalTest { private void crossEval(TrainingParameters params, double expectedScore) throws IOException { try (ObjectStream<POSSample> samples = createPOSSampleStream()) { - POSTaggerCrossValidator cv = new POSTaggerCrossValidator("eng", params, new POSTaggerFactory()); + POSTaggerCrossValidator cv = new POSTaggerCrossValidator("eng", params, + new POSTaggerFactory(), POSTagFormat.PENN); cv.evaluate(samples, 5); Assertions.assertEquals(expectedScore, cv.getWordAccuracy(), 0.0001d); diff --git a/opennlp-tools/src/test/java/opennlp/tools/eval/SourceForgeModelEval.java b/opennlp-tools/src/test/java/opennlp/tools/eval/SourceForgeModelEval.java index 36005707..3bd4213b 100644 --- a/opennlp-tools/src/test/java/opennlp/tools/eval/SourceForgeModelEval.java +++ b/opennlp-tools/src/test/java/opennlp/tools/eval/SourceForgeModelEval.java @@ -44,6 +44,7 @@ import opennlp.tools.parser.ParserFactory; import opennlp.tools.parser.ParserModel; import opennlp.tools.postag.POSModel; import opennlp.tools.postag.POSSample; +import opennlp.tools.postag.POSTagFormat; import opennlp.tools.postag.POSTagger; import opennlp.tools.postag.POSTaggerME; import opennlp.tools.sentdetect.SentenceDetector; @@ -352,7 +353,7 @@ public class SourceForgeModelEval extends AbstractEvalTest { MessageDigest digest = MessageDigest.getInstance(HASH_ALGORITHM); - POSTagger tagger = new POSTaggerME(model); + POSTagger tagger = new POSTaggerME(model, POSTagFormat.PENN); try (ObjectStream<LeipzigTestSample> lines = createLineWiseStream()) {
