This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2fd76994860 [SPARK-39228][PYTHON][PS] Implement `skipna` of 
`Series.argmax`
2fd76994860 is described below

commit 2fd769948604eff38d90974017971434484897d6
Author: Xinrong Meng <xinrong.m...@databricks.com>
AuthorDate: Fri May 27 09:25:21 2022 +0800

    [SPARK-39228][PYTHON][PS] Implement `skipna` of `Series.argmax`
    
    ### What changes were proposed in this pull request?
    Implement `skipna` of `Series.argmax`
    
    ### Why are the changes needed?
    Increase pandas API coverage.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. `skipna` of `Series.argmax` is supported as below.
    
    ```py
            >>> s = ps.Series({'Corn Flakes': 100.0, 'Almond Delight': 110.0, 
'Unknown': np.nan,
            ...                'Cinnamon Toast Crunch': 120.0, 'Cocoa Puff': 
110.0})
            >>> s
            Corn Flakes              100.0
            Almond Delight           110.0
            Unknown                    NaN
            Cinnamon Toast Crunch    120.0
            Cocoa Puff               110.0
            dtype: float64
    
            >>> s.argmax(skipna=True)
            3
            >>> s.argmax(skipna=False)
            -1
    ```
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #36599 from xinrong-databricks/argmax.skipna.
    
    Authored-by: Xinrong Meng <xinrong.m...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/pandas/series.py            | 50 +++++++++++++++++++++---------
 python/pyspark/pandas/tests/test_series.py |  9 +++++-
 2 files changed, 44 insertions(+), 15 deletions(-)

diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py
index 29afcbe956e..653b4812cad 100644
--- a/python/pyspark/pandas/series.py
+++ b/python/pyspark/pandas/series.py
@@ -6249,13 +6249,21 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
             ps.concat([psser, self.loc[self.isnull()].spark.transform(lambda 
_: SF.lit(-1))]),
         )
 
-    def argmax(self) -> int:
+    def argmax(self, axis: Axis = None, skipna: bool = True) -> int:
         """
         Return int position of the largest value in the Series.
 
         If the maximum is achieved in multiple locations,
         the first row position is returned.
 
+        Parameters
+        ----------
+        axis : {{None}}
+            Dummy argument for consistency with Series.
+        skipna : bool, default True
+            Exclude NA/null values. If the entire Series is NA, the result
+            will be NA.
+
         Returns
         -------
         int
@@ -6265,36 +6273,50 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
         --------
         Consider dataset containing cereal calories
 
-        >>> s = ps.Series({'Corn Flakes': 100.0, 'Almond Delight': 110.0,
+        >>> s = ps.Series({'Corn Flakes': 100.0, 'Almond Delight': 110.0, 
'Unknown': np.nan,
         ...                'Cinnamon Toast Crunch': 120.0, 'Cocoa Puff': 
110.0})
-        >>> s  # doctest: +SKIP
+        >>> s
         Corn Flakes              100.0
         Almond Delight           110.0
+        Unknown                    NaN
         Cinnamon Toast Crunch    120.0
         Cocoa Puff               110.0
         dtype: float64
 
-        >>> s.argmax()  # doctest: +SKIP
-        2
+        >>> s.argmax()
+        3
+
+        >>> s.argmax(skipna=False)
+        -1
         """
+        axis = validate_axis(axis, none_axis=0)
+        if axis == 1:
+            raise ValueError("axis can only be 0 or 'index'")
         sdf = self._internal.spark_frame.select(self.spark.column, 
NATURAL_ORDER_COLUMN_NAME)
+        seq_col_name = verify_temp_column_name(sdf, 
"__distributed_sequence_column__")
+        sdf = InternalFrame.attach_distributed_sequence_column(
+            sdf,
+            seq_col_name,
+        )
+        scol = scol_for(sdf, self._internal.data_spark_column_names[0])
+
+        if skipna:
+            sdf = sdf.orderBy(scol.desc_nulls_last(), 
NATURAL_ORDER_COLUMN_NAME)
+        else:
+            sdf = sdf.orderBy(scol.desc_nulls_first(), 
NATURAL_ORDER_COLUMN_NAME)
+
         max_value = sdf.select(
-            F.max(scol_for(sdf, self._internal.data_spark_column_names[0])),
+            F.first(scol),
             F.first(NATURAL_ORDER_COLUMN_NAME),
         ).head()
+
         if max_value[1] is None:
             raise ValueError("attempt to get argmax of an empty sequence")
         elif max_value[0] is None:
             return -1
-        # We should remember the natural sequence started from 0
-        seq_col_name = verify_temp_column_name(sdf, 
"__distributed_sequence_column__")
-        sdf = InternalFrame.attach_distributed_sequence_column(
-            sdf.drop(NATURAL_ORDER_COLUMN_NAME), seq_col_name
-        )
+
         # If the maximum is achieved in multiple locations, the first row 
position is returned.
-        return sdf.filter(
-            scol_for(sdf, self._internal.data_spark_column_names[0]) == 
max_value[0]
-        ).head()[0]
+        return sdf.filter(scol == max_value[0]).head()[0]
 
     def argmin(self) -> int:
         """
diff --git a/python/pyspark/pandas/tests/test_series.py 
b/python/pyspark/pandas/tests/test_series.py
index f39e1900dd3..7631444ee5d 100644
--- a/python/pyspark/pandas/tests/test_series.py
+++ b/python/pyspark/pandas/tests/test_series.py
@@ -2987,9 +2987,10 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
             name="Koalas",
         )
         psser = ps.from_pandas(pser)
-
         self.assert_eq(pser.argmin(), psser.argmin())
         self.assert_eq(pser.argmax(), psser.argmax())
+        self.assert_eq(pser.argmax(skipna=False), psser.argmax(skipna=False))
+        self.assert_eq((pser + 1).argmax(skipna=False), (psser + 
1).argmax(skipna=False))
 
         # MultiIndex
         pser.index = pd.MultiIndex.from_tuples(
@@ -2998,15 +2999,21 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
         psser = ps.from_pandas(pser)
         self.assert_eq(pser.argmin(), psser.argmin())
         self.assert_eq(pser.argmax(), psser.argmax())
+        self.assert_eq(pser.argmax(skipna=False), psser.argmax(skipna=False))
 
         # Null Series
         self.assert_eq(pd.Series([np.nan]).argmin(), 
ps.Series([np.nan]).argmin())
         self.assert_eq(pd.Series([np.nan]).argmax(), 
ps.Series([np.nan]).argmax())
+        self.assert_eq(
+            pd.Series([np.nan]).argmax(skipna=False), 
ps.Series([np.nan]).argmax(skipna=False)
+        )
 
         with self.assertRaisesRegex(ValueError, "attempt to get argmin of an 
empty sequence"):
             ps.Series([]).argmin()
         with self.assertRaisesRegex(ValueError, "attempt to get argmax of an 
empty sequence"):
             ps.Series([]).argmax()
+        with self.assertRaisesRegex(ValueError, "axis can only be 0 or 
'index'"):
+            psser.argmax(axis=1)
 
     def test_backfill(self):
         pdf = pd.DataFrame({"x": [np.nan, 2, 3, 4, np.nan, 6]})


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to