Author: srowen
Date: Fri Nov 4 11:20:03 2011
New Revision: 1197510
URL: http://svn.apache.org/viewvc?rev=1197510&view=rev
Log:
MAHOUT-838 Add confusion matrix dumper
Added:
mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/
mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java
mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java
mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
mahout/trunk/src/conf/driver.classes.props
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java?rev=1197510&r1=1197509&r2=1197510&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
Fri Nov 4 11:20:03 2011
@@ -19,35 +19,40 @@ package org.apache.mahout.classifier;
import java.util.Collection;
import java.util.Collections;
-import java.util.LinkedHashMap;
import java.util.Map;
-import com.google.common.collect.Maps;
import org.apache.commons.lang.StringUtils;
-import org.apache.mahout.math.CardinalityException;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import com.google.common.base.Preconditions;
+import com.google.common.collect.Maps;
/**
* The ConfusionMatrix Class stores the result of Classification of a Test
Dataset.
*
+ * The fact of whether there is a default is not stored. A row of zeros is the
only indicator that there is no default.
+ *
* See http://en.wikipedia.org/wiki/Confusion_matrix for background
*/
public class ConfusionMatrix {
-
- private final Map<String,Integer> labelMap = new
LinkedHashMap<String,Integer>();
+ private final Map<String,Integer> labelMap = Maps.newLinkedHashMap();
private final int[][] confusionMatrix;
private String defaultLabel = "unknown";
public ConfusionMatrix(Collection<String> labels, String defaultLabel) {
confusionMatrix = new int[labels.size() + 1][labels.size() + 1];
this.defaultLabel = defaultLabel;
+ int i = 0;
for (String label : labels) {
- labelMap.put(label, labelMap.size());
+ labelMap.put(label, i++);
}
- labelMap.put(defaultLabel, labelMap.size());
+ labelMap.put(defaultLabel, i);
+ }
+
+ public ConfusionMatrix(Matrix m) {
+ confusionMatrix = new int[m.numRows()][m.numRows()];
+ setMatrix(m);
}
public int[][] getConfusionMatrix() {
@@ -76,7 +81,7 @@ public class ConfusionMatrix {
return confusionMatrix[labelId][labelId];
}
- public double getTotal(String label) {
+ public int getTotal(String label) {
int labelId = labelMap.get(label);
int labelTotal = 0;
for (int i = 0; i < labelMap.size(); i++) {
@@ -94,25 +99,25 @@ public class ConfusionMatrix {
}
public int getCount(String correctLabel, String classifiedLabel) {
- Preconditions.checkArgument(labelMap.containsKey(correctLabel),
- "Label not found: " + correctLabel);
- Preconditions.checkArgument(labelMap.containsKey(classifiedLabel),
- "Label not found: " + classifiedLabel);
+ Preconditions.checkArgument(labelMap.containsKey(correctLabel), "Label not
found: " + correctLabel);
+ Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label
not found: " + classifiedLabel);
int correctId = labelMap.get(correctLabel);
int classifiedId = labelMap.get(classifiedLabel);
return confusionMatrix[correctId][classifiedId];
}
public void putCount(String correctLabel, String classifiedLabel, int count)
{
- Preconditions.checkArgument(labelMap.containsKey(correctLabel),
- "Label not found: " + correctLabel);
- Preconditions.checkArgument(labelMap.containsKey(classifiedLabel),
- "Label not found: " + classifiedLabel);
+ Preconditions.checkArgument(labelMap.containsKey(correctLabel), "Label not
found: " + correctLabel);
+ Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label
not found: " + classifiedLabel);
int correctId = labelMap.get(correctLabel);
int classifiedId = labelMap.get(classifiedLabel);
confusionMatrix[correctId][classifiedId] = count;
}
+ public String getDefaultLabel() {
+ return defaultLabel;
+ }
+
public void incrementCount(String correctLabel, String classifiedLabel, int
count) {
putCount(correctLabel, classifiedLabel, count + getCount(correctLabel,
classifiedLabel));
}
@@ -132,45 +137,69 @@ public class ConfusionMatrix {
}
public Matrix getMatrix() {
- int length = confusionMatrix.length;
- Matrix m = new DenseMatrix(length, length);
- for (int r = 0; r < length; r++) {
- for (int c = 0; c < length; c++) {
- m.set(r, c, confusionMatrix[r][c]);
- }
- }
- Map<String,Integer> labels = Maps.newHashMap();
- for (Map.Entry<String, Integer> entry : labelMap.entrySet()) {
- labels.put(entry.getKey(), entry.getValue());
- }
- m.setRowLabelBindings(labels);
- m.setColumnLabelBindings(labels);
- return m;
+ int length = confusionMatrix.length;
+ Matrix m = new DenseMatrix(length, length);
+ for (int r = 0; r < length; r++) {
+ for (int c = 0; c < length; c++) {
+ m.set(r, c, confusionMatrix[r][c]);
+ }
+ }
+ Map<String,Integer> labels = Maps.newHashMap();
+ for(Map.Entry<String, Integer> entry : labelMap.entrySet()) {
+ labels.put(entry.getKey(), entry.getValue());
+ }
+ m.setRowLabelBindings(labels);
+ m.setColumnLabelBindings(labels);
+ return m;
}
-
+
public void setMatrix(Matrix m) {
- int length = confusionMatrix.length;
- if (m.numRows() != m.numCols()) {
- throw new CardinalityException(m.numRows(), m.numCols());
- }
- if (m.numRows() != length) {
- throw new CardinalityException(m.numRows(), length);
- }
- for (int r = 0; r < length; r++) {
- for (int c = 0; c < length; c++) {
- confusionMatrix[r][c] = (int) Math.round(m.get(r, c));
- }
- }
- Map<String,Integer> labels = m.getRowLabelBindings();
- if (labels == null) {
+ int length = confusionMatrix.length;
+ if (m.numRows() != m.numCols()) {
+ throw new IllegalArgumentException(
+ String.format("ConfusionMatrix: matrix({},{}) must be square",
m.numRows(), m.numCols()));
+ }
+ for (int r = 0; r < length; r++) {
+ for (int c = 0; c < length; c++) {
+ confusionMatrix[r][c] = (int) Math.round(m.get(r, c));
+ }
+ }
+ Map<String,Integer> labels = m.getRowLabelBindings();
+ if (labels == null) {
labels = m.getColumnLabelBindings();
}
- labelMap.clear();
- if (labels != null) {
- labelMap.putAll(labels);
- }
+ if (labels != null) {
+ String[] sorted = sortLabels(labels);
+ verifyLabels(length, sorted);
+ labelMap.clear();
+ for(int i = 0; i < length; i++) {
+ labelMap.put(sorted[i], i);
+ }
+ }
+ }
+
+ private static String[] sortLabels(Map<String,Integer> labels) {
+ String[] sorted = new String[labels.keySet().size()];
+ for(String label: labels.keySet()) {
+ Integer index = labels.get(label);
+ sorted[index] = label;
+ }
+ return sorted;
+ }
+
+ private void verifyLabels(int length, String[] sorted) {
+ Preconditions.checkArgument(sorted.length == length, "One label, one row");
+ for(int i = 0; i < length; i++) {
+ if (sorted[i] == null) {
+ Preconditions.checkArgument(false, "One label, one row");
+ }
+ }
}
+ /**
+ * This is overloaded. toString() is not a formatted report you print for a
manager :)
+ * Assume that if there are no default assignments, the default feature was
not used
+ */
@Override
public String toString() {
StringBuilder returnString = new StringBuilder(200);
@@ -178,26 +207,37 @@ public class ConfusionMatrix {
returnString.append("Confusion Matrix\n");
returnString.append("-------------------------------------------------------").append('\n');
+ int unclassified = getTotal(defaultLabel);
for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) {
+ if (entry.getKey().equals(defaultLabel) && unclassified == 0) {
+ continue;
+ }
+
returnString.append(StringUtils.rightPad(getSmallLabel(entry.getValue()),
5)).append('\t');
}
returnString.append("<--Classified as").append('\n');
-
for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) {
+ if (entry.getKey().equals(defaultLabel) && unclassified == 0) {
+ continue;
+ }
String correctLabel = entry.getKey();
int labelTotal = 0;
for (String classifiedLabel : this.labelMap.keySet()) {
+ if (classifiedLabel.equals(defaultLabel) && unclassified == 0) {
+ continue;
+ }
returnString.append(
- StringUtils.rightPad(Integer.toString(getCount(correctLabel,
classifiedLabel)), 5)).append('\t');
+ StringUtils.rightPad(Integer.toString(getCount(correctLabel,
classifiedLabel)), 5)).append('\t');
labelTotal += getCount(correctLabel, classifiedLabel);
}
returnString.append(" |
").append(StringUtils.rightPad(String.valueOf(labelTotal), 6)).append('\t')
- .append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5))
- .append(" = ").append(correctLabel).append('\n');
+ .append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5))
+ .append(" = ").append(correctLabel).append('\n');
+ }
+ if (unclassified > 0) {
+ returnString.append("Default Category: ").append(defaultLabel).append(":
").append(unclassified).append('\n');
}
- returnString.append("Default Category: ").append(defaultLabel).append(":
").append(
- labelMap.get(defaultLabel)).append('\n');
returnString.append('\n');
return returnString.toString();
}
@@ -212,5 +252,5 @@ public class ConfusionMatrix {
} while (val > 0);
return returnString.toString();
}
-
+
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java?rev=1197510&r1=1197509&r2=1197510&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
Fri Nov 4 11:20:03 2011
@@ -33,6 +33,7 @@ import org.apache.commons.cli2.builder.D
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.mahout.classifier.ClassifierResult;
+import org.apache.mahout.classifier.ConfusionMatrix;
import org.apache.mahout.classifier.ResultAnalyzer;
import
org.apache.mahout.classifier.bayes.mapreduce.bayes.BayesClassifierDriver;
import org.apache.mahout.common.CommandLineUtil;
@@ -104,9 +105,14 @@ public final class TestClassifier {
"Method of Classification: sequential|mapreduce. Default Value:
sequential").withShortName("method")
.create();
+ Option confusionMatrixOpt =
obuilder.withLongName("confusionMatrix").withRequired(false).withArgument(
+
abuilder.withName("confusionMatrix").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Export ConfusionMatrix as SequenceFile").withShortName("cm").create();
+
Group group =
gbuilder.withName("Options").withOption(defaultCatOpt).withOption(dirOpt).withOption(
encodingOpt).withOption(gramSizeOpt).withOption(pathOpt).withOption(typeOpt).withOption(dataSourceOpt)
-
.withOption(helpOpt).withOption(methodOpt).withOption(verboseOutputOpt).withOption(alphaOpt).create();
+
.withOption(helpOpt).withOption(methodOpt).withOption(verboseOutputOpt).withOption(alphaOpt)
+ .withOption(confusionMatrixOpt).create();
try {
Parser parser = new Parser();
@@ -163,6 +169,11 @@ public final class TestClassifier {
classificationMethod = (String) cmdLine.getValue(methodOpt);
}
+ String confusionMatrixFile = null;
+ if (cmdLine.hasOption(confusionMatrixOpt)) {
+ confusionMatrixFile = (String) cmdLine.getValue(confusionMatrixOpt);
+ }
+
params.setGramSize(gramSize);
params.set("verbose", Boolean.toString(verbose));
params.setBasePath(modelBasePath);
@@ -172,6 +183,7 @@ public final class TestClassifier {
params.set("encoding", encoding);
params.set("alpha_i", alphaI);
params.set("testDirPath", testDirPath);
+ params.set("confusionMatrix", confusionMatrixFile);
if ("sequential".equalsIgnoreCase(classificationMethod)) {
classifySequential(params);
@@ -253,12 +265,12 @@ public final class TestClassifier {
}
lineNum++;
}
- /*
- * log.info("{}\t{}\t{}/{}", new Object[] {correctLabel,
- * resultAnalyzer.getConfusionMatrix().getAccuracy(correctLabel),
- * resultAnalyzer.getConfusionMatrix().getCorrect(correctLabel),
- * resultAnalyzer.getConfusionMatrix().getTotal(correctLabel)});
- */
+ ConfusionMatrix matrix = resultAnalyzer.getConfusionMatrix();
+ log.info("{}", matrix);
+ BayesClassifierDriver.confusionMatrixSeqFileExport(params, matrix);
+
+ log.info("ConfusionMatrix: {}", matrix.toString());
+
log.info("Classified instances from {}", file.getName());
if (verbose) {
log.info("Performance stats {}", operationStats.toString());
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java?rev=1197510&r1=1197509&r2=1197510&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesClassifierDriver.java
Fri Nov 4 11:20:03 2011
@@ -21,10 +21,15 @@ import java.io.IOException;
import java.util.Map;
import com.google.common.collect.Maps;
+import com.google.common.io.Closeables;
+
import org.apache.hadoop.conf.Configurable;
import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.FileInputFormat;
import org.apache.hadoop.mapred.FileOutputFormat;
import org.apache.hadoop.mapred.JobClient;
@@ -38,6 +43,7 @@ import org.apache.mahout.common.Paramete
import org.apache.mahout.common.StringTuple;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.math.MatrixWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -83,6 +89,10 @@ public final class BayesClassifierDriver
Path outputFiles = new Path(outPath, "part*");
ConfusionMatrix matrix = readResult(outputFiles, conf, params);
log.info("{}", matrix);
+ if (params.get("confusionMatrix") != null) {
+ confusionMatrixSeqFileExport(params, matrix);
+ }
+
}
public static ConfusionMatrix readResult(Path pathPattern, Configuration
conf, Parameters params) {
@@ -117,6 +127,24 @@ public final class BayesClassifierDriver
}
}
return matrix;
+ }
+ public static void confusionMatrixSeqFileExport(Parameters params,
ConfusionMatrix matrix) throws IOException {
+ if (params.get("confusionMatrix") != null) {
+ Configuration conf = new Configuration();
+ FileSystem fs = FileSystem.get(conf);
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf,
+ new Path(params.get("confusionMatrix")), Text.class,
MatrixWritable.class);
+ String name = params.get("confusionMatrix");
+ // embed file name as sequence key- useful for tuning classifiers
+ name = name.substring(name.lastIndexOf('/') + 1, name.length());
+ Text key = new Text(name);
+ MatrixWritable mw = new MatrixWritable(matrix.getMatrix());
+ try {
+ writer.append(key, mw);
+ } finally {
+ Closeables.closeQuietly(writer);
+ }
+ }
}
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java?rev=1197510&r1=1197509&r2=1197510&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java
Fri Nov 4 11:20:03 2011
@@ -123,15 +123,15 @@ public class MatrixWritable implements W
}
if (hasLabels) {
- Map<String,Integer> columnLabelBindings = Maps.newHashMap();
- Map<String,Integer> rowLabelBindings = Maps.newHashMap();
- readLabels(in, columnLabelBindings, rowLabelBindings);
- if (!columnLabelBindings.isEmpty()) {
- r.setColumnLabelBindings(columnLabelBindings);
- }
- if (!rowLabelBindings.isEmpty()) {
- r.setRowLabelBindings(rowLabelBindings);
- }
+ Map<String,Integer> columnLabelBindings = Maps.newHashMap();
+ Map<String,Integer> rowLabelBindings = Maps.newHashMap();
+ readLabels(in, columnLabelBindings, rowLabelBindings);
+ if (!columnLabelBindings.isEmpty()) {
+ r.setColumnLabelBindings(columnLabelBindings);
+ }
+ if (!rowLabelBindings.isEmpty()) {
+ r.setRowLabelBindings(rowLabelBindings);
+ }
}
return r;
@@ -159,7 +159,7 @@ public class MatrixWritable implements W
VectorWritable.writeVector(out, matrix.viewRow(i), false);
}
if ((flags & FLAG_LABELS) != 0) {
- writeLabelBindings(out, matrix.getColumnLabelBindings(),
matrix.getRowLabelBindings());
+ writeLabelBindings(out, matrix.getColumnLabelBindings(),
matrix.getRowLabelBindings());
}
}
}
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java?rev=1197510&r1=1197509&r2=1197510&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
Fri Nov 4 11:20:03 2011
@@ -17,6 +17,7 @@
package org.apache.mahout.math;
+import com.google.common.collect.Maps;
import com.google.common.io.Closeables;
import org.apache.hadoop.io.Writable;
import org.junit.Test;
@@ -26,7 +27,6 @@ import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
-import java.util.HashMap;
import java.util.Map;
public final class MatrixWritableTest extends MahoutTestCase {
@@ -36,13 +36,14 @@ public final class MatrixWritableTest ex
Matrix m = new SparseMatrix(5, 5);
m.set(1, 2, 3.0);
m.set(3, 4, 5.0);
- Map<String, Integer> bindings = new HashMap<String, Integer>();
+ Map<String, Integer> bindings = Maps.newHashMap();
bindings.put("A", 0);
bindings.put("B", 1);
bindings.put("C", 2);
bindings.put("D", 3);
bindings.put("default", 4);
m.setRowLabelBindings(bindings);
+ m.setColumnLabelBindings(bindings);
doTestMatrixWritableEquals(m);
}
@@ -51,12 +52,13 @@ public final class MatrixWritableTest ex
Matrix m = new DenseMatrix(5,5);
m.set(1, 2, 3.0);
m.set(3, 4, 5.0);
- Map<String, Integer> bindings = new HashMap<String, Integer>();
+ Map<String, Integer> bindings = Maps.newHashMap();
bindings.put("A", 0);
bindings.put("B", 1);
bindings.put("C", 2);
bindings.put("D", 3);
bindings.put("default", 4);
+ m.setRowLabelBindings(bindings);
m.setColumnLabelBindings(bindings);
doTestMatrixWritableEquals(m);
}
@@ -66,7 +68,9 @@ public final class MatrixWritableTest ex
MatrixWritable matrixWritable2 = new MatrixWritable();
writeAndRead(matrixWritable, matrixWritable2);
Matrix m2 = matrixWritable2.get();
- compareMatrices(m, m2); // not sure this works?
+ compareMatrices(m, m2);
+ doCheckBindings(m2.getRowLabelBindings());
+ doCheckBindings(m2.getColumnLabelBindings());
}
private static void compareMatrices(Matrix m, Matrix m2) {
@@ -98,6 +102,14 @@ public final class MatrixWritableTest ex
}
}
+ private static void doCheckBindings(Map<String,Integer> labels) {
+ assertTrue("Missing label", labels.keySet().contains("A"));
+ assertTrue("Missing label", labels.keySet().contains("B"));
+ assertTrue("Missing label", labels.keySet().contains("C"));
+ assertTrue("Missing label", labels.keySet().contains("D"));
+ assertTrue("Missing label", labels.keySet().contains("default"));
+ }
+
private static void writeAndRead(Writable toWrite, Writable toRead)
throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos);
Added:
mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java?rev=1197510&view=auto
==============================================================================
---
mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java
(added)
+++
mahout/trunk/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java
Fri Nov 4 11:20:03 2011
@@ -0,0 +1,423 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.PrintStream;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixWritable;
+
+import com.google.common.collect.Lists;
+
+/**
+ * Export a ConfusionMatrix in various text formats:
+ * ToString version
+ * Grayscale HTML table
+ * Summary HTML table
+ * Table of counts
+ * all with optional HTML wrappers
+ *
+ * Input format: Hadoop SequenceFile with Text key and MatrixWritable value, 1
pair
+ *
+ * Intended to consume ConfusionMatrix SequenceFile output by Bayes
+ * TestClassifier class
+ */
+public final class ConfusionMatrixDumper extends AbstractJob {
+
+ // HTML wrapper - default CSS
+ private static final String HEADER = "<html>" +
+ "<head>\n" +
+ "<title>TITLE</title>\n" +
+ "</head>" +
+ "<body>\n" +
+ "<style type='text/css'> \n" +
+ "table\n" +
+ "{\n" +
+ "border:3px solid black; text-align:left;\n" +
+ "}\n" +
+ "th.normalHeader\n" +
+ "{\n" +
+ "border:1px solid
black;border-collapse:collapse;text-align:center;background-color:white\n" +
+ "}\n" +
+ "th.tallHeader\n" +
+ "{\n" +
+ "border:1px solid
black;border-collapse:collapse;text-align:center;background-color:white;
height:6em\n" +
+ "}\n" +
+ "tr.label\n" +
+ "{\n" +
+ "border:1px solid
black;border-collapse:collapse;text-align:center;background-color:white\n" +
+ "}\n" +
+ "tr.row\n" +
+ "{\n" +
+ "border:1px solid gray;text-align:center;background-color:snow\n" +
+ "}\n" +
+ "td\n" +
+ "{\n" +
+ "min-width:2em\n" +
+ "}\n" +
+ "td.cell\n" +
+ "{\n" +
+ "border:1px solid black;text-align:right;background-color:snow\n" +
+ "}\n" +
+ "td.empty\n" +
+ "{\n" +
+ "border:0px;text-align:right;background-color:snow\n" +
+ "}\n" +
+ "td.white\n" +
+ "{\n" +
+ "border:0px solid black;text-align:right;background-color:white\n" +
+ "}\n" +
+ "td.black\n" +
+ "{\n" +
+ "border:0px solid red;text-align:right;background-color:black\n" +
+ "}\n" +
+ "td.gray1\n" +
+ "{\n" +
+ "border:0px solid green;text-align:right; background-color:LightGray\n" +
+ "}\n" +
+ "td.gray2\n" +
+ "{\n" +
+ "border:0px solid blue;text-align:right;background-color:gray\n" +
+ "}\n" +
+ "td.gray3\n" +
+ "{\n" +
+ "border:0px solid red;text-align:right;background-color:DarkGray\n" +
+ "}\n" +
+ "th" +
+ "{\n" +
+ " text-align: center;\n" +
+ " vertical-align: bottom;\n" +
+ " padding-bottom: 3px;\n" +
+ " padding-left: 5px;\n" +
+ " padding-right: 5px;\n" +
+ "}\n" +
+ " .verticalText\n" +
+ " {\n" +
+ " text-align: center;\n" +
+ " vertical-align: middle;\n" +
+ " width: 20px;\n" +
+ " margin: 0px;\n" +
+ " padding: 0px;\n" +
+ " padding-left: 3px;\n" +
+ " padding-right: 3px;\n" +
+ " padding-top: 10px;\n" +
+ " white-space: nowrap;\n" +
+ " -webkit-transform: rotate(-90deg); \n" +
+ " -moz-transform: rotate(-90deg); \n" +
+ " };\n" +
+ "</style>\n";
+ private static final String FOOTER = "</html></body>";
+
+ // CSS style names.
+ private static final String CSS_TABLE = "table";
+ private static final String CSS_LABEL = "label";
+ private static final String CSS_TALL_HEADER = "tall";
+ private static final String CSS_VERTICAL = "verticalText";
+ private static final String CSS_CELL = "cell";
+ private static final String CSS_EMPTY = "empty";
+ private static final String[] CSS_GRAY_CELLS = {"white", "gray1", "gray2",
"gray3", "black"};
+
+ private ConfusionMatrixDumper() {}
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new ConfusionMatrixDumper(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws IOException {
+ addInputOption();
+ addOption("output", "o", "Output path", null); // AbstractJob output
feature requires param
+ addOption(DefaultOptionCreator.overwriteOption().create());
+ addFlag("html", null, "Create complete HTML page");
+ addFlag("text", null, "Dump simple text");
+ Map<String, String> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+
+ Path inputPath = getInputPath();
+ String outputFile = parsedArgs.containsKey("--output") ?
parsedArgs.get("--output") : null;
+ boolean text = parsedArgs.containsKey("--text");
+ boolean wrapHtml = parsedArgs.containsKey("--html");
+ PrintStream out = getPrintStream(outputFile);
+ if (text) {
+ exportText(inputPath, out);
+ } else {
+ exportTable(inputPath, out, wrapHtml);
+ }
+ out.flush();
+ if (out != System.out) {
+ out.close();
+ }
+ return 0;
+ }
+
+ private static void exportText(Path inputPath, PrintStream out) throws
IOException {
+ MatrixWritable mw = new MatrixWritable();
+ Text key = new Text();
+ readSeqFile(inputPath, key, mw);
+ Matrix m = mw.get();
+ ConfusionMatrix cm = new ConfusionMatrix(m);
+ out.println(cm.toString());
+ }
+
+ private static void exportTable(Path inputPath, PrintStream out, boolean
wrapHtml) throws IOException {
+ MatrixWritable mw = new MatrixWritable();
+ Text key = new Text();
+ readSeqFile(inputPath, key, mw);
+ String fileName = inputPath.getName();
+ fileName = fileName.substring(fileName.lastIndexOf('/') + 1,
fileName.length());
+ Matrix m = mw.get();
+ ConfusionMatrix cm = new ConfusionMatrix(m);
+ if (wrapHtml) {
+ printHeader(out, fileName);
+ }
+ out.println("<p/>");
+ printSummaryTable(cm, out);
+ out.println("<p/>");
+ printGrayTable(cm, out);
+ out.println("<p/>");
+ printCountsTable(cm, out);
+ out.println("<p/>");
+ printTextInBox(cm, out);
+ out.println("<p/>");
+ if (wrapHtml) {
+ printFooter(out);
+ }
+ }
+
+ private static List<String> stripDefault(ConfusionMatrix cm) {
+ List<String> stripped = Lists.newArrayList(cm.getLabels().iterator());
+ String defaultLabel = cm.getDefaultLabel();
+ int unclassified = cm.getTotal(defaultLabel);
+ if (unclassified > 0) {
+ return stripped;
+ }
+ stripped.remove(defaultLabel);
+ return stripped;
+ }
+
+ // TODO: test - this should work with HDFS files
+ private static void readSeqFile(Path path, Text key, MatrixWritable m)
throws IOException {
+ Configuration conf = new Configuration();
+ FileSystem fs = FileSystem.get(conf);
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf);
+ reader.next(key, m);
+ }
+
+ // TODO: test - this might not work with HDFS files?
+ // after all, it does no seeks
+ private static PrintStream getPrintStream(String outputFilename) throws
IOException {
+ if (outputFilename != null) {
+ File outputFile = new File(outputFilename);
+ if (outputFile.exists()) {
+ outputFile.delete();
+ }
+ outputFile.createNewFile();
+ OutputStream os = new FileOutputStream(outputFile);
+ return new PrintStream(os);
+ } else {
+ return System.out;
+ }
+ }
+
+ private static int getLabelTotal(ConfusionMatrix cm, String rowLabel) {
+ Iterator<String> iter = cm.getLabels().iterator();
+ int count = 0;
+ while(iter.hasNext()) {
+ count += cm.getCount(rowLabel, iter.next());
+ }
+ return count;
+ }
+
+ // HTML generator code
+
+ private static void printTextInBox(ConfusionMatrix cm, PrintStream out) {
+ out.println("<div style='width:90%;overflow:scroll;'>");
+ out.println("<pre>");
+ out.println(cm.toString());
+ out.println("</pre>");
+ out.println("</div>");
+ }
+
+ public static void printSummaryTable(ConfusionMatrix cm, PrintStream out) {
+ format("<table class='%s'>\n", out, CSS_TABLE);
+ format("<tr class='%s'>", out, CSS_LABEL);
+ out.println("<td>Label</td><td>Total</td><td>Correct</td><td>%</td>");
+ out.println("</tr>");
+ List<String> labels = stripDefault(cm);
+ for(String label: labels) {
+ printSummaryRow(cm, out, label);
+ }
+ out.println("</table>");
+ }
+
+ private static void printSummaryRow(ConfusionMatrix cm, PrintStream out,
String label) {
+ format("<tr class='%s'>", out, CSS_CELL);
+ int correct = cm.getCorrect(label);
+ double accuracy = cm.getAccuracy(label);
+ int count = getCount(cm, label);
+ format("<td class='%s'>%s</td><td>%d</td><td>%d</td><td>%d</td>",
+ out, CSS_CELL, label, count, correct, (int) Math.round(accuracy));
+ out.println("</tr>");
+ }
+
+ private static int getCount(ConfusionMatrix cm, String label) {
+ int count = 0;
+ for (String s : cm.getLabels()) {
+ count += cm.getCount(label, s);
+ }
+ return count;
+ }
+
+ public static void printGrayTable(ConfusionMatrix cm, PrintStream out) {
+ format("<table class='%s'>\n", out, CSS_TABLE);
+ printCountsHeader(cm, out, true);
+ printGrayRows(cm, out);
+ out.println("</table>");
+ }
+
+ /**
+ * Print each value in a four-value grayscale based on count/max.
+ * Gives a mostly white matrix with grays in misclassified, and black in
diagonal.
+ * TODO: Using the sqrt(count/max) as the rating is more stringent
+ */
+ private static void printGrayRows(ConfusionMatrix cm, PrintStream out) {
+ List<String> labels = stripDefault(cm);
+ for (String label: labels) {
+ printGrayRow(cm, out, labels, label);
+ }
+ }
+
+ private static void printGrayRow(ConfusionMatrix cm, PrintStream out,
Iterable<String> labels, String rowLabel) {
+ format("<tr class='%s'>", out, CSS_LABEL);
+ format("<td>%s</td>", out, rowLabel);
+ int total = getLabelTotal(cm, rowLabel);
+ for (String columnLabel: labels) {
+ printGrayCell(cm, out, total, rowLabel, columnLabel);
+ }
+ out.println("</tr>");
+ }
+
+ // assign white/light/medium/dark to 0,1/4,1/2,3/4 of total number of inputs
+ // assign black to count = total, meaning complete success
+ // alternative rating is to use sqrt(total) instead of total - this is more
drastic
+ private static void printGrayCell(ConfusionMatrix cm,
+ PrintStream out,
+ int total,
+ String rowLabel,
+ String columnLabel) {
+
+ int count = cm.getCount(rowLabel, columnLabel);
+ if (count == 0) {
+ out.format("<td class='%s'/>", CSS_EMPTY);
+ } else {
+ // 0 is white, full is black, everything else gray
+ int rating = (int) ((count/ (double) total) * 4);
+ String css = CSS_GRAY_CELLS[rating];
+ format("<td class='%s' title='%s'>%s</td>", out, css, columnLabel,
count);
+ }
+ }
+
+ public static void printCountsTable(ConfusionMatrix cm, PrintStream out) {
+ int length = cm.getLabels().size();
+ format("<table class='%s'>\n", out, CSS_TABLE);
+ printCountsHeader(cm, out, false);
+ printCountsRows(cm, out);
+ out.println("</table>");
+ }
+
+ private static void printCountsRows(ConfusionMatrix cm, PrintStream out) {
+ List<String> labels = stripDefault(cm);
+ for(String label: labels) {
+ printCountsRow(cm, out, labels, label);
+ }
+ }
+
+ private static void printCountsRow(ConfusionMatrix cm, PrintStream out,
Iterable<String> labels, String rowLabel) {
+ out.println("<tr>");
+ format("<td class='%s'>%s</td>", out, CSS_LABEL, rowLabel);
+ for(String columnLabel: labels) {
+ printCountsCell(cm, out, rowLabel, columnLabel);
+ }
+ out.println("</tr>");
+ }
+
+ private static void printCountsCell(ConfusionMatrix cm, PrintStream out,
String rowLabel, String columnLabel) {
+ int count = cm.getCount(rowLabel, columnLabel);
+ String s = count == 0 ? "" : Integer.toString(count);
+ format("<td class='%s' title='%s'>%s</td>", out, CSS_CELL, columnLabel, s);
+ }
+
+ private static void printCountsHeader(ConfusionMatrix cm, PrintStream out,
boolean vertical) {
+ List<String> labels = stripDefault(cm);
+ int longest = getLongestHeader(labels);
+ if (vertical) {
+ // do vertical - rotation is a bitch
+ out.format("<tr class='%s' style='height:%dem'><th> </th>\n",
CSS_TALL_HEADER, longest/2);
+ for(String label: labels) {
+ out.format("<th><div class='%s'>%s</div></th>", CSS_VERTICAL, label);
+ }
+ out.println("</tr>");
+ } else {
+ // header - empty cell in upper left
+ out.format("<tr class='%s'><td class='%s'></td>\n", CSS_TABLE,
CSS_LABEL);
+ for(String label: labels) {
+ out.format("<td>%s</td>", label);
+ }
+ out.format("</tr>");
+ }
+ }
+
+ private static int getLongestHeader(Iterable<String> labels) {
+ int max = 0;
+ for (String label: labels) {
+ max = Math.max(label.length(), max);
+ }
+ return max;
+ }
+
+ private static void format(String format, PrintStream out, Object ... args) {
+ String format2 = String.format(format, args);
+ out.println(format2);
+ }
+
+ public static void printHeader(PrintStream out, CharSequence title) {
+ out.println(HEADER.replace("TITLE", title));
+ }
+
+ public static void printFooter(PrintStream out) {
+ out.println(FOOTER);
+ }
+
+}
Modified: mahout/trunk/src/conf/driver.classes.props
URL:
http://svn.apache.org/viewvc/mahout/trunk/src/conf/driver.classes.props?rev=1197510&r1=1197509&r2=1197510&view=diff
==============================================================================
--- mahout/trunk/src/conf/driver.classes.props (original)
+++ mahout/trunk/src/conf/driver.classes.props Fri Nov 4 11:20:03 2011
@@ -46,4 +46,6 @@ org.apache.mahout.classifier.sequencelea
org.apache.mahout.classifier.sequencelearning.hmm.RandomSequenceGenerator =
hmmpredict : Generate random sequence of observations by given HMM
org.apache.mahout.utils.SplitInput = split : Split Input data into test and
train sets
org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob = trainnb
: Train the Vector-based Bayes classifier
-org.apache.mahout.classifier.naivebayes.test.TestNaiveBayesDriver = testnb :
Test the Vector-based Bayes classifier
\ No newline at end of file
+org.apache.mahout.classifier.naivebayes.test.TestNaiveBayesDriver = testnb :
Test the Vector-based Bayes classifier
+org.apache.mahout.classifier.ConfusionMatrixDumper = cmdump : Dump confusion
matrix in HTML or text formats
+org.apache.mahout.utils.MatrixDumper = matrixdump : Dump matrix in CSV format
\ No newline at end of file