This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 001da5d003c [SPARK-43671][PS][FOLLOWUP] Refine `CategoricalOps` functions 001da5d003c is described below commit 001da5d003caef3cda9978d35967ade55837e0bc Author: itholic <haejoon....@databricks.com> AuthorDate: Sun May 28 08:44:16 2023 +0800 [SPARK-43671][PS][FOLLOWUP] Refine `CategoricalOps` functions ### What changes were proposed in this pull request? This PR follow-up for SPARK-43671, to refine functions to use `pyspark_column_op` util for clean-up the code. ### Why are the changes needed? To avoid `is_remote` in too many places for future maintenance. ### Does this PR introduce _any_ user-facing change? No, it's code cleanup ### How was this patch tested? The existing CI should pass Closes #41326 from itholic/categorical_followup. Authored-by: itholic <haejoon....@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../pandas/data_type_ops/categorical_ops.py | 69 +++++----------------- 1 file changed, 14 insertions(+), 55 deletions(-) diff --git a/python/pyspark/pandas/data_type_ops/categorical_ops.py b/python/pyspark/pandas/data_type_ops/categorical_ops.py index 9f14a4b1ee7..66e181a6079 100644 --- a/python/pyspark/pandas/data_type_ops/categorical_ops.py +++ b/python/pyspark/pandas/data_type_ops/categorical_ops.py @@ -16,19 +16,18 @@ # from itertools import chain -from typing import cast, Any, Callable, Union +from typing import cast, Any, Union import pandas as pd import numpy as np from pandas.api.types import is_list_like, CategoricalDtype # type: ignore[attr-defined] from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex -from pyspark.pandas.base import column_op, IndexOpsMixin +from pyspark.pandas.base import IndexOpsMixin from pyspark.pandas.data_type_ops.base import _sanitize_list_like, DataTypeOps from pyspark.pandas.typedef import pandas_on_spark_type from pyspark.sql import functions as F -from pyspark.sql.column import Column as PySparkColumn -from pyspark.sql.utils import is_remote +from pyspark.sql.utils import pyspark_column_op class CategoricalOps(DataTypeOps): @@ -66,73 +65,33 @@ class CategoricalOps(DataTypeOps): def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - if is_remote(): - from pyspark.sql.connect.column import Column as ConnectColumn - - Column = ConnectColumn - else: - Column = PySparkColumn # type: ignore[assignment] - return _compare( - left, right, Column.__eq__, is_equality_comparison=True # type: ignore[arg-type] - ) + return _compare(left, right, "__eq__", is_equality_comparison=True) def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - if is_remote(): - from pyspark.sql.connect.column import Column as ConnectColumn - - Column = ConnectColumn - else: - Column = PySparkColumn # type: ignore[assignment] - return _compare( - left, right, Column.__ne__, is_equality_comparison=True # type: ignore[arg-type] - ) + return _compare(left, right, "__ne__", is_equality_comparison=True) def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - if is_remote(): - from pyspark.sql.connect.column import Column as ConnectColumn - - Column = ConnectColumn - else: - Column = PySparkColumn # type: ignore[assignment] - return _compare(left, right, Column.__lt__) # type: ignore[arg-type] + return _compare(left, right, "__lt__") def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - if is_remote(): - from pyspark.sql.connect.column import Column as ConnectColumn - - Column = ConnectColumn - else: - Column = PySparkColumn # type: ignore[assignment] - return _compare(left, right, Column.__le__) # type: ignore[arg-type] + return _compare(left, right, "__le__") def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - if is_remote(): - from pyspark.sql.connect.column import Column as ConnectColumn - - Column = ConnectColumn - else: - Column = PySparkColumn # type: ignore[assignment] - return _compare(left, right, Column.__gt__) # type: ignore[arg-type] + return _compare(left, right, "__gt__") def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - if is_remote(): - from pyspark.sql.connect.column import Column as ConnectColumn - - Column = ConnectColumn - else: - Column = PySparkColumn # type: ignore[assignment] - return _compare(left, right, Column.__ge__) # type: ignore[arg-type] + return _compare(left, right, "__ge__") def _compare( left: IndexOpsLike, right: Any, - f: Callable[..., PySparkColumn], + func_name: str, *, is_equality_comparison: bool = False, ) -> SeriesOrIndex: @@ -143,7 +102,7 @@ def _compare( ---------- left: A Categorical operand right: The other operand to compare with - f : The Spark Column function to apply + func_name: The Spark Column function name to apply is_equality_comparison: True if it is equality comparison, ie. == or !=. False by default. Returns @@ -158,15 +117,15 @@ def _compare( if hash(left.dtype) != hash(right.dtype): raise TypeError("Categoricals can only be compared if 'categories' are the same.") if cast(CategoricalDtype, left.dtype).ordered: - return column_op(f)(left, right) + return pyspark_column_op(func_name)(left, right) else: - return column_op(f)(_to_cat(left), _to_cat(right)) + return pyspark_column_op(func_name)(_to_cat(left), _to_cat(right)) elif not is_list_like(right): categories = cast(CategoricalDtype, left.dtype).categories if right not in categories: raise TypeError("Cannot compare a Categorical with a scalar, which is not a category.") right_code = categories.get_loc(right) - return column_op(f)(left, right_code) + return pyspark_column_op(func_name)(left, right_code) else: raise TypeError("Cannot compare a Categorical with the given type.") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org