This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 6f0a73e457d [SPARK-43692][SPARK-43693][SPARK-43694][SPARK-43695][PS] Fix `StringOps` for Spark Connect 6f0a73e457d is described below commit 6f0a73e457dd3c49a4adce996d7201010cdd2651 Author: itholic <haejoon....@databricks.com> AuthorDate: Sun May 28 08:41:44 2023 +0800 [SPARK-43692][SPARK-43693][SPARK-43694][SPARK-43695][PS] Fix `StringOps` for Spark Connect ### What changes were proposed in this pull request? This PR proposes to fix `StringOps` test for pandas API on Spark with Spark Connect. This includes SPARK-43692, SPARK-43693, SPARK-43694, SPARK-43695 at once, because they are all related similar modifications in single file. ### Why are the changes needed? To support all features for pandas API on Spark with Spark Connect. ### Does this PR introduce _any_ user-facing change? Yes, `StringOps.lt`, `StringOps.le`, `StringOps.ge`, `StringOps.gt` are now working as expected on Spark Connect. ### How was this patch tested? Uncomment the UTs, and tested manually. Closes #41308 from itholic/SPARK-43692-5. Authored-by: itholic <haejoon....@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/pandas/data_type_ops/string_ops.py | 18 +++++------------- .../connect/data_type_ops/test_parity_string_ops.py | 16 ---------------- python/pyspark/sql/utils.py | 17 +++++++++++++++++ 3 files changed, 22 insertions(+), 29 deletions(-) diff --git a/python/pyspark/pandas/data_type_ops/string_ops.py b/python/pyspark/pandas/data_type_ops/string_ops.py index 0b9eb87a163..e5818cb4635 100644 --- a/python/pyspark/pandas/data_type_ops/string_ops.py +++ b/python/pyspark/pandas/data_type_ops/string_ops.py @@ -22,6 +22,7 @@ from pandas.api.types import CategoricalDtype from pyspark.sql import functions as F from pyspark.sql.types import IntegralType, StringType +from pyspark.sql.utils import pyspark_column_op from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex from pyspark.pandas.base import column_op, IndexOpsMixin @@ -34,7 +35,6 @@ from pyspark.pandas.data_type_ops.base import ( ) from pyspark.pandas.spark import functions as SF from pyspark.pandas.typedef import extension_dtypes, pandas_on_spark_type -from pyspark.sql import Column from pyspark.sql.types import BooleanType @@ -104,28 +104,20 @@ class StringOps(DataTypeOps): raise TypeError("Multiplication can not be applied to given types.") def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: - from pyspark.pandas.base import column_op - _sanitize_list_like(right) - return column_op(Column.__lt__)(left, right) + return pyspark_column_op("__lt__")(left, right) def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: - from pyspark.pandas.base import column_op - _sanitize_list_like(right) - return column_op(Column.__le__)(left, right) + return pyspark_column_op("__le__")(left, right) def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: - from pyspark.pandas.base import column_op - _sanitize_list_like(right) - return column_op(Column.__ge__)(left, right) + return pyspark_column_op("__ge__")(left, right) def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: - from pyspark.pandas.base import column_op - _sanitize_list_like(right) - return column_op(Column.__gt__)(left, right) + return pyspark_column_op("__gt__")(left, right) def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike: dtype, spark_type = pandas_on_spark_type(dtype) diff --git a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py index 9abfe1d1e09..2d81db1c701 100644 --- a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py +++ b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py @@ -34,22 +34,6 @@ class StringOpsParityTests( def test_astype(self): super().test_astype() - @unittest.skip("TODO(SPARK-43692): Fix StringOps.ge to work with Spark Connect.") - def test_ge(self): - super().test_ge() - - @unittest.skip("TODO(SPARK-43693): Fix StringOps.gt to work with Spark Connect.") - def test_gt(self): - super().test_gt() - - @unittest.skip("TODO(SPARK-43694): Fix StringOps.le to work with Spark Connect.") - def test_le(self): - super().test_le() - - @unittest.skip("TODO(SPARK-43695): Fix StringOps.lt to work with Spark Connect.") - def test_lt(self): - super().test_lt() - @unittest.skip( "TODO(SPARK-43621): Enable pyspark.pandas.spark.functions.repeat in Spark Connect." ) diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 85cd5e4e9e7..841ceb4fa1d 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -45,6 +45,7 @@ from pyspark.find_spark_home import _find_spark_home if TYPE_CHECKING: from pyspark.sql.session import SparkSession from pyspark.sql.dataframe import DataFrame + from pyspark.pandas._typing import SeriesOrIndex has_numpy = False try: @@ -234,3 +235,19 @@ def try_remote_observation(f: FuncT) -> FuncT: return f(*args, **kwargs) return cast(FuncT, wrapped) + + +def pyspark_column_op(func_name: str) -> Callable[..., "SeriesOrIndex"]: + """ + Wrapper function for column_op to get proper Column class. + """ + from pyspark.pandas.base import column_op + from pyspark.sql.column import Column as PySparkColumn + + if is_remote(): + from pyspark.sql.connect.column import Column as ConnectColumn + + Column = ConnectColumn + else: + Column = PySparkColumn # type: ignore[assignment] + return column_op(getattr(Column, func_name)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org