Author: srowen
Date: Wed Jul 6 20:17:58 2011
New Revision: 1143542
URL: http://svn.apache.org/viewvc?rev=1143542&view=rev
Log:
MAHOUT-696 add command lines for adaptive logistic
Added:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java
mahout/trunk/core/src/test/java/org/apache/mahout/common/MahoutTestCase.java
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java
mahout/trunk/src/conf/driver.classes.props
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java?rev=1143542&r1=1143541&r2=1143542&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java
Wed Jul 6 20:17:58 2011
@@ -70,7 +70,7 @@ public class CsvRecordFactory implements
// crude CSV value splitter. This will fail if any double quoted strings
have
// commas inside. Also, escaped quotes will not be unescaped. Good enough
for now.
- private final Splitter COMMA =
Splitter.on(',').trimResults(CharMatcher.is('"'));
+ private static final Splitter COMMA =
Splitter.on(',').trimResults(CharMatcher.is('"'));
private static final Map<String, Class<? extends FeatureVectorEncoder>>
TYPE_DICTIONARY =
ImmutableMap.<String, Class<? extends FeatureVectorEncoder>>builder()
@@ -87,6 +87,10 @@ public class CsvRecordFactory implements
private int target;
private final Dictionary targetDictionary;
+
+ //Which column is used for identify a CSV file line
+ private String idName;
+ private int id = -1;
private List<Integer> predictors;
private Map<Integer, FeatureVectorEncoder> predictorEncoders;
@@ -109,6 +113,11 @@ public class CsvRecordFactory implements
targetDictionary = new Dictionary();
}
+ public CsvRecordFactory(String targetName, String idName, Map<String,
String> typeMap){
+ this(targetName, typeMap);
+ this.idName = idName;
+ }
+
/**
* Defines the values and thus the encoding of values of the target
variables. Note
* that any values of the target variable not present in this list will be
given the
@@ -165,6 +174,11 @@ public class CsvRecordFactory implements
// record target column and establish dictionary for decoding target
target = vars.get(targetName);
+
+ // record id column
+ if (idName != null){
+ id = vars.get(idName);
+ }
// create list of predictor column numbers
predictors = Lists.newArrayList(Collections2.transform(typeMap.keySet(),
new Function<String, Integer>() {
@@ -244,6 +258,69 @@ public class CsvRecordFactory implements
}
return targetValue;
}
+
+ /***
+ * Decodes a single line of csv data and records the target(if retrunTarget
is true)
+ * and predictor variables in a record. As a side effect, features are added
into the featureVector.
+ * Returns the value of the target variable. When used during classify
against production data without
+ * target value, the method will be called with returnTarget = false.
+ * @param line The raw data.
+ * @param featureVector Where to fill in the features. Should be zeroed
before calling
+ * processLine.
+ * @param returnTarget whether process and return target value, -1 will be
returned if false.
+ * @return The value of the target variable.
+ */
+ public int processLine(CharSequence line, Vector featureVector, boolean
returnTarget) {
+ List<String> values = Lists.newArrayList(COMMA.split(line));
+ int targetValue = -1;
+ if (returnTarget) {
+ targetValue = targetDictionary.intern(values.get(target));
+ if (targetValue >= maxTargetValue) {
+ targetValue = maxTargetValue - 1;
+ }
+ }
+
+ for (Integer predictor : predictors) {
+ String value = predictor >= 0 ? values.get(predictor) : null;
+ predictorEncoders.get(predictor).addToVector(value, featureVector);
+ }
+ return targetValue;
+ }
+
+ /***
+ * Extract the raw target string from a line read from a CSV file.
+ * @param line the line of content read from CSV file
+ * @return the raw target value in the corresponding column of CSV line
+ */
+ public String getTargetString(CharSequence line) {
+ List<String> values = Lists.newArrayList(COMMA.split(line));
+ return values.get(target);
+
+ }
+
+ /***
+ * Extract the corresponding raw target label according to a code
+ * @param code the integer code encoded during training process
+ * @return the raw target label
+ */
+ public String getTargetLabel(int code) {
+ for (String key: targetDictionary.values()) {
+ if (targetDictionary.intern(key) == code) {
+ return key;
+ }
+ }
+ return null;
+ }
+
+ /***
+ * Extract the id column value from the CSV record
+ * @param line the line of content read from CSV file
+ * @return the id value of the CSV record
+ */
+ public String getIdString(CharSequence line){
+ List<String> values = Lists.newArrayList(COMMA.split(line));
+ return values.get(id);
+ }
/**
* Returns a list of the names of the predictor variables.
@@ -284,4 +361,12 @@ public class CsvRecordFactory implements
return r;
}
+ public String getIdName() {
+ return idName;
+ }
+
+ public void setIdName(String idName) {
+ this.idName = idName;
+ }
+
}
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/common/MahoutTestCase.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/common/MahoutTestCase.java?rev=1143542&r1=1143541&r2=1143542&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/common/MahoutTestCase.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/common/MahoutTestCase.java
Wed Jul 6 20:17:58 2011
@@ -17,7 +17,12 @@
package org.apache.mahout.common;
-import java.io.*;
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
import java.lang.reflect.Field;
import com.google.common.base.Charsets;
@@ -106,15 +111,15 @@ public abstract class MahoutTestCase ext
* Find a declared field in a class or one of it's super classes
*/
private static Field findDeclaredField(Class<?> inClass, String fieldname)
throws NoSuchFieldException {
- if (Object.class.equals(inClass)) {
- throw new NoSuchFieldException();
- }
- for (Field field : inClass.getDeclaredFields()) {
- if (field.getName().equalsIgnoreCase(fieldname)) {
- return field;
+ while (!Object.class.equals(inClass)) {
+ for (Field field : inClass.getDeclaredFields()) {
+ if (field.getName().equalsIgnoreCase(fieldname)) {
+ return field;
+ }
}
+ inClass = inClass.getSuperclass();
}
- return findDeclaredField(inClass.getSuperclass(), fieldname);
+ throw new NoSuchFieldException();
}
/**
Added:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java?rev=1143542&view=auto
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java
(added)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java
Wed Jul 6 20:17:58 2011
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import com.google.common.io.Closeables;
+import org.apache.mahout.math.stats.GlobalOnlineAuc;
+import org.apache.mahout.math.stats.GroupedOnlineAuc;
+import org.apache.mahout.math.stats.OnlineAuc;
+
+public class AdaptiveLogisticModelParameters extends LogisticModelParameters {
+
+ private AdaptiveLogisticRegression alr;
+ private int interval = 800;
+ private int averageWindow = 500;
+ private int threads = 4;
+ private String prior = "L1";
+ private double priorOption = Double.NaN;
+ private String auc = null;
+
+ public AdaptiveLogisticRegression createAdaptiveLogisticRegression() {
+
+ if (alr == null) {
+ alr = new AdaptiveLogisticRegression(getMaxTargetCategories(),
+ getNumFeatures(),
createPrior(prior, priorOption));
+ alr.setInterval(interval);
+ alr.setAveragingWindow(averageWindow);
+ alr.setThreadCount(threads);
+ alr.setAucEvaluator(createAUC(auc));
+ }
+ return alr;
+ }
+
+ public void checkParameters() {
+ if (prior != null) {
+ if ("TP".equals(prior.toUpperCase().trim()) ||
+ "EBP".equals(prior.toUpperCase().trim())) {
+ if (Double.isNaN(priorOption)) {
+ throw new IllegalArgumentException("You must specify a double value
for TPrior and ElasticBandPrior.");
+ }
+ }
+ }
+ }
+
+ private static PriorFunction createPrior(String cmd, double priorOption) {
+ if (cmd == null) {
+ return null;
+ }
+ if ("L1".equals(cmd.toUpperCase().trim())) {
+ return new L1();
+ }
+ if ("L2".equals(cmd.toUpperCase().trim())) {
+ return new L2();
+ }
+ if ("UP".equals(cmd.toUpperCase().trim())) {
+ return new UniformPrior();
+ }
+ if ("TP".equals(cmd.toUpperCase().trim())) {
+ return new TPrior(priorOption);
+ }
+ if ("EBP".equals(cmd.toUpperCase().trim())) {
+ return new ElasticBandPrior(priorOption);
+ }
+
+ return null;
+ }
+
+ private static OnlineAuc createAUC(String cmd) {
+ if (cmd == null) {
+ return null;
+ }
+ if ("GLOBAL".equals(cmd.toUpperCase().trim())) {
+ return new GlobalOnlineAuc();
+ }
+ if ("GROUPED".equals(cmd.toUpperCase().trim())) {
+ return new GroupedOnlineAuc();
+ }
+ return null;
+ }
+
+ @Override
+ public void saveTo(OutputStream out) throws IOException {
+ if (alr != null) {
+ alr.close();
+ }
+ setTargetCategories(getCsvRecordFactory().getTargetCategories());
+ write(new DataOutputStream(out));
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeUTF(getTargetVariable());
+ out.writeInt(getTypeMap().size());
+ for (Map.Entry<String, String> entry : getTypeMap().entrySet()) {
+ out.writeUTF(entry.getKey());
+ out.writeUTF(entry.getValue());
+ }
+ out.writeInt(getNumFeatures());
+ out.writeInt(getMaxTargetCategories());
+ out.writeInt(getTargetCategories().size());
+ for (String category : getTargetCategories()) {
+ out.writeUTF(category);
+ }
+
+ out.writeInt(interval);
+ out.writeInt(averageWindow);
+ out.writeInt(threads);
+ out.writeUTF(prior);
+ out.writeDouble(priorOption);
+ out.writeUTF(auc);
+
+ // skip csv
+ alr.write(out);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ setTargetVariable(in.readUTF());
+ int typeMapSize = in.readInt();
+ Map<String, String> typeMap = new HashMap<String, String>(typeMapSize);
+ for (int i = 0; i < typeMapSize; i++) {
+ String key = in.readUTF();
+ String value = in.readUTF();
+ typeMap.put(key, value);
+ }
+ setTypeMap(typeMap);
+
+ setNumFeatures(in.readInt());
+ setMaxTargetCategories(in.readInt());
+ int targetCategoriesSize = in.readInt();
+ List<String> targetCategories = new
ArrayList<String>(targetCategoriesSize);
+ for (int i = 0; i < targetCategoriesSize; i++) {
+ targetCategories.add(in.readUTF());
+ }
+ setTargetCategories(targetCategories);
+
+ interval = in.readInt();
+ averageWindow = in.readInt();
+ threads = in.readInt();
+ prior = in.readUTF();
+ priorOption = in.readDouble();
+ auc = in.readUTF();
+
+ alr = new AdaptiveLogisticRegression();
+ alr.readFields(in);
+ }
+
+
+ private static AdaptiveLogisticModelParameters loadFromStream(InputStream
in) throws IOException {
+ AdaptiveLogisticModelParameters result = new
AdaptiveLogisticModelParameters();
+ result.readFields(new DataInputStream(in));
+ return result;
+ }
+
+ public static AdaptiveLogisticModelParameters loadFromFile(File in) throws
IOException {
+ InputStream input = new FileInputStream(in);
+ try {
+ return loadFromStream(input);
+ } finally {
+ Closeables.closeQuietly(input);
+ }
+ }
+
+ public int getInterval() {
+ return interval;
+ }
+
+ public void setInterval(int interval) {
+ this.interval = interval;
+ }
+
+ public int getAverageWindow() {
+ return averageWindow;
+ }
+
+ public void setAverageWindow(int averageWindow) {
+ this.averageWindow = averageWindow;
+ }
+
+ public int getThreads() {
+ return threads;
+ }
+
+ public void setThreads(int threads) {
+ this.threads = threads;
+ }
+
+ public String getPrior() {
+ return prior;
+ }
+
+ public void setPrior(String prior) {
+ this.prior = prior;
+ }
+
+ public String getAuc() {
+ return auc;
+ }
+
+ public void setAuc(String auc) {
+ this.auc = auc;
+ }
+
+ public double getPriorOption() {
+ return priorOption;
+ }
+
+ public void setPriorOption(double priorOption) {
+ this.priorOption = priorOption;
+ }
+
+
+}
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java?rev=1143542&r1=1143541&r2=1143542&view=diff
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java
(original)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java
Wed Jul 6 20:17:58 2011
@@ -216,6 +216,10 @@ public class LogisticModelParameters imp
maxTargetCategories = targetCategories.size();
}
+ public List<String> getTargetCategories() {
+ return this.targetCategories;
+ }
+
public void setUseBias(boolean useBias) {
this.useBias = useBias;
}
@@ -232,6 +236,10 @@ public class LogisticModelParameters imp
return typeMap;
}
+ public void setTypeMap(Map<String, String> map) {
+ this.typeMap = map;
+ }
+
public int getNumFeatures() {
return numFeatures;
}
Added:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java?rev=1143542&view=auto
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java
(added)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java
Wed Jul 6 20:17:58 2011
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import java.io.BufferedReader;
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.HashMap;
+import java.util.Locale;
+import java.util.Map;
+
+public final class RunAdaptiveLogistic {
+
+ private static String inputFile;
+ private static String modelFile;
+ private static String outputFile;
+ private static String idColumn;
+ private static boolean maxScoreOnly;
+
+ private RunAdaptiveLogistic() {
+ }
+
+ public static void main(String[] args) throws IOException {
+ mainToOutput(args, new PrintWriter(System.out));
+ }
+
+ static void mainToOutput(String[] args, PrintWriter output) throws
IOException {
+ if (!parseArgs(args)) {
+ return;
+ }
+ AdaptiveLogisticModelParameters lmp = AdaptiveLogisticModelParameters
+ .loadFromFile(new File(modelFile));
+
+ CsvRecordFactory csv = lmp.getCsvRecordFactory();
+ csv.setIdName(idColumn);
+
+ AdaptiveLogisticRegression lr = lmp.createAdaptiveLogisticRegression();
+
+ State<Wrapper, CrossFoldLearner> best = lr.getBest();
+ if (best == null) {
+ output.printf("%s\n",
+ "AdaptiveLogisticRegression has not be trained probably.");
+ return;
+ }
+ CrossFoldLearner learner = best.getPayload().getLearner();
+
+ BufferedReader in = TrainAdaptiveLogistic.open(inputFile);
+ BufferedWriter out = new BufferedWriter(new FileWriter(outputFile));
+
+ out.write(idColumn + ",target,score");
+ out.newLine();
+
+ String line = in.readLine();
+ csv.firstLine(line);
+ line = in.readLine();
+ Map<String, Double> results = new HashMap<String, Double>();
+ int k = 0;
+ while (line != null) {
+ Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
+ csv.processLine(line, v, false);
+ Vector scores = learner.classifyFull(v);
+ results.clear();
+ if (maxScoreOnly) {
+ results.put(csv.getTargetLabel(scores.maxValueIndex()),
+ scores.maxValue());
+ } else {
+ for (int i = 0; i < scores.size(); i++) {
+ results.put(csv.getTargetLabel(i), scores.get(i));
+ }
+ }
+
+ for (Map.Entry<String,Double> entry : results.entrySet()) {
+ out.write(csv.getIdString(line) + ',' + entry.getKey() + ',' +
entry.getValue());
+ out.newLine();
+ }
+ k++;
+ if (k % 100 == 0) {
+ output.printf(Locale.ENGLISH, "%d records processed \n", k);
+ }
+ line = in.readLine();
+ }
+ out.flush();
+ out.close();
+ output.printf(Locale.ENGLISH, "%d records processed totally.\n", k);
+ }
+
+ private static boolean parseArgs(String[] args) {
+ DefaultOptionBuilder builder = new DefaultOptionBuilder();
+
+ Option help = builder.withLongName("help")
+ .withDescription("print this list").create();
+
+ Option quiet = builder.withLongName("quiet")
+ .withDescription("be extra quiet").create();
+
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+ Option inputFileOption = builder
+ .withLongName("input")
+ .withRequired(true)
+ .withArgument(
+ argumentBuilder.withName("input").withMaximum(1)
+ .create())
+ .withDescription("where to get training data").create();
+
+ Option modelFileOption = builder
+ .withLongName("model")
+ .withRequired(true)
+ .withArgument(
+ argumentBuilder.withName("model").withMaximum(1)
+ .create())
+ .withDescription("where to get the trained model").create();
+
+ Option outputFileOption = builder
+ .withLongName("output")
+ .withRequired(true)
+ .withDescription("the file path to output scores")
+ .withArgument(argumentBuilder.withName("output").withMaximum(1).create())
+ .create();
+
+ Option idColumnOption = builder
+ .withLongName("idcolumn")
+ .withRequired(true)
+ .withDescription("the name of the id column for each record")
+
.withArgument(argumentBuilder.withName("idcolumn").withMaximum(1).create())
+ .create();
+
+ Option maxScoreOnlyOption = builder
+ .withLongName("maxscoreonly")
+ .withDescription("only output the target label with max scores")
+ .create();
+
+ Group normalArgs = new GroupBuilder()
+ .withOption(help).withOption(quiet)
+ .withOption(inputFileOption).withOption(modelFileOption)
+ .withOption(outputFileOption).withOption(idColumnOption)
+ .withOption(maxScoreOnlyOption)
+ .create();
+
+ Parser parser = new Parser();
+ parser.setHelpOption(help);
+ parser.setHelpTrigger("--help");
+ parser.setGroup(normalArgs);
+ parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+ CommandLine cmdLine = parser.parseAndHelp(args);
+
+ if (cmdLine == null) {
+ return false;
+ }
+
+ inputFile = getStringArgument(cmdLine, inputFileOption);
+ modelFile = getStringArgument(cmdLine, modelFileOption);
+ outputFile = getStringArgument(cmdLine, outputFileOption);
+ idColumn = getStringArgument(cmdLine, idColumnOption);
+ maxScoreOnly = getBooleanArgument(cmdLine, maxScoreOnlyOption);
+ return true;
+ }
+
+ private static boolean getBooleanArgument(CommandLine cmdLine, Option
option) {
+ return cmdLine.hasOption(option);
+ }
+
+ private static String getStringArgument(CommandLine cmdLine, Option
inputFile) {
+ return (String) cmdLine.getValue(inputFile);
+ }
+
+}
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java?rev=1143542&r1=1143541&r2=1143542&view=diff
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
(original)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
Wed Jul 6 20:17:58 2011
@@ -33,7 +33,7 @@ import org.apache.mahout.classifier.eval
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
-import java.io.PrintStream;
+import java.io.PrintWriter;
import java.util.Locale;
public final class RunLogistic {
@@ -43,12 +43,15 @@ public final class RunLogistic {
private static boolean showAuc;
private static boolean showScores;
private static boolean showConfusion;
- static PrintStream output = System.out;
private RunLogistic() {
}
public static void main(String[] args) throws IOException {
+ mainToOutput(args, new PrintWriter(System.out));
+ }
+
+ static void mainToOutput(String[] args, PrintWriter output) throws
IOException {
if (parseArgs(args)) {
if (!showAuc && !showConfusion && !showScores) {
showAuc = true;
Added:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java?rev=1143542&view=auto
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java
(added)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java
Wed Jul 6 20:17:58 2011
@@ -0,0 +1,384 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.io.OutputStream;
+import java.io.PrintWriter;
+import java.util.List;
+import java.util.Locale;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+
+import com.google.common.base.Charsets;
+import com.google.common.collect.Lists;
+import com.google.common.io.Resources;
+
+public final class TrainAdaptiveLogistic {
+
+ private static String inputFile;
+ private static String outputFile;
+ private static AdaptiveLogisticModelParameters lmp;
+ private static int passes;
+ private static boolean showperf;
+ private static int skipperfnum = 99;
+ private static AdaptiveLogisticRegression model;
+
+ private TrainAdaptiveLogistic() {
+ }
+
+ public static void main(String[] args) throws IOException {
+ mainToOutput(args, new PrintWriter(System.out));
+ }
+
+ static void mainToOutput(String[] args, PrintWriter output) throws
IOException {
+ if (parseArgs(args)) {
+
+ CsvRecordFactory csv = lmp.getCsvRecordFactory();
+ model = lmp.createAdaptiveLogisticRegression();
+ State<Wrapper, CrossFoldLearner> best = null;
+ CrossFoldLearner learner = null;
+
+ int k = 0;
+ for (int pass = 0; pass < passes; pass++) {
+ BufferedReader in = open(inputFile);
+
+ // read variable names
+ csv.firstLine(in.readLine());
+
+ String line = in.readLine();
+
+ while (line != null) {
+ // for each new line, get target and predictors
+ Vector input = new RandomAccessSparseVector(lmp.getNumFeatures());
+ int targetValue = csv.processLine(line, input);
+ // update model
+ model.train(targetValue, input);
+ k++;
+
+ if (showperf && (k % (skipperfnum + 1) == 0)) {
+
+ best = model.getBest();
+ if (best != null) {
+ learner = best.getPayload().getLearner();
+ }
+ if (learner != null) {
+ double averageCorrect = learner.percentCorrect();
+ double averageLL = learner.logLikelihood();
+ output.printf("%d\t%.3f\t%.2f\n",
+ k, averageLL, averageCorrect * 100);
+ } else {
+ output.printf(Locale.ENGLISH,
+ "%10d %2d %s\n", k, targetValue,
+ "AdaptiveLogisticRegression has not found a good
model ......");
+ }
+ }
+ line = in.readLine();
+ }
+ in.close();
+ }
+
+ best = model.getBest();
+ if (best != null) {
+ learner = best.getPayload().getLearner();
+ }
+ if (learner == null) {
+ output.printf(Locale.ENGLISH,
+ "%s\n", "AdaptiveLogisticRegression has not successfully
trained any model.");
+ return;
+ }
+
+
+ OutputStream modelOutput = new FileOutputStream(outputFile);
+ try {
+ lmp.saveTo(modelOutput);
+ } finally {
+ modelOutput.close();
+ }
+
+ OnlineLogisticRegression lr = learner.getModels().get(0);
+ output.printf(Locale.ENGLISH, "%d\n", lmp.getNumFeatures());
+ output.printf(Locale.ENGLISH, "%s ~ ", lmp.getTargetVariable());
+ String sep = "";
+ for (String v : csv.getTraceDictionary().keySet()) {
+ double weight = predictorWeight(lr, 0, csv, v);
+ if (weight != 0) {
+ output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v);
+ sep = " + ";
+ }
+ }
+ output.printf("\n");
+
+ for (int row = 0; row < lr.getBeta().numRows(); row++) {
+ for (String key : csv.getTraceDictionary().keySet()) {
+ double weight = predictorWeight(lr, row, csv, key);
+ if (weight != 0) {
+ output.printf(Locale.ENGLISH, "%20s %.5f\n", key, weight);
+ }
+ }
+ for (int column = 0; column < lr.getBeta().numCols(); column++) {
+ output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row,
column));
+ }
+ output.println();
+ }
+ }
+
+ }
+
+ private static double predictorWeight(OnlineLogisticRegression lr, int row,
RecordFactory csv, String predictor) {
+ double weight = 0;
+ for (Integer column : csv.getTraceDictionary().get(predictor)) {
+ weight += lr.getBeta().get(row, column);
+ }
+ return weight;
+ }
+
+ private static boolean parseArgs(String[] args) {
+ DefaultOptionBuilder builder = new DefaultOptionBuilder();
+
+ Option help = builder.withLongName("help")
+ .withDescription("print this list").create();
+
+ Option quiet = builder.withLongName("quiet")
+ .withDescription("be extra quiet").create();
+
+
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+ Option showperf = builder
+ .withLongName("showperf")
+ .withDescription("output performance measures during training")
+ .create();
+
+ Option inputFile = builder
+ .withLongName("input")
+ .withRequired(true)
+ .withArgument(
+ argumentBuilder.withName("input").withMaximum(1)
+ .create())
+ .withDescription("where to get training data").create();
+
+ Option outputFile = builder
+ .withLongName("output")
+ .withRequired(true)
+ .withArgument(
+ argumentBuilder.withName("output").withMaximum(1)
+ .create())
+ .withDescription("where to write the model content").create();
+
+ Option threads = builder.withLongName("threads")
+ .withArgument(
+ argumentBuilder.withName("threads").withDefault("4").create())
+ .withDescription("the number of threads AdaptiveLogisticRegression
uses")
+ .create();
+
+
+ Option predictors = builder.withLongName("predictors")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("predictors").create())
+ .withDescription("a list of predictor variables").create();
+
+ Option types = builder
+ .withLongName("types")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("types").create())
+ .withDescription(
+ "a list of predictor variable types (numeric, word, or text)")
+ .create();
+
+ Option target = builder
+ .withLongName("target")
+ .withDescription("the name of the target variable")
+ .withRequired(true)
+ .withArgument(
+ argumentBuilder.withName("target").withMaximum(1)
+ .create())
+ .create();
+
+ Option targetCategories = builder
+ .withLongName("categories")
+ .withDescription("the number of target categories to be considered")
+ .withRequired(true)
+ .withArgument(
+ argumentBuilder.withName("categories").withMaximum(1)
+ .create())
+ .create();
+
+
+ Option features = builder
+ .withLongName("features")
+ .withDescription("the number of internal hashed features to use")
+ .withArgument(
+ argumentBuilder.withName("numFeatures")
+ .withDefault("1000").withMaximum(1).create())
+ .create();
+
+ Option passes = builder
+ .withLongName("passes")
+ .withDescription("the number of times to pass over the input data")
+ .withArgument(
+ argumentBuilder.withName("passes").withDefault("2")
+ .withMaximum(1).create())
+ .create();
+
+ Option interval = builder.withLongName("interval")
+ .withArgument(
+ argumentBuilder.withName("interval").withDefault("500").create())
+ .withDescription("the interval property of AdaptiveLogisticRegression")
+ .create();
+
+ Option window = builder.withLongName("window")
+ .withArgument(
+ argumentBuilder.withName("window").withDefault("800").create())
+ .withDescription("the average propery of AdaptiveLogisticRegression")
+ .create();
+
+ Option skipperfnum = builder.withLongName("skipperfnum")
+ .withArgument(
+ argumentBuilder.withName("skipperfnum").withDefault("99").create())
+ .withDescription("show performance measures every (skipperfnum + 1)
rows")
+ .create();
+
+ Option prior = builder.withLongName("prior")
+ .withArgument(
+ argumentBuilder.withName("prior").withDefault("L1").create())
+ .withDescription("the prior algorithm to use: L1, L2, ebp, tp, up")
+ .create();
+
+ Option priorOption = builder.withLongName("prioroption")
+ .withArgument(
+ argumentBuilder.withName("prioroption").create())
+ .withDescription("constructor parameter for ElasticBandPrior and
TPrior")
+ .create();
+
+ Option auc = builder.withLongName("auc")
+ .withArgument(
+ argumentBuilder.withName("auc").withDefault("global").create())
+ .withDescription("the auc to use: global or grouped")
+ .create();
+
+
+
+ Group normalArgs = new GroupBuilder().withOption(help)
+ .withOption(quiet).withOption(inputFile).withOption(outputFile)
+ .withOption(target).withOption(targetCategories)
+ .withOption(predictors).withOption(types).withOption(passes)
+ .withOption(interval).withOption(window).withOption(threads)
+ .withOption(prior).withOption(features).withOption(showperf)
+ .withOption(skipperfnum).withOption(priorOption).withOption(auc)
+ .create();
+
+ Parser parser = new Parser();
+ parser.setHelpOption(help);
+ parser.setHelpTrigger("--help");
+ parser.setGroup(normalArgs);
+ parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+ CommandLine cmdLine = parser.parseAndHelp(args);
+
+ if (cmdLine == null) {
+ return false;
+ }
+
+ TrainAdaptiveLogistic.inputFile = getStringArgument(cmdLine, inputFile);
+ TrainAdaptiveLogistic.outputFile = getStringArgument(cmdLine,
+ outputFile);
+
+ List<String> typeList = Lists.newArrayList();
+ for (Object x : cmdLine.getValues(types)) {
+ typeList.add(x.toString());
+ }
+
+ List<String> predictorList = Lists.newArrayList();
+ for (Object x : cmdLine.getValues(predictors)) {
+ predictorList.add(x.toString());
+ }
+
+ lmp = new AdaptiveLogisticModelParameters();
+ lmp.setTargetVariable(getStringArgument(cmdLine, target));
+ lmp.setMaxTargetCategories(getIntegerArgument(cmdLine, targetCategories));
+ lmp.setNumFeatures(getIntegerArgument(cmdLine, features));
+ lmp.setInterval(getIntegerArgument(cmdLine, interval));
+ lmp.setAverageWindow(getIntegerArgument(cmdLine, window));
+ lmp.setThreads(getIntegerArgument(cmdLine, threads));
+ lmp.setAuc(getStringArgument(cmdLine, auc));
+ lmp.setPrior(getStringArgument(cmdLine, prior));
+ if (cmdLine.getValue(priorOption) != null) {
+ lmp.setPriorOption(getDoubleArgument(cmdLine, priorOption));
+ }
+ lmp.setTypeMap(predictorList, typeList);
+ TrainAdaptiveLogistic.showperf = getBooleanArgument(cmdLine, showperf);
+ TrainAdaptiveLogistic.skipperfnum = getIntegerArgument(cmdLine,
skipperfnum);
+ TrainAdaptiveLogistic.passes = getIntegerArgument(cmdLine, passes);
+
+ lmp.checkParameters();
+
+ return true;
+ }
+
+ private static String getStringArgument(CommandLine cmdLine,
+ Option inputFile) {
+ return (String) cmdLine.getValue(inputFile);
+ }
+
+ private static boolean getBooleanArgument(CommandLine cmdLine, Option
option) {
+ return cmdLine.hasOption(option);
+ }
+
+ private static int getIntegerArgument(CommandLine cmdLine, Option features) {
+ return Integer.parseInt((String) cmdLine.getValue(features));
+ }
+
+ private static double getDoubleArgument(CommandLine cmdLine, Option op) {
+ return Double.parseDouble((String) cmdLine.getValue(op));
+ }
+
+ public static AdaptiveLogisticRegression getModel() {
+ return model;
+ }
+
+ public static LogisticModelParameters getParameters() {
+ return lmp;
+ }
+
+ static BufferedReader open(String inputFile) throws IOException {
+ InputStream in;
+ try {
+ in = Resources.getResource(inputFile).openStream();
+ } catch (IllegalArgumentException e) {
+ in = new FileInputStream(new File(inputFile));
+ }
+ return new BufferedReader(new InputStreamReader(in, Charsets.UTF_8));
+ }
+
+}
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java?rev=1143542&r1=1143541&r2=1143542&view=diff
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
(original)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
Wed Jul 6 20:17:58 2011
@@ -40,7 +40,7 @@ import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
-import java.io.PrintStream;
+import java.io.PrintWriter;
import java.util.List;
import java.util.Locale;
@@ -53,16 +53,18 @@ public final class TrainLogistic {
private static String inputFile;
private static String outputFile;
private static LogisticModelParameters lmp;
-
private static int passes;
private static boolean scores;
private static OnlineLogisticRegression model;
- static PrintStream output = System.out;
private TrainLogistic() {
}
public static void main(String[] args) throws IOException {
+ mainToOutput(args, new PrintWriter(System.out));
+ }
+
+ static void mainToOutput(String[] args, PrintWriter output) throws
IOException {
if (parseArgs(args)) {
double logPEstimate = 0;
int samples = 0;
Added:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java?rev=1143542&view=auto
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java
(added)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java
Wed Jul 6 20:17:58 2011
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.Locale;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.mahout.classifier.ConfusionMatrix;
+import org.apache.mahout.classifier.evaluation.Auc;
+import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.stats.OnlineSummarizer;
+
+/*
+ * Auc and averageLikelihood are always shown if possible, if the number of
target value is more than 2,
+ * then Auc and entropy matirx are not shown regardless the value of showAuc
and showEntropy
+ * the user passes, because the current implementation does not support them
on two value targets.
+ * */
+public final class ValidateAdaptiveLogistic {
+
+ private static String inputFile;
+ private static String modelFile;
+ private static boolean showAuc;
+ private static boolean showScores;
+ private static boolean showConfusion;
+
+ private ValidateAdaptiveLogistic() {
+ }
+
+ public static void main(String[] args) throws IOException {
+ mainToOutput(args, new PrintWriter(System.out));
+ }
+
+ static void mainToOutput(String[] args, PrintWriter output) throws
IOException {
+ if (parseArgs(args)) {
+ if (!showAuc && !showConfusion && !showScores) {
+ showAuc = true;
+ showConfusion = true;
+ }
+
+ Auc collector = null;
+ AdaptiveLogisticModelParameters lmp = AdaptiveLogisticModelParameters
+ .loadFromFile(new File(modelFile));
+ CsvRecordFactory csv = lmp.getCsvRecordFactory();
+ AdaptiveLogisticRegression lr = lmp.createAdaptiveLogisticRegression();
+
+ if (lmp.getTargetCategories().size() <=2 ) {
+ collector = new Auc();
+ }
+
+ OnlineSummarizer slh = new OnlineSummarizer();
+ ConfusionMatrix cm = new ConfusionMatrix(lmp.getTargetCategories(),
"unknown");
+
+
+ State<Wrapper, CrossFoldLearner> best = lr.getBest();
+ if (best == null) {
+ output.printf("%s\n",
+ "AdaptiveLogisticRegression has not be trained probably.");
+ return;
+ }
+ CrossFoldLearner learner = best.getPayload().getLearner();
+
+ BufferedReader in = TrainLogistic.open(inputFile);
+ String line = in.readLine();
+ csv.firstLine(line);
+ line = in.readLine();
+ if (showScores) {
+ output.printf(Locale.ENGLISH, "\"%s\", \"%s\", \"%s\", \"%s\"\n",
+ "target", "model-output", "log-likelihood", "average-likelihood");
+ }
+ while (line != null) {
+ Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
+ //TODO: How to avoid extra target values not shown in the training
process.
+ int target = csv.processLine(line, v);
+ double likelihood = learner.logLikelihood(target, v);
+ double score = learner.classifyFull(v).maxValue();
+
+ slh.add(likelihood);
+ cm.addInstance(csv.getTargetString(line), csv.getTargetLabel(target));
+
+ if (showScores) {
+ output.printf(Locale.ENGLISH, "%8d, %.12f, %.13f, %.13f\n", target,
+ score, learner.logLikelihood(target, v), slh.getMean());
+ }
+ if (collector != null) {
+ collector.add(target, score);
+ }
+ line = in.readLine();
+ }
+
+ output.printf(Locale.ENGLISH,"\nLog-likelihood:");
+ output.printf(Locale.ENGLISH, "Min=%.2f, Max=%.2f, Mean=%.2f,
Median=%.2f\n",
+ slh.getMin(), slh.getMax(), slh.getMean(), slh.getMedian());
+
+ if (collector != null) {
+ output.printf(Locale.ENGLISH, "\nAUC = %.2f\n", collector.auc());
+ }
+
+ if (showConfusion) {
+ output.printf(Locale.ENGLISH, "\n%s\n\n", cm.toString());
+
+ if (collector != null){
+ Matrix m = collector.entropy();
+ output.printf(Locale.ENGLISH,
+ "Entropy Matrix: [[%.1f, %.1f], [%.1f, %.1f]]\n", m.get(0, 0),
+ m.get(1, 0), m.get(0, 1), m.get(1, 1));
+ }
+ }
+
+ }
+ }
+
+ private static boolean parseArgs(String[] args) {
+ DefaultOptionBuilder builder = new DefaultOptionBuilder();
+
+ Option help = builder.withLongName("help")
+ .withDescription("print this list").create();
+
+ Option quiet = builder.withLongName("quiet")
+ .withDescription("be extra quiet").create();
+
+ Option auc = builder.withLongName("auc").withDescription("print AUC")
+ .create();
+ Option confusion = builder.withLongName("confusion")
+ .withDescription("print confusion matrix").create();
+
+ Option scores = builder.withLongName("scores")
+ .withDescription("print scores").create();
+
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+ Option inputFileOption = builder
+ .withLongName("input")
+ .withRequired(true)
+ .withArgument(
+ argumentBuilder.withName("input").withMaximum(1)
+ .create())
+ .withDescription("where to get validate data").create();
+
+ Option modelFileOption = builder
+ .withLongName("model")
+ .withRequired(true)
+ .withArgument(
+ argumentBuilder.withName("model").withMaximum(1)
+ .create())
+ .withDescription("where to get the trained model").create();
+
+ Group normalArgs = new GroupBuilder().withOption(help)
+ .withOption(quiet).withOption(auc).withOption(scores)
+ .withOption(confusion).withOption(inputFileOption)
+ .withOption(modelFileOption).create();
+
+ Parser parser = new Parser();
+ parser.setHelpOption(help);
+ parser.setHelpTrigger("--help");
+ parser.setGroup(normalArgs);
+ parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+ CommandLine cmdLine = parser.parseAndHelp(args);
+
+ if (cmdLine == null) {
+ return false;
+ }
+
+ inputFile = getStringArgument(cmdLine, inputFileOption);
+ modelFile = getStringArgument(cmdLine, modelFileOption);
+ showAuc = getBooleanArgument(cmdLine, auc);
+ showScores = getBooleanArgument(cmdLine, scores);
+ showConfusion = getBooleanArgument(cmdLine, confusion);
+
+ return true;
+ }
+
+ private static boolean getBooleanArgument(CommandLine cmdLine, Option
option) {
+ return cmdLine.hasOption(option);
+ }
+
+ private static String getStringArgument(CommandLine cmdLine, Option
inputFile) {
+ return (String) cmdLine.getValue(inputFile);
+ }
+
+}
Modified:
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java?rev=1143542&r1=1143541&r2=1143542&view=diff
==============================================================================
---
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java
(original)
+++
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java
Wed Jul 6 20:17:58 2011
@@ -28,15 +28,11 @@ import org.apache.mahout.math.DenseVecto
import org.apache.mahout.math.Vector;
import org.junit.Test;
-import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
-import java.io.IOException;
import java.io.InputStream;
-import java.io.PrintStream;
-import java.lang.reflect.Field;
-import java.lang.reflect.InvocationTargetException;
-import java.lang.reflect.Method;
+import java.io.PrintWriter;
+import java.io.StringWriter;
import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -47,16 +43,19 @@ public class TrainLogisticTest extends M
public void example13_1() throws Exception {
String outputFile = getTestTempFile("model").getAbsolutePath();
- String trainOut = runMain(TrainLogistic.class, new String[]{
- "--input", "donut.csv",
- "--output", outputFile,
- "--target", "color", "--categories", "2",
- "--predictors", "x", "y",
- "--types", "numeric",
- "--features", "20",
- "--passes", "100",
- "--rate", "50"
- });
+ StringWriter sw = new StringWriter();
+ PrintWriter pw = new PrintWriter(sw);
+ TrainLogistic.mainToOutput(new String[]{
+ "--input", "donut.csv",
+ "--output", outputFile,
+ "--target", "color", "--categories", "2",
+ "--predictors", "x", "y",
+ "--types", "numeric",
+ "--features", "20",
+ "--passes", "100",
+ "--rate", "50"
+ }, pw);
+ String trainOut = sw.toString();
assertTrue(trainOut.contains("x -0.7"));
assertTrue(trainOut.contains("y -0.4"));
@@ -87,20 +86,26 @@ public class TrainLogisticTest extends M
Closeables.closeQuietly(in);
}
- String output = runMain(RunLogistic.class, new String[]{
+ sw = new StringWriter();
+ pw = new PrintWriter(sw);
+ RunLogistic.mainToOutput(new String[]{
"--input", "donut.csv",
"--model", outputFile,
"--auc",
"--confusion"
- });
- assertTrue(output.contains("AUC = 0.57"));
- assertTrue(output.contains("confusion: [[27.0, 13.0], [0.0, 0.0]]"));
+ }, pw);
+ trainOut = sw.toString();
+ assertTrue(trainOut.contains("AUC = 0.57"));
+ assertTrue(trainOut.contains("confusion: [[27.0, 13.0], [0.0, 0.0]]"));
}
@Test
public void example13_2() throws Exception {
String outputFile = getTestTempFile("model").getAbsolutePath();
- String trainOut = runMain(TrainLogistic.class, new String[]{
+
+ StringWriter sw = new StringWriter();
+ PrintWriter pw = new PrintWriter(sw);
+ TrainLogistic.mainToOutput(new String[]{
"--input", "donut.csv",
"--output", outputFile,
"--target", "color",
@@ -110,59 +115,34 @@ public class TrainLogisticTest extends M
"--features", "20",
"--passes", "100",
"--rate", "50"
- });
+ }, pw);
+ String trainOut = sw.toString();
assertTrue(trainOut.contains("a 0."));
assertTrue(trainOut.contains("b -1."));
assertTrue(trainOut.contains("c -25."));
- String output = runMain(RunLogistic.class, new String[]{
+ sw = new StringWriter();
+ pw = new PrintWriter(sw);
+ RunLogistic.mainToOutput(new String[]{
"--input", "donut.csv",
"--model", outputFile,
"--auc",
"--confusion"
- });
- assertTrue(output.contains("AUC = 1.00"));
-
- String heldout = runMain(RunLogistic.class, new String[]{
+ }, pw);
+ trainOut = sw.toString();
+ assertTrue(trainOut.contains("AUC = 1.00"));
+
+ sw = new StringWriter();
+ pw = new PrintWriter(sw);
+ RunLogistic.mainToOutput(new String[]{
"--input", "donut-test.csv",
"--model", outputFile,
"--auc",
"--confusion"
- });
- assertTrue(heldout.contains("AUC = 0.9"));
- }
-
- /**
- * Runs a class with a public static void main method. We assume that there
is an accessible
- * field named "output" that we can change to redirect output.
- *
- *
- * @param clazz contains the main method.
- * @param args contains the command line arguments
- * @return The contents to standard out as a string.
- * @throws IOException Not possible, but must be declared.
- * @throws NoSuchFieldException If there isn't an output field.
- * @throws IllegalAccessException If the output field isn't
accessible by us.
- * @throws NoSuchMethodException If there isn't a main method.
- * @throws InvocationTargetException If the main method throws an
exception.
- */
- private static String runMain(Class<?> clazz, String[] args)
- throws NoSuchFieldException, IllegalAccessException,
NoSuchMethodException, InvocationTargetException {
- ByteArrayOutputStream trainOutput = new ByteArrayOutputStream();
- PrintStream printStream = new PrintStream(trainOutput);
-
- try {
- Field outputField = clazz.getDeclaredField("output");
- Method main = clazz.getMethod("main", args.getClass());
-
- outputField.set(null, printStream);
- Object[] argList = {args};
- main.invoke(null, argList);
- return new String(trainOutput.toByteArray(), Charsets.UTF_8);
- } finally {
- Closeables.closeQuietly(printStream);
- }
+ }, pw);
+ trainOut = sw.toString();
+ assertTrue(trainOut.contains("AUC = 0.9"));
}
private static void verifyModel(LogisticModelParameters lmp,
Modified: mahout/trunk/src/conf/driver.classes.props
URL:
http://svn.apache.org/viewvc/mahout/trunk/src/conf/driver.classes.props?rev=1143542&r1=1143541&r2=1143542&view=diff
==============================================================================
--- mahout/trunk/src/conf/driver.classes.props (original)
+++ mahout/trunk/src/conf/driver.classes.props Wed Jul 6 20:17:58 2011
@@ -31,6 +31,9 @@ org.apache.mahout.cf.taste.hadoop.item.R
org.apache.mahout.classifier.sgd.TrainLogistic = trainlogistic : Train a
logistic regression using stochastic gradient descent
org.apache.mahout.classifier.sgd.RunLogistic = runlogistic : Run a logistic
regression model against CSV data
org.apache.mahout.classifier.sgd.PrintResourceOrFile = cat : Print a file or
resource as the logistic regression models would see it
+org.apache.mahout.classifier.sgd.TrainAdaptiveLogistic = trainAdaptiveLogistic
: Train an AdaptivelogisticRegression model
+org.apache.mahout.classifier.sgd.ValidateAdaptiveLogistic =
validateAdaptiveLogistic : Validate an AdaptivelogisticRegression model against
hold-out data set
+org.apache.mahout.classifier.sgd.RunAdaptiveLogistic = runAdaptiveLogistic :
Score new production data using a probably trained and validated
AdaptivelogisticRegression model
org.apache.mahout.classifier.bayes.WikipediaXmlSplitter = wikipediaXMLSplitter
: Reads wikipedia data and creates ch
org.apache.mahout.classifier.bayes.WikipediaDatasetCreatorDriver =
wikipediaDataSetCreator : Splits data set of wikipedia wrt feature like country
org.apache.mahout.math.hadoop.stochasticsvd.SSVDCli = ssvd : Stochastic SVD