This is an automated email from the ASF dual-hosted git repository.

sarutak pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b874bf5  [SPARK-36894][SPARK-37077][PYTHON] Synchronize RDD.toDF 
annotations with SparkSession and SQLContext .createDataFrame variants
b874bf5 is described below

commit b874bf5dca4f1b7272f458350eb153e7b272f8c8
Author: zero323 <mszymkiew...@gmail.com>
AuthorDate: Thu Nov 4 02:06:48 2021 +0900

    [SPARK-36894][SPARK-37077][PYTHON] Synchronize RDD.toDF annotations with 
SparkSession and SQLContext .createDataFrame variants
    
    ### What changes were proposed in this pull request?
    
    This pull request synchronizes `RDD.toDF` annotations with 
`SparkSession.createDataFrame` and `SQLContext.createDataFrame` variants.
    
    Additionally, it fixes recent regression in `SQLContext.createDataFrame` 
(SPARK-37077), where `RDD` is no longer consider a valid input.
    
    ### Why are the changes needed?
    
    - Adds support for providing `str` schema.
    - Add supports for converting `RDDs` of "atomic" values, if schema is 
provided.
    
    Additionally it introduces a `TypeVar` representing supported "atomic" 
values. This was done to avoid issue with manual data tests, where the following
    
    ```python
    sc.parallelize([1]).toDF(schema=IntegerType())
    ```
    
    results in
    
    ```
    error: No overload variant of "toDF" of "RDD" matches argument type 
"IntegerType"  [call-overload]
    note: Possible overload variants:
    note:     def toDF(self, schema: Union[List[str], Tuple[str, ...], None] = 
..., sampleRatio: Optional[float] = ...) -> DataFrame
    note:     def toDF(self, schema: Union[StructType, str, None] = ...) -> 
DataFrame
    ```
    
    when `Union` type is used (this problem doesn't surface when non-self bound 
is used).
    
    ### Does this PR introduce _any_ user-facing change?
    
    Type checker only.
    
    Please note, that these annotations serve primarily to support 
documentation, as checks on `self` types are still very limited.
    
    ### How was this patch tested?
    
    Existing tests and manual data tests.
    
    __Note__:
    
    Updated data tests to reflect new expected traceback, after reversal in 
#34477
    
    Closes #34478 from zero323/SPARK-36894.
    
    Authored-by: zero323 <mszymkiew...@gmail.com>
    Signed-off-by: Kousuke Saruta <saru...@oss.nttdata.com>
---
 python/pyspark/rdd.pyi                           | 15 ++++++---
 python/pyspark/sql/_typing.pyi                   | 11 +++++++
 python/pyspark/sql/context.py                    | 38 ++++++++++------------
 python/pyspark/sql/session.py                    | 40 +++++++++++++++++-------
 python/pyspark/sql/tests/typing/test_session.yml |  8 ++---
 5 files changed, 71 insertions(+), 41 deletions(-)

diff --git a/python/pyspark/rdd.pyi b/python/pyspark/rdd.pyi
index a810a2c..84481d3 100644
--- a/python/pyspark/rdd.pyi
+++ b/python/pyspark/rdd.pyi
@@ -55,8 +55,8 @@ from pyspark.resource.requests import (  # noqa: F401
 from pyspark.resource.profile import ResourceProfile
 from pyspark.statcounter import StatCounter
 from pyspark.sql.dataframe import DataFrame
-from pyspark.sql.types import StructType
-from pyspark.sql._typing import RowLike
+from pyspark.sql.types import AtomicType, StructType
+from pyspark.sql._typing import AtomicValue, RowLike
 from py4j.java_gateway import JavaObject  # type: ignore[import]
 
 T = TypeVar("T")
@@ -445,11 +445,18 @@ class RDD(Generic[T]):
     @overload
     def toDF(
         self: RDD[RowLike],
-        schema: Optional[List[str]] = ...,
+        schema: Optional[Union[List[str], Tuple[str, ...]]] = ...,
         sampleRatio: Optional[float] = ...,
     ) -> DataFrame: ...
     @overload
-    def toDF(self: RDD[RowLike], schema: Optional[StructType] = ...) -> 
DataFrame: ...
+    def toDF(
+        self: RDD[RowLike], schema: Optional[Union[StructType, str]] = ...
+    ) -> DataFrame: ...
+    @overload
+    def toDF(
+        self: RDD[AtomicValue],
+        schema: Union[AtomicType, str],
+    ) -> DataFrame: ...
 
 class RDDBarrier(Generic[T]):
     rdd: RDD[T]
diff --git a/python/pyspark/sql/_typing.pyi b/python/pyspark/sql/_typing.pyi
index 1a3bd8f..b6b4606 100644
--- a/python/pyspark/sql/_typing.pyi
+++ b/python/pyspark/sql/_typing.pyi
@@ -42,6 +42,17 @@ AtomicDataTypeOrString = Union[pyspark.sql.types.AtomicType, 
str]
 DataTypeOrString = Union[pyspark.sql.types.DataType, str]
 OptionalPrimitiveType = Optional[PrimitiveType]
 
+AtomicValue = TypeVar(
+    "AtomicValue",
+    datetime.datetime,
+    datetime.date,
+    decimal.Decimal,
+    bool,
+    str,
+    int,
+    float,
+)
+
 RowLike = TypeVar("RowLike", List[Any], Tuple[Any, ...], pyspark.sql.types.Row)
 
 class SupportsOpen(Protocol):
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 7d27c55..eba2087 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -48,13 +48,11 @@ from pyspark.conf import SparkConf
 
 if TYPE_CHECKING:
     from pyspark.sql._typing import (
-        UserDefinedFunctionLike,
+        AtomicValue,
         RowLike,
-        DateTimeLiteral,
-        LiteralType,
-        DecimalLiteral
+        UserDefinedFunctionLike,
     )
-    from pyspark.sql.pandas._typing import DataFrameLike
+    from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
 
 __all__ = ["SQLContext", "HiveContext"]
 
@@ -323,7 +321,8 @@ class SQLContext(object):
     @overload
     def createDataFrame(
         self,
-        data: Iterable["RowLike"],
+        data: Union["RDD[RowLike]", Iterable["RowLike"]],
+        schema: Union[List[str], Tuple[str, ...]] = ...,
         samplingRatio: Optional[float] = ...,
     ) -> DataFrame:
         ...
@@ -331,8 +330,9 @@ class SQLContext(object):
     @overload
     def createDataFrame(
         self,
-        data: Iterable["RowLike"],
-        schema: Union[List[str], Tuple[str, ...]] = ...,
+        data: Union["RDD[RowLike]", Iterable["RowLike"]],
+        schema: Union[StructType, str],
+        *,
         verifySchema: bool = ...,
     ) -> DataFrame:
         ...
@@ -340,7 +340,10 @@ class SQLContext(object):
     @overload
     def createDataFrame(
         self,
-        data: Iterable[Union["DateTimeLiteral", "LiteralType", 
"DecimalLiteral"]],
+        data: Union[
+            "RDD[AtomicValue]",
+            Iterable["AtomicValue"],
+        ],
         schema: Union[AtomicType, str],
         verifySchema: bool = ...,
     ) -> DataFrame:
@@ -348,23 +351,14 @@ class SQLContext(object):
 
     @overload
     def createDataFrame(
-        self,
-        data: Iterable["RowLike"],
-        schema: Union[StructType, str],
-        verifySchema: bool = ...,
-    ) -> DataFrame:
-        ...
-
-    @overload
-    def createDataFrame(
-        self, data: "DataFrameLike", samplingRatio: Optional[float] = ...
+        self, data: "PandasDataFrameLike", samplingRatio: Optional[float] = ...
     ) -> DataFrame:
         ...
 
     @overload
     def createDataFrame(
         self,
-        data: "DataFrameLike",
+        data: "PandasDataFrameLike",
         schema: Union[StructType, str],
         verifySchema: bool = ...,
     ) -> DataFrame:
@@ -372,8 +366,8 @@ class SQLContext(object):
 
     def createDataFrame(  # type: ignore[misc]
         self,
-        data: Iterable["RowLike"],
-        schema: Optional[Union[List[str], Tuple[str, ...]]] = None,
+        data: Union["RDD[Any]", Iterable[Any], "PandasDataFrameLike"],
+        schema: Optional[Union[AtomicType, StructType, str]] = None,
         samplingRatio: Optional[float] = None,
         verifySchema: bool = True
     ) -> DataFrame:
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 728d658..de0f9e3 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -42,7 +42,7 @@ from pyspark.sql.types import (
 from pyspark.sql.utils import install_exception_handler, 
is_timestamp_ntz_preferred
 
 if TYPE_CHECKING:
-    from pyspark.sql._typing import DateTimeLiteral, LiteralType, 
DecimalLiteral, RowLike
+    from pyspark.sql._typing import AtomicValue, RowLike
     from pyspark.sql.catalog import Catalog
     from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
     from pyspark.sql.streaming import StreamingQueryManager
@@ -628,7 +628,8 @@ class SparkSession(SparkConversionMixin):
     @overload
     def createDataFrame(
         self,
-        data: Union["RDD[RowLike]", Iterable["RowLike"]],
+        data: Iterable["RowLike"],
+        schema: Union[List[str], Tuple[str, ...]] = ...,
         samplingRatio: Optional[float] = ...,
     ) -> DataFrame:
         ...
@@ -636,35 +637,52 @@ class SparkSession(SparkConversionMixin):
     @overload
     def createDataFrame(
         self,
-        data: Union["RDD[RowLike]", Iterable["RowLike"]],
+        data: "RDD[RowLike]",
         schema: Union[List[str], Tuple[str, ...]] = ...,
-        verifySchema: bool = ...,
+        samplingRatio: Optional[float] = ...,
     ) -> DataFrame:
         ...
 
     @overload
     def createDataFrame(
         self,
-        data: Union[
-            "RDD[Union[DateTimeLiteral, LiteralType, DecimalLiteral]]",
-            Iterable[Union["DateTimeLiteral", "LiteralType", 
"DecimalLiteral"]],
-        ],
-        schema: Union[AtomicType, str],
+        data: Iterable["RowLike"],
+        schema: Union[StructType, str],
+        *,
         verifySchema: bool = ...,
     ) -> DataFrame:
         ...
 
     @overload
     def createDataFrame(
+            self,
+            data: "RDD[RowLike]",
+            schema: Union[StructType, str],
+            *,
+            verifySchema: bool = ...,
+    ) -> DataFrame:
+        ...
+
+    @overload
+    def createDataFrame(
         self,
-        data: Union["RDD[RowLike]", Iterable["RowLike"]],
-        schema: Union[StructType, str],
+        data: "RDD[AtomicValue]",
+        schema: Union[AtomicType, str],
         verifySchema: bool = ...,
     ) -> DataFrame:
         ...
 
     @overload
     def createDataFrame(
+            self,
+            data: Iterable["AtomicValue"],
+            schema: Union[AtomicType, str],
+            verifySchema: bool = ...,
+    ) -> DataFrame:
+        ...
+
+    @overload
+    def createDataFrame(
         self, data: "PandasDataFrameLike", samplingRatio: Optional[float] = ...
     ) -> DataFrame:
         ...
diff --git a/python/pyspark/sql/tests/typing/test_session.yml 
b/python/pyspark/sql/tests/typing/test_session.yml
index f06e79e..01a6b28 100644
--- a/python/pyspark/sql/tests/typing/test_session.yml
+++ b/python/pyspark/sql/tests/typing/test_session.yml
@@ -92,12 +92,12 @@
     spark.createDataFrame(data, schema, samplingRatio=0.1)
 
   out: |
-    main:14: error: Argument 1 to "createDataFrame" of "SparkSession" has 
incompatible type "List[Tuple[str, int]]"; expected 
"Union[RDD[Union[Union[datetime, date], Union[bool, float, int, str], 
Decimal]], Iterable[Union[Union[datetime, date], Union[bool, float, int, str], 
Decimal]]]"  [arg-type]
+    main:14: error: Value of type variable "AtomicValue" of "createDataFrame" 
of "SparkSession" cannot be "Tuple[str, int]"  [type-var]
     main:18: error: No overload variant of "createDataFrame" of "SparkSession" 
matches argument types "List[Tuple[str, int]]", "StructType", "float"  
[call-overload]
     main:18: note: Possible overload variants:
-    main:18: note:     def [RowLike in (List[Any], Tuple[Any, ...], Row)] 
createDataFrame(self, data: Union[RDD[RowLike], Iterable[RowLike]], 
samplingRatio: Optional[float] = ...) -> DataFrame
-    main:18: note:     def [RowLike in (List[Any], Tuple[Any, ...], Row)] 
createDataFrame(self, data: Union[RDD[RowLike], Iterable[RowLike]], schema: 
Union[List[str], Tuple[str, ...]] = ..., verifySchema: bool = ...) -> DataFrame
-    main:18: note:     <4 more similar overloads not shown, out of 6 total 
overloads>
+    main:18: note:     def [RowLike in (List[Any], Tuple[Any, ...], Row)] 
createDataFrame(self, data: Iterable[RowLike], schema: Union[List[str], 
Tuple[str, ...]] = ..., samplingRatio: Optional[float] = ...) -> DataFrame
+    main:18: note:     def [RowLike in (List[Any], Tuple[Any, ...], Row)] 
createDataFrame(self, data: RDD[RowLike], schema: Union[List[str], Tuple[str, 
...]] = ..., samplingRatio: Optional[float] = ...) -> DataFrame
+    main:18: note:     <6 more non-matching overloads not shown>
 
 
 - case: createDataFrameFromEmptyRdd

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to