This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4fb4648ce3d [SPARK-38879][PYTHON][TEST] Improve the test coverage for 
pyspark/rddsampler.py
4fb4648ce3d is described below

commit 4fb4648ce3d7fab65ccfceb86cb6c839d0c921da
Author: Kumar, Pralabh <[email protected]>
AuthorDate: Wed Apr 27 10:11:16 2022 +0900

    [SPARK-38879][PYTHON][TEST] Improve the test coverage for 
pyspark/rddsampler.py
    
    ### What changes were proposed in this pull request?
    This PR add test cases for rddsampler
    
    ### Why are the changes needed?
    To cover corner test cases and increase coverage
    
    ### Does this PR introduce _any_ user-facing change?
    No - test only
    
    ### How was this patch tested?
    CI in this PR should test it out
    
    Closes #36342 from pralabhkumar/rk_rdd_sampler.
    
    Lead-authored-by: Kumar, Pralabh <[email protected]>
    Co-authored-by: pralabhkumar <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 dev/sparktestsupport/modules.py         |  1 +
 python/pyspark/tests/test_rddsampler.py | 66 +++++++++++++++++++++++++++++++++
 2 files changed, 67 insertions(+)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 5514df11f9a..ed1eeb9b807 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -397,6 +397,7 @@ pyspark_core = Module(
         "pyspark.tests.test_profiler",
         "pyspark.tests.test_rdd",
         "pyspark.tests.test_rddbarrier",
+        "pyspark.tests.test_rddsampler",
         "pyspark.tests.test_readwrite",
         "pyspark.tests.test_serializers",
         "pyspark.tests.test_shuffle",
diff --git a/python/pyspark/tests/test_rddsampler.py 
b/python/pyspark/tests/test_rddsampler.py
new file mode 100644
index 00000000000..b504c4ab980
--- /dev/null
+++ b/python/pyspark/tests/test_rddsampler.py
@@ -0,0 +1,66 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from pyspark.testing.utils import ReusedPySparkTestCase
+from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler
+
+
+class RDDSamplerTests(ReusedPySparkTestCase):
+    def test_rdd_sampler_func(self):
+        # SPARK-38879: Test case to improve test coverage for RDDSampler
+        # RDDSampler.func
+        rdd = self.sc.parallelize(range(20), 2)
+        sample_count = rdd.mapPartitionsWithIndex(RDDSampler(False, 0.4, 
10).func).count()
+        self.assertGreater(sample_count, 3)
+        self.assertLess(sample_count, 10)
+        sample_data = rdd.mapPartitionsWithIndex(RDDSampler(True, 1, 
10).func).collect()
+        sample_data.sort()
+        # check if at least one element is repeated.
+        self.assertTrue(
+            any(sample_data[i] == sample_data[i - 1] for i in range(1, 
len(sample_data)))
+        )
+
+    def test_rdd_stratified_sampler_func(self):
+        # SPARK-38879: Test case to improve test coverage for RDDSampler
+        # RDDStratifiedSampler.func
+
+        fractions = {"a": 0.8, "b": 0.2}
+        rdd = 
self.sc.parallelize(fractions.keys()).cartesian(self.sc.parallelize(range(0, 
100)))
+        sample_data = dict(
+            rdd.mapPartitionsWithIndex(
+                RDDStratifiedSampler(False, fractions, 10).func, True
+            ).countByKey()
+        )
+        # Since a have higher sampling rate (0.8),
+        # it will occur more number of times than b.
+        self.assertGreater(sample_data["a"], sample_data["b"])
+        self.assertGreater(sample_data["a"], 60)
+        self.assertLess(sample_data["a"], 90)
+        self.assertGreater(sample_data["b"], 15)
+        self.assertLess(sample_data["b"], 30)
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.tests.test_rddsampler import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to