Author: ssc
Date: Mon Feb 14 19:40:15 2011
New Revision: 1070622
URL: http://svn.apache.org/viewvc?rev=1070622&view=rev
Log:
MAHOUT-606 Parallelize non-distributed ALSWRFactorizer
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java?rev=1070622&r1=1070621&r2=1070622&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
Mon Feb 14 19:40:15 2011
@@ -34,6 +34,9 @@ import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
/**
* factorizes the rating matrix using "Alternating-Least-Squares with
Weighted-λ-Regularization" as described in
@@ -61,75 +64,146 @@ public class ALSWRFactorizer extends Abs
this.numIterations = numIterations;
}
+ static class Features {
+
+ private final DataModel dataModel;
+ private final int numFeatures;
+
+ private double[][] M;
+ private double[][] U;
+
+ Features(ALSWRFactorizer factorizer) throws TasteException {
+ this.dataModel = factorizer.dataModel;
+ this.numFeatures = factorizer.numFeatures;
+ Random random = RandomUtils.getRandom();
+ M = new double[this.dataModel.getNumItems()][this.numFeatures];
+ LongPrimitiveIterator itemIDsIterator = this.dataModel.getItemIDs();
+ while (itemIDsIterator.hasNext()) {
+ long itemID = itemIDsIterator.nextLong();
+ int itemIDIndex = factorizer.itemIndex(itemID);
+ M[itemIDIndex][0] = averateRating(itemID);
+ for (int feature = 1; feature < this.numFeatures; feature++) {
+ M[itemIDIndex][feature] = random.nextDouble() * 0.1;
+ }
+ }
+ U = new double[this.dataModel.getNumUsers()][this.numFeatures];
+ }
+
+ double[][] getM() {
+ return M;
+ }
+
+ double[][] getU() {
+ return U;
+ }
+
+ DenseVector getUserFeatureColumn(int index) {
+ return new DenseVector(U[index]);
+ }
+
+ DenseVector getItemFeatureColumn(int index) {
+ return new DenseVector(M[index]);
+ }
+
+ void setFeatureColumnInU(int idIndex, Vector vector) {
+ setFeatureColumn(U, idIndex, vector);
+ }
+
+ void setFeatureColumnInM(int idIndex, Vector vector) {
+ setFeatureColumn(M, idIndex, vector);
+ }
+
+ protected void setFeatureColumn(double[][] matrix, int idIndex, Vector
vector) {
+ for (int feature = 0; feature < numFeatures; feature++) {
+ matrix[idIndex][feature] = vector.get(feature);
+ }
+ }
+
+ protected double averateRating(long itemID) throws TasteException {
+ PreferenceArray prefs = dataModel.getPreferencesForItem(itemID);
+ RunningAverage avg = new FullRunningAverage();
+ for (Preference pref : prefs) {
+ avg.addDatum(pref.getValue());
+ }
+ return avg.getAverage();
+ }
+ }
+
@Override
public Factorization factorize() throws TasteException {
log.info("starting to compute the factorization...");
- AlternateLeastSquaresSolver solver = new AlternateLeastSquaresSolver();
-
- double[][] M = initializeM();
- double[][] U = null;
+ final AlternateLeastSquaresSolver solver = new
AlternateLeastSquaresSolver();
+ final Features features = new Features(this);
for (int iteration = 0; iteration < numIterations; iteration++) {
log.info("iteration {}", iteration);
/* fix M - compute U */
- U = new double[dataModel.getNumUsers()][numFeatures];
-
+ ExecutorService queue = createQueue();
LongPrimitiveIterator userIDsIterator = dataModel.getUserIDs();
- while (userIDsIterator.hasNext()) {
- long userID = userIDsIterator.nextLong();
- List<Vector> featureVectors = new ArrayList<Vector>();
- LongPrimitiveIterator itemIDsFromUser =
dataModel.getItemIDsFromUser(userID).iterator();
- while (itemIDsFromUser.hasNext()) {
- long itemID = itemIDsFromUser.nextLong();
- featureVectors.add(new DenseVector(M[itemIndex(itemID)]));
+ try {
+ while (userIDsIterator.hasNext()) {
+ final long userID = userIDsIterator.nextLong();
+ final LongPrimitiveIterator itemIDsFromUser =
dataModel.getItemIDsFromUser(userID).iterator();
+ final PreferenceArray userPrefs =
dataModel.getPreferencesFromUser(userID);
+ queue.execute(new Runnable() {
+ @Override
+ public void run() {
+ List<Vector> featureVectors = new ArrayList<Vector>();
+ while (itemIDsFromUser.hasNext()) {
+ long itemID = itemIDsFromUser.nextLong();
+
featureVectors.add(features.getItemFeatureColumn(itemIndex(itemID)));
+ }
+ Vector userFeatures = solver.solve(featureVectors,
ratingVector(userPrefs), lambda, numFeatures);
+ features.setFeatureColumnInU(userIndex(userID), userFeatures);
+ }
+ });
+ }
+ } finally {
+ queue.shutdown();
+ try {
+ queue.awaitTermination(dataModel.getNumUsers(), TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ throw new IllegalStateException("Error when computing user
features", e);
}
- PreferenceArray userPrefs = dataModel.getPreferencesFromUser(userID);
- Vector userFeatures = solver.solve(featureVectors,
ratingVector(userPrefs), lambda, numFeatures);
- setFeatureColumn(U, userIndex(userID), userFeatures);
}
/* fix U - compute M */
- M = new double[dataModel.getNumItems()][numFeatures];
-
+ queue = createQueue();
LongPrimitiveIterator itemIDsIterator = dataModel.getItemIDs();
- while (itemIDsIterator.hasNext()) {
- long itemID = itemIDsIterator.nextLong();
- List<Vector> featureVectors = new ArrayList<Vector>();
- for (Preference pref : dataModel.getPreferencesForItem(itemID)) {
- long userID = pref.getUserID();
- featureVectors.add(new DenseVector(U[userIndex(userID)]));
+ try {
+ while (itemIDsIterator.hasNext()) {
+ final long itemID = itemIDsIterator.nextLong();
+ final PreferenceArray itemPrefs =
dataModel.getPreferencesForItem(itemID);
+ queue.execute(new Runnable() {
+ @Override
+ public void run() {
+ List<Vector> featureVectors = new ArrayList<Vector>();
+ for (Preference pref : itemPrefs) {
+ long userID = pref.getUserID();
+
featureVectors.add(features.getUserFeatureColumn(userIndex(userID)));
+ }
+ Vector itemFeatures = solver.solve(featureVectors,
ratingVector(itemPrefs), lambda, numFeatures);
+ features.setFeatureColumnInM(itemIndex(itemID), itemFeatures);
+ }
+ });
+ }
+ } finally {
+ queue.shutdown();
+ try {
+ queue.awaitTermination(dataModel.getNumItems(), TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ throw new IllegalStateException("Error when computing item
features", e);
}
- PreferenceArray itemPrefs = dataModel.getPreferencesForItem(itemID);
- Vector itemFeatures = solver.solve(featureVectors,
ratingVector(itemPrefs), lambda, numFeatures);
- setFeatureColumn(M, itemIndex(itemID), itemFeatures);
}
}
log.info("finished computation of the factorization...");
- return createFactorization(U, M);
+ return createFactorization(features.getU(), features.getM());
}
- protected double[][] initializeM() throws TasteException {
- Random random = RandomUtils.getRandom();
- double[][] M = new double[dataModel.getNumItems()][numFeatures];
-
- LongPrimitiveIterator itemIDsIterator = dataModel.getItemIDs();
- while (itemIDsIterator.hasNext()) {
- long itemID = itemIDsIterator.nextLong();
- int itemIDIndex = itemIndex(itemID);
- M[itemIDIndex][0] = averateRating(itemID);
- for (int feature = 1; feature < numFeatures; feature++) {
- M[itemIDIndex][feature] = random.nextDouble() * 0.1;
- }
- }
- return M;
- }
-
- protected void setFeatureColumn(double[][] matrix, int idIndex, Vector
vector) {
- for (int feature = 0; feature < numFeatures; feature++) {
- matrix[idIndex][feature] = vector.get(feature);
- }
+ protected ExecutorService createQueue() {
+ return
Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
}
protected Vector ratingVector(PreferenceArray prefs) {
@@ -139,14 +213,4 @@ public class ALSWRFactorizer extends Abs
}
return new DenseVector(ratings);
}
-
- protected double averateRating(long itemID) throws TasteException {
- PreferenceArray prefs = dataModel.getPreferencesForItem(itemID);
- RunningAverage avg = new FullRunningAverage();
- for (Preference pref : prefs) {
- avg.addDatum(pref.getValue());
- }
- return avg.getAverage();
- }
-
}
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java?rev=1070622&r1=1070621&r2=1070622&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java
Mon Feb 14 19:40:15 2011
@@ -79,11 +79,12 @@ public class ALSWRFactorizerTest extends
@Test
public void setFeatureColumn() throws Exception {
- double[][] matrix = new double[3][3];
+ ALSWRFactorizer.Features features = new
ALSWRFactorizer.Features(factorizer);
Vector vector = new DenseVector(new double[] { 0.5, 2.0, 1.5 });
int index = 1;
- factorizer.setFeatureColumn(matrix, index, vector);
+ features.setFeatureColumnInM(index, vector);
+ double[][] matrix = features.getM();
assertEquals(vector.get(0), matrix[index][0], EPSILON);
assertEquals(vector.get(1), matrix[index][1], EPSILON);
@@ -104,12 +105,14 @@ public class ALSWRFactorizerTest extends
@Test
public void averageRating() throws Exception {
- assertEquals(2.5, factorizer.averateRating(3l), EPSILON);
+ ALSWRFactorizer.Features features = new
ALSWRFactorizer.Features(factorizer);
+ assertEquals(2.5, features.averateRating(3l), EPSILON);
}
@Test
public void initializeM() throws Exception {
- double[][] M = factorizer.initializeM();
+ ALSWRFactorizer.Features features = new
ALSWRFactorizer.Features(factorizer);
+ double[][] M = features.getM();
assertEquals(3.333333333, M[0][0], EPSILON);
assertEquals(5, M[1][0], EPSILON);