This is an automated email from the ASF dual-hosted git repository.

erans pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-math.git

commit c6b4ca908cc53c6d5d89e911391337086062e74a
Author: Gilles Sadowski <gillese...@gmail.com>
AuthorDate: Sat Jan 22 18:53:17 2022 +0100

    MATH-1640: Do not try to outguess the caller.
---
 .../ml/clustering/KMeansPlusPlusClusterer.java     | 24 +++++++++++++------
 .../ml/clustering/KMeansPlusPlusClustererTest.java | 28 +++++++++-------------
 .../clustering/MiniBatchKMeansClustererTest.java   |  4 ++--
 .../evaluation/CalinskiHarabaszTest.java           |  4 ++--
 4 files changed, 32 insertions(+), 28 deletions(-)

diff --git 
a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/ml/clustering/KMeansPlusPlusClusterer.java
 
b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/ml/clustering/KMeansPlusPlusClusterer.java
index 9890194..57ab663 100644
--- 
a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/ml/clustering/KMeansPlusPlusClusterer.java
+++ 
b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/ml/clustering/KMeansPlusPlusClusterer.java
@@ -19,6 +19,7 @@ package org.apache.commons.math4.legacy.ml.clustering;
 
 import org.apache.commons.math4.legacy.exception.NullArgumentException;
 import org.apache.commons.math4.legacy.exception.ConvergenceException;
+import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
 import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
 import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
 import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
@@ -79,7 +80,7 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> 
extends Clusterer<T>
      * @param k the number of clusters to split the data into
      */
     public KMeansPlusPlusClusterer(final int k) {
-        this(k, -1);
+        this(k, Integer.MAX_VALUE);
     }
 
     /** Build a clusterer.
@@ -104,8 +105,8 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> 
extends Clusterer<T>
      *
      * @param k the number of clusters to split the data into
      * @param maxIterations the maximum number of iterations to run the 
algorithm for.
-     *   If negative, no maximum will be used.
      * @param measure the distance measure to use
+     * @throws NotStrictlyPositiveException if {@code k <= 0}.
      */
     public KMeansPlusPlusClusterer(final int k, final int maxIterations, final 
DistanceMeasure measure) {
         this(k, maxIterations, measure, RandomSource.MT_64.create());
@@ -132,20 +133,30 @@ public class KMeansPlusPlusClusterer<T extends 
Clusterable> extends Clusterer<T>
      *
      * @param k the number of clusters to split the data into
      * @param maxIterations the maximum number of iterations to run the 
algorithm for.
-     *   If negative, no maximum will be used.
      * @param measure the distance measure to use
      * @param random random generator to use for choosing initial centers
      * @param emptyStrategy strategy to use for handling empty clusters that
      * may appear during algorithm iterations
+     * @throws NotStrictlyPositiveException if {@code k <= 0} or
+     * {@code maxIterations <= 0}.
      */
-    public KMeansPlusPlusClusterer(final int k, final int maxIterations,
+    public KMeansPlusPlusClusterer(final int k,
+                                   final int maxIterations,
                                    final DistanceMeasure measure,
                                    final UniformRandomProvider random,
                                    final EmptyClusterStrategy emptyStrategy) {
         super(measure);
+
+        if (k <= 0) {
+            throw new NotStrictlyPositiveException(k);
+        }
+        if (maxIterations <= 0) {
+            throw new NotStrictlyPositiveException(maxIterations);
+        }
+
         this.numberOfClusters = k;
         this.maxIterations = maxIterations;
-        this.random        = random;
+        this.random = random;
         this.emptyStrategy = emptyStrategy;
     }
 
@@ -195,8 +206,7 @@ public class KMeansPlusPlusClusterer<T extends Clusterable> 
extends Clusterer<T>
         assignPointsToClusters(clusters, points, assignments);
 
         // iterate through updating the centers until we're done
-        final int max = (maxIterations < 0) ? Integer.MAX_VALUE : 
maxIterations;
-        for (int count = 0; count < max; count++) {
+        for (int count = 0; count < maxIterations; count++) {
             boolean hasEmptyCluster = 
clusters.stream().anyMatch(cluster->cluster.getPoints().isEmpty());
             List<CentroidCluster<T>> newClusters = 
adjustClustersCenters(clusters);
             int changes = assignPointsToClusters(newClusters, points, 
assignments);
diff --git 
a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/KMeansPlusPlusClustererTest.java
 
b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/KMeansPlusPlusClustererTest.java
index a9e4979..a7f63b7 100644
--- 
a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/KMeansPlusPlusClustererTest.java
+++ 
b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/KMeansPlusPlusClustererTest.java
@@ -133,27 +133,21 @@ public class KMeansPlusPlusClustererTest {
     public void testSmallDistances() {
         // Create a bunch of CloseDoublePoints. Most are identical, but one is 
different by a
         // small distance.
-        int[] repeatedArray = { 0 };
-        int[] uniqueArray = { 1 };
-        DoublePoint repeatedPoint = new DoublePoint(repeatedArray);
-        DoublePoint uniquePoint = new DoublePoint(uniqueArray);
-
-        Collection<DoublePoint> points = new ArrayList<>();
-        final int NUM_REPEATED_POINTS = 10 * 1000;
-        for (int i = 0; i < NUM_REPEATED_POINTS; ++i) {
+        final int[] repeatedArray = { 0 };
+        final int[] uniqueArray = { 1 };
+        final DoublePoint repeatedPoint = new DoublePoint(repeatedArray);
+        final DoublePoint uniquePoint = new DoublePoint(uniqueArray);
+
+        final Collection<DoublePoint> points = new ArrayList<>();
+        final int numRepeated = 10000;
+        for (int i = 0; i < numRepeated; i++) {
             points.add(repeatedPoint);
         }
         points.add(uniquePoint);
 
-        // Ask a KMeansPlusPlusClusterer to run zero iterations (i.e., to 
simply choose initial
-        // cluster centers).
-        final int NUM_CLUSTERS = 2;
-        final int NUM_ITERATIONS = 0;
-
-        KMeansPlusPlusClusterer<DoublePoint> clusterer =
-            new KMeansPlusPlusClusterer<>(NUM_CLUSTERS, NUM_ITERATIONS,
-                    new CloseDistance(), random);
-        List<CentroidCluster<DoublePoint>> clusters = 
clusterer.cluster(points);
+        final KMeansPlusPlusClusterer<DoublePoint> clusterer =
+            new KMeansPlusPlusClusterer<>(2, 1, new CloseDistance(), random);
+        final List<CentroidCluster<DoublePoint>> clusters = 
clusterer.cluster(points);
 
         // Check that one of the chosen centers is the unique point.
         boolean uniquePointIsCenter = false;
diff --git 
a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/MiniBatchKMeansClustererTest.java
 
b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/MiniBatchKMeansClustererTest.java
index ca6c7d1..8b9c222 100644
--- 
a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/MiniBatchKMeansClustererTest.java
+++ 
b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/MiniBatchKMeansClustererTest.java
@@ -58,9 +58,9 @@ public class MiniBatchKMeansClustererTest {
         final UniformRandomProvider rng = RandomSource.MT_64.create();
         List<DoublePoint> data = generateCircles(rng);
         KMeansPlusPlusClusterer<DoublePoint> kMeans =
-            new KMeansPlusPlusClusterer<>(4, -1, DEFAULT_MEASURE, rng);
+            new KMeansPlusPlusClusterer<>(4, Integer.MAX_VALUE, 
DEFAULT_MEASURE, rng);
         MiniBatchKMeansClusterer<DoublePoint> miniBatchKMeans =
-            new MiniBatchKMeansClusterer<>(4, -1, 100, 3, 300, 10, 
DEFAULT_MEASURE, rng,
+            new MiniBatchKMeansClusterer<>(4, Integer.MAX_VALUE, 100, 3, 300, 
10, DEFAULT_MEASURE, rng,
                                            
KMeansPlusPlusClusterer.EmptyClusterStrategy.LARGEST_VARIANCE);
         // Test 100 times between KMeansPlusPlusClusterer and 
MiniBatchKMeansClusterer
         for (int i = 0; i < 100; i++) {
diff --git 
a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/evaluation/CalinskiHarabaszTest.java
 
b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/evaluation/CalinskiHarabaszTest.java
index 03f5295..f92c18d 100644
--- 
a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/evaluation/CalinskiHarabaszTest.java
+++ 
b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/ml/clustering/evaluation/CalinskiHarabaszTest.java
@@ -63,7 +63,7 @@ public class CalinskiHarabaszTest {
         double actualBestScore = 0.0;
         for (int i = 0; i < 5; i++) {
             final int k = i + 2;
-            KMeansPlusPlusClusterer<DoublePoint> kMeans = new 
KMeansPlusPlusClusterer<>(k, -1, distanceMeasure, rnd);
+            KMeansPlusPlusClusterer<DoublePoint> kMeans = new 
KMeansPlusPlusClusterer<>(k, Integer.MAX_VALUE, distanceMeasure, rnd);
             List<CentroidCluster<DoublePoint>> clusters = 
kMeans.cluster(points);
             double score = evaluator.score(clusters);
             if (score > expectBestScore) {
@@ -89,7 +89,7 @@ public class CalinskiHarabaszTest {
         double actualBestScore = 0.0;
         for (int i = 0; i < 5; i++) {
             final int k = i + 2;
-            KMeansPlusPlusClusterer<DoublePoint> kMeans = new 
KMeansPlusPlusClusterer<>(k, -1, distanceMeasure, rnd);
+            KMeansPlusPlusClusterer<DoublePoint> kMeans = new 
KMeansPlusPlusClusterer<>(k, Integer.MAX_VALUE, distanceMeasure, rnd);
             List<CentroidCluster<DoublePoint>> clusters = 
kMeans.cluster(points);
             double score = evaluator.score(clusters);
             if (score > expectBestScore) {

Reply via email to