Author: ogrisel
Date: Thu Jan 19 18:06:45 2012
New Revision: 1233504
URL: http://svn.apache.org/viewvc?rev=1233504&view=rev
Log:
STANBOL-197: cross validation seems to work
Modified:
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/ClassificationReport.java
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicSuggestion.java
incubator/stanbol/trunk/enhancer/engines/topic/src/main/resources/classifier/schema.xml
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
Modified:
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
URL:
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java?rev=1233504&r1=1233503&r2=1233504&view=diff
==============================================================================
---
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
(original)
+++
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/engine/topic/TopicClassificationEngine.java
Thu Jan 19 18:06:45 2012
@@ -106,7 +106,6 @@ import org.slf4j.LoggerFactory;
@Property(name =
TopicClassificationEngine.MODEL_UPDATE_DATE_FIELD, value = "last_update_dt"),
@Property(name =
TopicClassificationEngine.PRECISION_FIELD, value = "precision"),
@Property(name = TopicClassificationEngine.RECALL_FIELD,
value = "recall"),
- @Property(name = TopicClassificationEngine.F1_FIELD,
value = "f1"),
@Property(name =
TopicClassificationEngine.MODEL_ENTRY_ID_FIELD, value = "model_entry_id"),
@Property(name =
TopicClassificationEngine.MODEL_EVALUATION_DATE_FIELD, value =
"last_evaluation_dt"),
@Property(name =
TopicClassificationEngine.FALSE_NEGATIVES_FIELD, value = "false_negatives"),
@@ -148,8 +147,6 @@ public class TopicClassificationEngine e
public static final String RECALL_FIELD =
"org.apache.stanbol.enhancer.engine.topic.recallField";
- public static final String F1_FIELD =
"org.apache.stanbol.enhancer.engine.topic.f1Field";
-
public static final String FALSE_POSITIVES_FIELD =
"org.apache.stanbol.enhancer.engine.topic.falsePositivesField";
public static final String FALSE_NEGATIVES_FIELD =
"org.apache.stanbol.enhancer.engine.topic.falseNegativesField";
@@ -160,12 +157,17 @@ public class TopicClassificationEngine e
private static final Logger log =
LoggerFactory.getLogger(TopicClassificationEngine.class);
+ public static final String SOLR_NON_EMPTY_FIELD = "[\"\" TO *]";
+
// TODO: make the following bounds configurable
public int MAX_CHARS_PER_TOPIC = 100000;
public Integer MAX_ROOTS = 1000;
+ public int MAX_SUGGESTIONS = 5; // never suggest more than this: this is
expected to be a reasonable
+ // estimate of the number of topics
occuring in each documents
+
protected String engineId;
protected List<String> acceptedLanguages;
@@ -186,10 +188,6 @@ public class TopicClassificationEngine e
protected String recallField;
- protected String f1Field;
-
- protected int numTopics = 10;
-
protected TrainingSet trainingSet;
// the ENTRY_*_FIELD are basically a hack to use a single Solr core to
make documents with partially
@@ -243,7 +241,6 @@ public class TopicClassificationEngine e
acceptedLanguages = getStringListParan(config, LANGUAGES);
precisionField = getRequiredStringParam(config, PRECISION_FIELD);
recallField = getRequiredStringParam(config, RECALL_FIELD);
- f1Field = getRequiredStringParam(config, F1_FIELD);
modelUpdateDateField = getRequiredStringParam(config,
MODEL_UPDATE_DATE_FIELD);
modelEvaluationDateField = getRequiredStringParam(config,
MODEL_EVALUATION_DATE_FIELD);
falsePositivesField = getRequiredStringParam(config,
FALSE_POSITIVES_FIELD);
@@ -342,7 +339,7 @@ public class TopicClassificationEngine e
}
public List<TopicSuggestion> suggestTopics(String text) throws
ClassifierException {
- List<TopicSuggestion> suggestedTopics = new
ArrayList<TopicSuggestion>(numTopics);
+ List<TopicSuggestion> suggestedTopics = new
ArrayList<TopicSuggestion>(MAX_SUGGESTIONS * 3);
SolrServer solrServer = getActiveSolrServer();
SolrQuery query = new SolrQuery();
query.setQueryType("/" + MoreLikeThisParams.MLT);
@@ -350,12 +347,16 @@ public class TopicClassificationEngine e
query.set(MoreLikeThisParams.MATCH_INCLUDE, false);
query.set(MoreLikeThisParams.MIN_DOC_FREQ, 1);
query.set(MoreLikeThisParams.MIN_TERM_FREQ, 1);
+ query.set(MoreLikeThisParams.MAX_QUERY_TERMS, 30);
+ query.set(MoreLikeThisParams.MAX_NUM_TOKENS_PARSED, 10000);
// TODO: find a way to parse the interesting terms and report them
// for debugging / explanation in dedicated RDF data structure.
// query.set(MoreLikeThisParams.INTERESTING_TERMS, "details");
query.set(MoreLikeThisParams.SIMILARITY_FIELDS, similarityField);
query.set(CommonParams.STREAM_BODY, text);
- query.setRows(numTopics);
+ // over query the number of suggestions to find a statistical cut
based on the curve of the scores of
+ // the top suggestion
+ query.setRows(MAX_SUGGESTIONS * 3);
query.setFields(topicUriField);
query.setIncludeScore(true);
try {
@@ -381,7 +382,28 @@ public class TopicClassificationEngine e
throw new ClassifierException(e);
}
}
- return suggestedTopics;
+ if (suggestedTopics.size() <= 1) {
+ // no need to apply the cutting heuristic
+ return suggestedTopics;
+ }
+ // filter out suggestion that are less than some threshold based on
the mean of the top scores
+ float mean = 0.0f;
+ for (TopicSuggestion suggestion : suggestedTopics) {
+ mean += suggestion.score / suggestedTopics.size();
+ }
+ float threshold = 0.25f * suggestedTopics.get(0).score + 0.75f * mean;
+ List<TopicSuggestion> filteredSuggestions = new
ArrayList<TopicSuggestion>();
+ for (TopicSuggestion suggestion : suggestedTopics) {
+ if (filteredSuggestions.size() >= MAX_SUGGESTIONS) {
+ return filteredSuggestions;
+ }
+ if (filteredSuggestions.isEmpty() || suggestion.score > threshold)
{
+ filteredSuggestions.add(suggestion);
+ } else {
+ break;
+ }
+ }
+ return filteredSuggestions;
}
@Override
@@ -447,8 +469,8 @@ public class TopicClassificationEngine e
query.setSortField(topicUriField, SolrQuery.ORDER.asc);
if (broaderField != null) {
// find any topic with an empty broaderField
- query.setParam("q", entryTypeField + ":" + METADATA_ENTRY + " AND
-" + broaderField
- + ":[\"\" TO *]");
+ query.setParam("q", entryTypeField + ":" + METADATA_ENTRY + " AND
-" + broaderField + ":"
+ + SOLR_NON_EMPTY_FIELD);
} else {
// find any topic
query.setQuery(entryTypeField + ":" + METADATA_ENTRY);
@@ -616,19 +638,11 @@ public class TopicClassificationEngine e
}
final boolean incr = incremental;
int updatedTopics = batchOverTopics(new BatchProcessor<SolrDocument>()
{
- int offset = 0;
@Override
public int process(List<SolrDocument> batch) throws
ClassifierException, TrainingSetException {
int processed = 0;
for (SolrDocument result : batch) {
- if ((cvFoldCount != 0) && (offset % cvFoldCount ==
cvFoldIndex)) {
- // we are performing a cross validation session and
this example belong to the test
- // fold hence should be skipped
- offset++;
- continue;
- }
- offset++;
String topicId =
result.getFirstValue(topicUriField).toString();
List<String> impactedTopics = new ArrayList<String>();
impactedTopics.add(topicId);
@@ -675,9 +689,17 @@ public class TopicClassificationEngine e
long start = System.currentTimeMillis();
Batch<String> examples = Batch.emtpyBatch(String.class);
StringBuffer sb = new StringBuffer();
+ int offset = 0;
do {
examples = trainingSet.getPositiveExamples(impactedTopics,
examples.nextOffset);
for (String example : examples.items) {
+ if ((cvFoldCount != 0) && (offset % cvFoldCount ==
cvFoldIndex)) {
+ // we are performing a cross validation session and this
example belong to the test
+ // fold hence should be skipped
+ offset++;
+ continue;
+ }
+ offset++;
sb.append(example);
sb.append("\n\n");
}
@@ -752,7 +774,6 @@ public class TopicClassificationEngine e
config.put(TopicClassificationEngine.MODEL_EVALUATION_DATE_FIELD,
"last_evaluation_dt");
config.put(TopicClassificationEngine.PRECISION_FIELD, "precision");
config.put(TopicClassificationEngine.RECALL_FIELD, "recall");
- config.put(TopicClassificationEngine.F1_FIELD, "f1");
config.put(TopicClassificationEngine.POSITIVE_SUPPORT_FIELD,
"positive_support");
config.put(TopicClassificationEngine.NEGATIVE_SUPPORT_FIELD,
"negative_support");
config.put(TopicClassificationEngine.FALSE_POSITIVES_FIELD,
"false_positives");
@@ -863,23 +884,22 @@ public class TopicClassificationEngine e
}
offset++;
List<TopicSuggestion> suggestedTopics =
classifier.suggestTopics(example);
+ boolean match = false;
for (TopicSuggestion suggestedTopic :
suggestedTopics) {
- boolean match = false;
if (topic.equals(suggestedTopic.uri)) {
match = true;
truePositives++;
break;
}
- if (!match) {
- falseNegatives++;
- // falseNegativeExamples.add(exampleId);
- }
+ }
+ if (!match) {
+ falseNegatives++;
+ // falseNegativeExamples.add(exampleId);
}
}
} while (examples.hasMore); // TODO: put a bound on the
number of examples
List<String> falsePositiveExamples = new
ArrayList<String>();
- int trueNegatives = 0;
int falsePositives = 0;
offset = 0;
examples = Batch.emtpyBatch(String.class);
@@ -895,17 +915,13 @@ public class TopicClassificationEngine e
offset++;
List<TopicSuggestion> suggestedTopics =
classifier.suggestTopics(example);
for (TopicSuggestion suggestedTopic :
suggestedTopics) {
- boolean match = false;
if (topic.equals(suggestedTopic.uri)) {
- match = true;
falsePositives++;
// falsePositiveExamples.add(exampleId);
break;
}
- if (!match) {
- trueNegatives++;
- }
}
+ // we don't need to collect true negatives
}
} while (examples.hasMore); // TODO: put a bound on the
number of examples
@@ -915,14 +931,10 @@ public class TopicClassificationEngine e
precision = truePositives / (float) (truePositives +
falsePositives);
}
float recall = 0;
- if (trueNegatives != 0 || falseNegatives != 0) {
- recall = trueNegatives / (float) (trueNegatives +
falseNegatives);
- }
- float f1 = 0;
- if (precision != 0 || recall != 0) {
- f1 = 2 * precision * recall / (precision + recall);
+ if (truePositives != 0 || falseNegatives != 0) {
+ recall = truePositives / (float) (truePositives +
falseNegatives);
}
- updatePerformanceMetadata(topic, precision, recall, f1,
falsePositiveExamples,
+ updatePerformanceMetadata(topic, precision, recall,
falsePositiveExamples,
falseNegativeExamples);
}
try {
@@ -942,13 +954,12 @@ public class TopicClassificationEngine e
}
/**
- * Update the performance statistics in a metadata entry of a topic. It
ist the responsibility of the
+ * Update the performance statistics in a metadata entry of a topic. It is
the responsibility of the
* caller to commit.
*/
protected void updatePerformanceMetadata(String topicId,
float precision,
float recall,
- float f1,
List<String>
falsePositiveExamples,
List<String>
falseNegativeExamples) throws ClassifierException {
SolrServer solrServer = getActiveSolrServer();
@@ -964,7 +975,6 @@ public class TopicClassificationEngine e
}
addToList(fieldValues, precisionField, precision);
addToList(fieldValues, recallField, recall);
- addToList(fieldValues, f1Field, f1);
// TODO: handle supports too...
// addToList(fieldValues, falsePositivesField,
falsePositiveExamples);
// addToList(fieldValues, falseNegativesField,
falseNegativeExamples);
@@ -1010,12 +1020,11 @@ public class TopicClassificationEngine e
SolrDocument metadata = results.get(0);
Float precision = computeMeanValue(metadata, precisionField);
Float recall = computeMeanValue(metadata, recallField);
- Float f1 = computeMeanValue(metadata, f1Field);
int positiveSupport = computeSumValue(metadata,
positiveSupportField);
int negativeSupport = computeSumValue(metadata,
negativeSupportField);
Date evaluationDate = (Date)
metadata.getFirstValue(modelEvaluationDateField);
boolean uptodate = evaluationDate != null;
- ClassificationReport report = new ClassificationReport(precision,
recall, f1, positiveSupport,
+ ClassificationReport report = new ClassificationReport(precision,
recall, positiveSupport,
negativeSupport, uptodate, evaluationDate);
if (metadata.getFieldValues(falsePositivesField) == null) {
metadata.setField(falsePositivesField, new
ArrayList<Object>());
Modified:
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/ClassificationReport.java
URL:
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/ClassificationReport.java?rev=1233504&r1=1233503&r2=1233504&view=diff
==============================================================================
---
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/ClassificationReport.java
(original)
+++
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/ClassificationReport.java
Thu Jan 19 18:06:45 2012
@@ -76,18 +76,26 @@ public class ClassificationReport {
public ClassificationReport(float precision,
float recall,
- float f1,
int positiveSupport,
int negativeSupport,
boolean uptodate,
Date evaluationDate) {
this.precision = precision;
this.recall = recall;
- this.f1 = f1;
+ if (precision != 0 || recall != 0) {
+ this.f1 = 2 * precision * recall / (precision + recall);
+ } else {
+ this.f1 = 0;
+ }
this.positiveSupport = positiveSupport;
this.negativeSupport = negativeSupport;
this.uptodate = uptodate;
this.evaluationDate = evaluationDate;
}
+ @Override
+ public String toString() {
+ return String.format("ClassificationReport: precision=%f, recall=%f,
f1=%f", precision, recall, f1);
+ }
+
}
Modified:
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicSuggestion.java
URL:
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicSuggestion.java?rev=1233504&r1=1233503&r2=1233504&view=diff
==============================================================================
---
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicSuggestion.java
(original)
+++
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/TopicSuggestion.java
Thu Jan 19 18:06:45 2012
@@ -30,9 +30,9 @@ public class TopicSuggestion {
public final List<String> paths = new ArrayList<String>();
- public final double score;
+ public final float score;
- public TopicSuggestion(String uri, List<String> paths, double score) {
+ public TopicSuggestion(String uri, List<String> paths, float score) {
this.uri = uri;
if (paths != null) {
this.paths.addAll(paths);
@@ -40,7 +40,7 @@ public class TopicSuggestion {
this.score = score;
}
- public TopicSuggestion(String uri, double score) {
+ public TopicSuggestion(String uri, float score) {
this(uri, null, score);
}
Modified:
incubator/stanbol/trunk/enhancer/engines/topic/src/main/resources/classifier/schema.xml
URL:
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/resources/classifier/schema.xml?rev=1233504&r1=1233503&r2=1233504&view=diff
==============================================================================
---
incubator/stanbol/trunk/enhancer/engines/topic/src/main/resources/classifier/schema.xml
(original)
+++
incubator/stanbol/trunk/enhancer/engines/topic/src/main/resources/classifier/schema.xml
Thu Jan 19 18:06:45 2012
@@ -86,8 +86,6 @@
multiValued="true" />
<field name="recall" type="tfloat" indexed="true" stored="true"
multiValued="true" />
- <field name="f1" type="tfloat" indexed="true" stored="true"
- multiValued="true" />
<field name="last_evaluation_dt" type="tdate" indexed="true"
stored="true" />
<field name="positive_support" type="tint" indexed="false"
Modified:
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
URL:
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java?rev=1233504&r1=1233503&r2=1233504&view=diff
==============================================================================
---
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
(original)
+++
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TopicEngineTest.java
Thu Jan 19 18:06:45 2012
@@ -249,26 +249,30 @@ public class TopicEngineTest extends Emb
// check that updating the model incrementally without changing the
dataset won't change anything.
assertEquals(0, classifier.updateModel(true));
- // lets register some examples
- trainingSet.registerExample(null, "Money, money, money is the root of
all evil.",
+ // lets register some examples including stop words as well to limit
statistical artifacts cause by
+ // the small size of the training set.
+ String STOP_WORDS = " the a is are be in at ";
+ trainingSet.registerExample(null, "Money, money, money is the root of
all evil." + STOP_WORDS,
Arrays.asList(business));
- trainingSet.registerExample(null, "VC invested more money in tech
startups in 2011.",
+ trainingSet.registerExample(null, "VC invested more money in tech
startups in 2011." + STOP_WORDS,
Arrays.asList(business, technology));
- trainingSet.registerExample(null, "Apple's iPad is a small handheld
computer with a touch screen UI",
- Arrays.asList(apple, technology));
+ trainingSet.registerExample(null, "Apple's iPad is a small handheld
computer with a touch screen UI"
+ + STOP_WORDS, Arrays.asList(apple,
technology));
trainingSet.registerExample(null, "Apple sold the iPad at a very high
price"
- + " and made record profits.",
Arrays.asList(apple, business));
+ + " and made record profits." +
STOP_WORDS,
+ Arrays.asList(apple, business));
- trainingSet.registerExample(null, "Manchester United won 3-2 against
FC Barcelona.",
+ trainingSet.registerExample(null, "Manchester United won 3-2 against
FC Barcelona." + STOP_WORDS,
Arrays.asList(football));
- trainingSet.registerExample(null, "The 2012 Football Worldcup takes
place in Brazil.",
+ trainingSet.registerExample(null, "The 2012 Football Worldcup takes
place in Brazil." + STOP_WORDS,
Arrays.asList(football, worldcup));
trainingSet.registerExample(null, "Vuvuzela made the soundtrack of the"
- + " football worldcup of 2010 in
South Africa.",
+ + " football worldcup of 2010 in
South Africa." + STOP_WORDS,
Arrays.asList(football, worldcup, music));
- trainingSet.registerExample(null, "Amon Tobin will be live in Paris
soon.", Arrays.asList(music));
+ trainingSet.registerExample(null, "Amon Tobin will be live in Paris
soon." + STOP_WORDS,
+ Arrays.asList(music));
// retrain the model: all topics are recomputed
assertEquals(7, classifier.updateModel(true));
@@ -281,9 +285,9 @@ public class TopicEngineTest extends Emb
assertEquals(football, suggestions.get(2).uri);
assertEquals(sport, suggestions.get(3).uri);
// check that the scores are decreasing:
- assertTrue(suggestions.get(0).score > suggestions.get(1).score);
- assertTrue(suggestions.get(1).score > suggestions.get(2).score);
- assertTrue(suggestions.get(2).score > suggestions.get(3).score);
+ assertTrue(suggestions.get(0).score >= suggestions.get(1).score);
+ assertTrue(suggestions.get(1).score >= suggestions.get(2).score);
+ assertTrue(suggestions.get(2).score >= suggestions.get(3).score);
suggestions = classifier.suggestTopics("Apple is no longer a
startup.");
assertTrue(suggestions.size() >= 3);
@@ -292,11 +296,9 @@ public class TopicEngineTest extends Emb
assertEquals(business, suggestions.get(2).uri);
suggestions = classifier.suggestTopics("You can watch the worldcup on
your iPad.");
- assertTrue(suggestions.size() >= 4);
- assertEquals(worldcup, suggestions.get(0).uri);
- assertEquals(apple, suggestions.get(1).uri);
- assertEquals(football, suggestions.get(2).uri);
- assertEquals(sport, suggestions.get(3).uri);
+ assertTrue(suggestions.size() >= 2);
+ assertEquals(apple, suggestions.get(0).uri);
+ assertEquals(worldcup, suggestions.get(1).uri);
// test incremental update of a single root node
Thread.sleep(10);
@@ -358,25 +360,25 @@ public class TopicEngineTest extends Emb
assertFalse(performanceEstimates.uptodate);
// update the performance metadata manually
- classifier.updatePerformanceMetadata("urn:t/002", 0.76f, 0.60f, 0.67f,
Arrays.asList("ex14", "ex78"),
+ classifier.updatePerformanceMetadata("urn:t/002", 0.76f, 0.60f,
Arrays.asList("ex14", "ex78"),
Arrays.asList("ex34", "ex23", "ex89"));
classifier.getActiveSolrServer().commit();
performanceEstimates = classifier.getPerformanceEstimates("urn:t/002");
assertTrue(performanceEstimates.uptodate);
- assertEquals(Float.valueOf(0.76f),
Float.valueOf(performanceEstimates.precision));
- assertEquals(Float.valueOf(0.60f),
Float.valueOf(performanceEstimates.recall));
- assertEquals(Float.valueOf(0.67f),
Float.valueOf(performanceEstimates.f1));
+ assertEquals(0.76f, performanceEstimates.precision, 0.01);
+ assertEquals(0.60f, performanceEstimates.recall, 0.01);
+ assertEquals(0.67f, performanceEstimates.f1, 0.01);
assertTrue(classifier.getBroaderTopics("urn:t/002").contains("urn:t/001"));
// accumulate other folds statistics and compute means of statistics
- classifier.updatePerformanceMetadata("urn:t/002", 0.79f, 0.63f, 0.72f,
Arrays.asList("ex1", "ex5"),
+ classifier.updatePerformanceMetadata("urn:t/002", 0.79f, 0.63f,
Arrays.asList("ex1", "ex5"),
Arrays.asList("ex3", "ex10", "ex11"));
classifier.getActiveSolrServer().commit();
performanceEstimates = classifier.getPerformanceEstimates("urn:t/002");
assertTrue(performanceEstimates.uptodate);
- assertEquals(Float.valueOf(0.775f),
Float.valueOf(performanceEstimates.precision));
- assertEquals(Float.valueOf(0.615f),
Float.valueOf(performanceEstimates.recall));
- assertEquals(Float.valueOf(0.69500005f),
Float.valueOf(performanceEstimates.f1));
+ assertEquals(0.775f, performanceEstimates.precision, 0.01);
+ assertEquals(0.615f, performanceEstimates.recall, 0.01);
+ assertEquals(0.695f, performanceEstimates.f1, 0.01);
}
@Test
@@ -388,8 +390,8 @@ public class TopicEngineTest extends Emb
// build an artificial data set used for training models and evaluation
int numberOfTopics = 10;
int numberOfDocuments = 100;
- int vocabSizeMin = 10;
- int vocabSizeMax = 25; // we are using the alphabet as base terms
+ int vocabSizeMin = 20;
+ int vocabSizeMax = 30;
initArtificialTrainingSet(numberOfTopics, numberOfDocuments,
vocabSizeMin, vocabSizeMax, rng);
// by default the reports are not computed
@@ -409,32 +411,17 @@ public class TopicEngineTest extends Emb
// launch an evaluation of the classifier according to the current
state of the training set
assertEquals(numberOfTopics,
classifier.updatePerformanceEstimates(true));
- performanceEstimates = classifier.getPerformanceEstimates("urn:t/001");
- assertTrue(performanceEstimates.uptodate);
- // assertGreater(performanceEstimates.precision, 0.8f);
- // assertGreater(performanceEstimates.recall, 0.8f);
- // assertGreater(performanceEstimates.f1, 0.8f);
- // assertGreater(performanceEstimates.positiveSupport, 10);
- // assertGreater(performanceEstimates.negativeSupport, 90);
- assertNotNull(performanceEstimates.evaluationDate);
-
- performanceEstimates = classifier.getPerformanceEstimates("urn:t/002");
- assertTrue(performanceEstimates.uptodate);
- // assertGreater(performanceEstimates.precision, 0.8f);
- // assertGreater(performanceEstimates.recall, 0.8f);
- // assertGreater(performanceEstimates.f1, 0.8f);
- // assertGreater(performanceEstimates.positiveSupport, 10);
- // assertGreater(performanceEstimates.negativeSupport, 90);
- assertNotNull(performanceEstimates.evaluationDate);
-
- performanceEstimates = classifier.getPerformanceEstimates("urn:t/003");
- assertTrue(performanceEstimates.uptodate);
- // assertGreater(performanceEstimates.precision, 0.8f);
- // assertGreater(performanceEstimates.recall, 0.8f);
- // assertGreater(performanceEstimates.f1, 0.8f);
- // assertGreater(performanceEstimates.positiveSupport, 10);
- // assertGreater(performanceEstimates.negativeSupport, 90);
- assertNotNull(performanceEstimates.evaluationDate);
+ for (int i = 1; i <= numberOfTopics; i++) {
+ String topic = String.format("urn:t/%03d", i);
+ performanceEstimates = classifier.getPerformanceEstimates(topic);
+ assertTrue(performanceEstimates.uptodate);
+ assertGreater(performanceEstimates.precision, 0.5f);
+ assertGreater(performanceEstimates.recall, 0.5f);
+ assertGreater(performanceEstimates.f1, 0.65f);
+ // assertGreater(performanceEstimates.positiveSupport, 10);
+ // assertGreater(performanceEstimates.negativeSupport, 90);
+ assertNotNull(performanceEstimates.evaluationDate);
+ }
// TODO: test model invalidation by registering a sub topic manually
}
@@ -451,22 +438,14 @@ public class TopicEngineTest extends Emb
int vocabSizeMax,
Random rng) throws
ClassifierException, TrainingSetException {
// define some artificial topics and register them to the classifier
- char[] alphabet = "abcdefghijklmnopqrstuvwxyz".toCharArray();
+ String[] stopWords = randomVocabulary(0, 50, 50, rng);
String[] topics = new String[numberOfTopics];
Map<String,String[]> vocabularies = new TreeMap<String,String[]>();
for (int i = 0; i < numberOfTopics; i++) {
String topic = String.format("urn:t/%03d", i + 1);
topics[i] = topic;
classifier.addTopic(topic, null);
- int vocSize = rng.nextInt(vocabSizeMax + 1 - vocabSizeMin) +
vocabSizeMin;
- String[] terms = new String[vocSize];
-
- for (int j = 0; j < vocSize; j++) {
- // define some artificial vocabulary for each topic to
automatically generate random examples
- // with some topic structure
- // if i = 1, will generate: ["a1", "b1", "c1", ...]
- terms[j] = "term_" + alphabet[j] + String.valueOf(i + 1);
- }
+ String[] terms = randomVocabulary(i, vocabSizeMin, vocabSizeMax,
rng);
vocabularies.put(topic, terms);
}
classifier.setTrainingSet(trainingSet);
@@ -476,19 +455,19 @@ public class TopicEngineTest extends Emb
List<String> documentTerms = new ArrayList<String>();
// add terms from some non-dominant topics that are used as
classification target
- int numberOfDominantTopics = 1;// rng.nextInt(4) + 1; // between 1
and 3 topics
+ int numberOfDominantTopics = rng.nextInt(3) + 1; // between 1 and
3 topics
List<String> documentTopics = new ArrayList<String>();
for (int j = 0; j < numberOfDominantTopics; j++) {
- String topic = randomTopicAndTerms(topics, vocabularies,
documentTerms, 50, 100, rng);
+ String topic = randomTopicAndTerms(topics, vocabularies,
documentTerms, 100, 150, rng);
documentTopics.add(topic);
}
// add terms from some non-dominant topics
- for (int j = 0; j < 0; j++) {
- randomTopicAndTerms(topics, vocabularies, documentTerms, 1,
10, rng);
+ for (int j = 0; j < 5; j++) {
+ randomTopicAndTerms(topics, vocabularies, documentTerms, 5,
10, rng);
}
// add some non discriminative terms not linked to any topic
- for (int k = 0; k < 0; k++) {
-
documentTerms.add(String.valueOf(alphabet[rng.nextInt(alphabet.length)]));
+ for (int k = 0; k < 100; k++) {
+
documentTerms.add(String.valueOf(stopWords[rng.nextInt(stopWords.length)]));
}
// register the generated example in the training set
String text = StringUtils.join(documentTerms, " ");
@@ -496,6 +475,18 @@ public class TopicEngineTest extends Emb
}
}
+ protected String[] randomVocabulary(int topicIndex, int vocabSizeMin, int
vocabSizeMax, Random rng) {
+ int vocSize = rng.nextInt(vocabSizeMax + 1 - vocabSizeMin) +
vocabSizeMin;
+ String[] terms = new String[vocSize];
+
+ for (int j = 0; j < vocSize; j++) {
+ // define some artificial vocabulary for each topic to
automatically generate random examples
+ // with some topic structure
+ terms[j] = String.format("term_%03d_t%03d", j, topicIndex + 1);
+ }
+ return terms;
+ }
+
protected String randomTopicAndTerms(String[] topics,
Map<String,String[]> vocabularies,
List<String> documentTerms,
@@ -525,7 +516,6 @@ public class TopicEngineTest extends Emb
config.put(TopicClassificationEngine.MODEL_EVALUATION_DATE_FIELD,
"last_evaluation_dt");
config.put(TopicClassificationEngine.PRECISION_FIELD, "precision");
config.put(TopicClassificationEngine.RECALL_FIELD, "recall");
- config.put(TopicClassificationEngine.F1_FIELD, "f1");
config.put(TopicClassificationEngine.POSITIVE_SUPPORT_FIELD,
"positive_support");
config.put(TopicClassificationEngine.NEGATIVE_SUPPORT_FIELD,
"negative_support");
config.put(TopicClassificationEngine.FALSE_POSITIVES_FIELD,
"false_positives");