This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 2fd76994860 [SPARK-39228][PYTHON][PS] Implement `skipna` of `Series.argmax` 2fd76994860 is described below commit 2fd769948604eff38d90974017971434484897d6 Author: Xinrong Meng <xinrong.m...@databricks.com> AuthorDate: Fri May 27 09:25:21 2022 +0800 [SPARK-39228][PYTHON][PS] Implement `skipna` of `Series.argmax` ### What changes were proposed in this pull request? Implement `skipna` of `Series.argmax` ### Why are the changes needed? Increase pandas API coverage. ### Does this PR introduce _any_ user-facing change? Yes. `skipna` of `Series.argmax` is supported as below. ```py >>> s = ps.Series({'Corn Flakes': 100.0, 'Almond Delight': 110.0, 'Unknown': np.nan, ... 'Cinnamon Toast Crunch': 120.0, 'Cocoa Puff': 110.0}) >>> s Corn Flakes 100.0 Almond Delight 110.0 Unknown NaN Cinnamon Toast Crunch 120.0 Cocoa Puff 110.0 dtype: float64 >>> s.argmax(skipna=True) 3 >>> s.argmax(skipna=False) -1 ``` ### How was this patch tested? Unit tests. Closes #36599 from xinrong-databricks/argmax.skipna. Authored-by: Xinrong Meng <xinrong.m...@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/pandas/series.py | 50 +++++++++++++++++++++--------- python/pyspark/pandas/tests/test_series.py | 9 +++++- 2 files changed, 44 insertions(+), 15 deletions(-) diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index 29afcbe956e..653b4812cad 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -6249,13 +6249,21 @@ class Series(Frame, IndexOpsMixin, Generic[T]): ps.concat([psser, self.loc[self.isnull()].spark.transform(lambda _: SF.lit(-1))]), ) - def argmax(self) -> int: + def argmax(self, axis: Axis = None, skipna: bool = True) -> int: """ Return int position of the largest value in the Series. If the maximum is achieved in multiple locations, the first row position is returned. + Parameters + ---------- + axis : {{None}} + Dummy argument for consistency with Series. + skipna : bool, default True + Exclude NA/null values. If the entire Series is NA, the result + will be NA. + Returns ------- int @@ -6265,36 +6273,50 @@ class Series(Frame, IndexOpsMixin, Generic[T]): -------- Consider dataset containing cereal calories - >>> s = ps.Series({'Corn Flakes': 100.0, 'Almond Delight': 110.0, + >>> s = ps.Series({'Corn Flakes': 100.0, 'Almond Delight': 110.0, 'Unknown': np.nan, ... 'Cinnamon Toast Crunch': 120.0, 'Cocoa Puff': 110.0}) - >>> s # doctest: +SKIP + >>> s Corn Flakes 100.0 Almond Delight 110.0 + Unknown NaN Cinnamon Toast Crunch 120.0 Cocoa Puff 110.0 dtype: float64 - >>> s.argmax() # doctest: +SKIP - 2 + >>> s.argmax() + 3 + + >>> s.argmax(skipna=False) + -1 """ + axis = validate_axis(axis, none_axis=0) + if axis == 1: + raise ValueError("axis can only be 0 or 'index'") sdf = self._internal.spark_frame.select(self.spark.column, NATURAL_ORDER_COLUMN_NAME) + seq_col_name = verify_temp_column_name(sdf, "__distributed_sequence_column__") + sdf = InternalFrame.attach_distributed_sequence_column( + sdf, + seq_col_name, + ) + scol = scol_for(sdf, self._internal.data_spark_column_names[0]) + + if skipna: + sdf = sdf.orderBy(scol.desc_nulls_last(), NATURAL_ORDER_COLUMN_NAME) + else: + sdf = sdf.orderBy(scol.desc_nulls_first(), NATURAL_ORDER_COLUMN_NAME) + max_value = sdf.select( - F.max(scol_for(sdf, self._internal.data_spark_column_names[0])), + F.first(scol), F.first(NATURAL_ORDER_COLUMN_NAME), ).head() + if max_value[1] is None: raise ValueError("attempt to get argmax of an empty sequence") elif max_value[0] is None: return -1 - # We should remember the natural sequence started from 0 - seq_col_name = verify_temp_column_name(sdf, "__distributed_sequence_column__") - sdf = InternalFrame.attach_distributed_sequence_column( - sdf.drop(NATURAL_ORDER_COLUMN_NAME), seq_col_name - ) + # If the maximum is achieved in multiple locations, the first row position is returned. - return sdf.filter( - scol_for(sdf, self._internal.data_spark_column_names[0]) == max_value[0] - ).head()[0] + return sdf.filter(scol == max_value[0]).head()[0] def argmin(self) -> int: """ diff --git a/python/pyspark/pandas/tests/test_series.py b/python/pyspark/pandas/tests/test_series.py index f39e1900dd3..7631444ee5d 100644 --- a/python/pyspark/pandas/tests/test_series.py +++ b/python/pyspark/pandas/tests/test_series.py @@ -2987,9 +2987,10 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils): name="Koalas", ) psser = ps.from_pandas(pser) - self.assert_eq(pser.argmin(), psser.argmin()) self.assert_eq(pser.argmax(), psser.argmax()) + self.assert_eq(pser.argmax(skipna=False), psser.argmax(skipna=False)) + self.assert_eq((pser + 1).argmax(skipna=False), (psser + 1).argmax(skipna=False)) # MultiIndex pser.index = pd.MultiIndex.from_tuples( @@ -2998,15 +2999,21 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils): psser = ps.from_pandas(pser) self.assert_eq(pser.argmin(), psser.argmin()) self.assert_eq(pser.argmax(), psser.argmax()) + self.assert_eq(pser.argmax(skipna=False), psser.argmax(skipna=False)) # Null Series self.assert_eq(pd.Series([np.nan]).argmin(), ps.Series([np.nan]).argmin()) self.assert_eq(pd.Series([np.nan]).argmax(), ps.Series([np.nan]).argmax()) + self.assert_eq( + pd.Series([np.nan]).argmax(skipna=False), ps.Series([np.nan]).argmax(skipna=False) + ) with self.assertRaisesRegex(ValueError, "attempt to get argmin of an empty sequence"): ps.Series([]).argmin() with self.assertRaisesRegex(ValueError, "attempt to get argmax of an empty sequence"): ps.Series([]).argmax() + with self.assertRaisesRegex(ValueError, "axis can only be 0 or 'index'"): + psser.argmax(axis=1) def test_backfill(self): pdf = pd.DataFrame({"x": [np.nan, 2, 3, 4, np.nan, 6]}) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org