This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 001da5d003c [SPARK-43671][PS][FOLLOWUP] Refine `CategoricalOps` 
functions
001da5d003c is described below

commit 001da5d003caef3cda9978d35967ade55837e0bc
Author: itholic <haejoon....@databricks.com>
AuthorDate: Sun May 28 08:44:16 2023 +0800

    [SPARK-43671][PS][FOLLOWUP] Refine `CategoricalOps` functions
    
    ### What changes were proposed in this pull request?
    
    This PR follow-up for SPARK-43671, to refine functions to use 
`pyspark_column_op` util for clean-up the code.
    
    ### Why are the changes needed?
    
    To avoid `is_remote` in too many places for future maintenance.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, it's code cleanup
    
    ### How was this patch tested?
    
    The existing CI should pass
    
    Closes #41326 from itholic/categorical_followup.
    
    Authored-by: itholic <haejoon....@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../pandas/data_type_ops/categorical_ops.py        | 69 +++++-----------------
 1 file changed, 14 insertions(+), 55 deletions(-)

diff --git a/python/pyspark/pandas/data_type_ops/categorical_ops.py 
b/python/pyspark/pandas/data_type_ops/categorical_ops.py
index 9f14a4b1ee7..66e181a6079 100644
--- a/python/pyspark/pandas/data_type_ops/categorical_ops.py
+++ b/python/pyspark/pandas/data_type_ops/categorical_ops.py
@@ -16,19 +16,18 @@
 #
 
 from itertools import chain
-from typing import cast, Any, Callable, Union
+from typing import cast, Any, Union
 
 import pandas as pd
 import numpy as np
 from pandas.api.types import is_list_like, CategoricalDtype  # type: 
ignore[attr-defined]
 
 from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex
-from pyspark.pandas.base import column_op, IndexOpsMixin
+from pyspark.pandas.base import IndexOpsMixin
 from pyspark.pandas.data_type_ops.base import _sanitize_list_like, DataTypeOps
 from pyspark.pandas.typedef import pandas_on_spark_type
 from pyspark.sql import functions as F
-from pyspark.sql.column import Column as PySparkColumn
-from pyspark.sql.utils import is_remote
+from pyspark.sql.utils import pyspark_column_op
 
 
 class CategoricalOps(DataTypeOps):
@@ -66,73 +65,33 @@ class CategoricalOps(DataTypeOps):
 
     def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
-        return _compare(
-            left, right, Column.__eq__, is_equality_comparison=True  # type: 
ignore[arg-type]
-        )
+        return _compare(left, right, "__eq__", is_equality_comparison=True)
 
     def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
-        return _compare(
-            left, right, Column.__ne__, is_equality_comparison=True  # type: 
ignore[arg-type]
-        )
+        return _compare(left, right, "__ne__", is_equality_comparison=True)
 
     def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
-        return _compare(left, right, Column.__lt__)  # type: ignore[arg-type]
+        return _compare(left, right, "__lt__")
 
     def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
-        return _compare(left, right, Column.__le__)  # type: ignore[arg-type]
+        return _compare(left, right, "__le__")
 
     def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
-        return _compare(left, right, Column.__gt__)  # type: ignore[arg-type]
+        return _compare(left, right, "__gt__")
 
     def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
-        if is_remote():
-            from pyspark.sql.connect.column import Column as ConnectColumn
-
-            Column = ConnectColumn
-        else:
-            Column = PySparkColumn  # type: ignore[assignment]
-        return _compare(left, right, Column.__ge__)  # type: ignore[arg-type]
+        return _compare(left, right, "__ge__")
 
 
 def _compare(
     left: IndexOpsLike,
     right: Any,
-    f: Callable[..., PySparkColumn],
+    func_name: str,
     *,
     is_equality_comparison: bool = False,
 ) -> SeriesOrIndex:
@@ -143,7 +102,7 @@ def _compare(
     ----------
     left: A Categorical operand
     right: The other operand to compare with
-    f : The Spark Column function to apply
+    func_name: The Spark Column function name to apply
     is_equality_comparison: True if it is equality comparison, ie. == or !=. 
False by default.
 
     Returns
@@ -158,15 +117,15 @@ def _compare(
         if hash(left.dtype) != hash(right.dtype):
             raise TypeError("Categoricals can only be compared if 'categories' 
are the same.")
         if cast(CategoricalDtype, left.dtype).ordered:
-            return column_op(f)(left, right)
+            return pyspark_column_op(func_name)(left, right)
         else:
-            return column_op(f)(_to_cat(left), _to_cat(right))
+            return pyspark_column_op(func_name)(_to_cat(left), _to_cat(right))
     elif not is_list_like(right):
         categories = cast(CategoricalDtype, left.dtype).categories
         if right not in categories:
             raise TypeError("Cannot compare a Categorical with a scalar, which 
is not a category.")
         right_code = categories.get_loc(right)
-        return column_op(f)(left, right_code)
+        return pyspark_column_op(func_name)(left, right_code)
     else:
         raise TypeError("Cannot compare a Categorical with the given type.")
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to