This is an automated email from the ASF dual-hosted git repository.
rzo1 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/opennlp.git
The following commit(s) were added to refs/heads/main by this push:
new 7a7e013f OPENNLP-1564 - Fix Evaluation Tests after POSFormat Change
(#603)
7a7e013f is described below
commit 7a7e013f7a431022f06c9f17e3c19aea8888f904
Author: Richard Zowalla <[email protected]>
AuthorDate: Thu May 30 07:01:58 2024 +0200
OPENNLP-1564 - Fix Evaluation Tests after POSFormat Change (#603)
* OPENNLP-1564 - Correct POSTagFormat to PENN format for some of the eval
tests. Allow specification of POSTagFormat for POSTaggerCrossValidator.
* OPENNLP-1564 - In case we are loading a POSModel from a
POSTaggerNameFeatureGenerator (e.g. defined in XML), we need to actually guess
the format out of a given POSModel to initalize the POSTagger correctly.
---
.../opennlp/tools/postag/POSTagFormatMapper.java | 12 ++++++
.../tools/postag/POSTaggerCrossValidator.java | 48 ++++++++++++++++++++--
.../featuregen/POSTaggerNameFeatureGenerator.java | 4 +-
.../opennlp/tools/eval/ConllXPosTaggerEval.java | 3 +-
.../tools/eval/OntoNotes4PosTaggerEval.java | 4 +-
.../opennlp/tools/eval/SourceForgeModelEval.java | 3 +-
6 files changed, 66 insertions(+), 8 deletions(-)
diff --git
a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTagFormatMapper.java
b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTagFormatMapper.java
index e02cb520..b01a36ac 100644
--- a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTagFormatMapper.java
+++ b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTagFormatMapper.java
@@ -206,4 +206,16 @@ public class POSTagFormatMapper {
return POSTagFormat.UNKNOWN;
}
}
+
+ /**
+ * Guesses the {@link POSTagFormat} of a given {@link POSModel}
+ * @param posModel must not be {@code null}.
+ * @return the guessed {@link POSTagFormat}.
+ */
+ public static POSTagFormat guessFormat(POSModel posModel) {
+ Objects.requireNonNull(posModel, "POSModel must not be NULL.");
+ Objects.requireNonNull(posModel.getPosSequenceModel(), "POSSequenceModel
must not be NULL.");
+ final POSTagFormatMapper mapper = new
POSTagFormatMapper(posModel.getPosSequenceModel().getOutcomes());
+ return mapper.getGuessedFormat();
+ }
}
diff --git
a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java
b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java
index 96ffa10b..7fecbe63 100644
---
a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java
+++
b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java
@@ -31,6 +31,7 @@ public class POSTaggerCrossValidator {
private final String languageCode;
private final TrainingParameters params;
+ private final POSTagFormat posTagFormat;
private byte[] featureGeneratorBytes;
private Map<String, Object> resources;
@@ -46,6 +47,7 @@ public class POSTaggerCrossValidator {
private final Integer tagdicCutoff;
private File tagDictionaryFile;
+
/**
* Initializes a {@link POSTaggerCrossValidator} that builds a ngram
dictionary
* dynamically. It instantiates a subclass of {@link POSTaggerFactory} using
@@ -57,12 +59,13 @@ public class POSTaggerCrossValidator {
* @param featureGeneratorBytes The bytes for feature generation.
* @param resources Additional resources as key-value map.
* @param factoryClass The class name used for factory instantiation.
+ * @param format A valid {@link POSTagFormat}.
* @param listeners The {@link POSTaggerEvaluationMonitor evaluation
listeners}.
*/
public POSTaggerCrossValidator(String languageCode,
TrainingParameters trainParam, File
tagDictionary,
byte[] featureGeneratorBytes, Map<String,
Object> resources,
- Integer tagdicCutoff, String factoryClass,
+ Integer tagdicCutoff, String factoryClass,
POSTagFormat format,
POSTaggerEvaluationMonitor... listeners) {
this.languageCode = languageCode;
this.params = trainParam;
@@ -72,6 +75,28 @@ public class POSTaggerCrossValidator {
this.factoryClassName = factoryClass;
this.tagdicCutoff = tagdicCutoff;
this.tagDictionaryFile = tagDictionary;
+ this.posTagFormat = format;
+ }
+ /**
+ * Initializes a {@link POSTaggerCrossValidator} that builds a ngram
dictionary
+ * dynamically. It instantiates a subclass of {@link POSTaggerFactory} using
+ * the tag and the ngram dictionaries.
+ *
+ * @param languageCode An ISO conform language code.
+ * @param trainParam The {@link TrainingParameters} for the context of cross
validation.
+ * @param tagDictionary The {@link File} that references the a {@link
TagDictionary}.
+ * @param featureGeneratorBytes The bytes for feature generation.
+ * @param resources Additional resources as key-value map.
+ * @param factoryClass The class name used for factory instantiation.
+ * @param listeners The {@link POSTaggerEvaluationMonitor evaluation
listeners}.
+ */
+ public POSTaggerCrossValidator(String languageCode,
+ TrainingParameters trainParam, File
tagDictionary,
+ byte[] featureGeneratorBytes, Map<String,
Object> resources,
+ Integer tagdicCutoff, String factoryClass,
+ POSTaggerEvaluationMonitor... listeners) {
+ this(languageCode, trainParam, tagDictionary, featureGeneratorBytes,
resources,
+ tagdicCutoff, factoryClass, POSTagFormat.UD, listeners);
}
@@ -86,10 +111,27 @@ public class POSTaggerCrossValidator {
public POSTaggerCrossValidator(String languageCode,
TrainingParameters trainParam, POSTaggerFactory factory,
POSTaggerEvaluationMonitor... listeners) {
+ this(languageCode, trainParam, factory, POSTagFormat.UD, listeners);
+ }
+
+
+ /**
+ * Creates a {@link POSTaggerCrossValidator} using the given {@link
POSTaggerFactory}.
+ *
+ * @param languageCode An ISO conform language code.
+ * @param trainParam The {@link TrainingParameters} for the context of cross
validation.
+ * @param factory The {@link POSTaggerFactory} to be used.
+ * @param format A valid {@link POSTagFormat}.
+ * @param listeners The {@link POSTaggerEvaluationMonitor evaluation
listeners}.
+ */
+ public POSTaggerCrossValidator(String languageCode,
+ TrainingParameters trainParam,
POSTaggerFactory factory, POSTagFormat format,
+ POSTaggerEvaluationMonitor... listeners) {
this.languageCode = languageCode;
this.params = trainParam;
this.listeners = listeners;
this.factory = factory;
+ this.posTagFormat = format;
this.tagdicCutoff = null;
}
@@ -142,7 +184,7 @@ public class POSTaggerCrossValidator {
POSModel model = POSTaggerME.train(languageCode, trainingSampleStream,
params, this.factory);
- POSEvaluator evaluator = new POSEvaluator(new POSTaggerME(model),
listeners);
+ POSEvaluator evaluator = new POSEvaluator(new POSTaggerME(model,
posTagFormat), listeners);
evaluator.evaluate(trainingSampleStream.getTestSampleStream());
@@ -169,5 +211,5 @@ public class POSTaggerCrossValidator {
public long getWordCount() {
return wordAccuracy.count();
}
-
+
}
diff --git
a/opennlp-tools/src/main/java/opennlp/tools/util/featuregen/POSTaggerNameFeatureGenerator.java
b/opennlp-tools/src/main/java/opennlp/tools/util/featuregen/POSTaggerNameFeatureGenerator.java
index cc30ad68..9b67f684 100644
---
a/opennlp-tools/src/main/java/opennlp/tools/util/featuregen/POSTaggerNameFeatureGenerator.java
+++
b/opennlp-tools/src/main/java/opennlp/tools/util/featuregen/POSTaggerNameFeatureGenerator.java
@@ -22,6 +22,7 @@ import java.util.Arrays;
import java.util.List;
import opennlp.tools.postag.POSModel;
+import opennlp.tools.postag.POSTagFormatMapper;
import opennlp.tools.postag.POSTagger;
import opennlp.tools.postag.POSTaggerME;
@@ -50,8 +51,7 @@ public class POSTaggerNameFeatureGenerator implements
AdaptiveFeatureGenerator {
* @param aPosModel a POSTagger model.
*/
public POSTaggerNameFeatureGenerator(POSModel aPosModel) {
-
- this.posTagger = new POSTaggerME(aPosModel);
+ this.posTagger = new POSTaggerME(aPosModel,
POSTagFormatMapper.guessFormat(aPosModel));
}
diff --git
a/opennlp-tools/src/test/java/opennlp/tools/eval/ConllXPosTaggerEval.java
b/opennlp-tools/src/test/java/opennlp/tools/eval/ConllXPosTaggerEval.java
index ac71faf8..635050ac 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/eval/ConllXPosTaggerEval.java
+++ b/opennlp-tools/src/test/java/opennlp/tools/eval/ConllXPosTaggerEval.java
@@ -31,6 +31,7 @@ import opennlp.tools.formats.ConllXPOSSampleStream;
import opennlp.tools.postag.POSEvaluator;
import opennlp.tools.postag.POSModel;
import opennlp.tools.postag.POSSample;
+import opennlp.tools.postag.POSTagFormat;
import opennlp.tools.postag.POSTaggerFactory;
import opennlp.tools.postag.POSTaggerME;
import opennlp.tools.util.MarkableFileInputStreamFactory;
@@ -73,7 +74,7 @@ public class ConllXPosTaggerEval extends AbstractEvalTest {
ObjectStream<POSSample> samples = new ConllXPOSSampleStream(
new MarkableFileInputStreamFactory(testData), StandardCharsets.UTF_8);
- POSEvaluator evaluator = new POSEvaluator(new POSTaggerME(model));
+ POSEvaluator evaluator = new POSEvaluator(new POSTaggerME(model,
POSTagFormat.PENN));
evaluator.evaluate(samples);
Assertions.assertEquals(expectedAccuracy, evaluator.getWordAccuracy(),
0.0001);
diff --git
a/opennlp-tools/src/test/java/opennlp/tools/eval/OntoNotes4PosTaggerEval.java
b/opennlp-tools/src/test/java/opennlp/tools/eval/OntoNotes4PosTaggerEval.java
index 252080d3..984fdf03 100644
---
a/opennlp-tools/src/test/java/opennlp/tools/eval/OntoNotes4PosTaggerEval.java
+++
b/opennlp-tools/src/test/java/opennlp/tools/eval/OntoNotes4PosTaggerEval.java
@@ -32,6 +32,7 @@ import opennlp.tools.formats.convert.ParseToPOSSampleStream;
import opennlp.tools.formats.ontonotes.DocumentToLineStream;
import opennlp.tools.formats.ontonotes.OntoNotesParseSampleStream;
import opennlp.tools.postag.POSSample;
+import opennlp.tools.postag.POSTagFormat;
import opennlp.tools.postag.POSTaggerCrossValidator;
import opennlp.tools.postag.POSTaggerFactory;
import opennlp.tools.util.ObjectStream;
@@ -59,7 +60,8 @@ public class OntoNotes4PosTaggerEval extends AbstractEvalTest
{
private void crossEval(TrainingParameters params, double expectedScore)
throws IOException {
try (ObjectStream<POSSample> samples = createPOSSampleStream()) {
- POSTaggerCrossValidator cv = new POSTaggerCrossValidator("eng", params,
new POSTaggerFactory());
+ POSTaggerCrossValidator cv = new POSTaggerCrossValidator("eng", params,
+ new POSTaggerFactory(), POSTagFormat.PENN);
cv.evaluate(samples, 5);
Assertions.assertEquals(expectedScore, cv.getWordAccuracy(), 0.0001d);
diff --git
a/opennlp-tools/src/test/java/opennlp/tools/eval/SourceForgeModelEval.java
b/opennlp-tools/src/test/java/opennlp/tools/eval/SourceForgeModelEval.java
index 36005707..3bd4213b 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/eval/SourceForgeModelEval.java
+++ b/opennlp-tools/src/test/java/opennlp/tools/eval/SourceForgeModelEval.java
@@ -44,6 +44,7 @@ import opennlp.tools.parser.ParserFactory;
import opennlp.tools.parser.ParserModel;
import opennlp.tools.postag.POSModel;
import opennlp.tools.postag.POSSample;
+import opennlp.tools.postag.POSTagFormat;
import opennlp.tools.postag.POSTagger;
import opennlp.tools.postag.POSTaggerME;
import opennlp.tools.sentdetect.SentenceDetector;
@@ -352,7 +353,7 @@ public class SourceForgeModelEval extends AbstractEvalTest {
MessageDigest digest = MessageDigest.getInstance(HASH_ALGORITHM);
- POSTagger tagger = new POSTaggerME(model);
+ POSTagger tagger = new POSTaggerME(model, POSTagFormat.PENN);
try (ObjectStream<LeipzigTestSample> lines = createLineWiseStream()) {