Author: ssc
Date: Sun Sep 30 20:02:00 2012
New Revision: 1392101
URL: http://svn.apache.org/viewvc?rev=1392101&view=rev
Log:
MAHOUT-1088 biased item-based recommender
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/BiasedItemBasedRecommender.java
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/BiasedItemBasedRecommender.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/BiasedItemBasedRecommender.java?rev=1392101&view=auto
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/BiasedItemBasedRecommender.java
(added)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/BiasedItemBasedRecommender.java
Sun Sep 30 20:02:00 2012
@@ -0,0 +1,199 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+import org.apache.mahout.math.map.OpenLongDoubleHashMap;
+
+/**
+ * item-based recommender that uses weighted sum estimation enhanced by
baseline estimates, porting baseline estimation from
+ * the "UserItemBaseline" rating predictor from "mymedialite"
https://github.com/zenogantner/MyMediaLite/
+ */
+public class BiasedItemBasedRecommender extends GenericItemBasedRecommender {
+
+ private final int numSimilarItems;
+
+ private final double averageRating;
+ private final OpenLongDoubleHashMap itemBiases;
+ private final OpenLongDoubleHashMap userBiases;
+
+ private static final int DEFAULT_NUM_SIMILAR_ITEMS = 50;
+ private static final int DEFAULT_NUM_OPTIMIZATION_PASSES = 5;
+ private static final double DEFAULT_USER_BIAS_REGULARIZATION = 10;
+ private static final double DEFAULT_ITEM_BIAS_REGULARIZATION = 5;
+
+ private final ItemSimilarity similarity;
+
+ public BiasedItemBasedRecommender(DataModel dataModel, ItemSimilarity
similarity) throws TasteException {
+ this(dataModel, similarity, DEFAULT_NUM_SIMILAR_ITEMS,
DEFAULT_NUM_OPTIMIZATION_PASSES,
+ DEFAULT_ITEM_BIAS_REGULARIZATION, DEFAULT_USER_BIAS_REGULARIZATION);
+ }
+
+ public BiasedItemBasedRecommender(DataModel dataModel, ItemSimilarity
similarity, int numSimilarItems,
+ int numOptimizationPasses, double itemBiasRegularization, double
userBiasRegularization) throws TasteException {
+ super(dataModel, similarity);
+ this.numSimilarItems = numSimilarItems;
+ this.similarity = similarity;
+
+ averageRating = averageRating();
+
+ itemBiases = new OpenLongDoubleHashMap(getDataModel().getNumItems());
+ userBiases = new OpenLongDoubleHashMap(getDataModel().getNumUsers());
+
+ for (int pass = 0; pass < numOptimizationPasses; pass++) {
+ optimizeItemBiases(itemBiasRegularization);
+ optimizeUserBiases(userBiasRegularization);
+ }
+ }
+
+ private void optimizeItemBiases(double itemBiasRegularization) throws
TasteException {
+ LongPrimitiveIterator itemIDs = getDataModel().getItemIDs();
+ while (itemIDs.hasNext()) {
+ long itemID = itemIDs.nextLong();
+ PreferenceArray preferences =
getDataModel().getPreferencesForItem(itemID);
+ double sum = 0;
+ for (Preference pref : preferences) {
+ sum += pref.getValue() - averageRating;
+ }
+ double bias = sum / (itemBiasRegularization + preferences.length());
+ itemBiases.put(itemID, bias);
+ }
+ }
+
+ private void optimizeUserBiases(double userBiasRegularization) throws
TasteException {
+ LongPrimitiveIterator userIDs = getDataModel().getUserIDs();
+ while (userIDs.hasNext()) {
+ long userID = userIDs.nextLong();
+ PreferenceArray preferences =
getDataModel().getPreferencesFromUser(userID);
+ double sum = 0;
+ for (Preference pref : preferences) {
+ sum += pref.getValue() - averageRating -
itemBiases.get(pref.getItemID());
+ }
+ double bias = sum / (userBiasRegularization + preferences.length());
+ userBiases.put(userID, bias);
+ }
+ }
+
+ private double averageRating() throws TasteException {
+ RunningAverage averageRating = new FullRunningAverage();
+ LongPrimitiveIterator itemIDs = getDataModel().getItemIDs();
+ while (itemIDs.hasNext()) {
+ for (Preference pref :
getDataModel().getPreferencesForItem(itemIDs.next())) {
+ averageRating.addDatum(pref.getValue());
+ }
+ }
+ return averageRating.getAverage();
+ }
+
+ @Override
+ public float estimatePreference(long userID, long itemID) throws
TasteException {
+ PreferenceArray preferencesFromUser =
getDataModel().getPreferencesFromUser(userID);
+ Float actualPref = getPreferenceForItem(preferencesFromUser, itemID);
+ if (actualPref != null) {
+ return actualPref;
+ }
+ return doEstimatePreference(userID, preferencesFromUser, itemID);
+ }
+
+ private static Float getPreferenceForItem(PreferenceArray
preferencesFromUser, long itemID) {
+ int size = preferencesFromUser.length();
+ for (int i = 0; i < size; i++) {
+ if (preferencesFromUser.getItemID(i) == itemID) {
+ return preferencesFromUser.getValue(i);
+ }
+ }
+ return null;
+ }
+
+ protected double baselineEstimate(long userID, long itemID) throws
TasteException {
+ return averageRating + userBiases.get(userID) + itemBiases.get(itemID);
+ }
+
+ @Override
+ protected float doEstimatePreference(long userID, PreferenceArray
preferencesFromUser, long itemID)
+ throws TasteException {
+ double preference = 0.0;
+ double totalSimilarity = 0.0;
+ int count = 0;
+ long[] userIDs = preferencesFromUser.getIDs();
+ float[] ratings = new float[userIDs.length];
+ long[] itemIDs = new long[userIDs.length];
+
+ double[] similarities = similarity.itemSimilarities(itemID, userIDs);
+
+ for (int n = 0; n < preferencesFromUser.length(); n++) {
+ ratings[n] = preferencesFromUser.get(n).getValue();
+ itemIDs[n] = preferencesFromUser.get(n).getItemID();
+ }
+
+ quickSort(similarities, ratings, itemIDs, 0, (similarities.length - 1));
+
+ for (int i = 0; i < Math.min(numSimilarItems, similarities.length); i++) {
+ double theSimilarity = similarities[i];
+ if (!Double.isNaN(theSimilarity)) {
+ if (Double.isInfinite(theSimilarity)) {
+ throw new IllegalStateException();
+ }
+ preference += theSimilarity * (ratings[i] - baselineEstimate(userID,
itemIDs[i]));
+ totalSimilarity += Math.abs(theSimilarity);
+ count++;
+ }
+ }
+
+ if (count <= 1) {
+ return Float.NaN;
+ }
+
+ return (float) (baselineEstimate(userID, itemID) + (preference /
totalSimilarity));
+ }
+
+ //TODO is it possible to do this without recursion?
+ protected void quickSort(double[] similarities, float[] values, long[]
otherValues, int start, int end) {
+ if (start < end) {
+ double pivot = similarities[end];
+ float pivotValue = values[end];
+ int i = start;
+ int j = end;
+ while (i != j) {
+ if (similarities[i] > pivot) {
+ i = i + 1;
+ }
+ else {
+ similarities[j] = similarities[i];
+ values[j] = values[i];
+ otherValues[j] = otherValues[i];
+ similarities[i] = similarities[j - 1];
+ values[i] = values[j - 1];
+ otherValues[i] = otherValues[j - 1];
+ j = j - 1;
+ }
+ }
+ similarities[j] = pivot;
+ values[j] = pivotValue ;
+ quickSort(similarities, values, otherValues, start, j - 1);
+ quickSort(similarities, values, otherValues, j + 1, end);
+ }
+ }
+}
\ No newline at end of file