Author: ogrisel
Date: Fri Jan 6 16:32:01 2012
New Revision: 1228250
URL: http://svn.apache.org/viewvc?rev=1228250&view=rev
Log:
STANBOL-197: TDD write tests for the future classifier learning infrastructure
Modified:
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicClassifier.java
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
Modified:
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
URL:
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java?rev=1228250&r1=1228249&r2=1228250&view=diff
==============================================================================
---
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
(original)
+++
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
Fri Jan 6 16:32:01 2012
@@ -89,7 +89,8 @@ import org.slf4j.LoggerFactory;
@Property(name = TopicClassificationEngine.BROADER_FIELD),
@Property(name =
TopicClassificationEngine.MATERIALIZED_PATH_FIELD),
@Property(name =
TopicClassificationEngine.MODEL_UPDATE_DATE_FIELD)})
-public class TopicClassificationEngine extends ConfiguredSolrCoreTracker
implements EnhancementEngine, ServiceProperties, TopicClassifier {
+public class TopicClassificationEngine extends ConfiguredSolrCoreTracker
implements EnhancementEngine,
+ ServiceProperties, TopicClassifier {
public static final String ENGINE_ID =
"org.apache.stanbol.enhancer.engine.id";
@@ -401,11 +402,12 @@ public class TopicClassificationEngine e
}
@Override
- public void updateModel() throws TrainingSetException {
+ public int updateModel(boolean incremental) throws TrainingSetException {
checkTrainingSet();
// TODO:
// perform a first query to iterate over all the registered topics
sorted by id (to allow for paging)
// for each topic find the last update date of the union of the topic
and it's narrower topic
+ return 0;
}
protected void checkTrainingSet() throws TrainingSetException {
Modified:
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicClassifier.java
URL:
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicClassifier.java?rev=1228250&r1=1228249&r2=1228250&view=diff
==============================================================================
---
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicClassifier.java
(original)
+++
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicClassifier.java
Fri Jan 6 16:32:01 2012
@@ -101,9 +101,11 @@ public interface TopicClassifier {
void setTrainingSet(TrainingSet trainingSet);
/**
- * Incrementally update the statistical model of the classifier. Note:
depending on the size of the
- * dataset and the number of topics to update, this process can take a
long time and should probably be
- * wrapped in a dedicated thread if called by a the user interface layer.
+ * Update (incrementally or from scratch) the statistical model of the
classifier. Note: depending on the
+ * size of the dataset and the number of topics to update, this process
can take a long time and should
+ * probably be wrapped in a dedicated thread if called by a the user
interface layer.
+ *
+ * @return the number of updated topics
*/
- void updateModel() throws TrainingSetException;
+ int updateModel(boolean incremental) throws TrainingSetException;
}
Modified:
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
URL:
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java?rev=1228250&r1=1228249&r2=1228250&view=diff
==============================================================================
---
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
(original)
+++
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
Fri Jan 6 16:32:01 2012
@@ -38,6 +38,7 @@ import org.apache.solr.client.solrj.embe
import org.apache.solr.client.solrj.response.QueryResponse;
import org.apache.solr.common.params.CommonParams;
import org.apache.stanbol.commons.solr.utils.StreamQueryRequest;
+import org.apache.stanbol.enhancer.topic.SolrTrainingSet;
import org.apache.stanbol.enhancer.topic.TopicSuggestion;
import org.junit.After;
import org.junit.Before;
@@ -53,14 +54,25 @@ public class TopicEngineTest extends Bas
EmbeddedSolrServer classifierSolrServer;
+ EmbeddedSolrServer trainingSetSolrServer;
+
File solrHome;
+ SolrTrainingSet trainingSet;
+
+ TopicClassificationEngine classifier;
+
@Before
- public void setup() throws IOException, ParserConfigurationException,
SAXException {
+ public void setup() throws Exception {
solrHome = File.createTempFile("topicEngineTest_", "_solr_cores");
solrHome.delete();
solrHome.mkdir();
classifierSolrServer = makeEmptyEmbeddedSolrServer(solrHome,
"classifierserver", "classifier");
+ classifier =
TopicClassificationEngine.fromParameters(getDefaultClassifierConfigParams());
+
+ trainingSetSolrServer = makeEmptyEmbeddedSolrServer(solrHome,
"trainingsetserver", "trainingset");
+ trainingSet = new SolrTrainingSet();
+ trainingSet.configure(getDefaultTrainingSetConfigParams());
}
@After
@@ -68,6 +80,8 @@ public class TopicEngineTest extends Bas
FileUtils.deleteQuietly(solrHome);
solrHome = null;
classifierSolrServer = null;
+ trainingSetSolrServer = null;
+ trainingSet = null;
}
protected void loadSampleTopicsFromTSV() throws IOException,
SolrServerException {
@@ -92,26 +106,16 @@ public class TopicEngineTest extends Bas
log.info(String.format("Indexed test topics in %dms",
response.getElapsedTime()));
}
- protected Hashtable<String,Object> getDefaultConfigParams() {
- Hashtable<String,Object> config = new Hashtable<String,Object>();
- config.put(TopicClassificationEngine.ENGINE_ID, "test-engine");
- config.put(TopicClassificationEngine.SOLR_CORE, classifierSolrServer);
- config.put(TopicClassificationEngine.TOPIC_URI_FIELD, "topic");
- config.put(TopicClassificationEngine.SIMILARTITY_FIELD, "text");
- config.put(TopicClassificationEngine.BROADER_FIELD, "broader");
- return config;
- }
-
@Test
public void testEngineConfiguation() throws ConfigurationException {
- Hashtable<String,Object> config = getDefaultConfigParams();
- TopicClassificationEngine engine =
TopicClassificationEngine.fromParameters(config);
- assertNotNull(engine);
- assertEquals(engine.engineId, "test-engine");
- assertEquals(engine.getActiveSolrServer(), classifierSolrServer);
- assertEquals(engine.topicUriField, "topic");
- assertEquals(engine.similarityField, "text");
- assertEquals(engine.acceptedLanguages, new ArrayList<String>());
+ Hashtable<String,Object> config = getDefaultClassifierConfigParams();
+ TopicClassificationEngine classifier =
TopicClassificationEngine.fromParameters(config);
+ assertNotNull(classifier);
+ assertEquals(classifier.engineId, "test-engine");
+ assertEquals(classifier.getActiveSolrServer(), classifierSolrServer);
+ assertEquals(classifier.topicUriField, "topic");
+ assertEquals(classifier.similarityField, "text");
+ assertEquals(classifier.acceptedLanguages, new ArrayList<String>());
// check some required attributes
Hashtable<String,Object> configWithMissingTopicField = new
Hashtable<String,Object>();
@@ -134,54 +138,54 @@ public class TopicEngineTest extends Bas
Hashtable<String,Object> configWithAcceptLangage = new
Hashtable<String,Object>();
configWithAcceptLangage.putAll(config);
configWithAcceptLangage.put(TopicClassificationEngine.LANGUAGES, "en,
fr");
- engine =
TopicClassificationEngine.fromParameters(configWithAcceptLangage);
- assertNotNull(engine);
- assertEquals(engine.acceptedLanguages, Arrays.asList("en", "fr"));
+ classifier =
TopicClassificationEngine.fromParameters(configWithAcceptLangage);
+ assertNotNull(classifier);
+ assertEquals(classifier.acceptedLanguages, Arrays.asList("en", "fr"));
}
@Test
public void testProgrammaticThesaurusConstruction() throws Exception {
- TopicClassificationEngine engine =
TopicClassificationEngine.fromParameters(getDefaultConfigParams());
-
// Register the roots of the taxonomy
- engine.addTopic("http://example.com/topics/root1", null);
- engine.addTopic("http://example.com/topics/root2", null);
- engine.addTopic("http://example.com/topics/root3", new
ArrayList<String>());
- assertEquals(0,
engine.getBroaderTopics("http://example.com/topics/root1").size());
- assertEquals(0,
engine.getBroaderTopics("http://example.com/topics/root2").size());
- assertEquals(0,
engine.getBroaderTopics("http://example.com/topics/root3").size());
- assertEquals(3, engine.getTopicRoots().size());
+ classifier.addTopic("http://example.com/topics/root1", null);
+ classifier.addTopic("http://example.com/topics/root2", null);
+ classifier.addTopic("http://example.com/topics/root3", new
ArrayList<String>());
+ assertEquals(0,
classifier.getBroaderTopics("http://example.com/topics/root1").size());
+ assertEquals(0,
classifier.getBroaderTopics("http://example.com/topics/root2").size());
+ assertEquals(0,
classifier.getBroaderTopics("http://example.com/topics/root3").size());
+ assertEquals(3, classifier.getTopicRoots().size());
// Register some non root nodes
- engine.addTopic("http://example.com/topics/node1",
+ classifier.addTopic("http://example.com/topics/node1",
Arrays.asList("http://example.com/topics/root1",
"http://example.com/topics/root2"));
- engine.addTopic("http://example.com/topics/node2",
Arrays.asList("http://example.com/topics/root3"));
- engine.addTopic("http://example.com/topics/node3",
+ classifier.addTopic("http://example.com/topics/node2",
+ Arrays.asList("http://example.com/topics/root3"));
+ classifier.addTopic("http://example.com/topics/node3",
Arrays.asList("http://example.com/topics/node1",
"http://example.com/topics/node2"));
// the root where not impacted
- assertEquals(0,
engine.getBroaderTopics("http://example.com/topics/root1").size());
- assertEquals(0,
engine.getBroaderTopics("http://example.com/topics/root2").size());
- assertEquals(0,
engine.getBroaderTopics("http://example.com/topics/root3").size());
- assertEquals(3, engine.getTopicRoots().size());
+ assertEquals(0,
classifier.getBroaderTopics("http://example.com/topics/root1").size());
+ assertEquals(0,
classifier.getBroaderTopics("http://example.com/topics/root2").size());
+ assertEquals(0,
classifier.getBroaderTopics("http://example.com/topics/root3").size());
+ assertEquals(3, classifier.getTopicRoots().size());
// the other nodes have the same broader topics as at creation time
- assertEquals(2,
engine.getBroaderTopics("http://example.com/topics/node1").size());
- assertEquals(1,
engine.getBroaderTopics("http://example.com/topics/node2").size());
- assertEquals(2,
engine.getBroaderTopics("http://example.com/topics/node3").size());
+ assertEquals(2,
classifier.getBroaderTopics("http://example.com/topics/node1").size());
+ assertEquals(1,
classifier.getBroaderTopics("http://example.com/topics/node2").size());
+ assertEquals(2,
classifier.getBroaderTopics("http://example.com/topics/node3").size());
// check the induced narrower relationships
- assertEquals(1,
engine.getNarrowerTopics("http://example.com/topics/root1").size());
- assertEquals(1,
engine.getNarrowerTopics("http://example.com/topics/root2").size());
- assertEquals(1,
engine.getNarrowerTopics("http://example.com/topics/root3").size());
- assertEquals(1,
engine.getNarrowerTopics("http://example.com/topics/node1").size());
- assertEquals(1,
engine.getNarrowerTopics("http://example.com/topics/node2").size());
- assertEquals(0,
engine.getNarrowerTopics("http://example.com/topics/node3").size());
+ assertEquals(1,
classifier.getNarrowerTopics("http://example.com/topics/root1").size());
+ assertEquals(1,
classifier.getNarrowerTopics("http://example.com/topics/root2").size());
+ assertEquals(1,
classifier.getNarrowerTopics("http://example.com/topics/root3").size());
+ assertEquals(1,
classifier.getNarrowerTopics("http://example.com/topics/node1").size());
+ assertEquals(1,
classifier.getNarrowerTopics("http://example.com/topics/node2").size());
+ assertEquals(0,
classifier.getNarrowerTopics("http://example.com/topics/node3").size());
}
@Test
public void testEmptyIndexTopicClassification() throws Exception {
- TopicClassificationEngine engine =
TopicClassificationEngine.fromParameters(getDefaultConfigParams());
+ TopicClassificationEngine engine = TopicClassificationEngine
+ .fromParameters(getDefaultClassifierConfigParams());
List<TopicSuggestion> suggestedTopics = engine.suggestTopics("This is
a test.");
assertNotNull(suggestedTopics);
assertEquals(suggestedTopics.size(), 0);
@@ -190,8 +194,7 @@ public class TopicEngineTest extends Bas
@Test
public void testTopicClassification() throws Exception {
loadSampleTopicsFromTSV();
- TopicClassificationEngine engine =
TopicClassificationEngine.fromParameters(getDefaultConfigParams());
- List<TopicSuggestion> suggestedTopics = engine
+ List<TopicSuggestion> suggestedTopics = classifier
.suggestTopics("The Man Who Shot Liberty Valance is a 1962"
+ " American Western film directed by John
Ford,"
+ " narrated by Charlton Heston and starring
James"
@@ -201,4 +204,77 @@ public class TopicEngineTest extends Bas
TopicSuggestion bestSuggestion = suggestedTopics.get(0);
assertEquals(bestSuggestion.uri, "Category:American_films");
}
+
+ //@Test
+ public void testTrainClassifierFromExamples() throws Exception {
+
+ // mini taxonomy for news articles
+ String business = "urn:topics/business";
+ String technology = "urn:topics/technology";
+ String apple = "urn:topics/apple";
+ String sport = "urn:topics/sport";
+ String football = "urn:topics/football";
+ String wordcup = "urn:topics/wordcup";
+
+ classifier.addTopic(business, null);
+ classifier.addTopic(technology, null);
+ classifier.addTopic(sport, null);
+ classifier.addTopic(apple, Arrays.asList(business, technology));
+ classifier.addTopic(football, Arrays.asList(sport));
+ classifier.addTopic(wordcup, Arrays.asList(football));
+
+ // train the classifier on an empty dataset
+ classifier.setTrainingSet(trainingSet);
+ assertEquals(6, classifier.updateModel(true));
+
+ // the model is updated but does not predict anything
+ List<TopicSuggestion> suggestions = classifier
+ .suggestTopics("I like the sound of vuvuzula in the morning!");
+ assertEquals(0, suggestions.size());
+
+ // further update of the model leave do not change any topic
+ assertEquals(0, classifier.updateModel(true));
+
+ // lets register some examples
+ trainingSet.registerExample(null, "Money, money, money is the root of
all evil.",
+ Arrays.asList(business));
+ trainingSet.registerExample(null, "VC invested more money in tech
startups in 2011.",
+ Arrays.asList(business, technology));
+ trainingSet.registerExample(null, "Apple sold many iPads at a very
high price"
+ + " and made record profits.",
Arrays.asList(apple, business));
+ trainingSet.registerExample(null, "Manchester United won 3-2 against
FC Barcelona.",
+ Arrays.asList(football));
+ trainingSet.registerExample(null, "Vuvuzela made the soundtrack of the"
+ + " football wordcup of 2010 in
South Africa.",
+ Arrays.asList(football, wordcup));
+
+ // retrain the model: all 6 topics are impacted by the new examples
+ assertEquals(6, classifier.updateModel(true));
+ suggestions = classifier.suggestTopics("I like the sound of vuvuzula
in the morning!");
+ assertEquals(3, suggestions.size());
+ assertEquals(wordcup, suggestions.get(0).uri);
+ assertEquals(football, suggestions.get(1).uri);
+ assertEquals(sport, suggestions.get(2).uri);
+ }
+
+ protected Hashtable<String,Object> getDefaultClassifierConfigParams() {
+ Hashtable<String,Object> config = new Hashtable<String,Object>();
+ config.put(TopicClassificationEngine.ENGINE_ID, "test-engine");
+ config.put(TopicClassificationEngine.SOLR_CORE, classifierSolrServer);
+ config.put(TopicClassificationEngine.TOPIC_URI_FIELD, "topic");
+ config.put(TopicClassificationEngine.SIMILARTITY_FIELD, "text");
+ config.put(TopicClassificationEngine.BROADER_FIELD, "broader");
+ return config;
+ }
+
+ protected Hashtable<String,Object> getDefaultTrainingSetConfigParams() {
+ Hashtable<String,Object> config = new Hashtable<String,Object>();
+ config.put(SolrTrainingSet.SOLR_CORE, trainingSetSolrServer);
+ config.put(SolrTrainingSet.TRAINING_SET_ID, "test-training-set");
+ config.put(SolrTrainingSet.EXAMPLE_ID_FIELD, "id");
+ config.put(SolrTrainingSet.EXAMPLE_TEXT_FIELD, "text");
+ config.put(SolrTrainingSet.TOPICS_URI_FIELD, "topics");
+ config.put(SolrTrainingSet.MODIFICATION_DATE_FIELD, "modification_dt");
+ return config;
+ }
}