Author: joern
Date: Wed Jan  8 20:39:08 2014
New Revision: 1556629

URL: http://svn.apache.org/r1556629
Log:
OPENNLP-574 Intial work to integrate Mahouts Logistic Regression Classifiers

Added:
    opennlp/addons/mahout-addon/
    opennlp/addons/mahout-addon/pom.xml
    opennlp/addons/mahout-addon/src/
    opennlp/addons/mahout-addon/src/main/
    opennlp/addons/mahout-addon/src/main/java/
    opennlp/addons/mahout-addon/src/main/java/SimpleTest.java
    opennlp/addons/mahout-addon/src/main/java/opennlp/
    opennlp/addons/mahout-addon/src/main/java/opennlp/addons/
    opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/
    
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/AbstractOnlineLearnerTrainer.java
    
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/AdaptiveLogisticRegressionTrainer.java
    
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/LogisticRegressionTrainer.java
    
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/OnlineLogisticRegressionTrainer.java
    
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/PassiveAggressiveTrainer.java
    
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/VectorClassifierModel.java

Added: opennlp/addons/mahout-addon/pom.xml
URL: 
http://svn.apache.org/viewvc/opennlp/addons/mahout-addon/pom.xml?rev=1556629&view=auto
==============================================================================
--- opennlp/addons/mahout-addon/pom.xml (added)
+++ opennlp/addons/mahout-addon/pom.xml Wed Jan  8 20:39:08 2014
@@ -0,0 +1,86 @@
+<?xml version="1.0" encoding="UTF-8"?>
+
+<!--
+   Licensed to the Apache Software Foundation (ASF) under one
+   or more contributor license agreements.  See the NOTICE file
+   distributed with this work for additional information
+   regarding copyright ownership.  The ASF licenses this file
+   to you under the Apache License, Version 2.0 (the
+   "License"); you may not use this file except in compliance
+   with the License.  You may obtain a copy of the License at
+
+     http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing,
+   software distributed under the License is distributed on an
+   "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+   KIND, either express or implied.  See the License for the
+   specific language governing permissions and limitations
+   under the License.    
+-->
+
+<project xmlns="http://maven.apache.org/POM/4.0.0"; 
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"; 
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 
http://maven.apache.org/xsd/maven-4.0.0.xsd";>
+       <modelVersion>4.0.0</modelVersion>
+       
+       <parent>
+           <groupId>org.apache.opennlp</groupId>
+           <artifactId>opennlp</artifactId>
+           <version>1.6.0-SNAPSHOT</version>
+           <relativePath>../opennlp/pom.xml</relativePath>
+    </parent>
+    
+       <artifactId>mahout-addon</artifactId>
+       <packaging>jar</packaging>
+       <name>Apache OpenNLP Mahout Addon</name>
+
+       <dependencies>
+               <dependency>
+                       <groupId>org.apache.opennlp</groupId>
+                       <artifactId>opennlp-tools</artifactId>
+                       <version>1.6.0-SNAPSHOT</version>
+               </dependency>
+               
+               <dependency>
+                       <groupId>org.apache.mahout</groupId>
+                       <artifactId>mahout-core</artifactId>
+                       <version>0.8</version>
+               </dependency>
+
+               <dependency>
+                       <groupId>junit</groupId>
+                       <artifactId>junit</artifactId>
+                       <scope>test</scope>
+               </dependency>
+       </dependencies>
+
+       <build>
+               <plugins>
+                       <plugin>
+                               <groupId>org.apache.maven.plugins</groupId>
+                               <artifactId>maven-dependency-plugin</artifactId>
+                               <version>2.1</version>
+                               <executions>
+                                       <execution>
+                                               <id>copy-dependencies</id>
+                                               <phase>package</phase>
+                                               <goals>
+                                                       
<goal>copy-dependencies</goal>
+                                               </goals>
+                                               <configuration>
+                                                       
<excludeScope>provided</excludeScope>
+                                                       
<stripVersion>true</stripVersion>
+                                               </configuration>
+                                       </execution>
+                               </executions>
+                       </plugin>
+                       <plugin>
+                               <groupId>org.apache.maven.plugins</groupId>
+                               <artifactId>maven-surefire-plugin</artifactId>
+                               <configuration>
+                                       <skipTests>true</skipTests>
+                                       <argLine>-Xmx512m</argLine>
+                               </configuration>
+                       </plugin>
+               </plugins>
+       </build>
+</project>

Added: opennlp/addons/mahout-addon/src/main/java/SimpleTest.java
URL: 
http://svn.apache.org/viewvc/opennlp/addons/mahout-addon/src/main/java/SimpleTest.java?rev=1556629&view=auto
==============================================================================
--- opennlp/addons/mahout-addon/src/main/java/SimpleTest.java (added)
+++ opennlp/addons/mahout-addon/src/main/java/SimpleTest.java Wed Jan  8 
20:39:08 2014
@@ -0,0 +1,51 @@
+import org.apache.mahout.classifier.sgd.PassiveAggressive;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
+
+public class SimpleTest {
+
+  public static void main(String[] args) {
+
+    // Prepare data in vector format ...
+    
+    // The basic idea is that you create a vector, typically a 
RandomAccessSparseVector,
+    // and then you use various feature encoders to progressively add features 
to that vector.
+    // The size of the vector should be large enough to avoid feature 
collisions as features are hashed.
+    
+    // NOTE: Looks like we need to store the cardinality of the vector in the 
model ?!
+    
+    StaticWordValueEncoder encoder = new 
StaticWordValueEncoder("word-encoder");
+    
+    RandomAccessSparseVector vector1 = new RandomAccessSparseVector(3);
+    vector1.set(0, 1);
+    vector1.set(1, 0);
+    vector1.set(2, 1);
+    
+//    encoder.addToVector("f1", vector1);
+//    encoder.addToVector("f", vector1);
+
+    RandomAccessSparseVector vector2 = new RandomAccessSparseVector(3);
+    
+    vector2.set(0, 0);
+    vector2.set(1, 1);
+    vector2.set(2, 1);
+    
+//    encoder.addToVector("f2", vector2);
+//    encoder.addToVector("f", vector2);
+
+    // do the training
+    PassiveAggressive pa = new PassiveAggressive(2, 3);
+    pa.train(0, vector1);
+    pa.train(1, vector2);
+    
+    RandomAccessSparseVector vector = new 
RandomAccessSparseVector(pa.numFeatures());
+    vector.set(0, 1);
+    vector.set(1, 0);
+    vector.set(2, 1);
+
+    Vector result = pa.classifyFull(vector);
+    
+    System.out.println(result);
+  }
+}

Added: 
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/AbstractOnlineLearnerTrainer.java
URL: 
http://svn.apache.org/viewvc/opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/AbstractOnlineLearnerTrainer.java?rev=1556629&view=auto
==============================================================================
--- 
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/AbstractOnlineLearnerTrainer.java
 (added)
+++ 
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/AbstractOnlineLearnerTrainer.java
 Wed Jan  8 20:39:08 2014
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package opennlp.addons.mahout;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import opennlp.tools.ml.AbstractEventTrainer;
+import opennlp.tools.ml.model.DataIndexer;
+import opennlp.tools.ml.model.MaxentModel;
+
+import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
+import org.apache.mahout.classifier.sgd.L1;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+
+abstract class AbstractOnlineLearnerTrainer extends AbstractEventTrainer {
+
+  protected final int iterations;
+  
+  public AbstractOnlineLearnerTrainer(Map<String, String> trainParams,
+      Map<String, String> reportMap) {
+    super(trainParams, reportMap);
+    
+    // TODO: Extract parameters here, used by all implementations, e.g. 
learningRate
+    
+    String iterationsValue = trainParams.get("Iterations");
+    
+    if (iterationsValue != null) {
+      iterations = Integer.parseInt(iterationsValue);
+    }
+    else {
+      iterations = 20;
+    }
+  }
+
+  protected void trainOnlineLearner(DataIndexer indexer, 
org.apache.mahout.classifier.OnlineLearner pa) {
+    int cardinality = indexer.getPredLabels().length;
+    int outcomes[] = indexer.getOutcomeList();
+    
+    for (int i = 0; i < indexer.getContexts().length; i++) {
+
+      Vector vector = new RandomAccessSparseVector(cardinality);
+      
+      int features[] = indexer.getContexts()[i];
+      
+      for (int fi = 0; fi < features.length; fi++) {
+        vector.set(features[fi], indexer.getNumTimesEventsSeen()[i]);
+      } 
+      
+      pa.train(outcomes[i], vector);
+    }
+  }
+
+  protected Map<String, Integer> createPrepMap(DataIndexer indexer) {
+    Map<String, Integer> predMap = new HashMap<String, Integer>();
+    
+    String predLabels[] = indexer.getPredLabels();
+    for (int i = 0; i < predLabels.length; i++) {
+      predMap.put(predLabels[i], i);
+    }
+    
+    return predMap;
+  }
+  
+  @Override
+  public boolean isSortAndMerge() {
+    return true;
+  }
+}

Added: 
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/AdaptiveLogisticRegressionTrainer.java
URL: 
http://svn.apache.org/viewvc/opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/AdaptiveLogisticRegressionTrainer.java?rev=1556629&view=auto
==============================================================================
--- 
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/AdaptiveLogisticRegressionTrainer.java
 (added)
+++ 
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/AdaptiveLogisticRegressionTrainer.java
 Wed Jan  8 20:39:08 2014
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package opennlp.addons.mahout;
+
+import java.io.IOException;
+import java.util.Map;
+
+import opennlp.tools.ml.model.DataIndexer;
+import opennlp.tools.ml.model.MaxentModel;
+
+import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
+import org.apache.mahout.classifier.sgd.L1;
+
+public class AdaptiveLogisticRegressionTrainer extends 
AbstractOnlineLearnerTrainer {
+  
+  public AdaptiveLogisticRegressionTrainer(Map<String, String> trainParams,
+      Map<String, String> reportMap) {
+    super(trainParams, reportMap);
+  }
+
+  @Override
+  public MaxentModel doTrain(DataIndexer indexer) throws IOException {
+    
+    // TODO: Lets use the predMap here as well for encoding
+    int numberOfOutcomes = indexer.getOutcomeLabels().length;
+    int numberOfFeatures = indexer.getPredLabels().length;
+    
+    AdaptiveLogisticRegression pa = new 
AdaptiveLogisticRegression(numberOfOutcomes,
+        numberOfFeatures, new L1());
+    
+    // TODO: Make these parameters configurable ...
+    //  what are good values ?! 
+    pa.setInterval(800);
+    pa.setAveragingWindow(500);
+    
+    for (int k = 0; k < iterations; k++) {
+      trainOnlineLearner(indexer, pa);
+      
+      // What should be reported at the end of every iteration ?!
+      System.out.println("Iteration " + (k + 1));
+    }
+    
+    pa.close();
+    
+    return new VectorClassifierModel(pa.getBest().getPayload().getLearner(),
+        indexer.getOutcomeLabels(), createPrepMap(indexer));
+  }
+
+  @Override
+  public boolean isSortAndMerge() {
+    return true;
+  }
+}

Added: 
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/LogisticRegressionTrainer.java
URL: 
http://svn.apache.org/viewvc/opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/LogisticRegressionTrainer.java?rev=1556629&view=auto
==============================================================================
--- 
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/LogisticRegressionTrainer.java
 (added)
+++ 
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/LogisticRegressionTrainer.java
 Wed Jan  8 20:39:08 2014
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package opennlp.addons.mahout;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import opennlp.tools.ml.AbstractEventTrainer;
+import opennlp.tools.ml.model.DataIndexer;
+import opennlp.tools.ml.model.MaxentModel;
+
+import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
+import org.apache.mahout.classifier.sgd.L1;
+import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
+import org.apache.mahout.classifier.sgd.PassiveAggressive;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+
+public class LogisticRegressionTrainer extends AbstractOnlineLearnerTrainer {
+  
+  public LogisticRegressionTrainer(Map<String, String> trainParams,
+      Map<String, String> reportMap) {
+    super(trainParams, reportMap);
+  }
+
+  @Override
+  public MaxentModel doTrain(DataIndexer indexer) throws IOException {
+    
+    // TODO: Lets use the predMap here as well for encoding
+    
+    int outcomes[] = indexer.getOutcomeList();
+    
+    int cardinality = indexer.getPredLabels().length;
+    
+    
+    AdaptiveLogisticRegression pa = new 
AdaptiveLogisticRegression(indexer.getOutcomeLabels().length,
+        cardinality, new L1());
+    
+    pa.setInterval(800);
+    pa.setAveragingWindow(500);
+    
+//    PassiveAggressive pa = new 
PassiveAggressive(indexer.getOutcomeLabels().length, cardinality);
+//    pa.learningRate(10000);
+    
+//    OnlineLogisticRegression pa = new 
OnlineLogisticRegression(indexer.getOutcomeLabels().length, cardinality,
+//        new L1());
+//    
+//    pa.alpha(1).stepOffset(250)
+//    .decayExponent(0.9)
+//    .lambda(3.0e-5)
+//    .learningRate(3000);
+    
+    // TODO: Should we do both ?! AdaptiveLogisticRegression ?! 
+    
+    for (int k = 0; k < iterations; k++) {
+      trainOnlineLearner(indexer, pa);
+      
+      // What should be reported at the end of every iteration ?!
+      System.out.println("Iteration " + (k + 1));
+    }
+    
+    pa.close();
+    
+    Map<String, Integer> predMap = new HashMap<String, Integer>();
+    
+    String predLabels[] = indexer.getPredLabels();
+    for (int i = 0; i < predLabels.length; i++) {
+      predMap.put(predLabels[i], i);
+    }
+    
+    return new VectorClassifierModel(pa.getBest().getPayload().getLearner(), 
indexer.getOutcomeLabels(), predMap);
+    
+//    return new VectorClassifierModel(pa, indexer.getOutcomeLabels(), 
predMap);
+  }
+
+  @Override
+  public boolean isSortAndMerge() {
+    return true;
+  }
+}

Added: 
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/OnlineLogisticRegressionTrainer.java
URL: 
http://svn.apache.org/viewvc/opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/OnlineLogisticRegressionTrainer.java?rev=1556629&view=auto
==============================================================================
--- 
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/OnlineLogisticRegressionTrainer.java
 (added)
+++ 
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/OnlineLogisticRegressionTrainer.java
 Wed Jan  8 20:39:08 2014
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package opennlp.addons.mahout;
+
+import java.io.IOException;
+import java.util.Map;
+
+import opennlp.tools.ml.model.DataIndexer;
+import opennlp.tools.ml.model.MaxentModel;
+
+import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
+import org.apache.mahout.classifier.sgd.L1;
+import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
+
+public class OnlineLogisticRegressionTrainer extends 
AbstractOnlineLearnerTrainer {
+  
+  public OnlineLogisticRegressionTrainer(Map<String, String> trainParams,
+      Map<String, String> reportMap) {
+    super(trainParams, reportMap);
+  }
+
+  @Override
+  public MaxentModel doTrain(DataIndexer indexer) throws IOException {
+    
+    // TODO: Lets use the predMap here as well for encoding
+    int numberOfOutcomes = indexer.getOutcomeLabels().length;
+    int numberOfFeatures = indexer.getPredLabels().length;
+    
+    // TODO: Make these parameters configurable ...
+    OnlineLogisticRegression pa = new OnlineLogisticRegression(
+        numberOfOutcomes, numberOfFeatures, new L1());
+
+    pa.alpha(1).stepOffset(250).decayExponent(0.9).lambda(3.0e-5)
+        .learningRate(3000);
+    
+    for (int k = 0; k < iterations; k++) {
+      trainOnlineLearner(indexer, pa);
+      
+      // What should be reported at the end of every iteration ?!
+      System.out.println("Iteration " + (k + 1));
+    }
+    
+    pa.close();
+    
+    return new VectorClassifierModel(pa, indexer.getOutcomeLabels(), 
createPrepMap(indexer));
+  }
+
+  @Override
+  public boolean isSortAndMerge() {
+    return true;
+  }
+}

Added: 
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/PassiveAggressiveTrainer.java
URL: 
http://svn.apache.org/viewvc/opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/PassiveAggressiveTrainer.java?rev=1556629&view=auto
==============================================================================
--- 
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/PassiveAggressiveTrainer.java
 (added)
+++ 
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/PassiveAggressiveTrainer.java
 Wed Jan  8 20:39:08 2014
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package opennlp.addons.mahout;
+
+import java.io.IOException;
+import java.util.Map;
+
+import opennlp.tools.ml.model.DataIndexer;
+import opennlp.tools.ml.model.MaxentModel;
+
+import org.apache.mahout.classifier.sgd.PassiveAggressive;
+
+public class PassiveAggressiveTrainer extends AbstractOnlineLearnerTrainer {
+  
+  public PassiveAggressiveTrainer(Map<String, String> trainParams,
+      Map<String, String> reportMap) {
+    super(trainParams, reportMap);
+  }
+
+  @Override
+  public MaxentModel doTrain(DataIndexer indexer) throws IOException {
+    
+    // TODO: Lets use the predMap here as well for encoding
+    int numberOfOutcomes = indexer.getOutcomeLabels().length;
+    int numberOfFeatures = indexer.getPredLabels().length;
+    
+    PassiveAggressive pa = new PassiveAggressive(numberOfOutcomes, 
numberOfFeatures);    
+    
+    for (int k = 0; k < iterations; k++) {
+      trainOnlineLearner(indexer, pa);
+      
+      // What should be reported at the end of every iteration ?!
+      System.out.println("Iteration " + (k + 1));
+    }
+    
+    pa.close();
+    
+    return new VectorClassifierModel(pa, indexer.getOutcomeLabels(), 
createPrepMap(indexer));
+  }
+
+  @Override
+  public boolean isSortAndMerge() {
+    return true;
+  }
+}

Added: 
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/VectorClassifierModel.java
URL: 
http://svn.apache.org/viewvc/opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/VectorClassifierModel.java?rev=1556629&view=auto
==============================================================================
--- 
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/VectorClassifierModel.java
 (added)
+++ 
opennlp/addons/mahout-addon/src/main/java/opennlp/addons/mahout/VectorClassifierModel.java
 Wed Jan  8 20:39:08 2014
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package opennlp.addons.mahout;
+
+import java.util.Map;
+
+import opennlp.tools.ml.model.MaxentModel;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+
+// TODO: Would be nice to have an abstract maxent model impl ..
+
+public class VectorClassifierModel implements MaxentModel {
+
+  private final AbstractVectorClassifier classifier;
+  private final String[] outcomeLabels;
+  private final Map<String, Integer> predMap;
+  
+  public VectorClassifierModel(AbstractVectorClassifier pa, String 
outcomeLabels[],
+      Map<String, Integer> predMap) {
+    this.classifier = pa;
+    // TODO: We should make a copy, so the model is immutable ...
+    this.outcomeLabels = outcomeLabels;
+    this.predMap = predMap;
+  }
+
+  public double[] eval(String[] features) {
+    Vector vector = new RandomAccessSparseVector(predMap.size());
+    
+    for (String feature : features) {
+      Integer featureId = predMap.get(feature);
+      
+      if (featureId != null) {
+        vector.set(featureId, vector.get(featureId) + 1);
+      }
+    }
+    
+    Vector resultVector = classifier.classifyFull(vector);
+    
+    double outcomes[] = new double[classifier.numCategories()];
+    
+    for (int i = 0; i < outcomes.length; i++) {
+      outcomes[i] = resultVector.get(i);
+    }
+    
+    return outcomes;
+  }
+
+  public double[] eval(String[] context, double[] probs) {
+    return eval(context);
+  }
+
+  public double[] eval(String[] context, float[] values) {
+    return eval(context);
+  }
+
+  @Override
+  public String getBestOutcome(double[] ocs) {
+    int best = 0;
+    for (int i = 1; i < ocs.length; i++)
+        if (ocs[i] > ocs[best]) best = i;
+    return outcomeLabels[best];
+  }
+
+  @Override
+  public String getAllOutcomes(double[] outcomes) {
+    return null;
+  }
+
+  @Override
+  public String getOutcome(int i) {
+    return outcomeLabels[i];
+  }
+
+  @Override
+  public int getIndex(String outcome) {
+    for (int i = 0; i < outcomeLabels.length; i++) {
+      if (outcomeLabels[i].equals(outcome)) {
+        return i;
+      }
+    }
+    
+    return -1;
+  }
+
+  @Override
+  public int getNumOutcomes() {
+    return outcomeLabels.length;
+  }
+}


Reply via email to