Author: tdunning
Date: Wed Sep  1 00:23:41 2010
New Revision: 991407

URL: http://svn.apache.org/viewvc?rev=991407&view=rev
Log:
MAHOUT-494 - Added setters/getters/constructors to allow GSON serialization.  
Removed and added some transient attributes.

Modified:
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java
    mahout/trunk/core/src/main/java/org/apache/mahout/ep/Mapping.java
    mahout/trunk/core/src/main/java/org/apache/mahout/ep/State.java
    mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java?rev=991407&r1=991406&r2=991407&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
 Wed Sep  1 00:23:41 2010
@@ -50,15 +50,15 @@ public abstract class AbstractOnlineLogi
   private int step = 0;
 
   // information about how long since coefficient rows were updated.  This 
allows lazy regularization.
-  protected transient Vector updateSteps;
+  protected Vector updateSteps;
 
   // information about how many updates we have had on a location.  This 
allows per-term
   // annealing a la confidence weighted learning.
-  protected transient Vector updateCounts;
+  protected Vector updateCounts;
 
   // weight of the prior on beta
   private double lambda = 1e-5;
-  protected transient PriorFunction prior;
+  protected PriorFunction prior;
 
   // can we ignore any further regularization when doing classification?
   private boolean sealed = false;

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java?rev=991407&r1=991406&r2=991407&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
 Wed Sep  1 00:23:41 2010
@@ -54,13 +54,18 @@ public class AdaptiveLogisticRegression 
   private int record = 0;
   private int evaluationInterval = 1000;
 
-  private final List<TrainingExample> buffer = Lists.newArrayList();
+  private List<TrainingExample> buffer = Lists.newArrayList();
   private EvolutionaryProcess<Wrapper> ep;
   private State<Wrapper> best;
   private int threadCount = 20;
   private int poolSize = 20;
-  private final State<Wrapper> seed;
-  private final int numFeatures;
+  private State<Wrapper> seed;
+  private int numFeatures;
+
+  // for GSON
+  @SuppressWarnings({"UnusedDeclaration"})
+  private AdaptiveLogisticRegression() {
+  }
 
   public AdaptiveLogisticRegression(int numCategories, int numFeatures, 
PriorFunction prior) {
     this.numFeatures = numFeatures;
@@ -178,6 +183,65 @@ public class AdaptiveLogisticRegression 
     return best;
   }
 
+  public void setBest(State<Wrapper> best) {
+    this.best = best;
+  }
+
+  public int getRecord() {
+    return record;
+  }
+
+  public void setRecord(int record) {
+    this.record = record;
+  }
+
+  public int getEvaluationInterval() {
+    return evaluationInterval;
+  }
+
+  public int getNumCategories() {
+    return seed.getPayload().getLearner().numCategories();
+  }
+
+  public PriorFunction getPrior() {
+    return seed.getPayload().getLearner().getPrior();
+  }
+
+  public void setEvaluationInterval(int evaluationInterval) {
+    this.evaluationInterval = evaluationInterval;
+  }
+
+  public void setBuffer(List<TrainingExample> buffer) {
+    this.buffer = buffer;
+  }
+
+  public List<TrainingExample> getBuffer() {
+    return buffer;
+  }
+
+  public EvolutionaryProcess<Wrapper> getEp() {
+    return ep;
+  }
+
+  public void setEp(EvolutionaryProcess<Wrapper> ep) {
+    this.ep = ep;
+  }
+
+  public State<Wrapper> getSeed() {
+    return seed;
+  }
+
+  public void setSeed(State<Wrapper> seed) {
+    this.seed = seed;
+  }
+
+  public int getNumFeatures() {
+    return numFeatures;
+  }
+
+  public void setNumFeatures(int numFeatures) {
+    this.numFeatures = numFeatures;
+  }
 
   /**
    * Provides a shim between the EP optimization stuff and the 
CrossFoldLearner.  The most important
@@ -190,9 +254,6 @@ public class AdaptiveLogisticRegression 
    * offset is done.
    */
   public static class Wrapper implements Payload<Wrapper> {
-    //private static volatile int counter = 0;
-
-    //private volatile int id = counter++;
     private CrossFoldLearner wrapped;
 
     private Wrapper() {
@@ -214,7 +275,7 @@ public class AdaptiveLogisticRegression 
     public void update(double[] params) {
       int i = 0;
       wrapped.lambda(params[i++]);
-      wrapped.learningRate(params[i++]);
+      wrapped.learningRate(params[i]);
 
       wrapped.stepOffset(1);
       wrapped.alpha(1);
@@ -224,7 +285,7 @@ public class AdaptiveLogisticRegression 
     public void setMappings(State<Wrapper> x) {
       int i = 0;
       x.setMap(i++, Mapping.logLimit(1.0e-8, 0.1));
-      x.setMap(i++, Mapping.softLimit(0.001, 10));
+      x.setMap(i, Mapping.softLimit(0.001, 10));
     }
 
     public void train(TrainingExample example) {
@@ -242,9 +303,14 @@ public class AdaptiveLogisticRegression 
   }
 
   public static class TrainingExample {
-    private final long key;
-    private final int actual;
-    private final Vector instance;
+    private long key;
+    private int actual;
+    private Vector instance;
+
+    // for GSON
+    @SuppressWarnings({"UnusedDeclaration"})
+    private TrainingExample() {
+    }
 
     public TrainingExample(long key, int actual, Vector instance) {
       this.key = key;

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java?rev=991407&r1=991406&r2=991407&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
 Wed Sep  1 00:23:41 2010
@@ -19,9 +19,6 @@ import java.util.List;
  * record should be passed with each training example.
  */
 public class CrossFoldLearner extends AbstractVectorClassifier implements 
OnlineLearner {
-  private static volatile int nextId = 0;
-
-  private final int id = nextId++;
   private int record = 0;
   private OnlineAuc auc = new OnlineAuc();
   private double logLikelihood = 0;
@@ -32,7 +29,12 @@ public class CrossFoldLearner extends Ab
   private int numFeatures;
   private PriorFunction prior;
 
-  CrossFoldLearner(int folds, int numCategories, int numFeatures, 
PriorFunction prior) {
+  // pretty much just for GSON
+  @SuppressWarnings({"UnusedDeclaration"})
+  public CrossFoldLearner() {
+  }
+
+  public CrossFoldLearner(int folds, int numCategories, int numFeatures, 
PriorFunction prior) {
     this.numFeatures = numFeatures;
     this.prior = prior;
     for (int i = 0; i < folds; i++) {
@@ -172,4 +174,60 @@ public class CrossFoldLearner extends Ab
     }
     return r;
   }
+
+  public int getRecord() {
+    return record;
+  }
+
+  public void setRecord(int record) {
+    this.record = record;
+  }
+
+  public OnlineAuc getAuc() {
+    return auc;
+  }
+
+  public void setAuc(OnlineAuc auc) {
+    this.auc = auc;
+  }
+
+  public double getLogLikelihood() {
+    return logLikelihood;
+  }
+
+  public void setLogLikelihood(double logLikelihood) {
+    this.logLikelihood = logLikelihood;
+  }
+
+  public List<OnlineLogisticRegression> getModels() {
+    return models;
+  }
+
+  public void addModel(OnlineLogisticRegression model) {
+    models.add(model);
+  }
+
+  public double[] getParameters() {
+    return parameters;
+  }
+
+  public void setParameters(double[] parameters) {
+    this.parameters = parameters;
+  }
+
+  public int getNumFeatures() {
+    return numFeatures;
+  }
+
+  public void setNumFeatures(int numFeatures) {
+    this.numFeatures = numFeatures;
+  }
+
+  public PriorFunction getPrior() {
+    return prior;
+  }
+
+  public void setPrior(PriorFunction prior) {
+    this.prior = prior;
+  }
 }

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java?rev=991407&r1=991406&r2=991407&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java 
(original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java 
Wed Sep  1 00:23:41 2010
@@ -44,15 +44,23 @@ import java.util.concurrent.Future;
  */
 public class EvolutionaryProcess<T extends Payload<T>> {
   // used to execute operations on the population in thread parallel.
-  private ExecutorService pool;
+  private transient ExecutorService pool;
+
+  // threadCount is serialized so that we can reconstruct the thread pool
+  private int threadCount;
 
   // list of members of the population
   private List<State<T>> population;
 
   // how big should the population be.  If this is changed, it will take effect
   // the next time the population is mutated.
+
   private int populationSize;
 
+  public EvolutionaryProcess() {
+    population = Lists.newArrayList();
+  }
+
   /**
    * Creates an evolutionary optimization framework with specified threadiness,
    * population size and initial state.
@@ -62,13 +70,21 @@ public class EvolutionaryProcess<T exten
    */
   public EvolutionaryProcess(int threadCount, int populationSize, State<T> 
seed) {
     this.populationSize = populationSize;
-    pool = Executors.newFixedThreadPool(threadCount);
+    setThreadCount(threadCount);
+    initializePopulation(populationSize, seed);
+  }
+
+  private void initializePopulation(int populationSize, State<T> seed) {
     population = Lists.newArrayList(seed);
     for (int i = 0; i < populationSize; i++) {
       population.add(seed.mutate());
     }
   }
 
+  public void add(State<T> value) {
+    population.add(value);
+  }
+
   /**
    * Nuke all but a few of the current population and then repopulate with
    * variants of the survivors.
@@ -132,6 +148,23 @@ public class EvolutionaryProcess<T exten
     return best;
   }
 
+  public void setThreadCount(int threadCount) {
+    this.threadCount = threadCount;
+    pool = Executors.newFixedThreadPool(threadCount);
+  }
+
+  public int getThreadCount() {
+    return threadCount;
+  }
+
+  public int getPopulationSize() {
+    return populationSize;
+  }
+
+  public List<State<T>> getPopulation() {
+    return population;
+  }
+
   public interface Function<T> {
     double apply(T payload, double[] params);
   }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/ep/Mapping.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/ep/Mapping.java?rev=991407&r1=991406&r2=991407&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/ep/Mapping.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/ep/Mapping.java Wed Sep  
1 00:23:41 2010
@@ -7,6 +7,71 @@ import org.apache.mahout.math.function.U
  * reals but have the output limited and squished in convenient (and safe) 
ways.
  */
 public abstract class Mapping implements UnaryFunction {
+  public static class SoftLimit extends Mapping {
+    private double min;
+    private double max;
+    private double scale;
+
+    @SuppressWarnings({"UnusedDeclaration"})
+    private SoftLimit() {
+    }
+
+    private SoftLimit(double min, double max, double scale) {
+      this.min = min;
+      this.max = max;
+      this.scale = scale;
+    }
+
+    public double apply(double v) {
+      return min + (max - min) * 1 / (1 + Math.exp(-v * scale));
+    }
+
+  }
+
+  public static class LogLimit extends Mapping {
+    private Mapping wrapped;
+
+    @SuppressWarnings({"UnusedDeclaration"})
+    private LogLimit() {
+    }
+
+    private LogLimit(double low, double high) {
+      wrapped = softLimit(Math.log(low), Math.log(high));
+    }
+
+    @Override
+    public double apply(double v) {
+      return Math.exp(wrapped.apply(v));
+    }
+  }
+
+  public static class Exponential extends Mapping {
+    private double scale;
+
+    @SuppressWarnings({"UnusedDeclaration"})
+    private Exponential() {
+    }
+
+    private Exponential(double scale) {
+      this.scale = scale;
+    }
+
+    @Override
+    public double apply(double v) {
+      return Math.exp(v * scale);
+    }
+  }
+
+  public static class Identity extends Mapping {
+    private Identity() {
+    }
+
+    @Override
+    public double apply(double v) {
+      return v;
+    }
+  }
+
   /**
    * Maps input to the open interval (min, max) with 0 going to the mean of 
min and
    * max.  When scale is large, a larger proportion of values are mapped to 
points
@@ -18,12 +83,7 @@ public abstract class Mapping implements
    * @return A mapping that satisfies the desired constraint.
    */
   public static Mapping softLimit(final double min, final double max, final 
double scale) {
-    return new Mapping() {
-      @Override
-      public double apply(double v) {
-        return min + (max - min) * 1 / (1 + Math.exp(-v * scale));
-      }
-    };
+    return new SoftLimit(min, max, scale);
   }
 
   /**
@@ -54,14 +114,7 @@ public abstract class Mapping implements
     if (high <= 0) {
       throw new IllegalArgumentException("Upper bound for log limit must be > 
0 but was " + high);
     }
-    return new Mapping() {
-      Mapping wrapped = softLimit(Math.log(low), Math.log(high));
-
-      @Override
-      public double apply(double v) {
-        return Math.exp(wrapped.apply(v));
-      }
-    };
+    return new LogLimit(low, high);
   }
 
   /**
@@ -78,12 +131,7 @@ public abstract class Mapping implements
    * @return A positive value.
    */
   public static Mapping exponential(final double scale) {
-    return new Mapping() {
-      @Override
-      public double apply(double v) {
-        return Math.exp(v * scale);
-      }
-    };
+    return new Exponential(scale);
   }
 
   /**
@@ -91,11 +139,6 @@ public abstract class Mapping implements
    * @return The original value.
    */
   public static Mapping identity() {
-    return new Mapping() {
-      @Override
-      public double apply(double v) {
-        return v;
-      }
-    };
+    return new Identity();
   }
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/ep/State.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/ep/State.java?rev=991407&r1=991406&r2=991407&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/ep/State.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/ep/State.java Wed Sep  1 
00:23:41 2010
@@ -1,5 +1,7 @@
 package org.apache.mahout.ep;
 
+import com.google.common.collect.Lists;
+
 import java.util.Arrays;
 import java.util.Locale;
 import java.util.Random;
@@ -27,9 +29,9 @@ public class State<T extends Payload<T>>
   // object count is kept to break ties in comparison.
   static volatile int objectCount = 0;
 
-  int id = objectCount++;
+  private int id = objectCount++;
 
-  private Random gen = new Random();
+  private transient Random gen = new Random();
 
   // current state
   private double[] params;
@@ -48,7 +50,7 @@ public class State<T extends Payload<T>>
 
   private T payload;
 
-  private State() {
+  public State() {
   }
 
   /**
@@ -124,6 +126,18 @@ public class State<T extends Payload<T>>
     return m == null ? params[i] : m.apply(params[i]);
   }
 
+  public int getId() {
+    return id;
+  }
+
+  public double[] getParams() {
+    return params;
+  }
+
+  public Mapping[] getMaps() {
+    return maps;
+  }
+
   /**
    * Returns all the parameters in mapped form.
    * @return An array of parameters.
@@ -160,6 +174,22 @@ public class State<T extends Payload<T>>
     this.omni = omni;
   }
 
+  public void setId(int id) {
+    this.id = id;
+  }
+
+  public void setStep(double[] step) {
+    this.step = step;
+  }
+
+  public void setMaps(Mapping[] maps) {
+    this.maps = maps;
+  }
+
+  public void setMaps(Iterable<Mapping> maps) {
+    this.maps = Lists.newArrayList(maps).toArray(new Mapping[0]);
+  }
+
   public void setValue(double v) {
     value = v;
   }

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java?rev=991407&r1=991406&r2=991407&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java 
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java 
Wed Sep  1 00:23:41 2010
@@ -41,7 +41,7 @@ public class OnlineAuc {
   public static final int HISTORY = 10;
 
   private ReplacementPolicy policy = ReplacementPolicy.FAIR;
-  private Random random = org.apache.mahout.common.RandomUtils.getRandom();
+  private transient Random random = 
org.apache.mahout.common.RandomUtils.getRandom();
   private final Matrix scores;
   private final Vector averages;
   private final Vector samples;


Reply via email to