This is an automated email from the ASF dual-hosted git repository. zhangzp pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push: new fe338b19 [FLINK-31623] Fix DataStreamUtils#sample with approximate uniform sampling fe338b19 is described below commit fe338b194b73fd51218f4d842fa7b0065fb76c56 Author: Fan Hong <hongfa...@gmail.com> AuthorDate: Mon Apr 3 15:15:15 2023 +0800 [FLINK-31623] Fix DataStreamUtils#sample with approximate uniform sampling This closes #227. --- .../flink/ml/common/datastream/DataStreamUtils.java | 19 ++++++++++++++----- .../ml/common/datastream/DataStreamUtilsTest.java | 15 +++++++++++++++ 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java index 691e7704..eb4ec6ca 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java @@ -275,7 +275,10 @@ public class DataStreamUtils { } /** - * Performs a uniform sampling over the elements in a bounded data stream. + * Performs an approximate uniform sampling over the elements in a bounded data stream. The + * difference of probabilities of two data points been sampled is bounded by O(numSamples * p * + * p / (M * M)), where p is the parallelism of the input stream, M is the total number of data + * points that the input stream contains. * * <p>This method takes samples without replacement. If the number of elements in the stream is * smaller than expected number of samples, all elements will be included in the sample. @@ -288,13 +291,19 @@ public class DataStreamUtils { public static <T> DataStream<T> sample(DataStream<T> input, int numSamples, long randomSeed) { int inputParallelism = input.getParallelism(); - return input.transform( - "samplingOperator", + // The maximum difference of number of data points in each partition after calling + // `rebalance` is `inputParallelism`. As a result, extra `inputParallelism` data points are + // sampled for each partition in the first round. + int firstRoundNumSamples = + Math.min((numSamples / inputParallelism) + inputParallelism, numSamples); + return input.rebalance() + .transform( + "firstRoundSampling", input.getType(), - new SamplingOperator<>(numSamples, randomSeed)) + new SamplingOperator<>(firstRoundNumSamples, randomSeed)) .setParallelism(inputParallelism) .transform( - "samplingOperator", + "secondRoundSampling", input.getType(), new SamplingOperator<>(numSamples, randomSeed)) .setParallelism(1) diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java index d3f8a95e..7b3e8b3a 100644 --- a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java @@ -100,6 +100,21 @@ public class DataStreamUtilsTest { assertEquals(Integer.toString(190 + env.getParallelism()), stringSum.get(0)); } + @Test + public void testSample() throws Exception { + int numSamples = 10; + int[] totalMinusOneChoices = new int[] {0, 5, 9, 10, 11, 20, 30, 40, 200}; + for (int totalMinusOne : totalMinusOneChoices) { + DataStream<Long> dataStream = + env.fromParallelCollection( + new NumberSequenceIterator(0L, totalMinusOne), Types.LONG); + DataStream<Long> result = DataStreamUtils.sample(dataStream, numSamples, 0); + //noinspection unchecked + List<String> sampled = IteratorUtils.toList(result.executeAndCollect()); + assertEquals(Math.min(numSamples, totalMinusOne + 1), sampled.size()); + } + } + @Test public void testGenerateBatchData() throws Exception { DataStream<Long> dataStream =