Author: ogrisel
Date: Fri Jan 13 18:58:24 2012
New Revision: 1231246

URL: http://svn.apache.org/viewvc?rev=1231246&view=rev
Log:
STANBOL-197: WIP: TDD for cross validation-based evaluation

Modified:
    
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
    
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/ClassificationReport.java
    
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicClassifier.java
    
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java

Modified: 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
URL: 
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java?rev=1231246&r1=1231245&r2=1231246&view=diff
==============================================================================
--- 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
 (original)
+++ 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
 Fri Jan 13 18:58:24 2012
@@ -677,8 +677,12 @@ public class TopicClassificationEngine e
 
     @Override
     public void setCrossValidationInfo(int foldIndex, int foldCount) {
-        // TODO Auto-generated method stub
-
+        if (foldIndex > foldCount - 1) {
+            throw new IllegalArgumentException(String.format(
+                "foldIndex=%d should be smaller than foldCount=%d - 1", 
foldIndex, foldCount));
+        }
+        cvFoldIndex = foldIndex;
+        cvFoldCount = foldCount;
     }
 
     @Override
@@ -693,13 +697,36 @@ public class TopicClassificationEngine e
 
     }
 
-    public void updatePerformanceEstimates(boolean incremental) throws 
ClassifierException, TrainingSetException {
-        
+    public int updatePerformanceEstimates(boolean incremental) throws 
ClassifierException,
+                                                              
TrainingSetException {
+        int updatedTopics = 0;
+        // TODO
+        return updatedTopics;
     }
 
     @Override
     public ClassificationReport getPerformanceEstimates(String topic) throws 
ClassifierException {
-        // TODO Auto-generated method stub
-        return null;
+
+        SolrServer solrServer = getActiveSolrServer();
+        SolrQuery query = new SolrQuery(entryIdField + ":" + METADATA_ENTRY + 
" AND " + topicUriField + ":"
+                                        + ClientUtils.escapeQueryChars(topic));
+        try {
+            SolrDocumentList results = solrServer.query(query).getResults();
+            if (results.isEmpty()) {
+                throw new ClassifierException(String.format("%s is not a 
registered topic", topic));
+            }
+            SolrDocument metadata = results.get(0);
+            float precision = (Float) metadata.getFirstValue(precisionField);
+            float recall = (Float) metadata.getFirstValue(recallField);
+            float f1 = (Float) metadata.getFirstValue(f1Field);
+            // int positiveSupport = (Integer) metadata.getFirstValue(po);
+            // int negativeSupport = 0;
+            Date evaluationDate = (Date) 
metadata.getFirstValue(modelEvaluationDateField);
+            boolean uptodate = evaluationDate != null;
+            return new ClassificationReport(precision, recall, f1, 0, 0, 
uptodate, evaluationDate);
+        } catch (SolrServerException e) {
+            throw new ClassifierException(String.format("Error fetching the 
performance report for topic "
+                                                        + topic));
+        }
     }
 }

Modified: 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/ClassificationReport.java
URL: 
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/ClassificationReport.java?rev=1231246&r1=1231245&r2=1231246&view=diff
==============================================================================
--- 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/ClassificationReport.java
 (original)
+++ 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/ClassificationReport.java
 Fri Jan 13 18:58:24 2012
@@ -17,6 +17,7 @@
 package org.apache.stanbol.enhancer.topic;
 
 import java.util.ArrayList;
+import java.util.Date;
 import java.util.List;
 
 /**
@@ -65,6 +66,10 @@ public class ClassificationReport {
      */
     public final int negativeSupport;
 
+    public final boolean uptodate;
+
+    public final Date evaluationDate;
+
     public final List<String> falsePositiveExampleIds = new 
ArrayList<String>();
 
     public final List<String> falseNegativeExampleIds = new 
ArrayList<String>();
@@ -73,12 +78,16 @@ public class ClassificationReport {
                                 float recall,
                                 float f1,
                                 int positiveSupport,
-                                int negativeSupport) {
+                                int negativeSupport,
+                                boolean uptodate,
+                                Date evaluationDate) {
         this.precision = precision;
         this.recall = recall;
         this.f1 = f1;
         this.positiveSupport = positiveSupport;
         this.negativeSupport = negativeSupport;
+        this.uptodate = uptodate;
+        this.evaluationDate = evaluationDate;
     }
 
 }

Modified: 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicClassifier.java
URL: 
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicClassifier.java?rev=1231246&r1=1231245&r2=1231246&view=diff
==============================================================================
--- 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicClassifier.java
 (original)
+++ 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicClassifier.java
 Fri Jan 13 18:58:24 2012
@@ -112,9 +112,11 @@ public interface TopicClassifier {
     /**
      * Perform k-fold cross validation of the model to compute estimates of 
the precision, recall and f1
      * score.
+     * 
+     * @return number of updated topics
      */
-    public void updatePerformanceEstimates(boolean incremental) throws 
ClassifierException,
-                                                               
TrainingSetException;
+    public int updatePerformanceEstimates(boolean incremental) throws 
ClassifierException,
+                                                              
TrainingSetException;
 
     /**
      * Tell the classifier which slice of data to keep aside while training 
for model evaluation using k-folds

Modified: 
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
URL: 
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java?rev=1231246&r1=1231245&r2=1231246&view=diff
==============================================================================
--- 
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
 (original)
+++ 
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
 Fri Jan 13 18:58:24 2012
@@ -17,6 +17,7 @@
 package org.apache.stanbol.enhancer.engine.topic;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
@@ -28,17 +29,24 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Hashtable;
 import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.TreeMap;
 
 import org.apache.commons.io.FileUtils;
 import org.apache.commons.io.IOUtils;
+import org.apache.commons.lang.StringUtils;
 import org.apache.solr.client.solrj.SolrQuery;
 import org.apache.solr.client.solrj.SolrServerException;
 import org.apache.solr.client.solrj.embedded.EmbeddedSolrServer;
 import org.apache.solr.client.solrj.response.QueryResponse;
 import org.apache.solr.common.params.CommonParams;
 import org.apache.stanbol.commons.solr.utils.StreamQueryRequest;
+import org.apache.stanbol.enhancer.topic.ClassificationReport;
+import org.apache.stanbol.enhancer.topic.ClassifierException;
 import org.apache.stanbol.enhancer.topic.SolrTrainingSet;
 import org.apache.stanbol.enhancer.topic.TopicSuggestion;
+import org.apache.stanbol.enhancer.topic.TrainingSetException;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
@@ -321,12 +329,145 @@ public class TopicEngineTest extends Bas
         classifier.addTopic(law, null);
         assertEquals(1, classifier.updateModel(true));
         assertEquals(0, classifier.updateModel(true));
-        
+
         // registering new subtopics invalidate the models of the parent as 
well
         classifier.addTopic("urn:topics/sportsmafia", Arrays.asList(football, 
business));
         assertEquals(3, classifier.updateModel(true));
         assertEquals(0, classifier.updateModel(true));
-        
+
+    }
+
+    //@Test
+    public void testCrossValidation() throws Exception {
+        // seed a pseudo random number generator for reproducible tests
+        Random rng = new Random(0);
+        ClassificationReport performanceEstimates;
+
+        // build an artificial data set used for training models and evaluation
+        int numberOfTopics = 10;
+        int numberOfDocuments = 100;
+        int vocabSizeMin = 10;
+        int vocabSizeMax = 25; // we are using the alphabet as base terms
+        initArtificialTrainingSet(numberOfTopics, numberOfDocuments, 
vocabSizeMin, vocabSizeMax, rng);
+
+        // by default the reports are not computed
+        performanceEstimates = classifier.getPerformanceEstimates("urn:t/001");
+        assertFalse(performanceEstimates.uptodate);
+        performanceEstimates = classifier.getPerformanceEstimates("urn:t/002");
+        assertFalse(performanceEstimates.uptodate);
+        performanceEstimates = classifier.getPerformanceEstimates("urn:t/003");
+        assertFalse(performanceEstimates.uptodate);
+
+        try {
+            classifier.getPerformanceEstimates("urn:doesnotexist");
+            fail("Should have raised a ClassifierException");
+        } catch (ClassifierException e) {
+            // expected
+        }
+
+        // let's evaluate the first topic manually
+        assertEquals(numberOfTopics, 
classifier.updatePerformanceEstimates(true));
+        performanceEstimates = classifier.getPerformanceEstimates("urn:t/001");
+        assertTrue(performanceEstimates.uptodate);
+        assertGreater(performanceEstimates.precision, 0.8f);
+        assertGreater(performanceEstimates.recall, 0.8f);
+        assertGreater(performanceEstimates.f1, 0.8f);
+        assertGreater(performanceEstimates.positiveSupport, 10);
+        assertGreater(performanceEstimates.negativeSupport, 90);
+        assertNotNull(performanceEstimates.evaluationDate);
+
+        performanceEstimates = classifier.getPerformanceEstimates("urn:t/002");
+        assertTrue(performanceEstimates.uptodate);
+        assertGreater(performanceEstimates.precision, 0.8f);
+        assertGreater(performanceEstimates.recall, 0.8f);
+        assertGreater(performanceEstimates.f1, 0.8f);
+        assertGreater(performanceEstimates.positiveSupport, 10);
+        assertGreater(performanceEstimates.negativeSupport, 90);
+        assertNotNull(performanceEstimates.evaluationDate);
+
+        performanceEstimates = classifier.getPerformanceEstimates("urn:t/003");
+        assertTrue(performanceEstimates.uptodate);
+        assertGreater(performanceEstimates.precision, 0.8f);
+        assertGreater(performanceEstimates.recall, 0.8f);
+        assertGreater(performanceEstimates.f1, 0.8f);
+        assertGreater(performanceEstimates.positiveSupport, 10);
+        assertGreater(performanceEstimates.negativeSupport, 90);
+        assertNotNull(performanceEstimates.evaluationDate);
+
+        // TODO: test model invalidation by registering a sub topic manually
+    }
+
+    protected void assertGreater(float large, float small) {
+        if (small > large) {
+            throw new AssertionError(String.format("Expected %f to be greater 
than %f.", large, small));
+        }
+    }
+
+    protected void initArtificialTrainingSet(int numberOfTopics,
+                                             int numberOfDocuments,
+                                             int vocabSizeMin,
+                                             int vocabSizeMax,
+                                             Random rng) throws 
ClassifierException, TrainingSetException {
+        // define some artificial topics and register them to the classifier
+        char[] alphabet = "abcdefghijklmnopqrstuvwxyz".toCharArray();
+        String[] topics = new String[numberOfTopics];
+        Map<String,String[]> vocabularies = new TreeMap<String,String[]>();
+        for (int i = 0; i < numberOfTopics; i++) {
+            String topic = String.format("urn:t/%03d", i);
+            topics[i] = topic;
+            classifier.addTopic(topic, null);
+            int vocSize = rng.nextInt(vocabSizeMax + 1 - vocabSizeMin) + 
vocabSizeMin;
+            String[] terms = new String[vocSize];
+
+            for (int j = 0; j < vocSize; j++) {
+                // define some artificial vocabulary for each topic to 
automatically generate random examples
+                // with some topic structure
+                // if i = 1, will generate: ["a1", "b1", "c1", ...]
+                terms[j] = alphabet[j] + String.valueOf(i);
+            }
+            vocabularies.put(topic, terms);
+        }
+        classifier.setTrainingSet(trainingSet);
+
+        // build a random data where each example has a couple of dominating 
topics and other words
+        for (int i = 0; i < numberOfDocuments; i++) {
+            List<String> documentTerms = new ArrayList<String>();
+
+            // add terms from some non-dominant topics that are used as 
classification target
+            int numberOfDominantTopics = rng.nextInt(4) + 1; // between 1 and 
3 topics
+            List<String> documentTopics = new ArrayList<String>();
+            for (int j = 0; j < numberOfDominantTopics; j++) {
+                String topic = randomTopicAndTerms(topics, vocabularies, 
documentTerms, 50, 100, rng);
+                documentTopics.add(topic);
+            }
+            // add terms from some non-dominant topics
+            for (int j = 0; j < 10; j++) {
+                String topic = randomTopicAndTerms(topics, vocabularies, 
documentTerms, 1, 10, rng);
+                documentTopics.add(topic);
+            }
+            // add some non discriminative terms not linked to any topic
+            for (int k = 0; k < 100; k++) {
+                
documentTerms.add(String.valueOf(alphabet[rng.nextInt(alphabet.length)]));
+            }
+            // register the generated example in the training set
+            trainingSet.registerExample(String.format("example_%03d", i),
+                StringUtils.join(documentTerms, " "), Arrays.asList(topics));
+        }
+    }
+
+    protected String randomTopicAndTerms(String[] topics,
+                                         Map<String,String[]> vocabularies,
+                                         List<String> documentTerms,
+                                         int min,
+                                         int max,
+                                         Random rng) {
+        String topic = topics[rng.nextInt(topics.length)];
+        String[] terms = vocabularies.get(topic);
+        int numberOfDominantTopicTerm = rng.nextInt(max + 1 - min) + min;
+        for (int k = 0; k < numberOfDominantTopicTerm; k++) {
+            documentTerms.add(terms[rng.nextInt(terms.length)]);
+        }
+        return topic;
     }
 
     protected Hashtable<String,Object> getDefaultClassifierConfigParams() {


Reply via email to