This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f555a0d80e1 [SPARK-40240][PYTHON] PySpark rdd.takeSample should correctly validate `num > maxSampleSize` f555a0d80e1 is described below commit f555a0d80e1858ca30527328ca240b56ae6f415e Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Sun Aug 28 07:27:13 2022 +0800 [SPARK-40240][PYTHON] PySpark rdd.takeSample should correctly validate `num > maxSampleSize` ### What changes were proposed in this pull request? to make the PySpark `rdd.takeSample` behave like the scala side ### Why are the changes needed? `rdd.takeSample` in Spark-Core checks the `num > maxsize - int(numStDev * sqrt(maxsize))` at first, while in the PySpark, it may skip this validation: ```scala scala> sc.range(0, 10).takeSample(false, Int.MaxValue) java.lang.IllegalArgumentException: requirement failed: Cannot support a sample size > Int.MaxValue - 10.0 * math.sqrt(Int.MaxValue) at scala.Predef$.require(Predef.scala:281) at org.apache.spark.rdd.RDD.$anonfun$takeSample$1(RDD.scala:620) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112) at org.apache.spark.rdd.RDD.withScope(RDD.scala:406) at org.apache.spark.rdd.RDD.takeSample(RDD.scala:615) ... 47 elided ``` ```python In [2]: sc.range(0, 10).takeSample(False, sys.maxsize) Out[2]: [9, 6, 8, 5, 7, 2, 0, 3, 4, 1] ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? added doctest Closes #37683 from zhengruifeng/py_refine_takesample. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/rdd.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index b631f141a89..5fe463233a2 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1122,6 +1122,7 @@ class RDD(Generic[T_co]): Examples -------- + >>> import sys >>> rdd = sc.parallelize(range(0, 10)) >>> len(rdd.takeSample(True, 20, 1)) 20 @@ -1129,12 +1130,19 @@ class RDD(Generic[T_co]): 5 >>> len(rdd.takeSample(False, 15, 3)) 10 + >>> sc.range(0, 10).takeSample(False, sys.maxsize) + Traceback (most recent call last): + ... + ValueError: Sample size cannot be greater than ... """ numStDev = 10.0 - + maxSampleSize = sys.maxsize - int(numStDev * sqrt(sys.maxsize)) if num < 0: raise ValueError("Sample size cannot be negative.") - elif num == 0: + elif num > maxSampleSize: + raise ValueError("Sample size cannot be greater than %d." % maxSampleSize) + + if num == 0 or self.getNumPartitions() == 0: return [] initialCount = self.count() @@ -1149,10 +1157,6 @@ class RDD(Generic[T_co]): rand.shuffle(samples) return samples - maxSampleSize = sys.maxsize - int(numStDev * sqrt(sys.maxsize)) - if num > maxSampleSize: - raise ValueError("Sample size cannot be greater than %d." % maxSampleSize) - fraction = RDD._computeFractionForSampleSize(num, initialCount, withReplacement) samples = self.sample(withReplacement, fraction, seed).collect() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org