Author: ogrisel
Date: Thu Jan  5 16:20:19 2012
New Revision: 1227674

URL: http://svn.apache.org/viewvc?rev=1227674&view=rev
Log:
STANBOL-197: implement batching on SolrTrainingSet

Modified:
    
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/SolrTrainingSet.java
    
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TrainingSetTest.java

Modified: 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/SolrTrainingSet.java
URL: 
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/SolrTrainingSet.java?rev=1227674&r1=1227673&r2=1227674&view=diff
==============================================================================
--- 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/SolrTrainingSet.java
 (original)
+++ 
incubator/stanbol/trunk/enhancer/engines/topic/src/main/java/org/apache/stanbol/enhancer/topic/SolrTrainingSet.java
 Thu Jan  5 16:20:19 2012
@@ -37,6 +37,7 @@ import org.apache.felix.scr.annotations.
 import org.apache.solr.client.solrj.SolrQuery;
 import org.apache.solr.client.solrj.SolrServer;
 import org.apache.solr.client.solrj.SolrServerException;
+import org.apache.solr.client.solrj.response.QueryResponse;
 import org.apache.solr.common.SolrDocument;
 import org.apache.solr.common.SolrInputDocument;
 import org.osgi.framework.InvalidSyntaxException;
@@ -168,29 +169,50 @@ public class SolrTrainingSet extends Con
         SolrServer solrServer = getActiveSolrServer();
         SolrQuery query = new SolrQuery();
         List<String> parts = new ArrayList<String>();
+        String q = "";
         if (topics.isEmpty()) {
-            query.setQuery("*:*");
+            q += "*:*";
         } else if (positive) {
             for (String topic : topics) {
                 // use a nested query to avoid string escaping issues with 
special solr chars
                 parts.add("_query_:\"{!field f=" + topicUrisField + "}" + 
topic + "\"");
             }
-            query.setQuery(StringUtils.join(parts, " OR "));
+            if (offset != null) {
+                q += "(";
+            }
+            q += StringUtils.join(parts, " OR ");
+            if (offset != null) {
+                q += ")";
+            }
         } else {
             for (String topic : topics) {
                 // use a nested query to avoid string escaping issues with 
special solr chars
                 parts.add("-_query_:\"{!field f=" + topicUrisField + "}" + 
topic + "\"");
             }
-            query.setQuery(StringUtils.join(parts, " AND "));
+            q += StringUtils.join(parts, " AND ");
         }
+        if (offset != null) {
+            q += " AND " + exampleIdField + ":[" + offset.toString() + " TO 
*]";
+        }
+        query.setQuery(q);
+        query.addSortField(exampleIdField, SolrQuery.ORDER.asc);
+        query.set("rows", batchSize + 1);
+        String nextExampleId = null;
         try {
-            for (SolrDocument result : solrServer.query(query).getResults()) {
-                Collection<Object> textValues = 
result.getFieldValues(exampleTextField);
-                if (textValues == null) {
-                    continue;
-                }
-                for (Object value : textValues) {
-                    items.add(value.toString());
+            int count = 0;
+            QueryResponse response = solrServer.query(query);
+            for (SolrDocument result : response.getResults()) {
+                if (count == batchSize) {
+                    nextExampleId = 
result.getFirstValue(exampleIdField).toString();
+                } else {
+                    count++;
+                    Collection<Object> textValues = 
result.getFieldValues(exampleTextField);
+                    if (textValues == null) {
+                        continue;
+                    }
+                    for (Object value : textValues) {
+                        items.add(value.toString());
+                    }
                 }
             }
         } catch (SolrServerException e) {
@@ -199,7 +221,7 @@ public class SolrTrainingSet extends Con
                 StringUtils.join(topics, "', '"), solrCoreId);
             throw new TrainingSetException(msg, e);
         }
-        return new Batch<String>(items, false, null);
+        return new Batch<String>(items, nextExampleId != null, nextExampleId);
     }
 
     @Override

Modified: 
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TrainingSetTest.java
URL: 
http://svn.apache.org/viewvc/incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TrainingSetTest.java?rev=1227674&r1=1227673&r2=1227674&view=diff
==============================================================================
--- 
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TrainingSetTest.java
 (original)
+++ 
incubator/stanbol/trunk/enhancer/engines/topic/src/test/java/org/apache/stanbol/enhancer/engine/topic/TrainingSetTest.java
 Thu Jan  5 16:20:19 2012
@@ -22,7 +22,9 @@ import java.io.File;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashSet;
 import java.util.Hashtable;
+import java.util.Set;
 
 import javax.xml.parsers.ParserConfigurationException;
 
@@ -110,25 +112,65 @@ public class TrainingSetTest extends Bas
         assertEquals(1, examples.items.size());
         assertEquals(examples.items, Arrays.asList("Text of example3."));
         assertFalse(examples.hasMore);
+
+        // Test example update by adding topic3 to example1. The results of 
the previous query should remain
+        // the same (inplace update).
+        trainingSet.registerExample("example1", "Text of example1.", 
Arrays.asList(TOPIC_1, TOPIC_3));
+        examples = trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1, 
TOPIC_3), null);
+        assertEquals(2, examples.items.size());
+        assertEquals(examples.items, Arrays.asList("Text of example1.", "Text 
of example2."));
+        assertFalse(examples.hasMore);
     }
 
-    // @Test
-    public void testBatchingExamples() throws ConfigurationException, 
TrainingSetException {
+    @Test
+    public void testBatchingPositiveExamples() throws ConfigurationException, 
TrainingSetException {
+        Set<String> expectedCollectedText = new HashSet<String>();
+        Set<String> collectedText = new HashSet<String>();
         for (int i = 0; i < 28; i++) {
-            trainingSet.registerExample("example" + i, "Text of example" + i + 
".", Arrays.asList(TOPIC_1));
+            String text = "Text of example" + i + ".";
+            trainingSet.registerExample("example-" + i, text, 
Arrays.asList(TOPIC_1));
+            expectedCollectedText.add(text);
         }
         trainingSet.setBatchSize(10);
-        Batch<String> examples = 
trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1), null);
+        Batch<String> examples = 
trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1, TOPIC_2), null);
         assertEquals(10, examples.items.size());
+        collectedText.addAll(examples.items);
         assertTrue(examples.hasMore);
 
-        examples = trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1), 
examples.nextOffset);
+        examples = trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1, 
TOPIC_2), examples.nextOffset);
         assertEquals(10, examples.items.size());
+        collectedText.addAll(examples.items);
         assertTrue(examples.hasMore);
 
-        examples = trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1), 
examples.nextOffset);
+        examples = trainingSet.getPositiveExamples(Arrays.asList(TOPIC_1, 
TOPIC_2), examples.nextOffset);
         assertEquals(8, examples.items.size());
+        collectedText.addAll(examples.items);
         assertFalse(examples.hasMore);
+
+        assertEquals(expectedCollectedText, collectedText);
+    }
+
+    @Test
+    public void testBatchingNegativeExamplesAndAutoId() throws 
ConfigurationException, TrainingSetException {
+        Set<String> expectedCollectedText = new HashSet<String>();
+        Set<String> collectedText = new HashSet<String>();
+        for (int i = 0; i < 17; i++) {
+            String text = "Text of example" + i + ".";
+            trainingSet.registerExample(null, text, Arrays.asList(TOPIC_1));
+            expectedCollectedText.add(text);
+        }
+        trainingSet.setBatchSize(10);
+        Batch<String> examples = 
trainingSet.getNegativeExamples(Arrays.asList(TOPIC_2), null);
+        assertEquals(10, examples.items.size());
+        collectedText.addAll(examples.items);
+        assertTrue(examples.hasMore);
+
+        examples = trainingSet.getNegativeExamples(Arrays.asList(TOPIC_2), 
examples.nextOffset);
+        assertEquals(7, examples.items.size());
+        collectedText.addAll(examples.items);
+        assertFalse(examples.hasMore);
+
+        assertEquals(expectedCollectedText, collectedText);
     }
 
     protected Hashtable<String,Object> getDefaultConfigParams() {


Reply via email to