Author: gsingers
Date: Sat Nov 12 08:19:18 2011
New Revision: 1201223
URL: http://svn.apache.org/viewvc?rev=1201223&view=rev
Log:
MAHOUT-851: add in SGD example for ASF email
Added:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java
- copied, changed from r1200329,
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java
Modified:
mahout/trunk/examples/bin/build-asf-email.sh
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
Modified: mahout/trunk/examples/bin/build-asf-email.sh
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/bin/build-asf-email.sh?rev=1201223&r1=1201222&r2=1201223&view=diff
==============================================================================
--- mahout/trunk/examples/bin/build-asf-email.sh (original)
+++ mahout/trunk/examples/bin/build-asf-email.sh Sat Nov 12 08:19:18 2011
@@ -109,11 +109,16 @@ elif [ "x$alg" == "xclassification" ]; t
echo "Please select a number to choose the corresponding algorithm to run"
echo "1. ${algorithm[0]}"
echo "2. ${algorithm[1]}"
-# echo "3. ${algorithm[2]}"
+ echo "3. ${algorithm[2]}"
read -p "Enter your choice : " choice
echo "ok. You chose $choice and we'll use ${algorithm[$choice-1]}"
classAlg=${algorithm[$choice-1]}
+
+ if [ "x$classAlg" == "xsgd" ]; then
+ echo "How many labels/projects are there in the data set:"
+ read -p "Enter your choice : " numLabels
+ fi
#Convert mail to be formatted as:
# label\ttext
# One per line
@@ -167,6 +172,7 @@ elif [ "x$alg" == "xclassification" ]; t
TRAIN="$SPLIT/train"
TEST="$SPLIT/test"
TEST_OUT="$CLASS/test-results"
+ MODELS="$CLASS/models"
LABEL="$SPLIT/labels"
if [ "x$OVER" == "xover" ] || [ ! -e "$MAIL_OUT/chunk-0" ]; then
echo "Converting Mail files to Sequence Files"
@@ -182,12 +188,14 @@ elif [ "x$alg" == "xclassification" ]; t
echo "Creating training and test inputs from $SEQ2SPLABEL"
$MAHOUT split --input $SEQ2SPLABEL --trainingOutput $TRAIN --testOutput
$TEST --randomSelectionPct 20 --overwrite --sequenceFiles
fi
- MODEL="$CLASS/model"
+ MODEL="$MODELS/asf.model"
+
echo "Running SGD Training"
- #$MAHOUT trainnb -i $TRAIN -o $MODEL --extractLabels --labelIndex $LABEL
--overwrite
+ $MAHOUT org.apache.mahout.classifier.sgd.TrainASFEmail $TRAIN $MODELS
$numLabels 5000
echo "Running Test"
-#$MAHOUT testnb -i $TEST -o $TEST_OUT -m $MODEL --labelIndex $LABEL --overwrite
+ $MODEL="$MODELS/asf.model"
+ $MAHOUT org.apache.mahout.classifier.sgd.TestASFEmail --input $TEST
--model $MODEL
fi
fi
Added:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java?rev=1201223&view=auto
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java
(added)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java
Sat Nov 12 08:19:18 2011
@@ -0,0 +1,149 @@
+package org.apache.mahout.classifier.sgd;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Multiset;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+
+/**
+ *
+ *
+ **/
+public class SGDHelper {
+ private static final String[] LEAK_LABELS = {"none", "month-year",
"day-month-year"};
+
+ public static void dissect(int leakType,
+ Dictionary newsGroups,
+ AdaptiveLogisticRegression learningAlgorithm,
+ Iterable<File> files, Multiset<String>
overallCounts) throws IOException {
+ CrossFoldLearner model =
learningAlgorithm.getBest().getPayload().getLearner();
+ model.close();
+
+ Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap();
+ ModelDissector md = new ModelDissector();
+
+ NewsgroupHelper helper = new NewsgroupHelper();
+ helper.getEncoder().setTraceDictionary(traceDictionary);
+ helper.getBias().setTraceDictionary(traceDictionary);
+
+ for (File file : permute(files, helper.getRandom()).subList(0, 500)) {
+ String ng = file.getParentFile().getName();
+ int actual = newsGroups.intern(ng);
+
+ traceDictionary.clear();
+ Vector v = helper.encodeFeatureVector(file, actual, leakType,
overallCounts);
+ md.update(v, traceDictionary, model);
+ }
+
+ List<String> ngNames = Lists.newArrayList(newsGroups.values());
+ List<ModelDissector.Weight> weights = md.summary(100);
+ System.out.println("============");
+ System.out.println("Model Dissection");
+ for (ModelDissector.Weight w : weights) {
+ System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s\n",
+ w.getFeature(), w.getWeight(),
ngNames.get(w.getMaxImpact() + 1),
+ w.getCategory(1), w.getWeight(1), w.getCategory(2),
w.getWeight(2));
+ }
+ }
+
+ public static List<File> permute(Iterable<File> files, Random rand) {
+ List<File> r = Lists.newArrayList();
+ for (File file : files) {
+ int i = rand.nextInt(r.size() + 1);
+ if (i == r.size()) {
+ r.add(file);
+ } else {
+ r.add(r.get(i));
+ r.set(i, file);
+ }
+ }
+ return r;
+ }
+
+ static void analyzeState(SGDInfo info, int leakType, int k,
State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best) throws
IOException {
+ int bump = info.bumps[(int) Math.floor(info.step) % info.bumps.length];
+ int scale = (int) Math.pow(10, Math.floor(info.step / info.bumps.length));
+ double maxBeta;
+ double nonZeros;
+ double positive;
+ double norm;
+
+ double lambda = 0;
+ double mu = 0;
+
+ if (best != null) {
+ CrossFoldLearner state = best.getPayload().getLearner();
+ info.averageCorrect = state.percentCorrect();
+ info.averageLL = state.logLikelihood();
+
+ OnlineLogisticRegression model = state.getModels().get(0);
+ // finish off pending regularization
+ model.close();
+
+ Matrix beta = model.getBeta();
+ maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
+ nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {
+ @Override
+ public double apply(double v) {
+ return Math.abs(v) > 1.0e-6 ? 1 : 0;
+ }
+ });
+ positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {
+ @Override
+ public double apply(double v) {
+ return v > 0 ? 1 : 0;
+ }
+ });
+ norm = beta.aggregate(Functions.PLUS, Functions.ABS);
+
+ lambda = best.getMappedParams()[0];
+ mu = best.getMappedParams()[1];
+ } else {
+ maxBeta = 0;
+ nonZeros = 0;
+ positive = 0;
+ norm = 0;
+ }
+ if (k % (bump * scale) == 0) {
+ if (best != null) {
+ ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model",
+ best.getPayload().getLearner().getModels().get(0));
+ }
+
+ info.step += 0.25;
+ System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta,
nonZeros, positive, norm, lambda, mu);
+ System.out.printf("%d\t%.3f\t%.2f\t%s\n",
+ k, info.averageLL, info.averageCorrect * 100, LEAK_LABELS[leakType %
3]);
+ }
+ }
+
+}
Added:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java?rev=1201223&view=auto
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java
(added)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java
Sat Nov 12 08:19:18 2011
@@ -0,0 +1,30 @@
+package org.apache.mahout.classifier.sgd;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+/**
+ *
+ *
+ **/
+class SGDInfo {
+ double averageLL = 0;
+ double averageCorrect = 0;
+ double step = 0;
+ int[] bumps = {1, 2, 5};
+
+}
Copied:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java
(from r1200329,
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java)
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java?p2=mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java&p1=mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java&r1=1200329&r2=1201223&rev=1201223&view=diff
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
(original)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java
Sat Nov 12 08:19:18 2011
@@ -18,7 +18,6 @@
package org.apache.mahout.classifier.sgd;
import com.google.common.collect.HashMultiset;
-import com.google.common.collect.Lists;
import com.google.common.collect.Multiset;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
@@ -28,31 +27,37 @@ import org.apache.commons.cli2.builder.D
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Text;
import org.apache.mahout.classifier.ClassifierResult;
import org.apache.mahout.classifier.ResultAnalyzer;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;
import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.vectorizer.encoders.Dictionary;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.PrintWriter;
-import java.util.Arrays;
-import java.util.List;
/**
- * Run the 20 news groups test data through SGD, as trained by {@link
org.apache.mahout.classifier.sgd.TrainNewsGroups}.
+ * Run the 20 news groups test data through SGD, as trained by {@link
TrainNewsGroups}.
*/
-public final class TestNewsGroups {
+public final class TestASFEmail {
private String inputFile;
private String modelFile;
- private TestNewsGroups() {
+ private TestASFEmail() {
}
public static void main(String[] args) throws IOException {
- TestNewsGroups runner = new TestNewsGroups();
+ TestASFEmail runner = new TestASFEmail();
if (runner.parseArgs(args)) {
runner.run(new PrintWriter(System.out, true));
}
@@ -65,30 +70,35 @@ public final class TestNewsGroups {
OnlineLogisticRegression classifier = ModelSerializer.readBinary(new
FileInputStream(modelFile), OnlineLogisticRegression.class);
- Dictionary newsGroups = new Dictionary();
+ Dictionary asfDictionary = new Dictionary();
Multiset<String> overallCounts = HashMultiset.create();
-
- List<File> files = Lists.newArrayList();
- for (File newsgroup : base.listFiles()) {
- if (newsgroup.isDirectory()) {
- newsGroups.intern(newsgroup.getName());
- files.addAll(Arrays.asList(newsgroup.listFiles()));
- }
+ Configuration conf = new Configuration();
+ SequenceFileDirIterator<Text, VectorWritable> iter = new
SequenceFileDirIterator<Text, VectorWritable>(new Path(base.toString()),
PathType.LIST, PathFilters.partFilter(),
+ null, true, conf);
+
+ long numItems = 0;
+ while (iter.hasNext()) {
+ Pair<Text, VectorWritable> next = iter.next();
+ asfDictionary.intern(next.getFirst().toString());
+ numItems++;
}
- System.out.printf("%d test files\n", files.size());
- ResultAnalyzer ra = new ResultAnalyzer(newsGroups.values(), "DEFAULT");
- for (File file : files) {
- String ng = file.getParentFile().getName();
- int actual = newsGroups.intern(ng);
+ System.out.printf("%d test files\n", numItems);
+ ResultAnalyzer ra = new ResultAnalyzer(asfDictionary.values(), "DEFAULT");
+ iter = new SequenceFileDirIterator<Text, VectorWritable>(new
Path(base.toString()), PathType.LIST, PathFilters.partFilter(),
+ null, true, conf);
+ while (iter.hasNext()){
+ Pair<Text, VectorWritable> next = iter.next();
+ String ng = next.getFirst().toString();
+
+ int actual = asfDictionary.intern(ng);
NewsgroupHelper helper = new NewsgroupHelper();
- Vector input = helper.encodeFeatureVector(file, actual, 0,
overallCounts);//no leak type ensures this is a normal vector
- Vector result = classifier.classifyFull(input);
+ Vector result = classifier.classifyFull(next.getSecond().get());
int cat = result.maxValueIndex();
double score = result.maxValue();
- double ll = classifier.logLikelihood(actual, input);
- ClassifierResult cr = new ClassifierResult(newsGroups.values().get(cat),
score, ll);
- ra.addInstance(newsGroups.values().get(actual), cr);
+ double ll = classifier.logLikelihood(actual, next.getSecond().get());
+ ClassifierResult cr = new
ClassifierResult(asfDictionary.values().get(cat), score, ll);
+ ra.addInstance(asfDictionary.values().get(actual), cr);
}
output.printf("%s\n\n", ra.toString());
Added:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java?rev=1201223&view=auto
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java
(added)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java
Sat Nov 12 08:19:18 2011
@@ -0,0 +1,126 @@
+package org.apache.mahout.classifier.sgd;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Multiset;
+import com.google.common.collect.Ordering;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Text;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+
+/**
+ *
+ *
+ **/
+public class TrainASFEmail {
+ private static final String[] LEAK_LABELS = {"none", "month-year",
"day-month-year"};
+
+ private static Multiset<String> overallCounts;
+
+ private TrainASFEmail() {
+ }
+
+ public static void main(String[] args) throws IOException {
+ File base = new File(args[0]);
+
+ overallCounts = HashMultiset.create();
+ File output = new File(args[1]);
+ output.mkdirs();
+ int numCats = Integer.parseInt(args[2]);
+ int cardinality = Integer.parseInt(args[3]);
+
+ int leakType = 0;
+ if (args.length > 4) {
+ leakType = Integer.parseInt(args[4]);
+ }
+
+ Dictionary asfDictionary = new Dictionary();
+
+
+ AdaptiveLogisticRegression learningAlgorithm = new
AdaptiveLogisticRegression(numCats, cardinality, new L1());
+ learningAlgorithm.setInterval(800);
+ learningAlgorithm.setAveragingWindow(500);
+
+ //We ran seq2encoded and split input already, so let's just build up the
dictionary
+ Configuration conf = new Configuration();
+ SequenceFileDirIterator<Text, VectorWritable> iter = new
SequenceFileDirIterator<Text, VectorWritable>(new Path(base.toString()),
PathType.LIST, PathFilters.partFilter(),
+ null, true, conf);
+ long numItems = 0;
+ while (iter.hasNext()) {
+ Pair<Text, VectorWritable> next = iter.next();
+ asfDictionary.intern(next.getFirst().toString());
+ numItems++;
+ }
+
+ System.out.printf("%d training files\n", numItems);
+
+
+ int k = 0;
+ SGDInfo info = new SGDInfo();
+
+ iter = new SequenceFileDirIterator<Text, VectorWritable>(new
Path(base.toString()), PathType.LIST, PathFilters.partFilter(),
+ null, true, conf);
+ while (iter.hasNext()) {
+ Pair<Text, VectorWritable> next = iter.next();
+ String ng = next.getFirst().toString();
+ int actual = asfDictionary.intern(ng);
+ //we already have encoded
+ learningAlgorithm.train(actual, next.getSecond().get());
+ k++;
+ State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best =
learningAlgorithm.getBest();
+
+ SGDHelper.analyzeState(info, leakType, k, best);
+ }
+ learningAlgorithm.close();
+ //TODO: how to dissection since we aren't processing the files here
+ //SGDHelper.dissect(leakType, asfDictionary, learningAlgorithm, files,
overallCounts);
+ System.out.println("exiting main, writing model to " + output);
+
+ ModelSerializer.writeBinary(output + "/asf.model",
+
learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
+
+ List<Integer> counts = Lists.newArrayList();
+ System.out.printf("Word counts\n");
+ for (String count : overallCounts.elementSet()) {
+ counts.add(overallCounts.count(count));
+ }
+ Collections.sort(counts, Ordering.natural().reverse());
+ k = 0;
+ for (Integer count : counts) {
+ System.out.printf("%d\t%d\n", k, count);
+ k++;
+ if (k > 1000) {
+ break;
+ }
+ }
+ }
+}
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java?rev=1201223&r1=1201222&r2=1201223&view=diff
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
(original)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
Sat Nov 12 08:19:18 2011
@@ -19,15 +19,10 @@ package org.apache.mahout.classifier.sgd
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Lists;
-import com.google.common.collect.Maps;
import com.google.common.collect.Multiset;
import com.google.common.collect.Ordering;
-
import org.apache.mahout.ep.State;
-import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.function.Functions;
-import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.vectorizer.encoders.Dictionary;
import java.io.File;
@@ -35,9 +30,6 @@ import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
-import java.util.Map;
-import java.util.Random;
-import java.util.Set;
/**
* Reads and trains an adaptive logistic regression model on the 20 newsgroups
data.
@@ -45,10 +37,10 @@ import java.util.Set;
* data. The optional second argument, leakType, defines which classes of
features to use.
* Importantly, leakType controls whether a synthetic date is injected into
the data as
* a target leak and if so, how.
- * <p>
+ * <p/>
* The value of leakType % 3 determines whether the target leak is injected
according to
* the following table:
- * <p>
+ * <p/>
* <table>
* <tr><td valign='top'>0</td><td>No leak injected</td></tr>
* <tr><td valign='top'>1</td><td>Synthetic date injected in MMM-yyyy format.
This will be a single token and
@@ -57,16 +49,16 @@ import java.util.Set;
* and thus there are more leak symbols that need to be learned. Ultimately
this is just
* as big a leak as case 1.</td></tr>
* </table>
- * <p>
+ * <p/>
* Leaktype also determines what other text will be indexed. If leakType is
greater
* than or equal to 6, then neither headers nor text body will be used for
features and the leak is the only
* source of data. If leakType is greater than or equal to 3, then subject
words will be used as features.
* If leakType is less than 3, then both subject and body text will be used as
features.
- * <p>
+ * <p/>
* A leakType of 0 gives no leak and all textual features.
- * <p>
+ * <p/>
* See the following table for a summary of commonly used values for leakType
- * <p>
+ * <p/>
* <table>
*
<tr><td><b>leakType</b></td><td><b>Leak?</b></td><td><b>Subject?</b></td><td><b>Body?</b></td></tr>
* <tr><td colspan=4><hr></td></tr>
@@ -86,8 +78,6 @@ import java.util.Set;
*/
public final class TrainNewsGroups {
- private static final String[] LEAK_LABELS = {"none", "month-year",
"day-month-year"};
-
private static Multiset<String> overallCounts;
private TrainNewsGroups() {
@@ -120,13 +110,11 @@ public final class TrainNewsGroups {
}
Collections.shuffle(files);
System.out.printf("%d training files\n", files.size());
-
- double averageLL = 0;
- double averageCorrect = 0;
+ SGDInfo info = new SGDInfo();
int k = 0;
- double step = 0;
- int[] bumps = {1, 2, 5};
+
+
for (File file : files) {
String ng = file.getParentFile().getName();
int actual = newsGroups.intern(ng);
@@ -135,69 +123,16 @@ public final class TrainNewsGroups {
learningAlgorithm.train(actual, v);
k++;
-
- int bump = bumps[(int) Math.floor(step) % bumps.length];
- int scale = (int) Math.pow(10, Math.floor(step / bumps.length));
State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best =
learningAlgorithm.getBest();
- double maxBeta;
- double nonZeros;
- double positive;
- double norm;
-
- double lambda = 0;
- double mu = 0;
-
- if (best != null) {
- CrossFoldLearner state = best.getPayload().getLearner();
- averageCorrect = state.percentCorrect();
- averageLL = state.logLikelihood();
-
- OnlineLogisticRegression model = state.getModels().get(0);
- // finish off pending regularization
- model.close();
-
- Matrix beta = model.getBeta();
- maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
- nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {
- @Override
- public double apply(double v) {
- return Math.abs(v) > 1.0e-6 ? 1 : 0;
- }
- });
- positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {
- @Override
- public double apply(double v) {
- return v > 0 ? 1 : 0;
- }
- });
- norm = beta.aggregate(Functions.PLUS, Functions.ABS);
-
- lambda = learningAlgorithm.getBest().getMappedParams()[0];
- mu = learningAlgorithm.getBest().getMappedParams()[1];
- } else {
- maxBeta = 0;
- nonZeros = 0;
- positive = 0;
- norm = 0;
- }
- if (k % (bump * scale) == 0) {
- if (learningAlgorithm.getBest() != null) {
- ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model",
-
learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
- }
-
- step += 0.25;
- System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta,
nonZeros, positive, norm, lambda, mu);
- System.out.printf("%d\t%.3f\t%.2f\t%s\n",
- k, averageLL, averageCorrect * 100, LEAK_LABELS[leakType % 3]);
- }
+
+ SGDHelper.analyzeState(info, leakType, k, best);
}
learningAlgorithm.close();
- dissect(leakType, newsGroups, learningAlgorithm, files);
+ SGDHelper.dissect(leakType, newsGroups, learningAlgorithm, files,
overallCounts);
System.out.println("exiting main");
ModelSerializer.writeBinary("/tmp/news-group.model",
-
learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
+
learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
List<Integer> counts = Lists.newArrayList();
System.out.printf("Word counts\n");
@@ -215,52 +150,5 @@ public final class TrainNewsGroups {
}
}
- private static void dissect(int leakType,
- Dictionary newsGroups,
- AdaptiveLogisticRegression learningAlgorithm,
- Iterable<File> files) throws IOException {
- CrossFoldLearner model =
learningAlgorithm.getBest().getPayload().getLearner();
- model.close();
-
- Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap();
- ModelDissector md = new ModelDissector();
-
- NewsgroupHelper helper = new NewsgroupHelper();
- helper.getEncoder().setTraceDictionary(traceDictionary);
- helper.getBias().setTraceDictionary(traceDictionary);
-
- for (File file : permute(files, helper.getRandom()).subList(0, 500)) {
- String ng = file.getParentFile().getName();
- int actual = newsGroups.intern(ng);
-
- traceDictionary.clear();
- Vector v = helper.encodeFeatureVector(file, actual, leakType,
overallCounts);
- md.update(v, traceDictionary, model);
- }
-
- List<String> ngNames = Lists.newArrayList(newsGroups.values());
- List<ModelDissector.Weight> weights = md.summary(100);
- System.out.println("============");
- System.out.println("Model Dissection");
- for (ModelDissector.Weight w : weights) {
- System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s\n",
- w.getFeature(), w.getWeight(),
ngNames.get(w.getMaxImpact() + 1),
- w.getCategory(1), w.getWeight(1), w.getCategory(2),
w.getWeight(2));
- }
- }
-
- private static List<File> permute(Iterable<File> files, Random rand) {
- List<File> r = Lists.newArrayList();
- for (File file : files) {
- int i = rand.nextInt(r.size() + 1);
- if (i == r.size()) {
- r.add(file);
- } else {
- r.add(r.get(i));
- r.set(i, file);
- }
- }
- return r;
- }
}