Author: ssc
Date: Mon Jan 3 09:18:46 2011
New Revision: 1054567
URL: http://svn.apache.org/viewvc?rev=1054567&view=rev
Log:
MAHOUT-572 Non-distributed implementation of ALS-WR matrix factorization
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVDFactorizer.java
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommenderTest.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/als/
mahout/trunk/math/src/main/java/org/apache/mahout/math/als/AlternateLeastSquaresSolver.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/als/
mahout/trunk/math/src/test/java/org/apache/mahout/math/als/AlternateLeastSquaresSolverTest.java
Removed:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVD.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java?rev=1054567&view=auto
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
(added)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
Mon Jan 3 09:18:46 2011
@@ -0,0 +1,152 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.als.AlternateLeastSquaresSolver;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * factorizes the rating matrix using "Alternating-Least-Squares with
Weighted-λ-Regularization" as described in
+ * the paper "Large-scale Collaborative Filtering for the Netflix Prize"
available at
+ * {...@see
http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/netflix_aaim08(submitted).pdf}
+ */
+public class ALSWRFactorizer extends AbstractFactorizer {
+
+ private final DataModel dataModel;
+
+ /** number of features used to compute this factorization */
+ private final int numFeatures;
+ /** parameter to control the regularization */
+ private final double lambda;
+ /** number of iterations */
+ private final int numIterations;
+
+ private static final Logger log =
LoggerFactory.getLogger(ALSWRFactorizer.class);
+
+ public ALSWRFactorizer(DataModel dataModel, int numFeatures, double lambda,
int numIterations) throws TasteException {
+ super(dataModel);
+ this.dataModel = dataModel;
+ this.numFeatures = numFeatures;
+ this.lambda = lambda;
+ this.numIterations = numIterations;
+ }
+
+ @Override
+ public Factorization factorize() throws TasteException {
+ log.info("starting to compute the factorization...");
+ AlternateLeastSquaresSolver solver = new AlternateLeastSquaresSolver();
+
+ double[][] M = initializeM();
+ double[][] U = null;
+
+ for (int iteration = 0; iteration < numIterations; iteration++) {
+ log.info("iteration {}", iteration);
+
+ /* fix M - compute U */
+ U = new double[dataModel.getNumUsers()][numFeatures];
+
+ LongPrimitiveIterator userIDsIterator = dataModel.getUserIDs();
+ while (userIDsIterator.hasNext()) {
+ long userID = userIDsIterator.nextLong();
+ List<Vector> featureVectors = new ArrayList<Vector>();
+ LongPrimitiveIterator itemIDsFromUser =
dataModel.getItemIDsFromUser(userID).iterator();
+ while (itemIDsFromUser.hasNext()) {
+ long itemID = itemIDsFromUser.nextLong();
+ featureVectors.add(new DenseVector(M[itemIndex(itemID)]));
+ }
+ PreferenceArray userPrefs = dataModel.getPreferencesFromUser(userID);
+ Vector userFeatures = solver.solve(featureVectors,
ratingVector(userPrefs), lambda, numFeatures);
+ setFeatureColumn(U, userIndex(userID), userFeatures);
+ }
+
+ /* fix U - compute M */
+ M = new double[dataModel.getNumItems()][numFeatures];
+
+ LongPrimitiveIterator itemIDsIterator = dataModel.getItemIDs();
+ while (itemIDsIterator.hasNext()) {
+ long itemID = itemIDsIterator.nextLong();
+ List<Vector> featureVectors = new ArrayList<Vector>();
+ for (Preference pref : dataModel.getPreferencesForItem(itemID)) {
+ long userID = pref.getUserID();
+ featureVectors.add(new DenseVector(U[userIndex(userID)]));
+ }
+ PreferenceArray itemPrefs = dataModel.getPreferencesForItem(itemID);
+ Vector itemFeatures = solver.solve(featureVectors,
ratingVector(itemPrefs), lambda, numFeatures);
+ setFeatureColumn(M, itemIndex(itemID), itemFeatures);
+ }
+ }
+
+ log.info("finished computation of the factorization...");
+ return createFactorization(U, M);
+ }
+
+ protected double[][] initializeM() throws TasteException {
+ Random random = RandomUtils.getRandom();
+ double[][] M = new double[dataModel.getNumItems()][numFeatures];
+
+ LongPrimitiveIterator itemIDsIterator = dataModel.getItemIDs();
+ while (itemIDsIterator.hasNext()) {
+ long itemID = itemIDsIterator.nextLong();
+ int itemIDIndex = itemIndex(itemID);
+ M[itemIDIndex][0] = averateRating(itemID);
+ for (int feature = 1; feature < numFeatures; feature++) {
+ M[itemIDIndex][feature] = random.nextDouble() * 0.1;
+ }
+ }
+ return M;
+ }
+
+ protected void setFeatureColumn(double[][] matrix, int idIndex, Vector
vector) {
+ for (int feature = 0; feature < numFeatures; feature++) {
+ matrix[idIndex][feature] = vector.get(feature);
+ }
+ }
+
+ protected Vector ratingVector(PreferenceArray prefs) {
+ double[] ratings = new double[prefs.length()];
+ for (int n = 0; n < prefs.length(); n++) {
+ ratings[n] = prefs.get(n).getValue();
+ }
+ return new DenseVector(ratings);
+ }
+
+ protected double averateRating(long itemID) throws TasteException {
+ PreferenceArray prefs = dataModel.getPreferencesForItem(itemID);
+ RunningAverage avg = new FullRunningAverage();
+ for (Preference pref : prefs) {
+ avg.addDatum(pref.getValue());
+ }
+ return avg.getAverage();
+ }
+
+}
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java?rev=1054567&view=auto
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java
(added)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java
Mon Jan 3 09:18:46 2011
@@ -0,0 +1,58 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.model.DataModel;
+
+/**
+ * base class for {...@link Factorizer}s, provides ID to index mapping
+ */
+public abstract class AbstractFactorizer implements Factorizer {
+
+ private final FastByIDMap<Integer> userIDMapping;
+ private final FastByIDMap<Integer> itemIDMapping;
+
+ protected AbstractFactorizer(DataModel dataModel) throws TasteException {
+ userIDMapping = createIDMapping(dataModel.getNumUsers(),
dataModel.getUserIDs());
+ itemIDMapping = createIDMapping(dataModel.getNumItems(),
dataModel.getItemIDs());
+ }
+
+ protected Factorization createFactorization(double[][] userFeatures,
double[][] itemFeatures) {
+ return new Factorization(userIDMapping, itemIDMapping, userFeatures,
itemFeatures);
+ }
+
+ protected Integer userIndex(long userID) {
+ return userIDMapping.get(userID);
+ }
+
+ protected Integer itemIndex(long itemID) {
+ return itemIDMapping.get(itemID);
+ }
+
+ private FastByIDMap<Integer> createIDMapping(int size, LongPrimitiveIterator
idIterator) {
+ FastByIDMap<Integer> mapping = new FastByIDMap<Integer>(size);
+ int index = 0;
+ while (idIterator.hasNext()) {
+ mapping.put(idIterator.nextLong(), index++);
+ }
+ return mapping;
+ }
+}
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVDFactorizer.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVDFactorizer.java?rev=1054567&view=auto
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVDFactorizer.java
(added)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVDFactorizer.java
Mon Jan 3 09:18:46 2011
@@ -0,0 +1,160 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.common.RandomUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * Uses Single Value Decomposition to find the main features of the data set.
Thanks to Simon Funk for the hints
+ * in the implementation, {...@see
http://sifter.org/~simon/journal/20061211.html}.
+ */
+public class ExpectationMaximizationSVDFactorizer extends AbstractFactorizer {
+
+ private final Random random;
+
+ private final double learningRate;
+ /** Parameter used to prevent overfitting. 0.02 is a good value. */
+ private final double preventOverfitting;
+
+ /** number of features used to compute this factorization */
+ private final int numFeatures;
+ /** number of iterations */
+ private final int numIterations;
+
+ /** user singular vectors */
+ private final double[][] leftVectors;
+
+ /** item singular vectors */
+ private final double[][] rightVectors;
+
+ private final DataModel dataModel;
+ private final List<Preference> cachedPreferences;
+
+ private static final Logger log =
LoggerFactory.getLogger(ExpectationMaximizationSVDFactorizer.class);
+
+ public ExpectationMaximizationSVDFactorizer(DataModel dataModel, int
numFeatures, int numIterations)
+ throws TasteException {
+ /* use the default parameters from the old SVDRecommender implementation */
+ this(dataModel, numFeatures, 0.005, 0.02, 0.005, numIterations);
+ }
+
+ public ExpectationMaximizationSVDFactorizer(DataModel dataModel, int
numFeatures, double learningRate,
+ double preventOverfitting, double randomNoise, int numIterations) throws
TasteException {
+ super(dataModel);
+ random = RandomUtils.getRandom();
+ this.dataModel = dataModel;
+ this.numFeatures = numFeatures;
+ this.numIterations = numIterations;
+
+ this.learningRate = learningRate;
+ this.preventOverfitting = preventOverfitting;
+
+ leftVectors = new double[dataModel.getNumUsers()][numFeatures];
+ rightVectors = new double[dataModel.getNumItems()][numFeatures];
+
+ double average = getAveragePreference();
+ double defaultValue = Math.sqrt((average - 1.0) / numFeatures);
+
+ for (int feature = 0; feature < numFeatures; feature++) {
+ for (int userIndex = 0; userIndex < dataModel.getNumUsers();
userIndex++) {
+ leftVectors[userIndex][feature] = defaultValue + (random.nextDouble()
- 0.5) * randomNoise;
+ }
+ for (int itemIndex = 0; itemIndex < dataModel.getNumItems();
itemIndex++) {
+ rightVectors[itemIndex][feature] = defaultValue + (random.nextDouble()
- 0.5) * randomNoise;
+ }
+ }
+
+ cachedPreferences = new ArrayList<Preference>(dataModel.getNumUsers());
+ }
+
+ @Override
+ public Factorization factorize() throws TasteException {
+ log.info("starting to compute the factorization...");
+
+ cachePreferences();
+ for (int currentIteration = 0; currentIteration < numIterations;
currentIteration++) {
+ log.info("iteration {}", currentIteration);
+ nextTrainStep();
+ }
+
+ log.info("finished computation of the factorization...");
+ return createFactorization(leftVectors, rightVectors);
+ }
+
+ void cachePreferences() throws TasteException {
+ cachedPreferences.clear();
+ LongPrimitiveIterator it = dataModel.getUserIDs();
+ while (it.hasNext()) {
+ for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) {
+ cachedPreferences.add(pref);
+ }
+ }
+ }
+
+ double getAveragePreference() throws TasteException {
+ RunningAverage average = new FullRunningAverage();
+ LongPrimitiveIterator it = dataModel.getUserIDs();
+ while (it.hasNext()) {
+ for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) {
+ average.addDatum(pref.getValue());
+ }
+ }
+ return average.getAverage();
+ }
+
+ void nextTrainStep() {
+ Collections.shuffle(cachedPreferences, random);
+ for (int feature = 0; feature < numFeatures; feature++) {
+ for (Preference pref : cachedPreferences) {
+ train(userIndex(pref.getUserID()), itemIndex(pref.getItemID()),
feature, pref.getValue());
+ }
+ }
+ }
+
+ double getDotProduct(int userIndex, int itemIndex) {
+ double result = 1.0;
+ for (int feature = 0; feature < this.numFeatures; feature++) {
+ result += leftVectors[userIndex][feature] *
rightVectors[itemIndex][feature];
+ }
+ return result;
+ }
+
+ void train(int userIndex, int itemIndex, int currentFeature, double value) {
+ double err = value - getDotProduct(userIndex, itemIndex);
+ double[] leftVector = leftVectors[userIndex];
+ double[] rightVector = rightVectors[itemIndex];
+ leftVector[currentFeature] += learningRate *
+ (err * rightVector[currentFeature] - preventOverfitting *
leftVector[currentFeature]);
+ rightVector[currentFeature] += learningRate *
+ (err * leftVector[currentFeature] - preventOverfitting *
rightVector[currentFeature]);
+ }
+
+}
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java?rev=1054567&view=auto
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java
(added)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java
Mon Jan 3 09:18:46 2011
@@ -0,0 +1,63 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import org.apache.mahout.cf.taste.common.NoSuchItemException;
+import org.apache.mahout.cf.taste.common.NoSuchUserException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+
+/**
+ * a factorization of the rating matrix
+ */
+public class Factorization {
+
+ /** used to find the rows in the user features matrix by userID */
+ private final FastByIDMap<Integer> userIDMapping;
+ /** used to find the rows in the item features matrix by itemID */
+ private final FastByIDMap<Integer> itemIDMapping;
+
+ /** user features matrix */
+ private final double[][] userFeatures;
+ /** item features matrix */
+ private final double[][] itemFeatures;
+
+ public Factorization(FastByIDMap<Integer> userIDMapping,
FastByIDMap<Integer> itemIDMapping, double[][] userFeatures,
+ double[][] itemFeatures) {
+ this.userIDMapping = userIDMapping;
+ this.itemIDMapping = itemIDMapping;
+ this.userFeatures = userFeatures;
+ this.itemFeatures = itemFeatures;
+ }
+
+ public double[] getUserFeatures(long userID) throws NoSuchUserException {
+ Integer index = userIDMapping.get(userID);
+ if (index == null) {
+ throw new NoSuchUserException();
+ }
+ return userFeatures[index];
+ }
+
+ public double[] getItemFeatures(long itemID) throws NoSuchItemException {
+ Integer index = itemIDMapping.get(itemID);
+ if (index == null) {
+ throw new NoSuchItemException();
+ }
+ return itemFeatures[index];
+ }
+
+}
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java?rev=1054567&view=auto
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java
(added)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java
Mon Jan 3 09:18:46 2011
@@ -0,0 +1,29 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+
+/**
+ * Implementation must be able to create a factorization of a rating matrix
+ */
+public interface Factorizer {
+
+ Factorization factorize() throws TasteException;
+
+}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java?rev=1054567&r1=1054566&r2=1054567&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java
Mon Jan 3 09:18:46 2011
@@ -17,171 +17,53 @@
package org.apache.mahout.cf.taste.impl.recommender.svd;
-import java.util.ArrayList;
import java.util.Collection;
-import java.util.Collections;
import java.util.List;
-import java.util.Random;
import java.util.concurrent.Callable;
-import org.apache.mahout.cf.taste.common.NoSuchItemException;
-import org.apache.mahout.cf.taste.common.NoSuchUserException;
+import com.google.common.base.Preconditions;
import org.apache.mahout.cf.taste.common.Refreshable;
import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
import org.apache.mahout.cf.taste.impl.common.FastIDSet;
-import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
-import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
-import org.apache.mahout.cf.taste.impl.common.RunningAverage;
import org.apache.mahout.cf.taste.impl.recommender.AbstractRecommender;
import org.apache.mahout.cf.taste.impl.recommender.TopItems;
import org.apache.mahout.cf.taste.model.DataModel;
-import org.apache.mahout.cf.taste.model.Preference;
import org.apache.mahout.cf.taste.recommender.CandidateItemsStrategy;
import org.apache.mahout.cf.taste.recommender.IDRescorer;
import org.apache.mahout.cf.taste.recommender.RecommendedItem;
-import org.apache.mahout.common.RandomUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import com.google.common.base.Preconditions;
-
/**
- * <p>
- * A {...@link org.apache.mahout.cf.taste.recommender.Recommender} which uses
Single Value Decomposition
- * to find the main features of the data set. Thanks to Simon Funk for the
hints in the implementation.
+ * A {...@link org.apache.mahout.cf.taste.recommender.Recommender} that uses
matrix factorization (a projection of users
+ * and items onto a feature space)
*/
public final class SVDRecommender extends AbstractRecommender {
-
- private static final Logger log =
LoggerFactory.getLogger(SVDRecommender.class);
- private static final Random random = RandomUtils.getRandom();
-
+
+ private Factorization factorization;
private final RefreshHelper refreshHelper;
-
- /** Number of features */
- private final int numFeatures;
-
- private final FastByIDMap<Integer> userMap;
- private final FastByIDMap<Integer> itemMap;
- private final ExpectationMaximizationSVD emSvd;
- private final List<Preference> cachedPreferences;
-
- /**
- * @param numFeatures
- * the number of features
- * @param initialSteps
- * number of initial training steps
- */
- public SVDRecommender(DataModel dataModel,
- CandidateItemsStrategy candidateItemsStrategy,
- int numFeatures,
- int initialSteps) throws TasteException {
+
+ private static final Logger log =
LoggerFactory.getLogger(SVDRecommender.class);
+
+ public SVDRecommender(DataModel dataModel, Factorizer factorizer) throws
TasteException {
+ this(dataModel, factorizer, getDefaultCandidateItemsStrategy());
+ }
+
+ public SVDRecommender(DataModel dataModel, Factorizer factorizer,
CandidateItemsStrategy candidateItemsStrategy)
+ throws TasteException {
super(dataModel, candidateItemsStrategy);
-
- this.numFeatures = numFeatures;
-
- int numUsers = dataModel.getNumUsers();
- userMap = new FastByIDMap<Integer>(numUsers);
-
- int idx = 0;
- LongPrimitiveIterator userIterator = dataModel.getUserIDs();
- while (userIterator.hasNext()) {
- userMap.put(userIterator.nextLong(), idx++);
- }
-
- int numItems = dataModel.getNumItems();
- itemMap = new FastByIDMap<Integer>(numItems);
-
- idx = 0;
- LongPrimitiveIterator itemIterator = dataModel.getItemIDs();
- while (itemIterator.hasNext()) {
- itemMap.put(itemIterator.nextLong(), idx++);
- }
-
- double average = getAveragePreference();
- double defaultValue = Math.sqrt((average - 1.0) / numFeatures);
-
- emSvd = new ExpectationMaximizationSVD(numUsers, numItems, numFeatures,
defaultValue);
- cachedPreferences = new ArrayList<Preference>(numUsers);
- recachePreferences();
-
+ factorization = factorizer.factorize();
refreshHelper = new RefreshHelper(new Callable<Object>() {
@Override
public Object call() throws TasteException {
- recachePreferences();
// TODO: train again
return null;
}
});
- refreshHelper.addDependency(dataModel);
-
- train(initialSteps);
- }
-
- public SVDRecommender(DataModel dataModel,
- int numFeatures,
- int initialSteps) throws TasteException {
- this(dataModel, getDefaultCandidateItemsStrategy(), numFeatures,
initialSteps);
- }
-
- private void recachePreferences() throws TasteException {
- cachedPreferences.clear();
- DataModel dataModel = getDataModel();
- LongPrimitiveIterator it = dataModel.getUserIDs();
- while (it.hasNext()) {
- for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) {
- cachedPreferences.add(pref);
- }
- }
- }
-
- private double getAveragePreference() throws TasteException {
- RunningAverage average = new FullRunningAverage();
- DataModel dataModel = getDataModel();
- LongPrimitiveIterator it = dataModel.getUserIDs();
- while (it.hasNext()) {
- for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) {
- average.addDatum(pref.getValue());
- }
- }
- return average.getAverage();
- }
-
- public void train(int steps) {
- for (int i = 0; i < steps; i++) {
- nextTrainStep();
- }
- }
-
- private void nextTrainStep() {
- Collections.shuffle(cachedPreferences, random);
- for (int i = 0; i < numFeatures; i++) {
- for (Preference pref : cachedPreferences) {
- int useridx = userMap.get(pref.getUserID());
- int itemidx = itemMap.get(pref.getItemID());
- emSvd.train(useridx, itemidx, i, pref.getValue());
- }
- }
- }
-
- private float predictRating(int user, int item) {
- return (float) emSvd.getDotProduct(user, item);
- }
-
- @Override
- public float estimatePreference(long userID, long itemID) throws
TasteException {
- Integer useridx = userMap.get(userID);
- if (useridx == null) {
- throw new NoSuchUserException();
- }
- Integer itemidx = itemMap.get(itemID);
- if (itemidx == null) {
- throw new NoSuchItemException();
- }
- return predictRating(useridx, itemidx);
+ refreshHelper.addDependency(getDataModel());
}
-
+
@Override
public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer
rescorer) throws TasteException {
Preconditions.checkArgument(howMany >= 1, "howMany must be at least 1");
@@ -189,37 +71,43 @@ public final class SVDRecommender extend
FastIDSet possibleItemIDs = getAllOtherItems(userID);
- TopItems.Estimator<Long> estimator = new Estimator(userID);
-
List<RecommendedItem> topItems = TopItems.getTopItems(howMany,
possibleItemIDs.iterator(), rescorer,
- estimator);
-
+ new Estimator(userID));
log.debug("Recommendations are: {}", topItems);
+
return topItems;
}
-
- @Override
- public void refresh(Collection<Refreshable> alreadyRefreshed) {
- refreshHelper.refresh(alreadyRefreshed);
- }
-
+
+ /**
+ * a preference is estimated by computing the dot-product of the user and
item feature vectors
+ */
@Override
- public String toString() {
- return "SVDRecommender[numFeatures:" + numFeatures + ']';
+ public float estimatePreference(long userID, long itemID) throws
TasteException {
+ double[] userFeatures = factorization.getUserFeatures(userID);
+ double[] itemFeatures = factorization.getItemFeatures(itemID);
+ double estimate = 0;
+ for (int feature = 0; feature < userFeatures.length; feature++) {
+ estimate += userFeatures[feature] * itemFeatures[feature];
+ }
+ return (float) estimate;
}
-
+
private final class Estimator implements TopItems.Estimator<Long> {
-
+
private final long theUserID;
-
+
private Estimator(long theUserID) {
this.theUserID = theUserID;
}
-
+
@Override
public double estimate(Long itemID) throws TasteException {
return estimatePreference(theUserID, itemID);
}
}
-
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ refreshHelper.refresh(alreadyRefreshed);
+ }
}
Added:
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java?rev=1054567&view=auto
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java
(added)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java
Mon Jan 3 09:18:46 2011
@@ -0,0 +1,149 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.impl.model.GenericDataModel;
+import org.apache.mahout.cf.taste.impl.model.GenericPreference;
+import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Arrays;
+
+public class ALSWRFactorizerTest extends TasteTestCase {
+
+ ALSWRFactorizer factorizer;
+ DataModel dataModel;
+
+ /**
+ * rating-matrix
+ *
+ * burger hotdog berries icecream
+ * dog 5 5 2 -
+ * rabbit 2 - 3 5
+ * cow - 5 - 3
+ * donkey 3 - - 5
+ */
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ FastByIDMap<PreferenceArray> userData = new FastByIDMap<PreferenceArray>();
+
+ userData.put(1l, new GenericUserPreferenceArray(Arrays.asList(new
Preference[] {
+ new GenericPreference(1l, 1l, 5f),
+ new GenericPreference(1l, 2l, 5f),
+ new GenericPreference(1l, 3l, 2f) })));
+
+ userData.put(2l, new GenericUserPreferenceArray(Arrays.asList(new
Preference[] {
+ new GenericPreference(2l, 1l, 2f),
+ new GenericPreference(2l, 3l, 3f),
+ new GenericPreference(2l, 4l, 5f) })));
+
+ userData.put(3l, new GenericUserPreferenceArray(Arrays.asList(new
Preference[] {
+ new GenericPreference(3l, 2l, 5f),
+ new GenericPreference(3l, 4l, 3f) })));
+
+ userData.put(4l, new GenericUserPreferenceArray(Arrays.asList(new
Preference[] {
+ new GenericPreference(4l, 1l, 3f),
+ new GenericPreference(4l, 4l, 5f) })));
+
+ dataModel = new GenericDataModel(userData);
+ factorizer = new ALSWRFactorizer(dataModel, 3, 0.065, 10);
+ }
+
+ @Test
+ public void setFeatureColumn() throws Exception {
+ double[][] matrix = new double[3][3];
+ Vector vector = new DenseVector(new double[] { 0.5, 2.0, 1.5 });
+ int index = 1;
+
+ factorizer.setFeatureColumn(matrix, index, vector);
+
+ assertEquals(vector.get(0), matrix[index][0], EPSILON);
+ assertEquals(vector.get(1), matrix[index][1], EPSILON);
+ assertEquals(vector.get(2), matrix[index][2], EPSILON);
+ }
+
+ @Test
+ public void ratingVector() throws Exception {
+ PreferenceArray prefs = dataModel.getPreferencesFromUser(1);
+
+ Vector ratingVector = factorizer.ratingVector(prefs);
+
+ assertEquals(prefs.length(), ratingVector.getNumNondefaultElements());
+ assertEquals(prefs.get(0).getValue(), ratingVector.get(0), EPSILON);
+ assertEquals(prefs.get(1).getValue(), ratingVector.get(1), EPSILON);
+ assertEquals(prefs.get(2).getValue(), ratingVector.get(2), EPSILON);
+ }
+
+ @Test
+ public void averageRating() throws Exception {
+ assertEquals(2.5, factorizer.averateRating(3l), EPSILON);
+ }
+
+ @Test
+ public void initializeM() throws Exception {
+ double[][] M = factorizer.initializeM();
+
+ assertEquals(3.333333333, M[0][0], EPSILON);
+ assertEquals(5, M[1][0], EPSILON);
+ assertEquals(2.5, M[2][0], EPSILON);
+ assertEquals(4.333333333, M[3][0], EPSILON);
+
+ for (int itemIndex = 0; itemIndex < dataModel.getNumItems(); itemIndex++) {
+ for (int feature = 1; feature < 3; feature++ ) {
+ assertTrue(M[itemIndex][feature] >= 0);
+ assertTrue(M[itemIndex][feature] <= 0.1);
+ }
+ }
+ }
+
+ @Test
+ public void toyExample() throws Exception {
+
+ SVDRecommender svdRecommender = new SVDRecommender(dataModel, factorizer);
+
+ /* a hold out test would be better, but this is just a toy example so we
only check that the
+ * factorization is close to the original matrix */
+ RunningAverage avg = new FullRunningAverage();
+ LongPrimitiveIterator userIDs = dataModel.getUserIDs();
+ while (userIDs.hasNext()) {
+ long userID = userIDs.nextLong();
+ for (Preference pref : dataModel.getPreferencesFromUser(userID)) {
+ double rating = pref.getValue();
+ double estimate = svdRecommender.estimatePreference(userID,
pref.getItemID());
+ double err = rating - estimate;
+ avg.addDatum(err * err);
+ }
+ }
+
+ double rmse = Math.sqrt(avg.getAverage());
+ assertTrue(rmse < 0.2d);
+ }
+}
Added:
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommenderTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommenderTest.java?rev=1054567&view=auto
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommenderTest.java
(added)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommenderTest.java
Mon Jan 3 09:18:46 2011
@@ -0,0 +1,82 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.recommender.CandidateItemsStrategy;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.easymock.classextension.EasyMock;
+import org.junit.Test;
+
+import java.util.List;
+
+public class SVDRecommenderTest extends TasteTestCase {
+
+ @Test
+ public void estimatePreference() throws Exception {
+ DataModel dataModel = EasyMock.createMock(DataModel.class);
+ Factorizer factorizer = EasyMock.createMock(Factorizer.class);
+ Factorization factorization = EasyMock.createMock(Factorization.class);
+
+ EasyMock.expect(factorizer.factorize()).andReturn(factorization);
+ EasyMock.expect(factorization.getUserFeatures(1L)).andReturn(new double[]
{ 0.4, 2 });
+ EasyMock.expect(factorization.getItemFeatures(5L)).andReturn(new double[]
{ 1, 0.3 });
+ EasyMock.replay(dataModel, factorizer, factorization);
+
+ SVDRecommender svdRecommender = new SVDRecommender(dataModel, factorizer);
+
+ float estimate = svdRecommender.estimatePreference(1L, 5L);
+ assertEquals(1, estimate, EPSILON);
+
+ EasyMock.verify(dataModel, factorizer, factorization);
+ }
+
+ @Test
+ public void recommend() throws Exception {
+ DataModel dataModel = EasyMock.createMock(DataModel.class);
+ CandidateItemsStrategy candidateItemsStrategy =
EasyMock.createMock(CandidateItemsStrategy.class);
+ Factorizer factorizer = EasyMock.createMock(Factorizer.class);
+ Factorization factorization = EasyMock.createMock(Factorization.class);
+
+ FastIDSet candidateItems = new FastIDSet();
+ candidateItems.add(5L);
+ candidateItems.add(3L);
+
+ EasyMock.expect(factorizer.factorize()).andReturn(factorization);
+ EasyMock.expect(candidateItemsStrategy.getCandidateItems(1L,
dataModel)).andReturn(candidateItems);
+ EasyMock.expect(factorization.getUserFeatures(1L)).andReturn(new double[]
{ 0.4, 2 });
+ EasyMock.expect(factorization.getItemFeatures(5L)).andReturn(new double[]
{ 1, 0.3 });
+ EasyMock.expect(factorization.getUserFeatures(1L)).andReturn(new double[]
{ 0.4, 2 });
+ EasyMock.expect(factorization.getItemFeatures(3L)).andReturn(new double[]
{ 2, 0.6 });
+
+ EasyMock.replay(dataModel, candidateItemsStrategy, factorizer,
factorization);
+
+ SVDRecommender svdRecommender = new SVDRecommender(dataModel, factorizer,
candidateItemsStrategy);
+
+ List<RecommendedItem> recommendedItems = svdRecommender.recommend(1L, 5);
+ assertEquals(2, recommendedItems.size());
+ assertEquals(3L, recommendedItems.get(0).getItemID());
+ assertEquals(2f, recommendedItems.get(0).getValue(), EPSILON);
+ assertEquals(5L, recommendedItems.get(1).getItemID());
+ assertEquals(1f, recommendedItems.get(1).getValue(), EPSILON);
+
+ EasyMock.verify(dataModel, candidateItemsStrategy, factorizer,
factorization);
+ }
+}
Added:
mahout/trunk/math/src/main/java/org/apache/mahout/math/als/AlternateLeastSquaresSolver.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/als/AlternateLeastSquaresSolver.java?rev=1054567&view=auto
==============================================================================
---
mahout/trunk/math/src/main/java/org/apache/mahout/math/als/AlternateLeastSquaresSolver.java
(added)
+++
mahout/trunk/math/src/main/java/org/apache/mahout/math/als/AlternateLeastSquaresSolver.java
Mon Jan 3 09:18:46 2011
@@ -0,0 +1,89 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.als;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.QRDecomposition;
+import org.apache.mahout.math.Vector;
+
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * {...@see
http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/netflix_aaim08(submitted).pdf}
+ */
+public class AlternateLeastSquaresSolver {
+
+ public Vector solve(List<Vector> featureVectors, Vector ratingVector, double
lambda, int numFeatures) {
+
+ Preconditions.checkNotNull(featureVectors);
+ Preconditions.checkArgument(!featureVectors.isEmpty());
+ Preconditions.checkNotNull(ratingVector);
+ Preconditions.checkArgument(featureVectors.size() ==
ratingVector.getNumNondefaultElements());
+
+ int nui = ratingVector.getNumNondefaultElements();
+
+ Matrix MiIi = createMiIi(featureVectors, numFeatures);
+ Matrix RiIiMaybeTransposed = createRiIiMaybeTransposed(ratingVector);
+
+ /* compute Ai = MiIi * t(MiIi) + lambda * nui * E */
+ Matrix Ai = addLambdaTimesNuiTimesE(MiIi.times(MiIi.transpose()), lambda,
nui);
+ /* compute Vi = MIi * t(R(i,Ii)) */
+ Matrix Vi = MiIi.times(RiIiMaybeTransposed);
+ /* compute ui = inverse(Ai) * Vi */
+ return solve(Ai, Vi);
+ }
+
+ Vector solve(Matrix Ai, Matrix Vi) {
+ return new QRDecomposition(Ai).solve(Vi).getColumn(0);
+ }
+
+ protected Matrix addLambdaTimesNuiTimesE(Matrix matrix, double lambda, int
nui) {
+ Preconditions.checkArgument(matrix.numCols() == matrix.numRows());
+ double lambdaTimesNui = lambda * nui;
+ for (int n = 0; n < matrix.numCols(); n++) {
+ matrix.setQuick(n, n, matrix.getQuick(n, n) + lambdaTimesNui);
+ }
+ return matrix;
+ }
+
+ protected Matrix createMiIi(List<Vector> featureVectors, int numFeatures) {
+ Matrix MiIi = new DenseMatrix(numFeatures, featureVectors.size());
+ for (int n = 0; n < featureVectors.size(); n++) {
+ Vector featureVector = featureVectors.get(n);
+ for (int m = 0; m < numFeatures; m++) {
+ MiIi.setQuick(m, n, featureVector.get(m));
+ }
+ }
+ return MiIi;
+ }
+
+ protected Matrix createRiIiMaybeTransposed(Vector ratingVector) {
+ Preconditions.checkArgument(ratingVector.isSequentialAccess());
+ Matrix RiIiMaybeTransposed = new
DenseMatrix(ratingVector.getNumNondefaultElements(), 1);
+ Iterator<Vector.Element> ratingsIterator = ratingVector.iterateNonZero();
+ int index = 0;
+ while (ratingsIterator.hasNext()) {
+ Vector.Element elem = ratingsIterator.next();
+ RiIiMaybeTransposed.setQuick(index++, 0, elem.get());
+ }
+ return RiIiMaybeTransposed;
+ }
+}
Added:
mahout/trunk/math/src/test/java/org/apache/mahout/math/als/AlternateLeastSquaresSolverTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/als/AlternateLeastSquaresSolverTest.java?rev=1054567&view=auto
==============================================================================
---
mahout/trunk/math/src/test/java/org/apache/mahout/math/als/AlternateLeastSquaresSolverTest.java
(added)
+++
mahout/trunk/math/src/test/java/org/apache/mahout/math/als/AlternateLeastSquaresSolverTest.java
Mon Jan 3 09:18:46 2011
@@ -0,0 +1,98 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.als;
+
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.MahoutTestCase;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.SparseMatrix;
+import org.apache.mahout.math.Vector;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Arrays;
+
+public class AlternateLeastSquaresSolverTest extends MahoutTestCase {
+
+ AlternateLeastSquaresSolver solver;
+
+ @Before
+ public void setup() {
+ solver = new AlternateLeastSquaresSolver();
+ }
+
+ @Test
+ public void addLambdaTimesNuiTimesE() {
+ int nui = 5;
+ double lambda = 0.2;
+ Matrix matrix = new SparseMatrix(new int[] { 5, 5 });
+
+ solver.addLambdaTimesNuiTimesE(matrix, lambda, nui);
+
+ for (int n = 0; n < 5; n++) {
+ assertEquals(1.0, matrix.getQuick(n, n), EPSILON);
+ }
+ }
+
+ @Test
+ public void createMiIi() {
+ Vector f1 = new DenseVector(new double[] { 1, 2, 3 });
+ Vector f2 = new DenseVector(new double[] { 4, 5, 6 });
+
+ Matrix miIi = solver.createMiIi(Arrays.asList(f1, f2), 3);
+
+ assertEquals(1.0, miIi.getQuick(0, 0), EPSILON);
+ assertEquals(2.0, miIi.getQuick(1, 0), EPSILON);
+ assertEquals(3.0, miIi.getQuick(2, 0), EPSILON);
+ assertEquals(4.0, miIi.getQuick(0, 1), EPSILON);
+ assertEquals(5.0, miIi.getQuick(1, 1), EPSILON);
+ assertEquals(6.0, miIi.getQuick(2, 1), EPSILON);
+ }
+
+ @Test
+ public void createRiIiMaybeTransposed() {
+ Vector ratings = new SequentialAccessSparseVector(3);
+ ratings.setQuick(1, 1.0);
+ ratings.setQuick(3, 3.0);
+ ratings.setQuick(5, 5.0);
+
+ Matrix riIiMaybeTransposed = solver.createRiIiMaybeTransposed(ratings);
+ assertEquals(1, riIiMaybeTransposed.numCols(), 1);
+ assertEquals(3, riIiMaybeTransposed.numRows(), 3);
+
+ assertEquals(1.0, riIiMaybeTransposed.getQuick(0, 0), EPSILON);
+ assertEquals(3.0, riIiMaybeTransposed.getQuick(1, 0), EPSILON);
+ assertEquals(5.0, riIiMaybeTransposed.getQuick(2, 0), EPSILON);
+ }
+
+ @Test
+ public void createRiIiMaybeTransposedExceptionOnNonSequentialVector() {
+ Vector ratings = new RandomAccessSparseVector(3);
+ ratings.setQuick(1, 1.0);
+ ratings.setQuick(3, 3.0);
+ ratings.setQuick(5, 5.0);
+
+ try {
+ solver.createRiIiMaybeTransposed(ratings);
+ fail();
+ } catch (IllegalArgumentException e) {}
+ }
+
+}