Author: tdunning
Date: Tue Sep  4 02:18:58 2012
New Revision: 1380432

URL: http://svn.apache.org/viewvc?rev=1380432&view=rev
Log:
MAHOUT-1059 - Abstract the idea of a cached length.

Added:
    
mahout/trunk/math/src/main/java/org/apache/mahout/math/LengthCachingVector.java
    
mahout/trunk/math/src/test/java/org/apache/mahout/math/FileBasedSparseBinaryMatrixTest.java
Modified:
    mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java
    mahout/trunk/math/src/main/java/org/apache/mahout/math/DelegatingVector.java
    mahout/trunk/math/src/main/java/org/apache/mahout/math/MatrixVectorView.java
    
mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java

Modified: 
mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java?rev=1380432&r1=1380431&r2=1380432&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java 
(original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java 
Tue Sep  4 02:18:58 2012
@@ -25,7 +25,7 @@ import org.apache.mahout.math.function.F
 import java.util.Iterator;
 
 /** Implementations of generic capabilities like sum of elements and dot 
products */
-public abstract class AbstractVector implements Vector {
+public abstract class AbstractVector implements Vector, LengthCachingVector {
   
   private static final double LOG2 = Math.log(2.0);
 
@@ -186,7 +186,18 @@ public abstract class AbstractVector imp
     }
 
     // TODO: check the numNonDefault elements to further optimize
-    Vector result = like().assign(this);
+    Vector result;
+    if (isDense()) {
+      result = like().assign(this);
+    } else {
+      result = like();
+      Iterator<Element> i = this.iterateNonZero();
+      while (i.hasNext()) {
+        final Element element = i.next();
+        result.setQuick(element.index(), element.get());
+      }
+    }
+
     Iterator<Element> iter = that.iterateNonZero();
     while (iter.hasNext()) {
       Element thatElement = iter.next();
@@ -282,33 +293,44 @@ public abstract class AbstractVector imp
   }
 
   @Override
+  public void setLengthSquared(double d2) {
+    lengthSquared = d2;
+  }
+
+  @Override
   public double getDistanceSquared(Vector v) {
     if (size != v.size()) {
       throw new CardinalityException(size, v.size());
     }
     // if this and v has a cached lengthSquared, dot product is quickest way 
to compute this.
-    if (lengthSquared >= 0 && v instanceof AbstractVector && 
((AbstractVector)v).lengthSquared >= 0) {
+    if (lengthSquared >= 0 && v instanceof LengthCachingVector && 
v.getLengthSquared() >= 0) {
       return lengthSquared + v.getLengthSquared() - 2 * this.dot(v);
     }
+    Vector sparseAccessed;
     Vector randomlyAccessed;
-    Iterator<Element> it;
-    double d = 0.0;
     if (lengthSquared >= 0.0) {
-      it = v.iterateNonZero();
       randomlyAccessed = this;
-      d += lengthSquared;
+      sparseAccessed = v;
     } else { // TODO: could be further optimized, figure out which one is 
smaller, etc
-      it = iterateNonZero();
       randomlyAccessed = v;
-      d += v.getLengthSquared();
+      sparseAccessed = this;
     }
+
+    Iterator<Element> it = sparseAccessed.iterateNonZero();
+    double d = randomlyAccessed.getLengthSquared();
+    double d2 = 0;
+    double dot = 0;
     while (it.hasNext()) {
       Element e = it.next();
       double value = e.get();
-      d += value * (value - 2.0 * randomlyAccessed.getQuick(e.index()));
+      d2 += value * value;
+      dot += value * randomlyAccessed.getQuick(e.index());
+    }
+    if (sparseAccessed instanceof LengthCachingVector) {
+      ((LengthCachingVector) sparseAccessed).setLengthSquared(d2);
     }
     //assert d > -1.0e-9; // round-off errors should never be too far off!
-    return Math.abs(d);
+    return Math.abs(d + d2 - 2 * dot);
   }
 
   @Override

Modified: 
mahout/trunk/math/src/main/java/org/apache/mahout/math/DelegatingVector.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/DelegatingVector.java?rev=1380432&r1=1380431&r2=1380432&view=diff
==============================================================================
--- 
mahout/trunk/math/src/main/java/org/apache/mahout/math/DelegatingVector.java 
(original)
+++ 
mahout/trunk/math/src/main/java/org/apache/mahout/math/DelegatingVector.java 
Tue Sep  4 02:18:58 2012
@@ -25,8 +25,11 @@ import java.util.Iterator;
 /**
  * A delegating vector provides an easy way to decorate vectors with weights 
or id's and such while
  * keeping all of the Vector functionality.
+ *
+ * This vector implements LengthCachingVector because almost all delegates 
cache the length and
+ * the cost of false positives is very low.
  */
-public class DelegatingVector implements Vector {
+public class DelegatingVector implements Vector, LengthCachingVector {
   protected Vector delegate;
 
   public DelegatingVector(int size) {
@@ -123,6 +126,17 @@ public class DelegatingVector implements
     return delegate.getLengthSquared();
   }
 
+  // not normally called because the delegate vector is who would need this and
+  // they will call their own version of this method.  In fact, if the 
delegate is
+  // also a delegating vector the same logic will apply recursively down to 
the first
+  // non-delegating vector.  This makes this very hard to test except in 
trivial ways.
+  @Override
+  public void setLengthSquared(double d2) {
+    if (delegate instanceof LengthCachingVector) {
+      ((LengthCachingVector) delegate).setLengthSquared(d2);
+    }
+  }
+
   @Override
   public double getDistanceSquared(Vector v) {
     return delegate.getDistanceSquared(v);

Added: 
mahout/trunk/math/src/main/java/org/apache/mahout/math/LengthCachingVector.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/LengthCachingVector.java?rev=1380432&view=auto
==============================================================================
--- 
mahout/trunk/math/src/main/java/org/apache/mahout/math/LengthCachingVector.java 
(added)
+++ 
mahout/trunk/math/src/main/java/org/apache/mahout/math/LengthCachingVector.java 
Tue Sep  4 02:18:58 2012
@@ -0,0 +1,15 @@
+package org.apache.mahout.math;
+
+/**
+ * Marker interface for vectors that may cache their squared length.
+ */
+interface LengthCachingVector {
+  public double getLengthSquared();
+
+  /**
+   * This is a very dangerous method to call.  Passing in a wrong value can
+   * completely screw up distance computations and normalization.
+   * @param d2  The new value for the squared length cache.
+   */
+  public void setLengthSquared(double d2);
+}

Modified: 
mahout/trunk/math/src/main/java/org/apache/mahout/math/MatrixVectorView.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/MatrixVectorView.java?rev=1380432&r1=1380431&r2=1380432&view=diff
==============================================================================
--- 
mahout/trunk/math/src/main/java/org/apache/mahout/math/MatrixVectorView.java 
(original)
+++ 
mahout/trunk/math/src/main/java/org/apache/mahout/math/MatrixVectorView.java 
Tue Sep  4 02:18:58 2012
@@ -29,6 +29,12 @@ public class MatrixVectorView extends Ab
   private int column;
   private int rowStride;
   private int columnStride;
+  private boolean isDense = true;
+
+  public MatrixVectorView(Matrix matrix, int row, int column, int rowStride, 
int columnStride, boolean isDense) {
+    this(matrix, row, column, rowStride, columnStride);
+    this.isDense = isDense;
+  }
 
   public MatrixVectorView(Matrix matrix, int row, int column, int rowStride, 
int columnStride) {
     super(viewSize(matrix, row, column, rowStride, columnStride));
@@ -64,7 +70,7 @@ public class MatrixVectorView extends Ab
    */
   @Override
   public boolean isDense() {
-    return true;
+    return isDense;
   }
 
   /**

Modified: 
mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java?rev=1380432&r1=1380431&r2=1380432&view=diff
==============================================================================
--- 
mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java 
(original)
+++ 
mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java 
Tue Sep  4 02:18:58 2012
@@ -1,6 +1,7 @@
 package org.apache.mahout.math;
 
 import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.function.Functions;
 import org.apache.mahout.math.jet.random.Normal;
 import org.junit.Test;
 
@@ -15,6 +16,9 @@ import java.util.Random;
  * confidence that it is working correctly.
  */
 public abstract class AbstractVectorTest<T extends Vector> extends 
MahoutTestCase {
+
+  private static final double FUZZ = 1e-13;
+
   public abstract T vectorToTest(int size);
 
   @Test
@@ -35,58 +39,99 @@ public abstract class AbstractVectorTest
     Vector sv1 = new RandomAccessSparseVector(v1);
     Vector sv2 = new RandomAccessSparseVector(v2);
 
-    assertEquals(0, dv1.plus(dv2).getDistanceSquared(v1.plus(v2)), 1e-13);
-    assertEquals(0, dv1.plus(dv2).getDistanceSquared(v1.plus(dv2)), 1e-13);
-    assertEquals(0, dv1.plus(dv2).getDistanceSquared(v1.plus(sv2)), 1e-13);
-    assertEquals(0, dv1.plus(dv2).getDistanceSquared(sv1.plus(v2)), 1e-13);
-
-    assertEquals(0, dv1.minus(dv2).getDistanceSquared(v1.minus(v2)), 1e-13);
-    assertEquals(0, dv1.minus(dv2).getDistanceSquared(v1.minus(dv2)), 1e-13);
-    assertEquals(0, dv1.minus(dv2).getDistanceSquared(v1.minus(sv2)), 1e-13);
-    assertEquals(0, dv1.minus(dv2).getDistanceSquared(sv1.minus(v2)), 1e-13);
+    assertEquals(0, dv1.plus(dv2).getDistanceSquared(v1.plus(v2)), FUZZ);
+    assertEquals(0, dv1.plus(dv2).getDistanceSquared(v1.plus(dv2)), FUZZ);
+    assertEquals(0, dv1.plus(dv2).getDistanceSquared(v1.plus(sv2)), FUZZ);
+    assertEquals(0, dv1.plus(dv2).getDistanceSquared(sv1.plus(v2)), FUZZ);
+
+    assertEquals(0, dv1.times(dv2).getDistanceSquared(v1.times(v2)), FUZZ);
+    assertEquals(0, dv1.times(dv2).getDistanceSquared(v1.times(dv2)), FUZZ);
+    assertEquals(0, dv1.times(dv2).getDistanceSquared(v1.times(sv2)), FUZZ);
+    assertEquals(0, dv1.times(dv2).getDistanceSquared(sv1.times(v2)), FUZZ);
+
+    assertEquals(0, dv1.minus(dv2).getDistanceSquared(v1.minus(v2)), FUZZ);
+    assertEquals(0, dv1.minus(dv2).getDistanceSquared(v1.minus(dv2)), FUZZ);
+    assertEquals(0, dv1.minus(dv2).getDistanceSquared(v1.minus(sv2)), FUZZ);
+    assertEquals(0, dv1.minus(dv2).getDistanceSquared(sv1.minus(v2)), FUZZ);
 
     double z = gen.nextDouble();
     assertEquals(0, dv1.divide(z).getDistanceSquared(v1.divide(z)), 1e-12);
     assertEquals(0, dv1.times(z).getDistanceSquared(v1.times(z)), 1e-12);
     assertEquals(0, dv1.plus(z).getDistanceSquared(v1.plus(z)), 1e-12);
 
-    assertEquals(dv1.dot(dv2), v1.dot(v2), 1e-13);
-    assertEquals(dv1.dot(dv2), v1.dot(dv2), 1e-13);
-    assertEquals(dv1.dot(dv2), v1.dot(sv2), 1e-13);
-    assertEquals(dv1.dot(dv2), sv1.dot(v2), 1e-13);
-    assertEquals(dv1.dot(dv2), dv1.dot(v2), 1e-13);
-
-    assertEquals(dv1.getLengthSquared(), v1.getLengthSquared(), 1e-13);
-    assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(v2), 
1e-13);
-    assertEquals(dv1.getDistanceSquared(dv2), dv1.getDistanceSquared(v2), 
1e-13);
-    assertEquals(dv1.getDistanceSquared(dv2), sv1.getDistanceSquared(v2), 
1e-13);
-    assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(dv2), 
1e-13);
-    assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(sv2), 
1e-13);
+    assertEquals(dv1.dot(dv2), v1.dot(v2), FUZZ);
+    assertEquals(dv1.dot(dv2), v1.dot(dv2), FUZZ);
+    assertEquals(dv1.dot(dv2), v1.dot(sv2), FUZZ);
+    assertEquals(dv1.dot(dv2), sv1.dot(v2), FUZZ);
+    assertEquals(dv1.dot(dv2), dv1.dot(v2), FUZZ);
+
+    // first attempt has no cached distances
+    assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(v2), FUZZ);
+    assertEquals(dv1.getDistanceSquared(dv2), dv1.getDistanceSquared(v2), 
FUZZ);
+    assertEquals(dv1.getDistanceSquared(dv2), sv1.getDistanceSquared(v2), 
FUZZ);
+    assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(dv2), 
FUZZ);
+    assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(sv2), 
FUZZ);
+
+    // now repeat with cached sizes
+    assertEquals(dv1.getLengthSquared(), v1.getLengthSquared(), FUZZ);
+    assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(v2), FUZZ);
+    assertEquals(dv1.getDistanceSquared(dv2), dv1.getDistanceSquared(v2), 
FUZZ);
+    assertEquals(dv1.getDistanceSquared(dv2), sv1.getDistanceSquared(v2), 
FUZZ);
+    assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(dv2), 
FUZZ);
+    assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(sv2), 
FUZZ);
 
-    assertEquals(dv1.minValue(), v1.minValue(), 1e-13);
+    assertEquals(dv1.minValue(), v1.minValue(), FUZZ);
     assertEquals(dv1.minValueIndex(), v1.minValueIndex());
 
-    assertEquals(dv1.maxValue(), v1.maxValue(), 1e-13);
+    assertEquals(dv1.maxValue(), v1.maxValue(), FUZZ);
     assertEquals(dv1.maxValueIndex(), v1.maxValueIndex());
 
     Vector nv1 = v1.normalize();
 
-    assertEquals(0, dv1.getDistanceSquared(v1), 1e-13);
-    assertEquals(1, nv1.norm(2), 1e-13);
-    assertEquals(0, dv1.normalize().getDistanceSquared(nv1), 1e-13);
+    assertEquals(0, dv1.getDistanceSquared(v1), FUZZ);
+    assertEquals(1, nv1.norm(2), FUZZ);
+    assertEquals(0, dv1.normalize().getDistanceSquared(nv1), FUZZ);
 
     nv1 = v1.normalize(1);
-    assertEquals(0, dv1.getDistanceSquared(v1), 1e-13);
-    assertEquals(1, nv1.norm(1), 1e-13);
-    assertEquals(0, dv1.normalize(1).getDistanceSquared(nv1), 1e-13);
-
-    assertEquals(dv1.norm(0), v1.norm(0), 1e-13);
-    assertEquals(dv1.norm(1), v1.norm(1), 1e-13);
-    assertEquals(dv1.norm(1.5), v1.norm(1.5), 1e-13);
-    assertEquals(dv1.norm(2), v1.norm(2), 1e-13);
-
-    // assign double, function, vector x function
-
+    assertEquals(0, dv1.getDistanceSquared(v1), FUZZ);
+    assertEquals(1, nv1.norm(1), FUZZ);
+    assertEquals(0, dv1.normalize(1).getDistanceSquared(nv1), FUZZ);
+
+    assertEquals(dv1.norm(0), v1.norm(0), FUZZ);
+    assertEquals(dv1.norm(1), v1.norm(1), FUZZ);
+    assertEquals(dv1.norm(1.5), v1.norm(1.5), FUZZ);
+    assertEquals(dv1.norm(2), v1.norm(2), FUZZ);
+
+    assertEquals(dv1.zSum(), v1.zSum(), FUZZ);
+
+    assertEquals(3.1 * v1.size(), v1.assign(3.1).zSum(), FUZZ);
+    assertEquals(0, v1.plus(-3.1).norm(1), FUZZ);
+    v1.assign(dv1);
+    assertEquals(0, v1.getDistanceSquared(dv1), FUZZ);
+
+    assertEquals(dv1.zSum() - dv1.size() * 3.4, 
v1.assign(Functions.minus(3.4)).zSum(), FUZZ);
+    assertEquals(dv1.zSum() - dv1.size() * 4.5, v1.assign(Functions.MINUS, 
1.1).zSum(), FUZZ);
+    v1.assign(dv1);
+
+    assertEquals(0, dv1.minus(dv2).getDistanceSquared(v1.assign(v2, 
Functions.MINUS)), FUZZ);
+    v1.assign(dv1);
+
+    assertEquals(dv1.norm(2), Math.sqrt(v1.aggregate(Functions.PLUS, 
Functions.pow(2))), FUZZ);
+    assertEquals(dv1.dot(dv2), v1.aggregate(v2, Functions.PLUS, 
Functions.MULT), FUZZ);
+
+    assertEquals(dv1.viewPart(5, 10).zSum(), v1.viewPart(5, 10).zSum(), FUZZ);
+
+    Vector v3 = v1.clone();
+    assertEquals(0, v1.getDistanceSquared(v3), FUZZ);
+    assertFalse(v1 == v3);
+    v3.assign(0);
+    assertEquals(0, dv1.getDistanceSquared(v1), FUZZ);
+    assertEquals(0, v3.getLengthSquared(), FUZZ);
+
+    dv1.assign(Functions.ABS);
+    v1.assign(Functions.ABS);
+    assertEquals(0, dv1.logNormalize().getDistanceSquared(v1.logNormalize()), 
FUZZ);
+    assertEquals(0, 
dv1.logNormalize(1.5).getDistanceSquared(v1.logNormalize(1.5)), FUZZ);
 
     // aggregate
 
@@ -99,5 +144,7 @@ public abstract class AbstractVectorTest
       assertEquals(dv1.get(element.index()), v1.get(element.index()), 0);
       assertEquals(dv1.get(element.index()), v1.getQuick(element.index()), 0);
     }
+
+
   }
 }

Added: 
mahout/trunk/math/src/test/java/org/apache/mahout/math/FileBasedSparseBinaryMatrixTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/FileBasedSparseBinaryMatrixTest.java?rev=1380432&view=auto
==============================================================================
--- 
mahout/trunk/math/src/test/java/org/apache/mahout/math/FileBasedSparseBinaryMatrixTest.java
 (added)
+++ 
mahout/trunk/math/src/test/java/org/apache/mahout/math/FileBasedSparseBinaryMatrixTest.java
 Tue Sep  4 02:18:58 2012
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.random.MultiNormal;
+import org.junit.Test;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Random;
+
+import static junit.framework.Assert.assertEquals;
+import static org.junit.Assume.assumeNotNull;
+
+public class FileBasedSparseBinaryMatrixTest {
+  // 10 million rows x 40 columns x 8 bytes = 3.2GB of data
+  // we need >2GB to stress the file based matrix implementation
+  private static final int ROWS = 10 * 1000 * 1000;
+  private static final int COLUMNS = 1000;
+
+  @Test
+  public void testBigMatrix() throws IOException {
+    // only run this test if -DrunSlowTests is used.  Also requires 4GB or 
more of heap.
+//    assumeNotNull(System.getProperty("runSlowTests"));
+
+    Matrix m0 = new SparseRowMatrix(ROWS, COLUMNS);
+    Random gen = new Random(1);
+    for (int i = 0; i < 1000; i++) {
+      m0.set(gen.nextInt(ROWS), gen.nextInt(COLUMNS), matrixValue(i));
+    }
+    final File f = File.createTempFile("foo", ".m");
+    f.deleteOnExit();
+    System.out.printf("Starting to write to %s\n", f.getAbsolutePath());
+    FileBasedSparseBinaryMatrix.writeMatrix(f, m0);
+    System.out.printf("done\n");
+    System.out.printf("File is %.1f MB\n", f.length() / 1e6);
+
+    FileBasedSparseBinaryMatrix m1 = new FileBasedSparseBinaryMatrix(ROWS, 
COLUMNS);
+    System.out.printf("Starting read\n");
+    m1.setData(f, false);
+    gen = new Random(1);
+    for (int i = 0; i < 1000; i++) {
+      assertEquals(matrixValue(i), m1.get(gen.nextInt(ROWS), 
gen.nextInt(COLUMNS)), 0.0);
+    }
+    System.out.printf("done\n");
+  }
+
+  private int matrixValue(int i) {
+    return (i * 88513) % 10000;
+  }
+
+  @Test
+  public void testSetData() throws IOException {
+    final int ROWS = 10;
+    final int COLS = 21;
+    File f = File.createTempFile("matrix", ".m");
+    f.deleteOnExit();
+
+    Random gen = RandomUtils.getRandom();
+    Matrix m0 = new SparseRowMatrix(ROWS, COLS);
+    for (MatrixSlice row : m0) {
+      int len = (int) Math.ceil(-15 * Math.log(1 - gen.nextDouble()));
+      for (int i = 0; i < len; i++) {
+        row.vector().set(gen.nextInt(COLS), 1);
+      }
+    }
+    FileBasedSparseBinaryMatrix.writeMatrix(f, m0);
+
+    FileBasedSparseBinaryMatrix m = new FileBasedSparseBinaryMatrix(ROWS, 
COLS);
+    m.setData(f, true);
+
+    for (MatrixSlice row : m) {
+      final Vector diff = row.vector().minus(m0.viewRow(row.index()));
+      final double error = diff.norm(1);
+      if (error > 1e-14) {
+        System.out.printf("%s\n", diff);
+      }
+      assertEquals(0, error, 1e-14);
+    }
+  }
+
+
+}


Reply via email to