I have write a command line program proto type for RunAdaptiveLogistic, 1. How can I make it invokeable from mahout 2. Can you help to fine tune the AdaptiveLogisticRegression creating and settings to make it make sense.
On Tue, May 10, 2011 at 11:30 PM, Ted Dunning <ted.dunn...@gmail.com> wrote: > Great idea. Why don't you implement something like what you need? Others > will be happy to contribute improvements. > > On Tue, May 10, 2011 at 8:26 AM, XiaoboGu <guxiaobo1...@gmail.com> wrote: > >> > There isn't a good command line for this, largely because it is difficult >> to >> > describe how to convert each CSV field. There is some beginnings of >> efforts >> > on this, but the results are still limit. >> >> In common usages the predictor variables are almost number or category >> variables encoded into numbers, so an unify CSV file converter is possible >> for data with only these data types. >> >
package org.apache.mahout.classifier.sgd; import java.io.DataInputStream; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; public class AdaptiveLogisticModelParameters extends LogisticModelParameters { private AdaptiveLogisticRegression alr; public AdaptiveLogisticRegression createAdaptiveLogisticRegression() { if (alr == null) { alr = new AdaptiveLogisticRegression(getMaxTargetCategories(), getNumFeatures(), new L1()); } return alr; } public static AdaptiveLogisticModelParameters loadFromStream(InputStream in) throws IOException { AdaptiveLogisticModelParameters result = new AdaptiveLogisticModelParameters(); result.readFields(new DataInputStream(in)); return result; } public static AdaptiveLogisticModelParameters loadFromFile(File in) throws IOException { InputStream input = new FileInputStream(in); try { return loadFromStream(input); } finally { input.close(); } } }
package org.apache.mahout.classifier.sgd; import org.apache.commons.cli2.CommandLine; import org.apache.commons.cli2.Group; import org.apache.commons.cli2.Option; import org.apache.commons.cli2.builder.ArgumentBuilder; import org.apache.commons.cli2.builder.DefaultOptionBuilder; import org.apache.commons.cli2.builder.GroupBuilder; import org.apache.commons.cli2.commandline.Parser; import org.apache.commons.cli2.util.HelpFormatter; //import org.apache.mahout.math.Matrix; import org.apache.mahout.math.SequentialAccessSparseVector; import org.apache.mahout.math.Vector; //import org.apache.mahout.classifier.evaluation.Auc; import java.io.BufferedReader; import java.io.File; import java.io.IOException; import java.io.PrintStream; import java.util.Locale; public final class RunAdaptiveLogistic { private static String inputFile; private static String modelFile; private static boolean showAuc; private static boolean showScores; private static boolean showConfusion; static PrintStream output = System.out; private RunAdaptiveLogistic(){} public static void main(String[] args) throws IOException { if (parseArgs(args)) { if (!showAuc && !showConfusion && !showScores) { showAuc = true; showConfusion = true; } //Auc collector = new Auc(); AdaptiveLogisticModelParameters lmp = AdaptiveLogisticModelParameters.loadFromFile(new File(modelFile)); CsvRecordFactory csv = lmp.getCsvRecordFactory(); AdaptiveLogisticRegression lr = lmp.createAdaptiveLogisticRegression(); BufferedReader in = TrainLogistic.open(inputFile); String line = in.readLine(); csv.firstLine(line); line = in.readLine(); if (showScores) { output.printf(Locale.ENGLISH, "\"%s\",\"%s\",\"%s\"\n", "target", "model-output", "log-likelihood"); } while (line != null) { Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures()); int target = csv.processLine(line, v); lr.train(target, v); //double score = lr.classifyScalar(v); if (showScores) { //output.printf(Locale.ENGLISH, "%d,%.3f,%.6f\n", target, score, lr.logLikelihood(target, v)); output.printf(Locale.ENGLISH," sroce can't be shown while using AdaptiveLogisticRegression"); } //collector.add(target, score); line = in.readLine(); } if (showAuc) { //output.printf(Locale.ENGLISH, "AUC = %.2f\n", collector.auc()); } /* if (showConfusion) { Matrix m = collector.confusion(); output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]\n", m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1)); m = collector.entropy(); output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]\n", m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1)); }*/ } } private static boolean parseArgs(String[] args) { DefaultOptionBuilder builder = new DefaultOptionBuilder(); Option help = builder.withLongName("help").withDescription("print this list").create(); Option quiet = builder.withLongName("quiet").withDescription("be extra quiet").create(); Option auc = builder.withLongName("auc").withDescription("print AUC").create(); Option confusion = builder.withLongName("confusion").withDescription("print confusion matrix").create(); Option scores = builder.withLongName("scores").withDescription("print scores").create(); ArgumentBuilder argumentBuilder = new ArgumentBuilder(); Option inputFileOption = builder.withLongName("input") .withRequired(true) .withArgument(argumentBuilder.withName("input").withMaximum(1).create()) .withDescription("where to get training data") .create(); Option modelFileOption = builder.withLongName("model") .withRequired(true) .withArgument(argumentBuilder.withName("model").withMaximum(1).create()) .withDescription("where to get a model") .create(); Group normalArgs = new GroupBuilder() .withOption(help) .withOption(quiet) .withOption(auc) .withOption(scores) .withOption(confusion) .withOption(inputFileOption) .withOption(modelFileOption) .create(); Parser parser = new Parser(); parser.setHelpOption(help); parser.setHelpTrigger("--help"); parser.setGroup(normalArgs); parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130)); CommandLine cmdLine = parser.parseAndHelp(args); if (cmdLine == null) { return false; } inputFile = getStringArgument(cmdLine, inputFileOption); modelFile = getStringArgument(cmdLine, modelFileOption); showAuc = getBooleanArgument(cmdLine, auc); showScores = getBooleanArgument(cmdLine, scores); showConfusion = getBooleanArgument(cmdLine, confusion); return true; } private static boolean getBooleanArgument(CommandLine cmdLine, Option option) { return cmdLine.hasOption(option); } private static String getStringArgument(CommandLine cmdLine, Option inputFile) { return (String) cmdLine.getValue(inputFile); } }