Author: smarthi Date: Sat Jun 15 04:40:02 2013 New Revision: 1493313 URL: http://svn.apache.org/r1493313 Log: MAHOUT-1263:Serialise/Deserialise Lambda value for OnlineLogisticRegression
Modified: mahout/trunk/CHANGELOG mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java Modified: mahout/trunk/CHANGELOG URL: http://svn.apache.org/viewvc/mahout/trunk/CHANGELOG?rev=1493313&r1=1493312&r2=1493313&view=diff ============================================================================== --- mahout/trunk/CHANGELOG (original) +++ mahout/trunk/CHANGELOG Sat Jun 15 04:40:02 2013 @@ -2,6 +2,8 @@ Mahout Change Log Release 0.8 - unreleased + MAHOUT-1263: Serialise/Deserialise Lambda value for OnlineLogisticRegression (Mike Davy via smarthi) + MAHOUT-1258: Another shot at findbugs and checkstyle (ssc) MAHOUT-1253: Add experiment tools for StreamingKMeans, part 1 (dfilimon) Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java?rev=1493313&r1=1493312&r2=1493313&view=diff ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java (original) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java Sat Jun 15 04:40:02 2013 @@ -135,6 +135,7 @@ public class OnlineLogisticRegression ex public void write(DataOutput out) throws IOException { out.writeInt(WRITABLE_VERSION); out.writeDouble(mu0); + out.writeDouble(getLambda()); out.writeDouble(decayFactor); out.writeInt(stepOffset); out.writeInt(step); @@ -152,6 +153,7 @@ public class OnlineLogisticRegression ex int version = in.readInt(); if (version == WRITABLE_VERSION) { mu0 = in.readDouble(); + lambda(in.readDouble()); decayFactor = in.readDouble(); stepOffset = in.readInt(); step = in.readInt(); Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java?rev=1493313&r1=1493312&r2=1493313&view=diff ============================================================================== --- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java (original) +++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java Sat Jun 15 04:40:02 2013 @@ -27,15 +27,22 @@ import org.apache.mahout.math.DenseVecto import org.apache.mahout.math.Matrix; import org.apache.mahout.math.Vector; import org.apache.mahout.vectorizer.encoders.Dictionary; +import org.junit.Assert; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; import java.io.IOException; +import java.lang.reflect.Field; import java.util.Collections; import java.util.List; import java.util.Random; + public final class OnlineLogisticRegressionTest extends OnlineBaseTest { private static final Logger logger = LoggerFactory.getLogger(OnlineLogisticRegressionTest.class); @@ -184,7 +191,7 @@ public final class OnlineLogisticRegress // for permuting data later List<Integer> order = Lists.newArrayList(); - for (String line : raw.subList(1,raw.size())) { + for (String line : raw.subList(1, raw.size())) { // order gets a list of indexes order.add(order.size()); @@ -262,4 +269,58 @@ public final class OnlineLogisticRegress test(getInput(), target, lr, 0.05, 0.3); } + /** + * Test for Serialization/DeSerialization + * + * @throws Exception + */ + @Test + public void testSerializationAndDeSerialization() throws Exception { + OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1()) + .lambda(1 * 1.0e-3) + .stepOffset(11) + .alpha(0.01) + .learningRate(50) + .decayExponent(-0.02); + + lr.close(); + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream); + PolymorphicWritable.write(dataOutputStream, lr); + byte[] output = byteArrayOutputStream.toByteArray(); + byteArrayOutputStream.close(); + + ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(output); + DataInputStream dataInputStream = new DataInputStream(byteArrayInputStream); + OnlineLogisticRegression read = PolymorphicWritable.read(dataInputStream, OnlineLogisticRegression.class); + read.close(); + + //lambda + Assert.assertEquals((1.0e-3), read.getLambda(), 1.0e-7); + + // Reflection to get private variables + //stepOffset + Field stepOffset = lr.getClass().getDeclaredField("stepOffset"); + stepOffset.setAccessible(true); + int stepOffsetVal = (Integer) stepOffset.get(lr); + Assert.assertEquals(11, stepOffsetVal); + + //decayFactor (alpha) + Field decayFactor = lr.getClass().getDeclaredField("decayFactor"); + decayFactor.setAccessible(true); + double decayFactorVal = (Double) decayFactor.get(lr); + Assert.assertEquals(0.01, decayFactorVal, 1.0e-7); + + //learning rate (mu0) + Field mu0 = lr.getClass().getDeclaredField("mu0"); + mu0.setAccessible(true); + double mu0Val = (Double) mu0.get(lr); + Assert.assertEquals(50, mu0Val, 1.0e-7); + + //forgettingExponent (decayExponent) + Field forgettingExponent = lr.getClass().getDeclaredField("forgettingExponent"); + forgettingExponent.setAccessible(true); + double forgettingExponentVal = (Double) forgettingExponent.get(lr); + Assert.assertEquals(-0.02, forgettingExponentVal, 1.0e-7); + } } \ No newline at end of file