This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9c8592058cd [SPARK-43684][SPARK-43685][SPARK-43686][SPARK-43691][CONNECT][PS] Fix `(NullOps|NumOps).(eq|ne)` for Spark Connect 9c8592058cd is described below commit 9c8592058cd1a5cc530fb30e6dc7c5c759ad528d Author: itholic <haejoon....@databricks.com> AuthorDate: Wed Jun 14 09:38:22 2023 +0900 [SPARK-43684][SPARK-43685][SPARK-43686][SPARK-43691][CONNECT][PS] Fix `(NullOps|NumOps).(eq|ne)` for Spark Connect ### What changes were proposed in this pull request? This PR proposes to fix `NullOps.(eq|ne)` and `NumOps.(eq|ne)` for pandas API on Spark with Spark Connect. This includes SPARK-43684, SPARK-43685, SPARK-43686, SPARK-43691 at once, because they are all related similar modifications in single file. This PR also introduce new util function `_is_extension_dtypes` to check whether the given object is a type of extension dtype or not, and apply to all related functions. ### Why are the changes needed? The reason is that pandas API on Spark with Spark Connect operates differently from pandas as below: **For `ne`:** ```python >>> pser = pd.Series([1.0, 2.0, np.nan]) >>> psser = ps.from_pandas(pser) >>> pser.ne(pser) 0 False 1 False 2 True dtype: bool >>> psser.ne(psser) 0 False 1 False 2 None dtype: bool ``` We expect `True` for non-equal case, but it returns `None` in Spark Connect. So we should cast `None` to `True` for `ne`. **For `eq`:** ```python >>> pser = pd.Series([1.0, 2.0, np.nan]) >>> psser = ps.from_pandas(pser) >>> pser.eq(pser) 0 True 1 True 2 False dtype: bool >>> psser.eq(psser) 0 True 1 True 2 None dtype: bool ``` We expect `False` for non-equal case, but it returns `None` in Spark Connect. So we should cast `None` to `False` for `eq`. ### Does this PR introduce _any_ user-facing change? Yes, `NullOps.eq`, `NullOps.ne`, `NumOps.eq`, `NumOps.ne` are now working as expected on Spark Connect. ### How was this patch tested? Uncomment the UTs, tested manually for vanilla PySpark. Closes #41514 from itholic/SPARK-43684. Authored-by: itholic <haejoon....@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/pandas/data_type_ops/base.py | 9 +++++ python/pyspark/pandas/data_type_ops/binary_ops.py | 8 ++--- .../pandas/data_type_ops/categorical_ops.py | 6 ++-- .../pyspark/pandas/data_type_ops/datetime_ops.py | 8 ++--- python/pyspark/pandas/data_type_ops/null_ops.py | 39 ++++++++++------------ python/pyspark/pandas/data_type_ops/num_ops.py | 35 +++++++++---------- python/pyspark/pandas/data_type_ops/string_ops.py | 8 ++--- .../pyspark/pandas/data_type_ops/timedelta_ops.py | 12 +++---- .../connect/data_type_ops/test_parity_null_ops.py | 8 ----- .../connect/data_type_ops/test_parity_num_ops.py | 8 ----- .../pandas/tests/data_type_ops/test_null_ops.py | 1 + python/pyspark/sql/utils.py | 16 ++++++--- 12 files changed, 75 insertions(+), 83 deletions(-) diff --git a/python/pyspark/pandas/data_type_ops/base.py b/python/pyspark/pandas/data_type_ops/base.py index d88eddee26b..18e792e292f 100644 --- a/python/pyspark/pandas/data_type_ops/base.py +++ b/python/pyspark/pandas/data_type_ops/base.py @@ -219,6 +219,15 @@ def _is_boolean_type(right: Any) -> bool: ) +def _is_extension_dtypes(object: Any) -> bool: + """ + Check whether the type of given object is extension dtype or not. + Extention dtype includes Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype, BooleanDtype, + StringDtype, Float32Dtype and Float64Dtype. + """ + return isinstance(getattr(object, "dtype", None), extension_dtypes) + + class DataTypeOps(object, metaclass=ABCMeta): """The base class for binary operations of pandas-on-Spark objects (of different data types).""" diff --git a/python/pyspark/pandas/data_type_ops/binary_ops.py b/python/pyspark/pandas/data_type_ops/binary_ops.py index ba31156178a..f528d3e9ae2 100644 --- a/python/pyspark/pandas/data_type_ops/binary_ops.py +++ b/python/pyspark/pandas/data_type_ops/binary_ops.py @@ -69,19 +69,19 @@ class BinaryOps(DataTypeOps): def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - return pyspark_column_op("__lt__")(left, right) + return pyspark_column_op("__lt__", left, right) def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - return pyspark_column_op("__le__")(left, right) + return pyspark_column_op("__le__", left, right) def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - return pyspark_column_op("__ge__")(left, right) + return pyspark_column_op("__ge__", left, right) def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - return pyspark_column_op("__gt__")(left, right) + return pyspark_column_op("__gt__", left, right) def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike: dtype, spark_type = pandas_on_spark_type(dtype) diff --git a/python/pyspark/pandas/data_type_ops/categorical_ops.py b/python/pyspark/pandas/data_type_ops/categorical_ops.py index 66e181a6079..824666b5819 100644 --- a/python/pyspark/pandas/data_type_ops/categorical_ops.py +++ b/python/pyspark/pandas/data_type_ops/categorical_ops.py @@ -117,15 +117,15 @@ def _compare( if hash(left.dtype) != hash(right.dtype): raise TypeError("Categoricals can only be compared if 'categories' are the same.") if cast(CategoricalDtype, left.dtype).ordered: - return pyspark_column_op(func_name)(left, right) + return pyspark_column_op(func_name, left, right) else: - return pyspark_column_op(func_name)(_to_cat(left), _to_cat(right)) + return pyspark_column_op(func_name, _to_cat(left), _to_cat(right)) elif not is_list_like(right): categories = cast(CategoricalDtype, left.dtype).categories if right not in categories: raise TypeError("Cannot compare a Categorical with a scalar, which is not a category.") right_code = categories.get_loc(right) - return pyspark_column_op(func_name)(left, right_code) + return pyspark_column_op(func_name, left, right_code) else: raise TypeError("Cannot compare a Categorical with the given type.") diff --git a/python/pyspark/pandas/data_type_ops/datetime_ops.py b/python/pyspark/pandas/data_type_ops/datetime_ops.py index c5f4df96bde..ea9b994076b 100644 --- a/python/pyspark/pandas/data_type_ops/datetime_ops.py +++ b/python/pyspark/pandas/data_type_ops/datetime_ops.py @@ -111,19 +111,19 @@ class DatetimeOps(DataTypeOps): def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - return pyspark_column_op("__lt__")(left, right) + return pyspark_column_op("__lt__", left, right) def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - return pyspark_column_op("__le__")(left, right) + return pyspark_column_op("__le__", left, right) def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - return pyspark_column_op("__ge__")(left, right) + return pyspark_column_op("__ge__", left, right) def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - return pyspark_column_op("__gt__")(left, right) + return pyspark_column_op("__gt__", left, right) def prepare(self, col: pd.Series) -> pd.Series: """Prepare column when from_pandas.""" diff --git a/python/pyspark/pandas/data_type_ops/null_ops.py b/python/pyspark/pandas/data_type_ops/null_ops.py index ab86f074b99..329a3790df6 100644 --- a/python/pyspark/pandas/data_type_ops/null_ops.py +++ b/python/pyspark/pandas/data_type_ops/null_ops.py @@ -17,7 +17,7 @@ from typing import Any, Union -from pandas.api.types import CategoricalDtype +from pandas.api.types import CategoricalDtype, is_list_like # type: ignore[attr-defined] from pyspark.pandas._typing import Dtype, IndexOpsLike from pyspark.pandas.data_type_ops.base import ( @@ -31,7 +31,8 @@ from pyspark.pandas.data_type_ops.base import ( from pyspark.pandas._typing import SeriesOrIndex from pyspark.pandas.typedef import pandas_on_spark_type from pyspark.sql.types import BooleanType, StringType -from pyspark.sql.utils import pyspark_column_op, is_remote +from pyspark.sql.utils import pyspark_column_op +from pyspark.pandas.base import IndexOpsMixin class NullOps(DataTypeOps): @@ -43,37 +44,31 @@ class NullOps(DataTypeOps): def pretty_name(self) -> str: return "nulls" + def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: + # We can directly use `super().eq` when given object is list, tuple, dict or set. + if not isinstance(right, IndexOpsMixin) and is_list_like(right): + return super().eq(left, right) + return pyspark_column_op("__eq__", left, right, fillna=False) + + def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: + _sanitize_list_like(right) + return pyspark_column_op("__ne__", left, right, fillna=True) + def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - result = pyspark_column_op("__lt__")(left, right) - if is_remote(): - # TODO(SPARK-43877): Fix behavior difference for compare binary functions. - result = result.fillna(False) - return result + return pyspark_column_op("__lt__", left, right, fillna=False) def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - result = pyspark_column_op("__le__")(left, right) - if is_remote(): - # TODO(SPARK-43877): Fix behavior difference for compare binary functions. - result = result.fillna(False) - return result + return pyspark_column_op("__le__", left, right, fillna=False) def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - result = pyspark_column_op("__ge__")(left, right) - if is_remote(): - # TODO(SPARK-43877): Fix behavior difference for compare binary functions. - result = result.fillna(False) - return result + return pyspark_column_op("__ge__", left, right, fillna=False) def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - result = pyspark_column_op("__gt__")(left, right) - if is_remote(): - # TODO(SPARK-43877): Fix behavior difference for compare binary functions. - result = result.fillna(False) - return result + return pyspark_column_op("__gt__", left, right, fillna=False) def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike: dtype, spark_type = pandas_on_spark_type(dtype) diff --git a/python/pyspark/pandas/data_type_ops/num_ops.py b/python/pyspark/pandas/data_type_ops/num_ops.py index 9e7e2037a92..911228a5265 100644 --- a/python/pyspark/pandas/data_type_ops/num_ops.py +++ b/python/pyspark/pandas/data_type_ops/num_ops.py @@ -24,6 +24,7 @@ from pandas.api.types import ( # type: ignore[attr-defined] is_bool_dtype, is_integer_dtype, CategoricalDtype, + is_list_like, ) from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex @@ -213,37 +214,31 @@ class NumericOps(DataTypeOps): F.abs(operand.spark.column), field=operand._internal.data_fields[0] ) + def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: + # We can directly use `super().eq` when given object is list, tuple, dict or set. + if not isinstance(right, IndexOpsMixin) and is_list_like(right): + return super().eq(left, right) + return pyspark_column_op("__eq__", left, right, fillna=False) + + def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: + _sanitize_list_like(right) + return pyspark_column_op("__ne__", left, right, fillna=True) + def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - result = pyspark_column_op("__lt__")(left, right) - if is_remote(): - # TODO(SPARK-43877): Fix behavior difference for compare binary functions. - result = result.fillna(False) - return result + return pyspark_column_op("__lt__", left, right, fillna=False) def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - result = pyspark_column_op("__le__")(left, right) - if is_remote(): - # TODO(SPARK-43877): Fix behavior difference for compare binary functions. - result = result.fillna(False) - return result + return pyspark_column_op("__le__", left, right, fillna=False) def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - result = pyspark_column_op("__ge__")(left, right) - if is_remote(): - # TODO(SPARK-43877): Fix behavior difference for compare binary functions. - result = result.fillna(False) - return result + return pyspark_column_op("__ge__", left, right, fillna=False) def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - result = pyspark_column_op("__gt__")(left, right) - if is_remote(): - # TODO(SPARK-43877): Fix behavior difference for compare binary functions. - result = result.fillna(False) - return result + return pyspark_column_op("__gt__", left, right, fillna=False) class IntegralOps(NumericOps): diff --git a/python/pyspark/pandas/data_type_ops/string_ops.py b/python/pyspark/pandas/data_type_ops/string_ops.py index e5818cb4635..1c282f20117 100644 --- a/python/pyspark/pandas/data_type_ops/string_ops.py +++ b/python/pyspark/pandas/data_type_ops/string_ops.py @@ -105,19 +105,19 @@ class StringOps(DataTypeOps): def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - return pyspark_column_op("__lt__")(left, right) + return pyspark_column_op("__lt__", left, right) def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - return pyspark_column_op("__le__")(left, right) + return pyspark_column_op("__le__", left, right) def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - return pyspark_column_op("__ge__")(left, right) + return pyspark_column_op("__ge__", left, right) def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - return pyspark_column_op("__gt__")(left, right) + return pyspark_column_op("__gt__", left, right) def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike: dtype, spark_type = pandas_on_spark_type(dtype) diff --git a/python/pyspark/pandas/data_type_ops/timedelta_ops.py b/python/pyspark/pandas/data_type_ops/timedelta_ops.py index 3e96ebbb13a..7a9da8511e6 100644 --- a/python/pyspark/pandas/data_type_ops/timedelta_ops.py +++ b/python/pyspark/pandas/data_type_ops/timedelta_ops.py @@ -72,7 +72,7 @@ class TimedeltaOps(DataTypeOps): and isinstance(right.spark.data_type, DayTimeIntervalType) or isinstance(right, timedelta) ): - return pyspark_column_op("__sub__")(left, right) + return pyspark_column_op("__sub__", left, right) else: raise TypeError("Timedelta subtraction can only be applied to timedelta series.") @@ -80,22 +80,22 @@ class TimedeltaOps(DataTypeOps): _sanitize_list_like(right) if isinstance(right, timedelta): - return pyspark_column_op("__rsub__")(left, right) + return pyspark_column_op("__rsub__", left, right) else: raise TypeError("Timedelta subtraction can only be applied to timedelta series.") def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - return pyspark_column_op("__lt__")(left, right) + return pyspark_column_op("__lt__", left, right) def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - return pyspark_column_op("__le__")(left, right) + return pyspark_column_op("__le__", left, right) def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - return pyspark_column_op("__ge__")(left, right) + return pyspark_column_op("__ge__", left, right) def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: _sanitize_list_like(right) - return pyspark_column_op("__gt__")(left, right) + return pyspark_column_op("__gt__", left, right) diff --git a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py index 1b53a064971..63b53c02fd7 100644 --- a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py +++ b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_null_ops.py @@ -29,14 +29,6 @@ class NullOpsParityTests( def test_astype(self): super().test_astype() - @unittest.skip("TODO(SPARK-43684): Fix NullOps.eq to work with Spark Connect Column.") - def test_eq(self): - super().test_eq() - - @unittest.skip("TODO(SPARK-43685): Fix NullOps.ne to work with Spark Connect Column.") - def test_ne(self): - super().test_ne() - if __name__ == "__main__": from pyspark.pandas.tests.connect.data_type_ops.test_parity_null_ops import * # noqa: F401 diff --git a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py index b65873c6ab5..04aa24c4045 100644 --- a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py +++ b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_num_ops.py @@ -34,14 +34,6 @@ class NumOpsParityTests( def test_astype(self): super().test_astype() - @unittest.skip("TODO(SPARK-43686): Enable NumOpsParityTests.test_eq.") - def test_eq(self): - super().test_eq() - - @unittest.skip("TODO(SPARK-43691): Enable NumOpsParityTests.test_ne.") - def test_ne(self): - super().test_ne() - if __name__ == "__main__": from pyspark.pandas.tests.connect.data_type_ops.test_parity_num_ops import * # noqa: F401 diff --git a/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py index 22ea26050bf..19a3e7c0735 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py @@ -138,6 +138,7 @@ class NullOpsTestsMixin: def test_eq(self): pser, psser = self.pser, self.psser self.assert_eq(pser == pser, psser == psser) + self.assert_eq(pser == [None, 1, None], psser == [None, 1, None]) def test_ne(self): pser, psser = self.pser, self.psser diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 841ceb4fa1d..f5a5c88b8d3 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -16,7 +16,7 @@ # import functools import os -from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING, cast, TypeVar +from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING, cast, TypeVar, Union from py4j.java_collections import JavaArray from py4j.java_gateway import ( @@ -45,7 +45,7 @@ from pyspark.find_spark_home import _find_spark_home if TYPE_CHECKING: from pyspark.sql.session import SparkSession from pyspark.sql.dataframe import DataFrame - from pyspark.pandas._typing import SeriesOrIndex + from pyspark.pandas._typing import IndexOpsLike, SeriesOrIndex has_numpy = False try: @@ -237,12 +237,15 @@ def try_remote_observation(f: FuncT) -> FuncT: return cast(FuncT, wrapped) -def pyspark_column_op(func_name: str) -> Callable[..., "SeriesOrIndex"]: +def pyspark_column_op( + func_name: str, left: "IndexOpsLike", right: Any, fillna: Any = None +) -> Union["SeriesOrIndex", None]: """ Wrapper function for column_op to get proper Column class. """ from pyspark.pandas.base import column_op from pyspark.sql.column import Column as PySparkColumn + from pyspark.pandas.data_type_ops.base import _is_extension_dtypes if is_remote(): from pyspark.sql.connect.column import Column as ConnectColumn @@ -250,4 +253,9 @@ def pyspark_column_op(func_name: str) -> Callable[..., "SeriesOrIndex"]: Column = ConnectColumn else: Column = PySparkColumn # type: ignore[assignment] - return column_op(getattr(Column, func_name)) + result = column_op(getattr(Column, func_name))(left, right) + # It works as expected on extension dtype, so we don't need to call `fillna` for this case. + if (fillna is not None) and (_is_extension_dtypes(left) or _is_extension_dtypes(right)): + fillna = None + # TODO(SPARK-43877): Fix behavior difference for compare binary functions. + return result.fillna(fillna) if fillna is not None else result --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org