Author: tdunning
Date: Wed Sep 1 00:23:41 2010
New Revision: 991407
URL: http://svn.apache.org/viewvc?rev=991407&view=rev
Log:
MAHOUT-494 - Added setters/getters/constructors to allow GSON serialization.
Removed and added some transient attributes.
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
mahout/trunk/core/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java
mahout/trunk/core/src/main/java/org/apache/mahout/ep/Mapping.java
mahout/trunk/core/src/main/java/org/apache/mahout/ep/State.java
mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java?rev=991407&r1=991406&r2=991407&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
Wed Sep 1 00:23:41 2010
@@ -50,15 +50,15 @@ public abstract class AbstractOnlineLogi
private int step = 0;
// information about how long since coefficient rows were updated. This
allows lazy regularization.
- protected transient Vector updateSteps;
+ protected Vector updateSteps;
// information about how many updates we have had on a location. This
allows per-term
// annealing a la confidence weighted learning.
- protected transient Vector updateCounts;
+ protected Vector updateCounts;
// weight of the prior on beta
private double lambda = 1e-5;
- protected transient PriorFunction prior;
+ protected PriorFunction prior;
// can we ignore any further regularization when doing classification?
private boolean sealed = false;
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java?rev=991407&r1=991406&r2=991407&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
Wed Sep 1 00:23:41 2010
@@ -54,13 +54,18 @@ public class AdaptiveLogisticRegression
private int record = 0;
private int evaluationInterval = 1000;
- private final List<TrainingExample> buffer = Lists.newArrayList();
+ private List<TrainingExample> buffer = Lists.newArrayList();
private EvolutionaryProcess<Wrapper> ep;
private State<Wrapper> best;
private int threadCount = 20;
private int poolSize = 20;
- private final State<Wrapper> seed;
- private final int numFeatures;
+ private State<Wrapper> seed;
+ private int numFeatures;
+
+ // for GSON
+ @SuppressWarnings({"UnusedDeclaration"})
+ private AdaptiveLogisticRegression() {
+ }
public AdaptiveLogisticRegression(int numCategories, int numFeatures,
PriorFunction prior) {
this.numFeatures = numFeatures;
@@ -178,6 +183,65 @@ public class AdaptiveLogisticRegression
return best;
}
+ public void setBest(State<Wrapper> best) {
+ this.best = best;
+ }
+
+ public int getRecord() {
+ return record;
+ }
+
+ public void setRecord(int record) {
+ this.record = record;
+ }
+
+ public int getEvaluationInterval() {
+ return evaluationInterval;
+ }
+
+ public int getNumCategories() {
+ return seed.getPayload().getLearner().numCategories();
+ }
+
+ public PriorFunction getPrior() {
+ return seed.getPayload().getLearner().getPrior();
+ }
+
+ public void setEvaluationInterval(int evaluationInterval) {
+ this.evaluationInterval = evaluationInterval;
+ }
+
+ public void setBuffer(List<TrainingExample> buffer) {
+ this.buffer = buffer;
+ }
+
+ public List<TrainingExample> getBuffer() {
+ return buffer;
+ }
+
+ public EvolutionaryProcess<Wrapper> getEp() {
+ return ep;
+ }
+
+ public void setEp(EvolutionaryProcess<Wrapper> ep) {
+ this.ep = ep;
+ }
+
+ public State<Wrapper> getSeed() {
+ return seed;
+ }
+
+ public void setSeed(State<Wrapper> seed) {
+ this.seed = seed;
+ }
+
+ public int getNumFeatures() {
+ return numFeatures;
+ }
+
+ public void setNumFeatures(int numFeatures) {
+ this.numFeatures = numFeatures;
+ }
/**
* Provides a shim between the EP optimization stuff and the
CrossFoldLearner. The most important
@@ -190,9 +254,6 @@ public class AdaptiveLogisticRegression
* offset is done.
*/
public static class Wrapper implements Payload<Wrapper> {
- //private static volatile int counter = 0;
-
- //private volatile int id = counter++;
private CrossFoldLearner wrapped;
private Wrapper() {
@@ -214,7 +275,7 @@ public class AdaptiveLogisticRegression
public void update(double[] params) {
int i = 0;
wrapped.lambda(params[i++]);
- wrapped.learningRate(params[i++]);
+ wrapped.learningRate(params[i]);
wrapped.stepOffset(1);
wrapped.alpha(1);
@@ -224,7 +285,7 @@ public class AdaptiveLogisticRegression
public void setMappings(State<Wrapper> x) {
int i = 0;
x.setMap(i++, Mapping.logLimit(1.0e-8, 0.1));
- x.setMap(i++, Mapping.softLimit(0.001, 10));
+ x.setMap(i, Mapping.softLimit(0.001, 10));
}
public void train(TrainingExample example) {
@@ -242,9 +303,14 @@ public class AdaptiveLogisticRegression
}
public static class TrainingExample {
- private final long key;
- private final int actual;
- private final Vector instance;
+ private long key;
+ private int actual;
+ private Vector instance;
+
+ // for GSON
+ @SuppressWarnings({"UnusedDeclaration"})
+ private TrainingExample() {
+ }
public TrainingExample(long key, int actual, Vector instance) {
this.key = key;
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java?rev=991407&r1=991406&r2=991407&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
Wed Sep 1 00:23:41 2010
@@ -19,9 +19,6 @@ import java.util.List;
* record should be passed with each training example.
*/
public class CrossFoldLearner extends AbstractVectorClassifier implements
OnlineLearner {
- private static volatile int nextId = 0;
-
- private final int id = nextId++;
private int record = 0;
private OnlineAuc auc = new OnlineAuc();
private double logLikelihood = 0;
@@ -32,7 +29,12 @@ public class CrossFoldLearner extends Ab
private int numFeatures;
private PriorFunction prior;
- CrossFoldLearner(int folds, int numCategories, int numFeatures,
PriorFunction prior) {
+ // pretty much just for GSON
+ @SuppressWarnings({"UnusedDeclaration"})
+ public CrossFoldLearner() {
+ }
+
+ public CrossFoldLearner(int folds, int numCategories, int numFeatures,
PriorFunction prior) {
this.numFeatures = numFeatures;
this.prior = prior;
for (int i = 0; i < folds; i++) {
@@ -172,4 +174,60 @@ public class CrossFoldLearner extends Ab
}
return r;
}
+
+ public int getRecord() {
+ return record;
+ }
+
+ public void setRecord(int record) {
+ this.record = record;
+ }
+
+ public OnlineAuc getAuc() {
+ return auc;
+ }
+
+ public void setAuc(OnlineAuc auc) {
+ this.auc = auc;
+ }
+
+ public double getLogLikelihood() {
+ return logLikelihood;
+ }
+
+ public void setLogLikelihood(double logLikelihood) {
+ this.logLikelihood = logLikelihood;
+ }
+
+ public List<OnlineLogisticRegression> getModels() {
+ return models;
+ }
+
+ public void addModel(OnlineLogisticRegression model) {
+ models.add(model);
+ }
+
+ public double[] getParameters() {
+ return parameters;
+ }
+
+ public void setParameters(double[] parameters) {
+ this.parameters = parameters;
+ }
+
+ public int getNumFeatures() {
+ return numFeatures;
+ }
+
+ public void setNumFeatures(int numFeatures) {
+ this.numFeatures = numFeatures;
+ }
+
+ public PriorFunction getPrior() {
+ return prior;
+ }
+
+ public void setPrior(PriorFunction prior) {
+ this.prior = prior;
+ }
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java?rev=991407&r1=991406&r2=991407&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java
Wed Sep 1 00:23:41 2010
@@ -44,15 +44,23 @@ import java.util.concurrent.Future;
*/
public class EvolutionaryProcess<T extends Payload<T>> {
// used to execute operations on the population in thread parallel.
- private ExecutorService pool;
+ private transient ExecutorService pool;
+
+ // threadCount is serialized so that we can reconstruct the thread pool
+ private int threadCount;
// list of members of the population
private List<State<T>> population;
// how big should the population be. If this is changed, it will take effect
// the next time the population is mutated.
+
private int populationSize;
+ public EvolutionaryProcess() {
+ population = Lists.newArrayList();
+ }
+
/**
* Creates an evolutionary optimization framework with specified threadiness,
* population size and initial state.
@@ -62,13 +70,21 @@ public class EvolutionaryProcess<T exten
*/
public EvolutionaryProcess(int threadCount, int populationSize, State<T>
seed) {
this.populationSize = populationSize;
- pool = Executors.newFixedThreadPool(threadCount);
+ setThreadCount(threadCount);
+ initializePopulation(populationSize, seed);
+ }
+
+ private void initializePopulation(int populationSize, State<T> seed) {
population = Lists.newArrayList(seed);
for (int i = 0; i < populationSize; i++) {
population.add(seed.mutate());
}
}
+ public void add(State<T> value) {
+ population.add(value);
+ }
+
/**
* Nuke all but a few of the current population and then repopulate with
* variants of the survivors.
@@ -132,6 +148,23 @@ public class EvolutionaryProcess<T exten
return best;
}
+ public void setThreadCount(int threadCount) {
+ this.threadCount = threadCount;
+ pool = Executors.newFixedThreadPool(threadCount);
+ }
+
+ public int getThreadCount() {
+ return threadCount;
+ }
+
+ public int getPopulationSize() {
+ return populationSize;
+ }
+
+ public List<State<T>> getPopulation() {
+ return population;
+ }
+
public interface Function<T> {
double apply(T payload, double[] params);
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/ep/Mapping.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/ep/Mapping.java?rev=991407&r1=991406&r2=991407&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/ep/Mapping.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/ep/Mapping.java Wed Sep
1 00:23:41 2010
@@ -7,6 +7,71 @@ import org.apache.mahout.math.function.U
* reals but have the output limited and squished in convenient (and safe)
ways.
*/
public abstract class Mapping implements UnaryFunction {
+ public static class SoftLimit extends Mapping {
+ private double min;
+ private double max;
+ private double scale;
+
+ @SuppressWarnings({"UnusedDeclaration"})
+ private SoftLimit() {
+ }
+
+ private SoftLimit(double min, double max, double scale) {
+ this.min = min;
+ this.max = max;
+ this.scale = scale;
+ }
+
+ public double apply(double v) {
+ return min + (max - min) * 1 / (1 + Math.exp(-v * scale));
+ }
+
+ }
+
+ public static class LogLimit extends Mapping {
+ private Mapping wrapped;
+
+ @SuppressWarnings({"UnusedDeclaration"})
+ private LogLimit() {
+ }
+
+ private LogLimit(double low, double high) {
+ wrapped = softLimit(Math.log(low), Math.log(high));
+ }
+
+ @Override
+ public double apply(double v) {
+ return Math.exp(wrapped.apply(v));
+ }
+ }
+
+ public static class Exponential extends Mapping {
+ private double scale;
+
+ @SuppressWarnings({"UnusedDeclaration"})
+ private Exponential() {
+ }
+
+ private Exponential(double scale) {
+ this.scale = scale;
+ }
+
+ @Override
+ public double apply(double v) {
+ return Math.exp(v * scale);
+ }
+ }
+
+ public static class Identity extends Mapping {
+ private Identity() {
+ }
+
+ @Override
+ public double apply(double v) {
+ return v;
+ }
+ }
+
/**
* Maps input to the open interval (min, max) with 0 going to the mean of
min and
* max. When scale is large, a larger proportion of values are mapped to
points
@@ -18,12 +83,7 @@ public abstract class Mapping implements
* @return A mapping that satisfies the desired constraint.
*/
public static Mapping softLimit(final double min, final double max, final
double scale) {
- return new Mapping() {
- @Override
- public double apply(double v) {
- return min + (max - min) * 1 / (1 + Math.exp(-v * scale));
- }
- };
+ return new SoftLimit(min, max, scale);
}
/**
@@ -54,14 +114,7 @@ public abstract class Mapping implements
if (high <= 0) {
throw new IllegalArgumentException("Upper bound for log limit must be >
0 but was " + high);
}
- return new Mapping() {
- Mapping wrapped = softLimit(Math.log(low), Math.log(high));
-
- @Override
- public double apply(double v) {
- return Math.exp(wrapped.apply(v));
- }
- };
+ return new LogLimit(low, high);
}
/**
@@ -78,12 +131,7 @@ public abstract class Mapping implements
* @return A positive value.
*/
public static Mapping exponential(final double scale) {
- return new Mapping() {
- @Override
- public double apply(double v) {
- return Math.exp(v * scale);
- }
- };
+ return new Exponential(scale);
}
/**
@@ -91,11 +139,6 @@ public abstract class Mapping implements
* @return The original value.
*/
public static Mapping identity() {
- return new Mapping() {
- @Override
- public double apply(double v) {
- return v;
- }
- };
+ return new Identity();
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/ep/State.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/ep/State.java?rev=991407&r1=991406&r2=991407&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/ep/State.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/ep/State.java Wed Sep 1
00:23:41 2010
@@ -1,5 +1,7 @@
package org.apache.mahout.ep;
+import com.google.common.collect.Lists;
+
import java.util.Arrays;
import java.util.Locale;
import java.util.Random;
@@ -27,9 +29,9 @@ public class State<T extends Payload<T>>
// object count is kept to break ties in comparison.
static volatile int objectCount = 0;
- int id = objectCount++;
+ private int id = objectCount++;
- private Random gen = new Random();
+ private transient Random gen = new Random();
// current state
private double[] params;
@@ -48,7 +50,7 @@ public class State<T extends Payload<T>>
private T payload;
- private State() {
+ public State() {
}
/**
@@ -124,6 +126,18 @@ public class State<T extends Payload<T>>
return m == null ? params[i] : m.apply(params[i]);
}
+ public int getId() {
+ return id;
+ }
+
+ public double[] getParams() {
+ return params;
+ }
+
+ public Mapping[] getMaps() {
+ return maps;
+ }
+
/**
* Returns all the parameters in mapped form.
* @return An array of parameters.
@@ -160,6 +174,22 @@ public class State<T extends Payload<T>>
this.omni = omni;
}
+ public void setId(int id) {
+ this.id = id;
+ }
+
+ public void setStep(double[] step) {
+ this.step = step;
+ }
+
+ public void setMaps(Mapping[] maps) {
+ this.maps = maps;
+ }
+
+ public void setMaps(Iterable<Mapping> maps) {
+ this.maps = Lists.newArrayList(maps).toArray(new Mapping[0]);
+ }
+
public void setValue(double v) {
value = v;
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java?rev=991407&r1=991406&r2=991407&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java
Wed Sep 1 00:23:41 2010
@@ -41,7 +41,7 @@ public class OnlineAuc {
public static final int HISTORY = 10;
private ReplacementPolicy policy = ReplacementPolicy.FAIR;
- private Random random = org.apache.mahout.common.RandomUtils.getRandom();
+ private transient Random random =
org.apache.mahout.common.RandomUtils.getRandom();
private final Matrix scores;
private final Vector averages;
private final Vector samples;