This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c686313bb7f [SPARK-43671][SPARK-43672][SPARK-43673][SPARK-43674][PS] Fix `CategoricalOps` for Spark Connect c686313bb7f is described below commit c686313bb7f2288cdda5b85b33aa4f3ebfea7760 Author: itholic <haejoon....@databricks.com> AuthorDate: Fri May 26 09:36:30 2023 +0800 [SPARK-43671][SPARK-43672][SPARK-43673][SPARK-43674][PS] Fix `CategoricalOps` for Spark Connect ### What changes were proposed in this pull request? This PR proposes to fix `CategoricalOps` test for pandas API on Spark with Spark Connect. This includes SPARK-43671, SPARK-43672, SPARK-43673, SPARK-43674 at once, because they are all related similar modifications in single file. ### Why are the changes needed? To support all features for pandas API on Spark with Spark Connect. ### Does this PR introduce _any_ user-facing change? Yes, `CategoricalOps.lt`, `CategoricalOps.le`, `CategoricalOps.ge`, `CategoricalOps.gt` are now working as expected on Spark Connect. ### How was this patch tested? Uncomment the UTs, and tested manually. Closes #41310 from itholic/SPARK-43671-4. Authored-by: itholic <haejoon....@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../pandas/data_type_ops/categorical_ops.py | 57 +++++++++++++++++++--- .../data_type_ops/test_parity_categorical_ops.py | 16 ------ 2 files changed, 49 insertions(+), 24 deletions(-) diff --git a/python/pyspark/pandas/data_type_ops/categorical_ops.py b/python/pyspark/pandas/data_type_ops/categorical_ops.py index ad7e46192bf..9f14a4b1ee7 100644 --- a/python/pyspark/pandas/data_type_ops/categorical_ops.py +++ b/python/pyspark/pandas/data_type_ops/categorical_ops.py @@ -27,7 +27,8 @@ from pyspark.pandas.base import column_op, IndexOpsMixin from pyspark.pandas.data_type_ops.base import _sanitize_list_like, DataTypeOps from pyspark.pandas.typedef import pandas_on_spark_type from pyspark.sql import functions as F -from pyspark.sql.column import Column +from pyspark.sql.column import Column as PySparkColumn +from pyspark.sql.utils import is_remote class CategoricalOps(DataTypeOps): @@ -65,33 +66,73 @@ class CategoricalOps(DataTypeOps): def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - return _compare(left, right, Column.__eq__, is_equality_comparison=True) + if is_remote(): + from pyspark.sql.connect.column import Column as ConnectColumn + + Column = ConnectColumn + else: + Column = PySparkColumn # type: ignore[assignment] + return _compare( + left, right, Column.__eq__, is_equality_comparison=True # type: ignore[arg-type] + ) def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - return _compare(left, right, Column.__ne__, is_equality_comparison=True) + if is_remote(): + from pyspark.sql.connect.column import Column as ConnectColumn + + Column = ConnectColumn + else: + Column = PySparkColumn # type: ignore[assignment] + return _compare( + left, right, Column.__ne__, is_equality_comparison=True # type: ignore[arg-type] + ) def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - return _compare(left, right, Column.__lt__) + if is_remote(): + from pyspark.sql.connect.column import Column as ConnectColumn + + Column = ConnectColumn + else: + Column = PySparkColumn # type: ignore[assignment] + return _compare(left, right, Column.__lt__) # type: ignore[arg-type] def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - return _compare(left, right, Column.__le__) + if is_remote(): + from pyspark.sql.connect.column import Column as ConnectColumn + + Column = ConnectColumn + else: + Column = PySparkColumn # type: ignore[assignment] + return _compare(left, right, Column.__le__) # type: ignore[arg-type] def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - return _compare(left, right, Column.__gt__) + if is_remote(): + from pyspark.sql.connect.column import Column as ConnectColumn + + Column = ConnectColumn + else: + Column = PySparkColumn # type: ignore[assignment] + return _compare(left, right, Column.__gt__) # type: ignore[arg-type] def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - return _compare(left, right, Column.__ge__) + if is_remote(): + from pyspark.sql.connect.column import Column as ConnectColumn + + Column = ConnectColumn + else: + Column = PySparkColumn # type: ignore[assignment] + return _compare(left, right, Column.__ge__) # type: ignore[arg-type] def _compare( left: IndexOpsLike, right: Any, - f: Callable[..., Column], + f: Callable[..., PySparkColumn], *, is_equality_comparison: bool = False, ) -> SeriesOrIndex: diff --git a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py index dc196060bfc..44243192d50 100644 --- a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py +++ b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py @@ -38,22 +38,6 @@ class CategoricalOpsParityTests( def test_eq(self): super().test_eq() - @unittest.skip("TODO(SPARK-43671): Enable CategoricalOps.ge to work with Spark Connect.") - def test_ge(self): - super().test_ge() - - @unittest.skip("TODO(SPARK-43672): Enable CategoricalOps.gt to work with Spark Connect.") - def test_gt(self): - super().test_gt() - - @unittest.skip("TODO(SPARK-43673): Enable CategoricalOps.le to work with Spark Connect.") - def test_le(self): - super().test_le() - - @unittest.skip("TODO(SPARK-43674): Enable CategoricalOps.lt to work with Spark Connect.") - def test_lt(self): - super().test_lt() - @unittest.skip("TODO(SPARK-43675): Enable CategoricalOps.ne to work with Spark Connect.") def test_ne(self): super().test_ne() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org