This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 52cd037f321 [SPARK-40225][PYTHON] PySpark rdd.takeOrdered should check num and numPartitions 52cd037f321 is described below commit 52cd037f321c1e8e1e4d6c1b76b1b4c50fbe3ccd Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Fri Aug 26 15:26:29 2022 +0900 [SPARK-40225][PYTHON] PySpark rdd.takeOrdered should check num and numPartitions ### What changes were proposed in this pull request? add the validation of `num` and `numPartitions` ### Why are the changes needed? this PR is to keep in line with the scala side 1, the scala side will check whether `num==0` 2, the scala side will check whether `numPartitions==0`, so it accept a `emptyRDD` ```scala scala> sc.emptyRDD[Int].takeOrdered(3) res2: Array[Int] = Array() ``` while in PySpark: ``` In [1]: sc.emptyRDD().takeOrdered(3) --------------------------------------------------------------------------- ValueError Traceback (most recent call last) Input In [1], in <cell line: 1>() ----> 1 sc.emptyRDD().takeOrdered(3) File ~/.dev/spark-3.3.0-bin-hadoop3/python/pyspark/rdd.py:1823, in RDD.takeOrdered(self, num, key) 1820 def merge(a: List[T], b: List[T]) -> List[T]: 1821 return heapq.nsmallest(num, a + b, key) -> 1823 return self.mapPartitions(lambda it: [heapq.nsmallest(num, it, key)]).reduce(merge) File ~/.dev/spark-3.3.0-bin-hadoop3/python/pyspark/rdd.py:1253, in RDD.reduce(self, f) 1251 if vals: 1252 return reduce(f, vals) -> 1253 raise ValueError("Can not reduce() empty RDD") ValueError: Can not reduce() empty RDD ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? added doctest Closes #37669 from zhengruifeng/py_rdd_take_ordered. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/rdd.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 181935aa8d7..b631f141a89 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2732,12 +2732,20 @@ class RDD(Generic[T_co]): [1, 2, 3, 4, 5, 6] >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7], 2).takeOrdered(6, key=lambda x: -x) [10, 9, 7, 6, 5, 4] + >>> sc.emptyRDD().takeOrdered(3) + [] """ + if num < 0: + raise ValueError("top N cannot be negative.") - def merge(a: List[T], b: List[T]) -> List[T]: - return heapq.nsmallest(num, a + b, key) + if num == 0 or self.getNumPartitions() == 0: + return [] + else: + + def merge(a: List[T], b: List[T]) -> List[T]: + return heapq.nsmallest(num, a + b, key) - return self.mapPartitions(lambda it: [heapq.nsmallest(num, it, key)]).reduce(merge) + return self.mapPartitions(lambda it: [heapq.nsmallest(num, it, key)]).reduce(merge) def take(self: "RDD[T]", num: int) -> List[T]: """ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org