Author: gsingers
Date: Fri Jul 15 17:31:25 2011
New Revision: 1147257
URL: http://svn.apache.org/viewvc?rev=1147257&view=rev
Log:
MAHOUT-763: add map-side distance calculation
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java?rev=1147257&view=auto
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
(added)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
Fri Jul 15 17:31:25 2011
@@ -0,0 +1,140 @@
+package org.apache.mahout.math.hadoop.similarity;
+
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FileUtil;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.clustering.canopy.Canopy;
+import org.apache.mahout.clustering.kmeans.Cluster;
+import org.apache.mahout.common.StringTuple;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import
org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+/**
+ *
+ *
+ **/
+public class VectorDistanceMapper extends Mapper<WritableComparable<?>,
VectorWritable, StringTuple, DoubleWritable> {
+ private transient static Logger log =
LoggerFactory.getLogger(VectorDistanceMapper.class);
+ protected DistanceMeasure measure;
+ protected List<NamedVector> seedVectors;
+
+ @Override
+ protected void map(WritableComparable<?> key, VectorWritable value, Context
context) throws IOException, InterruptedException {
+ String keyName;
+ Vector valVec = value.get();
+ if (valVec instanceof NamedVector) {
+ keyName = ((NamedVector) valVec).getName();
+ } else {
+ keyName = key.toString();
+ }
+ for (NamedVector seedVector : seedVectors) {
+ double distance = measure.distance(seedVector, valVec);
+ StringTuple outKey = new StringTuple();
+ outKey.add(seedVector.getName());
+ outKey.add(keyName);
+ context.write(outKey, new DoubleWritable(distance));
+ }
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException,
InterruptedException {
+ super.setup(context);
+ Configuration conf = context.getConfiguration();
+ try {
+ ClassLoader ccl = Thread.currentThread().getContextClassLoader();
+ measure =
ccl.loadClass(conf.get(VectorDistanceSimilarityJob.DISTANCE_MEASURE_KEY))
+ .asSubclass(DistanceMeasure.class).newInstance();
+ measure.configure(conf);
+
+
+ String seedPathStr =
conf.get(VectorDistanceSimilarityJob.SEEDS_PATH_KEY);
+ if (seedPathStr != null && seedPathStr.length() > 0) {
+
+ Path thePath = new Path(seedPathStr, "*");
+ Collection<Path> result = Lists.newArrayList();
+
+ // get all filtered file names in result list
+ FileSystem fs = thePath.getFileSystem(conf);
+ FileStatus[] matches =
fs.listStatus(FileUtil.stat2Paths(fs.globStatus(thePath,
PathFilters.partFilter())),
+ PathFilters.partFilter());
+
+ for (FileStatus match : matches) {
+ result.add(fs.makeQualified(match.getPath()));
+ }
+ seedVectors = new ArrayList<NamedVector>(100);
+ long item = 0;
+ for (Path seedPath : result) {
+ for (Writable value : new
SequenceFileValueIterable<Writable>(seedPath, conf)) {
+ Class<? extends Writable> valueClass = value.getClass();
+ if (valueClass.equals(Cluster.class)) {
+ // get the cluster info
+ Cluster cluster = (Cluster) value;
+ Vector vector = cluster.getCenter();
+ if (vector instanceof NamedVector) {
+ seedVectors.add((NamedVector) vector);
+ } else {
+ seedVectors.add(new NamedVector(vector,
cluster.getIdentifier()));
+ }
+ } else if (valueClass.equals(Canopy.class)) {
+ // get the cluster info
+ Canopy canopy = (Canopy) value;
+ Vector vector = canopy.getCenter();
+ if (vector instanceof NamedVector) {
+ seedVectors.add((NamedVector) vector);
+ } else {
+ seedVectors.add(new NamedVector(vector,
canopy.getIdentifier()));
+ }
+ } else if (valueClass.equals(Vector.class)) {
+ Vector vector = (Vector) value;
+ if (vector instanceof NamedVector) {
+ seedVectors.add((NamedVector) vector);
+ } else {
+ seedVectors.add(new NamedVector(vector, seedPath + "." +
item++));
+ }
+ } else if (valueClass.equals(VectorWritable.class) ||
valueClass.isInstance(VectorWritable.class)) {
+ VectorWritable vw = (VectorWritable) value;
+ Vector vector = vw.get();
+ if (vector instanceof NamedVector) {
+ seedVectors.add((NamedVector) vector);
+ } else {
+ seedVectors.add(new NamedVector(vector, seedPath + "." +
item++));
+ }
+ } else {
+ throw new IllegalStateException("Bad value class: " +
valueClass);
+ }
+ }
+ }
+ if (seedVectors.isEmpty()) {
+ throw new IllegalStateException("No seeds found. Check your path: "
+ seedPathStr);
+ } else {
+ log.info("Seed Vectors size: " + seedVectors.size());
+ }
+ }
+ } catch (ClassNotFoundException e) {
+ throw new IllegalStateException(e);
+ } catch (IllegalAccessException e) {
+ throw new IllegalStateException(e);
+ } catch (InstantiationException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+}
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java?rev=1147257&view=auto
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
(added)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
Fri Jul 15 17:31:25 2011
@@ -0,0 +1,112 @@
+package org.apache.mahout.math.hadoop.similarity;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.StringTuple;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+
+/**
+ * This class does a Map-side join between seed vectors (the map side can also
be a Cluster) and a list of other vectors
+ * and emits the a tuple of seed id, other id, distance. It is a more generic
version of KMean's mapper
+ */
+public class VectorDistanceSimilarityJob extends AbstractJob {
+ private static final Logger log =
LoggerFactory.getLogger(VectorDistanceSimilarityJob.class);
+ public static final String SEEDS = "seeds";
+ public static final String SEEDS_PATH_KEY = "seedsPath";
+ public static final String DISTANCE_MEASURE_KEY = "vectorDistSim.measure";
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(),
args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOutputOption();
+ addOption(DefaultOptionCreator.distanceMeasureOption().create());
+ addOption(SEEDS, "s", "The set of vectors to compute distances against.
Must fit in memory on the mapper");
+ addOption(DefaultOptionCreator.overwriteOption().create());
+
+ if (parseArguments(args) == null) {
+ return -1;
+ }
+
+ Path input = getInputPath();
+ Path output = getOutputPath();
+ Path seeds = new Path(getOption(SEEDS));
+ String measureClass =
getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
+ if (measureClass == null) {
+ measureClass = SquaredEuclideanDistanceMeasure.class.getName();
+ }
+ if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+ HadoopUtil.delete(getConf(), output);
+ }
+ ClassLoader ccl = Thread.currentThread().getContextClassLoader();
+ DistanceMeasure measure =
ccl.loadClass(measureClass).asSubclass(DistanceMeasure.class).newInstance();
+ if (getConf() == null) {
+ setConf(new Configuration());
+ }
+ run(getConf(), input, seeds, output, measure);
+ return 0;
+ }
+
+ public static void run(Configuration conf,
+ Path input,
+ Path seeds,
+ Path output,
+ DistanceMeasure measure) throws IOException,
ClassNotFoundException, InterruptedException {
+ conf.set(DISTANCE_MEASURE_KEY, measure.getClass().getName());
+ conf.set(SEEDS_PATH_KEY, seeds.toString());
+ Job job = new Job(conf, "Vector Distance Similarity: seeds: " + seeds + "
input: " + input);
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ job.setMapOutputKeyClass(StringTuple.class);
+ job.setOutputKeyClass(StringTuple.class);
+ job.setMapOutputValueClass(DoubleWritable.class);
+ job.setOutputValueClass(DoubleWritable.class);
+ job.setMapperClass(VectorDistanceMapper.class);
+
+ job.setNumReduceTasks(0);
+ FileInputFormat.addInputPath(job, input);
+ FileOutputFormat.setOutputPath(job, output);
+
+ job.setJarByClass(VectorDistanceSimilarityJob.class);
+ HadoopUtil.delete(conf, output);
+ if (!job.waitForCompletion(true)) {
+ throw new InterruptedException("VectorDistance Similarity failed
processing " + seeds);
+ }
+ }
+}
Added:
mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java?rev=1147257&view=auto
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java
(added)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java
Fri Jul 15 17:31:25 2011
@@ -0,0 +1,136 @@
+package org.apache.mahout.math.hadoop.similarity;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.clustering.ClusteringTestUtils;
+import org.apache.mahout.clustering.kmeans.KMeansDriver;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.StringTuple;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.easymock.EasyMock;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ *
+ *
+ **/
+public class TestVectorDistanceSimilarityJob extends MahoutTestCase {
+ private FileSystem fs;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ Configuration conf = new Configuration();
+ fs = FileSystem.get(conf);
+ }
+
+ @Test
+ public void testVectorDistanceMapper() throws Exception {
+ Mapper<WritableComparable<?>, VectorWritable, StringTuple,
DoubleWritable>.Context context =
+ EasyMock.createMock(Mapper.Context.class);
+ StringTuple tuple;
+ tuple = new StringTuple();
+ tuple.add("foo");
+ tuple.add("123");
+ context.write(tuple, new DoubleWritable(Math.sqrt(2.0)));
+ tuple = new StringTuple();
+ tuple.add("foo2");
+ tuple.add("123");
+ context.write(tuple, new DoubleWritable(1));
+
+ EasyMock.replay(context);
+
+ Vector vector = new RandomAccessSparseVector(2);
+ vector.set(0, 2);
+ vector.set(1, 2);
+
+ VectorDistanceMapper mapper = new VectorDistanceMapper();
+ setField(mapper, "measure", new EuclideanDistanceMeasure());
+ List<NamedVector> seedVectors = new ArrayList<NamedVector>();
+ Vector seed1 = new RandomAccessSparseVector(2);
+ seed1.set(0, 1);
+ seed1.set(1, 1);
+ Vector seed2 = new RandomAccessSparseVector(2);
+ seed2.set(0, 2);
+ seed2.set(1, 1);
+
+ seedVectors.add(new NamedVector(seed1, "foo"));
+ seedVectors.add(new NamedVector(seed2, "foo2"));
+ setField(mapper, "seedVectors", seedVectors);
+
+ mapper.map(new IntWritable(123), new VectorWritable(vector), context);
+
+ EasyMock.verify(context);
+
+ }
+
+ public static final double[][] REFERENCE = {
+ {1, 1}, {2, 1}, {1, 2}, {2, 2}, {3, 3}, {4, 4}, {5, 4}, {4, 5}, {5,
5}
+ };
+
+ public static final double[][] SEEDS = {
+ {1, 1}, {10, 10}
+ };
+
+ @Test
+ public void testRun() throws Exception {
+ Path input = getTestTempDirPath("input");
+ Path output = getTestTempDirPath("output");
+ Path seedsPath = getTestTempDirPath("seeds");
+ List<VectorWritable> points = getPointsWritable(REFERENCE);
+ List<VectorWritable> seeds = getPointsWritable(SEEDS);
+ Configuration conf = new Configuration();
+ ClusteringTestUtils.writePointsToFile(points, true, new Path(input,
"file1"), fs, conf);
+ ClusteringTestUtils.writePointsToFile(seeds, true, new Path(seedsPath,
"part-seeds"), fs, conf);
+ String[] args = {optKey(DefaultOptionCreator.INPUT_OPTION),
input.toString(),
+ optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(),
optKey(DefaultOptionCreator.OUTPUT_OPTION),
+ output.toString(),
optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
EuclideanDistanceMeasure.class.getName()
+ };
+ ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(),
args);
+ }
+
+ public static List<VectorWritable> getPointsWritable(double[][] raw) {
+ List<VectorWritable> points = Lists.newArrayList();
+ for (double[] fr : raw) {
+ Vector vec = new RandomAccessSparseVector(fr.length);
+ vec.assign(fr);
+ points.add(new VectorWritable(vec));
+ }
+ return points;
+ }
+
+}