xinrong-databricks commented on code in PR #36599:
URL: https://github.com/apache/spark/pull/36599#discussion_r880753241


##########
python/pyspark/pandas/series.py:
##########
@@ -6255,36 +6261,47 @@ def argmax(self) -> int:
         --------
         Consider dataset containing cereal calories
 
-        >>> s = ps.Series({'Corn Flakes': 100.0, 'Almond Delight': 110.0,
+        >>> s = ps.Series({'Corn Flakes': 100.0, 'Almond Delight': 110.0, 
'Unknown': np.nan,
         ...                'Cinnamon Toast Crunch': 120.0, 'Cocoa Puff': 
110.0})
-        >>> s  # doctest: +SKIP
+        >>> s
         Corn Flakes              100.0
         Almond Delight           110.0
+        Unknown                    NaN
         Cinnamon Toast Crunch    120.0
         Cocoa Puff               110.0
         dtype: float64
 
-        >>> s.argmax()  # doctest: +SKIP
-        2
+        >>> s.argmax()
+        3
+
+        >>> s.argmax(skipna=False)
+        -1
         """
         sdf = self._internal.spark_frame.select(self.spark.column, 
NATURAL_ORDER_COLUMN_NAME)
+        seq_col_name = verify_temp_column_name(sdf, 
"__distributed_sequence_column__")
+        sdf = InternalFrame.attach_distributed_sequence_column(
+            sdf,
+            seq_col_name,
+        )
+        scol = scol_for(sdf, self._internal.data_spark_column_names[0])
+
+        if skipna:
+            sdf = sdf.orderBy(scol.desc_nulls_last(), 
NATURAL_ORDER_COLUMN_NAME)
+        else:
+            sdf = sdf.orderBy(scol.desc_nulls_first(), 
NATURAL_ORDER_COLUMN_NAME)
+
         max_value = sdf.select(
-            F.max(scol_for(sdf, self._internal.data_spark_column_names[0])),
+            F.first(scol),
             F.first(NATURAL_ORDER_COLUMN_NAME),
         ).head()
+
         if max_value[1] is None:
             raise ValueError("attempt to get argmax of an empty sequence")
         elif max_value[0] is None:
             return -1
-        # We should remember the natural sequence started from 0
-        seq_col_name = verify_temp_column_name(sdf, 
"__distributed_sequence_column__")
-        sdf = InternalFrame.attach_distributed_sequence_column(
-            sdf.drop(NATURAL_ORDER_COLUMN_NAME), seq_col_name
-        )
+
         # If the maximum is achieved in multiple locations, the first row 
position is returned.
-        return sdf.filter(
-            scol_for(sdf, self._internal.data_spark_column_names[0]) == 
max_value[0]
-        ).head()[0]
+        return sdf.filter(scol == max_value[0]).head()[0]

Review Comment:
   Thank you @itholic @zhengruifeng !
   
   btw `max_by`'s document can be better to explain how ties are handled ;).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to