This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6f0a73e457d [SPARK-43692][SPARK-43693][SPARK-43694][SPARK-43695][PS] 
Fix `StringOps` for Spark Connect
6f0a73e457d is described below

commit 6f0a73e457dd3c49a4adce996d7201010cdd2651
Author: itholic <haejoon....@databricks.com>
AuthorDate: Sun May 28 08:41:44 2023 +0800

    [SPARK-43692][SPARK-43693][SPARK-43694][SPARK-43695][PS] Fix `StringOps` 
for Spark Connect
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to fix `StringOps` test for pandas API on Spark with Spark 
Connect.
    
    This includes SPARK-43692, SPARK-43693, SPARK-43694, SPARK-43695 at once, 
because they are all related similar modifications in single file.
    
    ### Why are the changes needed?
    
    To support all features for pandas API on Spark with Spark Connect.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, `StringOps.lt`,  `StringOps.le`, `StringOps.ge`, `StringOps.gt` are 
now working as expected on Spark Connect.
    
    ### How was this patch tested?
    
    Uncomment the UTs, and tested manually.
    
    Closes #41308 from itholic/SPARK-43692-5.
    
    Authored-by: itholic <haejoon....@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/pandas/data_type_ops/string_ops.py      | 18 +++++-------------
 .../connect/data_type_ops/test_parity_string_ops.py    | 16 ----------------
 python/pyspark/sql/utils.py                            | 17 +++++++++++++++++
 3 files changed, 22 insertions(+), 29 deletions(-)

diff --git a/python/pyspark/pandas/data_type_ops/string_ops.py 
b/python/pyspark/pandas/data_type_ops/string_ops.py
index 0b9eb87a163..e5818cb4635 100644
--- a/python/pyspark/pandas/data_type_ops/string_ops.py
+++ b/python/pyspark/pandas/data_type_ops/string_ops.py
@@ -22,6 +22,7 @@ from pandas.api.types import CategoricalDtype
 
 from pyspark.sql import functions as F
 from pyspark.sql.types import IntegralType, StringType
+from pyspark.sql.utils import pyspark_column_op
 
 from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex
 from pyspark.pandas.base import column_op, IndexOpsMixin
@@ -34,7 +35,6 @@ from pyspark.pandas.data_type_ops.base import (
 )
 from pyspark.pandas.spark import functions as SF
 from pyspark.pandas.typedef import extension_dtypes, pandas_on_spark_type
-from pyspark.sql import Column
 from pyspark.sql.types import BooleanType
 
 
@@ -104,28 +104,20 @@ class StringOps(DataTypeOps):
             raise TypeError("Multiplication can not be applied to given 
types.")
 
     def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
-        from pyspark.pandas.base import column_op
-
         _sanitize_list_like(right)
-        return column_op(Column.__lt__)(left, right)
+        return pyspark_column_op("__lt__")(left, right)
 
     def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
-        from pyspark.pandas.base import column_op
-
         _sanitize_list_like(right)
-        return column_op(Column.__le__)(left, right)
+        return pyspark_column_op("__le__")(left, right)
 
     def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
-        from pyspark.pandas.base import column_op
-
         _sanitize_list_like(right)
-        return column_op(Column.__ge__)(left, right)
+        return pyspark_column_op("__ge__")(left, right)
 
     def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
-        from pyspark.pandas.base import column_op
-
         _sanitize_list_like(right)
-        return column_op(Column.__gt__)(left, right)
+        return pyspark_column_op("__gt__")(left, right)
 
     def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) 
-> IndexOpsLike:
         dtype, spark_type = pandas_on_spark_type(dtype)
diff --git 
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py 
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py
index 9abfe1d1e09..2d81db1c701 100644
--- 
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py
+++ 
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_string_ops.py
@@ -34,22 +34,6 @@ class StringOpsParityTests(
     def test_astype(self):
         super().test_astype()
 
-    @unittest.skip("TODO(SPARK-43692): Fix StringOps.ge to work with Spark 
Connect.")
-    def test_ge(self):
-        super().test_ge()
-
-    @unittest.skip("TODO(SPARK-43693): Fix StringOps.gt to work with Spark 
Connect.")
-    def test_gt(self):
-        super().test_gt()
-
-    @unittest.skip("TODO(SPARK-43694): Fix StringOps.le to work with Spark 
Connect.")
-    def test_le(self):
-        super().test_le()
-
-    @unittest.skip("TODO(SPARK-43695): Fix StringOps.lt to work with Spark 
Connect.")
-    def test_lt(self):
-        super().test_lt()
-
     @unittest.skip(
         "TODO(SPARK-43621): Enable pyspark.pandas.spark.functions.repeat in 
Spark Connect."
     )
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index 85cd5e4e9e7..841ceb4fa1d 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -45,6 +45,7 @@ from pyspark.find_spark_home import _find_spark_home
 if TYPE_CHECKING:
     from pyspark.sql.session import SparkSession
     from pyspark.sql.dataframe import DataFrame
+    from pyspark.pandas._typing import SeriesOrIndex
 
 has_numpy = False
 try:
@@ -234,3 +235,19 @@ def try_remote_observation(f: FuncT) -> FuncT:
         return f(*args, **kwargs)
 
     return cast(FuncT, wrapped)
+
+
+def pyspark_column_op(func_name: str) -> Callable[..., "SeriesOrIndex"]:
+    """
+    Wrapper function for column_op to get proper Column class.
+    """
+    from pyspark.pandas.base import column_op
+    from pyspark.sql.column import Column as PySparkColumn
+
+    if is_remote():
+        from pyspark.sql.connect.column import Column as ConnectColumn
+
+        Column = ConnectColumn
+    else:
+        Column = PySparkColumn  # type: ignore[assignment]
+    return column_op(getattr(Column, func_name))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to