Author: ssc
Date: Tue Apr 12 09:40:58 2011
New Revision: 1091345
URL: http://svn.apache.org/viewvc?rev=1091345&view=rev
Log:
MAHOUT-542 missing evaluation classes
Added:
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/DatasetSplitter.java
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/InMemoryFactorizationEvaluator.java
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/ParallelFactorizationEvaluator.java
mahout/trunk/utils/src/test/java/org/apache/mahout/utils/eval/
mahout/trunk/utils/src/test/java/org/apache/mahout/utils/eval/ParallelFactorizationEvaluatorTest.java
Added:
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/DatasetSplitter.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/DatasetSplitter.java?rev=1091345&view=auto
==============================================================================
---
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/DatasetSplitter.java
(added)
+++
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/DatasetSplitter.java
Tue Apr 12 09:40:58 2011
@@ -0,0 +1,149 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils.eval;
+
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.BooleanWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.RandomUtils;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * <p>Split a recommendation dataset into a training and a test set</p>
+ *
+ * <p>Command line arguments specific to this class are:</p>
+ *
+ * <ol>
+ * <li>--input (path): Directory containing one or more text files with the
dataset</li>
+ * <li>--output (path): path where output should go</li>
+ * <li>--trainingPercentage (double): percentage of the data to use as
training set (optional, default 0.9)</li>
+ * <li>--probePercentage (double): percentage of the data to use as probe set
(optional, default 0.1)</li>
+ * </ol>
+ */
+public class DatasetSplitter extends AbstractJob {
+
+ private static final String TRAINING_PERCENTAGE =
DatasetSplitter.class.getName() + ".trainingPercentage";
+ private static final String PROBE_PERCENTAGE =
DatasetSplitter.class.getName() + ".probePercentage";
+ private static final String PART_TO_USE = DatasetSplitter.class.getName() +
".partToUse";
+
+ private static final Text INTO_TRAINING_SET = new Text("T");
+ private static final Text INTO_PROBE_SET = new Text("P");
+
+ private static final double DEFAULT_TRAINING_PERCENTAGE = 0.9;
+ private static final double DEFAULT_PROBE_PERCENTAGE = 0.1;
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new DatasetSplitter(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOutputOption();
+ addOption("trainingPercentage", "t", "percentage of the data to use as
training set (default: " +
+ DEFAULT_TRAINING_PERCENTAGE + ")",
String.valueOf(DEFAULT_TRAINING_PERCENTAGE));
+ addOption("probePercentage", "p", "percentage of the data to use as probe
set (default: " +
+ DEFAULT_PROBE_PERCENTAGE +")",
String.valueOf(DEFAULT_PROBE_PERCENTAGE));
+
+ Map<String, String> parsedArgs = parseArguments(args);
+ double trainingPercentage =
Double.parseDouble(parsedArgs.get("--trainingPercentage"));
+ double probePercentage =
Double.parseDouble(parsedArgs.get("--probePercentage"));
+ String tempDir = parsedArgs.get("--tempDir");
+
+ Path markedPrefs = new Path(tempDir, "markedPreferences");
+ Path trainingSetPath = new Path(getOutputPath(), "trainingSet");
+ Path probeSetPath = new Path(getOutputPath(), "probeSet");
+
+ Job markPreferences = prepareJob(getInputPath(), markedPrefs,
TextInputFormat.class, MarkPreferencesMapper.class,
+ Text.class, Text.class, Reducer.class, Text.class, Text.class,
+ SequenceFileOutputFormat.class);
+ markPreferences.getConfiguration().set(TRAINING_PERCENTAGE,
String.valueOf(trainingPercentage));
+ markPreferences.getConfiguration().set(PROBE_PERCENTAGE,
String.valueOf(probePercentage));
+ markPreferences.waitForCompletion(true);
+
+ Job createTrainingSet = prepareJob(markedPrefs, trainingSetPath,
SequenceFileInputFormat.class,
+ WritePrefsMapper.class, NullWritable.class, Text.class, Reducer.class,
NullWritable.class, Text.class,
+ TextOutputFormat.class);
+ createTrainingSet.getConfiguration().set(PART_TO_USE,
INTO_TRAINING_SET.toString());
+ createTrainingSet.waitForCompletion(true);
+
+ Job createProbeSet = prepareJob(markedPrefs, probeSetPath,
SequenceFileInputFormat.class,
+ WritePrefsMapper.class, NullWritable.class, Text.class, Reducer.class,
NullWritable.class, Text.class,
+ TextOutputFormat.class);
+ createProbeSet.getConfiguration().set(PART_TO_USE,
INTO_PROBE_SET.toString());
+ createProbeSet.waitForCompletion(true);
+
+ return 0;
+ }
+
+ static class MarkPreferencesMapper extends
Mapper<LongWritable,Text,Text,Text> {
+
+ private Random random;
+ private double trainingBound;
+ private double probeBound;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException
{
+ random = RandomUtils.getRandom();
+ trainingBound =
Double.parseDouble(ctx.getConfiguration().get(TRAINING_PERCENTAGE));
+ probeBound = trainingBound +
Double.parseDouble(ctx.getConfiguration().get(PROBE_PERCENTAGE));
+ }
+
+ @Override
+ protected void map(LongWritable key, Text text, Context ctx) throws
IOException, InterruptedException {
+ double randomValue = random.nextDouble();
+ if (randomValue <= trainingBound) {
+ ctx.write(INTO_TRAINING_SET, text);
+ } else if (randomValue <= probeBound) {
+ ctx.write(INTO_PROBE_SET, text);
+ }
+ }
+ }
+
+ static class WritePrefsMapper extends Mapper<Text,Text,NullWritable,Text> {
+
+ private String partToUse;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException
{
+ partToUse = ctx.getConfiguration().get(PART_TO_USE);
+ }
+
+ @Override
+ protected void map(Text key, Text text, Context ctx) throws IOException,
InterruptedException {
+ if (partToUse.equals(key.toString())) {
+ ctx.write(NullWritable.get(), text);
+ }
+ }
+ }
+}
\ No newline at end of file
Added:
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/InMemoryFactorizationEvaluator.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/InMemoryFactorizationEvaluator.java?rev=1091345&view=auto
==============================================================================
---
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/InMemoryFactorizationEvaluator.java
(added)
+++
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/InMemoryFactorizationEvaluator.java
Tue Apr 12 09:40:58 2011
@@ -0,0 +1,168 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils.eval;
+
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.model.GenericPreference;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.IOUtils;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.io.OutputStreamWriter;
+import java.io.Writer;
+import java.nio.charset.Charset;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * <p>Measures the root-mean-squared error of a ratring matrix factorization
against a test set.</p>
+ *
+ * <p>the factorization matrices are read into memory, which makes this job
pretty fast, if you get OutOfMemoryErrors,
+ * use {@link ParallelFactorizationEvaluator} instead</p>
+ *
+ * <p>Command line arguments specific to this class are:</p>
+ *
+ * <ol>
+ * <li>--output (path): path where output should go</li>
+ * <li>--pairs (path): path containing the test ratings, each line must be
userID,itemID,rating</li>
+ * <li>--userFeatures (path): path to the user feature matrix</li>
+ * <li>--itemFeatures (path): path to the item feature matrix</li>
+ * </ol>
+ */
+public class InMemoryFactorizationEvaluator extends AbstractJob {
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new InMemoryFactorizationEvaluator(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addOption("pairs", "p", "path containing the test ratings, each line must
be userID,itemID,rating", true);
+ addOption("userFeatures", "u", "path to the user feature matrix", true);
+ addOption("itemFeatures", "i", "path to the item feature matrix", true);
+ addOutputOption();
+
+ Map<String,String> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+
+ Path pairs = new Path(parsedArgs.get("--pairs"));
+ Path userFeatures = new Path(parsedArgs.get("--userFeatures"));
+ Path itemFeatures = new Path(parsedArgs.get("--itemFeatures"));
+
+ Matrix u = readMatrix(userFeatures);
+ Matrix m = readMatrix(itemFeatures);
+
+ FullRunningAverage rmseAvg = new FullRunningAverage();
+ FullRunningAverage maeAvg = new FullRunningAverage();
+ int pairsUsed = 1;
+ Writer writer = new OutputStreamWriter(System.out);
+ try {
+ for (Preference pref : readProbePreferences(pairs)) {
+ int userID = (int) pref.getUserID();
+ int itemID = (int) pref.getItemID();
+
+ double rating = pref.getValue();
+ double estimate = u.getRow(userID).dot(m.getRow(itemID));
+ double err = rating - estimate;
+ rmseAvg.addDatum(err * err);
+ maeAvg.addDatum(Math.abs(err));
+ writer.write("Probe [" + pairsUsed + "], rating of user [" + userID +
"] towards item [" + itemID + "], " +
+ "[" + rating + "] estimated [" + estimate + "]\n");
+ pairsUsed++;
+ }
+ double rmse = Math.sqrt(rmseAvg.getAverage());
+ double mae = maeAvg.getAverage();
+ writer.write("RMSE: " + rmse + ", MAE: " + mae + "\n");
+ } finally {
+ IOUtils.quietClose(writer);
+ }
+ return 0;
+ }
+
+ private Matrix readMatrix(Path dir) throws IOException {
+
+ Matrix matrix = new SparseMatrix(new int[] { Integer.MAX_VALUE,
Integer.MAX_VALUE });
+
+ FileSystem fs = dir.getFileSystem(getConf());
+ for (FileStatus seqFile : fs.globStatus(new Path(dir, "part-*"))) {
+ Path path = seqFile.getPath();
+ SequenceFile.Reader reader = null;
+ try {
+ reader = new SequenceFile.Reader(fs, path, getConf());
+ IntWritable key = new IntWritable();
+ VectorWritable value = new VectorWritable();
+ while (reader.next(key, value)) {
+ int row = key.get();
+ Iterator<Vector.Element> elementsIterator =
value.get().iterateNonZero();
+ while (elementsIterator.hasNext()) {
+ Vector.Element element = elementsIterator.next();
+ matrix.set(row, element.index(), element.get());
+ }
+ }
+ } finally {
+ IOUtils.quietClose(reader);
+ }
+ }
+ return matrix;
+ }
+
+ private List<Preference> readProbePreferences(Path dir) throws IOException {
+
+ List<Preference> preferences = new LinkedList<Preference>();
+ FileSystem fs = dir.getFileSystem(getConf());
+ for (FileStatus seqFile : fs.globStatus(new Path(dir, "part-*"))) {
+ Path path = seqFile.getPath();
+ InputStream in = null;
+ try {
+ in = fs.open(path);
+ BufferedReader reader = new BufferedReader(new InputStreamReader(in,
Charset.forName("UTF-8")));
+ String line;
+ while ((line = reader.readLine()) != null) {
+ String[] tokens = TasteHadoopUtils.splitPrefTokens(line);
+ long userID = Long.parseLong(tokens[0]);
+ long itemID = Long.parseLong(tokens[1]);
+ float value = Float.parseFloat(tokens[2]);
+ preferences.add(new GenericPreference(userID, itemID, value));
+ }
+ } finally {
+ IOUtils.quietClose(in);
+ }
+ }
+ return preferences;
+ }
+}
\ No newline at end of file
Added:
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/ParallelFactorizationEvaluator.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/ParallelFactorizationEvaluator.java?rev=1091345&view=auto
==============================================================================
---
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/ParallelFactorizationEvaluator.java
(added)
+++
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/eval/ParallelFactorizationEvaluator.java
Tue Apr 12 09:40:58 2011
@@ -0,0 +1,154 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils.eval;
+
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
+import org.apache.mahout.cf.taste.hadoop.als.PredictionJob;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.IOUtils;
+import org.apache.mahout.common.IntPairWritable;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+
+import java.io.BufferedWriter;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.util.Map;
+
+/**
+ * <p>Measures the root-mean-squared error of a ratring matrix factorization
against a test set.</p>
+ *
+ * <p>Command line arguments specific to this class are:</p>
+ *
+ * <ol>
+ * <li>--output (path): path where output should go</li>
+ * <li>--pairs (path): path containing the test ratings, each line must be
userID,itemID,rating</li>
+ * <li>--userFeatures (path): path to the user feature matrix</li>
+ * <li>--itemFeatures (path): path to the item feature matrix</li>
+ * </ol>
+ */
+public class ParallelFactorizationEvaluator extends AbstractJob {
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new ParallelFactorizationEvaluator(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addOption("pairs", "p", "path containing the test ratings, each line must
be userID,itemID,rating", true);
+ addOption("userFeatures", "u", "path to the user feature matrix", true);
+ addOption("itemFeatures", "i", "path to the item feature matrix", true);
+ addOutputOption();
+
+ Map<String,String> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+
+ Path tempDir = new Path(parsedArgs.get("--tempDir"));
+ Path predictions = new Path(tempDir, "predictions");
+ Path errors = new Path(tempDir, "errors");
+
+ ToolRunner.run(getConf(), new PredictionJob(), new String[] { "--output",
predictions.toString(),
+ "--pairs", parsedArgs.get("--pairs"), "--userFeatures",
parsedArgs.get("--userFeatures"),
+ "--itemFeatures", parsedArgs.get("--itemFeatures"),
+ "--tempDir", tempDir.toString() });
+
+ Job estimationErrors = prepareJob(new Path(parsedArgs.get("--pairs") + ","
+ predictions.toString()), errors,
+ TextInputFormat.class, PairsWithRatingMapper.class,
IntPairWritable.class, DoubleWritable.class,
+ ErrorReducer.class, DoubleWritable.class, NullWritable.class,
SequenceFileOutputFormat.class);
+ estimationErrors.waitForCompletion(true);
+
+ BufferedWriter writer = null;
+ try {
+ FileSystem fs = FileSystem.get(getOutputPath().toUri(), getConf());
+ FSDataOutputStream outputStream = fs.create(new Path(getOutputPath(),
"rmse.txt"));
+ double rmse = computeRmse(errors);
+ writer = new BufferedWriter(new OutputStreamWriter(outputStream));
+ writer.write(String.valueOf(rmse));
+ } finally {
+ IOUtils.quietClose(writer);
+ }
+
+ return 0;
+ }
+
+ protected double computeRmse(Path errors) {
+ RunningAverage average = new FullRunningAverage();
+ for (Pair<DoubleWritable,NullWritable> entry :
+ new SequenceFileDirIterable<DoubleWritable, NullWritable>(errors,
PathType.LIST, getConf())) {
+ DoubleWritable error = entry.getFirst();
+ average.addDatum(error.get() * error.get());
+ }
+
+ return Math.sqrt(average.getAverage());
+ }
+
+ public static class PairsWithRatingMapper extends
Mapper<LongWritable,Text,IntPairWritable,DoubleWritable> {
+ @Override
+ protected void map(LongWritable key, Text value, Context ctx) throws
IOException, InterruptedException {
+ String[] tokens = TasteHadoopUtils.splitPrefTokens(value.toString());
+ int userIDIndex = TasteHadoopUtils.idToIndex(Long.parseLong(tokens[0]));
+ int itemIDIndex = TasteHadoopUtils.idToIndex(Long.parseLong(tokens[1]));
+ double rating = Double.parseDouble(tokens[2]);
+ ctx.write(new IntPairWritable(userIDIndex, itemIDIndex), new
DoubleWritable(rating));
+ }
+ }
+
+ public static class ErrorReducer extends
Reducer<IntPairWritable,DoubleWritable,DoubleWritable,NullWritable> {
+ @Override
+ protected void reduce(IntPairWritable key, Iterable<DoubleWritable>
ratingAndEstimate, Context ctx)
+ throws IOException, InterruptedException {
+
+ double error = Double.NaN;
+ boolean bothFound = false;
+ for (DoubleWritable ratingOrEstimate : ratingAndEstimate) {
+ if (Double.isNaN(error)) {
+ error = ratingOrEstimate.get();
+ } else {
+ error -= ratingOrEstimate.get();
+ bothFound = true;
+ break;
+ }
+ }
+
+ if (bothFound) {
+ ctx.write(new DoubleWritable(error), NullWritable.get());
+ }
+ }
+ }
+}
\ No newline at end of file
Added:
mahout/trunk/utils/src/test/java/org/apache/mahout/utils/eval/ParallelFactorizationEvaluatorTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/utils/src/test/java/org/apache/mahout/utils/eval/ParallelFactorizationEvaluatorTest.java?rev=1091345&view=auto
==============================================================================
---
mahout/trunk/utils/src/test/java/org/apache/mahout/utils/eval/ParallelFactorizationEvaluatorTest.java
(added)
+++
mahout/trunk/utils/src/test/java/org/apache/mahout/utils/eval/ParallelFactorizationEvaluatorTest.java
Tue Apr 12 09:40:58 2011
@@ -0,0 +1,76 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.utils.eval;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.common.IOUtils;
+import org.apache.mahout.math.hadoop.MathHelper;
+import org.junit.Test;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileReader;
+
+public class ParallelFactorizationEvaluatorTest extends TasteTestCase {
+
+ @Test
+ public void smallIntegration() throws Exception {
+
+ File pairs = getTestTempFile("pairs.txt");
+ File userFeatures = getTestTempFile("userFeatures.seq");
+ File itemFeatures = getTestTempFile("itemFeatures.seq");
+ File tempDir = getTestTempDir("temp");
+ File outputDir = getTestTempDir("out");
+ outputDir.delete();
+
+ Configuration conf = new Configuration();
+ Path inputPath = new Path(pairs.getAbsolutePath());
+ FileSystem fs = FileSystem.get(inputPath.toUri(), conf);
+
+ MathHelper.writeEntries(new double[][] {
+ new double[] { 1.5, -2, 0.3 },
+ new double[] { -0.7, 2, 0.6 },
+ new double[] { -1, 2.5, 3 } }, fs, conf, new
Path(userFeatures.getAbsolutePath()));
+
+ MathHelper.writeEntries(new double [][] {
+ new double[] { 2.3, 0.5, 0 },
+ new double[] { 4.7, -1, 0.2 },
+ new double[] { 0.6, 2, 1.3 } }, fs, conf, new
Path(itemFeatures.getAbsolutePath()));
+
+ writeLines(pairs, "0,0,3", "2,1,-7", "1,0,-2");
+
+ ParallelFactorizationEvaluator evaluator = new
ParallelFactorizationEvaluator();
+ evaluator.setConf(conf);
+ evaluator.run(new String[] { "--output", outputDir.getAbsolutePath(),
"--pairs", pairs.getAbsolutePath(),
+ "--userFeatures", userFeatures.getAbsolutePath(), "--itemFeatures",
itemFeatures.getAbsolutePath(),
+ "--tempDir", tempDir.getAbsolutePath() });
+
+ BufferedReader reader = null;
+ try {
+ reader = new BufferedReader(new FileReader(new File(outputDir,
"rmse.txt")));
+ double rmse = Double.parseDouble(reader.readLine());
+ assertEquals(0.89342, rmse, EPSILON);
+ } finally {
+ IOUtils.quietClose(reader);
+ }
+
+ }
+}
\ No newline at end of file