Author: tdunning
Date: Tue Sep 4 02:18:58 2012
New Revision: 1380432
URL: http://svn.apache.org/viewvc?rev=1380432&view=rev
Log:
MAHOUT-1059 - Abstract the idea of a cached length.
Added:
mahout/trunk/math/src/main/java/org/apache/mahout/math/LengthCachingVector.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/FileBasedSparseBinaryMatrixTest.java
Modified:
mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/DelegatingVector.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/MatrixVectorView.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java
Modified:
mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java?rev=1380432&r1=1380431&r2=1380432&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java
(original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java
Tue Sep 4 02:18:58 2012
@@ -25,7 +25,7 @@ import org.apache.mahout.math.function.F
import java.util.Iterator;
/** Implementations of generic capabilities like sum of elements and dot
products */
-public abstract class AbstractVector implements Vector {
+public abstract class AbstractVector implements Vector, LengthCachingVector {
private static final double LOG2 = Math.log(2.0);
@@ -186,7 +186,18 @@ public abstract class AbstractVector imp
}
// TODO: check the numNonDefault elements to further optimize
- Vector result = like().assign(this);
+ Vector result;
+ if (isDense()) {
+ result = like().assign(this);
+ } else {
+ result = like();
+ Iterator<Element> i = this.iterateNonZero();
+ while (i.hasNext()) {
+ final Element element = i.next();
+ result.setQuick(element.index(), element.get());
+ }
+ }
+
Iterator<Element> iter = that.iterateNonZero();
while (iter.hasNext()) {
Element thatElement = iter.next();
@@ -282,33 +293,44 @@ public abstract class AbstractVector imp
}
@Override
+ public void setLengthSquared(double d2) {
+ lengthSquared = d2;
+ }
+
+ @Override
public double getDistanceSquared(Vector v) {
if (size != v.size()) {
throw new CardinalityException(size, v.size());
}
// if this and v has a cached lengthSquared, dot product is quickest way
to compute this.
- if (lengthSquared >= 0 && v instanceof AbstractVector &&
((AbstractVector)v).lengthSquared >= 0) {
+ if (lengthSquared >= 0 && v instanceof LengthCachingVector &&
v.getLengthSquared() >= 0) {
return lengthSquared + v.getLengthSquared() - 2 * this.dot(v);
}
+ Vector sparseAccessed;
Vector randomlyAccessed;
- Iterator<Element> it;
- double d = 0.0;
if (lengthSquared >= 0.0) {
- it = v.iterateNonZero();
randomlyAccessed = this;
- d += lengthSquared;
+ sparseAccessed = v;
} else { // TODO: could be further optimized, figure out which one is
smaller, etc
- it = iterateNonZero();
randomlyAccessed = v;
- d += v.getLengthSquared();
+ sparseAccessed = this;
}
+
+ Iterator<Element> it = sparseAccessed.iterateNonZero();
+ double d = randomlyAccessed.getLengthSquared();
+ double d2 = 0;
+ double dot = 0;
while (it.hasNext()) {
Element e = it.next();
double value = e.get();
- d += value * (value - 2.0 * randomlyAccessed.getQuick(e.index()));
+ d2 += value * value;
+ dot += value * randomlyAccessed.getQuick(e.index());
+ }
+ if (sparseAccessed instanceof LengthCachingVector) {
+ ((LengthCachingVector) sparseAccessed).setLengthSquared(d2);
}
//assert d > -1.0e-9; // round-off errors should never be too far off!
- return Math.abs(d);
+ return Math.abs(d + d2 - 2 * dot);
}
@Override
Modified:
mahout/trunk/math/src/main/java/org/apache/mahout/math/DelegatingVector.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/DelegatingVector.java?rev=1380432&r1=1380431&r2=1380432&view=diff
==============================================================================
---
mahout/trunk/math/src/main/java/org/apache/mahout/math/DelegatingVector.java
(original)
+++
mahout/trunk/math/src/main/java/org/apache/mahout/math/DelegatingVector.java
Tue Sep 4 02:18:58 2012
@@ -25,8 +25,11 @@ import java.util.Iterator;
/**
* A delegating vector provides an easy way to decorate vectors with weights
or id's and such while
* keeping all of the Vector functionality.
+ *
+ * This vector implements LengthCachingVector because almost all delegates
cache the length and
+ * the cost of false positives is very low.
*/
-public class DelegatingVector implements Vector {
+public class DelegatingVector implements Vector, LengthCachingVector {
protected Vector delegate;
public DelegatingVector(int size) {
@@ -123,6 +126,17 @@ public class DelegatingVector implements
return delegate.getLengthSquared();
}
+ // not normally called because the delegate vector is who would need this and
+ // they will call their own version of this method. In fact, if the
delegate is
+ // also a delegating vector the same logic will apply recursively down to
the first
+ // non-delegating vector. This makes this very hard to test except in
trivial ways.
+ @Override
+ public void setLengthSquared(double d2) {
+ if (delegate instanceof LengthCachingVector) {
+ ((LengthCachingVector) delegate).setLengthSquared(d2);
+ }
+ }
+
@Override
public double getDistanceSquared(Vector v) {
return delegate.getDistanceSquared(v);
Added:
mahout/trunk/math/src/main/java/org/apache/mahout/math/LengthCachingVector.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/LengthCachingVector.java?rev=1380432&view=auto
==============================================================================
---
mahout/trunk/math/src/main/java/org/apache/mahout/math/LengthCachingVector.java
(added)
+++
mahout/trunk/math/src/main/java/org/apache/mahout/math/LengthCachingVector.java
Tue Sep 4 02:18:58 2012
@@ -0,0 +1,15 @@
+package org.apache.mahout.math;
+
+/**
+ * Marker interface for vectors that may cache their squared length.
+ */
+interface LengthCachingVector {
+ public double getLengthSquared();
+
+ /**
+ * This is a very dangerous method to call. Passing in a wrong value can
+ * completely screw up distance computations and normalization.
+ * @param d2 The new value for the squared length cache.
+ */
+ public void setLengthSquared(double d2);
+}
Modified:
mahout/trunk/math/src/main/java/org/apache/mahout/math/MatrixVectorView.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/MatrixVectorView.java?rev=1380432&r1=1380431&r2=1380432&view=diff
==============================================================================
---
mahout/trunk/math/src/main/java/org/apache/mahout/math/MatrixVectorView.java
(original)
+++
mahout/trunk/math/src/main/java/org/apache/mahout/math/MatrixVectorView.java
Tue Sep 4 02:18:58 2012
@@ -29,6 +29,12 @@ public class MatrixVectorView extends Ab
private int column;
private int rowStride;
private int columnStride;
+ private boolean isDense = true;
+
+ public MatrixVectorView(Matrix matrix, int row, int column, int rowStride,
int columnStride, boolean isDense) {
+ this(matrix, row, column, rowStride, columnStride);
+ this.isDense = isDense;
+ }
public MatrixVectorView(Matrix matrix, int row, int column, int rowStride,
int columnStride) {
super(viewSize(matrix, row, column, rowStride, columnStride));
@@ -64,7 +70,7 @@ public class MatrixVectorView extends Ab
*/
@Override
public boolean isDense() {
- return true;
+ return isDense;
}
/**
Modified:
mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java?rev=1380432&r1=1380431&r2=1380432&view=diff
==============================================================================
---
mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java
(original)
+++
mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java
Tue Sep 4 02:18:58 2012
@@ -1,6 +1,7 @@
package org.apache.mahout.math;
import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.jet.random.Normal;
import org.junit.Test;
@@ -15,6 +16,9 @@ import java.util.Random;
* confidence that it is working correctly.
*/
public abstract class AbstractVectorTest<T extends Vector> extends
MahoutTestCase {
+
+ private static final double FUZZ = 1e-13;
+
public abstract T vectorToTest(int size);
@Test
@@ -35,58 +39,99 @@ public abstract class AbstractVectorTest
Vector sv1 = new RandomAccessSparseVector(v1);
Vector sv2 = new RandomAccessSparseVector(v2);
- assertEquals(0, dv1.plus(dv2).getDistanceSquared(v1.plus(v2)), 1e-13);
- assertEquals(0, dv1.plus(dv2).getDistanceSquared(v1.plus(dv2)), 1e-13);
- assertEquals(0, dv1.plus(dv2).getDistanceSquared(v1.plus(sv2)), 1e-13);
- assertEquals(0, dv1.plus(dv2).getDistanceSquared(sv1.plus(v2)), 1e-13);
-
- assertEquals(0, dv1.minus(dv2).getDistanceSquared(v1.minus(v2)), 1e-13);
- assertEquals(0, dv1.minus(dv2).getDistanceSquared(v1.minus(dv2)), 1e-13);
- assertEquals(0, dv1.minus(dv2).getDistanceSquared(v1.minus(sv2)), 1e-13);
- assertEquals(0, dv1.minus(dv2).getDistanceSquared(sv1.minus(v2)), 1e-13);
+ assertEquals(0, dv1.plus(dv2).getDistanceSquared(v1.plus(v2)), FUZZ);
+ assertEquals(0, dv1.plus(dv2).getDistanceSquared(v1.plus(dv2)), FUZZ);
+ assertEquals(0, dv1.plus(dv2).getDistanceSquared(v1.plus(sv2)), FUZZ);
+ assertEquals(0, dv1.plus(dv2).getDistanceSquared(sv1.plus(v2)), FUZZ);
+
+ assertEquals(0, dv1.times(dv2).getDistanceSquared(v1.times(v2)), FUZZ);
+ assertEquals(0, dv1.times(dv2).getDistanceSquared(v1.times(dv2)), FUZZ);
+ assertEquals(0, dv1.times(dv2).getDistanceSquared(v1.times(sv2)), FUZZ);
+ assertEquals(0, dv1.times(dv2).getDistanceSquared(sv1.times(v2)), FUZZ);
+
+ assertEquals(0, dv1.minus(dv2).getDistanceSquared(v1.minus(v2)), FUZZ);
+ assertEquals(0, dv1.minus(dv2).getDistanceSquared(v1.minus(dv2)), FUZZ);
+ assertEquals(0, dv1.minus(dv2).getDistanceSquared(v1.minus(sv2)), FUZZ);
+ assertEquals(0, dv1.minus(dv2).getDistanceSquared(sv1.minus(v2)), FUZZ);
double z = gen.nextDouble();
assertEquals(0, dv1.divide(z).getDistanceSquared(v1.divide(z)), 1e-12);
assertEquals(0, dv1.times(z).getDistanceSquared(v1.times(z)), 1e-12);
assertEquals(0, dv1.plus(z).getDistanceSquared(v1.plus(z)), 1e-12);
- assertEquals(dv1.dot(dv2), v1.dot(v2), 1e-13);
- assertEquals(dv1.dot(dv2), v1.dot(dv2), 1e-13);
- assertEquals(dv1.dot(dv2), v1.dot(sv2), 1e-13);
- assertEquals(dv1.dot(dv2), sv1.dot(v2), 1e-13);
- assertEquals(dv1.dot(dv2), dv1.dot(v2), 1e-13);
-
- assertEquals(dv1.getLengthSquared(), v1.getLengthSquared(), 1e-13);
- assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(v2),
1e-13);
- assertEquals(dv1.getDistanceSquared(dv2), dv1.getDistanceSquared(v2),
1e-13);
- assertEquals(dv1.getDistanceSquared(dv2), sv1.getDistanceSquared(v2),
1e-13);
- assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(dv2),
1e-13);
- assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(sv2),
1e-13);
+ assertEquals(dv1.dot(dv2), v1.dot(v2), FUZZ);
+ assertEquals(dv1.dot(dv2), v1.dot(dv2), FUZZ);
+ assertEquals(dv1.dot(dv2), v1.dot(sv2), FUZZ);
+ assertEquals(dv1.dot(dv2), sv1.dot(v2), FUZZ);
+ assertEquals(dv1.dot(dv2), dv1.dot(v2), FUZZ);
+
+ // first attempt has no cached distances
+ assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(v2), FUZZ);
+ assertEquals(dv1.getDistanceSquared(dv2), dv1.getDistanceSquared(v2),
FUZZ);
+ assertEquals(dv1.getDistanceSquared(dv2), sv1.getDistanceSquared(v2),
FUZZ);
+ assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(dv2),
FUZZ);
+ assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(sv2),
FUZZ);
+
+ // now repeat with cached sizes
+ assertEquals(dv1.getLengthSquared(), v1.getLengthSquared(), FUZZ);
+ assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(v2), FUZZ);
+ assertEquals(dv1.getDistanceSquared(dv2), dv1.getDistanceSquared(v2),
FUZZ);
+ assertEquals(dv1.getDistanceSquared(dv2), sv1.getDistanceSquared(v2),
FUZZ);
+ assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(dv2),
FUZZ);
+ assertEquals(dv1.getDistanceSquared(dv2), v1.getDistanceSquared(sv2),
FUZZ);
- assertEquals(dv1.minValue(), v1.minValue(), 1e-13);
+ assertEquals(dv1.minValue(), v1.minValue(), FUZZ);
assertEquals(dv1.minValueIndex(), v1.minValueIndex());
- assertEquals(dv1.maxValue(), v1.maxValue(), 1e-13);
+ assertEquals(dv1.maxValue(), v1.maxValue(), FUZZ);
assertEquals(dv1.maxValueIndex(), v1.maxValueIndex());
Vector nv1 = v1.normalize();
- assertEquals(0, dv1.getDistanceSquared(v1), 1e-13);
- assertEquals(1, nv1.norm(2), 1e-13);
- assertEquals(0, dv1.normalize().getDistanceSquared(nv1), 1e-13);
+ assertEquals(0, dv1.getDistanceSquared(v1), FUZZ);
+ assertEquals(1, nv1.norm(2), FUZZ);
+ assertEquals(0, dv1.normalize().getDistanceSquared(nv1), FUZZ);
nv1 = v1.normalize(1);
- assertEquals(0, dv1.getDistanceSquared(v1), 1e-13);
- assertEquals(1, nv1.norm(1), 1e-13);
- assertEquals(0, dv1.normalize(1).getDistanceSquared(nv1), 1e-13);
-
- assertEquals(dv1.norm(0), v1.norm(0), 1e-13);
- assertEquals(dv1.norm(1), v1.norm(1), 1e-13);
- assertEquals(dv1.norm(1.5), v1.norm(1.5), 1e-13);
- assertEquals(dv1.norm(2), v1.norm(2), 1e-13);
-
- // assign double, function, vector x function
-
+ assertEquals(0, dv1.getDistanceSquared(v1), FUZZ);
+ assertEquals(1, nv1.norm(1), FUZZ);
+ assertEquals(0, dv1.normalize(1).getDistanceSquared(nv1), FUZZ);
+
+ assertEquals(dv1.norm(0), v1.norm(0), FUZZ);
+ assertEquals(dv1.norm(1), v1.norm(1), FUZZ);
+ assertEquals(dv1.norm(1.5), v1.norm(1.5), FUZZ);
+ assertEquals(dv1.norm(2), v1.norm(2), FUZZ);
+
+ assertEquals(dv1.zSum(), v1.zSum(), FUZZ);
+
+ assertEquals(3.1 * v1.size(), v1.assign(3.1).zSum(), FUZZ);
+ assertEquals(0, v1.plus(-3.1).norm(1), FUZZ);
+ v1.assign(dv1);
+ assertEquals(0, v1.getDistanceSquared(dv1), FUZZ);
+
+ assertEquals(dv1.zSum() - dv1.size() * 3.4,
v1.assign(Functions.minus(3.4)).zSum(), FUZZ);
+ assertEquals(dv1.zSum() - dv1.size() * 4.5, v1.assign(Functions.MINUS,
1.1).zSum(), FUZZ);
+ v1.assign(dv1);
+
+ assertEquals(0, dv1.minus(dv2).getDistanceSquared(v1.assign(v2,
Functions.MINUS)), FUZZ);
+ v1.assign(dv1);
+
+ assertEquals(dv1.norm(2), Math.sqrt(v1.aggregate(Functions.PLUS,
Functions.pow(2))), FUZZ);
+ assertEquals(dv1.dot(dv2), v1.aggregate(v2, Functions.PLUS,
Functions.MULT), FUZZ);
+
+ assertEquals(dv1.viewPart(5, 10).zSum(), v1.viewPart(5, 10).zSum(), FUZZ);
+
+ Vector v3 = v1.clone();
+ assertEquals(0, v1.getDistanceSquared(v3), FUZZ);
+ assertFalse(v1 == v3);
+ v3.assign(0);
+ assertEquals(0, dv1.getDistanceSquared(v1), FUZZ);
+ assertEquals(0, v3.getLengthSquared(), FUZZ);
+
+ dv1.assign(Functions.ABS);
+ v1.assign(Functions.ABS);
+ assertEquals(0, dv1.logNormalize().getDistanceSquared(v1.logNormalize()),
FUZZ);
+ assertEquals(0,
dv1.logNormalize(1.5).getDistanceSquared(v1.logNormalize(1.5)), FUZZ);
// aggregate
@@ -99,5 +144,7 @@ public abstract class AbstractVectorTest
assertEquals(dv1.get(element.index()), v1.get(element.index()), 0);
assertEquals(dv1.get(element.index()), v1.getQuick(element.index()), 0);
}
+
+
}
}
Added:
mahout/trunk/math/src/test/java/org/apache/mahout/math/FileBasedSparseBinaryMatrixTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/FileBasedSparseBinaryMatrixTest.java?rev=1380432&view=auto
==============================================================================
---
mahout/trunk/math/src/test/java/org/apache/mahout/math/FileBasedSparseBinaryMatrixTest.java
(added)
+++
mahout/trunk/math/src/test/java/org/apache/mahout/math/FileBasedSparseBinaryMatrixTest.java
Tue Sep 4 02:18:58 2012
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.random.MultiNormal;
+import org.junit.Test;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Random;
+
+import static junit.framework.Assert.assertEquals;
+import static org.junit.Assume.assumeNotNull;
+
+public class FileBasedSparseBinaryMatrixTest {
+ // 10 million rows x 40 columns x 8 bytes = 3.2GB of data
+ // we need >2GB to stress the file based matrix implementation
+ private static final int ROWS = 10 * 1000 * 1000;
+ private static final int COLUMNS = 1000;
+
+ @Test
+ public void testBigMatrix() throws IOException {
+ // only run this test if -DrunSlowTests is used. Also requires 4GB or
more of heap.
+// assumeNotNull(System.getProperty("runSlowTests"));
+
+ Matrix m0 = new SparseRowMatrix(ROWS, COLUMNS);
+ Random gen = new Random(1);
+ for (int i = 0; i < 1000; i++) {
+ m0.set(gen.nextInt(ROWS), gen.nextInt(COLUMNS), matrixValue(i));
+ }
+ final File f = File.createTempFile("foo", ".m");
+ f.deleteOnExit();
+ System.out.printf("Starting to write to %s\n", f.getAbsolutePath());
+ FileBasedSparseBinaryMatrix.writeMatrix(f, m0);
+ System.out.printf("done\n");
+ System.out.printf("File is %.1f MB\n", f.length() / 1e6);
+
+ FileBasedSparseBinaryMatrix m1 = new FileBasedSparseBinaryMatrix(ROWS,
COLUMNS);
+ System.out.printf("Starting read\n");
+ m1.setData(f, false);
+ gen = new Random(1);
+ for (int i = 0; i < 1000; i++) {
+ assertEquals(matrixValue(i), m1.get(gen.nextInt(ROWS),
gen.nextInt(COLUMNS)), 0.0);
+ }
+ System.out.printf("done\n");
+ }
+
+ private int matrixValue(int i) {
+ return (i * 88513) % 10000;
+ }
+
+ @Test
+ public void testSetData() throws IOException {
+ final int ROWS = 10;
+ final int COLS = 21;
+ File f = File.createTempFile("matrix", ".m");
+ f.deleteOnExit();
+
+ Random gen = RandomUtils.getRandom();
+ Matrix m0 = new SparseRowMatrix(ROWS, COLS);
+ for (MatrixSlice row : m0) {
+ int len = (int) Math.ceil(-15 * Math.log(1 - gen.nextDouble()));
+ for (int i = 0; i < len; i++) {
+ row.vector().set(gen.nextInt(COLS), 1);
+ }
+ }
+ FileBasedSparseBinaryMatrix.writeMatrix(f, m0);
+
+ FileBasedSparseBinaryMatrix m = new FileBasedSparseBinaryMatrix(ROWS,
COLS);
+ m.setData(f, true);
+
+ for (MatrixSlice row : m) {
+ final Vector diff = row.vector().minus(m0.viewRow(row.index()));
+ final double error = diff.norm(1);
+ if (error > 1e-14) {
+ System.out.printf("%s\n", diff);
+ }
+ assertEquals(0, error, 1e-14);
+ }
+ }
+
+
+}