This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 381dd7943e5 [SPARK-41222][CONNECT][PYTHON] Unify the typing definitions 381dd7943e5 is described below commit 381dd7943e52483b1a10cb6d15c980e375631052 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu Nov 24 10:44:08 2022 +0800 [SPARK-41222][CONNECT][PYTHON] Unify the typing definitions ### What changes were proposed in this pull request? 1, remove `__init__.py` 2, rename `ColumnOrString ` as `ColumnOrName` to be the same as pyspark ### Why are the changes needed? 1, there are two typing files now: `_typing.py` and `__init__.py`, they are used in different files, which is very confusing; 2, the definitions of `LiteralType` are different, the old one in `_typing.py` was never used 3, both `ColumnOrString ` and `ColumnOrName` are used now; ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing UTs Closes #38757 from zhengruifeng/connect_typing. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/connect/_typing.py | 41 +++++++++++++++++++++++--- python/pyspark/sql/connect/client.py | 25 +++++++--------- python/pyspark/sql/connect/column.py | 5 +++- python/pyspark/sql/connect/dataframe.py | 12 ++++---- python/pyspark/sql/connect/function_builder.py | 10 +++++-- python/pyspark/sql/connect/plan.py | 4 +-- python/pyspark/sql/connect/readwriter.py | 5 +--- python/pyspark/sql/connect/typing/__init__.pyi | 35 ---------------------- 8 files changed, 66 insertions(+), 71 deletions(-) diff --git a/python/pyspark/sql/connect/_typing.py b/python/pyspark/sql/connect/_typing.py index 4e69b2e4aa5..262d71fcea1 100644 --- a/python/pyspark/sql/connect/_typing.py +++ b/python/pyspark/sql/connect/_typing.py @@ -14,8 +14,41 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Union -from datetime import date, time, datetime -PrimitiveType = Union[str, int, bool, float] -LiteralType = Union[PrimitiveType, Union[date, time, datetime]] +import sys + +if sys.version_info >= (3, 8): + from typing import Protocol +else: + from typing_extensions import Protocol + +from typing import Union, Optional +import datetime +import decimal + +from pyspark.sql.connect.column import ScalarFunctionExpression, Expression, Column +from pyspark.sql.connect.function_builder import UserDefinedFunction + +ExpressionOrString = Union[Expression, str] + +ColumnOrName = Union[Column, str] + +PrimitiveType = Union[bool, float, int, str] + +OptionalPrimitiveType = Optional[PrimitiveType] + +LiteralType = PrimitiveType + +DecimalLiteral = decimal.Decimal + +DateTimeLiteral = Union[datetime.datetime, datetime.date] + + +class FunctionBuilderCallable(Protocol): + def __call__(self, *_: ExpressionOrString) -> ScalarFunctionExpression: + ... + + +class UserDefinedFunctionCallable(Protocol): + def __call__(self, *_: ColumnOrName) -> UserDefinedFunction: + ... diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index fdcf34b7a47..deb9ef6f3be 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -18,7 +18,6 @@ import logging import os -import typing import urllib.parse import uuid @@ -35,9 +34,7 @@ from pyspark.sql.connect.readwriter import DataFrameReader from pyspark.sql.connect.plan import SQL, Range from pyspark.sql.types import DataType, StructType, StructField, LongType, StringType -from typing import Optional, Any, Union - -NumericType = typing.Union[int, float] +from typing import Iterable, Optional, Any, Union, List, Tuple, Dict logging.basicConfig(level=logging.INFO) @@ -74,7 +71,7 @@ class ChannelBuilder: # Python's built-in parser. tmp_url = "http" + url[2:] self.url = urllib.parse.urlparse(tmp_url) - self.params: typing.Dict[str, str] = {} + self.params: Dict[str, str] = {} if len(self.url.path) > 0 and self.url.path != "/": raise AttributeError( f"Path component for connection URI must be empty: {self.url.path}" @@ -102,7 +99,7 @@ class ChannelBuilder: f"Target destination {self.url.netloc} does not match '<host>:<port>' pattern" ) - def metadata(self) -> typing.Iterable[typing.Tuple[str, str]]: + def metadata(self) -> Iterable[Tuple[str, str]]: """ Builds the GRPC specific metadata list to be injected into the request. All parameters will be converted to metadata except ones that are explicitly used @@ -198,7 +195,7 @@ class ChannelBuilder: class MetricValue: - def __init__(self, name: str, value: NumericType, type: str): + def __init__(self, name: str, value: Union[int, float], type: str): self._name = name self._type = type self._value = value @@ -211,7 +208,7 @@ class MetricValue: return self._name @property - def value(self) -> NumericType: + def value(self) -> Union[int, float]: return self._value @property @@ -220,7 +217,7 @@ class MetricValue: class PlanMetrics: - def __init__(self, name: str, id: int, parent: int, metrics: typing.List[MetricValue]): + def __init__(self, name: str, id: int, parent: int, metrics: List[MetricValue]): self._name = name self._id = id self._parent_id = parent @@ -242,7 +239,7 @@ class PlanMetrics: return self._parent_id @property - def metrics(self) -> typing.List[MetricValue]: + def metrics(self) -> List[MetricValue]: return self._metrics @@ -252,7 +249,7 @@ class AnalyzeResult: self.explain_string = explain @classmethod - def fromProto(cls, pb: typing.Any) -> "AnalyzeResult": + def fromProto(cls, pb: Any) -> "AnalyzeResult": return AnalyzeResult(pb.schema, pb.explain_string) @@ -306,9 +303,7 @@ class RemoteSparkSession(object): self._execute_and_fetch(req) return name - def _build_metrics( - self, metrics: "pb2.ExecutePlanResponse.Metrics" - ) -> typing.List[PlanMetrics]: + def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> List[PlanMetrics]: return [ PlanMetrics( x.name, @@ -450,7 +445,7 @@ class RemoteSparkSession(object): return rd.read_pandas() return None - def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> typing.Optional[pandas.DataFrame]: + def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> Optional[pandas.DataFrame]: import pandas as pd m: Optional[pb2.ExecutePlanResponse.Metrics] = None diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index c4ffc54c20b..36f38e0ded2 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -22,7 +22,6 @@ import decimal import datetime import pyspark.sql.connect.proto as proto -from pyspark.sql.connect._typing import PrimitiveType if TYPE_CHECKING: from pyspark.sql.connect.client import RemoteSparkSession @@ -33,6 +32,8 @@ def _bin_op( name: str, doc: str = "binary function", reverse: bool = False ) -> Callable[["Column", Any], "Expression"]: def _(self: "Column", other: Any) -> "Expression": + from pyspark.sql.connect._typing import PrimitiveType + if isinstance(other, get_args(PrimitiveType)): other = LiteralExpression(other) if not reverse: @@ -70,6 +71,8 @@ class Expression(object): """Returns a binary expression with the current column as the left side and the other expression as the right side. """ + from pyspark.sql.connect._typing import PrimitiveType + if isinstance(other, get_args(PrimitiveType)): other = LiteralExpression(other) return ScalarFunctionExpression("==", self, other) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 7b42bdf747b..ff14945db0f 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -44,11 +44,9 @@ from pyspark.sql.types import ( ) if TYPE_CHECKING: - from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString, LiteralType + from pyspark.sql.connect._typing import ColumnOrName, ExpressionOrString, LiteralType from pyspark.sql.connect.client import RemoteSparkSession -ColumnOrName = Union[Column, str] - class GroupingFrame(object): @@ -308,7 +306,7 @@ class DataFrame(object): plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session ) - def drop(self, *cols: "ColumnOrString") -> "DataFrame": + def drop(self, *cols: "ColumnOrName") -> "DataFrame": _cols = list(cols) if any(not isinstance(c, (str, Column)) for c in _cols): raise TypeError( @@ -342,7 +340,7 @@ class DataFrame(object): """ return self.head() - def groupBy(self, *cols: "ColumnOrString") -> GroupingFrame: + def groupBy(self, *cols: "ColumnOrName") -> GroupingFrame: return GroupingFrame(self, *cols) @overload @@ -414,13 +412,13 @@ class DataFrame(object): def offset(self, n: int) -> "DataFrame": return DataFrame.withPlan(plan.Offset(child=self._plan, offset=n), session=self._session) - def sort(self, *cols: "ColumnOrString") -> "DataFrame": + def sort(self, *cols: "ColumnOrName") -> "DataFrame": """Sort by a specific column""" return DataFrame.withPlan( plan.Sort(self._plan, columns=list(cols), is_global=True), session=self._session ) - def sortWithinPartitions(self, *cols: "ColumnOrString") -> "DataFrame": + def sortWithinPartitions(self, *cols: "ColumnOrName") -> "DataFrame": """Sort within each partition by a specific column""" return DataFrame.withPlan( plan.Sort(self._plan, columns=list(cols), is_global=False), session=self._session diff --git a/python/pyspark/sql/connect/function_builder.py b/python/pyspark/sql/connect/function_builder.py index e116e493954..4a2688d6a0d 100644 --- a/python/pyspark/sql/connect/function_builder.py +++ b/python/pyspark/sql/connect/function_builder.py @@ -28,9 +28,13 @@ from pyspark.sql.connect.column import ( if TYPE_CHECKING: - from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString + from pyspark.sql.connect._typing import ( + ColumnOrName, + ExpressionOrString, + FunctionBuilderCallable, + UserDefinedFunctionCallable, + ) from pyspark.sql.connect.client import RemoteSparkSession - from pyspark.sql.connect.typing import FunctionBuilderCallable, UserDefinedFunctionCallable def _build(name: str, *args: "ExpressionOrString") -> ScalarFunctionExpression: @@ -103,7 +107,7 @@ class UserDefinedFunction(Expression): def _create_udf( function: Any, return_type: Union[str, pyspark.sql.types.DataType] ) -> "UserDefinedFunctionCallable": - def wrapper(*cols: "ColumnOrString") -> UserDefinedFunction: + def wrapper(*cols: "ColumnOrName") -> UserDefinedFunction: return UserDefinedFunction(func=function, return_type=return_type, args=cols) return wrapper diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index ffb0ce080b3..8aadc3dc4fa 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -36,7 +36,7 @@ from pyspark.sql.connect.column import ( if TYPE_CHECKING: - from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString + from pyspark.sql.connect._typing import ColumnOrName, ExpressionOrString from pyspark.sql.connect.client import RemoteSparkSession @@ -58,7 +58,7 @@ class LogicalPlan(object): return exp def to_attr_or_expression( - self, col: "ColumnOrString", session: "RemoteSparkSession" + self, col: "ColumnOrName", session: "RemoteSparkSession" ) -> proto.Expression: """Returns either an instance of an unresolved attribute or the serialized expression value of the column.""" diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index 66e48eeab76..27aa023ae47 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -18,17 +18,14 @@ from typing import Dict, Optional -from pyspark.sql.connect.column import PrimitiveType from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.connect.plan import Read, DataSource from pyspark.sql.utils import to_str - -OptionalPrimitiveType = Optional[PrimitiveType] - from typing import TYPE_CHECKING if TYPE_CHECKING: + from pyspark.sql.connect._typing import OptionalPrimitiveType from pyspark.sql.connect.client import RemoteSparkSession diff --git a/python/pyspark/sql/connect/typing/__init__.pyi b/python/pyspark/sql/connect/typing/__init__.pyi deleted file mode 100644 index 43cc28701da..00000000000 --- a/python/pyspark/sql/connect/typing/__init__.pyi +++ /dev/null @@ -1,35 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from typing_extensions import Protocol -from typing import Union -from pyspark.sql.connect.column import ScalarFunctionExpression, Expression, Column -from pyspark.sql.connect.function_builder import UserDefinedFunction - -ExpressionOrString = Union[str, Expression] - -ColumnOrString = Union[str, Column] - -PrimitiveType = Union[bool, float, int, str] - -LiteralType = PrimitiveType - -class FunctionBuilderCallable(Protocol): - def __call__(self, *_: ExpressionOrString) -> ScalarFunctionExpression: ... - -class UserDefinedFunctionCallable(Protocol): - def __call__(self, *_: ColumnOrString) -> UserDefinedFunction: ... --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org