Repository: commons-rng
Updated Branches:
  refs/heads/master 1bbf43bbc -> 477598909


RNG-64: Created SubsetSampler utility class

Move shared code from the PermutationSampler and CombinationSampler to
the utility class.

Test the PermutationSampler with k < n.


Project: http://git-wip-us.apache.org/repos/asf/commons-rng/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-rng/commit/42a2b516
Tree: http://git-wip-us.apache.org/repos/asf/commons-rng/tree/42a2b516
Diff: http://git-wip-us.apache.org/repos/asf/commons-rng/diff/42a2b516

Branch: refs/heads/master
Commit: 42a2b5160a7d73bc7d4a0beae396ec0daa31c1ae
Parents: 1bbf43b
Author: Alex Herbert <a.herb...@sussex.ac.uk>
Authored: Fri Nov 23 12:51:13 2018 +0000
Committer: Gilles <er...@apache.org>
Committed: Fri Nov 23 14:39:42 2018 +0100

----------------------------------------------------------------------
 .../rng/sampling/CombinationSampler.java        | 64 ++-----------
 .../rng/sampling/PermutationSampler.java        | 35 +------
 .../rng/sampling/SubsetSamplerUtils.java        | 96 ++++++++++++++++++++
 .../rng/sampling/PermutationSamplerTest.java    | 67 ++++++++------
 4 files changed, 148 insertions(+), 114 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/commons-rng/blob/42a2b516/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CombinationSampler.java
----------------------------------------------------------------------
diff --git 
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CombinationSampler.java
 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CombinationSampler.java
index ca95f42..31159c4 100644
--- 
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CombinationSampler.java
+++ 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/CombinationSampler.java
@@ -41,17 +41,12 @@ import org.apache.commons.rng.UniformRandomProvider;
 public class CombinationSampler {
     /** Domain of the combination. */
     private final int[] domain;
-    /** Size of the combination. */
-    private final int size;
     /** The number of steps of a full shuffle to perform. */
     private final int steps;
     /**
-     * The position to copy the domain from after a partial shuffle.
-     *
-     * <p>The copy is either in the range [0 : size] or [domain.length - size :
-     * domain.length].
+     * The section to copy the domain from after a partial shuffle.
      */
-    private final int from;
+    private final boolean upper;
     /** RNG. */
     private final UniformRandomProvider rng;
 
@@ -77,34 +72,19 @@ public class CombinationSampler {
      * @throws IllegalArgumentException if {@code n <= 0} or {@code k <= 0} or
      *                                  {@code k > n}.
      */
-    public CombinationSampler(UniformRandomProvider rng, int n, int k) {
-        if (n <= 0) {
-            throw new IllegalArgumentException("n <= 0 : n=" + n);
-        }
-        if (k <= 0) {
-            throw new IllegalArgumentException("k <= 0 : k=" + k);
-        }
-        if (k > n) {
-            throw new IllegalArgumentException("k > n : k=" + k + ", n=" + n);
-        }
-
+    public CombinationSampler(UniformRandomProvider rng,
+                              int n,
+                              int k) {
+        SubsetSamplerUtils.checkSubset(n, k);
         domain = PermutationSampler.natural(n);
-        size = k;
         // The sample can be optimised by only performing the first k or (n - 
k) steps
         // from a full Fisher-Yates shuffle from the end of the domain to the 
start.
         // The upper positions will then contain a random sample from the 
domain. The
         // lower half is then by definition also a random sample (just not in 
a random order).
         // The sample is then picked using the upper or lower half depending 
which
         // makes the number of steps smaller.
-        if (k <= n / 2) {
-            // Upper half
-            steps = k;
-            from = n - k;
-        } else {
-            // Lower half
-            steps = n - k;
-            from = 0;
-        }
+        upper = k <= n / 2;
+        steps = upper ? k : n - k;
         this.rng = rng;
     }
 
@@ -118,32 +98,6 @@ public class CombinationSampler {
      * @return a random combination.
      */
     public int[] sample() {
-        // Shuffle from the end but limit to a number of steps.
-        // The subset C(n, k) is then either those positions that have
-        // been sampled, or those that haven't been, depending
-        // on the number of steps.
-        // Note: if n==k the number of steps is zero and the result
-        // is just a clone of the domain.
-        for (int i = domain.length - 1,
-                j = 0; i > 0 && j < steps; i--, j++) {
-            // Swap index i with any position down to 0 (including itself)
-            swap(domain, i, rng.nextInt(i + 1));
-        }
-        final int[] result = new int[size];
-        System.arraycopy(domain, from, result, 0, size);
-        return result;
-    }
-
-    /**
-     * Swaps the two specified elements in the specified array.
-     *
-     * @param array the array
-     * @param i     the first index
-     * @param j     the second index
-     */
-    private static void swap(int[] array, int i, int j) {
-        final int tmp = array[i];
-        array[i] = array[j];
-        array[j] = tmp;
+        return SubsetSamplerUtils.partialSample(domain, steps, rng, upper);
     }
 }

http://git-wip-us.apache.org/repos/asf/commons-rng/blob/42a2b516/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/PermutationSampler.java
----------------------------------------------------------------------
diff --git 
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/PermutationSampler.java
 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/PermutationSampler.java
index fe27f33..4b27670 100644
--- 
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/PermutationSampler.java
+++ 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/PermutationSampler.java
@@ -17,8 +17,6 @@
 
 package org.apache.commons.rng.sampling;
 
-import java.util.Arrays;
-
 import org.apache.commons.rng.UniformRandomProvider;
 
 /**
@@ -48,22 +46,13 @@ public class PermutationSampler {
      * @param rng Generator of uniformly distributed random numbers.
      * @param n Domain of the permutation.
      * @param k Size of the permutation.
-     * @throws IllegalArgumentException if {@code n < 0} or {@code k <= 0}
+     * @throws IllegalArgumentException if {@code n <= 0} or {@code k <= 0}
      * or {@code k > n}.
      */
     public PermutationSampler(UniformRandomProvider rng,
                               int n,
                               int k) {
-        if (n < 0) {
-            throw new IllegalArgumentException("n < 0 : n=" + n);
-        }
-        if (k <= 0) {
-            throw new IllegalArgumentException("k <= 0 : k=" + k);
-        }
-        if (k > n) {
-            throw new IllegalArgumentException("k > n : k=" + k + ", n=" + n);
-        }
-
+        SubsetSamplerUtils.checkSubset(n, k);
         domain = natural(n);
         size = k;
         this.rng = rng;
@@ -75,8 +64,7 @@ public class PermutationSampler {
      * @see #PermutationSampler(UniformRandomProvider,int,int)
      */
     public int[] sample() {
-        shuffle(rng, domain);
-        return Arrays.copyOf(domain, size);
+        return SubsetSamplerUtils.partialSample(domain, size, rng, true);
     }
 
     /**
@@ -115,7 +103,7 @@ public class PermutationSampler {
             // Do not visit 0 to avoid a swap with itself.
             for (int i = start; i > 0; i--) {
                 // Swap index with any position down to 0
-                swap(list, i, rng.nextInt(i + 1));
+                SubsetSamplerUtils.swap(list, i, rng.nextInt(i + 1));
             }
         } else {
             // Visit all positions from the end to start.
@@ -123,25 +111,12 @@ public class PermutationSampler {
             for (int i = list.length - 1; i > start; i--) {
                 // Swap index with any position down to start.
                 // Note: i - start + 1 is the number of elements remaining.
-                swap(list, i, rng.nextInt(i - start + 1) + start);
+                SubsetSamplerUtils.swap(list, i, rng.nextInt(i - start + 1) + 
start);
             }
         }
     }
 
     /**
-     * Swaps the two specified elements in the specified array.
-     *
-     * @param array the array
-     * @param i the first index
-     * @param j the second index
-     */
-    private static void swap(int[] array, int i, int j) {
-        final int tmp = array[i];
-        array[i] = array[j];
-        array[j] = tmp;
-    }
-
-    /**
      * Creates an array representing the natural number {@code n}.
      *
      * @param n Natural number.

http://git-wip-us.apache.org/repos/asf/commons-rng/blob/42a2b516/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/SubsetSamplerUtils.java
----------------------------------------------------------------------
diff --git 
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/SubsetSamplerUtils.java
 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/SubsetSamplerUtils.java
new file mode 100644
index 0000000..32d58a1
--- /dev/null
+++ 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/SubsetSamplerUtils.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.rng.sampling;
+
+import org.apache.commons.rng.UniformRandomProvider;
+
+/**
+ * Utility class for selecting a subset of a sequence of integers.
+ */
+final class SubsetSamplerUtils {
+
+    /** No public construction. */
+    private SubsetSamplerUtils() {}
+
+    /**
+     * Checks the subset of length {@code k} from {@code n} is valid.
+     *
+     * <p>If {@code n <= 0} or {@code k <= 0} or {@code k > n} then no subset
+     * is required and an exception is raised.
+     *
+     * @param n   Size of the set.
+     * @param k   Size of the subset.
+     * @throws IllegalArgumentException if {@code n <= 0} or {@code k <= 0} or
+     *                                  {@code k > n}.
+     */
+    static void checkSubset(int n,
+                            int k) {
+        if (n <= 0) {
+            throw new IllegalArgumentException("n <= 0 : n=" + n);
+        }
+        if (k <= 0) {
+            throw new IllegalArgumentException("k <= 0 : k=" + k);
+        }
+        if (k > n) {
+            throw new IllegalArgumentException("k > n : k=" + k + ", n=" + n);
+        }
+    }
+
+    /**
+     * Perform a partial Fisher-Yates shuffle of the domain in-place and return
+     * either the upper fully shuffled section or the remaining lower partially
+     * shuffled section.
+     *
+     * <p>The returned combination will have a length of {@code steps} for
+     * {@code upper=true}, or {@code domain.length - steps} otherwise.
+     *
+     * @param domain The domain.
+     * @param steps  The number of shuffle steps.
+     * @param rng    Generator of uniformly distributed random numbers.
+     * @param upper  Set to true to return the upper fully shuffled section.
+     * @return a random combination.
+     */
+    static int[] partialSample(int[] domain,
+                               int steps,
+                               UniformRandomProvider rng,
+                               boolean upper) {
+        // Shuffle from the end but limit to a number of steps.
+        for (int i = domain.length - 1, j = 0; i > 0 && j < steps; i--, j++) {
+            // Swap index i with any position down to 0 (including itself)
+            swap(domain, i, rng.nextInt(i + 1));
+        }
+        final int size = upper ? steps : domain.length - steps;
+        final int from = upper ? domain.length - steps : 0;
+        final int[] result = new int[size];
+        System.arraycopy(domain, from, result, 0, size);
+        return result;
+    }
+
+    /**
+     * Swaps the two specified elements in the specified array.
+     *
+     * @param array the array
+     * @param i     the first index
+     * @param j     the second index
+     */
+    static void swap(int[] array, int i, int j) {
+        final int tmp = array[i];
+        array[i] = array[j];
+        array[j] = tmp;
+    }
+}

http://git-wip-us.apache.org/repos/asf/commons-rng/blob/42a2b516/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/PermutationSamplerTest.java
----------------------------------------------------------------------
diff --git 
a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/PermutationSamplerTest.java
 
b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/PermutationSamplerTest.java
index a299f2c..c887147 100644
--- 
a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/PermutationSamplerTest.java
+++ 
b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/PermutationSamplerTest.java
@@ -16,17 +16,12 @@
  */
 package org.apache.commons.rng.sampling;
 
-import java.util.Set;
-import java.util.HashSet;
-import java.util.List;
-import java.util.ArrayList;
 import java.util.Arrays;
 
 import org.junit.Assert;
 import org.junit.Test;
 
 import org.apache.commons.math3.stat.inference.ChiSquareTest;
-import org.apache.commons.math3.util.MathArrays;
 
 import org.apache.commons.rng.UniformRandomProvider;
 import org.apache.commons.rng.simple.RandomSource;
@@ -57,25 +52,27 @@ public class PermutationSamplerTest {
 
     @Test
     public void testSampleChiSquareTest() {
+        final int n = 3;
+        final int k = 3;
         final int[][] p = { { 0, 1, 2 }, { 0, 2, 1 },
                             { 1, 0, 2 }, { 1, 2, 0 },
                             { 2, 0, 1 }, { 2, 1, 0 } };
-        final int len = p.length; 
-        final long[] observed = new long[len];
-        final int numSamples = 6000;
-        final double numExpected = numSamples / (double) len;
-        final double[] expected = new double[len];
-        Arrays.fill(expected, numExpected);
-
-        final PermutationSampler sampler = new PermutationSampler(rng, 3, 3);
-        for (int i = 0; i < numSamples; i++) {
-            observed[findPerm(p, sampler.sample())]++;
-        }
+        runSampleChiSquareTest(n, k, p);
+    }
 
-        // Pass if we cannot reject null hypothesis that distributions are the 
same.
-        Assert.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 
0.001));
+    @Test
+    public void testSubSampleChiSquareTest() {
+        final int n = 4;
+        final int k = 2;
+        final int[][] p = { { 0, 1 }, { 1, 0 },
+                            { 0, 2 }, { 2, 0 },
+                            { 0, 3 }, { 3, 0 },
+                            { 1, 2 }, { 2, 1 },
+                            { 1, 3 }, { 3, 1 },
+                            { 2, 3 }, { 3, 2 } };
+        runSampleChiSquareTest(n, k, p);
     }
-    
+
     @Test
     public void testSampleBoundaryCase() {
         // Check size = 1 boundary case.
@@ -191,20 +188,32 @@ public class PermutationSamplerTest {
 
     //// Support methods.
 
-    private int findPerm(int[][] p,
-                         int[] samp) {
+    private void runSampleChiSquareTest(int n,
+                                        int k,
+                                        int[][] p) {
+        final int len = p.length;
+        final long[] observed = new long[len];
+        final int numSamples = 6000;
+        final double numExpected = numSamples / (double) len;
+        final double[] expected = new double[len];
+        Arrays.fill(expected, numExpected);
+
+        final PermutationSampler sampler = new PermutationSampler(rng, n, k);
+        for (int i = 0; i < numSamples; i++) {
+            observed[findPerm(p, sampler.sample())]++;
+        }
+
+        // Pass if we cannot reject null hypothesis that distributions are the 
same.
+        Assert.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 
0.001));
+    }
+
+    private static int findPerm(int[][] p,
+                                int[] samp) {
         for (int i = 0; i < p.length; i++) {
-            boolean good = true;
-            for (int j = 0; j < samp.length; j++) {
-                if (samp[j] != p[i][j]) {
-                    good = false;
-                }
-            }
-            if (good) {
+            if (Arrays.equals(p[i], samp)) {
                 return i;
             }
         }
-
         Assert.fail("Permutation not found");
         return -1;
     }

Reply via email to