Author: ogrisel
Date: Tue Jan 17 18:54:33 2012
New Revision: 1232533

URL: http://svn.apache.org/viewvc?rev=1232533&view=rev
Log:
STANBOL-197: compute precision, recall and f1 score + averaging accross CV folds

Modified:
    
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
    
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java

Modified: 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
URL: 
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java?rev=1232533&r1=1232532&r2=1232533&view=diff
==============================================================================
--- 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
 (original)
+++ 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
 Tue Jan 17 18:54:33 2012
@@ -26,6 +26,7 @@ import java.util.Collection;
 import java.util.Collections;
 import java.util.Date;
 import java.util.Dictionary;
+import java.util.HashMap;
 import java.util.Hashtable;
 import java.util.Iterator;
 import java.util.LinkedHashSet;
@@ -838,11 +839,17 @@ public class TopicClassificationEngine e
 
             @Override
             public int process(List<SolrDocument> batch) throws 
TrainingSetException, ClassifierException {
+                int offset;
                 for (SolrDocument topicMetadata : batch) {
                     String topic = 
topicMetadata.getFirstValue(topicUriField).toString();
                     List<String> impactedTopics = new ArrayList<String>();
-                    int offset = 0;
+
                     Batch<String> examples = Batch.emtpyBatch(String.class);
+
+                    List<String> falseNegativeExamples = new 
ArrayList<String>();
+                    int truePositives = 0;
+                    int falseNegatives = 0;
+                    offset = 0;
                     do {
                         examples = 
trainingSet.getPositiveExamples(impactedTopics, examples.nextOffset);
                         for (String example : examples.items) {
@@ -853,16 +860,58 @@ public class TopicClassificationEngine e
                             }
                             offset++;
                             if 
(classifier.suggestTopics(example).contains(topic)) {
-                                // count positive success
+                                truePositives++;
                             } else {
-                                // collect false negatives
+                                falseNegatives++;
+                                // falseNegativeExamples.add(exampleId);
                             }
                         }
                     } while (examples.hasMore); // TODO: put a bound on the 
number of examples
 
-                    // TODO: handle false positives with negative examples here
+                    List<String> falsePositiveExamples = new 
ArrayList<String>();
+                    int trueNegatives = 0;
+                    int falsePositives = 0;
+                    offset = 0;
+                    do {
+                        examples = 
trainingSet.getNegativeExamples(impactedTopics, examples.nextOffset);
+                        for (String example : examples.items) {
+                            if (!(offset % foldCount == foldIndex)) {
+                                // TODO: change the dataset API to include 
exampleId
+                                // this example is not part of the test fold, 
skip it
+                                offset++;
+                                continue;
+                            }
+                            offset++;
+                            if 
(classifier.suggestTopics(example).contains(topic)) {
+                                falsePositives++;
+                                // TODO: change the dataset API to include 
exampleId
+                                // falsePositiveExamples.add(exampleId);
+                            } else {
+                                trueNegatives++;
+                            }
+                        }
+                    } while (examples.hasMore); // TODO: put a bound on the 
number of examples
 
-                    // TODO: store performance statistics for current model in 
the original classifier
+                    // compute precision, recall and f1 score for the current 
test fold and topic
+                    float precision = 0;
+                    if (truePositives != 0 || falsePositives != 0) {
+                        precision = truePositives / (float) (truePositives + 
falsePositives);
+                    }
+                    float recall = 0;
+                    if (trueNegatives != 0 || falseNegatives != 0) {
+                        recall = trueNegatives / (float) (trueNegatives + 
falseNegatives);
+                    }
+                    float f1 = 0;
+                    if (precision != 0 || recall != 0) {
+                        f1 = 2 * precision * recall / (precision + recall);
+                    }
+                    updatePerformanceMetadata(topic, precision, recall, f1, 
falsePositiveExamples,
+                        falseNegativeExamples);
+                }
+                try {
+                    getActiveSolrServer().commit();
+                } catch (Exception e) {
+                    throw new ClassifierException(e);
                 }
                 return batch.size();
             }
@@ -874,37 +923,122 @@ public class TopicClassificationEngine e
             cvFoldIndex + 1, cvFoldCount, engineId, (stop - start) / 1000.0, 
averageF1));
     }
 
+    /**
+     * Update the performance statistics in a metadata entry of a topic. It 
ist the responsibility of the
+     * caller to commit.
+     */
+    protected void updatePerformanceMetadata(String topicId,
+                                             float precision,
+                                             float recall,
+                                             float f1,
+                                             List<String> 
falsePositiveExamples,
+                                             List<String> 
falseNegativeExamples) throws ClassifierException {
+        SolrServer solrServer = getActiveSolrServer();
+        try {
+            SolrQuery query = new SolrQuery(entryTypeField + ":" + 
METADATA_ENTRY + " AND " + topicUriField
+                                            + ":" + 
ClientUtils.escapeQueryChars(topicId));
+            for (SolrDocument result : solrServer.query(query).getResults()) {
+                // there should be only one (or none: tolerated)
+                // fetch any old values to update (all metadata fields are 
assumed to be stored)s
+                Map<String,Collection<Object>> fieldValues = new 
HashMap<String,Collection<Object>>();
+                for (String fieldName : result.getFieldNames()) {
+                    fieldValues.put(fieldName, 
result.getFieldValues(fieldName));
+                }
+                addToList(fieldValues, precisionField, precision);
+                addToList(fieldValues, recallField, recall);
+                addToList(fieldValues, f1Field, f1);
+                // TODO: handle supports too...
+                // addToList(fieldValues, falsePositivesField, 
falsePositiveExamples);
+                // addToList(fieldValues, falseNegativesField, 
falseNegativeExamples);
+                SolrInputDocument newEntry = new SolrInputDocument();
+                for (Map.Entry<String,Collection<Object>> entry : 
fieldValues.entrySet()) {
+                    newEntry.addField(entry.getKey(), entry.getValue());
+                }
+                newEntry.setField(modelEvaluationDateField, 
UTCTimeStamper.nowUtcDate());
+                solrServer.add(newEntry);
+            }
+        } catch (Exception e) {
+            String msg = String.format(
+                "Error updating performance metadata for topic '%s' on Solr 
Core '%s'", topicId, solrCoreId);
+            throw new ClassifierException(msg, e);
+        }
+    }
+
+    @SuppressWarnings("unchecked")
+    protected void addToList(Map<String,Collection<Object>> fieldValues, 
String fieldName, Object value) {
+        Collection<Object> values = new ArrayList<Object>();
+        if (fieldValues.get(fieldName) != null) {
+            values.addAll(fieldValues.get(fieldName));
+        }
+        if (value instanceof Collection) {
+            values.addAll((Collection<Object>) value);
+        } else {
+            values.add(value);
+        }
+        fieldValues.put(fieldName, values);
+    }
+
     @Override
-    public ClassificationReport getPerformanceEstimates(String topic) throws 
ClassifierException {
+    public ClassificationReport getPerformanceEstimates(String topicId) throws 
ClassifierException {
 
         SolrServer solrServer = getActiveSolrServer();
-        SolrQuery query = new SolrQuery(entryIdField + ":" + METADATA_ENTRY + 
" AND " + topicUriField + ":"
-                                        + ClientUtils.escapeQueryChars(topic));
+        SolrQuery query = new SolrQuery(entryTypeField + ":" + METADATA_ENTRY 
+ " AND " + topicUriField + ":"
+                                        + 
ClientUtils.escapeQueryChars(topicId));
         try {
             SolrDocumentList results = solrServer.query(query).getResults();
             if (results.isEmpty()) {
-                throw new ClassifierException(String.format("%s is not a 
registered topic", topic));
+                throw new ClassifierException(String.format("'%s' is not a 
registered topic", topicId));
             }
             SolrDocument metadata = results.get(0);
-            float precision = (Float) metadata.getFirstValue(precisionField);
-            float recall = (Float) metadata.getFirstValue(recallField);
-            float f1 = (Float) metadata.getFirstValue(f1Field);
-            int positiveSupport = (Integer) 
metadata.getFirstValue(positiveSupportField);
-            int negativeSupport = (Integer) 
metadata.getFirstValue(negativeSupportField);
+            Float precision = computeMeanValue(metadata, precisionField);
+            Float recall = computeMeanValue(metadata, recallField);
+            Float f1 = computeMeanValue(metadata, f1Field);
+            int positiveSupport = computeSumValue(metadata, 
positiveSupportField);
+            int negativeSupport = computeSumValue(metadata, 
negativeSupportField);
             Date evaluationDate = (Date) 
metadata.getFirstValue(modelEvaluationDateField);
             boolean uptodate = evaluationDate != null;
             ClassificationReport report = new ClassificationReport(precision, 
recall, f1, positiveSupport,
                     negativeSupport, uptodate, evaluationDate);
-            for (Object falsePositiveId : 
metadata.getFieldValues(FALSE_POSITIVES_FIELD)) {
+            if (metadata.getFieldValues(falsePositivesField) == null) {
+                metadata.setField(falsePositivesField, new 
ArrayList<Object>());
+            }
+            for (Object falsePositiveId : 
metadata.getFieldValues(falsePositivesField)) {
                 report.falsePositiveExampleIds.add(falsePositiveId.toString());
             }
-            for (Object falseNegativeId : 
metadata.getFieldValues(FALSE_NEGATIVES_FIELD)) {
+            if (metadata.getFieldValues(falseNegativesField) == null) {
+                metadata.setField(falseNegativesField, new 
ArrayList<Object>());
+            }
+            for (Object falseNegativeId : 
metadata.getFieldValues(falseNegativesField)) {
                 report.falseNegativeExampleIds.add(falseNegativeId.toString());
             }
             return report;
         } catch (SolrServerException e) {
             throw new ClassifierException(String.format("Error fetching the 
performance report for topic "
-                                                        + topic));
+                                                        + topicId));
+        }
+    }
+
+    protected Float computeMeanValue(SolrDocument metadata, String fielName) {
+        Float mean = 0f;
+        Collection<Object> values = metadata.getFieldValues(fielName);
+        if (values == null || values.isEmpty()) {
+            return mean;
+        }
+        for (Object v : values) {
+            mean += (Float) v / values.size();
+        }
+        return mean;
+    }
+
+    protected Integer computeSumValue(SolrDocument metadata, String fielName) {
+        Integer sum = 0;
+        Collection<Object> values = metadata.getFieldValues(fielName);
+        if (values == null || values.isEmpty()) {
+            return sum;
+        }
+        for (Object v : values) {
+            sum += (Integer) v;
         }
+        return sum;
     }
 }

Modified: 
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
URL: 
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java?rev=1232533&r1=1232532&r2=1232533&view=diff
==============================================================================
--- 
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
 (original)
+++ 
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
 Tue Jan 17 18:54:33 2012
@@ -337,7 +337,46 @@ public class TopicEngineTest extends Bas
 
     }
 
-    //@Test
+    @Test
+    public void testUpdatePerformanceEstimates() throws Exception {
+        ClassificationReport performanceEstimates;
+        // no registered topic
+        try {
+            classifier.getPerformanceEstimates("urn:t/001");
+            fail("Should have raised ClassifierException");
+        } catch (ClassifierException e) {
+            // expected
+        }
+
+        // register some topics
+        classifier.addTopic("urn:t/001", null);
+        classifier.addTopic("urn:t/002", Arrays.asList("urn:t/001"));
+        performanceEstimates = classifier.getPerformanceEstimates("urn:t/002");
+        assertFalse(performanceEstimates.uptodate);
+
+        // update the performance metadata manually
+        classifier.updatePerformanceMetadata("urn:t/002", 0.76f, 0.60f, 0.67f, 
Arrays.asList("ex14", "ex78"),
+            Arrays.asList("ex34", "ex23", "ex89"));
+        classifier.getActiveSolrServer().commit();
+        performanceEstimates = classifier.getPerformanceEstimates("urn:t/002");
+        assertTrue(performanceEstimates.uptodate);
+        assertEquals(Float.valueOf(0.76f), 
Float.valueOf(performanceEstimates.precision));
+        assertEquals(Float.valueOf(0.60f), 
Float.valueOf(performanceEstimates.recall));
+        assertEquals(Float.valueOf(0.67f), 
Float.valueOf(performanceEstimates.f1));
+        
assertTrue(classifier.getBroaderTopics("urn:t/002").contains("urn:t/001"));
+
+        // accumulate other folds statistics and compute means of statistics
+        classifier.updatePerformanceMetadata("urn:t/002", 0.79f, 0.63f, 0.72f, 
Arrays.asList("ex1", "ex5"),
+            Arrays.asList("ex3", "ex10", "ex11"));
+        classifier.getActiveSolrServer().commit();
+        performanceEstimates = classifier.getPerformanceEstimates("urn:t/002");
+        assertTrue(performanceEstimates.uptodate);
+        assertEquals(Float.valueOf(0.775f), 
Float.valueOf(performanceEstimates.precision));
+        assertEquals(Float.valueOf(0.615f), 
Float.valueOf(performanceEstimates.recall));
+        assertEquals(Float.valueOf(0.69500005f), 
Float.valueOf(performanceEstimates.f1));
+    }
+
+    // @Test
     public void testCrossValidation() throws Exception {
         // seed a pseudo random number generator for reproducible tests
         Random rng = new Random(0);


Reply via email to