Author: ssc
Date: Tue Mar 12 06:43:03 2013
New Revision: 1455420

URL: http://svn.apache.org/r1455420
Log:
MAHOUT-1130 Wrong logic in 
org.apache.mahout.clustering.kmeans.RandomSeedGenerator

Modified:
    
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java?rev=1455420&r1=1455419&r2=1455420&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java
 Tue Mar 12 06:43:03 2013
@@ -21,6 +21,7 @@ import java.io.IOException;
 import java.util.List;
 import java.util.Random;
 
+import com.google.common.base.Preconditions;
 import com.google.common.collect.Lists;
 import com.google.common.io.Closeables;
 import org.apache.hadoop.conf.Configuration;
@@ -45,6 +46,8 @@ import org.slf4j.LoggerFactory;
  * Given an Input Path containing a {@link org.apache.hadoop.io.SequenceFile}, 
randomly select k vectors and
  * write them to the output file as a {@link 
org.apache.mahout.clustering.kmeans.Kluster} representing the
  * initial centroid to use.
+ *
+ * This implementation uses reservoir sampling as described in 
http://en.wikipedia.org/wiki/Reservoir_sampling
  */
 public final class RandomSeedGenerator {
   
@@ -60,6 +63,8 @@ public final class RandomSeedGenerator {
                                  Path output,
                                  int k,
                                  DistanceMeasure measure) throws IOException {
+
+    Preconditions.checkArgument(k > 0);
     // delete the output directory
     FileSystem fs = FileSystem.get(output.toUri(), conf);
     HadoopUtil.delete(conf, output);
@@ -80,7 +85,8 @@ public final class RandomSeedGenerator {
       List<Text> chosenTexts = Lists.newArrayListWithCapacity(k);
       List<ClusterWritable> chosenClusters = Lists.newArrayListWithCapacity(k);
       int nextClusterId = 0;
-      
+
+      int index = 0;
       for (FileStatus fileStatus : inputFiles) {
         if (fileStatus.isDir()) {
           continue;
@@ -98,15 +104,16 @@ public final class RandomSeedGenerator {
             ClusterWritable clusterWritable = new ClusterWritable();
             clusterWritable.setValue(newCluster);
             chosenClusters.add(clusterWritable);
-          } else if (random.nextInt(currentSize + 1) != 0) { // with chance 
1/(currentSize+1) pick new element
-            int indexToRemove = random.nextInt(currentSize); // evict one 
chosen randomly
-            chosenTexts.remove(indexToRemove);
-            chosenClusters.remove(indexToRemove);
-            chosenTexts.add(newText);
-            ClusterWritable clusterWritable = new ClusterWritable();
-            clusterWritable.setValue(newCluster);
-            chosenClusters.add(clusterWritable);
+          } else {
+            int j = random.nextInt(index);
+            if (j < k) {
+              chosenTexts.set(j, newText);
+              ClusterWritable clusterWritable = new ClusterWritable();
+              clusterWritable.setValue(newCluster);
+              chosenClusters.set(j, clusterWritable);
+            }
           }
+          index++;
         }
       }
 


Reply via email to