This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c686313bb7f [SPARK-43671][SPARK-43672][SPARK-43673][SPARK-43674][PS] 
Fix `CategoricalOps` for Spark Connect
c686313bb7f is described below

commit c686313bb7f2288cdda5b85b33aa4f3ebfea7760
Author: itholic <haejoon....@databricks.com>
AuthorDate: Fri May 26 09:36:30 2023 +0800

    [SPARK-43671][SPARK-43672][SPARK-43673][SPARK-43674][PS] Fix 
`CategoricalOps` for Spark Connect
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to fix `CategoricalOps` test for pandas API on Spark with 
Spark Connect.
    
    This includes SPARK-43671, SPARK-43672, SPARK-43673, SPARK-43674 at once, 
because they are all related similar modifications in single file.
    
    ### Why are the changes needed?
    
    To support all features for pandas API on Spark with Spark Connect.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, `CategoricalOps.lt`,  `CategoricalOps.le`, `CategoricalOps.ge`, 
`CategoricalOps.gt` are now working as expected on Spark Connect.
    
    ### How was this patch tested?
    
    Uncomment the UTs, and tested manually.
    
    Closes #41310 from itholic/SPARK-43671-4.
    
    Authored-by: itholic <haejoon....@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../pandas/data_type_ops/categorical_ops.py        | 57 +++++++++++++++++++---
 .../data_type_ops/test_parity_categorical_ops.py   | 16 ------
 2 files changed, 49 insertions(+), 24 deletions(-)

diff --git a/python/pyspark/pandas/data_type_ops/categorical_ops.py 
b/python/pyspark/pandas/data_type_ops/categorical_ops.py
index ad7e46192bf..9f14a4b1ee7 100644
--- a/python/pyspark/pandas/data_type_ops/categorical_ops.py
+++ b/python/pyspark/pandas/data_type_ops/categorical_ops.py
@@ -27,7 +27,8 @@ from pyspark.pandas.base import column_op, IndexOpsMixin
 from pyspark.pandas.data_type_ops.base import _sanitize_list_like, DataTypeOps
 from pyspark.pandas.typedef import pandas_on_spark_type
 from pyspark.sql import functions as F
-from pyspark.sql.column import Column
+from pyspark.sql.column import Column as PySparkColumn
+from pyspark.sql.utils import is_remote
 
 
 class CategoricalOps(DataTypeOps):
@@ -65,33 +66,73 @@ class CategoricalOps(DataTypeOps):
 
     def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
-        return _compare(left, right, Column.__eq__, 
is_equality_comparison=True)
+        if is_remote():
+            from pyspark.sql.connect.column import Column as ConnectColumn
+
+            Column = ConnectColumn
+        else:
+            Column = PySparkColumn  # type: ignore[assignment]
+        return _compare(
+            left, right, Column.__eq__, is_equality_comparison=True  # type: 
ignore[arg-type]
+        )
 
     def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
-        return _compare(left, right, Column.__ne__, 
is_equality_comparison=True)
+        if is_remote():
+            from pyspark.sql.connect.column import Column as ConnectColumn
+
+            Column = ConnectColumn
+        else:
+            Column = PySparkColumn  # type: ignore[assignment]
+        return _compare(
+            left, right, Column.__ne__, is_equality_comparison=True  # type: 
ignore[arg-type]
+        )
 
     def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
-        return _compare(left, right, Column.__lt__)
+        if is_remote():
+            from pyspark.sql.connect.column import Column as ConnectColumn
+
+            Column = ConnectColumn
+        else:
+            Column = PySparkColumn  # type: ignore[assignment]
+        return _compare(left, right, Column.__lt__)  # type: ignore[arg-type]
 
     def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
-        return _compare(left, right, Column.__le__)
+        if is_remote():
+            from pyspark.sql.connect.column import Column as ConnectColumn
+
+            Column = ConnectColumn
+        else:
+            Column = PySparkColumn  # type: ignore[assignment]
+        return _compare(left, right, Column.__le__)  # type: ignore[arg-type]
 
     def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
-        return _compare(left, right, Column.__gt__)
+        if is_remote():
+            from pyspark.sql.connect.column import Column as ConnectColumn
+
+            Column = ConnectColumn
+        else:
+            Column = PySparkColumn  # type: ignore[assignment]
+        return _compare(left, right, Column.__gt__)  # type: ignore[arg-type]
 
     def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
-        return _compare(left, right, Column.__ge__)
+        if is_remote():
+            from pyspark.sql.connect.column import Column as ConnectColumn
+
+            Column = ConnectColumn
+        else:
+            Column = PySparkColumn  # type: ignore[assignment]
+        return _compare(left, right, Column.__ge__)  # type: ignore[arg-type]
 
 
 def _compare(
     left: IndexOpsLike,
     right: Any,
-    f: Callable[..., Column],
+    f: Callable[..., PySparkColumn],
     *,
     is_equality_comparison: bool = False,
 ) -> SeriesOrIndex:
diff --git 
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py
 
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py
index dc196060bfc..44243192d50 100644
--- 
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py
+++ 
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_categorical_ops.py
@@ -38,22 +38,6 @@ class CategoricalOpsParityTests(
     def test_eq(self):
         super().test_eq()
 
-    @unittest.skip("TODO(SPARK-43671): Enable CategoricalOps.ge to work with 
Spark Connect.")
-    def test_ge(self):
-        super().test_ge()
-
-    @unittest.skip("TODO(SPARK-43672): Enable CategoricalOps.gt to work with 
Spark Connect.")
-    def test_gt(self):
-        super().test_gt()
-
-    @unittest.skip("TODO(SPARK-43673): Enable CategoricalOps.le to work with 
Spark Connect.")
-    def test_le(self):
-        super().test_le()
-
-    @unittest.skip("TODO(SPARK-43674): Enable CategoricalOps.lt to work with 
Spark Connect.")
-    def test_lt(self):
-        super().test_lt()
-
     @unittest.skip("TODO(SPARK-43675): Enable CategoricalOps.ne to work with 
Spark Connect.")
     def test_ne(self):
         super().test_ne()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to