Author: srowen
Date: Tue Mar  8 08:57:52 2011
New Revision: 1079299

URL: http://svn.apache.org/viewvc?rev=1079299&view=rev
Log:
MAHOUT-541 faster factorization for SVDRecommender

Added:
    
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPreference.java
Modified:
    
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVDFactorizer.java

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVDFactorizer.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVDFactorizer.java?rev=1079299&r1=1079298&r2=1079299&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVDFactorizer.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVDFactorizer.java
 Tue Mar  8 08:57:52 2011
@@ -17,6 +17,11 @@
 
 package org.apache.mahout.cf.taste.impl.recommender.svd;
 
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+
 import org.apache.mahout.cf.taste.common.TasteException;
 import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
 import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
@@ -27,43 +32,31 @@ import org.apache.mahout.common.RandomUt
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-import java.util.Random;
+/** Calculates the SVD using an Expectation Maximization algorithm. */
+public final class ExpectationMaximizationSVDFactorizer extends 
AbstractFactorizer {
 
-/**
- * Uses Single Value Decomposition to find the main features of the data set. 
Thanks to Simon Funk for the hints
- * in the implementation, {@see 
http://sifter.org/~simon/journal/20061211.html}.
- */
-public class ExpectationMaximizationSVDFactorizer extends AbstractFactorizer {
+  private static final Logger log = 
LoggerFactory.getLogger(ExpectationMaximizationSVDFactorizer.class);
 
   private final Random random;
-
   private final double learningRate;
   /** Parameter used to prevent overfitting. 0.02 is a good value. */
   private final double preventOverfitting;
-
   /** number of features used to compute this factorization */
   private final int numFeatures;
   /** number of iterations */
   private final int numIterations;
-
   /** user singular vectors */
   private final double[][] leftVectors;
-
   /** item singular vectors */
   private final double[][] rightVectors;
-
   private final DataModel dataModel;
-  private final List<Preference> cachedPreferences;
-
-  private static final Logger log = 
LoggerFactory.getLogger(ExpectationMaximizationSVDFactorizer.class);
+  private final List<SVDPreference> cachedPreferences;
+  private final double defaultValue;
 
   public ExpectationMaximizationSVDFactorizer(DataModel dataModel,
                                               int numFeatures,
                                               int numIterations) throws 
TasteException {
-    // use the default parameters from the old SVDRecommender implementation   
 
+    // use the default parameters from the old SVDRecommender implementation
     this(dataModel, numFeatures, 0.005, 0.02, 0.005, numIterations);
   }
 
@@ -86,7 +79,7 @@ public class ExpectationMaximizationSVDF
     rightVectors = new double[dataModel.getNumItems()][numFeatures];
 
     double average = getAveragePreference();
-    double defaultValue = Math.sqrt((average - 1.0) / numFeatures);
+    defaultValue = Math.sqrt((average - 1.0) / numFeatures);
 
     for (int feature = 0; feature < numFeatures; feature++) {
       for (int userIndex = 0; userIndex < dataModel.getNumUsers(); 
userIndex++) {
@@ -96,32 +89,34 @@ public class ExpectationMaximizationSVDF
         rightVectors[itemIndex][feature] = defaultValue + (random.nextDouble() 
- 0.5) * randomNoise;
       }
     }
-
-    cachedPreferences = new ArrayList<Preference>(dataModel.getNumUsers());
+    cachedPreferences = new ArrayList<SVDPreference>(dataModel.getNumUsers());
   }
 
   @Override
   public Factorization factorize() throws TasteException {
-    log.info("starting to compute the factorization...");
-
     cachePreferences();
-    for (int currentIteration = 0; currentIteration < numIterations; 
currentIteration++) {
-      log.info("iteration {}", currentIteration);
-      nextTrainStep();
-    }
-
-    log.info("finished computation of the factorization...");
-    return createFactorization(leftVectors, rightVectors);
-  }
-
-  void cachePreferences() throws TasteException {
-    cachedPreferences.clear();
-    LongPrimitiveIterator it = dataModel.getUserIDs();
-    while (it.hasNext()) {
-      for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) {
-        cachedPreferences.add(pref);
+    double rmse = (dataModel.getMaxPreference() - 
dataModel.getMinPreference());
+    Collections.shuffle(cachedPreferences, random);
+    for (int ii = 0; ii < numFeatures; ii++) {
+      for (int i = 0; (i < numIterations); i++) {
+        double err = 0.0;
+        for (SVDPreference pref : cachedPreferences) {
+          int useridx = userIndex(pref.getUserID());
+          int itemidx = itemIndex(pref.getItemID());
+          err += Math.pow(train(useridx, itemidx, ii, pref), 2.0);
+        }
+        rmse = Math.sqrt(err / cachedPreferences.size());
+      }
+      if (ii < numFeatures - 1) {
+        for (SVDPreference pref : cachedPreferences) {
+          int useridx = userIndex(pref.getUserID());
+          int itemidx = itemIndex(pref.getItemID());
+          buildCache(useridx, itemidx, ii, pref);
+        }
       }
+      log.info("Finished training feature {} with RMSE {}.", ii, rmse);
     }
+    return createFactorization(leftVectors, rightVectors);
   }
 
   double getAveragePreference() throws TasteException {
@@ -135,31 +130,44 @@ public class ExpectationMaximizationSVDF
     return average.getAverage();
   }
 
-  void nextTrainStep() {
-    Collections.shuffle(cachedPreferences, random);
-    for (int feature = 0; feature < numFeatures; feature++) {
-      for (Preference pref : cachedPreferences) {
-        train(userIndex(pref.getUserID()), itemIndex(pref.getItemID()), 
feature, pref.getValue());
+  private double train(int i, int j, int f, SVDPreference pref) {
+    double[] leftVectorI = leftVectors[i];
+    double[] rightVectorJ = rightVectors[j];
+    double prediction = predictRating(i, j, f, pref, true);
+    double err = pref.getValue() - prediction;
+    leftVectorI[f] += learningRate * (err * rightVectorJ[f] - 
preventOverfitting * leftVectorI[f]);
+    rightVectorJ[f] += learningRate * (err * leftVectorI[f] - 
preventOverfitting * rightVectorJ[f]);
+    return err;
+  }
+
+  private void buildCache(int i, int j, int k, SVDPreference pref) {
+    pref.setCache(predictRating(i, j, k, pref, false));
+  }
+
+  private double predictRating(int i, int j, int f, SVDPreference pref, 
boolean trailing) {
+    float minPreference = dataModel.getMinPreference();
+    float maxPreference = dataModel.getMaxPreference();
+    double sum = pref.getCache();
+    sum += leftVectors[i][f] * rightVectors[j][f];
+    if (trailing) {
+      sum += (numFeatures - f - 1) * (defaultValue * defaultValue);
+      if (sum > dataModel.getMaxPreference()) {
+        sum = maxPreference;
+      } else if (sum < minPreference) {
+        sum = minPreference;
       }
     }
+    return sum;
   }
 
-  double getDotProduct(int userIndex, int itemIndex) {
-    double result = 1.0;
-    for (int feature = 0; feature < this.numFeatures; feature++) {
-      result += leftVectors[userIndex][feature] * 
rightVectors[itemIndex][feature];
+  private void cachePreferences() throws TasteException {
+    cachedPreferences.clear();
+    LongPrimitiveIterator it = dataModel.getUserIDs();
+    while (it.hasNext()) {
+      for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) {
+        cachedPreferences.add(new SVDPreference(pref.getUserID(), 
pref.getItemID(), pref.getValue(), 0.0));
+      }
     }
-    return result;
-  }
-
-  void train(int userIndex, int itemIndex, int currentFeature, double value) {
-    double err = value - getDotProduct(userIndex, itemIndex);
-    double[] leftVector = leftVectors[userIndex];
-    double[] rightVector = rightVectors[itemIndex];
-    leftVector[currentFeature] +=
-        learningRate * (err * rightVector[currentFeature] - preventOverfitting 
* leftVector[currentFeature]);
-    rightVector[currentFeature] +=
-        learningRate * (err * leftVector[currentFeature] - preventOverfitting 
* rightVector[currentFeature]);
   }
 
 }

Added: 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPreference.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPreference.java?rev=1079299&view=auto
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPreference.java
 (added)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPreference.java
 Tue Mar  8 08:57:52 2011
@@ -0,0 +1,42 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.cf.taste.impl.model.GenericPreference;
+
+final class SVDPreference extends GenericPreference {
+
+  private double cache;
+
+  SVDPreference(long userID, long itemID, float value, double cache) {
+    super(userID, itemID, value);
+    Preconditions.checkArgument(!Double.isNaN(cache), "Invalid cache value: " 
+ cache);
+    this.cache = cache;
+  }
+
+  public double getCache() {
+    return cache;
+  }
+
+  public void setCache(double value) {
+    Preconditions.checkArgument(!Double.isNaN(value), "Invalid cache value: " 
+ value);
+    this.cache = value;
+  }
+
+}


Reply via email to