Author: tdunning
Date: Wed Sep 15 06:18:04 2010
New Revision: 997193
URL: http://svn.apache.org/viewvc?rev=997193&view=rev
Log:
Add model reverse engineering for classifiers that extend
AbstractVectorClassifier.
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
mahout/trunk/core/src/main/java/org/apache/mahout/vectors/Dictionary.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java?rev=997193&r1=997192&r2=997193&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java
Wed Sep 15 06:18:04 2010
@@ -31,6 +31,17 @@ public abstract class AbstractVectorClas
public abstract Vector classify(Vector instance);
/**
+ * Classify a vector, but don't apply the inverse link function. For
logistic regression
+ * and other generalized linear models, this is just the linear part of the
classification.
+ * @param features A feature vector to be classified.
+ * @return A vector of scores. If transformed by the link function, these
will become probabilities.
+ */
+ public Vector classifyNoLink(Vector features) {
+ throw new UnsupportedOperationException("Classifier " +
this.getClass().getName() +
+ " doesn't support classification without a link");
+ }
+
+ /**
* Classifies a vector in the special case of a binary classifier where
* <code>classify(Vector)</code> would return a vector with only one
element. As such,
* using this method can void the allocation of a vector.
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java?rev=997193&r1=997192&r2=997193&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
Wed Sep 15 06:18:04 2010
@@ -86,6 +86,12 @@ public abstract class AbstractOnlineLogi
}
}
+ public Vector classifyNoLink(Vector instance) {
+ // apply pending regularization to whichever coefficients matter
+ regularize(instance);
+ return beta.times(instance);
+ }
+
/**
* Returns n-1 probabilities, one for each category but the 0-th. The
probability of the 0-th
* category is 1 - sum(this result).
@@ -94,11 +100,7 @@ public abstract class AbstractOnlineLogi
* @return A vector of probabilities, one for each of the first n-1
categories.
*/
public Vector classify(Vector instance) {
- // apply pending regularization to whichever coefficients matter
- regularize(instance);
-
- Vector v = beta.times(instance);
- return logisticLink(v);
+ return logisticLink(classifyNoLink(instance));
}
/**
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java?rev=997193&r1=997192&r2=997193&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
Wed Sep 15 06:18:04 2010
@@ -5,6 +5,7 @@ import org.apache.mahout.classifier.Abst
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.BinaryFunction;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.stats.OnlineAuc;
@@ -131,9 +132,18 @@ public class CrossFoldLearner extends Ab
@Override
public Vector classify(Vector instance) {
Vector r = new DenseVector(numCategories() - 1);
- double scale = 1.0 / models.size();
+ BinaryFunction scale = Functions.plusMult(1.0 / models.size());
for (OnlineLogisticRegression model : models) {
- r.assign(model.classify(instance), Functions.plusMult(scale));
+ r.assign(model.classify(instance), scale);
+ }
+ return r;
+ }
+
+ public Vector classifyNoLink(Vector instance) {
+ Vector r = new DenseVector(numCategories() - 1);
+ BinaryFunction scale = Functions.plusMult(1.0 / models.size());
+ for (OnlineLogisticRegression model : models) {
+ r.assign(model.classifyNoLink(instance), scale);
}
return r;
}
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java?rev=997193&view=auto
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
(added)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
Wed Sep 15 06:18:04 2010
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.Maps;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.QRDecomposition;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectors.Dictionary;
+
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Uses sample data to reverse engineer a feature-hashed model.
+ *
+ * The result gives approximate weights for features and interactions
+ * in the original space.
+ */
+public class ModelDissector {
+ int records = 0;
+ private Dictionary dict;
+ private Matrix a;
+ private Matrix b;
+
+ public ModelDissector(int n) {
+ a = new SparseRowMatrix(new int[]{Integer.MAX_VALUE, Integer.MAX_VALUE},
true);
+ b = new SparseRowMatrix(new int[]{Integer.MAX_VALUE, n});
+
+ dict.intern("Intercept Value");
+ }
+
+ public void addExample(Set<String> features, Vector score) {
+ for (Vector.Element element : score) {
+ b.set(records, element.index(), element.get());
+ }
+
+ for (String feature : features) {
+ int j = dict.intern(feature);
+ a.set(records, j, 1);
+ }
+ records++;
+ }
+
+ public void addExample(Set<String> features, double score) {
+ b.set(records, 0, score);
+
+ a.set(records, 0, 1);
+ for (String feature : features) {
+ int j = dict.intern(feature);
+ a.set(records, j, 1);
+ }
+ records++;
+ }
+
+ public Matrix solve() {
+ Matrix az = a.viewPart(new int[]{0, 0}, new int[]{records, dict.size()});
+ Matrix bz = b.viewPart(new int[]{0, 0}, new int[]{records,
b.columnSize()});
+ QRDecomposition qr = new QRDecomposition(az.transpose().times(az));
+ Matrix x = qr.solve(bz);
+ Map<String, Integer> labels = Maps.newHashMap();
+ int i = 0;
+ for (String s : dict.values()) {
+ labels.put(s, i++);
+ }
+ x.setRowLabelBindings(labels);
+ return x;
+ }
+}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/vectors/Dictionary.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/vectors/Dictionary.java?rev=997193&r1=997192&r2=997193&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/vectors/Dictionary.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/vectors/Dictionary.java
Wed Sep 15 06:18:04 2010
@@ -41,6 +41,10 @@ public class Dictionary {
return new ArrayList<String>(dict.keySet());
}
+ public int size() {
+ return dict.size();
+ }
+
public static Dictionary fromList(List<String> values) {
Dictionary dict = new Dictionary();
for (String value : values) {