Author: ssc
Date: Tue Nov 8 10:13:22 2011
New Revision: 1199171
URL: http://svn.apache.org/viewvc?rev=1199171&view=rev
Log:
MAHOUT-877 Enable the parallel ALS recommender to use implicit feedback data
Added:
mahout/trunk/math/src/main/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolver.java
- copied, changed from r1199136,
mahout/trunk/math/src/main/java/org/apache/mahout/math/als/AlternateLeastSquaresSolver.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/als/ImplicitFeedbackAlternatingLeastSquaresSolver.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolverTest.java
- copied, changed from r1199136,
mahout/trunk/math/src/test/java/org/apache/mahout/math/als/AlternateLeastSquaresSolverTest.java
Removed:
mahout/trunk/math/src/main/java/org/apache/mahout/math/als/AlternateLeastSquaresSolver.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/als/AlternateLeastSquaresSolverTest.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java?rev=1199171&r1=1199170&r2=1199171&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java
Tue Nov 8 10:13:22 2011
@@ -20,6 +20,7 @@ package org.apache.mahout.cf.taste.hadoo
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
@@ -46,7 +47,8 @@ import org.apache.mahout.math.RandomAcce
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
-import org.apache.mahout.math.als.AlternateLeastSquaresSolver;
+import org.apache.mahout.math.als.AlternatingLeastSquaresSolver;
+import
org.apache.mahout.math.als.ImplicitFeedbackAlternatingLeastSquaresSolver;
import org.apache.mahout.math.map.OpenIntObjectHashMap;
import java.io.IOException;
@@ -56,13 +58,15 @@ import java.util.Map;
import java.util.Random;
/**
- * <p>MapReduce implementation of the factorization algorithm described in
"Large-scale Parallel Collaborative Filtering for the Netï¬ix Prize"
- * available at
+ * <p>MapReduce implementation of the two factorization algorithms described in
+ *
+ * <p>"Large-scale Parallel Collaborative Filtering for the Netï¬ix Prize"
available at
*
http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/netflix_aaim08(submitted).pdf.</p>
*
- * <p>Implements a parallel algorithm that uses "Alternating-Least-Squares
with Weighted-λ-Regularization"
- * to factorize the preference-matrix </p>
+ * "<p>Collaborative Filtering for Implicit Feedback Datasets" available at
+ * http://research.yahoo.com/pub/2433</p>
*
+ * </p>
* <p>Command line arguments specific to this class are:</p>
*
* <ol>
@@ -77,11 +81,14 @@ public class ParallelALSFactorizationJob
static final String NUM_FEATURES =
ParallelALSFactorizationJob.class.getName() + ".numFeatures";
static final String LAMBDA = ParallelALSFactorizationJob.class.getName() +
".lambda";
+ static final String ALPHA = ParallelALSFactorizationJob.class.getName() +
".alpha";
static final String FEATURE_MATRIX =
ParallelALSFactorizationJob.class.getName() + ".featureMatrix";
+ private boolean implicitFeedback;
private int numIterations;
private int numFeatures;
private double lambda;
+ private double alpha;
public static void main(String[] args) throws Exception {
ToolRunner.run(new ParallelALSFactorizationJob(), args);
@@ -92,8 +99,10 @@ public class ParallelALSFactorizationJob
addInputOption();
addOutputOption();
- addOption("lambda", "l", "regularization parameter", true);
- addOption("numFeatures", "f", "dimension of the feature space", true);
+ addOption("lambda", null, "regularization parameter", true);
+ addOption("implicitFeedback", null, "data consists of implicit feedback?",
String.valueOf(false));
+ addOption("alpha", null, "confidence parameter (only used on implicit
feedback)", String.valueOf(40));
+ addOption("numFeatures", null, "dimension of the feature space", true);
addOption("numIterations", null, "number of iterations", true);
Map<String,String> parsedArgs = parseArguments(args);
@@ -104,6 +113,8 @@ public class ParallelALSFactorizationJob
numFeatures = Integer.parseInt(parsedArgs.get("--numFeatures"));
numIterations = Integer.parseInt(parsedArgs.get("--numIterations"));
lambda = Double.parseDouble(parsedArgs.get("--lambda"));
+ alpha = Double.parseDouble(parsedArgs.get("--alpha"));
+ implicitFeedback =
Boolean.parseBoolean(parsedArgs.get("--implicitFeedback"));
/*
* compute the factorization A = U M'
@@ -143,7 +154,7 @@ public class ParallelALSFactorizationJob
for (int currentIteration = 0; currentIteration < numIterations;
currentIteration++) {
/* broadcast M, read A row-wise, recompute U row-wise */
runSolver(pathToUserRatings(), pathToU(currentIteration),
pathToM(currentIteration - 1));
- /* broadcast U, read A' row-wise, recompute I row-wise */
+ /* broadcast U, read A' row-wise, recompute M row-wise */
runSolver(pathToItemRatings(), pathToM(currentIteration),
pathToU(currentIteration));
}
@@ -191,28 +202,34 @@ public class ParallelALSFactorizationJob
private void runSolver(Path ratings, Path output, Path pathToUorI)
throws ClassNotFoundException, IOException, InterruptedException {
- Job solverForUorI = prepareJob(ratings, output,
SequenceFileInputFormat.class, SolveMapper.class, IntWritable.class,
+
+ Class<? extends Mapper> solverMapper = implicitFeedback ?
+ SolveImplicitFeedbackMapper.class : SolveExplicitFeedbackMapper.class;
+
+ Job solverForUorI = prepareJob(ratings, output,
SequenceFileInputFormat.class, solverMapper, IntWritable.class,
VectorWritable.class, SequenceFileOutputFormat.class);
- solverForUorI.getConfiguration().set(LAMBDA, String.valueOf(lambda));
- solverForUorI.getConfiguration().setInt(NUM_FEATURES, numFeatures);
- solverForUorI.getConfiguration().set(FEATURE_MATRIX,
pathToUorI.toString());
+ Configuration solverConf = solverForUorI.getConfiguration();
+ solverConf.set(LAMBDA, String.valueOf(lambda));
+ solverConf.set(ALPHA, String.valueOf(alpha));
+ solverConf.setInt(NUM_FEATURES, numFeatures);
+ solverConf.set(FEATURE_MATRIX, pathToUorI.toString());
solverForUorI.waitForCompletion(true);
}
- static class SolveMapper extends
Mapper<IntWritable,VectorWritable,IntWritable,VectorWritable> {
+ static class SolveExplicitFeedbackMapper extends
Mapper<IntWritable,VectorWritable,IntWritable,VectorWritable> {
private double lambda;
private int numFeatures;
private OpenIntObjectHashMap<Vector> UorM;
- private AlternateLeastSquaresSolver solver;
+ private AlternatingLeastSquaresSolver solver;
@Override
protected void setup(Mapper.Context ctx) throws IOException,
InterruptedException {
lambda = Double.parseDouble(ctx.getConfiguration().get(LAMBDA));
numFeatures = ctx.getConfiguration().getInt(NUM_FEATURES, -1);
- solver = new AlternateLeastSquaresSolver();
+ solver = new AlternatingLeastSquaresSolver();
Path UOrIPath = new Path(ctx.getConfiguration().get(FEATURE_MATRIX));
@@ -237,6 +254,35 @@ public class ParallelALSFactorizationJob
}
}
+ static class SolveImplicitFeedbackMapper extends
Mapper<IntWritable,VectorWritable,IntWritable,VectorWritable> {
+
+ private ImplicitFeedbackAlternatingLeastSquaresSolver solver;
+
+ @Override
+ protected void setup(Mapper.Context ctx) throws IOException,
InterruptedException {
+ double lambda = Double.parseDouble(ctx.getConfiguration().get(LAMBDA));
+ double alpha = Double.parseDouble(ctx.getConfiguration().get(ALPHA));
+ int numFeatures = ctx.getConfiguration().getInt(NUM_FEATURES, -1);
+
+ Path YPath = new Path(ctx.getConfiguration().get(FEATURE_MATRIX));
+ OpenIntObjectHashMap<Vector> Y = ALSUtils.readMatrixByRows(YPath,
ctx.getConfiguration());
+
+ solver = new ImplicitFeedbackAlternatingLeastSquaresSolver(numFeatures,
lambda, alpha, Y);
+
+ Preconditions.checkArgument(numFeatures > 0, "numFeatures was not set
correctly!");
+ }
+
+ @Override
+ protected void map(IntWritable userOrItemID, VectorWritable
ratingsWritable, Context ctx)
+ throws IOException, InterruptedException {
+ Vector ratings = new SequentialAccessSparseVector(ratingsWritable.get());
+
+ Vector uiOrmj = solver.solve(ratings);
+
+ ctx.write(userOrItemID, new VectorWritable(uiOrmj));
+ }
+ }
+
static class AverageRatingMapper extends
Mapper<IntWritable,VectorWritable,IntWritable,VectorWritable> {
@Override
protected void map(IntWritable r, VectorWritable v, Context ctx) throws
IOException, InterruptedException {
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java?rev=1199171&r1=1199170&r2=1199171&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
Tue Nov 8 10:13:22 2011
@@ -28,7 +28,7 @@ import org.apache.mahout.cf.taste.model.
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.als.AlternateLeastSquaresSolver;
+import org.apache.mahout.math.als.AlternatingLeastSquaresSolver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -74,20 +74,20 @@ public class ALSWRFactorizer extends Abs
private final double[][] U;
Features(ALSWRFactorizer factorizer) throws TasteException {
- this.dataModel = factorizer.dataModel;
- this.numFeatures = factorizer.numFeatures;
+ dataModel = factorizer.dataModel;
+ numFeatures = factorizer.numFeatures;
Random random = RandomUtils.getRandom();
- M = new double[this.dataModel.getNumItems()][this.numFeatures];
- LongPrimitiveIterator itemIDsIterator = this.dataModel.getItemIDs();
+ M = new double[dataModel.getNumItems()][numFeatures];
+ LongPrimitiveIterator itemIDsIterator = dataModel.getItemIDs();
while (itemIDsIterator.hasNext()) {
long itemID = itemIDsIterator.nextLong();
int itemIDIndex = factorizer.itemIndex(itemID);
M[itemIDIndex][0] = averateRating(itemID);
- for (int feature = 1; feature < this.numFeatures; feature++) {
+ for (int feature = 1; feature < numFeatures; feature++) {
M[itemIDIndex][feature] = random.nextDouble() * 0.1;
}
}
- U = new double[this.dataModel.getNumUsers()][this.numFeatures];
+ U = new double[dataModel.getNumUsers()][numFeatures];
}
double[][] getM() {
@@ -98,11 +98,11 @@ public class ALSWRFactorizer extends Abs
return U;
}
- DenseVector getUserFeatureColumn(int index) {
+ Vector getUserFeatureColumn(int index) {
return new DenseVector(U[index]);
}
- DenseVector getItemFeatureColumn(int index) {
+ Vector getItemFeatureColumn(int index) {
return new DenseVector(M[index]);
}
@@ -133,7 +133,7 @@ public class ALSWRFactorizer extends Abs
@Override
public Factorization factorize() throws TasteException {
log.info("starting to compute the factorization...");
- final AlternateLeastSquaresSolver solver = new
AlternateLeastSquaresSolver();
+ final AlternatingLeastSquaresSolver solver = new
AlternatingLeastSquaresSolver();
final Features features = new Features(this);
for (int iteration = 0; iteration < numIterations; iteration++) {
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java?rev=1199171&r1=1199170&r2=1199171&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java
Tue Nov 8 10:13:22 2011
@@ -28,6 +28,7 @@ import org.apache.mahout.math.MatrixSlic
import org.apache.mahout.math.SparseRowMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.hadoop.MathHelper;
+import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -39,6 +40,23 @@ public class ParallelALSFactorizationJob
private static final Logger log =
LoggerFactory.getLogger(ParallelALSFactorizationJobTest.class);
+ File inputFile;
+ File outputDir;
+ File tmpDir;
+
+ Configuration conf;
+
+ @Before
+ @Override
+ public void setUp() throws Exception {
+ super.setUp();
+ inputFile = getTestTempFile("prefs.txt");
+ outputDir = getTestTempDir("output");
+ outputDir.delete();
+ tmpDir = getTestTempDir("tmp");
+
+ conf = new Configuration();
+ }
/**
* small integration test that runs the full job
@@ -58,11 +76,6 @@ public class ParallelALSFactorizationJob
@Test
public void completeJobToyExample() throws Exception {
- File inputFile = getTestTempFile("prefs.txt");
- File outputDir = getTestTempDir("output");
- outputDir.delete();
- File tmpDir = getTestTempDir("tmp");
-
Double na = Double.NaN;
Matrix preferences = new SparseRowMatrix(4, 4, new Vector[] {
new DenseVector(new double[] { 5.0, 5.0, 2.0, na }),
@@ -70,8 +83,39 @@ public class ParallelALSFactorizationJob
new DenseVector(new double[] { na, 5.0, na, 3.0 }),
new DenseVector(new double[] { 3.0, na, na, 5.0 }) });
- StringBuilder prefsAsText = new StringBuilder();
- String separator = "";
+ writeLines(inputFile, preferencesAsText(preferences));
+
+ ParallelALSFactorizationJob alsFactorization = new
ParallelALSFactorizationJob();
+ alsFactorization.setConf(conf);
+
+ int numFeatures = 3;
+ int numIterations = 5;
+ double lambda = 0.065;
+
+ alsFactorization.run(new String[] { "--input",
inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(),
+ "--tempDir", tmpDir.getAbsolutePath(), "--lambda",
String.valueOf(lambda),
+ "--numFeatures", String.valueOf(numFeatures), "--numIterations",
String.valueOf(numIterations) });
+
+ Matrix u = MathHelper.readMatrix(conf, new
Path(outputDir.getAbsolutePath(), "U/part-m-00000"),
+ preferences.numRows(), numFeatures);
+ Matrix m = MathHelper.readMatrix(conf, new
Path(outputDir.getAbsolutePath(), "M/part-m-00000"),
+ preferences.numCols(), numFeatures);
+
+ StringBuilder info = new StringBuilder();
+ info.append("\nA - users x items\n\n");
+ info.append(nice(preferences));
+ info.append("\nU - users x features\n\n");
+ info.append(nice(u));
+ info.append("\nM - items x features\n\n");
+ info.append(nice(m));
+ Matrix Ak = u.times(m.transpose());
+ info.append("\nAk - users x items\n\n");
+ info.append(nice(Ak));
+ info.append("\n");
+
+ log.info(info.toString());
+
+ RunningAverage avg = new FullRunningAverage();
Iterator<MatrixSlice> sliceIterator = preferences.iterateAll();
while (sliceIterator.hasNext()) {
MatrixSlice slice = sliceIterator.next();
@@ -79,75 +123,119 @@ public class ParallelALSFactorizationJob
while (elementIterator.hasNext()) {
Vector.Element e = elementIterator.next();
if (!Double.isNaN(e.get())) {
-
prefsAsText.append(separator).append(slice.index()).append(',').append(e.index()).append(',').append(e.get());
- separator = "\n";
+ double pref = e.get();
+ double estimate = u.viewRow(slice.index()).dot(m.viewRow(e.index()));
+ double err = pref - estimate;
+ avg.addDatum(err * err);
+ log.info("Comparing preference of user [{}] towards item [{}], was
[{}] estimate is [{}]",
+ new Object[] { slice.index(), e.index(), pref, estimate });
}
}
}
- log.info("Input matrix:\n{}", prefsAsText);
- writeLines(inputFile, prefsAsText.toString());
+ double rmse = Math.sqrt(avg.getAverage());
+ log.info("RMSE: {}", rmse);
- ParallelALSFactorizationJob alsFactorization = new
ParallelALSFactorizationJob();
+ assertTrue(rmse < 0.2);
+ }
- Configuration conf = new Configuration();
- conf.set("mapred.input.dir", inputFile.getAbsolutePath());
- conf.set("mapred.output.dir", outputDir.getAbsolutePath());
- conf.setBoolean("mapred.output.compress", false);
+ @Test
+ public void completeJobImplicitToyExample() throws Exception {
+ Matrix observations = new SparseRowMatrix(4, 4, new Vector[] {
+ new DenseVector(new double[] { 5.0, 5.0, 2.0, 0 }),
+ new DenseVector(new double[] { 2.0, 0, 3.0, 5.0 }),
+ new DenseVector(new double[] { 0, 5.0, 0, 3.0 }),
+ new DenseVector(new double[] { 3.0, 0, 0, 5.0 }) });
+
+ Matrix preferences = new SparseRowMatrix(4, 4, new Vector[] {
+ new DenseVector(new double[] { 1.0, 1.0, 1.0, 0 }),
+ new DenseVector(new double[] { 1.0, 0, 1.0, 1.0 }),
+ new DenseVector(new double[] { 0, 1.0, 0, 1.0 }),
+ new DenseVector(new double[] { 1.0, 0, 0, 1.0 }) });
+
+ writeLines(inputFile, preferencesAsText(observations));
+
+ ParallelALSFactorizationJob alsFactorization = new
ParallelALSFactorizationJob();
alsFactorization.setConf(conf);
+
int numFeatures = 3;
int numIterations = 5;
double lambda = 0.065;
- alsFactorization.run(new String[] { "--tempDir", tmpDir.getAbsolutePath(),
"--lambda", String.valueOf(lambda),
+ double alpha = 20;
+
+ alsFactorization.run(new String[] { "--input",
inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(),
+ "--tempDir", tmpDir.getAbsolutePath(), "--lambda",
String.valueOf(lambda),
+ "--implicitFeedback", String.valueOf(true), "--alpha",
String.valueOf(alpha),
"--numFeatures", String.valueOf(numFeatures), "--numIterations",
String.valueOf(numIterations) });
Matrix u = MathHelper.readMatrix(conf, new
Path(outputDir.getAbsolutePath(), "U/part-m-00000"),
- preferences.numRows(), numFeatures);
+ observations.numRows(), numFeatures);
Matrix m = MathHelper.readMatrix(conf, new
Path(outputDir.getAbsolutePath(), "M/part-m-00000"),
- preferences.numCols(), numFeatures);
+ observations.numCols(), numFeatures);
- System.out.println("A - users x items\n");
- for (int n = 0; n < preferences.numRows(); n++) {
- System.out.println(ALSUtils.nice(preferences.viewRow(n)));
- }
- System.out.println("\nU - users x features\n");
- for (int n = 0; n < u.numRows(); n++) {
- System.out.println(ALSUtils.nice(u.viewRow(n)));
- }
- System.out.println("\nM - items x features\n");
- for (int n = 0; n < m.numRows(); n++) {
- System.out.println(ALSUtils.nice(m.viewRow(n)));
- }
+ StringBuilder info = new StringBuilder();
+ info.append("\nObservations - users x items\n");
+ info.append(nice(observations));
+ info.append("\nA - users x items\n\n");
+ info.append(nice(preferences));
+ info.append("\nU - users x features\n\n");
+ info.append(nice(u));
+ info.append("\nM - items x features\n\n");
+ info.append(nice(m));
Matrix Ak = u.times(m.transpose());
- System.out.println("\nAk - users x items\n");
- for (int n = 0; n < Ak.numRows(); n++) {
- System.out.println(ALSUtils.nice(Ak.viewRow(n)));
- }
-
- System.out.println();
+ info.append("\nAk - users x items\n\n");
+ info.append(nice(Ak));
+ info.append("\n");
+ log.info(info.toString());
RunningAverage avg = new FullRunningAverage();
- sliceIterator = preferences.iterateAll();
+ Iterator<MatrixSlice> sliceIterator = preferences.iterateAll();
while (sliceIterator.hasNext()) {
MatrixSlice slice = sliceIterator.next();
- Iterator<Vector.Element> elementIterator =
slice.vector().iterateNonZero();
+ Iterator<Vector.Element> elementIterator = slice.vector().iterator();
while (elementIterator.hasNext()) {
Vector.Element e = elementIterator.next();
if (!Double.isNaN(e.get())) {
double pref = e.get();
double estimate = u.viewRow(slice.index()).dot(m.viewRow(e.index()));
- double err = pref - estimate;
- avg.addDatum(err * err);
- log.info("Comparing preference of user [{}] towards item [{}], was
[{}] estimate is [{}]",
- new Object[] { slice.index(), e.index(), pref, estimate });
+ double confidence = 1 + alpha * observations.getQuick(slice.index(),
e.index());
+ double err = confidence * (pref - estimate) * (pref - estimate);
+ avg.addDatum(err);
+ log.info("Comparing preference of user [{}] towards item [{}], was
[{}] with confidence [{}] " +
+ "estimate is [{}]", new Object[] { slice.index(), e.index(),
pref, confidence, estimate });
}
}
}
double rmse = Math.sqrt(avg.getAverage());
log.info("RMSE: {}", rmse);
- assertTrue(rmse < 0.2);
+ assertTrue(rmse < 0.4);
}
+ protected String preferencesAsText(Matrix preferences) {
+ StringBuilder prefsAsText = new StringBuilder();
+ String separator = "";
+ Iterator<MatrixSlice> sliceIterator = preferences.iterateAll();
+ while (sliceIterator.hasNext()) {
+ MatrixSlice slice = sliceIterator.next();
+ Iterator<Vector.Element> elementIterator =
slice.vector().iterateNonZero();
+ while (elementIterator.hasNext()) {
+ Vector.Element e = elementIterator.next();
+ if (!Double.isNaN(e.get())) {
+
prefsAsText.append(separator).append(slice.index()).append(',').append(e.index()).append(',').append(e.get());
+ separator = "\n";
+ }
+ }
+ }
+ return prefsAsText.toString();
+ }
+
+ protected StringBuilder nice(Matrix matrix) {
+ StringBuilder info = new StringBuilder();
+ for (int n = 0; n < matrix.numRows(); n++) {
+ info.append(ALSUtils.nice(matrix.viewRow(n))).append("\n");
+ }
+ return info;
+ }
}
Copied:
mahout/trunk/math/src/main/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolver.java
(from r1199136,
mahout/trunk/math/src/main/java/org/apache/mahout/math/als/AlternateLeastSquaresSolver.java)
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolver.java?p2=mahout/trunk/math/src/main/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolver.java&p1=mahout/trunk/math/src/main/java/org/apache/mahout/math/als/AlternateLeastSquaresSolver.java&r1=1199136&r2=1199171&rev=1199171&view=diff
==============================================================================
---
mahout/trunk/math/src/main/java/org/apache/mahout/math/als/AlternateLeastSquaresSolver.java
(original)
+++
mahout/trunk/math/src/main/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolver.java
Tue Nov 8 10:13:22 2011
@@ -30,7 +30,7 @@ import java.util.Iterator;
* See <a
href="http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/netflix_aaim08(submitted).pdf">
* this paper.</a>
*/
-public class AlternateLeastSquaresSolver {
+public class AlternatingLeastSquaresSolver {
public Vector solve(Iterable<Vector> featureVectors, Vector ratingVector,
double lambda, int numFeatures) {
Added:
mahout/trunk/math/src/main/java/org/apache/mahout/math/als/ImplicitFeedbackAlternatingLeastSquaresSolver.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/als/ImplicitFeedbackAlternatingLeastSquaresSolver.java?rev=1199171&view=auto
==============================================================================
---
mahout/trunk/math/src/main/java/org/apache/mahout/math/als/ImplicitFeedbackAlternatingLeastSquaresSolver.java
(added)
+++
mahout/trunk/math/src/main/java/org/apache/mahout/math/als/ImplicitFeedbackAlternatingLeastSquaresSolver.java
Tue Nov 8 10:13:22 2011
@@ -0,0 +1,132 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.als;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.QRDecomposition;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.list.IntArrayList;
+import org.apache.mahout.math.map.OpenIntObjectHashMap;
+
+import java.util.Iterator;
+
+/** see <a href="http://research.yahoo.com/pub/2433">Collaborative Filtering
for Implicit Feedback Datasets</a> */
+public class ImplicitFeedbackAlternatingLeastSquaresSolver {
+
+ private final int numFeatures;
+ private final double alpha;
+ private final double lambda;
+
+ private final OpenIntObjectHashMap<Vector> Y;
+ private final Matrix YtransposeY;
+
+ public ImplicitFeedbackAlternatingLeastSquaresSolver(int numFeatures, double
lambda, double alpha,
+ OpenIntObjectHashMap Y) {
+ this.numFeatures = numFeatures;
+ this.lambda = lambda;
+ this.alpha = alpha;
+ this.Y = Y;
+ YtransposeY = YtransposeY(Y);
+ }
+
+ public Vector solve(Vector ratings) {
+ return solve(YtransposeY.plus(YtransponseCuMinusIYPlusLambdaI(ratings)),
YtransponseCuPu(ratings));
+ }
+
+ private Vector solve(Matrix A, Matrix y) {
+ return new QRDecomposition(A).solve(y).viewColumn(0);
+ }
+
+ protected double confidence(double rating) {
+ return 1 + alpha * rating;
+ }
+
+ /* Y' Y */
+ private Matrix YtransposeY(OpenIntObjectHashMap<Vector> Y) {
+
+ Matrix compactedY = new DenseMatrix(Y.size(), numFeatures);
+ IntArrayList indexes = Y.keys();
+ indexes.quickSort();
+
+ int row = 0;
+ for (int index : indexes.elements()) {
+ compactedY.assignRow(row++, Y.get(index));
+ }
+
+ return compactedY.transpose().times(compactedY);
+ }
+
+ /** Y' (Cu - I) Y + λ I */
+ private Matrix YtransponseCuMinusIYPlusLambdaI(Vector userRatings) {
+ Preconditions.checkArgument(userRatings.isSequentialAccess(), "need
sequential access to ratings!");
+
+ /* (Cu -I) Y */
+ OpenIntObjectHashMap<Vector> CuMinusIY = new
OpenIntObjectHashMap<Vector>();
+ Iterator<Vector.Element> ratings = userRatings.iterateNonZero();
+ while (ratings.hasNext()) {
+ Vector.Element e = ratings.next();
+ CuMinusIY.put(e.index(), Y.get(e.index()).times(confidence(e.get()) -
1));
+ }
+
+ Matrix YtransponseCuMinusIY = new DenseMatrix(numFeatures, numFeatures);
+
+ /* Y' (Cu -I) Y by outer products */
+ ratings = userRatings.iterateNonZero();
+ while (ratings.hasNext()) {
+ Vector.Element e = ratings.next();
+ for (Vector.Element feature : Y.get(e.index())) {
+ Vector partial = CuMinusIY.get(e.index()).times(feature.get());
+ YtransponseCuMinusIY.viewRow(feature.index()).assign(partial,
Functions.PLUS);
+ }
+ }
+
+ /* Y' (Cu - I) Y + λ I add lambda on the diagonal */
+ for (int feature = 0; feature < numFeatures; feature++) {
+ YtransponseCuMinusIY.setQuick(feature, feature,
YtransponseCuMinusIY.getQuick(feature, feature) + lambda);
+ }
+
+ return YtransponseCuMinusIY;
+ }
+
+ /** Y' Cu p(u) */
+ private Matrix YtransponseCuPu(Vector userRatings) {
+ Preconditions.checkArgument(userRatings.isSequentialAccess(), "need
sequential access to ratings!");
+
+ Vector YtransponseCuPu = new DenseVector(numFeatures);
+
+ Iterator<Vector.Element> ratings = userRatings.iterateNonZero();
+ while (ratings.hasNext()) {
+ Vector.Element e = ratings.next();
+ YtransponseCuPu.assign((Y.get(e.index()).times(confidence(e.get()))),
Functions.PLUS);
+ }
+
+ return columnVectorAsMatrix(YtransponseCuPu);
+ }
+
+ private Matrix columnVectorAsMatrix(Vector v) {
+ Matrix matrix = new DenseMatrix(numFeatures, 1);
+ for (Vector.Element e : v)
+ matrix.setQuick(e.index(), 0, e.get());
+ return matrix;
+ }
+
+}
Copied:
mahout/trunk/math/src/test/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolverTest.java
(from r1199136,
mahout/trunk/math/src/test/java/org/apache/mahout/math/als/AlternateLeastSquaresSolverTest.java)
URL:
http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolverTest.java?p2=mahout/trunk/math/src/test/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolverTest.java&p1=mahout/trunk/math/src/test/java/org/apache/mahout/math/als/AlternateLeastSquaresSolverTest.java&r1=1199136&r2=1199171&rev=1199171&view=diff
==============================================================================
---
mahout/trunk/math/src/test/java/org/apache/mahout/math/als/AlternateLeastSquaresSolverTest.java
(original)
+++
mahout/trunk/math/src/test/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolverTest.java
Tue Nov 8 10:13:22 2011
@@ -29,15 +29,15 @@ import org.junit.Test;
import java.util.Arrays;
-public class AlternateLeastSquaresSolverTest extends MahoutTestCase {
+public class AlternatingLeastSquaresSolverTest extends MahoutTestCase {
- private AlternateLeastSquaresSolver solver;
+ private AlternatingLeastSquaresSolver solver;
@Override
@Before
public void setUp() throws Exception {
super.setUp();
- solver = new AlternateLeastSquaresSolver();
+ solver = new AlternatingLeastSquaresSolver();
}
@Test