This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 20a8fc87d67 [SPARK-43631][CONNECT][PS] Enable Series.interpolate with Spark Connect 20a8fc87d67 is described below commit 20a8fc87d67c842ac3386dc6ae0c53a9533900c2 Author: itholic <haejoon....@databricks.com> AuthorDate: Tue Jun 27 14:05:42 2023 -0700 [SPARK-43631][CONNECT][PS] Enable Series.interpolate with Spark Connect ### What changes were proposed in this pull request? This PR proposes to add `LastNonNull` and `NullIndex` to SparkConnectPlanner to enable `Series.interpolate`. ### Why are the changes needed? To increase pandas API coverage ### Does this PR introduce _any_ user-facing change? Yes, `Series.interpolate` will be available from this fix. ### How was this patch tested? Reusing the existing UT. Closes #41670 from itholic/interpolate. Authored-by: itholic <haejoon....@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../sql/connect/planner/SparkConnectPlanner.scala | 8 +++++++ python/pyspark/pandas/series.py | 9 ++++--- python/pyspark/pandas/spark/functions.py | 28 ++++++++++++++++++++++ .../tests/connect/test_parity_generic_functions.py | 4 +++- python/pyspark/sql/utils.py | 14 ++++++++++- 5 files changed, 56 insertions(+), 7 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index c19fc5fe90e..ff158990560 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1768,6 +1768,14 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { val ignoreNA = extractBoolean(children(2), "ignoreNA") Some(EWM(children(0), alpha, ignoreNA)) + case "last_non_null" if fun.getArgumentsCount == 1 => + val children = fun.getArgumentsList.asScala.map(transformExpression) + Some(LastNonNull(children(0))) + + case "null_index" if fun.getArgumentsCount == 1 => + val children = fun.getArgumentsList.asScala.map(transformExpression) + Some(NullIndex(children(0))) + // ML-specific functions case "vector_to_array" if fun.getArgumentsCount == 2 => val expr = transformExpression(fun.getArguments(0)) diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index 0f1e814946a..95ca92e7878 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -53,7 +53,6 @@ from pandas.api.types import ( # type: ignore[attr-defined] CategoricalDtype, ) from pandas.tseries.frequencies import DateOffset -from pyspark import SparkContext from pyspark.sql import functions as F, Column as PySparkColumn, DataFrame as SparkDataFrame from pyspark.sql.types import ( ArrayType, @@ -70,7 +69,7 @@ from pyspark.sql.types import ( TimestampType, ) from pyspark.sql.window import Window -from pyspark.sql.utils import get_column_class +from pyspark.sql.utils import get_column_class, get_window_class from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. from pyspark.pandas._typing import Axis, Dtype, Label, Name, Scalar, T @@ -2257,10 +2256,10 @@ class Series(Frame, IndexOpsMixin, Generic[T]): return self._psdf.copy()._psser_for(self._column_label) scol = self.spark.column - sql_utils = SparkContext._active_spark_context._jvm.PythonSQLUtils - last_non_null = PySparkColumn(sql_utils.lastNonNull(scol._jc)) - null_index = PySparkColumn(sql_utils.nullIndex(scol._jc)) + last_non_null = SF.last_non_null(scol) + null_index = SF.null_index(scol) + Window = get_window_class() window_forward = Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween( Window.unboundedPreceding, Window.currentRow ) diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py index 06d5692238d..44650fd4d20 100644 --- a/python/pyspark/pandas/spark/functions.py +++ b/python/pyspark/pandas/spark/functions.py @@ -157,3 +157,31 @@ def ewm(col: Column, alpha: float, ignore_na: bool) -> Column: else: sc = SparkContext._active_spark_context return Column(sc._jvm.PythonSQLUtils.ewm(col._jc, alpha, ignore_na)) + + +def last_non_null(col: Column) -> Column: + if is_remote(): + from pyspark.sql.connect.functions import _invoke_function_over_columns + + return _invoke_function_over_columns( # type: ignore[return-value] + "last_non_null", + col, # type: ignore[arg-type] + ) + + else: + sc = SparkContext._active_spark_context + return Column(sc._jvm.PythonSQLUtils.lastNonNull(col._jc)) + + +def null_index(col: Column) -> Column: + if is_remote(): + from pyspark.sql.connect.functions import _invoke_function_over_columns + + return _invoke_function_over_columns( # type: ignore[return-value] + "null_index", + col, # type: ignore[arg-type] + ) + + else: + sc = SparkContext._active_spark_context + return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc)) diff --git a/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py b/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py index d2c05893ae2..1bf2650d874 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py +++ b/python/pyspark/pandas/tests/connect/test_parity_generic_functions.py @@ -24,7 +24,9 @@ from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils class GenericFunctionsParityTests( GenericFunctionsTestsMixin, TestUtils, PandasOnSparkTestUtils, ReusedConnectTestCase ): - @unittest.skip("TODO(SPARK-43631): Enable Series.interpolate with Spark Connect.") + @unittest.skip( + "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client." + ) def test_interpolate(self): super().test_interpolate() diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 7ecfa65dcd1..608ed7e9ac9 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -46,6 +46,7 @@ if TYPE_CHECKING: from pyspark.sql.session import SparkSession from pyspark.sql.dataframe import DataFrame from pyspark.sql.column import Column + from pyspark.sql.window import Window from pyspark.pandas._typing import IndexOpsLike, SeriesOrIndex has_numpy = False @@ -188,7 +189,7 @@ def try_remote_window(f: FuncT) -> FuncT: def wrapped(*args: Any, **kwargs: Any) -> Any: if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: - from pyspark.sql.connect.window import Window + from pyspark.sql.connect.window import Window # type: ignore[misc] return getattr(Window, f.__name__)(*args, **kwargs) else: @@ -282,3 +283,14 @@ def get_dataframe_class() -> Type["DataFrame"]: return ConnectDataFrame # type: ignore[return-value] else: return PySparkDataFrame + + +def get_window_class() -> Type["Window"]: + from pyspark.sql.window import Window as PySparkWindow + + if is_remote(): + from pyspark.sql.connect.window import Window as ConnectWindow + + return ConnectWindow # type: ignore[return-value] + else: + return PySparkWindow --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org