Author: ssc
Date: Mon Oct 29 20:50:45 2012
New Revision: 1403497
URL: http://svn.apache.org/viewvc?rev=1403497&view=rev
Log:
MAHOUT-1089 SGD matrix factorization for rating prediction with user and item
biases
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java?rev=1403497&view=auto
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java
(added)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java
Mon Oct 29 20:50:45 2012
@@ -0,0 +1,224 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.RandomUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Random;
+
+/** Matrix factorization with user and item biases for rating prediction,
trained with plain vanilla SGD */
+public class RatingSGDFactorizer extends AbstractFactorizer {
+
+ /** Multiplicative decay factor for learning_rate */
+ protected final double learningRateDecay;
+ /** Learning rate (step size) */
+ protected final double learningRate;
+ /** Parameter used to prevent overfitting. */
+ protected final double preventOverfitting;
+ /** Number of features used to compute this factorization */
+ protected final int numFeatures;
+ protected final int featureOffset = 3;
+ /** Number of iterations */
+ private final int numIterations;
+ /** Standard deviation for random initialization of features */
+ protected final double randomNoise;
+ /** User features */
+ protected double[][] userVectors;
+ /** Item features */
+ protected double[][] itemVectors;
+ protected final DataModel dataModel;
+ private long[] cachedUserIDs;
+ private long[] cachedItemIDs;
+
+ protected double biasLearningRate = 0.5;
+ protected double biasReg = 0.1;
+
+ /** place in user vector where the bias is stored */
+ protected static final int USER_BIAS_INDEX = 1;
+ /** place in item vector where the bias is stored */
+ protected static final int ITEM_BIAS_INDEX = 2;
+
+ public RatingSGDFactorizer(DataModel dataModel, int numFeatures, int
numIterations) throws TasteException {
+ this(dataModel, numFeatures, 0.01, 0.1, 0.01, numIterations, 1.0);
+ }
+
+ public RatingSGDFactorizer(DataModel dataModel, int numFeatures, double
learningRate, double preventOverfitting,
+ double randomNoise, int numIterations, double learningRateDecay) throws
TasteException {
+ super(dataModel);
+ this.dataModel = dataModel;
+ this.numFeatures = numFeatures + featureOffset;
+ this.numIterations = numIterations;
+
+ this.learningRate = learningRate;
+ this.learningRateDecay = learningRateDecay;
+ this.preventOverfitting = preventOverfitting;
+ this.randomNoise = randomNoise;
+ }
+
+ protected void prepareTraining() throws TasteException {
+ Random random = RandomUtils.getRandom();
+ userVectors = new double[dataModel.getNumUsers()][numFeatures];
+ itemVectors = new double[dataModel.getNumItems()][numFeatures];
+
+ double globalAverage = getAveragePreference();
+ for (int userIndex = 0; userIndex < userVectors.length; userIndex++) {
+ userVectors[userIndex][0] = globalAverage;
+ userVectors[userIndex][USER_BIAS_INDEX] = 0; // will store user bias
+ userVectors[userIndex][ITEM_BIAS_INDEX] = 1; // corresponding item
feature contains item bias
+ for (int feature = featureOffset; feature < numFeatures; feature++) {
+ userVectors[userIndex][feature] = random.nextGaussian() * randomNoise;
+ }
+ }
+ for (int itemIndex = 0; itemIndex < itemVectors.length; itemIndex++) {
+ itemVectors[itemIndex][0] = 1; // corresponding user feature contains
global average
+ itemVectors[itemIndex][USER_BIAS_INDEX] = 1; // corresponding user
feature contains user bias
+ itemVectors[itemIndex][ITEM_BIAS_INDEX] = 0; // will store item bias
+ for (int feature = featureOffset; feature < numFeatures; feature++) {
+ itemVectors[itemIndex][feature] = random.nextGaussian() * randomNoise;
+ }
+ }
+
+ cachePreferences();
+ shufflePreferences();
+ }
+
+ private int countPreferences() throws TasteException {
+ int numPreferences = 0;
+ LongPrimitiveIterator userIDs = dataModel.getUserIDs();
+ while (userIDs.hasNext()) {
+ PreferenceArray preferencesFromUser =
dataModel.getPreferencesFromUser(userIDs.nextLong());
+ numPreferences += preferencesFromUser.length();
+ }
+ return numPreferences;
+ }
+
+ private void cachePreferences() throws TasteException {
+ int numPreferences = countPreferences();
+ cachedUserIDs = new long[numPreferences];
+ cachedItemIDs = new long[numPreferences];
+
+ LongPrimitiveIterator userIDs = dataModel.getUserIDs();
+ int index = 0;
+ while (userIDs.hasNext()) {
+ long userID = userIDs.nextLong();
+ PreferenceArray preferencesFromUser =
dataModel.getPreferencesFromUser(userID);
+ for (Preference preference : preferencesFromUser) {
+ cachedUserIDs[index] = userID;
+ cachedItemIDs[index] = preference.getItemID();
+ index++;
+ }
+ }
+ }
+
+ protected void shufflePreferences() {
+ Random random = RandomUtils.getRandom();
+ /* Durstenfeld shuffle */
+ for (int currentPos = cachedUserIDs.length - 1; currentPos > 0;
currentPos--) {
+ int swapPos = random.nextInt(currentPos + 1);
+ swapCachedPreferences(currentPos, swapPos);
+ }
+ }
+
+ private void swapCachedPreferences(int posA, int posB) {
+ long tmpUserIndex = cachedUserIDs[posA];
+ long tmpItemIndex = cachedItemIDs[posA];
+
+ cachedUserIDs[posA] = cachedUserIDs[posB];
+ cachedItemIDs[posA] = cachedItemIDs[posB];
+
+ cachedUserIDs[posB] = tmpUserIndex;
+ cachedItemIDs[posB] = tmpItemIndex;
+ }
+
+ @Override
+ public Factorization factorize() throws TasteException {
+ prepareTraining();
+ double currentLearningRate = learningRate;
+
+
+ for (int it = 0; it < numIterations; it++) {
+ for (int index = 0; index < cachedUserIDs.length; index++) {
+ long userId = cachedUserIDs[index];
+ long itemId = cachedItemIDs[index];
+ float rating = dataModel.getPreferenceValue(userId, itemId);
+ updateParameters(userId, itemId, rating, currentLearningRate);
+ }
+ currentLearningRate *= learningRateDecay;
+ }
+ return createFactorization(userVectors, itemVectors);
+ }
+
+ double getAveragePreference() throws TasteException {
+ RunningAverage average = new FullRunningAverage();
+ LongPrimitiveIterator it = dataModel.getUserIDs();
+ while (it.hasNext()) {
+ for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) {
+ average.addDatum(pref.getValue());
+ }
+ }
+ return average.getAverage();
+ }
+
+ protected void updateParameters(long userID, long itemID, float rating,
double currentLearningRate)
+ throws TasteException {
+ int userIndex = userIndex(userID);
+ int itemIndex = itemIndex(itemID);
+
+ double[] userVector = userVectors[userIndex];
+ double[] itemVector = itemVectors[itemIndex];
+ double prediction = predictRating(userIndex, itemIndex);
+ double err = rating - prediction;
+
+ // adjust user bias
+ userVector[USER_BIAS_INDEX] +=
+ biasLearningRate * currentLearningRate * (err - biasReg *
preventOverfitting * userVector[USER_BIAS_INDEX]);
+
+ // adjust item bias
+ itemVector[ITEM_BIAS_INDEX] +=
+ biasLearningRate * currentLearningRate * (err - biasReg *
preventOverfitting * itemVector[ITEM_BIAS_INDEX]);
+
+ // adjust features
+ for (int feature = featureOffset; feature < numFeatures; feature++) {
+ double userFeature = userVector[feature];
+ double itemFeature = itemVector[feature];
+
+ double deltaUserFeature = err * itemFeature - preventOverfitting *
userFeature;
+ userVector[feature] += currentLearningRate * deltaUserFeature;
+
+ double deltaItemFeature = err * userFeature - preventOverfitting *
itemFeature;
+ itemVector[feature] += currentLearningRate * deltaItemFeature;
+ }
+ }
+
+ private double predictRating(int userID, int itemID) {
+ double sum = 0;
+ for (int feature = 0; feature < numFeatures; feature++) {
+ sum += userVectors[userID][feature] * itemVectors[itemID][feature];
+ }
+ return sum;
+ }
+}
\ No newline at end of file