This is an automated email from the ASF dual-hosted git repository.

zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new fe338b19 [FLINK-31623] Fix DataStreamUtils#sample with approximate 
uniform sampling
fe338b19 is described below

commit fe338b194b73fd51218f4d842fa7b0065fb76c56
Author: Fan Hong <hongfa...@gmail.com>
AuthorDate: Mon Apr 3 15:15:15 2023 +0800

    [FLINK-31623] Fix DataStreamUtils#sample with approximate uniform sampling
    
    This closes #227.
---
 .../flink/ml/common/datastream/DataStreamUtils.java   | 19 ++++++++++++++-----
 .../ml/common/datastream/DataStreamUtilsTest.java     | 15 +++++++++++++++
 2 files changed, 29 insertions(+), 5 deletions(-)

diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
index 691e7704..eb4ec6ca 100644
--- 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
@@ -275,7 +275,10 @@ public class DataStreamUtils {
     }
 
     /**
-     * Performs a uniform sampling over the elements in a bounded data stream.
+     * Performs an approximate uniform sampling over the elements in a bounded 
data stream. The
+     * difference of probabilities of two data points been sampled is bounded 
by O(numSamples * p *
+     * p / (M * M)), where p is the parallelism of the input stream, M is the 
total number of data
+     * points that the input stream contains.
      *
      * <p>This method takes samples without replacement. If the number of 
elements in the stream is
      * smaller than expected number of samples, all elements will be included 
in the sample.
@@ -288,13 +291,19 @@ public class DataStreamUtils {
     public static <T> DataStream<T> sample(DataStream<T> input, int 
numSamples, long randomSeed) {
         int inputParallelism = input.getParallelism();
 
-        return input.transform(
-                        "samplingOperator",
+        // The maximum difference of number of data points in each partition 
after calling
+        // `rebalance` is `inputParallelism`. As a result, extra 
`inputParallelism` data points are
+        // sampled for each partition in the first round.
+        int firstRoundNumSamples =
+                Math.min((numSamples / inputParallelism) + inputParallelism, 
numSamples);
+        return input.rebalance()
+                .transform(
+                        "firstRoundSampling",
                         input.getType(),
-                        new SamplingOperator<>(numSamples, randomSeed))
+                        new SamplingOperator<>(firstRoundNumSamples, 
randomSeed))
                 .setParallelism(inputParallelism)
                 .transform(
-                        "samplingOperator",
+                        "secondRoundSampling",
                         input.getType(),
                         new SamplingOperator<>(numSamples, randomSeed))
                 .setParallelism(1)
diff --git 
a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
 
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
index d3f8a95e..7b3e8b3a 100644
--- 
a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
+++ 
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
@@ -100,6 +100,21 @@ public class DataStreamUtilsTest {
         assertEquals(Integer.toString(190 + env.getParallelism()), 
stringSum.get(0));
     }
 
+    @Test
+    public void testSample() throws Exception {
+        int numSamples = 10;
+        int[] totalMinusOneChoices = new int[] {0, 5, 9, 10, 11, 20, 30, 40, 
200};
+        for (int totalMinusOne : totalMinusOneChoices) {
+            DataStream<Long> dataStream =
+                    env.fromParallelCollection(
+                            new NumberSequenceIterator(0L, totalMinusOne), 
Types.LONG);
+            DataStream<Long> result = DataStreamUtils.sample(dataStream, 
numSamples, 0);
+            //noinspection unchecked
+            List<String> sampled = 
IteratorUtils.toList(result.executeAndCollect());
+            assertEquals(Math.min(numSamples, totalMinusOne + 1), 
sampled.size());
+        }
+    }
+
     @Test
     public void testGenerateBatchData() throws Exception {
         DataStream<Long> dataStream =

Reply via email to