Author: ogrisel
Date: Fri Jan  6 16:32:01 2012
New Revision: 1228250

URL: http://svn.apache.org/viewvc?rev=1228250&view=rev
Log:
STANBOL-197: TDD write tests for the future classifier learning infrastructure

Modified:
    
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
    
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicClassifier.java
    
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java

Modified: 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
URL: 
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java?rev=1228250&r1=1228249&r2=1228250&view=diff
==============================================================================
--- 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
 (original)
+++ 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
 Fri Jan  6 16:32:01 2012
@@ -89,7 +89,8 @@ import org.slf4j.LoggerFactory;
                      @Property(name = TopicClassificationEngine.BROADER_FIELD),
                      @Property(name = 
TopicClassificationEngine.MATERIALIZED_PATH_FIELD),
                      @Property(name = 
TopicClassificationEngine.MODEL_UPDATE_DATE_FIELD)})
-public class TopicClassificationEngine extends ConfiguredSolrCoreTracker 
implements EnhancementEngine, ServiceProperties, TopicClassifier {
+public class TopicClassificationEngine extends ConfiguredSolrCoreTracker 
implements EnhancementEngine,
+        ServiceProperties, TopicClassifier {
 
     public static final String ENGINE_ID = 
"org.apache.stanbol.enhancer.engine.id";
 
@@ -401,11 +402,12 @@ public class TopicClassificationEngine e
     }
 
     @Override
-    public void updateModel() throws TrainingSetException {
+    public int updateModel(boolean incremental) throws TrainingSetException {
         checkTrainingSet();
         // TODO:
         // perform a first query to iterate over all the registered topics 
sorted by id (to allow for paging)
         // for each topic find the last update date of the union of the topic 
and it's narrower topic
+        return 0;
     }
 
     protected void checkTrainingSet() throws TrainingSetException {

Modified: 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicClassifier.java
URL: 
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicClassifier.java?rev=1228250&r1=1228249&r2=1228250&view=diff
==============================================================================
--- 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicClassifier.java
 (original)
+++ 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicClassifier.java
 Fri Jan  6 16:32:01 2012
@@ -101,9 +101,11 @@ public interface TopicClassifier {
     void setTrainingSet(TrainingSet trainingSet);
 
     /**
-     * Incrementally update the statistical model of the classifier. Note: 
depending on the size of the
-     * dataset and the number of topics to update, this process can take a 
long time and should probably be
-     * wrapped in a dedicated thread if called by a the user interface layer.
+     * Update (incrementally or from scratch) the statistical model of the 
classifier. Note: depending on the
+     * size of the dataset and the number of topics to update, this process 
can take a long time and should
+     * probably be wrapped in a dedicated thread if called by a the user 
interface layer.
+     * 
+     * @return the number of updated topics
      */
-    void updateModel() throws TrainingSetException;
+    int updateModel(boolean incremental) throws TrainingSetException;
 }

Modified: 
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
URL: 
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java?rev=1228250&r1=1228249&r2=1228250&view=diff
==============================================================================
--- 
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
 (original)
+++ 
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
 Fri Jan  6 16:32:01 2012
@@ -38,6 +38,7 @@ import org.apache.solr.client.solrj.embe
 import org.apache.solr.client.solrj.response.QueryResponse;
 import org.apache.solr.common.params.CommonParams;
 import org.apache.stanbol.commons.solr.utils.StreamQueryRequest;
+import org.apache.stanbol.enhancer.topic.SolrTrainingSet;
 import org.apache.stanbol.enhancer.topic.TopicSuggestion;
 import org.junit.After;
 import org.junit.Before;
@@ -53,14 +54,25 @@ public class TopicEngineTest extends Bas
 
     EmbeddedSolrServer classifierSolrServer;
 
+    EmbeddedSolrServer trainingSetSolrServer;
+
     File solrHome;
 
+    SolrTrainingSet trainingSet;
+
+    TopicClassificationEngine classifier;
+
     @Before
-    public void setup() throws IOException, ParserConfigurationException, 
SAXException {
+    public void setup() throws Exception {
         solrHome = File.createTempFile("topicEngineTest_", "_solr_cores");
         solrHome.delete();
         solrHome.mkdir();
         classifierSolrServer = makeEmptyEmbeddedSolrServer(solrHome, 
"classifierserver", "classifier");
+        classifier = 
TopicClassificationEngine.fromParameters(getDefaultClassifierConfigParams());
+
+        trainingSetSolrServer = makeEmptyEmbeddedSolrServer(solrHome, 
"trainingsetserver", "trainingset");
+        trainingSet = new SolrTrainingSet();
+        trainingSet.configure(getDefaultTrainingSetConfigParams());
     }
 
     @After
@@ -68,6 +80,8 @@ public class TopicEngineTest extends Bas
         FileUtils.deleteQuietly(solrHome);
         solrHome = null;
         classifierSolrServer = null;
+        trainingSetSolrServer = null;
+        trainingSet = null;
     }
 
     protected void loadSampleTopicsFromTSV() throws IOException, 
SolrServerException {
@@ -92,26 +106,16 @@ public class TopicEngineTest extends Bas
         log.info(String.format("Indexed test topics in %dms", 
response.getElapsedTime()));
     }
 
-    protected Hashtable<String,Object> getDefaultConfigParams() {
-        Hashtable<String,Object> config = new Hashtable<String,Object>();
-        config.put(TopicClassificationEngine.ENGINE_ID, "test-engine");
-        config.put(TopicClassificationEngine.SOLR_CORE, classifierSolrServer);
-        config.put(TopicClassificationEngine.TOPIC_URI_FIELD, "topic");
-        config.put(TopicClassificationEngine.SIMILARTITY_FIELD, "text");
-        config.put(TopicClassificationEngine.BROADER_FIELD, "broader");
-        return config;
-    }
-
     @Test
     public void testEngineConfiguation() throws ConfigurationException {
-        Hashtable<String,Object> config = getDefaultConfigParams();
-        TopicClassificationEngine engine = 
TopicClassificationEngine.fromParameters(config);
-        assertNotNull(engine);
-        assertEquals(engine.engineId, "test-engine");
-        assertEquals(engine.getActiveSolrServer(), classifierSolrServer);
-        assertEquals(engine.topicUriField, "topic");
-        assertEquals(engine.similarityField, "text");
-        assertEquals(engine.acceptedLanguages, new ArrayList<String>());
+        Hashtable<String,Object> config = getDefaultClassifierConfigParams();
+        TopicClassificationEngine classifier = 
TopicClassificationEngine.fromParameters(config);
+        assertNotNull(classifier);
+        assertEquals(classifier.engineId, "test-engine");
+        assertEquals(classifier.getActiveSolrServer(), classifierSolrServer);
+        assertEquals(classifier.topicUriField, "topic");
+        assertEquals(classifier.similarityField, "text");
+        assertEquals(classifier.acceptedLanguages, new ArrayList<String>());
 
         // check some required attributes
         Hashtable<String,Object> configWithMissingTopicField = new 
Hashtable<String,Object>();
@@ -134,54 +138,54 @@ public class TopicEngineTest extends Bas
         Hashtable<String,Object> configWithAcceptLangage = new 
Hashtable<String,Object>();
         configWithAcceptLangage.putAll(config);
         configWithAcceptLangage.put(TopicClassificationEngine.LANGUAGES, "en, 
fr");
-        engine = 
TopicClassificationEngine.fromParameters(configWithAcceptLangage);
-        assertNotNull(engine);
-        assertEquals(engine.acceptedLanguages, Arrays.asList("en", "fr"));
+        classifier = 
TopicClassificationEngine.fromParameters(configWithAcceptLangage);
+        assertNotNull(classifier);
+        assertEquals(classifier.acceptedLanguages, Arrays.asList("en", "fr"));
     }
 
     @Test
     public void testProgrammaticThesaurusConstruction() throws Exception {
-        TopicClassificationEngine engine = 
TopicClassificationEngine.fromParameters(getDefaultConfigParams());
-
         // Register the roots of the taxonomy
-        engine.addTopic("http://example.com/topics/root1";, null);
-        engine.addTopic("http://example.com/topics/root2";, null);
-        engine.addTopic("http://example.com/topics/root3";, new 
ArrayList<String>());
-        assertEquals(0, 
engine.getBroaderTopics("http://example.com/topics/root1";).size());
-        assertEquals(0, 
engine.getBroaderTopics("http://example.com/topics/root2";).size());
-        assertEquals(0, 
engine.getBroaderTopics("http://example.com/topics/root3";).size());
-        assertEquals(3, engine.getTopicRoots().size());
+        classifier.addTopic("http://example.com/topics/root1";, null);
+        classifier.addTopic("http://example.com/topics/root2";, null);
+        classifier.addTopic("http://example.com/topics/root3";, new 
ArrayList<String>());
+        assertEquals(0, 
classifier.getBroaderTopics("http://example.com/topics/root1";).size());
+        assertEquals(0, 
classifier.getBroaderTopics("http://example.com/topics/root2";).size());
+        assertEquals(0, 
classifier.getBroaderTopics("http://example.com/topics/root3";).size());
+        assertEquals(3, classifier.getTopicRoots().size());
 
         // Register some non root nodes
-        engine.addTopic("http://example.com/topics/node1";,
+        classifier.addTopic("http://example.com/topics/node1";,
             Arrays.asList("http://example.com/topics/root1";, 
"http://example.com/topics/root2";));
-        engine.addTopic("http://example.com/topics/node2";, 
Arrays.asList("http://example.com/topics/root3";));
-        engine.addTopic("http://example.com/topics/node3";,
+        classifier.addTopic("http://example.com/topics/node2";,
+            Arrays.asList("http://example.com/topics/root3";));
+        classifier.addTopic("http://example.com/topics/node3";,
             Arrays.asList("http://example.com/topics/node1";, 
"http://example.com/topics/node2";));
 
         // the root where not impacted
-        assertEquals(0, 
engine.getBroaderTopics("http://example.com/topics/root1";).size());
-        assertEquals(0, 
engine.getBroaderTopics("http://example.com/topics/root2";).size());
-        assertEquals(0, 
engine.getBroaderTopics("http://example.com/topics/root3";).size());
-        assertEquals(3, engine.getTopicRoots().size());
+        assertEquals(0, 
classifier.getBroaderTopics("http://example.com/topics/root1";).size());
+        assertEquals(0, 
classifier.getBroaderTopics("http://example.com/topics/root2";).size());
+        assertEquals(0, 
classifier.getBroaderTopics("http://example.com/topics/root3";).size());
+        assertEquals(3, classifier.getTopicRoots().size());
 
         // the other nodes have the same broader topics as at creation time
-        assertEquals(2, 
engine.getBroaderTopics("http://example.com/topics/node1";).size());
-        assertEquals(1, 
engine.getBroaderTopics("http://example.com/topics/node2";).size());
-        assertEquals(2, 
engine.getBroaderTopics("http://example.com/topics/node3";).size());
+        assertEquals(2, 
classifier.getBroaderTopics("http://example.com/topics/node1";).size());
+        assertEquals(1, 
classifier.getBroaderTopics("http://example.com/topics/node2";).size());
+        assertEquals(2, 
classifier.getBroaderTopics("http://example.com/topics/node3";).size());
 
         // check the induced narrower relationships
-        assertEquals(1, 
engine.getNarrowerTopics("http://example.com/topics/root1";).size());
-        assertEquals(1, 
engine.getNarrowerTopics("http://example.com/topics/root2";).size());
-        assertEquals(1, 
engine.getNarrowerTopics("http://example.com/topics/root3";).size());
-        assertEquals(1, 
engine.getNarrowerTopics("http://example.com/topics/node1";).size());
-        assertEquals(1, 
engine.getNarrowerTopics("http://example.com/topics/node2";).size());
-        assertEquals(0, 
engine.getNarrowerTopics("http://example.com/topics/node3";).size());
+        assertEquals(1, 
classifier.getNarrowerTopics("http://example.com/topics/root1";).size());
+        assertEquals(1, 
classifier.getNarrowerTopics("http://example.com/topics/root2";).size());
+        assertEquals(1, 
classifier.getNarrowerTopics("http://example.com/topics/root3";).size());
+        assertEquals(1, 
classifier.getNarrowerTopics("http://example.com/topics/node1";).size());
+        assertEquals(1, 
classifier.getNarrowerTopics("http://example.com/topics/node2";).size());
+        assertEquals(0, 
classifier.getNarrowerTopics("http://example.com/topics/node3";).size());
     }
 
     @Test
     public void testEmptyIndexTopicClassification() throws Exception {
-        TopicClassificationEngine engine = 
TopicClassificationEngine.fromParameters(getDefaultConfigParams());
+        TopicClassificationEngine engine = TopicClassificationEngine
+                .fromParameters(getDefaultClassifierConfigParams());
         List<TopicSuggestion> suggestedTopics = engine.suggestTopics("This is 
a test.");
         assertNotNull(suggestedTopics);
         assertEquals(suggestedTopics.size(), 0);
@@ -190,8 +194,7 @@ public class TopicEngineTest extends Bas
     @Test
     public void testTopicClassification() throws Exception {
         loadSampleTopicsFromTSV();
-        TopicClassificationEngine engine = 
TopicClassificationEngine.fromParameters(getDefaultConfigParams());
-        List<TopicSuggestion> suggestedTopics = engine
+        List<TopicSuggestion> suggestedTopics = classifier
                 .suggestTopics("The Man Who Shot Liberty Valance is a 1962"
                                + " American Western film directed by John 
Ford,"
                                + " narrated by Charlton Heston and starring 
James"
@@ -201,4 +204,77 @@ public class TopicEngineTest extends Bas
         TopicSuggestion bestSuggestion = suggestedTopics.get(0);
         assertEquals(bestSuggestion.uri, "Category:American_films");
     }
+
+    //@Test
+    public void testTrainClassifierFromExamples() throws Exception {
+
+        // mini taxonomy for news articles
+        String business = "urn:topics/business";
+        String technology = "urn:topics/technology";
+        String apple = "urn:topics/apple";
+        String sport = "urn:topics/sport";
+        String football = "urn:topics/football";
+        String wordcup = "urn:topics/wordcup";
+
+        classifier.addTopic(business, null);
+        classifier.addTopic(technology, null);
+        classifier.addTopic(sport, null);
+        classifier.addTopic(apple, Arrays.asList(business, technology));
+        classifier.addTopic(football, Arrays.asList(sport));
+        classifier.addTopic(wordcup, Arrays.asList(football));
+
+        // train the classifier on an empty dataset
+        classifier.setTrainingSet(trainingSet);
+        assertEquals(6, classifier.updateModel(true));
+
+        // the model is updated but does not predict anything
+        List<TopicSuggestion> suggestions = classifier
+                .suggestTopics("I like the sound of vuvuzula in the morning!");
+        assertEquals(0, suggestions.size());
+
+        // further update of the model leave do not change any topic
+        assertEquals(0, classifier.updateModel(true));
+
+        // lets register some examples
+        trainingSet.registerExample(null, "Money, money, money is the root of 
all evil.",
+            Arrays.asList(business));
+        trainingSet.registerExample(null, "VC invested more money in tech 
startups in 2011.",
+            Arrays.asList(business, technology));
+        trainingSet.registerExample(null, "Apple sold many iPads at a very 
high price"
+                                          + " and made record profits.", 
Arrays.asList(apple, business));
+        trainingSet.registerExample(null, "Manchester United won 3-2 against 
FC Barcelona.",
+            Arrays.asList(football));
+        trainingSet.registerExample(null, "Vuvuzela made the soundtrack of the"
+                                          + " football wordcup of 2010 in 
South Africa.",
+            Arrays.asList(football, wordcup));
+
+        // retrain the model: all 6 topics are impacted by the new examples
+        assertEquals(6, classifier.updateModel(true));
+        suggestions = classifier.suggestTopics("I like the sound of vuvuzula 
in the morning!");
+        assertEquals(3, suggestions.size());
+        assertEquals(wordcup, suggestions.get(0).uri);
+        assertEquals(football, suggestions.get(1).uri);
+        assertEquals(sport, suggestions.get(2).uri);
+    }
+
+    protected Hashtable<String,Object> getDefaultClassifierConfigParams() {
+        Hashtable<String,Object> config = new Hashtable<String,Object>();
+        config.put(TopicClassificationEngine.ENGINE_ID, "test-engine");
+        config.put(TopicClassificationEngine.SOLR_CORE, classifierSolrServer);
+        config.put(TopicClassificationEngine.TOPIC_URI_FIELD, "topic");
+        config.put(TopicClassificationEngine.SIMILARTITY_FIELD, "text");
+        config.put(TopicClassificationEngine.BROADER_FIELD, "broader");
+        return config;
+    }
+
+    protected Hashtable<String,Object> getDefaultTrainingSetConfigParams() {
+        Hashtable<String,Object> config = new Hashtable<String,Object>();
+        config.put(SolrTrainingSet.SOLR_CORE, trainingSetSolrServer);
+        config.put(SolrTrainingSet.TRAINING_SET_ID, "test-training-set");
+        config.put(SolrTrainingSet.EXAMPLE_ID_FIELD, "id");
+        config.put(SolrTrainingSet.EXAMPLE_TEXT_FIELD, "text");
+        config.put(SolrTrainingSet.TOPICS_URI_FIELD, "topics");
+        config.put(SolrTrainingSet.MODIFICATION_DATE_FIELD, "modification_dt");
+        return config;
+    }
 }


Reply via email to