This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 4d73552abf3 [SPARK-40333][PS] Implement `GroupBy.nth` 4d73552abf3 is described below commit 4d73552abf39c687a1ef1f742fcecdf7492995af Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu Sep 8 16:04:26 2022 +0800 [SPARK-40333][PS] Implement `GroupBy.nth` ### What changes were proposed in this pull request? Implement `GroupBy.nth` ### Why are the changes needed? for API coverage ### Does this PR introduce _any_ user-facing change? yes, new API ``` In [4]: import pyspark.pandas as ps In [5]: import numpy as np In [6]: df = ps.DataFrame({'A': [1, 1, 2, 1, 2], 'B': [np.nan, 2, 3, 4, 5], 'C': ['a', 'b', 'c', 'd', 'e']}, columns=['A', 'B', 'C']) In [7]: df.groupby('A').nth(0) B C A 1 NaN a 2 3.0 c In [8]: df.groupby('A').nth(2) Out[8]: B C A 1 4.0 d In [9]: df.C.groupby(df.A).nth(-1) Out[9]: A 1 d 2 e Name: C, dtype: object In [10]: df.C.groupby(df.A).nth(-2) Out[10]: A 1 b 2 c Name: C, dtype: object ``` ### How was this patch tested? added UT Closes #37801 from zhengruifeng/ps_groupby_nth. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../source/reference/pyspark.pandas/groupby.rst | 1 + python/pyspark/pandas/groupby.py | 98 ++++++++++++++++++++++ python/pyspark/pandas/missing/groupby.py | 2 - python/pyspark/pandas/tests/test_groupby.py | 11 +++ 4 files changed, 110 insertions(+), 2 deletions(-) diff --git a/python/docs/source/reference/pyspark.pandas/groupby.rst b/python/docs/source/reference/pyspark.pandas/groupby.rst index b331a49b683..24e3bde91f5 100644 --- a/python/docs/source/reference/pyspark.pandas/groupby.rst +++ b/python/docs/source/reference/pyspark.pandas/groupby.rst @@ -73,6 +73,7 @@ Computations / Descriptive Stats GroupBy.mean GroupBy.median GroupBy.min + GroupBy.nth GroupBy.rank GroupBy.sem GroupBy.std diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py index 84a5a3377f3..01163b61375 100644 --- a/python/pyspark/pandas/groupby.py +++ b/python/pyspark/pandas/groupby.py @@ -895,6 +895,104 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta): bool_to_numeric=True, ) + # TODO: 1, 'n' accepts list and slice; 2, implement 'dropna' parameter + def nth(self, n: int) -> FrameLike: + """ + Take the nth row from each group. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + n : int + A single nth value for the row + + Returns + ------- + Series or DataFrame + + Notes + ----- + There is a behavior difference between pandas-on-Spark and pandas: + + * when there is no aggregation column, and `n` not equal to 0 or -1, + the returned empty dataframe may have an index with different lenght `__len__`. + + Examples + -------- + >>> df = ps.DataFrame({'A': [1, 1, 2, 1, 2], + ... 'B': [np.nan, 2, 3, 4, 5]}, columns=['A', 'B']) + >>> g = df.groupby('A') + >>> g.nth(0) + B + A + 1 NaN + 2 3.0 + >>> g.nth(1) + B + A + 1 2.0 + 2 5.0 + >>> g.nth(-1) + B + A + 1 4.0 + 2 5.0 + + See Also + -------- + pyspark.pandas.Series.groupby + pyspark.pandas.DataFrame.groupby + """ + if isinstance(n, slice) or is_list_like(n): + raise NotImplementedError("n doesn't support slice or list for now") + if not isinstance(n, int): + raise TypeError("Invalid index %s" % type(n).__name__) + + groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))] + internal, agg_columns, sdf = self._prepare_reduce( + groupkey_names=groupkey_names, + accepted_spark_types=None, + bool_to_numeric=False, + ) + psdf: DataFrame = DataFrame(internal) + + if len(psdf._internal.column_labels) > 0: + window1 = Window.partitionBy(*groupkey_names).orderBy(NATURAL_ORDER_COLUMN_NAME) + tmp_row_number_col = verify_temp_column_name(sdf, "__tmp_row_number_col__") + if n >= 0: + sdf = ( + psdf._internal.spark_frame.withColumn( + tmp_row_number_col, F.row_number().over(window1) + ) + .where(F.col(tmp_row_number_col) == n + 1) + .drop(tmp_row_number_col) + ) + else: + window2 = Window.partitionBy(*groupkey_names).rowsBetween( + Window.unboundedPreceding, Window.unboundedFollowing + ) + tmp_group_size_col = verify_temp_column_name(sdf, "__tmp_group_size_col__") + sdf = ( + psdf._internal.spark_frame.withColumn( + tmp_group_size_col, F.count(F.lit(0)).over(window2) + ) + .withColumn(tmp_row_number_col, F.row_number().over(window1)) + .where(F.col(tmp_row_number_col) == F.col(tmp_group_size_col) + 1 + n) + .drop(tmp_group_size_col, tmp_row_number_col) + ) + else: + sdf = sdf.select(*groupkey_names).distinct() + + internal = internal.copy( + spark_frame=sdf, + index_spark_columns=[scol_for(sdf, col) for col in groupkey_names], + data_spark_columns=[scol_for(sdf, col) for col in internal.data_spark_column_names], + data_fields=None, + ) + + return self._prepare_return(DataFrame(internal)) + def all(self, skipna: bool = True) -> FrameLike: """ Returns True if all values in the group are truthful, else False. diff --git a/python/pyspark/pandas/missing/groupby.py b/python/pyspark/pandas/missing/groupby.py index 8ae8a68b5fe..e913835ca72 100644 --- a/python/pyspark/pandas/missing/groupby.py +++ b/python/pyspark/pandas/missing/groupby.py @@ -59,7 +59,6 @@ class MissingPandasLikeDataFrameGroupBy: # Functions boxplot = _unsupported_function("boxplot") ngroup = _unsupported_function("ngroup") - nth = _unsupported_function("nth") ohlc = _unsupported_function("ohlc") pct_change = _unsupported_function("pct_change") pipe = _unsupported_function("pipe") @@ -93,7 +92,6 @@ class MissingPandasLikeSeriesGroupBy: aggregate = _unsupported_function("aggregate") describe = _unsupported_function("describe") ngroup = _unsupported_function("ngroup") - nth = _unsupported_function("nth") ohlc = _unsupported_function("ohlc") pct_change = _unsupported_function("pct_change") pipe = _unsupported_function("pipe") diff --git a/python/pyspark/pandas/tests/test_groupby.py b/python/pyspark/pandas/tests/test_groupby.py index e76fcf00faf..1076d867344 100644 --- a/python/pyspark/pandas/tests/test_groupby.py +++ b/python/pyspark/pandas/tests/test_groupby.py @@ -1380,6 +1380,17 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils): self._test_stat_func(lambda groupby_obj: groupby_obj.last(numeric_only=None)) self._test_stat_func(lambda groupby_obj: groupby_obj.last(numeric_only=True)) + def test_nth(self): + for n in [0, 1, 2, 128, -1, -2, -128]: + self._test_stat_func(lambda groupby_obj: groupby_obj.nth(n)) + + with self.assertRaisesRegex(NotImplementedError, "slice or list"): + self.psdf.groupby("B").nth(slice(0, 2)) + with self.assertRaisesRegex(NotImplementedError, "slice or list"): + self.psdf.groupby("B").nth([0, 1, -1]) + with self.assertRaisesRegex(TypeError, "Invalid index"): + self.psdf.groupby("B").nth("x") + def test_cumcount(self): pdf = pd.DataFrame( { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org