This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c01e524c298 [SPARK-40334][PS] Implement `GroupBy.prod` c01e524c298 is described below commit c01e524c2985be06027191e51bb94d9ee5637d40 Author: artsiomyudovin <a.yudovin6...@gmail.com> AuthorDate: Mon Sep 26 08:00:20 2022 +0800 [SPARK-40334][PS] Implement `GroupBy.prod` ### What changes were proposed in this pull request? Implement `GroupBy.prod` ### Why are the changes needed? for API coverage ### Does this PR introduce _any_ user-facing change? yes, the new API ``` df = ps.DataFrame({'A': [1, 1, 2, 1, 2], 'B': [np.nan, 2, 3, 4, 5], 'C': [1, 2, 1, 1, 2], 'D': [True, False, True, False, True]}) Groupby one column and return the prod of the remaining columns in each group. df.groupby('A').prod() B C D A 1 8.0 2 0 2 15.0 2 11 df.groupby('A').prod(min_count=3) B C D A 1 NaN 2 0 2 NaN NaN NaN ``` ### How was this patch tested? added UT Closes #37923 from ayudovin/ps_group_by_prod. Authored-by: artsiomyudovin <a.yudovin6...@gmail.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../source/reference/pyspark.pandas/groupby.rst | 1 + python/pyspark/pandas/groupby.py | 106 ++++++++++++++++++++- python/pyspark/pandas/missing/groupby.py | 2 - python/pyspark/pandas/tests/test_groupby.py | 10 ++ 4 files changed, 114 insertions(+), 5 deletions(-) diff --git a/python/docs/source/reference/pyspark.pandas/groupby.rst b/python/docs/source/reference/pyspark.pandas/groupby.rst index 4c29964966c..da1579fd723 100644 --- a/python/docs/source/reference/pyspark.pandas/groupby.rst +++ b/python/docs/source/reference/pyspark.pandas/groupby.rst @@ -74,6 +74,7 @@ Computations / Descriptive Stats GroupBy.median GroupBy.min GroupBy.nth + GroupBy.prod GroupBy.rank GroupBy.sem GroupBy.std diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py index 2e5c9ab219a..6d36cfecce6 100644 --- a/python/pyspark/pandas/groupby.py +++ b/python/pyspark/pandas/groupby.py @@ -18,7 +18,6 @@ """ A wrapper for GroupedData to behave similar to pandas GroupBy. """ - from abc import ABCMeta, abstractmethod import inspect from collections import defaultdict, namedtuple @@ -63,6 +62,7 @@ from pyspark.sql.types import ( StructField, StructType, StringType, + IntegralType, ) from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. @@ -1055,6 +1055,106 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta): return self._prepare_return(DataFrame(internal)) + def prod(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> FrameLike: + """ + Compute prod of groups. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + numeric_only : bool, default False + Include only float, int, boolean columns. If None, will attempt to use + everything, then use only numeric data. + + min_count: int, default 0 + The required number of valid values to perform the operation. + If fewer than min_count non-NA values are present the result will be NA. + + Returns + ------- + Series or DataFrame + Computed prod of values within each group. + + See Also + -------- + pyspark.pandas.Series.groupby + pyspark.pandas.DataFrame.groupby + + Examples + -------- + >>> import numpy as np + >>> df = ps.DataFrame( + ... { + ... "A": [1, 1, 2, 1, 2], + ... "B": [np.nan, 2, 3, 4, 5], + ... "C": [1, 2, 1, 1, 2], + ... "D": [True, False, True, False, True], + ... } + ... ) + + Groupby one column and return the prod of the remaining columns in + each group. + + >>> df.groupby('A').prod().sort_index() + B C D + A + 1 8.0 2 0 + 2 15.0 2 1 + + >>> df.groupby('A').prod(min_count=3).sort_index() + B C D + A + 1 NaN 2.0 0.0 + 2 NaN NaN NaN + """ + + self._validate_agg_columns(numeric_only=numeric_only, function_name="prod") + + groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))] + internal, agg_columns, sdf = self._prepare_reduce( + groupkey_names=groupkey_names, + accepted_spark_types=(NumericType, BooleanType), + bool_to_numeric=True, + ) + + psdf: DataFrame = DataFrame(internal) + if len(psdf._internal.column_labels) > 0: + + stat_exprs = [] + for label in psdf._internal.column_labels: + psser = psdf._psser_for(label) + column = psser._dtype_op.nan_to_null(psser).spark.column + data_type = psser.spark.data_type + aggregating = ( + F.product(column).cast("long") + if isinstance(data_type, IntegralType) + else F.product(column) + ) + + if min_count > 0: + prod_scol = F.when( + F.count(F.when(~F.isnull(column), F.lit(0))) < min_count, F.lit(None) + ).otherwise(aggregating) + else: + prod_scol = aggregating + + stat_exprs.append(prod_scol.alias(psser._internal.data_spark_column_names[0])) + + sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs) + + else: + sdf = sdf.select(*groupkey_names).distinct() + + internal = internal.copy( + spark_frame=sdf, + index_spark_columns=[scol_for(sdf, col) for col in groupkey_names], + data_spark_columns=[scol_for(sdf, col) for col in internal.data_spark_column_names], + data_fields=None, + ) + + return self._prepare_return(DataFrame(internal)) + def all(self, skipna: bool = True) -> FrameLike: """ Returns True if all values in the group are truthful, else False. @@ -3297,10 +3397,10 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta): if not numeric_only: if has_non_numeric: warnings.warn( - "Dropping invalid columns in DataFrameGroupBy.mean is deprecated. " + "Dropping invalid columns in DataFrameGroupBy.%s is deprecated. " "In a future version, a TypeError will be raised. " "Before calling .%s, select only columns which should be " - "valid for the function." % function_name, + "valid for the function." % (function_name, function_name), FutureWarning, ) diff --git a/python/pyspark/pandas/missing/groupby.py b/python/pyspark/pandas/missing/groupby.py index 3a0e90c2151..1799fac0033 100644 --- a/python/pyspark/pandas/missing/groupby.py +++ b/python/pyspark/pandas/missing/groupby.py @@ -61,7 +61,6 @@ class MissingPandasLikeDataFrameGroupBy: ohlc = _unsupported_function("ohlc") pct_change = _unsupported_function("pct_change") pipe = _unsupported_function("pipe") - prod = _unsupported_function("prod") resample = _unsupported_function("resample") @@ -93,5 +92,4 @@ class MissingPandasLikeSeriesGroupBy: ohlc = _unsupported_function("ohlc") pct_change = _unsupported_function("pct_change") pipe = _unsupported_function("pipe") - prod = _unsupported_function("prod") resample = _unsupported_function("resample") diff --git a/python/pyspark/pandas/tests/test_groupby.py b/python/pyspark/pandas/tests/test_groupby.py index 6e4aa6186c6..4a57a3421df 100644 --- a/python/pyspark/pandas/tests/test_groupby.py +++ b/python/pyspark/pandas/tests/test_groupby.py @@ -1433,6 +1433,16 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils): with self.assertRaisesRegex(TypeError, "Invalid index"): self.psdf.groupby("B").nth("x") + def test_prod(self): + for n in [0, 1, 2, 128, -1, -2, -128]: + self._test_stat_func(lambda groupby_obj: groupby_obj.prod(min_count=n)) + self._test_stat_func( + lambda groupby_obj: groupby_obj.prod(numeric_only=None, min_count=n) + ) + self._test_stat_func( + lambda groupby_obj: groupby_obj.prod(numeric_only=True, min_count=n) + ) + def test_cumcount(self): pdf = pd.DataFrame( { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org