Author: smarthi
Date: Sat Jun 15 04:40:02 2013
New Revision: 1493313

URL: http://svn.apache.org/r1493313
Log:
MAHOUT-1263:Serialise/Deserialise Lambda value for OnlineLogisticRegression

Modified:
    mahout/trunk/CHANGELOG
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
    
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java

Modified: mahout/trunk/CHANGELOG
URL: 
http://svn.apache.org/viewvc/mahout/trunk/CHANGELOG?rev=1493313&r1=1493312&r2=1493313&view=diff
==============================================================================
--- mahout/trunk/CHANGELOG (original)
+++ mahout/trunk/CHANGELOG Sat Jun 15 04:40:02 2013
@@ -2,6 +2,8 @@ Mahout Change Log
 
 Release 0.8 - unreleased
 
+  MAHOUT-1263: Serialise/Deserialise Lambda value for OnlineLogisticRegression 
(Mike Davy via smarthi)  
+
   MAHOUT-1258: Another shot at findbugs and checkstyle (ssc)
 
   MAHOUT-1253: Add experiment tools for StreamingKMeans, part 1 (dfilimon)

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java?rev=1493313&r1=1493312&r2=1493313&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
 Sat Jun 15 04:40:02 2013
@@ -135,6 +135,7 @@ public class OnlineLogisticRegression ex
   public void write(DataOutput out) throws IOException {
     out.writeInt(WRITABLE_VERSION);
     out.writeDouble(mu0);
+    out.writeDouble(getLambda()); 
     out.writeDouble(decayFactor);
     out.writeInt(stepOffset);
     out.writeInt(step);
@@ -152,6 +153,7 @@ public class OnlineLogisticRegression ex
     int version = in.readInt();
     if (version == WRITABLE_VERSION) {
       mu0 = in.readDouble();
+      lambda(in.readDouble()); 
       decayFactor = in.readDouble();
       stepOffset = in.readInt();
       step = in.readInt();

Modified: 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java?rev=1493313&r1=1493312&r2=1493313&view=diff
==============================================================================
--- 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
 (original)
+++ 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
 Sat Jun 15 04:40:02 2013
@@ -27,15 +27,22 @@ import org.apache.mahout.math.DenseVecto
 import org.apache.mahout.math.Matrix;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.vectorizer.encoders.Dictionary;
+import org.junit.Assert;
 import org.junit.Test;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
 import java.io.IOException;
+import java.lang.reflect.Field;
 import java.util.Collections;
 import java.util.List;
 import java.util.Random;
 
+
 public final class OnlineLogisticRegressionTest extends OnlineBaseTest {
 
   private static final Logger logger = 
LoggerFactory.getLogger(OnlineLogisticRegressionTest.class);
@@ -184,7 +191,7 @@ public final class OnlineLogisticRegress
     // for permuting data later
     List<Integer> order = Lists.newArrayList();
 
-    for (String line : raw.subList(1,raw.size())) {
+    for (String line : raw.subList(1, raw.size())) {
       // order gets a list of indexes
       order.add(order.size());
 
@@ -262,4 +269,58 @@ public final class OnlineLogisticRegress
     test(getInput(), target, lr, 0.05, 0.3);
   }
 
+  /**
+   * Test for Serialization/DeSerialization
+   *
+   * @throws Exception
+   */
+  @Test
+  public void testSerializationAndDeSerialization() throws Exception {
+    OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1())
+      .lambda(1 * 1.0e-3)
+      .stepOffset(11)
+      .alpha(0.01)
+      .learningRate(50)
+      .decayExponent(-0.02);
+
+    lr.close();
+    ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
+    DataOutputStream dataOutputStream = new 
DataOutputStream(byteArrayOutputStream);
+    PolymorphicWritable.write(dataOutputStream, lr);
+    byte[] output = byteArrayOutputStream.toByteArray();
+    byteArrayOutputStream.close();
+
+    ByteArrayInputStream byteArrayInputStream = new 
ByteArrayInputStream(output);
+    DataInputStream dataInputStream = new 
DataInputStream(byteArrayInputStream);
+    OnlineLogisticRegression read = PolymorphicWritable.read(dataInputStream, 
OnlineLogisticRegression.class);
+    read.close();
+
+    //lambda
+    Assert.assertEquals((1.0e-3), read.getLambda(), 1.0e-7);
+
+    // Reflection to get private variables
+    //stepOffset
+    Field stepOffset = lr.getClass().getDeclaredField("stepOffset");
+    stepOffset.setAccessible(true);
+    int stepOffsetVal = (Integer) stepOffset.get(lr);
+    Assert.assertEquals(11, stepOffsetVal);
+
+    //decayFactor (alpha)
+    Field decayFactor = lr.getClass().getDeclaredField("decayFactor");
+    decayFactor.setAccessible(true);
+    double decayFactorVal = (Double) decayFactor.get(lr);
+    Assert.assertEquals(0.01, decayFactorVal, 1.0e-7);
+
+    //learning rate (mu0)
+    Field mu0 = lr.getClass().getDeclaredField("mu0");
+    mu0.setAccessible(true);
+    double mu0Val = (Double) mu0.get(lr);
+    Assert.assertEquals(50, mu0Val, 1.0e-7);
+
+    //forgettingExponent (decayExponent)
+    Field forgettingExponent = 
lr.getClass().getDeclaredField("forgettingExponent");
+    forgettingExponent.setAccessible(true);
+    double forgettingExponentVal = (Double) forgettingExponent.get(lr);
+    Assert.assertEquals(-0.02, forgettingExponentVal, 1.0e-7);
+  }
 }
\ No newline at end of file


Reply via email to