Author: luc Date: Mon Jun 20 19:48:44 2011 New Revision: 1137759 URL: http://svn.apache.org/viewvc?rev=1137759&view=rev Log: added multiple trials runs to K-means++ clustering algorithm.
JIRA: MATH-548 Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java commons/proper/math/trunk/src/site/xdoc/changes.xml commons/proper/math/trunk/src/test/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClustererTest.java Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java?rev=1137759&r1=1137758&r2=1137759&view=diff ============================================================================== --- commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java (original) +++ commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java Mon Jun 20 19:48:44 2011 @@ -96,6 +96,60 @@ public class KMeansPlusPlusClusterer<T e * of clusters is larger than the number of data points */ public List<Cluster<T>> cluster(final Collection<T> points, final int k, + int numTrials, int maxIterationsPerTrial) + throws MathIllegalArgumentException { + + // at first, we have not found any clusters list yet + List<Cluster<T>> best = null; + double bestVarianceSum = Double.POSITIVE_INFINITY; + + // do several clustering trials + for (int i = 0; i < numTrials; ++i) { + + // compute a clusters list + List<Cluster<T>> clusters = cluster(points, k, maxIterationsPerTrial); + + // compute the variance of the current list + double varianceSum = 0.0; + for (final Cluster<T> cluster : clusters) { + if (!cluster.getPoints().isEmpty()) { + + // compute the distance variance of the current cluster + final T center = cluster.getCenter(); + final Variance stat = new Variance(); + for (final T point : cluster.getPoints()) { + stat.increment(point.distanceFrom(center)); + } + varianceSum += stat.getResult(); + + } + } + + if (varianceSum <= bestVarianceSum) { + // this one is the best we have found so far, remember it + best = clusters; + bestVarianceSum = varianceSum; + } + + } + + // return the best clusters list found + return best; + + } + + /** + * Runs the K-means++ clustering algorithm. + * + * @param points the points to cluster + * @param k the number of clusters to split the data into + * @param maxIterations the maximum number of iterations to run the algorithm + * for. If negative, no maximum will be used + * @return a list of clusters containing the points + * @throws MathIllegalArgumentException if the data points are null or the number + * of clusters is larger than the number of data points + */ + public List<Cluster<T>> cluster(final Collection<T> points, final int k, final int maxIterations) throws MathIllegalArgumentException { Modified: commons/proper/math/trunk/src/site/xdoc/changes.xml URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/site/xdoc/changes.xml?rev=1137759&r1=1137758&r2=1137759&view=diff ============================================================================== --- commons/proper/math/trunk/src/site/xdoc/changes.xml (original) +++ commons/proper/math/trunk/src/site/xdoc/changes.xml Mon Jun 20 19:48:44 2011 @@ -52,6 +52,9 @@ The <action> type attribute can be add,u If the output is not quite correct, check for invisible trailing spaces! --> <release version="3.0" date="TBD" description="TBD"> + <action dev="luc" type="add" issue="MATH-548"> + K-means++ clustering can now run multiple trials + </action> <action dev="luc" type="add" issue="MATH-591"> Added a way to compute sub-lines intersections, considering sub-lines either as open sets or closed sets Modified: commons/proper/math/trunk/src/test/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClustererTest.java URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClustererTest.java?rev=1137759&r1=1137758&r2=1137759&view=diff ============================================================================== --- commons/proper/math/trunk/src/test/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClustererTest.java (original) +++ commons/proper/math/trunk/src/test/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClustererTest.java Mon Jun 20 19:48:44 2011 @@ -65,7 +65,7 @@ public class KMeansPlusPlusClustererTest }; List<Cluster<EuclideanIntegerPoint>> clusters = - transformer.cluster(Arrays.asList(points), 3, 10); + transformer.cluster(Arrays.asList(points), 3, 5, 10); Assert.assertEquals(3, clusters.size()); boolean cluster1Found = false;