zhengruifeng commented on code in PR #36599:
URL: https://github.com/apache/spark/pull/36599#discussion_r880050644


##########
python/pyspark/pandas/series.py:
##########
@@ -6255,36 +6261,47 @@ def argmax(self) -> int:
         --------
         Consider dataset containing cereal calories
 
-        >>> s = ps.Series({'Corn Flakes': 100.0, 'Almond Delight': 110.0,
+        >>> s = ps.Series({'Corn Flakes': 100.0, 'Almond Delight': 110.0, 
'Unknown': np.nan,
         ...                'Cinnamon Toast Crunch': 120.0, 'Cocoa Puff': 
110.0})
-        >>> s  # doctest: +SKIP
+        >>> s
         Corn Flakes              100.0
         Almond Delight           110.0
+        Unknown                    NaN
         Cinnamon Toast Crunch    120.0
         Cocoa Puff               110.0
         dtype: float64
 
-        >>> s.argmax()  # doctest: +SKIP
-        2
+        >>> s.argmax()
+        3
+
+        >>> s.argmax(skipna=False)
+        -1
         """
         sdf = self._internal.spark_frame.select(self.spark.column, 
NATURAL_ORDER_COLUMN_NAME)
+        seq_col_name = verify_temp_column_name(sdf, 
"__distributed_sequence_column__")
+        sdf = InternalFrame.attach_distributed_sequence_column(
+            sdf,
+            seq_col_name,
+        )
+        scol = scol_for(sdf, self._internal.data_spark_column_names[0])
+
+        if skipna:
+            sdf = sdf.orderBy(scol.desc_nulls_last(), 
NATURAL_ORDER_COLUMN_NAME)
+        else:
+            sdf = sdf.orderBy(scol.desc_nulls_first(), 
NATURAL_ORDER_COLUMN_NAME)
+
         max_value = sdf.select(
-            F.max(scol_for(sdf, self._internal.data_spark_column_names[0])),
+            F.first(scol),
             F.first(NATURAL_ORDER_COLUMN_NAME),
         ).head()
+
         if max_value[1] is None:
             raise ValueError("attempt to get argmax of an empty sequence")
         elif max_value[0] is None:
             return -1
-        # We should remember the natural sequence started from 0
-        seq_col_name = verify_temp_column_name(sdf, 
"__distributed_sequence_column__")
-        sdf = InternalFrame.attach_distributed_sequence_column(
-            sdf.drop(NATURAL_ORDER_COLUMN_NAME), seq_col_name
-        )
+
         # If the maximum is achieved in multiple locations, the first row 
position is returned.
-        return sdf.filter(
-            scol_for(sdf, self._internal.data_spark_column_names[0]) == 
max_value[0]
-        ).head()[0]
+        return sdf.filter(scol == max_value[0]).head()[0]

Review Comment:
   I had a try to apply `max_by` here but found it can not guarantee the `If 
the maximum is achieved in multiple locations, the first row position is 
returned.`
   
   let's keep current code. I'll take another look at `max_by`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to