Author: gsingers
Date: Tue Nov 1 18:59:35 2011
New Revision: 1196206
URL: http://svn.apache.org/viewvc?rev=1196206&view=rev
Log:
MAHOUT-857: hook in support for SGD to the 20 newsgroups
Added:
mahout/trunk/examples/bin/classify-20newsgroups.sh (contents, props
changed)
- copied, changed from r1196037,
mahout/trunk/examples/bin/build-20news-bayes.sh
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/NewsgroupHelper.java
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
Removed:
mahout/trunk/examples/bin/build-20news-bayes.sh
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
Copied: mahout/trunk/examples/bin/classify-20newsgroups.sh (from r1196037,
mahout/trunk/examples/bin/build-20news-bayes.sh)
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/bin/classify-20newsgroups.sh?p2=mahout/trunk/examples/bin/classify-20newsgroups.sh&p1=mahout/trunk/examples/bin/build-20news-bayes.sh&r1=1196037&r2=1196206&rev=1196206&view=diff
==============================================================================
--- mahout/trunk/examples/bin/build-20news-bayes.sh (original)
+++ mahout/trunk/examples/bin/classify-20newsgroups.sh Tue Nov 1 18:59:35 2011
@@ -17,7 +17,7 @@
#
#
-# Downloads the 20newsgroups dataset, trains and tests a bayes classifier.
+# Downloads the 20newsgroups dataset, trains and tests a classifier.
#
# To run: change into the mahout directory and type:
# examples/bin/build-20news.sh
@@ -29,6 +29,20 @@ fi
START_PATH=`pwd`
WORK_DIR=/tmp/mahout-work-${USER}
+if [ "$1" = "-ni" ]; then
+ alg=rec
+else
+ algorithm=( naivebayes sgd clean)
+
+ echo "Please select a number to choose the corresponding task to run"
+ echo "1. ${algorithm[0]}"
+ echo "2. ${algorithm[1]}"
+ echo "3. ${algorithm[2]} -- cleans up the work area in $WORK_DIR"
+ read -p "Enter your choice : " choice
+
+ echo "ok. You chose $choice and we'll use ${algorithm[$choice-1]}"
+ alg=${algorithm[$choice-1]}
+fi
echo "creating work directory at ${WORK_DIR}"
@@ -49,60 +63,72 @@ cd $START_PATH
cd ../..
set -e
-echo "Preparing Training Data"
-./bin/mahout org.apache.mahout.classifier.bayes.PrepareTwentyNewsgroups \
- -p ${WORK_DIR}/20news-bydate/20news-bydate-train \
- -o ${WORK_DIR}/20news-bydate/bayes-train-input \
- -a org.apache.mahout.vectorizer.DefaultAnalyzer \
- -c UTF-8
-
-echo "Preparing Test Data"
-
-./bin/mahout org.apache.mahout.classifier.bayes.PrepareTwentyNewsgroups \
- -p ${WORK_DIR}/20news-bydate/20news-bydate-test \
- -o ${WORK_DIR}/20news-bydate/bayes-test-input \
- -a org.apache.mahout.vectorizer.DefaultAnalyzer \
- -c UTF-8
-
-TEST_METHOD="sequential"
-
-# if we're set up to run on a cluster..
-if [ "$HADOOP_HOME" != "" ]; then
- # mapreduce test method used on hadoop
- TEST_METHOD="mapreduce"
-
- set +e
- hadoop dfs -rmr \
- ${WORK_DIR}/20news-bydate/bayes-train-input
-
- hadoop dfs -rmr \
- ${WORK_DIR}/20news-bydate/bayes-test-input
-
- set -e
- hadoop dfs -put \
- ${WORK_DIR}/20news-bydate/bayes-train-input \
- ${WORK_DIR}/20news-bydate/bayes-train-input
-
- hadoop dfs -put \
- ${WORK_DIR}/20news-bydate/bayes-test-input \
- ${WORK_DIR}/20news-bydate/bayes-test-input
-fi
+if [ "x$alg" == "xnaivebayes" ]; then
+ echo "Preparing Training Data"
+ ./bin/mahout org.apache.mahout.classifier.bayes.PrepareTwentyNewsgroups \
+ -p ${WORK_DIR}/20news-bydate/20news-bydate-train \
+ -o ${WORK_DIR}/20news-bydate/bayes-train-input \
+ -a org.apache.mahout.vectorizer.DefaultAnalyzer \
+ -c UTF-8
+
+ echo "Preparing Test Data"
+
+ ./bin/mahout org.apache.mahout.classifier.bayes.PrepareTwentyNewsgroups \
+ -p ${WORK_DIR}/20news-bydate/20news-bydate-test \
+ -o ${WORK_DIR}/20news-bydate/bayes-test-input \
+ -a org.apache.mahout.vectorizer.DefaultAnalyzer \
+ -c UTF-8
+
+ TEST_METHOD="sequential"
+
+ # if we're set up to run on a cluster..
+ if [ "$HADOOP_HOME" != "" ]; then
+ # mapreduce test method used on hadoop
+ TEST_METHOD="mapreduce"
+
+ set +e
+ hadoop dfs -rmr \
+ ${WORK_DIR}/20news-bydate/bayes-train-input
+
+ hadoop dfs -rmr \
+ ${WORK_DIR}/20news-bydate/bayes-test-input
+
+ set -e
+ hadoop dfs -put \
+ ${WORK_DIR}/20news-bydate/bayes-train-input \
+ ${WORK_DIR}/20news-bydate/bayes-train-input
+
+ hadoop dfs -put \
+ ${WORK_DIR}/20news-bydate/bayes-test-input \
+ ${WORK_DIR}/20news-bydate/bayes-test-input
+ fi
-./bin/mahout trainclassifier \
- -i ${WORK_DIR}/20news-bydate/bayes-train-input \
- -o ${WORK_DIR}/20news-bydate/bayes-model \
- -type bayes \
- -ng 1 \
- -source hdfs
-
-./bin/mahout testclassifier \
- -m ${WORK_DIR}/20news-bydate/bayes-model \
- -d ${WORK_DIR}/20news-bydate/bayes-test-input \
- -type bayes \
- -ng 1 \
- -source hdfs \
- -method ${TEST_METHOD}
+ ./bin/mahout trainclassifier \
+ -i ${WORK_DIR}/20news-bydate/bayes-train-input \
+ -o ${WORK_DIR}/20news-bydate/bayes-model \
+ -type bayes \
+ -ng 1 \
+ -source hdfs
+
+ ./bin/mahout testclassifier \
+ -m ${WORK_DIR}/20news-bydate/bayes-model \
+ -d ${WORK_DIR}/20news-bydate/bayes-test-input \
+ -type bayes \
+ -ng 1 \
+ -source hdfs \
+ -method ${TEST_METHOD}
+elif [ "x$alg" == "xsgd" ]; then
+ if [ ! -e "/tmp/news-group.model" ]; then
+ echo "Training on ${WORK_DIR}/20news-bydate/20news-bydate-train/"
+ ./bin/mahout org.apache.mahout.classifier.sgd.TrainNewsGroups
${WORK_DIR}/20news-bydate/20news-bydate-train/
+ fi
+ echo "Testing on ${WORK_DIR}/20news-bydate/20news-bydate-test/ with model:
/tmp/news-group.model"
+ ./bin/mahout org.apache.mahout.classifier.sgd.TestNewsGroups --input
${WORK_DIR}/20news-bydate/20news-bydate-test/ --model /tmp/news-group.model
+elif [ "x$alg" == "xclean" ]; then
+ rm -rf ${WORK_DIR}
+ rm -rf /tmp/news-group.model
+fi
# Remove the work directory
-rm -rf ${WORK_DIR}
+#
Propchange: mahout/trunk/examples/bin/classify-20newsgroups.sh
------------------------------------------------------------------------------
svn:executable = *
Added:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/NewsgroupHelper.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/NewsgroupHelper.java?rev=1196206&view=auto
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/NewsgroupHelper.java
(added)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/NewsgroupHelper.java
Tue Nov 1 18:59:35 2011
@@ -0,0 +1,100 @@
+package org.apache.mahout.classifier.sgd;
+
+
+import com.google.common.base.Charsets;
+import com.google.common.collect.ConcurrentHashMultiset;
+import com.google.common.collect.Multiset;
+import com.google.common.io.Closeables;
+import com.google.common.io.Files;
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.TokenStream;
+import org.apache.lucene.analysis.standard.StandardAnalyzer;
+import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
+import org.apache.lucene.util.Version;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
+import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
+import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+import java.io.Reader;
+import java.io.StringReader;
+import java.text.SimpleDateFormat;
+import java.util.Collection;
+import java.util.Date;
+import java.util.Locale;
+import java.util.Random;
+
+/**
+ *
+ *
+ **/
+public class NewsgroupHelper {
+
+ static final Random rand = RandomUtils.getRandom();
+ static final SimpleDateFormat[] DATE_FORMATS = {
+ new SimpleDateFormat("", Locale.ENGLISH),
+ new SimpleDateFormat("MMM-yyyy", Locale.ENGLISH),
+ new SimpleDateFormat("dd-MMM-yyyy HH:mm:ss", Locale.ENGLISH)
+ };
+ static final Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_31);
+ static final FeatureVectorEncoder encoder = new
StaticWordValueEncoder("body");
+ static final FeatureVectorEncoder bias = new
ConstantValueEncoder("Intercept");
+ public static final int FEATURES = 10000;
+ // 1997-01-15 00:01:00 GMT
+ static final long DATE_REFERENCE = 853286460;
+ static final long MONTH = 30 * 24 * 3600;
+ static final long WEEK = 7 * 24 * 3600;
+
+ static Vector encodeFeatureVector(File file, int actual, int leakType,
Multiset<String> overallCounts) throws IOException {
+ long date = (long) (1000 * (DATE_REFERENCE + actual * MONTH + 1 * WEEK *
rand.nextDouble()));
+ Multiset<String> words = ConcurrentHashMultiset.create();
+
+ BufferedReader reader = Files.newReader(file, Charsets.UTF_8);
+ try {
+ String line = reader.readLine();
+ Reader dateString = new StringReader(DATE_FORMATS[leakType %
3].format(new Date(date)));
+ countWords(analyzer, words, dateString, overallCounts);
+ while (line != null && line.length() > 0) {
+ boolean countHeader = (
+ line.startsWith("From:") || line.startsWith("Subject:") ||
+ line.startsWith("Keywords:") || line.startsWith("Summary:")) &&
leakType < 6;
+ do {
+ Reader in = new StringReader(line);
+ if (countHeader) {
+ countWords(analyzer, words, in, overallCounts);
+ }
+ line = reader.readLine();
+ } while (line != null && line.startsWith(" "));
+ }
+ if (leakType < 3) {
+ countWords(analyzer, words, reader, overallCounts);
+ }
+ } finally {
+ Closeables.closeQuietly(reader);
+ }
+
+ Vector v = new RandomAccessSparseVector(FEATURES);
+ bias.addToVector("", 1, v);
+ for (String word : words.elementSet()) {
+ encoder.addToVector(word, Math.log1p(words.count(word)), v);
+ }
+
+ return v;
+ }
+
+ static void countWords(Analyzer analyzer, Collection<String> words, Reader
in, Multiset<String> overallCounts) throws IOException {
+ TokenStream ts = analyzer.reusableTokenStream("text", in);
+ ts.addAttribute(CharTermAttribute.class);
+ ts.reset();
+ while (ts.incrementToken()) {
+ String s = ts.getAttribute(CharTermAttribute.class).toString();
+ words.add(s);
+ }
+ overallCounts.addAll(words);
+ }
+}
Added:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java?rev=1196206&view=auto
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
(added)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
Tue Nov 1 18:59:35 2011
@@ -0,0 +1,148 @@
+package org.apache.mahout.classifier.sgd;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Multiset;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.mahout.classifier.ClassifierResult;
+import org.apache.mahout.classifier.ConfusionMatrix;
+import org.apache.mahout.classifier.ResultAnalyzer;
+import org.apache.mahout.classifier.evaluation.Auc;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Locale;
+
+/**
+ * Run the 20 news groups test data through SGD, as trained by {@link
org.apache.mahout.classifier.sgd.TrainNewsGroups}.
+ */
+public class TestNewsGroups {
+ protected String inputFile;
+ protected String modelFile;
+
+
+
+ private TestNewsGroups() {
+
+ }
+
+ public static void main(String[] args) throws IOException {
+ TestNewsGroups runner = new TestNewsGroups();
+ if (runner.parseArgs(args)) {
+ runner.run(new PrintWriter(System.out, true));
+ }
+ }
+
+ public void run(PrintWriter output) throws IOException {
+
+ File base = new File(inputFile);
+ //contains the best model
+ OnlineLogisticRegression classifier = ModelSerializer.readBinary(new
FileInputStream(modelFile), OnlineLogisticRegression.class);
+
+
+ Dictionary newsGroups = new Dictionary();
+ Multiset<String> overallCounts = HashMultiset.create();
+
+ List<File> files = Lists.newArrayList();
+ for (File newsgroup : base.listFiles()) {
+ if (newsgroup.isDirectory()) {
+ newsGroups.intern(newsgroup.getName());
+ files.addAll(Arrays.asList(newsgroup.listFiles()));
+ }
+ }
+ System.out.printf("%d test files\n", files.size());
+ ResultAnalyzer ra = new ResultAnalyzer(newsGroups.values(), "DEFAULT");
+ for (File file : files) {
+ String ng = file.getParentFile().getName();
+
+ int actual = newsGroups.intern(ng);
+ Vector input = NewsgroupHelper.encodeFeatureVector(file, actual, 0,
overallCounts);//no leak type ensures this is a normal vector
+ Vector result = classifier.classifyFull(input);
+ int cat = result.maxValueIndex();
+ double score = result.maxValue();
+ ClassifierResult cr = new ClassifierResult(newsGroups.values().get(cat),
score);
+ ra.addInstance(newsGroups.values().get(actual), cr);
+
+ }
+ output.printf("%s\n\n", ra.toString());
+ }
+
+ protected boolean parseArgs(String[] args) {
+ DefaultOptionBuilder builder = new DefaultOptionBuilder();
+
+ Option help = builder.withLongName("help").withDescription("print this
list").create();
+
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+ Option inputFileOption = builder.withLongName("input")
+ .withRequired(true)
+
.withArgument(argumentBuilder.withName("input").withMaximum(1).create())
+ .withDescription("where to get training data")
+ .create();
+
+ Option modelFileOption = builder.withLongName("model")
+ .withRequired(true)
+
.withArgument(argumentBuilder.withName("model").withMaximum(1).create())
+ .withDescription("where to get a model")
+ .create();
+
+ Group normalArgs = new GroupBuilder()
+ .withOption(help)
+ .withOption(inputFileOption)
+ .withOption(modelFileOption)
+ .create();
+
+ Parser parser = new Parser();
+ parser.setHelpOption(help);
+ parser.setHelpTrigger("--help");
+ parser.setGroup(normalArgs);
+ parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+ CommandLine cmdLine = parser.parseAndHelp(args);
+
+ if (cmdLine == null) {
+ return false;
+ }
+
+ inputFile = getStringArgument(cmdLine, inputFileOption);
+ modelFile = getStringArgument(cmdLine, modelFileOption);
+ return true;
+ }
+
+ protected boolean getBooleanArgument(CommandLine cmdLine, Option option) {
+ return cmdLine.hasOption(option);
+ }
+
+ protected String getStringArgument(CommandLine cmdLine, Option inputFile) {
+ return (String) cmdLine.getValue(inputFile);
+ }
+}
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java?rev=1196206&r1=1196205&r2=1196206&view=diff
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
(original)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
Tue Nov 1 18:59:35 2011
@@ -17,45 +17,24 @@
package org.apache.mahout.classifier.sgd;
-import com.google.common.base.Charsets;
-import com.google.common.collect.ConcurrentHashMultiset;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multiset;
import com.google.common.collect.Ordering;
-import com.google.common.io.Closeables;
-import com.google.common.io.Files;
-import org.apache.lucene.analysis.Analyzer;
-import org.apache.lucene.analysis.TokenStream;
-import org.apache.lucene.analysis.standard.StandardAnalyzer;
-
-import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
-import org.apache.lucene.util.Version;
-import org.apache.mahout.common.RandomUtils;
+
import org.apache.mahout.ep.State;
import org.apache.mahout.math.Matrix;
-import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.function.DoubleFunction;
-import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
import org.apache.mahout.vectorizer.encoders.Dictionary;
-import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
-import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
-import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
-import java.io.Reader;
-import java.io.StringReader;
-import java.text.SimpleDateFormat;
import java.util.Arrays;
-import java.util.Collection;
import java.util.Collections;
-import java.util.Date;
import java.util.List;
-import java.util.Locale;
import java.util.Map;
import java.util.Random;
import java.util.Set;
@@ -107,24 +86,8 @@ import java.util.Set;
*/
public final class TrainNewsGroups {
- private static final int FEATURES = 10000;
- // 1997-01-15 00:01:00 GMT
- private static final long DATE_REFERENCE = 853286460;
- private static final long MONTH = 30 * 24 * 3600;
- private static final long WEEK = 7 * 24 * 3600;
-
- private static final Random rand = RandomUtils.getRandom();
-
private static final String[] LEAK_LABELS = {"none", "month-year",
"day-month-year"};
- private static final SimpleDateFormat[] DATE_FORMATS = {
- new SimpleDateFormat("", Locale.ENGLISH),
- new SimpleDateFormat("MMM-yyyy", Locale.ENGLISH),
- new SimpleDateFormat("dd-MMM-yyyy HH:mm:ss", Locale.ENGLISH)
- };
-
- private static final Analyzer analyzer = new
StandardAnalyzer(Version.LUCENE_31);
- private static final FeatureVectorEncoder encoder = new
StaticWordValueEncoder("body");
- private static final FeatureVectorEncoder bias = new
ConstantValueEncoder("Intercept");
+
private static Multiset<String> overallCounts;
private TrainNewsGroups() {
@@ -142,8 +105,8 @@ public final class TrainNewsGroups {
Dictionary newsGroups = new Dictionary();
- encoder.setProbes(2);
- AdaptiveLogisticRegression learningAlgorithm = new
AdaptiveLogisticRegression(20, FEATURES, new L1());
+ NewsgroupHelper.encoder.setProbes(2);
+ AdaptiveLogisticRegression learningAlgorithm = new
AdaptiveLogisticRegression(20, NewsgroupHelper.FEATURES, new L1());
learningAlgorithm.setInterval(800);
learningAlgorithm.setAveragingWindow(500);
@@ -163,11 +126,11 @@ public final class TrainNewsGroups {
int k = 0;
double step = 0;
int[] bumps = {1, 2, 5};
- for (File file : files.subList(0, 3000)) {
+ for (File file : files) {
String ng = file.getParentFile().getName();
int actual = newsGroups.intern(ng);
- Vector v = encodeFeatureVector(file, actual, leakType);
+ Vector v = NewsgroupHelper.encodeFeatureVector(file, actual, leakType,
overallCounts);
learningAlgorithm.train(actual, v);
k++;
@@ -261,15 +224,15 @@ public final class TrainNewsGroups {
Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap();
ModelDissector md = new ModelDissector();
- encoder.setTraceDictionary(traceDictionary);
- bias.setTraceDictionary(traceDictionary);
+ NewsgroupHelper.encoder.setTraceDictionary(traceDictionary);
+ NewsgroupHelper.bias.setTraceDictionary(traceDictionary);
- for (File file : permute(files, rand).subList(0, 500)) {
+ for (File file : permute(files, NewsgroupHelper.rand).subList(0, 500)) {
String ng = file.getParentFile().getName();
int actual = newsGroups.intern(ng);
traceDictionary.clear();
- Vector v = encodeFeatureVector(file, actual, leakType);
+ Vector v = NewsgroupHelper.encodeFeatureVector(file, actual, leakType,
overallCounts);
md.update(v, traceDictionary, model);
}
@@ -282,54 +245,6 @@ public final class TrainNewsGroups {
}
}
- private static Vector encodeFeatureVector(File file, int actual, int
leakType) throws IOException {
- long date = (long) (1000 * (DATE_REFERENCE + actual * MONTH + 1 * WEEK *
rand.nextDouble()));
- Multiset<String> words = ConcurrentHashMultiset.create();
-
- BufferedReader reader = Files.newReader(file, Charsets.UTF_8);
- try {
- String line = reader.readLine();
- Reader dateString = new StringReader(DATE_FORMATS[leakType %
3].format(new Date(date)));
- countWords(analyzer, words, dateString);
- while (line != null && line.length() > 0) {
- boolean countHeader = (
- line.startsWith("From:") || line.startsWith("Subject:") ||
- line.startsWith("Keywords:") || line.startsWith("Summary:")) &&
leakType < 6;
- do {
- Reader in = new StringReader(line);
- if (countHeader) {
- countWords(analyzer, words, in);
- }
- line = reader.readLine();
- } while (line != null && line.startsWith(" "));
- }
- if (leakType < 3) {
- countWords(analyzer, words, reader);
- }
- } finally {
- Closeables.closeQuietly(reader);
- }
-
- Vector v = new RandomAccessSparseVector(FEATURES);
- bias.addToVector("", 1, v);
- for (String word : words.elementSet()) {
- encoder.addToVector(word, Math.log1p(words.count(word)), v);
- }
-
- return v;
- }
-
- private static void countWords(Analyzer analyzer, Collection<String> words,
Reader in) throws IOException {
- TokenStream ts = analyzer.reusableTokenStream("text", in);
- ts.addAttribute(CharTermAttribute.class);
- ts.reset();
- while (ts.incrementToken()) {
- String s = ts.getAttribute(CharTermAttribute.class).toString();
- words.add(s);
- }
- overallCounts.addAll(words);
- }
-
private static List<File> permute(Iterable<File> files, Random rand) {
List<File> r = Lists.newArrayList();
for (File file : files) {