This is an automated email from the ASF dual-hosted git repository.

rzo1 pushed a commit to branch fix-eval-tests-after-pos-tagging-conversion
in repository https://gitbox.apache.org/repos/asf/opennlp.git

commit 249197703e1f6d051b79ad8788bdeeb9cf19e273
Author: Richard Zowalla <[email protected]>
AuthorDate: Wed May 29 20:43:08 2024 +0200

    OPENNLP-1564 - Correct POSTagFormat to PENN format for some of the eval 
tests. Allow specification of POSTagFormat for POSTaggerCrossValidator.
---
 .../tools/postag/POSTaggerCrossValidator.java      | 48 ++++++++++++++++++++--
 .../opennlp/tools/eval/ConllXPosTaggerEval.java    |  3 +-
 .../tools/eval/OntoNotes4PosTaggerEval.java        |  4 +-
 .../opennlp/tools/eval/SourceForgeModelEval.java   |  3 +-
 4 files changed, 52 insertions(+), 6 deletions(-)

diff --git 
a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java 
b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java
index 96ffa10b..7fecbe63 100644
--- 
a/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java
+++ 
b/opennlp-tools/src/main/java/opennlp/tools/postag/POSTaggerCrossValidator.java
@@ -31,6 +31,7 @@ public class POSTaggerCrossValidator {
   private final String languageCode;
 
   private final TrainingParameters params;
+  private final POSTagFormat posTagFormat;
 
   private byte[] featureGeneratorBytes;
   private Map<String, Object> resources;
@@ -46,6 +47,7 @@ public class POSTaggerCrossValidator {
   private final Integer tagdicCutoff;
   private File tagDictionaryFile;
 
+
   /**
    * Initializes a {@link POSTaggerCrossValidator} that builds a ngram 
dictionary
    * dynamically. It instantiates a subclass of {@link POSTaggerFactory} using
@@ -57,12 +59,13 @@ public class POSTaggerCrossValidator {
    * @param featureGeneratorBytes The bytes for feature generation.
    * @param resources  Additional resources as key-value map.
    * @param factoryClass The class name used for factory instantiation.
+   * @param format   A valid {@link POSTagFormat}.
    * @param listeners The {@link POSTaggerEvaluationMonitor evaluation 
listeners}.
    */
   public POSTaggerCrossValidator(String languageCode,
                                  TrainingParameters trainParam, File 
tagDictionary,
                                  byte[] featureGeneratorBytes, Map<String, 
Object> resources,
-                                 Integer tagdicCutoff, String factoryClass,
+                                 Integer tagdicCutoff, String factoryClass, 
POSTagFormat format,
                                  POSTaggerEvaluationMonitor... listeners) {
     this.languageCode = languageCode;
     this.params = trainParam;
@@ -72,6 +75,28 @@ public class POSTaggerCrossValidator {
     this.factoryClassName = factoryClass;
     this.tagdicCutoff = tagdicCutoff;
     this.tagDictionaryFile = tagDictionary;
+    this.posTagFormat = format;
+  }
+  /**
+   * Initializes a {@link POSTaggerCrossValidator} that builds a ngram 
dictionary
+   * dynamically. It instantiates a subclass of {@link POSTaggerFactory} using
+   * the tag and the ngram dictionaries.
+   *
+   * @param languageCode An ISO conform language code.
+   * @param trainParam The {@link TrainingParameters} for the context of cross 
validation.
+   * @param tagDictionary The {@link File} that references the a {@link 
TagDictionary}.
+   * @param featureGeneratorBytes The bytes for feature generation.
+   * @param resources  Additional resources as key-value map.
+   * @param factoryClass The class name used for factory instantiation.
+   * @param listeners The {@link POSTaggerEvaluationMonitor evaluation 
listeners}.
+   */
+  public POSTaggerCrossValidator(String languageCode,
+                                 TrainingParameters trainParam, File 
tagDictionary,
+                                 byte[] featureGeneratorBytes, Map<String, 
Object> resources,
+                                 Integer tagdicCutoff, String factoryClass,
+                                 POSTaggerEvaluationMonitor... listeners) {
+    this(languageCode, trainParam, tagDictionary, featureGeneratorBytes, 
resources,
+        tagdicCutoff, factoryClass, POSTagFormat.UD, listeners);
   }
 
 
@@ -86,10 +111,27 @@ public class POSTaggerCrossValidator {
   public POSTaggerCrossValidator(String languageCode,
       TrainingParameters trainParam, POSTaggerFactory factory,
       POSTaggerEvaluationMonitor... listeners) {
+    this(languageCode, trainParam, factory, POSTagFormat.UD, listeners);
+  }
+
+
+  /**
+   * Creates a {@link POSTaggerCrossValidator} using the given {@link 
POSTaggerFactory}.
+   *
+   * @param languageCode An ISO conform language code.
+   * @param trainParam The {@link TrainingParameters} for the context of cross 
validation.
+   * @param factory The {@link POSTaggerFactory} to be used.
+   * @param format   A valid {@link POSTagFormat}.
+   * @param listeners The {@link POSTaggerEvaluationMonitor evaluation 
listeners}.
+   */
+  public POSTaggerCrossValidator(String languageCode,
+                                 TrainingParameters trainParam, 
POSTaggerFactory factory, POSTagFormat format,
+                                 POSTaggerEvaluationMonitor... listeners) {
     this.languageCode = languageCode;
     this.params = trainParam;
     this.listeners = listeners;
     this.factory = factory;
+    this.posTagFormat = format;
     this.tagdicCutoff = null;
   }
 
@@ -142,7 +184,7 @@ public class POSTaggerCrossValidator {
       POSModel model = POSTaggerME.train(languageCode, trainingSampleStream,
           params, this.factory);
 
-      POSEvaluator evaluator = new POSEvaluator(new POSTaggerME(model), 
listeners);
+      POSEvaluator evaluator = new POSEvaluator(new POSTaggerME(model, 
posTagFormat), listeners);
 
       evaluator.evaluate(trainingSampleStream.getTestSampleStream());
 
@@ -169,5 +211,5 @@ public class POSTaggerCrossValidator {
   public long getWordCount() {
     return wordAccuracy.count();
   }
-  
+
 }
diff --git 
a/opennlp-tools/src/test/java/opennlp/tools/eval/ConllXPosTaggerEval.java 
b/opennlp-tools/src/test/java/opennlp/tools/eval/ConllXPosTaggerEval.java
index ac71faf8..635050ac 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/eval/ConllXPosTaggerEval.java
+++ b/opennlp-tools/src/test/java/opennlp/tools/eval/ConllXPosTaggerEval.java
@@ -31,6 +31,7 @@ import opennlp.tools.formats.ConllXPOSSampleStream;
 import opennlp.tools.postag.POSEvaluator;
 import opennlp.tools.postag.POSModel;
 import opennlp.tools.postag.POSSample;
+import opennlp.tools.postag.POSTagFormat;
 import opennlp.tools.postag.POSTaggerFactory;
 import opennlp.tools.postag.POSTaggerME;
 import opennlp.tools.util.MarkableFileInputStreamFactory;
@@ -73,7 +74,7 @@ public class ConllXPosTaggerEval extends AbstractEvalTest {
     ObjectStream<POSSample> samples = new ConllXPOSSampleStream(
         new MarkableFileInputStreamFactory(testData), StandardCharsets.UTF_8);
 
-    POSEvaluator evaluator = new POSEvaluator(new POSTaggerME(model));
+    POSEvaluator evaluator = new POSEvaluator(new POSTaggerME(model, 
POSTagFormat.PENN));
     evaluator.evaluate(samples);
 
     Assertions.assertEquals(expectedAccuracy, evaluator.getWordAccuracy(), 
0.0001);
diff --git 
a/opennlp-tools/src/test/java/opennlp/tools/eval/OntoNotes4PosTaggerEval.java 
b/opennlp-tools/src/test/java/opennlp/tools/eval/OntoNotes4PosTaggerEval.java
index 252080d3..984fdf03 100644
--- 
a/opennlp-tools/src/test/java/opennlp/tools/eval/OntoNotes4PosTaggerEval.java
+++ 
b/opennlp-tools/src/test/java/opennlp/tools/eval/OntoNotes4PosTaggerEval.java
@@ -32,6 +32,7 @@ import opennlp.tools.formats.convert.ParseToPOSSampleStream;
 import opennlp.tools.formats.ontonotes.DocumentToLineStream;
 import opennlp.tools.formats.ontonotes.OntoNotesParseSampleStream;
 import opennlp.tools.postag.POSSample;
+import opennlp.tools.postag.POSTagFormat;
 import opennlp.tools.postag.POSTaggerCrossValidator;
 import opennlp.tools.postag.POSTaggerFactory;
 import opennlp.tools.util.ObjectStream;
@@ -59,7 +60,8 @@ public class OntoNotes4PosTaggerEval extends AbstractEvalTest 
{
   private void crossEval(TrainingParameters params, double expectedScore)
       throws IOException {
     try (ObjectStream<POSSample> samples = createPOSSampleStream()) {
-      POSTaggerCrossValidator cv = new POSTaggerCrossValidator("eng", params, 
new POSTaggerFactory());
+      POSTaggerCrossValidator cv = new POSTaggerCrossValidator("eng", params,
+          new POSTaggerFactory(), POSTagFormat.PENN);
       cv.evaluate(samples, 5);
 
       Assertions.assertEquals(expectedScore, cv.getWordAccuracy(), 0.0001d);
diff --git 
a/opennlp-tools/src/test/java/opennlp/tools/eval/SourceForgeModelEval.java 
b/opennlp-tools/src/test/java/opennlp/tools/eval/SourceForgeModelEval.java
index 36005707..3bd4213b 100644
--- a/opennlp-tools/src/test/java/opennlp/tools/eval/SourceForgeModelEval.java
+++ b/opennlp-tools/src/test/java/opennlp/tools/eval/SourceForgeModelEval.java
@@ -44,6 +44,7 @@ import opennlp.tools.parser.ParserFactory;
 import opennlp.tools.parser.ParserModel;
 import opennlp.tools.postag.POSModel;
 import opennlp.tools.postag.POSSample;
+import opennlp.tools.postag.POSTagFormat;
 import opennlp.tools.postag.POSTagger;
 import opennlp.tools.postag.POSTaggerME;
 import opennlp.tools.sentdetect.SentenceDetector;
@@ -352,7 +353,7 @@ public class SourceForgeModelEval extends AbstractEvalTest {
 
     MessageDigest digest = MessageDigest.getInstance(HASH_ALGORITHM);
 
-    POSTagger tagger = new POSTaggerME(model);
+    POSTagger tagger = new POSTaggerME(model, POSTagFormat.PENN);
 
     try (ObjectStream<LeipzigTestSample> lines = createLineWiseStream()) {
 

Reply via email to