Author: gsingers
Date: Tue Nov  1 18:59:35 2011
New Revision: 1196206

URL: http://svn.apache.org/viewvc?rev=1196206&view=rev
Log:
MAHOUT-857: hook in support for SGD to the 20 newsgroups

Added:
    mahout/trunk/examples/bin/classify-20newsgroups.sh   (contents, props 
changed)
      - copied, changed from r1196037, 
mahout/trunk/examples/bin/build-20news-bayes.sh
    
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/NewsgroupHelper.java
    
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
Removed:
    mahout/trunk/examples/bin/build-20news-bayes.sh
Modified:
    
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java

Copied: mahout/trunk/examples/bin/classify-20newsgroups.sh (from r1196037, 
mahout/trunk/examples/bin/build-20news-bayes.sh)
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/bin/classify-20newsgroups.sh?p2=mahout/trunk/examples/bin/classify-20newsgroups.sh&p1=mahout/trunk/examples/bin/build-20news-bayes.sh&r1=1196037&r2=1196206&rev=1196206&view=diff
==============================================================================
--- mahout/trunk/examples/bin/build-20news-bayes.sh (original)
+++ mahout/trunk/examples/bin/classify-20newsgroups.sh Tue Nov  1 18:59:35 2011
@@ -17,7 +17,7 @@
 #
 
 #
-# Downloads the 20newsgroups dataset, trains and tests a bayes classifier. 
+# Downloads the 20newsgroups dataset, trains and tests a classifier.
 #
 # To run:  change into the mahout directory and type:
 #  examples/bin/build-20news.sh
@@ -29,6 +29,20 @@ fi
 START_PATH=`pwd`
 
 WORK_DIR=/tmp/mahout-work-${USER}
+if [ "$1" = "-ni" ]; then
+  alg=rec
+else
+  algorithm=( naivebayes sgd clean)
+
+  echo "Please select a number to choose the corresponding task to run"
+  echo "1. ${algorithm[0]}"
+  echo "2. ${algorithm[1]}"
+  echo "3. ${algorithm[2]} -- cleans up the work area in $WORK_DIR"
+  read -p "Enter your choice : " choice
+
+  echo "ok. You chose $choice and we'll use ${algorithm[$choice-1]}"
+  alg=${algorithm[$choice-1]}
+fi
 
 echo "creating work directory at ${WORK_DIR}"
 
@@ -49,60 +63,72 @@ cd $START_PATH
 cd ../..
 
 set -e
-echo "Preparing Training Data"
-./bin/mahout org.apache.mahout.classifier.bayes.PrepareTwentyNewsgroups \
-  -p ${WORK_DIR}/20news-bydate/20news-bydate-train \
-  -o ${WORK_DIR}/20news-bydate/bayes-train-input \
-  -a org.apache.mahout.vectorizer.DefaultAnalyzer \
-  -c UTF-8
-
-echo "Preparing Test Data"
-
-./bin/mahout org.apache.mahout.classifier.bayes.PrepareTwentyNewsgroups \
-  -p ${WORK_DIR}/20news-bydate/20news-bydate-test \
-  -o ${WORK_DIR}/20news-bydate/bayes-test-input \
-  -a org.apache.mahout.vectorizer.DefaultAnalyzer \
-  -c UTF-8 
-
-TEST_METHOD="sequential"
-
-# if we're set up to run on a cluster..
-if [ "$HADOOP_HOME" != "" ]; then
-    # mapreduce test method used on hadoop
-    TEST_METHOD="mapreduce"
-
-    set +e 
-    hadoop dfs -rmr \
-      ${WORK_DIR}/20news-bydate/bayes-train-input 
-
-    hadoop dfs -rmr \
-      ${WORK_DIR}/20news-bydate/bayes-test-input
-
-    set -e
-    hadoop dfs -put \
-      ${WORK_DIR}/20news-bydate/bayes-train-input \
-      ${WORK_DIR}/20news-bydate/bayes-train-input 
-
-    hadoop dfs -put \
-      ${WORK_DIR}/20news-bydate/bayes-test-input \
-      ${WORK_DIR}/20news-bydate/bayes-test-input
-fi
 
+if [ "x$alg" == "xnaivebayes" ]; then
+  echo "Preparing Training Data"
+  ./bin/mahout org.apache.mahout.classifier.bayes.PrepareTwentyNewsgroups \
+    -p ${WORK_DIR}/20news-bydate/20news-bydate-train \
+    -o ${WORK_DIR}/20news-bydate/bayes-train-input \
+    -a org.apache.mahout.vectorizer.DefaultAnalyzer \
+    -c UTF-8
+
+  echo "Preparing Test Data"
+
+  ./bin/mahout org.apache.mahout.classifier.bayes.PrepareTwentyNewsgroups \
+    -p ${WORK_DIR}/20news-bydate/20news-bydate-test \
+    -o ${WORK_DIR}/20news-bydate/bayes-test-input \
+    -a org.apache.mahout.vectorizer.DefaultAnalyzer \
+    -c UTF-8
+
+  TEST_METHOD="sequential"
+
+  # if we're set up to run on a cluster..
+  if [ "$HADOOP_HOME" != "" ]; then
+      # mapreduce test method used on hadoop
+      TEST_METHOD="mapreduce"
+
+      set +e
+      hadoop dfs -rmr \
+        ${WORK_DIR}/20news-bydate/bayes-train-input
+
+      hadoop dfs -rmr \
+        ${WORK_DIR}/20news-bydate/bayes-test-input
+
+      set -e
+      hadoop dfs -put \
+        ${WORK_DIR}/20news-bydate/bayes-train-input \
+        ${WORK_DIR}/20news-bydate/bayes-train-input
+
+      hadoop dfs -put \
+        ${WORK_DIR}/20news-bydate/bayes-test-input \
+        ${WORK_DIR}/20news-bydate/bayes-test-input
+  fi
 
-./bin/mahout trainclassifier \
-  -i ${WORK_DIR}/20news-bydate/bayes-train-input \
-  -o ${WORK_DIR}/20news-bydate/bayes-model \
-  -type bayes \
-  -ng 1 \
-  -source hdfs
-
-./bin/mahout testclassifier \
-  -m ${WORK_DIR}/20news-bydate/bayes-model \
-  -d ${WORK_DIR}/20news-bydate/bayes-test-input \
-  -type bayes \
-  -ng 1 \
-  -source hdfs \
-  -method ${TEST_METHOD}
 
+  ./bin/mahout trainclassifier \
+    -i ${WORK_DIR}/20news-bydate/bayes-train-input \
+    -o ${WORK_DIR}/20news-bydate/bayes-model \
+    -type bayes \
+    -ng 1 \
+    -source hdfs
+
+  ./bin/mahout testclassifier \
+    -m ${WORK_DIR}/20news-bydate/bayes-model \
+    -d ${WORK_DIR}/20news-bydate/bayes-test-input \
+    -type bayes \
+    -ng 1 \
+    -source hdfs \
+    -method ${TEST_METHOD}
+elif [ "x$alg" == "xsgd" ]; then
+  if [ ! -e "/tmp/news-group.model" ]; then
+    echo "Training on ${WORK_DIR}/20news-bydate/20news-bydate-train/"
+    ./bin/mahout org.apache.mahout.classifier.sgd.TrainNewsGroups 
${WORK_DIR}/20news-bydate/20news-bydate-train/
+  fi
+  echo "Testing on ${WORK_DIR}/20news-bydate/20news-bydate-test/ with model: 
/tmp/news-group.model"
+  ./bin/mahout org.apache.mahout.classifier.sgd.TestNewsGroups --input 
${WORK_DIR}/20news-bydate/20news-bydate-test/ --model /tmp/news-group.model
+elif [ "x$alg" == "xclean" ]; then
+  rm -rf ${WORK_DIR}
+  rm -rf /tmp/news-group.model
+fi
 # Remove the work directory
-rm -rf ${WORK_DIR}
+#

Propchange: mahout/trunk/examples/bin/classify-20newsgroups.sh
------------------------------------------------------------------------------
    svn:executable = *

Added: 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/NewsgroupHelper.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/NewsgroupHelper.java?rev=1196206&view=auto
==============================================================================
--- 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/NewsgroupHelper.java
 (added)
+++ 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/NewsgroupHelper.java
 Tue Nov  1 18:59:35 2011
@@ -0,0 +1,100 @@
+package org.apache.mahout.classifier.sgd;
+
+
+import com.google.common.base.Charsets;
+import com.google.common.collect.ConcurrentHashMultiset;
+import com.google.common.collect.Multiset;
+import com.google.common.io.Closeables;
+import com.google.common.io.Files;
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.TokenStream;
+import org.apache.lucene.analysis.standard.StandardAnalyzer;
+import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
+import org.apache.lucene.util.Version;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
+import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
+import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+import java.io.Reader;
+import java.io.StringReader;
+import java.text.SimpleDateFormat;
+import java.util.Collection;
+import java.util.Date;
+import java.util.Locale;
+import java.util.Random;
+
+/**
+ *
+ *
+ **/
+public class NewsgroupHelper {
+
+  static final Random rand = RandomUtils.getRandom();
+  static final SimpleDateFormat[] DATE_FORMATS = {
+    new SimpleDateFormat("", Locale.ENGLISH),
+    new SimpleDateFormat("MMM-yyyy", Locale.ENGLISH),
+    new SimpleDateFormat("dd-MMM-yyyy HH:mm:ss", Locale.ENGLISH)
+  };
+  static final Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_31);
+  static final FeatureVectorEncoder encoder = new 
StaticWordValueEncoder("body");
+  static final FeatureVectorEncoder bias = new 
ConstantValueEncoder("Intercept");
+  public static final int FEATURES = 10000;
+  // 1997-01-15 00:01:00 GMT
+  static final long DATE_REFERENCE = 853286460;
+  static final long MONTH = 30 * 24 * 3600;
+  static final long WEEK = 7 * 24 * 3600;
+
+  static Vector encodeFeatureVector(File file, int actual, int leakType, 
Multiset<String> overallCounts) throws IOException {
+    long date = (long) (1000 * (DATE_REFERENCE + actual * MONTH + 1 * WEEK * 
rand.nextDouble()));
+    Multiset<String> words = ConcurrentHashMultiset.create();
+
+    BufferedReader reader = Files.newReader(file, Charsets.UTF_8);
+    try {
+      String line = reader.readLine();
+      Reader dateString = new StringReader(DATE_FORMATS[leakType % 
3].format(new Date(date)));
+      countWords(analyzer, words, dateString, overallCounts);
+      while (line != null && line.length() > 0) {
+        boolean countHeader = (
+          line.startsWith("From:") || line.startsWith("Subject:") ||
+            line.startsWith("Keywords:") || line.startsWith("Summary:")) && 
leakType < 6;
+        do {
+          Reader in = new StringReader(line);
+          if (countHeader) {
+            countWords(analyzer, words, in, overallCounts);
+          }
+          line = reader.readLine();
+        } while (line != null && line.startsWith(" "));
+      }
+      if (leakType < 3) {
+        countWords(analyzer, words, reader, overallCounts);
+      }
+    } finally {
+      Closeables.closeQuietly(reader);
+    }
+
+    Vector v = new RandomAccessSparseVector(FEATURES);
+    bias.addToVector("", 1, v);
+    for (String word : words.elementSet()) {
+      encoder.addToVector(word, Math.log1p(words.count(word)), v);
+    }
+
+    return v;
+  }
+
+  static void countWords(Analyzer analyzer, Collection<String> words, Reader 
in, Multiset<String> overallCounts) throws IOException {
+    TokenStream ts = analyzer.reusableTokenStream("text", in);
+    ts.addAttribute(CharTermAttribute.class);
+    ts.reset();
+    while (ts.incrementToken()) {
+      String s = ts.getAttribute(CharTermAttribute.class).toString();
+      words.add(s);
+    }
+    overallCounts.addAll(words);
+  }
+}

Added: 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java?rev=1196206&view=auto
==============================================================================
--- 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
 (added)
+++ 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
 Tue Nov  1 18:59:35 2011
@@ -0,0 +1,148 @@
+package org.apache.mahout.classifier.sgd;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Multiset;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.mahout.classifier.ClassifierResult;
+import org.apache.mahout.classifier.ConfusionMatrix;
+import org.apache.mahout.classifier.ResultAnalyzer;
+import org.apache.mahout.classifier.evaluation.Auc;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Locale;
+
+/**
+ * Run the 20 news groups test data through SGD, as trained by {@link 
org.apache.mahout.classifier.sgd.TrainNewsGroups}.
+ */
+public class TestNewsGroups {
+  protected String inputFile;
+  protected String modelFile;
+
+
+
+  private TestNewsGroups() {
+
+  }
+
+  public static void main(String[] args) throws IOException {
+    TestNewsGroups runner = new TestNewsGroups();
+    if (runner.parseArgs(args)) {
+      runner.run(new PrintWriter(System.out, true));
+    }
+  }
+
+  public void run(PrintWriter output) throws IOException {
+
+    File base = new File(inputFile);
+    //contains the best model
+    OnlineLogisticRegression classifier = ModelSerializer.readBinary(new 
FileInputStream(modelFile), OnlineLogisticRegression.class);
+
+
+    Dictionary newsGroups = new Dictionary();
+    Multiset<String> overallCounts = HashMultiset.create();
+
+    List<File> files = Lists.newArrayList();
+    for (File newsgroup : base.listFiles()) {
+      if (newsgroup.isDirectory()) {
+        newsGroups.intern(newsgroup.getName());
+        files.addAll(Arrays.asList(newsgroup.listFiles()));
+      }
+    }
+    System.out.printf("%d test files\n", files.size());
+    ResultAnalyzer ra = new ResultAnalyzer(newsGroups.values(), "DEFAULT");
+    for (File file : files) {
+      String ng = file.getParentFile().getName();
+
+      int actual = newsGroups.intern(ng);
+      Vector input = NewsgroupHelper.encodeFeatureVector(file, actual, 0, 
overallCounts);//no leak type ensures this is a normal vector
+      Vector result = classifier.classifyFull(input);
+      int cat = result.maxValueIndex();
+      double score = result.maxValue();
+      ClassifierResult cr = new ClassifierResult(newsGroups.values().get(cat), 
score);
+      ra.addInstance(newsGroups.values().get(actual), cr);
+
+    }
+    output.printf("%s\n\n", ra.toString());
+  }
+
+  protected boolean parseArgs(String[] args) {
+    DefaultOptionBuilder builder = new DefaultOptionBuilder();
+
+    Option help = builder.withLongName("help").withDescription("print this 
list").create();
+
+    ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+    Option inputFileOption = builder.withLongName("input")
+            .withRequired(true)
+            
.withArgument(argumentBuilder.withName("input").withMaximum(1).create())
+            .withDescription("where to get training data")
+            .create();
+
+    Option modelFileOption = builder.withLongName("model")
+            .withRequired(true)
+            
.withArgument(argumentBuilder.withName("model").withMaximum(1).create())
+            .withDescription("where to get a model")
+            .create();
+
+    Group normalArgs = new GroupBuilder()
+            .withOption(help)
+            .withOption(inputFileOption)
+            .withOption(modelFileOption)
+            .create();
+
+    Parser parser = new Parser();
+    parser.setHelpOption(help);
+    parser.setHelpTrigger("--help");
+    parser.setGroup(normalArgs);
+    parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+    CommandLine cmdLine = parser.parseAndHelp(args);
+
+    if (cmdLine == null) {
+      return false;
+    }
+
+    inputFile = getStringArgument(cmdLine, inputFileOption);
+    modelFile = getStringArgument(cmdLine, modelFileOption);
+    return true;
+  }
+
+  protected boolean getBooleanArgument(CommandLine cmdLine, Option option) {
+    return cmdLine.hasOption(option);
+  }
+
+  protected String getStringArgument(CommandLine cmdLine, Option inputFile) {
+    return (String) cmdLine.getValue(inputFile);
+  }
+}

Modified: 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java?rev=1196206&r1=1196205&r2=1196206&view=diff
==============================================================================
--- 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
 (original)
+++ 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
 Tue Nov  1 18:59:35 2011
@@ -17,45 +17,24 @@
 
 package org.apache.mahout.classifier.sgd;
 
-import com.google.common.base.Charsets;
-import com.google.common.collect.ConcurrentHashMultiset;
 import com.google.common.collect.HashMultiset;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Multiset;
 import com.google.common.collect.Ordering;
-import com.google.common.io.Closeables;
-import com.google.common.io.Files;
-import org.apache.lucene.analysis.Analyzer;
-import org.apache.lucene.analysis.TokenStream;
-import org.apache.lucene.analysis.standard.StandardAnalyzer;
-
-import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
-import org.apache.lucene.util.Version;
-import org.apache.mahout.common.RandomUtils;
+
 import org.apache.mahout.ep.State;
 import org.apache.mahout.math.Matrix;
-import org.apache.mahout.math.RandomAccessSparseVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.function.Functions;
 import org.apache.mahout.math.function.DoubleFunction;
-import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
 import org.apache.mahout.vectorizer.encoders.Dictionary;
-import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
-import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
 
-import java.io.BufferedReader;
 import java.io.File;
 import java.io.IOException;
-import java.io.Reader;
-import java.io.StringReader;
-import java.text.SimpleDateFormat;
 import java.util.Arrays;
-import java.util.Collection;
 import java.util.Collections;
-import java.util.Date;
 import java.util.List;
-import java.util.Locale;
 import java.util.Map;
 import java.util.Random;
 import java.util.Set;
@@ -107,24 +86,8 @@ import java.util.Set;
  */
 public final class TrainNewsGroups {
 
-  private static final int FEATURES = 10000;
-  // 1997-01-15 00:01:00 GMT
-  private static final long DATE_REFERENCE = 853286460;
-  private static final long MONTH = 30 * 24 * 3600;
-  private static final long WEEK = 7 * 24 * 3600;
-
-  private static final Random rand = RandomUtils.getRandom();
-
   private static final String[] LEAK_LABELS = {"none", "month-year", 
"day-month-year"};
-  private static final SimpleDateFormat[] DATE_FORMATS = {
-    new SimpleDateFormat("", Locale.ENGLISH),
-    new SimpleDateFormat("MMM-yyyy", Locale.ENGLISH),
-    new SimpleDateFormat("dd-MMM-yyyy HH:mm:ss", Locale.ENGLISH)
-  };
-
-  private static final Analyzer analyzer = new 
StandardAnalyzer(Version.LUCENE_31);
-  private static final FeatureVectorEncoder encoder = new 
StaticWordValueEncoder("body");
-  private static final FeatureVectorEncoder bias = new 
ConstantValueEncoder("Intercept");
+
   private static Multiset<String> overallCounts;
 
   private TrainNewsGroups() {
@@ -142,8 +105,8 @@ public final class TrainNewsGroups {
 
     Dictionary newsGroups = new Dictionary();
 
-    encoder.setProbes(2);
-    AdaptiveLogisticRegression learningAlgorithm = new 
AdaptiveLogisticRegression(20, FEATURES, new L1());
+    NewsgroupHelper.encoder.setProbes(2);
+    AdaptiveLogisticRegression learningAlgorithm = new 
AdaptiveLogisticRegression(20, NewsgroupHelper.FEATURES, new L1());
     learningAlgorithm.setInterval(800);
     learningAlgorithm.setAveragingWindow(500);
 
@@ -163,11 +126,11 @@ public final class TrainNewsGroups {
     int k = 0;
     double step = 0;
     int[] bumps = {1, 2, 5};
-    for (File file : files.subList(0, 3000)) {
+    for (File file : files) {
       String ng = file.getParentFile().getName();
       int actual = newsGroups.intern(ng);
 
-      Vector v = encodeFeatureVector(file, actual, leakType);
+      Vector v = NewsgroupHelper.encodeFeatureVector(file, actual, leakType, 
overallCounts);
       learningAlgorithm.train(actual, v);
 
       k++;
@@ -261,15 +224,15 @@ public final class TrainNewsGroups {
     Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap();
     ModelDissector md = new ModelDissector();
 
-    encoder.setTraceDictionary(traceDictionary);
-    bias.setTraceDictionary(traceDictionary);
+    NewsgroupHelper.encoder.setTraceDictionary(traceDictionary);
+    NewsgroupHelper.bias.setTraceDictionary(traceDictionary);
 
-    for (File file : permute(files, rand).subList(0, 500)) {
+    for (File file : permute(files, NewsgroupHelper.rand).subList(0, 500)) {
       String ng = file.getParentFile().getName();
       int actual = newsGroups.intern(ng);
 
       traceDictionary.clear();
-      Vector v = encodeFeatureVector(file, actual, leakType);
+      Vector v = NewsgroupHelper.encodeFeatureVector(file, actual, leakType, 
overallCounts);
       md.update(v, traceDictionary, model);
     }
 
@@ -282,54 +245,6 @@ public final class TrainNewsGroups {
     }
   }
 
-  private static Vector encodeFeatureVector(File file, int actual, int 
leakType) throws IOException {
-    long date = (long) (1000 * (DATE_REFERENCE + actual * MONTH + 1 * WEEK * 
rand.nextDouble()));
-    Multiset<String> words = ConcurrentHashMultiset.create();
-
-    BufferedReader reader = Files.newReader(file, Charsets.UTF_8);
-    try {
-      String line = reader.readLine();
-      Reader dateString = new StringReader(DATE_FORMATS[leakType % 
3].format(new Date(date)));
-      countWords(analyzer, words, dateString);
-      while (line != null && line.length() > 0) {
-        boolean countHeader = (
-          line.startsWith("From:") || line.startsWith("Subject:") ||
-            line.startsWith("Keywords:") || line.startsWith("Summary:")) && 
leakType < 6;
-        do {
-          Reader in = new StringReader(line);
-          if (countHeader) {
-            countWords(analyzer, words, in);
-          }
-          line = reader.readLine();
-        } while (line != null && line.startsWith(" "));
-      }
-      if (leakType < 3) {
-        countWords(analyzer, words, reader);
-      }
-    } finally {
-      Closeables.closeQuietly(reader);
-    }
-
-    Vector v = new RandomAccessSparseVector(FEATURES);
-    bias.addToVector("", 1, v);
-    for (String word : words.elementSet()) {
-      encoder.addToVector(word, Math.log1p(words.count(word)), v);
-    }
-
-    return v;
-  }
-
-  private static void countWords(Analyzer analyzer, Collection<String> words, 
Reader in) throws IOException {
-    TokenStream ts = analyzer.reusableTokenStream("text", in);
-    ts.addAttribute(CharTermAttribute.class);
-    ts.reset();
-    while (ts.incrementToken()) {
-      String s = ts.getAttribute(CharTermAttribute.class).toString();
-      words.add(s);
-    }
-    overallCounts.addAll(words);
-  }
-
   private static List<File> permute(Iterable<File> files, Random rand) {
     List<File> r = Lists.newArrayList();
     for (File file : files) {


Reply via email to