This is an automated email from the ASF dual-hosted git repository. sarutak pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new b874bf5 [SPARK-36894][SPARK-37077][PYTHON] Synchronize RDD.toDF annotations with SparkSession and SQLContext .createDataFrame variants b874bf5 is described below commit b874bf5dca4f1b7272f458350eb153e7b272f8c8 Author: zero323 <mszymkiew...@gmail.com> AuthorDate: Thu Nov 4 02:06:48 2021 +0900 [SPARK-36894][SPARK-37077][PYTHON] Synchronize RDD.toDF annotations with SparkSession and SQLContext .createDataFrame variants ### What changes were proposed in this pull request? This pull request synchronizes `RDD.toDF` annotations with `SparkSession.createDataFrame` and `SQLContext.createDataFrame` variants. Additionally, it fixes recent regression in `SQLContext.createDataFrame` (SPARK-37077), where `RDD` is no longer consider a valid input. ### Why are the changes needed? - Adds support for providing `str` schema. - Add supports for converting `RDDs` of "atomic" values, if schema is provided. Additionally it introduces a `TypeVar` representing supported "atomic" values. This was done to avoid issue with manual data tests, where the following ```python sc.parallelize([1]).toDF(schema=IntegerType()) ``` results in ``` error: No overload variant of "toDF" of "RDD" matches argument type "IntegerType" [call-overload] note: Possible overload variants: note: def toDF(self, schema: Union[List[str], Tuple[str, ...], None] = ..., sampleRatio: Optional[float] = ...) -> DataFrame note: def toDF(self, schema: Union[StructType, str, None] = ...) -> DataFrame ``` when `Union` type is used (this problem doesn't surface when non-self bound is used). ### Does this PR introduce _any_ user-facing change? Type checker only. Please note, that these annotations serve primarily to support documentation, as checks on `self` types are still very limited. ### How was this patch tested? Existing tests and manual data tests. __Note__: Updated data tests to reflect new expected traceback, after reversal in #34477 Closes #34478 from zero323/SPARK-36894. Authored-by: zero323 <mszymkiew...@gmail.com> Signed-off-by: Kousuke Saruta <saru...@oss.nttdata.com> --- python/pyspark/rdd.pyi | 15 ++++++--- python/pyspark/sql/_typing.pyi | 11 +++++++ python/pyspark/sql/context.py | 38 ++++++++++------------ python/pyspark/sql/session.py | 40 +++++++++++++++++------- python/pyspark/sql/tests/typing/test_session.yml | 8 ++--- 5 files changed, 71 insertions(+), 41 deletions(-) diff --git a/python/pyspark/rdd.pyi b/python/pyspark/rdd.pyi index a810a2c..84481d3 100644 --- a/python/pyspark/rdd.pyi +++ b/python/pyspark/rdd.pyi @@ -55,8 +55,8 @@ from pyspark.resource.requests import ( # noqa: F401 from pyspark.resource.profile import ResourceProfile from pyspark.statcounter import StatCounter from pyspark.sql.dataframe import DataFrame -from pyspark.sql.types import StructType -from pyspark.sql._typing import RowLike +from pyspark.sql.types import AtomicType, StructType +from pyspark.sql._typing import AtomicValue, RowLike from py4j.java_gateway import JavaObject # type: ignore[import] T = TypeVar("T") @@ -445,11 +445,18 @@ class RDD(Generic[T]): @overload def toDF( self: RDD[RowLike], - schema: Optional[List[str]] = ..., + schema: Optional[Union[List[str], Tuple[str, ...]]] = ..., sampleRatio: Optional[float] = ..., ) -> DataFrame: ... @overload - def toDF(self: RDD[RowLike], schema: Optional[StructType] = ...) -> DataFrame: ... + def toDF( + self: RDD[RowLike], schema: Optional[Union[StructType, str]] = ... + ) -> DataFrame: ... + @overload + def toDF( + self: RDD[AtomicValue], + schema: Union[AtomicType, str], + ) -> DataFrame: ... class RDDBarrier(Generic[T]): rdd: RDD[T] diff --git a/python/pyspark/sql/_typing.pyi b/python/pyspark/sql/_typing.pyi index 1a3bd8f..b6b4606 100644 --- a/python/pyspark/sql/_typing.pyi +++ b/python/pyspark/sql/_typing.pyi @@ -42,6 +42,17 @@ AtomicDataTypeOrString = Union[pyspark.sql.types.AtomicType, str] DataTypeOrString = Union[pyspark.sql.types.DataType, str] OptionalPrimitiveType = Optional[PrimitiveType] +AtomicValue = TypeVar( + "AtomicValue", + datetime.datetime, + datetime.date, + decimal.Decimal, + bool, + str, + int, + float, +) + RowLike = TypeVar("RowLike", List[Any], Tuple[Any, ...], pyspark.sql.types.Row) class SupportsOpen(Protocol): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 7d27c55..eba2087 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -48,13 +48,11 @@ from pyspark.conf import SparkConf if TYPE_CHECKING: from pyspark.sql._typing import ( - UserDefinedFunctionLike, + AtomicValue, RowLike, - DateTimeLiteral, - LiteralType, - DecimalLiteral + UserDefinedFunctionLike, ) - from pyspark.sql.pandas._typing import DataFrameLike + from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike __all__ = ["SQLContext", "HiveContext"] @@ -323,7 +321,8 @@ class SQLContext(object): @overload def createDataFrame( self, - data: Iterable["RowLike"], + data: Union["RDD[RowLike]", Iterable["RowLike"]], + schema: Union[List[str], Tuple[str, ...]] = ..., samplingRatio: Optional[float] = ..., ) -> DataFrame: ... @@ -331,8 +330,9 @@ class SQLContext(object): @overload def createDataFrame( self, - data: Iterable["RowLike"], - schema: Union[List[str], Tuple[str, ...]] = ..., + data: Union["RDD[RowLike]", Iterable["RowLike"]], + schema: Union[StructType, str], + *, verifySchema: bool = ..., ) -> DataFrame: ... @@ -340,7 +340,10 @@ class SQLContext(object): @overload def createDataFrame( self, - data: Iterable[Union["DateTimeLiteral", "LiteralType", "DecimalLiteral"]], + data: Union[ + "RDD[AtomicValue]", + Iterable["AtomicValue"], + ], schema: Union[AtomicType, str], verifySchema: bool = ..., ) -> DataFrame: @@ -348,23 +351,14 @@ class SQLContext(object): @overload def createDataFrame( - self, - data: Iterable["RowLike"], - schema: Union[StructType, str], - verifySchema: bool = ..., - ) -> DataFrame: - ... - - @overload - def createDataFrame( - self, data: "DataFrameLike", samplingRatio: Optional[float] = ... + self, data: "PandasDataFrameLike", samplingRatio: Optional[float] = ... ) -> DataFrame: ... @overload def createDataFrame( self, - data: "DataFrameLike", + data: "PandasDataFrameLike", schema: Union[StructType, str], verifySchema: bool = ..., ) -> DataFrame: @@ -372,8 +366,8 @@ class SQLContext(object): def createDataFrame( # type: ignore[misc] self, - data: Iterable["RowLike"], - schema: Optional[Union[List[str], Tuple[str, ...]]] = None, + data: Union["RDD[Any]", Iterable[Any], "PandasDataFrameLike"], + schema: Optional[Union[AtomicType, StructType, str]] = None, samplingRatio: Optional[float] = None, verifySchema: bool = True ) -> DataFrame: diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 728d658..de0f9e3 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -42,7 +42,7 @@ from pyspark.sql.types import ( from pyspark.sql.utils import install_exception_handler, is_timestamp_ntz_preferred if TYPE_CHECKING: - from pyspark.sql._typing import DateTimeLiteral, LiteralType, DecimalLiteral, RowLike + from pyspark.sql._typing import AtomicValue, RowLike from pyspark.sql.catalog import Catalog from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike from pyspark.sql.streaming import StreamingQueryManager @@ -628,7 +628,8 @@ class SparkSession(SparkConversionMixin): @overload def createDataFrame( self, - data: Union["RDD[RowLike]", Iterable["RowLike"]], + data: Iterable["RowLike"], + schema: Union[List[str], Tuple[str, ...]] = ..., samplingRatio: Optional[float] = ..., ) -> DataFrame: ... @@ -636,35 +637,52 @@ class SparkSession(SparkConversionMixin): @overload def createDataFrame( self, - data: Union["RDD[RowLike]", Iterable["RowLike"]], + data: "RDD[RowLike]", schema: Union[List[str], Tuple[str, ...]] = ..., - verifySchema: bool = ..., + samplingRatio: Optional[float] = ..., ) -> DataFrame: ... @overload def createDataFrame( self, - data: Union[ - "RDD[Union[DateTimeLiteral, LiteralType, DecimalLiteral]]", - Iterable[Union["DateTimeLiteral", "LiteralType", "DecimalLiteral"]], - ], - schema: Union[AtomicType, str], + data: Iterable["RowLike"], + schema: Union[StructType, str], + *, verifySchema: bool = ..., ) -> DataFrame: ... @overload def createDataFrame( + self, + data: "RDD[RowLike]", + schema: Union[StructType, str], + *, + verifySchema: bool = ..., + ) -> DataFrame: + ... + + @overload + def createDataFrame( self, - data: Union["RDD[RowLike]", Iterable["RowLike"]], - schema: Union[StructType, str], + data: "RDD[AtomicValue]", + schema: Union[AtomicType, str], verifySchema: bool = ..., ) -> DataFrame: ... @overload def createDataFrame( + self, + data: Iterable["AtomicValue"], + schema: Union[AtomicType, str], + verifySchema: bool = ..., + ) -> DataFrame: + ... + + @overload + def createDataFrame( self, data: "PandasDataFrameLike", samplingRatio: Optional[float] = ... ) -> DataFrame: ... diff --git a/python/pyspark/sql/tests/typing/test_session.yml b/python/pyspark/sql/tests/typing/test_session.yml index f06e79e..01a6b28 100644 --- a/python/pyspark/sql/tests/typing/test_session.yml +++ b/python/pyspark/sql/tests/typing/test_session.yml @@ -92,12 +92,12 @@ spark.createDataFrame(data, schema, samplingRatio=0.1) out: | - main:14: error: Argument 1 to "createDataFrame" of "SparkSession" has incompatible type "List[Tuple[str, int]]"; expected "Union[RDD[Union[Union[datetime, date], Union[bool, float, int, str], Decimal]], Iterable[Union[Union[datetime, date], Union[bool, float, int, str], Decimal]]]" [arg-type] + main:14: error: Value of type variable "AtomicValue" of "createDataFrame" of "SparkSession" cannot be "Tuple[str, int]" [type-var] main:18: error: No overload variant of "createDataFrame" of "SparkSession" matches argument types "List[Tuple[str, int]]", "StructType", "float" [call-overload] main:18: note: Possible overload variants: - main:18: note: def [RowLike in (List[Any], Tuple[Any, ...], Row)] createDataFrame(self, data: Union[RDD[RowLike], Iterable[RowLike]], samplingRatio: Optional[float] = ...) -> DataFrame - main:18: note: def [RowLike in (List[Any], Tuple[Any, ...], Row)] createDataFrame(self, data: Union[RDD[RowLike], Iterable[RowLike]], schema: Union[List[str], Tuple[str, ...]] = ..., verifySchema: bool = ...) -> DataFrame - main:18: note: <4 more similar overloads not shown, out of 6 total overloads> + main:18: note: def [RowLike in (List[Any], Tuple[Any, ...], Row)] createDataFrame(self, data: Iterable[RowLike], schema: Union[List[str], Tuple[str, ...]] = ..., samplingRatio: Optional[float] = ...) -> DataFrame + main:18: note: def [RowLike in (List[Any], Tuple[Any, ...], Row)] createDataFrame(self, data: RDD[RowLike], schema: Union[List[str], Tuple[str, ...]] = ..., samplingRatio: Optional[float] = ...) -> DataFrame + main:18: note: <6 more non-matching overloads not shown> - case: createDataFrameFromEmptyRdd --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org