Author: srowen
Date: Tue Mar 8 08:57:52 2011
New Revision: 1079299
URL: http://svn.apache.org/viewvc?rev=1079299&view=rev
Log:
MAHOUT-541 faster factorization for SVDRecommender
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPreference.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVDFactorizer.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVDFactorizer.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVDFactorizer.java?rev=1079299&r1=1079298&r2=1079299&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVDFactorizer.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVDFactorizer.java
Tue Mar 8 08:57:52 2011
@@ -17,6 +17,11 @@
package org.apache.mahout.cf.taste.impl.recommender.svd;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
@@ -27,43 +32,31 @@ import org.apache.mahout.common.RandomUt
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-import java.util.Random;
+/** Calculates the SVD using an Expectation Maximization algorithm. */
+public final class ExpectationMaximizationSVDFactorizer extends
AbstractFactorizer {
-/**
- * Uses Single Value Decomposition to find the main features of the data set.
Thanks to Simon Funk for the hints
- * in the implementation, {@see
http://sifter.org/~simon/journal/20061211.html}.
- */
-public class ExpectationMaximizationSVDFactorizer extends AbstractFactorizer {
+ private static final Logger log =
LoggerFactory.getLogger(ExpectationMaximizationSVDFactorizer.class);
private final Random random;
-
private final double learningRate;
/** Parameter used to prevent overfitting. 0.02 is a good value. */
private final double preventOverfitting;
-
/** number of features used to compute this factorization */
private final int numFeatures;
/** number of iterations */
private final int numIterations;
-
/** user singular vectors */
private final double[][] leftVectors;
-
/** item singular vectors */
private final double[][] rightVectors;
-
private final DataModel dataModel;
- private final List<Preference> cachedPreferences;
-
- private static final Logger log =
LoggerFactory.getLogger(ExpectationMaximizationSVDFactorizer.class);
+ private final List<SVDPreference> cachedPreferences;
+ private final double defaultValue;
public ExpectationMaximizationSVDFactorizer(DataModel dataModel,
int numFeatures,
int numIterations) throws
TasteException {
- // use the default parameters from the old SVDRecommender implementation
+ // use the default parameters from the old SVDRecommender implementation
this(dataModel, numFeatures, 0.005, 0.02, 0.005, numIterations);
}
@@ -86,7 +79,7 @@ public class ExpectationMaximizationSVDF
rightVectors = new double[dataModel.getNumItems()][numFeatures];
double average = getAveragePreference();
- double defaultValue = Math.sqrt((average - 1.0) / numFeatures);
+ defaultValue = Math.sqrt((average - 1.0) / numFeatures);
for (int feature = 0; feature < numFeatures; feature++) {
for (int userIndex = 0; userIndex < dataModel.getNumUsers();
userIndex++) {
@@ -96,32 +89,34 @@ public class ExpectationMaximizationSVDF
rightVectors[itemIndex][feature] = defaultValue + (random.nextDouble()
- 0.5) * randomNoise;
}
}
-
- cachedPreferences = new ArrayList<Preference>(dataModel.getNumUsers());
+ cachedPreferences = new ArrayList<SVDPreference>(dataModel.getNumUsers());
}
@Override
public Factorization factorize() throws TasteException {
- log.info("starting to compute the factorization...");
-
cachePreferences();
- for (int currentIteration = 0; currentIteration < numIterations;
currentIteration++) {
- log.info("iteration {}", currentIteration);
- nextTrainStep();
- }
-
- log.info("finished computation of the factorization...");
- return createFactorization(leftVectors, rightVectors);
- }
-
- void cachePreferences() throws TasteException {
- cachedPreferences.clear();
- LongPrimitiveIterator it = dataModel.getUserIDs();
- while (it.hasNext()) {
- for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) {
- cachedPreferences.add(pref);
+ double rmse = (dataModel.getMaxPreference() -
dataModel.getMinPreference());
+ Collections.shuffle(cachedPreferences, random);
+ for (int ii = 0; ii < numFeatures; ii++) {
+ for (int i = 0; (i < numIterations); i++) {
+ double err = 0.0;
+ for (SVDPreference pref : cachedPreferences) {
+ int useridx = userIndex(pref.getUserID());
+ int itemidx = itemIndex(pref.getItemID());
+ err += Math.pow(train(useridx, itemidx, ii, pref), 2.0);
+ }
+ rmse = Math.sqrt(err / cachedPreferences.size());
+ }
+ if (ii < numFeatures - 1) {
+ for (SVDPreference pref : cachedPreferences) {
+ int useridx = userIndex(pref.getUserID());
+ int itemidx = itemIndex(pref.getItemID());
+ buildCache(useridx, itemidx, ii, pref);
+ }
}
+ log.info("Finished training feature {} with RMSE {}.", ii, rmse);
}
+ return createFactorization(leftVectors, rightVectors);
}
double getAveragePreference() throws TasteException {
@@ -135,31 +130,44 @@ public class ExpectationMaximizationSVDF
return average.getAverage();
}
- void nextTrainStep() {
- Collections.shuffle(cachedPreferences, random);
- for (int feature = 0; feature < numFeatures; feature++) {
- for (Preference pref : cachedPreferences) {
- train(userIndex(pref.getUserID()), itemIndex(pref.getItemID()),
feature, pref.getValue());
+ private double train(int i, int j, int f, SVDPreference pref) {
+ double[] leftVectorI = leftVectors[i];
+ double[] rightVectorJ = rightVectors[j];
+ double prediction = predictRating(i, j, f, pref, true);
+ double err = pref.getValue() - prediction;
+ leftVectorI[f] += learningRate * (err * rightVectorJ[f] -
preventOverfitting * leftVectorI[f]);
+ rightVectorJ[f] += learningRate * (err * leftVectorI[f] -
preventOverfitting * rightVectorJ[f]);
+ return err;
+ }
+
+ private void buildCache(int i, int j, int k, SVDPreference pref) {
+ pref.setCache(predictRating(i, j, k, pref, false));
+ }
+
+ private double predictRating(int i, int j, int f, SVDPreference pref,
boolean trailing) {
+ float minPreference = dataModel.getMinPreference();
+ float maxPreference = dataModel.getMaxPreference();
+ double sum = pref.getCache();
+ sum += leftVectors[i][f] * rightVectors[j][f];
+ if (trailing) {
+ sum += (numFeatures - f - 1) * (defaultValue * defaultValue);
+ if (sum > dataModel.getMaxPreference()) {
+ sum = maxPreference;
+ } else if (sum < minPreference) {
+ sum = minPreference;
}
}
+ return sum;
}
- double getDotProduct(int userIndex, int itemIndex) {
- double result = 1.0;
- for (int feature = 0; feature < this.numFeatures; feature++) {
- result += leftVectors[userIndex][feature] *
rightVectors[itemIndex][feature];
+ private void cachePreferences() throws TasteException {
+ cachedPreferences.clear();
+ LongPrimitiveIterator it = dataModel.getUserIDs();
+ while (it.hasNext()) {
+ for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) {
+ cachedPreferences.add(new SVDPreference(pref.getUserID(),
pref.getItemID(), pref.getValue(), 0.0));
+ }
}
- return result;
- }
-
- void train(int userIndex, int itemIndex, int currentFeature, double value) {
- double err = value - getDotProduct(userIndex, itemIndex);
- double[] leftVector = leftVectors[userIndex];
- double[] rightVector = rightVectors[itemIndex];
- leftVector[currentFeature] +=
- learningRate * (err * rightVector[currentFeature] - preventOverfitting
* leftVector[currentFeature]);
- rightVector[currentFeature] +=
- learningRate * (err * leftVector[currentFeature] - preventOverfitting
* rightVector[currentFeature]);
}
}
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPreference.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPreference.java?rev=1079299&view=auto
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPreference.java
(added)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPreference.java
Tue Mar 8 08:57:52 2011
@@ -0,0 +1,42 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.cf.taste.impl.model.GenericPreference;
+
+final class SVDPreference extends GenericPreference {
+
+ private double cache;
+
+ SVDPreference(long userID, long itemID, float value, double cache) {
+ super(userID, itemID, value);
+ Preconditions.checkArgument(!Double.isNaN(cache), "Invalid cache value: "
+ cache);
+ this.cache = cache;
+ }
+
+ public double getCache() {
+ return cache;
+ }
+
+ public void setCache(double value) {
+ Preconditions.checkArgument(!Double.isNaN(value), "Invalid cache value: "
+ value);
+ this.cache = value;
+ }
+
+}