This is an automated email from the ASF dual-hosted git repository.

zero323 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6ff760d  [SPARK-37154][PYTHON] Inline hints for pyspark.rdd
6ff760d is described below

commit 6ff760d483124b121d79c3a2d5fdc3ee3f27dd00
Author: zero323 <mszymkiew...@gmail.com>
AuthorDate: Sat Feb 19 11:32:13 2022 +0100

    [SPARK-37154][PYTHON] Inline hints for pyspark.rdd
    
    ### What changes were proposed in this pull request?
    
    This PR proposes migration of type hints for `pyspark.rdd` from stub file 
to inline annotation.
    
    ### Why are the changes needed?
    
    As a part of ongoing process of migration of stubs to inline hints.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests + new data tests.
    
    Closes #35252 from zero323/SPARK-37154.
    
    Authored-by: zero323 <mszymkiew...@gmail.com>
    Signed-off-by: zero323 <mszymkiew...@gmail.com>
---
 python/pyspark/_typing.pyi                     |  14 +-
 python/pyspark/conf.py                         |  12 +
 python/pyspark/context.py                      |   3 +-
 python/pyspark/rdd.py                          | 987 ++++++++++++++++++-------
 python/pyspark/rdd.pyi                         | 481 ------------
 python/pyspark/serializers.py                  |   7 +
 python/pyspark/sql/_typing.pyi                 |   4 +-
 python/pyspark/sql/pandas/_typing/__init__.pyi |   5 +-
 python/pyspark/tests/typing/test_rdd.yml       |  77 +-
 9 files changed, 814 insertions(+), 776 deletions(-)

diff --git a/python/pyspark/_typing.pyi b/python/pyspark/_typing.pyi
index 9a36c89..6cc0926 100644
--- a/python/pyspark/_typing.pyi
+++ b/python/pyspark/_typing.pyi
@@ -17,17 +17,27 @@
 # under the License.
 
 from typing import Callable, Iterable, Sized, TypeVar, Union
-from typing_extensions import Protocol
+from typing_extensions import Literal, Protocol
+
+from numpy import int32, int64, float32, float64, ndarray
 
 F = TypeVar("F", bound=Callable)
 T_co = TypeVar("T_co", covariant=True)
 
 PrimitiveType = Union[bool, float, int, str]
 
+NonUDFType = Literal[0]
+
 class SupportsIAdd(Protocol):
     def __iadd__(self, other: SupportsIAdd) -> SupportsIAdd: ...
 
 class SupportsOrdering(Protocol):
-    def __le__(self, other: SupportsOrdering) -> bool: ...
+    def __lt__(self, other: SupportsOrdering) -> bool: ...
 
 class SizedIterable(Protocol, Sized, Iterable[T_co]): ...
+
+S = TypeVar("S", bound=SupportsOrdering)
+
+NumberOrArray = TypeVar(
+    "NumberOrArray", float, int, complex, int32, int64, float32, float64, 
ndarray
+)
diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py
index 536e1f8..a9e26966 100644
--- a/python/pyspark/conf.py
+++ b/python/pyspark/conf.py
@@ -203,6 +203,18 @@ class SparkConf:
             self.set(k, v)
         return self
 
+    @overload
+    def get(self, key: str) -> Optional[str]:
+        ...
+
+    @overload
+    def get(self, key: str, defaultValue: None) -> Optional[str]:
+        ...
+
+    @overload
+    def get(self, key: str, defaultValue: str) -> str:
+        ...
+
     def get(self, key: str, defaultValue: Optional[str] = None) -> 
Optional[str]:
         """Get the configured value for some key, or return a default 
otherwise."""
         if defaultValue is None:  # Py4J doesn't call the right get() if we 
pass None
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 3db9630..68f748e 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -35,6 +35,7 @@ from typing import (
     List,
     NoReturn,
     Optional,
+    Sequence,
     Tuple,
     Type,
     TYPE_CHECKING,
@@ -1421,7 +1422,7 @@ class SparkContext:
         self,
         rdd: "RDD[T]",
         partitionFunc: Callable[[Iterable[T]], Iterable[U]],
-        partitions: Optional[List[int]] = None,
+        partitions: Optional[Sequence[int]] = None,
         allowLocal: bool = False,
     ) -> List[U]:
         """
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 97b87ea..7cb887f 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -30,6 +30,26 @@ from collections import defaultdict
 from itertools import chain
 from functools import reduce
 from math import sqrt, log, isinf, isnan, pow, ceil
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Generic,
+    Hashable,
+    Iterable,
+    Iterator,
+    IO,
+    List,
+    NoReturn,
+    Optional,
+    Sequence,
+    Tuple,
+    Union,
+    TypeVar,
+    cast,
+    overload,
+    TYPE_CHECKING,
+)
 
 from pyspark.java_gateway import local_connect_and_auth
 from pyspark.serializers import (
@@ -40,6 +60,7 @@ from pyspark.serializers import (
     CloudPickleSerializer,
     PairDeserializer,
     CPickleSerializer,
+    Serializer,
     pack_long,
     read_int,
     write_int,
@@ -67,6 +88,41 @@ from pyspark.shuffle import (
 from pyspark.traceback_utils import SCCallSiteSync
 from pyspark.util import fail_on_stopiteration, _parse_memory
 
+
+if TYPE_CHECKING:
+    import socket
+    import io
+
+    from pyspark._typing import NonUDFType
+    from pyspark._typing import S, NumberOrArray
+    from pyspark.context import SparkContext
+    from pyspark.sql.pandas._typing import (
+        PandasScalarUDFType,
+        PandasGroupedMapUDFType,
+        PandasGroupedAggUDFType,
+        PandasWindowAggUDFType,
+        PandasScalarIterUDFType,
+        PandasMapIterUDFType,
+        PandasCogroupedMapUDFType,
+        ArrowMapIterUDFType,
+    )
+    from pyspark.sql.dataframe import DataFrame
+    from pyspark.sql.types import AtomicType, StructType
+    from pyspark.sql._typing import AtomicValue, RowLike, SQLBatchedUDFType
+
+    from py4j.java_gateway import JavaObject  # type: ignore[import]
+    from py4j.java_collections import JavaArray  # type: ignore[import]
+
+T = TypeVar("T")
+T_co = TypeVar("T_co", covariant=True)
+U = TypeVar("U")
+K = TypeVar("K", bound=Hashable)
+V = TypeVar("V")
+V1 = TypeVar("V1")
+V2 = TypeVar("V2")
+V3 = TypeVar("V3")
+
+
 __all__ = ["RDD"]
 
 
@@ -79,21 +135,21 @@ class PythonEvalType:
     These values should match values in 
org.apache.spark.api.python.PythonEvalType.
     """
 
-    NON_UDF = 0
+    NON_UDF: "NonUDFType" = 0
 
-    SQL_BATCHED_UDF = 100
+    SQL_BATCHED_UDF: "SQLBatchedUDFType" = 100
 
-    SQL_SCALAR_PANDAS_UDF = 200
-    SQL_GROUPED_MAP_PANDAS_UDF = 201
-    SQL_GROUPED_AGG_PANDAS_UDF = 202
-    SQL_WINDOW_AGG_PANDAS_UDF = 203
-    SQL_SCALAR_PANDAS_ITER_UDF = 204
-    SQL_MAP_PANDAS_ITER_UDF = 205
-    SQL_COGROUPED_MAP_PANDAS_UDF = 206
-    SQL_MAP_ARROW_ITER_UDF = 207
+    SQL_SCALAR_PANDAS_UDF: "PandasScalarUDFType" = 200
+    SQL_GROUPED_MAP_PANDAS_UDF: "PandasGroupedMapUDFType" = 201
+    SQL_GROUPED_AGG_PANDAS_UDF: "PandasGroupedAggUDFType" = 202
+    SQL_WINDOW_AGG_PANDAS_UDF: "PandasWindowAggUDFType" = 203
+    SQL_SCALAR_PANDAS_ITER_UDF: "PandasScalarIterUDFType" = 204
+    SQL_MAP_PANDAS_ITER_UDF: "PandasMapIterUDFType" = 205
+    SQL_COGROUPED_MAP_PANDAS_UDF: "PandasCogroupedMapUDFType" = 206
+    SQL_MAP_ARROW_ITER_UDF: "ArrowMapIterUDFType" = 207
 
 
-def portable_hash(x):
+def portable_hash(x: Hashable) -> int:
     """
     This function returns consistent hash code for builtin types, especially
     for None and tuple with None.
@@ -137,7 +193,11 @@ class BoundedFloat(float):
     100.0
     """
 
-    def __new__(cls, mean, confidence, low, high):
+    confidence: float
+    low: float
+    high: float
+
+    def __new__(cls, mean: float, confidence: float, low: float, high: float) 
-> "BoundedFloat":
         obj = float.__new__(cls, mean)
         obj.confidence = confidence
         obj.low = low
@@ -145,7 +205,7 @@ class BoundedFloat(float):
         return obj
 
 
-def _create_local_socket(sock_info):
+def _create_local_socket(sock_info: "JavaArray") -> "io.BufferedRWPair":
     """
     Create a local socket that can be used to load deserialized data from the 
JVM
 
@@ -158,8 +218,10 @@ def _create_local_socket(sock_info):
     -------
     sockfile file descriptor of the local socket
     """
-    port = sock_info[0]
-    auth_secret = sock_info[1]
+    sockfile: "io.BufferedRWPair"
+    sock: "socket.socket"
+    port: int = sock_info[0]
+    auth_secret: str = sock_info[1]
     sockfile, sock = local_connect_and_auth(port, auth_secret)
     # The RDD materialization time is unpredictable, if we set a timeout for 
socket reading
     # operation, it will very possibly fail. See SPARK-18281.
@@ -167,7 +229,7 @@ def _create_local_socket(sock_info):
     return sockfile
 
 
-def _load_from_socket(sock_info, serializer):
+def _load_from_socket(sock_info: "JavaArray", serializer: Serializer) -> 
Iterator[Any]:
     """
     Connect to a local socket described by sock_info and use the given 
serializer to yield data
 
@@ -188,18 +250,21 @@ def _load_from_socket(sock_info, serializer):
     return serializer.load_stream(sockfile)
 
 
-def _local_iterator_from_socket(sock_info, serializer):
+def _local_iterator_from_socket(sock_info: "JavaArray", serializer: 
Serializer) -> Iterator[Any]:
     class PyLocalIterable:
         """Create a synchronous local iterable over a socket"""
 
-        def __init__(self, _sock_info, _serializer):
+        def __init__(self, _sock_info: "JavaArray", _serializer: Serializer):
+            port: int
+            auth_secret: str
+            jsocket_auth_server: "JavaObject"
             port, auth_secret, self.jsocket_auth_server = _sock_info
             self._sockfile = _create_local_socket((port, auth_secret))
             self._serializer = _serializer
-            self._read_iter = iter([])  # Initialize as empty iterator
+            self._read_iter: Iterator[Any] = iter([])  # Initialize as empty 
iterator
             self._read_status = 1
 
-        def __iter__(self):
+        def __iter__(self) -> Iterator[Any]:
             while self._read_status == 1:
                 # Request next partition data from Java
                 write_int(1, self._sockfile)
@@ -218,7 +283,7 @@ def _local_iterator_from_socket(sock_info, serializer):
                 elif self._read_status == -1:
                     self.jsocket_auth_server.getResult()
 
-        def __del__(self):
+        def __del__(self) -> None:
             # If local iterator is not fully consumed,
             if self._read_status == 1:
                 try:
@@ -236,22 +301,22 @@ def _local_iterator_from_socket(sock_info, serializer):
 
 
 class Partitioner:
-    def __init__(self, numPartitions, partitionFunc):
+    def __init__(self, numPartitions: int, partitionFunc: Callable[[Any], 
int]):
         self.numPartitions = numPartitions
         self.partitionFunc = partitionFunc
 
-    def __eq__(self, other):
+    def __eq__(self, other: Any) -> bool:
         return (
             isinstance(other, Partitioner)
             and self.numPartitions == other.numPartitions
             and self.partitionFunc == other.partitionFunc
         )
 
-    def __call__(self, k):
+    def __call__(self, k: Any) -> int:
         return self.partitionFunc(k) % self.numPartitions
 
 
-class RDD:
+class RDD(Generic[T_co]):
 
     """
     A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
@@ -259,7 +324,12 @@ class RDD:
     operated on in parallel.
     """
 
-    def __init__(self, jrdd, ctx, 
jrdd_deserializer=AutoBatchedSerializer(CPickleSerializer())):
+    def __init__(
+        self,
+        jrdd: "JavaObject",
+        ctx: "SparkContext",
+        jrdd_deserializer: Serializer = 
AutoBatchedSerializer(CPickleSerializer()),
+    ):
         self._jrdd = jrdd
         self.is_cached = False
         self.is_checkpointed = False
@@ -267,21 +337,21 @@ class RDD:
         self.ctx = ctx
         self._jrdd_deserializer = jrdd_deserializer
         self._id = jrdd.id()
-        self.partitioner = None
+        self.partitioner: Optional[Partitioner] = None
 
-    def _pickled(self):
+    def _pickled(self: "RDD[T]") -> "RDD[T]":
         return self._reserialize(AutoBatchedSerializer(CPickleSerializer()))
 
-    def id(self):
+    def id(self) -> int:
         """
         A unique ID for this RDD (within its SparkContext).
         """
         return self._id
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return self._jrdd.toString()
 
-    def __getnewargs__(self):
+    def __getnewargs__(self) -> NoReturn:
         # This method is called when attempting to pickle an RDD, which is 
always an error:
         raise RuntimeError(
             "It appears that you are attempting to broadcast an RDD or 
reference an RDD from an "
@@ -293,13 +363,13 @@ class RDD:
         )
 
     @property
-    def context(self):
+    def context(self) -> "SparkContext":
         """
         The :class:`SparkContext` that this RDD was created on.
         """
         return self.ctx
 
-    def cache(self):
+    def cache(self: "RDD[T]") -> "RDD[T]":
         """
         Persist this RDD with the default storage level (`MEMORY_ONLY`).
         """
@@ -307,7 +377,7 @@ class RDD:
         self.persist(StorageLevel.MEMORY_ONLY)
         return self
 
-    def persist(self, storageLevel=StorageLevel.MEMORY_ONLY):
+    def persist(self: "RDD[T]", storageLevel: StorageLevel = 
StorageLevel.MEMORY_ONLY) -> "RDD[T]":
         """
         Set this RDD's storage level to persist its values across operations
         after the first time it is computed. This can only be used to assign
@@ -325,7 +395,7 @@ class RDD:
         self._jrdd.persist(javaStorageLevel)
         return self
 
-    def unpersist(self, blocking=False):
+    def unpersist(self: "RDD[T]", blocking: bool = False) -> "RDD[T]":
         """
         Mark the RDD as non-persistent, and remove all blocks for it from
         memory and disk.
@@ -338,7 +408,7 @@ class RDD:
         self._jrdd.unpersist(blocking)
         return self
 
-    def checkpoint(self):
+    def checkpoint(self) -> None:
         """
         Mark this RDD for checkpointing. It will be saved to a file inside the
         checkpoint directory set with :meth:`SparkContext.setCheckpointDir` and
@@ -350,13 +420,13 @@ class RDD:
         self.is_checkpointed = True
         self._jrdd.rdd().checkpoint()
 
-    def isCheckpointed(self):
+    def isCheckpointed(self) -> bool:
         """
         Return whether this RDD is checkpointed and materialized, either 
reliably or locally.
         """
         return self._jrdd.rdd().isCheckpointed()
 
-    def localCheckpoint(self):
+    def localCheckpoint(self) -> None:
         """
         Mark this RDD for local checkpointing using Spark's existing caching 
layer.
 
@@ -377,7 +447,7 @@ class RDD:
         """
         self._jrdd.rdd().localCheckpoint()
 
-    def isLocallyCheckpointed(self):
+    def isLocallyCheckpointed(self) -> bool:
         """
         Return whether this RDD is marked for local checkpointing.
 
@@ -385,17 +455,17 @@ class RDD:
         """
         return self._jrdd.rdd().isLocallyCheckpointed()
 
-    def getCheckpointFile(self):
+    def getCheckpointFile(self) -> Optional[str]:
         """
         Gets the name of the file to which this RDD was checkpointed
 
         Not defined if RDD is checkpointed locally.
         """
         checkpointFile = self._jrdd.rdd().getCheckpointFile()
-        if checkpointFile.isDefined():
-            return checkpointFile.get()
 
-    def map(self, f, preservesPartitioning=False):
+        return checkpointFile.get() if checkpointFile.isDefined() else None
+
+    def map(self: "RDD[T]", f: Callable[[T], U], preservesPartitioning: bool = 
False) -> "RDD[U]":
         """
         Return a new RDD by applying a function to each element of this RDD.
 
@@ -406,12 +476,14 @@ class RDD:
         [('a', 1), ('b', 1), ('c', 1)]
         """
 
-        def func(_, iterator):
+        def func(_: int, iterator: Iterable[T]) -> Iterable[U]:
             return map(fail_on_stopiteration(f), iterator)
 
         return self.mapPartitionsWithIndex(func, preservesPartitioning)
 
-    def flatMap(self, f, preservesPartitioning=False):
+    def flatMap(
+        self: "RDD[T]", f: Callable[[T], Iterable[U]], preservesPartitioning: 
bool = False
+    ) -> "RDD[U]":
         """
         Return a new RDD by first applying a function to all elements of this
         RDD, and then flattening the results.
@@ -425,12 +497,14 @@ class RDD:
         [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
         """
 
-        def func(s, iterator):
+        def func(_: int, iterator: Iterable[T]) -> Iterable[U]:
             return chain.from_iterable(map(fail_on_stopiteration(f), iterator))
 
         return self.mapPartitionsWithIndex(func, preservesPartitioning)
 
-    def mapPartitions(self, f, preservesPartitioning=False):
+    def mapPartitions(
+        self: "RDD[T]", f: Callable[[Iterable[T]], Iterable[U]], 
preservesPartitioning: bool = False
+    ) -> "RDD[U]":
         """
         Return a new RDD by applying a function to each partition of this RDD.
 
@@ -442,12 +516,16 @@ class RDD:
         [3, 7]
         """
 
-        def func(s, iterator):
+        def func(_: int, iterator: Iterable[T]) -> Iterable[U]:
             return f(iterator)
 
         return self.mapPartitionsWithIndex(func, preservesPartitioning)
 
-    def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
+    def mapPartitionsWithIndex(
+        self: "RDD[T]",
+        f: Callable[[int, Iterable[T]], Iterable[U]],
+        preservesPartitioning: bool = False,
+    ) -> "RDD[U]":
         """
         Return a new RDD by applying a function to each partition of this RDD,
         while tracking the index of the original partition.
@@ -461,7 +539,11 @@ class RDD:
         """
         return PipelinedRDD(self, f, preservesPartitioning)
 
-    def mapPartitionsWithSplit(self, f, preservesPartitioning=False):
+    def mapPartitionsWithSplit(
+        self: "RDD[T]",
+        f: Callable[[int, Iterable[T]], Iterable[U]],
+        preservesPartitioning: bool = False,
+    ) -> "RDD[U]":
         """
 
         Return a new RDD by applying a function to each partition of this RDD,
@@ -484,7 +566,7 @@ class RDD:
         )
         return self.mapPartitionsWithIndex(f, preservesPartitioning)
 
-    def getNumPartitions(self):
+    def getNumPartitions(self) -> int:
         """
         Returns the number of partitions in RDD
 
@@ -496,7 +578,7 @@ class RDD:
         """
         return self._jrdd.partitions().size()
 
-    def filter(self, f):
+    def filter(self: "RDD[T]", f: Callable[[T], bool]) -> "RDD[T]":
         """
         Return a new RDD containing only the elements that satisfy a predicate.
 
@@ -507,12 +589,12 @@ class RDD:
         [2, 4]
         """
 
-        def func(iterator):
+        def func(iterator: Iterable[T]) -> Iterable[T]:
             return filter(fail_on_stopiteration(f), iterator)
 
         return self.mapPartitions(func, True)
 
-    def distinct(self, numPartitions=None):
+    def distinct(self: "RDD[T]", numPartitions: Optional[int] = None) -> 
"RDD[T]":
         """
         Return a new RDD containing the distinct elements in this RDD.
 
@@ -527,7 +609,9 @@ class RDD:
             .map(lambda x: x[0])
         )
 
-    def sample(self, withReplacement, fraction, seed=None):
+    def sample(
+        self: "RDD[T]", withReplacement: bool, fraction: float, seed: 
Optional[int] = None
+    ) -> "RDD[T]":
         """
         Return a sampled subset of this RDD.
 
@@ -556,7 +640,9 @@ class RDD:
         assert fraction >= 0.0, "Negative fraction value: %s" % fraction
         return self.mapPartitionsWithIndex(RDDSampler(withReplacement, 
fraction, seed).func, True)
 
-    def randomSplit(self, weights, seed=None):
+    def randomSplit(
+        self: "RDD[T]", weights: Sequence[Union[int, float]], seed: 
Optional[int] = None
+    ) -> "List[RDD[T]]":
         """
         Randomly splits this RDD with the provided weights.
 
@@ -593,7 +679,9 @@ class RDD:
         ]
 
     # this is ported from scala/spark/RDD.scala
-    def takeSample(self, withReplacement, num, seed=None):
+    def takeSample(
+        self: "RDD[T]", withReplacement: bool, num: int, seed: Optional[int] = 
None
+    ) -> List[T]:
         """
         Return a fixed-size sampled subset of this RDD.
 
@@ -651,7 +739,9 @@ class RDD:
         return samples[0:num]
 
     @staticmethod
-    def _computeFractionForSampleSize(sampleSizeLowerBound, total, 
withReplacement):
+    def _computeFractionForSampleSize(
+        sampleSizeLowerBound: int, total: int, withReplacement: bool
+    ) -> float:
         """
         Returns a sampling rate that guarantees a sample of
         size >= sampleSizeLowerBound 99.99% of the time.
@@ -683,7 +773,7 @@ class RDD:
             gamma = -log(delta) / total
             return min(1, fraction + gamma + sqrt(gamma * gamma + 2 * gamma * 
fraction))
 
-    def union(self, other):
+    def union(self: "RDD[T]", other: "RDD[U]") -> "RDD[Union[T, U]]":
         """
         Return the union of this RDD and another one.
 
@@ -694,7 +784,9 @@ class RDD:
         [1, 1, 2, 3, 1, 1, 2, 3]
         """
         if self._jrdd_deserializer == other._jrdd_deserializer:
-            rdd = RDD(self._jrdd.union(other._jrdd), self.ctx, 
self._jrdd_deserializer)
+            rdd: "RDD[Union[T, U]]" = RDD(
+                self._jrdd.union(other._jrdd), self.ctx, 
self._jrdd_deserializer
+            )
         else:
             # These RDDs contain data in different serialized formats, so we
             # must normalize them to the default serializer.
@@ -708,7 +800,7 @@ class RDD:
             rdd.partitioner = self.partitioner
         return rdd
 
-    def intersection(self, other):
+    def intersection(self: "RDD[T]", other: "RDD[T]") -> "RDD[T]":
         """
         Return the intersection of this RDD and another one. The output will
         not contain any duplicate elements, even if the input RDDs did.
@@ -731,14 +823,14 @@ class RDD:
             .keys()
         )
 
-    def _reserialize(self, serializer=None):
+    def _reserialize(self: "RDD[T]", serializer: Optional[Serializer] = None) 
-> "RDD[T]":
         serializer = serializer or self.ctx.serializer
         if self._jrdd_deserializer != serializer:
             self = self.map(lambda x: x, preservesPartitioning=True)
             self._jrdd_deserializer = serializer
         return self
 
-    def __add__(self, other):
+    def __add__(self: "RDD[T]", other: "RDD[U]") -> "RDD[Union[T, U]]":
         """
         Return the union of this RDD and another one.
 
@@ -752,9 +844,43 @@ class RDD:
             raise TypeError
         return self.union(other)
 
+    @overload
     def repartitionAndSortWithinPartitions(
-        self, numPartitions=None, partitionFunc=portable_hash, ascending=True, 
keyfunc=lambda x: x
-    ):
+        self: "RDD[Tuple[S, V]]",
+        numPartitions: Optional[int] = ...,
+        partitionFunc: Callable[["S"], int] = ...,
+        ascending: bool = ...,
+    ) -> "RDD[Tuple[S, V]]":
+        ...
+
+    @overload
+    def repartitionAndSortWithinPartitions(
+        self: "RDD[Tuple[K, V]]",
+        numPartitions: Optional[int],
+        partitionFunc: Callable[[K], int],
+        ascending: bool,
+        keyfunc: Callable[[K], "S"],
+    ) -> "RDD[Tuple[K, V]]":
+        ...
+
+    @overload
+    def repartitionAndSortWithinPartitions(
+        self: "RDD[Tuple[K, V]]",
+        numPartitions: Optional[int] = ...,
+        partitionFunc: Callable[[K], int] = ...,
+        ascending: bool = ...,
+        *,
+        keyfunc: Callable[[K], "S"],
+    ) -> "RDD[Tuple[K, V]]":
+        ...
+
+    def repartitionAndSortWithinPartitions(
+        self: "RDD[Tuple[Any, Any]]",
+        numPartitions: Optional[int] = None,
+        partitionFunc: Callable[[Any], int] = portable_hash,
+        ascending: bool = True,
+        keyfunc: Callable[[Any], Any] = lambda x: x,
+    ) -> "RDD[Tuple[Any, Any]]":
         """
         Repartition the RDD according to the given partitioner and, within 
each resulting partition,
         sort records by their keys.
@@ -772,13 +898,45 @@ class RDD:
         memory = self._memory_limit()
         serializer = self._jrdd_deserializer
 
-        def sortPartition(iterator):
+        def sortPartition(iterator: Iterable[Tuple[K, V]]) -> 
Iterable[Tuple[K, V]]:
             sort = ExternalSorter(memory * 0.9, serializer).sorted
             return iter(sort(iterator, key=lambda k_v: keyfunc(k_v[0]), 
reverse=(not ascending)))
 
         return self.partitionBy(numPartitions, 
partitionFunc).mapPartitions(sortPartition, True)
 
-    def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: 
x):
+    @overload
+    def sortByKey(
+        self: "RDD[Tuple[S, V]]",
+        ascending: bool = ...,
+        numPartitions: Optional[int] = ...,
+    ) -> "RDD[Tuple[K, V]]":
+        ...
+
+    @overload
+    def sortByKey(
+        self: "RDD[Tuple[K, V]]",
+        ascending: bool,
+        numPartitions: int,
+        keyfunc: Callable[[K], "S"],
+    ) -> "RDD[Tuple[K, V]]":
+        ...
+
+    @overload
+    def sortByKey(
+        self: "RDD[Tuple[K, V]]",
+        ascending: bool = ...,
+        numPartitions: Optional[int] = ...,
+        *,
+        keyfunc: Callable[[K], "S"],
+    ) -> "RDD[Tuple[K, V]]":
+        ...
+
+    def sortByKey(
+        self: "RDD[Tuple[K, V]]",
+        ascending: Optional[bool] = True,
+        numPartitions: Optional[int] = None,
+        keyfunc: Callable[[Any], Any] = lambda x: x,
+    ) -> "RDD[Tuple[K, V]]":
         """
         Sorts this RDD, which is assumed to consist of (key, value) pairs.
 
@@ -802,7 +960,7 @@ class RDD:
         memory = self._memory_limit()
         serializer = self._jrdd_deserializer
 
-        def sortPartition(iterator):
+        def sortPartition(iterator: Iterable[Tuple[K, V]]) -> 
Iterable[Tuple[K, V]]:
             sort = ExternalSorter(memory * 0.9, serializer).sorted
             return iter(sort(iterator, key=lambda kv: keyfunc(kv[0]), 
reverse=(not ascending)))
 
@@ -829,16 +987,21 @@ class RDD:
             for i in range(0, numPartitions - 1)
         ]
 
-        def rangePartitioner(k):
+        def rangePartitioner(k: K) -> int:
             p = bisect.bisect_left(bounds, keyfunc(k))
             if ascending:
                 return p
             else:
-                return numPartitions - 1 - p
+                return numPartitions - 1 - p  # type: ignore[operator]
 
         return self.partitionBy(numPartitions, 
rangePartitioner).mapPartitions(sortPartition, True)
 
-    def sortBy(self, keyfunc, ascending=True, numPartitions=None):
+    def sortBy(
+        self: "RDD[T]",
+        keyfunc: Callable[[T], "S"],
+        ascending: bool = True,
+        numPartitions: Optional[int] = None,
+    ) -> "RDD[T]":
         """
         Sorts this RDD by the given keyfunc
 
@@ -850,9 +1013,13 @@ class RDD:
         >>> sc.parallelize(tmp).sortBy(lambda x: x[1]).collect()
         [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
         """
-        return self.keyBy(keyfunc).sortByKey(ascending, numPartitions).values()
+        return (
+            self.keyBy(keyfunc)  # type: ignore[type-var]
+            .sortByKey(ascending, numPartitions)
+            .values()
+        )
 
-    def glom(self):
+    def glom(self: "RDD[T]") -> "RDD[List[T]]":
         """
         Return an RDD created by coalescing all elements within each partition
         into a list.
@@ -864,12 +1031,12 @@ class RDD:
         [[1, 2], [3, 4]]
         """
 
-        def func(iterator):
+        def func(iterator: Iterable[T]) -> Iterable[List[T]]:
             yield list(iterator)
 
         return self.mapPartitions(func)
 
-    def cartesian(self, other):
+    def cartesian(self: "RDD[T]", other: "RDD[U]") -> "RDD[Tuple[T, U]]":
         """
         Return the Cartesian product of this RDD and another one, that is, the
         RDD of all pairs of elements ``(a, b)`` where ``a`` is in `self` and
@@ -885,7 +1052,12 @@ class RDD:
         deserializer = CartesianDeserializer(self._jrdd_deserializer, 
other._jrdd_deserializer)
         return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer)
 
-    def groupBy(self, f, numPartitions=None, partitionFunc=portable_hash):
+    def groupBy(
+        self: "RDD[T]",
+        f: Callable[[T], K],
+        numPartitions: Optional[int] = None,
+        partitionFunc: Callable[[K], int] = portable_hash,
+    ) -> "RDD[Tuple[K, Iterable[T]]]":
         """
         Return an RDD of grouped items.
 
@@ -898,7 +1070,9 @@ class RDD:
         """
         return self.map(lambda x: (f(x), x)).groupByKey(numPartitions, 
partitionFunc)
 
-    def pipe(self, command, env=None, checkCode=False):
+    def pipe(
+        self, command: str, env: Optional[Dict[str, str]] = None, checkCode: 
bool = False
+    ) -> "RDD[str]":
         """
         Return an RDD created by piping elements to a forked external process.
 
@@ -919,10 +1093,10 @@ class RDD:
         if env is None:
             env = dict()
 
-        def func(iterator):
+        def func(iterator: Iterable[T]) -> Iterable[str]:
             pipe = Popen(shlex.split(command), env=env, stdin=PIPE, 
stdout=PIPE)
 
-            def pipe_objs(out):
+            def pipe_objs(out: IO[bytes]) -> None:
                 for obj in iterator:
                     s = str(obj).rstrip("\n") + "\n"
                     out.write(s.encode("utf-8"))
@@ -930,7 +1104,7 @@ class RDD:
 
             Thread(target=pipe_objs, args=[pipe.stdin]).start()
 
-            def check_return_code():
+            def check_return_code() -> Iterable[int]:
                 pipe.wait()
                 if checkCode and pipe.returncode:
                     raise RuntimeError(
@@ -942,13 +1116,15 @@ class RDD:
                         yield i
 
             return (
-                x.rstrip(b"\n").decode("utf-8")
-                for x in chain(iter(pipe.stdout.readline, b""), 
check_return_code())
+                cast(bytes, x).rstrip(b"\n").decode("utf-8")
+                for x in chain(
+                    iter(cast(IO[bytes], pipe.stdout).readline, b""), 
check_return_code()
+                )
             )
 
         return self.mapPartitions(func)
 
-    def foreach(self, f):
+    def foreach(self: "RDD[T]", f: Callable[[T], None]) -> None:
         """
         Applies a function to all elements of this RDD.
 
@@ -959,14 +1135,14 @@ class RDD:
         """
         f = fail_on_stopiteration(f)
 
-        def processPartition(iterator):
+        def processPartition(iterator: Iterable[T]) -> Iterable[Any]:
             for x in iterator:
                 f(x)
             return iter([])
 
         self.mapPartitions(processPartition).count()  # Force evaluation
 
-    def foreachPartition(self, f):
+    def foreachPartition(self: "RDD[T]", f: Callable[[Iterable[T]], None]) -> 
None:
         """
         Applies a function to each partition of this RDD.
 
@@ -978,16 +1154,16 @@ class RDD:
         >>> sc.parallelize([1, 2, 3, 4, 5]).foreachPartition(f)
         """
 
-        def func(it):
+        def func(it: Iterable[T]) -> Iterable[Any]:
             r = f(it)
             try:
-                return iter(r)
+                return iter(r)  # type: ignore[call-overload]
             except TypeError:
                 return iter([])
 
         self.mapPartitions(func).count()  # Force evaluation
 
-    def collect(self):
+    def collect(self: "RDD[T]") -> List[T]:
         """
         Return a list that contains all of the elements in this RDD.
 
@@ -997,10 +1173,13 @@ class RDD:
         to be small, as all the data is loaded into the driver's memory.
         """
         with SCCallSiteSync(self.context):
+            assert self.ctx._jvm is not None
             sock_info = 
self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
         return list(_load_from_socket(sock_info, self._jrdd_deserializer))
 
-    def collectWithJobGroup(self, groupId, description, 
interruptOnCancel=False):
+    def collectWithJobGroup(
+        self: "RDD[T]", groupId: str, description: str, interruptOnCancel: 
bool = False
+    ) -> "List[T]":
         """
         When collect rdd, use this method to specify job group.
 
@@ -1015,12 +1194,13 @@ class RDD:
         )
 
         with SCCallSiteSync(self.context):
+            assert self.ctx._jvm is not None
             sock_info = self.ctx._jvm.PythonRDD.collectAndServeWithJobGroup(
                 self._jrdd.rdd(), groupId, description, interruptOnCancel
             )
         return list(_load_from_socket(sock_info, self._jrdd_deserializer))
 
-    def reduce(self, f):
+    def reduce(self: "RDD[T]", f: Callable[[T, T], T]) -> T:
         """
         Reduces the elements of this RDD using the specified commutative and
         associative binary operator. Currently reduces partitions locally.
@@ -1039,7 +1219,7 @@ class RDD:
         """
         f = fail_on_stopiteration(f)
 
-        def func(iterator):
+        def func(iterator: Iterable[T]) -> Iterable[T]:
             iterator = iter(iterator)
             try:
                 initial = next(iterator)
@@ -1052,7 +1232,7 @@ class RDD:
             return reduce(f, vals)
         raise ValueError("Can not reduce() empty RDD")
 
-    def treeReduce(self, f, depth=2):
+    def treeReduce(self: "RDD[T]", f: Callable[[T, T], T], depth: int = 2) -> 
T:
         """
         Reduces the elements of this RDD in a multi-level tree pattern.
 
@@ -1080,22 +1260,26 @@ class RDD:
         if depth < 1:
             raise ValueError("Depth cannot be smaller than 1 but got %d." % 
depth)
 
-        zeroValue = None, True  # Use the second entry to indicate whether 
this is a dummy value.
+        # Use the second entry to indicate whether this is a dummy value.
+        zeroValue: Tuple[T, bool] = (  # type: ignore[assignment]
+            None,
+            True,
+        )
 
-        def op(x, y):
+        def op(x: Tuple[T, bool], y: Tuple[T, bool]) -> Tuple[T, bool]:
             if x[1]:
                 return y
             elif y[1]:
                 return x
             else:
-                return f(x[0], y[0]), False
+                return f(x[0], y[0]), False  # type: ignore[arg-type]
 
         reduced = self.map(lambda x: (x, False)).treeAggregate(zeroValue, op, 
op, depth)
         if reduced[1]:
             raise ValueError("Cannot reduce empty RDD.")
         return reduced[0]
 
-    def fold(self, zeroValue, op):
+    def fold(self: "RDD[T]", zeroValue: T, op: Callable[[T, T], T]) -> T:
         """
         Aggregate the elements of each partition, and then the results for all
         the partitions, using a given associative function and a neutral "zero 
value."
@@ -1120,7 +1304,7 @@ class RDD:
         """
         op = fail_on_stopiteration(op)
 
-        def func(iterator):
+        def func(iterator: Iterable[T]) -> Iterable[T]:
             acc = zeroValue
             for obj in iterator:
                 acc = op(acc, obj)
@@ -1132,7 +1316,9 @@ class RDD:
         vals = self.mapPartitions(func).collect()
         return reduce(op, vals, zeroValue)
 
-    def aggregate(self, zeroValue, seqOp, combOp):
+    def aggregate(
+        self: "RDD[T]", zeroValue: U, seqOp: Callable[[U, T], U], combOp: 
Callable[[U, U], U]
+    ) -> U:
         """
         Aggregate the elements of each partition, and then the results for all
         the partitions, using a given combine functions and a neutral "zero
@@ -1158,7 +1344,7 @@ class RDD:
         seqOp = fail_on_stopiteration(seqOp)
         combOp = fail_on_stopiteration(combOp)
 
-        def func(iterator):
+        def func(iterator: Iterable[T]) -> Iterable[U]:
             acc = zeroValue
             for obj in iterator:
                 acc = seqOp(acc, obj)
@@ -1170,7 +1356,13 @@ class RDD:
         vals = self.mapPartitions(func).collect()
         return reduce(combOp, vals, zeroValue)
 
-    def treeAggregate(self, zeroValue, seqOp, combOp, depth=2):
+    def treeAggregate(
+        self: "RDD[T]",
+        zeroValue: U,
+        seqOp: Callable[[U, T], U],
+        combOp: Callable[[U, U], U],
+        depth: int = 2,
+    ) -> U:
         """
         Aggregates the elements of this RDD in a multi-level tree
         pattern.
@@ -1199,7 +1391,7 @@ class RDD:
         if self.getNumPartitions() == 0:
             return zeroValue
 
-        def aggregatePartition(iterator):
+        def aggregatePartition(iterator: Iterable[T]) -> Iterable[U]:
             acc = zeroValue
             for obj in iterator:
                 acc = seqOp(acc, obj)
@@ -1211,10 +1403,10 @@ class RDD:
         # If creating an extra level doesn't help reduce the wall-clock time, 
we stop the tree
         # aggregation.
         while numPartitions > scale + numPartitions / scale:
-            numPartitions /= scale
+            numPartitions /= scale  # type: ignore[assignment]
             curNumPartitions = int(numPartitions)
 
-            def mapPartition(i, iterator):
+            def mapPartition(i: int, iterator: Iterable[U]) -> 
Iterable[Tuple[int, U]]:
                 for obj in iterator:
                     yield (i % curNumPartitions, obj)
 
@@ -1226,7 +1418,15 @@ class RDD:
 
         return partiallyAggregated.reduce(combOp)
 
-    def max(self, key=None):
+    @overload
+    def max(self: "RDD[S]") -> "S":
+        ...
+
+    @overload
+    def max(self: "RDD[T]", key: Callable[[T], "S"]) -> T:
+        ...
+
+    def max(self: "RDD[T]", key: Optional[Callable[[T], "S"]] = None) -> T:
         """
         Find the maximum item in this RDD.
 
@@ -1244,10 +1444,18 @@ class RDD:
         5.0
         """
         if key is None:
-            return self.reduce(max)
-        return self.reduce(lambda a, b: max(a, b, key=key))
+            return self.reduce(max)  # type: ignore[arg-type]
+        return self.reduce(lambda a, b: max(a, b, key=key))  # type: 
ignore[arg-type]
 
-    def min(self, key=None):
+    @overload
+    def min(self: "RDD[S]") -> "S":
+        ...
+
+    @overload
+    def min(self: "RDD[T]", key: Callable[[T], "S"]) -> T:
+        ...
+
+    def min(self: "RDD[T]", key: Optional[Callable[[T], "S"]] = None) -> T:
         """
         Find the minimum item in this RDD.
 
@@ -1265,10 +1473,10 @@ class RDD:
         10.0
         """
         if key is None:
-            return self.reduce(min)
-        return self.reduce(lambda a, b: min(a, b, key=key))
+            return self.reduce(min)  # type: ignore[arg-type]
+        return self.reduce(lambda a, b: min(a, b, key=key))  # type: 
ignore[arg-type]
 
-    def sum(self):
+    def sum(self: "RDD[NumberOrArray]") -> "NumberOrArray":
         """
         Add up the elements in this RDD.
 
@@ -1277,9 +1485,11 @@ class RDD:
         >>> sc.parallelize([1.0, 2.0, 3.0]).sum()
         6.0
         """
-        return self.mapPartitions(lambda x: [sum(x)]).fold(0, operator.add)
+        return self.mapPartitions(lambda x: [sum(x)]).fold(  # type: 
ignore[return-value]
+            0, operator.add
+        )
 
-    def count(self):
+    def count(self) -> int:
         """
         Return the number of elements in this RDD.
 
@@ -1290,18 +1500,22 @@ class RDD:
         """
         return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()
 
-    def stats(self):
+    def stats(self: "RDD[NumberOrArray]") -> StatCounter:
         """
         Return a :class:`StatCounter` object that captures the mean, variance
         and count of the RDD's elements in one operation.
         """
 
-        def redFunc(left_counter, right_counter):
+        def redFunc(left_counter: StatCounter, right_counter: StatCounter) -> 
StatCounter:
             return left_counter.mergeStats(right_counter)
 
-        return self.mapPartitions(lambda i: [StatCounter(i)]).reduce(redFunc)
+        return self.mapPartitions(lambda i: [StatCounter(i)]).reduce(  # type: 
ignore[arg-type]
+            redFunc
+        )
 
-    def histogram(self, buckets):
+    def histogram(
+        self: "RDD[S]", buckets: Union[int, List["S"], Tuple["S", ...]]
+    ) -> Tuple[Sequence["S"], List[int]]:
         """
         Compute a histogram using the provided buckets. The buckets
         are all open to the right except for the last which is closed.
@@ -1345,7 +1559,7 @@ class RDD:
                 raise ValueError("number of buckets must be >= 1")
 
             # filter out non-comparable elements
-            def comparable(x):
+            def comparable(x: Any) -> bool:
                 if x is None:
                     return False
                 if type(x) is float and isnan(x):
@@ -1355,7 +1569,7 @@ class RDD:
             filtered = self.filter(comparable)
 
             # faster than stats()
-            def minmax(a, b):
+            def minmax(a: Tuple["S", "S"], b: Tuple["S", "S"]) -> Tuple["S", 
"S"]:
                 return min(a[0], b[0]), max(a[1], b[1])
 
             try:
@@ -1369,7 +1583,7 @@ class RDD:
                 return [minv, maxv], [filtered.count()]
 
             try:
-                inc = (maxv - minv) / buckets
+                inc = (maxv - minv) / buckets  # type: ignore[operator]
             except TypeError:
                 raise TypeError("Can not generate buckets with non-number in 
RDD")
 
@@ -1378,8 +1592,8 @@ class RDD:
 
             # keep them as integer if possible
             inc = int(inc)
-            if inc * buckets != maxv - minv:
-                inc = (maxv - minv) * 1.0 / buckets
+            if inc * buckets != maxv - minv:  # type: ignore[operator]
+                inc = (maxv - minv) * 1.0 / buckets  # type: ignore[operator]
 
             buckets = [i * inc + minv for i in range(buckets)]
             buckets.append(maxv)  # fix accumulated error
@@ -1403,35 +1617,47 @@ class RDD:
             even = False
             inc = None
             try:
-                steps = [buckets[i + 1] - buckets[i] for i in 
range(len(buckets) - 1)]
+                steps = [
+                    buckets[i + 1] - buckets[i]  # type: ignore[operator]
+                    for i in range(len(buckets) - 1)
+                ]
             except TypeError:
                 pass  # objects in buckets do not support '-'
             else:
                 if max(steps) - min(steps) < 1e-10:  # handle precision errors
                     even = True
-                    inc = (maxv - minv) / (len(buckets) - 1)
+                    inc = (maxv - minv) / (len(buckets) - 1)  # type: 
ignore[operator]
 
         else:
             raise TypeError("buckets should be a list or tuple or number(int 
or long)")
 
-        def histogram(iterator):
-            counters = [0] * len(buckets)
+        def histogram(iterator: Iterable["S"]) -> Iterable[List[int]]:
+            counters = [0] * len(buckets)  # type: ignore[arg-type]
             for i in iterator:
-                if i is None or (type(i) is float and isnan(i)) or i > maxv or 
i < minv:
+                if (
+                    i is None
+                    or (isinstance(i, float) and isnan(i))  # type: 
ignore[arg-type]
+                    or i > maxv
+                    or i < minv
+                ):
                     continue
-                t = int((i - minv) / inc) if even else 
bisect.bisect_right(buckets, i) - 1
+                t = (
+                    int((i - minv) / inc)  # type: ignore[operator]
+                    if even
+                    else bisect.bisect_right(buckets, i) - 1  # type: 
ignore[arg-type]
+                )
                 counters[t] += 1
             # add last two together
             last = counters.pop()
             counters[-1] += last
             return [counters]
 
-        def mergeCounters(a, b):
+        def mergeCounters(a: List[int], b: List[int]) -> List[int]:
             return [i + j for i, j in zip(a, b)]
 
         return buckets, self.mapPartitions(histogram).reduce(mergeCounters)
 
-    def mean(self):
+    def mean(self: "RDD[NumberOrArray]") -> "NumberOrArray":
         """
         Compute the mean of this RDD's elements.
 
@@ -1440,9 +1666,9 @@ class RDD:
         >>> sc.parallelize([1, 2, 3]).mean()
         2.0
         """
-        return self.stats().mean()
+        return self.stats().mean()  # type: ignore[return-value]
 
-    def variance(self):
+    def variance(self: "RDD[NumberOrArray]") -> "NumberOrArray":
         """
         Compute the variance of this RDD's elements.
 
@@ -1451,9 +1677,9 @@ class RDD:
         >>> sc.parallelize([1, 2, 3]).variance()
         0.666...
         """
-        return self.stats().variance()
+        return self.stats().variance()  # type: ignore[return-value]
 
-    def stdev(self):
+    def stdev(self: "RDD[NumberOrArray]") -> "NumberOrArray":
         """
         Compute the standard deviation of this RDD's elements.
 
@@ -1462,9 +1688,9 @@ class RDD:
         >>> sc.parallelize([1, 2, 3]).stdev()
         0.816...
         """
-        return self.stats().stdev()
+        return self.stats().stdev()  # type: ignore[return-value]
 
-    def sampleStdev(self):
+    def sampleStdev(self: "RDD[NumberOrArray]") -> "NumberOrArray":
         """
         Compute the sample standard deviation of this RDD's elements (which
         corrects for bias in estimating the standard deviation by dividing by
@@ -1475,9 +1701,9 @@ class RDD:
         >>> sc.parallelize([1, 2, 3]).sampleStdev()
         1.0
         """
-        return self.stats().sampleStdev()
+        return self.stats().sampleStdev()  # type: ignore[return-value]
 
-    def sampleVariance(self):
+    def sampleVariance(self: "RDD[NumberOrArray]") -> "NumberOrArray":
         """
         Compute the sample variance of this RDD's elements (which corrects
         for bias in estimating the variance by dividing by N-1 instead of N).
@@ -1487,9 +1713,9 @@ class RDD:
         >>> sc.parallelize([1, 2, 3]).sampleVariance()
         1.0
         """
-        return self.stats().sampleVariance()
+        return self.stats().sampleVariance()  # type: ignore[return-value]
 
-    def countByValue(self):
+    def countByValue(self: "RDD[K]") -> Dict[K, int]:
         """
         Return the count of each unique value in this RDD as a dictionary of
         (value, count) pairs.
@@ -1500,20 +1726,28 @@ class RDD:
         [(1, 2), (2, 3)]
         """
 
-        def countPartition(iterator):
-            counts = defaultdict(int)
+        def countPartition(iterator: Iterable[K]) -> Iterable[Dict[K, int]]:
+            counts: Dict[K, int] = defaultdict(int)
             for obj in iterator:
                 counts[obj] += 1
             yield counts
 
-        def mergeMaps(m1, m2):
+        def mergeMaps(m1: Dict[K, int], m2: Dict[K, int]) -> Dict[K, int]:
             for k, v in m2.items():
                 m1[k] += v
             return m1
 
         return self.mapPartitions(countPartition).reduce(mergeMaps)
 
-    def top(self, num, key=None):
+    @overload
+    def top(self: "RDD[S]", num: int) -> List["S"]:
+        ...
+
+    @overload
+    def top(self: "RDD[T]", num: int, key: Callable[[T], "S"]) -> List[T]:
+        ...
+
+    def top(self: "RDD[T]", num: int, key: Optional[Callable[[T], "S"]] = 
None) -> List[T]:
         """
         Get the top N elements from an RDD.
 
@@ -1534,15 +1768,23 @@ class RDD:
         [4, 3, 2]
         """
 
-        def topIterator(iterator):
+        def topIterator(iterator: Iterable[T]) -> Iterable[List[T]]:
             yield heapq.nlargest(num, iterator, key=key)
 
-        def merge(a, b):
+        def merge(a: List[T], b: List[T]) -> List[T]:
             return heapq.nlargest(num, a + b, key=key)
 
         return self.mapPartitions(topIterator).reduce(merge)
 
-    def takeOrdered(self, num, key=None):
+    @overload
+    def takeOrdered(self: "RDD[S]", num: int) -> List["S"]:
+        ...
+
+    @overload
+    def takeOrdered(self: "RDD[T]", num: int, key: Callable[[T], "S"]) -> 
List[T]:
+        ...
+
+    def takeOrdered(self: "RDD[T]", num: int, key: Optional[Callable[[T], 
"S"]] = None) -> List[T]:
         """
         Get the N elements from an RDD ordered in ascending order or as
         specified by the optional key function.
@@ -1560,12 +1802,12 @@ class RDD:
         [10, 9, 7, 6, 5, 4]
         """
 
-        def merge(a, b):
+        def merge(a: List[T], b: List[T]) -> List[T]:
             return heapq.nsmallest(num, a + b, key)
 
         return self.mapPartitions(lambda it: [heapq.nsmallest(num, it, 
key)]).reduce(merge)
 
-    def take(self, num):
+    def take(self: "RDD[T]", num: int) -> List[T]:
         """
         Take the first num elements of the RDD.
 
@@ -1589,7 +1831,7 @@ class RDD:
         >>> sc.parallelize(range(100), 100).filter(lambda x: x > 90).take(3)
         [91, 92, 93]
         """
-        items = []
+        items: List[T] = []
         totalParts = self.getNumPartitions()
         partsScanned = 0
 
@@ -1612,7 +1854,7 @@ class RDD:
 
             left = num - len(items)
 
-            def takeUpToNumLeft(iterator):
+            def takeUpToNumLeft(iterator: Iterable[T]) -> Iterable[T]:
                 iterator = iter(iterator)
                 taken = 0
                 while taken < left:
@@ -1630,7 +1872,7 @@ class RDD:
 
         return items[:num]
 
-    def first(self):
+    def first(self: "RDD[T]") -> T:
         """
         Return the first element in this RDD.
 
@@ -1648,7 +1890,7 @@ class RDD:
             return rs[0]
         raise ValueError("RDD is empty")
 
-    def isEmpty(self):
+    def isEmpty(self) -> bool:
         """
         Returns true if and only if the RDD contains no elements at all.
 
@@ -1665,7 +1907,12 @@ class RDD:
         """
         return self.getNumPartitions() == 0 or len(self.take(1)) == 0
 
-    def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, 
valueConverter=None):
+    def saveAsNewAPIHadoopDataset(
+        self: "RDD[Tuple[K, V]]",
+        conf: Dict[str, str],
+        keyConverter: Optional[str] = None,
+        valueConverter: Optional[str] = None,
+    ) -> None:
         """
         Output a Python RDD of key-value pairs (of form ``RDD[(K, V)]``) to 
any Hadoop file
         system, using the new Hadoop OutputFormat API (mapreduce package). 
Keys/values are
@@ -1683,20 +1930,22 @@ class RDD:
         """
         jconf = self.ctx._dictToJavaMap(conf)
         pickledRDD = self._pickled()
+        assert self.ctx._jvm is not None
+
         self.ctx._jvm.PythonRDD.saveAsHadoopDataset(
             pickledRDD._jrdd, True, jconf, keyConverter, valueConverter, True
         )
 
     def saveAsNewAPIHadoopFile(
-        self,
-        path,
-        outputFormatClass,
-        keyClass=None,
-        valueClass=None,
-        keyConverter=None,
-        valueConverter=None,
-        conf=None,
-    ):
+        self: "RDD[Tuple[K, V]]",
+        path: str,
+        outputFormatClass: str,
+        keyClass: Optional[str] = None,
+        valueClass: Optional[str] = None,
+        keyConverter: Optional[str] = None,
+        valueConverter: Optional[str] = None,
+        conf: Optional[Dict[str, str]] = None,
+    ) -> None:
         """
         Output a Python RDD of key-value pairs (of form ``RDD[(K, V)]``) to 
any Hadoop file
         system, using the new Hadoop OutputFormat API (mapreduce package). Key 
and value types
@@ -1725,6 +1974,8 @@ class RDD:
         """
         jconf = self.ctx._dictToJavaMap(conf)
         pickledRDD = self._pickled()
+        assert self.ctx._jvm is not None
+
         self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(
             pickledRDD._jrdd,
             True,
@@ -1737,7 +1988,12 @@ class RDD:
             jconf,
         )
 
-    def saveAsHadoopDataset(self, conf, keyConverter=None, 
valueConverter=None):
+    def saveAsHadoopDataset(
+        self: "RDD[Tuple[K, V]]",
+        conf: Dict[str, str],
+        keyConverter: Optional[str] = None,
+        valueConverter: Optional[str] = None,
+    ) -> None:
         """
         Output a Python RDD of key-value pairs (of form ``RDD[(K, V)]``) to 
any Hadoop file
         system, using the old Hadoop OutputFormat API (mapred package). 
Keys/values are
@@ -1755,21 +2011,23 @@ class RDD:
         """
         jconf = self.ctx._dictToJavaMap(conf)
         pickledRDD = self._pickled()
+        assert self.ctx._jvm is not None
+
         self.ctx._jvm.PythonRDD.saveAsHadoopDataset(
             pickledRDD._jrdd, True, jconf, keyConverter, valueConverter, False
         )
 
     def saveAsHadoopFile(
-        self,
-        path,
-        outputFormatClass,
-        keyClass=None,
-        valueClass=None,
-        keyConverter=None,
-        valueConverter=None,
-        conf=None,
-        compressionCodecClass=None,
-    ):
+        self: "RDD[Tuple[K, V]]",
+        path: str,
+        outputFormatClass: str,
+        keyClass: Optional[str] = None,
+        valueClass: Optional[str] = None,
+        keyConverter: Optional[str] = None,
+        valueConverter: Optional[str] = None,
+        conf: Optional[Dict[str, str]] = None,
+        compressionCodecClass: Optional[str] = None,
+    ) -> None:
         """
         Output a Python RDD of key-value pairs (of form ``RDD[(K, V)]``) to 
any Hadoop file
         system, using the old Hadoop OutputFormat API (mapred package). Key 
and value types
@@ -1803,6 +2061,8 @@ class RDD:
         """
         jconf = self.ctx._dictToJavaMap(conf)
         pickledRDD = self._pickled()
+        assert self.ctx._jvm is not None
+
         self.ctx._jvm.PythonRDD.saveAsHadoopFile(
             pickledRDD._jrdd,
             True,
@@ -1816,7 +2076,9 @@ class RDD:
             compressionCodecClass,
         )
 
-    def saveAsSequenceFile(self, path, compressionCodecClass=None):
+    def saveAsSequenceFile(
+        self: "RDD[Tuple[K, V]]", path: str, compressionCodecClass: 
Optional[str] = None
+    ) -> None:
         """
         Output a Python RDD of key-value pairs (of form ``RDD[(K, V)]``) to 
any Hadoop file
         system, using the "org.apache.hadoop.io.Writable" types that we 
convert from the
@@ -1834,11 +2096,13 @@ class RDD:
             i.e. "org.apache.hadoop.io.compress.GzipCodec" (None by default)
         """
         pickledRDD = self._pickled()
+        assert self.ctx._jvm is not None
+
         self.ctx._jvm.PythonRDD.saveAsSequenceFile(
             pickledRDD._jrdd, True, path, compressionCodecClass
         )
 
-    def saveAsPickleFile(self, path, batchSize=10):
+    def saveAsPickleFile(self, path: str, batchSize: int = 10) -> None:
         """
         Save this RDD as a SequenceFile of serialized objects. The serializer
         used is :class:`pyspark.serializers.CPickleSerializer`, default batch 
size
@@ -1853,13 +2117,14 @@ class RDD:
         >>> sorted(sc.pickleFile(tmpFile.name, 5).map(str).collect())
         ['1', '2', 'rdd', 'spark']
         """
+        ser: Serializer
         if batchSize == 0:
             ser = AutoBatchedSerializer(CPickleSerializer())
         else:
             ser = BatchedSerializer(CPickleSerializer(), batchSize)
         self._reserialize(ser)._jrdd.saveAsObjectFile(path)
 
-    def saveAsTextFile(self, path, compressionCodecClass=None):
+    def saveAsTextFile(self, path: str, compressionCodecClass: Optional[str] = 
None) -> None:
         """
         Save this RDD as a text file, using string representations of elements.
 
@@ -1904,16 +2169,20 @@ class RDD:
         'bar\\nfoo\\n'
         """
 
-        def func(split, iterator):
+        def func(split: int, iterator: Iterable[Any]) -> Iterable[bytes]:
             for x in iterator:
-                if not isinstance(x, (str, bytes)):
-                    x = str(x)
-                if isinstance(x, str):
-                    x = x.encode("utf-8")
-                yield x
+                if isinstance(x, bytes):
+                    yield x
+                elif isinstance(x, str):
+                    yield x.encode("utf-8")
+                else:
+                    yield str(x).encode("utf-8")
 
         keyed = self.mapPartitionsWithIndex(func)
-        keyed._bypass_serializer = True
+        keyed._bypass_serializer = True  # type: ignore[attr-defined]
+
+        assert self.ctx._jvm is not None
+
         if compressionCodecClass:
             compressionCodec = 
self.ctx._jvm.java.lang.Class.forName(compressionCodecClass)
             
keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path, 
compressionCodec)
@@ -1922,7 +2191,7 @@ class RDD:
 
     # Pair functions
 
-    def collectAsMap(self):
+    def collectAsMap(self: "RDD[Tuple[K, V]]") -> Dict[K, V]:
         """
         Return the key-value pairs in this RDD to the master as a dictionary.
 
@@ -1941,7 +2210,7 @@ class RDD:
         """
         return dict(self.collect())
 
-    def keys(self):
+    def keys(self: "RDD[Tuple[K, V]]") -> "RDD[K]":
         """
         Return an RDD with the keys of each tuple.
 
@@ -1953,7 +2222,7 @@ class RDD:
         """
         return self.map(lambda x: x[0])
 
-    def values(self):
+    def values(self: "RDD[Tuple[K, V]]") -> "RDD[V]":
         """
         Return an RDD with the values of each tuple.
 
@@ -1965,7 +2234,12 @@ class RDD:
         """
         return self.map(lambda x: x[1])
 
-    def reduceByKey(self, func, numPartitions=None, 
partitionFunc=portable_hash):
+    def reduceByKey(
+        self: "RDD[Tuple[K, V]]",
+        func: Callable[[V, V], V],
+        numPartitions: Optional[int] = None,
+        partitionFunc: Callable[[K], int] = portable_hash,
+    ) -> "RDD[Tuple[K, V]]":
         """
         Merge the values for each key using an associative and commutative 
reduce function.
 
@@ -1985,7 +2259,7 @@ class RDD:
         """
         return self.combineByKey(lambda x: x, func, func, numPartitions, 
partitionFunc)
 
-    def reduceByKeyLocally(self, func):
+    def reduceByKeyLocally(self: "RDD[Tuple[K, V]]", func: Callable[[V, V], 
V]) -> Dict[K, V]:
         """
         Merge the values for each key using an associative and commutative 
reduce function, but
         return the results immediately to the master as a dictionary.
@@ -2002,20 +2276,20 @@ class RDD:
         """
         func = fail_on_stopiteration(func)
 
-        def reducePartition(iterator):
-            m = {}
+        def reducePartition(iterator: Iterable[Tuple[K, V]]) -> 
Iterable[Dict[K, V]]:
+            m: Dict[K, V] = {}
             for k, v in iterator:
                 m[k] = func(m[k], v) if k in m else v
             yield m
 
-        def mergeMaps(m1, m2):
+        def mergeMaps(m1: Dict[K, V], m2: Dict[K, V]) -> Dict[K, V]:
             for k, v in m2.items():
                 m1[k] = func(m1[k], v) if k in m1 else v
             return m1
 
         return self.mapPartitions(reducePartition).reduce(mergeMaps)
 
-    def countByKey(self):
+    def countByKey(self: "RDD[Tuple[K, V]]") -> Dict[K, int]:
         """
         Count the number of elements for each key, and return the result to the
         master as a dictionary.
@@ -2028,7 +2302,11 @@ class RDD:
         """
         return self.map(lambda x: x[0]).countByValue()
 
-    def join(self, other, numPartitions=None):
+    def join(
+        self: "RDD[Tuple[K, V]]",
+        other: "RDD[Tuple[K, U]]",
+        numPartitions: Optional[int] = None,
+    ) -> "RDD[Tuple[K, Tuple[V, U]]]":
         """
         Return an RDD containing all pairs of elements with matching keys in
         `self` and `other`.
@@ -2047,7 +2325,11 @@ class RDD:
         """
         return python_join(self, other, numPartitions)
 
-    def leftOuterJoin(self, other, numPartitions=None):
+    def leftOuterJoin(
+        self: "RDD[Tuple[K, V]]",
+        other: "RDD[Tuple[K, U]]",
+        numPartitions: Optional[int] = None,
+    ) -> "RDD[Tuple[K, Tuple[V, Optional[U]]]]":
         """
         Perform a left outer join of `self` and `other`.
 
@@ -2066,7 +2348,11 @@ class RDD:
         """
         return python_left_outer_join(self, other, numPartitions)
 
-    def rightOuterJoin(self, other, numPartitions=None):
+    def rightOuterJoin(
+        self: "RDD[Tuple[K, V]]",
+        other: "RDD[Tuple[K, U]]",
+        numPartitions: Optional[int] = None,
+    ) -> "RDD[Tuple[K, Tuple[Optional[V], U]]]":
         """
         Perform a right outer join of `self` and `other`.
 
@@ -2085,7 +2371,11 @@ class RDD:
         """
         return python_right_outer_join(self, other, numPartitions)
 
-    def fullOuterJoin(self, other, numPartitions=None):
+    def fullOuterJoin(
+        self: "RDD[Tuple[K, V]]",
+        other: "RDD[Tuple[K, U]]",
+        numPartitions: Optional[int] = None,
+    ) -> "RDD[Tuple[K, Tuple[Optional[V], Optional[U]]]]":
         """
         Perform a right outer join of `self` and `other`.
 
@@ -2111,7 +2401,11 @@ class RDD:
     # TODO: add option to control map-side combining
     # portable_hash is used as default, because builtin hash of None is 
different
     # cross machines.
-    def partitionBy(self, numPartitions, partitionFunc=portable_hash):
+    def partitionBy(
+        self: "RDD[Tuple[K, V]]",
+        numPartitions: Optional[int],
+        partitionFunc: Callable[[K], int] = portable_hash,
+    ) -> "RDD[Tuple[K, V]]":
         """
         Return a copy of the RDD partitioned using the specified partitioner.
 
@@ -2138,13 +2432,13 @@ class RDD:
 
         limit = self._memory_limit() / 2
 
-        def add_shuffle_key(split, iterator):
+        def add_shuffle_key(split: int, iterator: Iterable[Tuple[K, V]]) -> 
Iterable[bytes]:
 
             buckets = defaultdict(list)
-            c, batch = 0, min(10 * numPartitions, 1000)
+            c, batch = 0, min(10 * numPartitions, 1000)  # type: 
ignore[operator]
 
             for k, v in iterator:
-                buckets[partitionFunc(k) % numPartitions].append((k, v))
+                buckets[partitionFunc(k) % numPartitions].append((k, v))  # 
type: ignore[operator]
                 c += 1
 
                 # check used memory and avg size of chunk of objects
@@ -2160,7 +2454,7 @@ class RDD:
                     avg = int(size / n) >> 20
                     # let 1M < avg < 10M
                     if avg < 1:
-                        batch = min(sys.maxsize, batch * 1.5)
+                        batch = min(sys.maxsize, batch * 1.5)  # type: 
ignore[assignment]
                     elif avg > 10:
                         batch = max(int(batch / 1.5), 1)
                     c = 0
@@ -2170,24 +2464,26 @@ class RDD:
                 yield outputSerializer.dumps(items)
 
         keyed = self.mapPartitionsWithIndex(add_shuffle_key, 
preservesPartitioning=True)
-        keyed._bypass_serializer = True
+        keyed._bypass_serializer = True  # type: ignore[attr-defined]
+        assert self.ctx._jvm is not None
+
         with SCCallSiteSync(self.context):
             pairRDD = 
self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
             jpartitioner = self.ctx._jvm.PythonPartitioner(numPartitions, 
id(partitionFunc))
         jrdd = 
self.ctx._jvm.PythonRDD.valueOfPair(pairRDD.partitionBy(jpartitioner))
-        rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
+        rdd: "RDD[Tuple[K, V]]" = RDD(jrdd, self.ctx, 
BatchedSerializer(outputSerializer))
         rdd.partitioner = partitioner
         return rdd
 
     # TODO: add control over map-side aggregation
     def combineByKey(
-        self,
-        createCombiner,
-        mergeValue,
-        mergeCombiners,
-        numPartitions=None,
-        partitionFunc=portable_hash,
-    ):
+        self: "RDD[Tuple[K, V]]",
+        createCombiner: Callable[[V], U],
+        mergeValue: Callable[[U, V], U],
+        mergeCombiners: Callable[[U, U], U],
+        numPartitions: Optional[int] = None,
+        partitionFunc: Callable[[K], int] = portable_hash,
+    ) -> "RDD[Tuple[K, U]]":
         """
         Generic function to combine the elements for each key using a custom
         set of aggregation functions.
@@ -2238,7 +2534,7 @@ class RDD:
         memory = self._memory_limit()
         agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
 
-        def combineLocally(iterator):
+        def combineLocally(iterator: Iterable[Tuple[K, V]]) -> 
Iterable[Tuple[K, U]]:
             merger = ExternalMerger(agg, memory * 0.9, serializer)
             merger.mergeValues(iterator)
             return merger.items()
@@ -2246,7 +2542,7 @@ class RDD:
         locally_combined = self.mapPartitions(combineLocally, 
preservesPartitioning=True)
         shuffled = locally_combined.partitionBy(numPartitions, partitionFunc)
 
-        def _mergeCombiners(iterator):
+        def _mergeCombiners(iterator: Iterable[Tuple[K, U]]) -> 
Iterable[Tuple[K, U]]:
             merger = ExternalMerger(agg, memory, serializer)
             merger.mergeCombiners(iterator)
             return merger.items()
@@ -2254,8 +2550,13 @@ class RDD:
         return shuffled.mapPartitions(_mergeCombiners, 
preservesPartitioning=True)
 
     def aggregateByKey(
-        self, zeroValue, seqFunc, combFunc, numPartitions=None, 
partitionFunc=portable_hash
-    ):
+        self: "RDD[Tuple[K, V]]",
+        zeroValue: U,
+        seqFunc: Callable[[U, V], U],
+        combFunc: Callable[[U, U], U],
+        numPartitions: Optional[int] = None,
+        partitionFunc: Callable[[K], int] = portable_hash,
+    ) -> "RDD[Tuple[K, U]]":
         """
         Aggregate the values of each key, using given combine functions and a 
neutral
         "zero value". This function can return a different result type, U, 
than the type
@@ -2266,14 +2567,20 @@ class RDD:
         allowed to modify and return their first argument instead of creating 
a new U.
         """
 
-        def createZero():
+        def createZero() -> U:
             return copy.deepcopy(zeroValue)
 
         return self.combineByKey(
             lambda v: seqFunc(createZero(), v), seqFunc, combFunc, 
numPartitions, partitionFunc
         )
 
-    def foldByKey(self, zeroValue, func, numPartitions=None, 
partitionFunc=portable_hash):
+    def foldByKey(
+        self: "RDD[Tuple[K, V]]",
+        zeroValue: V,
+        func: Callable[[V, V], V],
+        numPartitions: Optional[int] = None,
+        partitionFunc: Callable[[K], int] = portable_hash,
+    ) -> "RDD[Tuple[K, V]]":
         """
         Merge the values for each key using an associative function "func"
         and a neutral "zeroValue" which may be added to the result an
@@ -2288,18 +2595,22 @@ class RDD:
         [('a', 2), ('b', 1)]
         """
 
-        def createZero():
+        def createZero() -> V:
             return copy.deepcopy(zeroValue)
 
         return self.combineByKey(
             lambda v: func(createZero(), v), func, func, numPartitions, 
partitionFunc
         )
 
-    def _memory_limit(self):
+    def _memory_limit(self) -> int:
         return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", 
"512m"))
 
     # TODO: support variant with custom partitioner
-    def groupByKey(self, numPartitions=None, partitionFunc=portable_hash):
+    def groupByKey(
+        self: "RDD[Tuple[K, V]]",
+        numPartitions: Optional[int] = None,
+        partitionFunc: Callable[[K], int] = portable_hash,
+    ) -> "RDD[Tuple[K, Iterable[V]]]":
         """
         Group the values for each key in the RDD into a single sequence.
         Hash-partitions the resulting RDD with numPartitions partitions.
@@ -2319,14 +2630,14 @@ class RDD:
         [('a', [1, 1]), ('b', [1])]
         """
 
-        def createCombiner(x):
+        def createCombiner(x: V) -> List[V]:
             return [x]
 
-        def mergeValue(xs, x):
+        def mergeValue(xs: List[V], x: V) -> List[V]:
             xs.append(x)
             return xs
 
-        def mergeCombiners(a, b):
+        def mergeCombiners(a: List[V], b: List[V]) -> List[V]:
             a.extend(b)
             return a
 
@@ -2334,7 +2645,7 @@ class RDD:
         serializer = self._jrdd_deserializer
         agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
 
-        def combine(iterator):
+        def combine(iterator: Iterable[Tuple[K, V]]) -> Iterable[Tuple[K, 
List[V]]]:
             merger = ExternalMerger(agg, memory * 0.9, serializer)
             merger.mergeValues(iterator)
             return merger.items()
@@ -2342,14 +2653,16 @@ class RDD:
         locally_combined = self.mapPartitions(combine, 
preservesPartitioning=True)
         shuffled = locally_combined.partitionBy(numPartitions, partitionFunc)
 
-        def groupByKey(it):
+        def groupByKey(it: Iterable[Tuple[K, List[V]]]) -> Iterable[Tuple[K, 
List[V]]]:
             merger = ExternalGroupBy(agg, memory, serializer)
             merger.mergeCombiners(it)
             return merger.items()
 
         return shuffled.mapPartitions(groupByKey, 
True).mapValues(ResultIterable)
 
-    def flatMapValues(self, f):
+    def flatMapValues(
+        self: "RDD[Tuple[K, V]]", f: Callable[[V], Iterable[U]]
+    ) -> "RDD[Tuple[K, U]]":
         """
         Pass each value in the key-value pair RDD through a flatMap function
         without changing the keys; this also retains the original RDD's
@@ -2363,12 +2676,12 @@ class RDD:
         [('a', 'x'), ('a', 'y'), ('a', 'z'), ('b', 'p'), ('b', 'r')]
         """
 
-        def flat_map_fn(kv):
+        def flat_map_fn(kv: Tuple[K, V]) -> Iterable[Tuple[K, U]]:
             return ((kv[0], x) for x in f(kv[1]))
 
         return self.flatMap(flat_map_fn, preservesPartitioning=True)
 
-    def mapValues(self, f):
+    def mapValues(self: "RDD[Tuple[K, V]]", f: Callable[[V], U]) -> 
"RDD[Tuple[K, U]]":
         """
         Pass each value in the key-value pair RDD through a map function
         without changing the keys; this also retains the original RDD's
@@ -2382,12 +2695,45 @@ class RDD:
         [('a', 3), ('b', 1)]
         """
 
-        def map_values_fn(kv):
+        def map_values_fn(kv: Tuple[K, V]) -> Tuple[K, U]:
             return kv[0], f(kv[1])
 
         return self.map(map_values_fn, preservesPartitioning=True)
 
-    def groupWith(self, other, *others):
+    @overload
+    def groupWith(
+        self: "RDD[Tuple[K, V]]", other: "RDD[Tuple[K, V1]]"
+    ) -> "RDD[Tuple[K, Tuple[ResultIterable[V], ResultIterable[V1]]]]":
+        ...
+
+    @overload
+    def groupWith(
+        self: "RDD[Tuple[K, V]]", other: "RDD[Tuple[K, V1]]", __o1: 
"RDD[Tuple[K, V2]]"
+    ) -> "RDD[Tuple[K, Tuple[ResultIterable[V], ResultIterable[V1], 
ResultIterable[V2]]]]":
+        ...
+
+    @overload
+    def groupWith(
+        self: "RDD[Tuple[K, V]]",
+        other: "RDD[Tuple[K, V1]]",
+        _o1: "RDD[Tuple[K, V2]]",
+        _o2: "RDD[Tuple[K, V3]]",
+    ) -> """RDD[
+        Tuple[
+            K,
+            Tuple[
+                ResultIterable[V],
+                ResultIterable[V1],
+                ResultIterable[V2],
+                ResultIterable[V3],
+            ],
+        ]
+    ]""":
+        ...
+
+    def groupWith(  # type: ignore[misc]
+        self: "RDD[Tuple[Any, Any]]", other: "RDD[Tuple[Any, Any]]", *others: 
"RDD[Tuple[Any, Any]]"
+    ) -> "RDD[Tuple[Any, Tuple[ResultIterable[Any], ...]]]":
         """
         Alias for cogroup but with support for multiple RDDs.
 
@@ -2404,7 +2750,11 @@ class RDD:
         return python_cogroup((self, other) + others, numPartitions=None)
 
     # TODO: add variant with custom partitioner
-    def cogroup(self, other, numPartitions=None):
+    def cogroup(
+        self: "RDD[Tuple[K, V]]",
+        other: "RDD[Tuple[K, U]]",
+        numPartitions: Optional[int] = None,
+    ) -> "RDD[Tuple[K, Tuple[ResultIterable[V], ResultIterable[U]]]]":
         """
         For each key k in `self` or `other`, return a resulting RDD that
         contains a tuple with the list of values for that key in `self` as
@@ -2419,7 +2769,12 @@ class RDD:
         """
         return python_cogroup((self, other), numPartitions)
 
-    def sampleByKey(self, withReplacement, fractions, seed=None):
+    def sampleByKey(
+        self: "RDD[Tuple[K, V]]",
+        withReplacement: bool,
+        fractions: Dict[K, Union[float, int]],
+        seed: Optional[int] = None,
+    ) -> "RDD[Tuple[K, V]]":
         """
         Return a subset of this RDD sampled by key (via stratified sampling).
         Create a sample of this RDD using variable sampling rates for
@@ -2443,7 +2798,11 @@ class RDD:
             RDDStratifiedSampler(withReplacement, fractions, seed).func, True
         )
 
-    def subtractByKey(self, other, numPartitions=None):
+    def subtractByKey(
+        self: "RDD[Tuple[K, V]]",
+        other: "RDD[Tuple[K, Any]]",
+        numPartitions: Optional[int] = None,
+    ) -> "RDD[Tuple[K, V]]":
         """
         Return each (key, value) pair in `self` that has no pair with matching
         key in `other`.
@@ -2456,13 +2815,17 @@ class RDD:
         [('b', 4), ('b', 5)]
         """
 
-        def filter_func(pair):
+        def filter_func(pair: Tuple[K, Tuple[V, Any]]) -> bool:
             key, (val1, val2) = pair
-            return val1 and not val2
+            return val1 and not val2  # type: ignore[return-value]
 
-        return self.cogroup(other, 
numPartitions).filter(filter_func).flatMapValues(lambda x: x[0])
+        return (
+            self.cogroup(other, numPartitions)
+            .filter(filter_func)  # type: ignore[arg-type]
+            .flatMapValues(lambda x: x[0])
+        )
 
-    def subtract(self, other, numPartitions=None):
+    def subtract(self: "RDD[T]", other: "RDD[T]", numPartitions: Optional[int] 
= None) -> "RDD[T]":
         """
         Return each value in `self` that is not contained in `other`.
 
@@ -2477,7 +2840,7 @@ class RDD:
         rdd = other.map(lambda x: (x, True))
         return self.map(lambda x: (x, True)).subtractByKey(rdd, 
numPartitions).keys()
 
-    def keyBy(self, f):
+    def keyBy(self: "RDD[T]", f: Callable[[T], K]) -> "RDD[Tuple[K, T]]":
         """
         Creates tuples of the elements in this RDD by applying `f`.
 
@@ -2490,7 +2853,7 @@ class RDD:
         """
         return self.map(lambda x: (f(x), x))
 
-    def repartition(self, numPartitions):
+    def repartition(self: "RDD[T]", numPartitions: int) -> "RDD[T]":
         """
          Return a new RDD that has exactly numPartitions partitions.
 
@@ -2511,7 +2874,7 @@ class RDD:
         """
         return self.coalesce(numPartitions, shuffle=True)
 
-    def coalesce(self, numPartitions, shuffle=False):
+    def coalesce(self: "RDD[T]", numPartitions: int, shuffle: bool = False) -> 
"RDD[T]":
         """
         Return a new RDD that is reduced into `numPartitions` partitions.
 
@@ -2535,7 +2898,7 @@ class RDD:
             jrdd = self._jrdd.coalesce(numPartitions, shuffle)
         return RDD(jrdd, self.ctx, jrdd_deserializer)
 
-    def zip(self, other):
+    def zip(self: "RDD[T]", other: "RDD[U]") -> "RDD[Tuple[T, U]]":
         """
         Zips this RDD with another one, returning key-value pairs with the
         first element in each RDD second element in each RDD, etc. Assumes
@@ -2551,12 +2914,12 @@ class RDD:
         [(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)]
         """
 
-        def get_batch_size(ser):
+        def get_batch_size(ser: Serializer) -> int:
             if isinstance(ser, BatchedSerializer):
                 return ser.batchSize
             return 1  # not batched
 
-        def batch_as(rdd, batchSize):
+        def batch_as(rdd: "RDD[V]", batchSize: int) -> "RDD[V]":
             return rdd._reserialize(BatchedSerializer(CPickleSerializer(), 
batchSize))
 
         my_batch = get_batch_size(self._jrdd_deserializer)
@@ -2579,7 +2942,7 @@ class RDD:
         deserializer = PairDeserializer(self._jrdd_deserializer, 
other._jrdd_deserializer)
         return RDD(pairRDD, self.ctx, deserializer)
 
-    def zipWithIndex(self):
+    def zipWithIndex(self: "RDD[T]") -> "RDD[Tuple[T, int]]":
         """
         Zips this RDD with its element indices.
 
@@ -2602,13 +2965,13 @@ class RDD:
             for i in range(len(nums) - 1):
                 starts.append(starts[-1] + nums[i])
 
-        def func(k, it):
+        def func(k: int, it: Iterable[T]) -> Iterable[Tuple[T, int]]:
             for i, v in enumerate(it, starts[k]):
                 yield v, i
 
         return self.mapPartitionsWithIndex(func)
 
-    def zipWithUniqueId(self):
+    def zipWithUniqueId(self: "RDD[T]") -> "RDD[Tuple[T, int]]":
         """
         Zips this RDD with generated unique Long ids.
 
@@ -2624,21 +2987,20 @@ class RDD:
         """
         n = self.getNumPartitions()
 
-        def func(k, it):
+        def func(k: int, it: Iterable[T]) -> Iterable[Tuple[T, int]]:
             for i, v in enumerate(it):
                 yield v, i * n + k
 
         return self.mapPartitionsWithIndex(func)
 
-    def name(self):
+    def name(self) -> Optional[str]:
         """
         Return the name of this RDD.
         """
         n = self._jrdd.name()
-        if n:
-            return n
+        return n if n else None
 
-    def setName(self, name):
+    def setName(self: "RDD[T]", name: str) -> "RDD[T]":
         """
         Assign a name to this RDD.
 
@@ -2651,15 +3013,15 @@ class RDD:
         self._jrdd.setName(name)
         return self
 
-    def toDebugString(self):
+    def toDebugString(self) -> Optional[bytes]:
         """
         A description of this RDD and its recursive dependencies for debugging.
         """
         debug_string = self._jrdd.toDebugString()
-        if debug_string:
-            return debug_string.encode("utf-8")
 
-    def getStorageLevel(self):
+        return debug_string.encode("utf-8") if debug_string else None
+
+    def getStorageLevel(self) -> StorageLevel:
         """
         Get the RDD's current storage level.
 
@@ -2681,7 +3043,7 @@ class RDD:
         )
         return storage_level
 
-    def _defaultReducePartitions(self):
+    def _defaultReducePartitions(self) -> int:
         """
         Returns the default number of partitions to use during reduce tasks 
(e.g., groupBy).
         If spark.default.parallelism is set, then we'll use the value from 
SparkContext
@@ -2696,7 +3058,7 @@ class RDD:
         else:
             return self.getNumPartitions()
 
-    def lookup(self, key):
+    def lookup(self: "RDD[Tuple[K, V]]", key: K) -> List[V]:
         """
         Return the list of values in the RDD for key `key`. This operation
         is done efficiently if the RDD has a known partitioner by only
@@ -2724,16 +3086,18 @@ class RDD:
 
         return values.collect()
 
-    def _to_java_object_rdd(self):
+    def _to_java_object_rdd(self) -> "JavaObject":
         """Return a JavaRDD of Object by unpickling
 
         It will convert each Python object into Java object by Pickle, 
whenever the
         RDD is serialized in batch or not.
         """
         rdd = self._pickled()
+        assert self.ctx._jvm is not None
+
         return self.ctx._jvm.SerDeUtil.pythonToJava(rdd._jrdd, True)
 
-    def countApprox(self, timeout, confidence=0.95):
+    def countApprox(self, timeout: int, confidence: float = 0.95) -> int:
         """
         Approximate version of count() that returns a potentially incomplete
         result within a timeout, even if not all tasks have finished.
@@ -2747,7 +3111,9 @@ class RDD:
         drdd = self.mapPartitions(lambda it: [float(sum(1 for i in it))])
         return int(drdd.sumApprox(timeout, confidence))
 
-    def sumApprox(self, timeout, confidence=0.95):
+    def sumApprox(
+        self: "RDD[Union[float, int]]", timeout: int, confidence: float = 0.95
+    ) -> BoundedFloat:
         """
         Approximate operation to return the sum within a timeout
         or meet the confidence.
@@ -2760,11 +3126,14 @@ class RDD:
         True
         """
         jrdd = self.mapPartitions(lambda it: 
[float(sum(it))])._to_java_object_rdd()
+        assert self.ctx._jvm is not None
         jdrdd = self.ctx._jvm.JavaDoubleRDD.fromRDD(jrdd.rdd())
         r = jdrdd.sumApprox(timeout, confidence).getFinalValue()
         return BoundedFloat(r.mean(), r.confidence(), r.low(), r.high())
 
-    def meanApprox(self, timeout, confidence=0.95):
+    def meanApprox(
+        self: "RDD[Union[float, int]]", timeout: int, confidence: float = 0.95
+    ) -> BoundedFloat:
         """
         Approximate operation to return the mean within a timeout
         or meet the confidence.
@@ -2777,11 +3146,12 @@ class RDD:
         True
         """
         jrdd = self.map(float)._to_java_object_rdd()
+        assert self.ctx._jvm is not None
         jdrdd = self.ctx._jvm.JavaDoubleRDD.fromRDD(jrdd.rdd())
         r = jdrdd.meanApprox(timeout, confidence).getFinalValue()
         return BoundedFloat(r.mean(), r.confidence(), r.low(), r.high())
 
-    def countApproxDistinct(self, relativeSD=0.05):
+    def countApproxDistinct(self: "RDD[T]", relativeSD: float = 0.05) -> int:
         """
         Return approximate number of distinct elements in the RDD.
 
@@ -2814,7 +3184,7 @@ class RDD:
         hashRDD = self.map(lambda x: portable_hash(x) & 0xFFFFFFFF)
         return hashRDD._to_java_object_rdd().countApproxDistinct(relativeSD)
 
-    def toLocalIterator(self, prefetchPartitions=False):
+    def toLocalIterator(self: "RDD[T]", prefetchPartitions: bool = False) -> 
Iterator[T]:
         """
         Return an iterator that contains all of the elements in this RDD.
         The iterator will consume as much memory as the largest partition in 
this RDD.
@@ -2832,13 +3202,15 @@ class RDD:
         >>> [x for x in rdd.toLocalIterator()]
         [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
         """
+        assert self.ctx._jvm is not None
+
         with SCCallSiteSync(self.context):
             sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(
                 self._jrdd.rdd(), prefetchPartitions
             )
         return _local_iterator_from_socket(sock_info, self._jrdd_deserializer)
 
-    def barrier(self):
+    def barrier(self: "RDD[T]") -> "RDDBarrier[T]":
         """
         Marks the current stage as a barrier stage, where Spark must launch 
all tasks together.
         In case of a task failure, instead of only restarting the failed task, 
Spark will abort the
@@ -2868,13 +3240,13 @@ class RDD:
         """
         return RDDBarrier(self)
 
-    def _is_barrier(self):
+    def _is_barrier(self) -> bool:
         """
         Whether this RDD is in a barrier stage.
         """
         return self._jrdd.rdd().isBarrier()
 
-    def withResources(self, profile):
+    def withResources(self: "RDD[T]", profile: ResourceProfile) -> "RDD[T]":
         """
         Specify a :class:`pyspark.resource.ResourceProfile` to use when 
calculating this RDD.
         This is only supported on certain cluster managers and currently 
requires dynamic
@@ -2891,6 +3263,8 @@ class RDD:
         if profile._java_resource_profile is not None:
             jrp = profile._java_resource_profile
         else:
+            assert self.ctx._jvm is not None
+
             builder = 
self.ctx._jvm.org.apache.spark.resource.ResourceProfileBuilder()
             ereqs = ExecutorResourceRequests(self.ctx._jvm, 
profile._executor_resource_requests)
             treqs = TaskResourceRequests(self.ctx._jvm, 
profile._task_resource_requests)
@@ -2901,7 +3275,7 @@ class RDD:
         self._jrdd.withResources(jrp)
         return self
 
-    def getResourceProfile(self):
+    def getResourceProfile(self) -> Optional[ResourceProfile]:
         """
         Get the :class:`pyspark.resource.ResourceProfile` specified with this 
RDD or None
         if it wasn't specified.
@@ -2923,11 +3297,38 @@ class RDD:
         else:
             return None
 
+    @overload
+    def toDF(
+        self: "RDD[RowLike]",
+        schema: Optional[Union[List[str], Tuple[str, ...]]] = None,
+        sampleRatio: Optional[float] = None,
+    ) -> "DataFrame":
+        ...
+
+    @overload
+    def toDF(
+        self: "RDD[RowLike]", schema: Optional[Union["StructType", str]] = None
+    ) -> "DataFrame":
+        ...
+
+    @overload
+    def toDF(
+        self: "RDD[AtomicValue]",
+        schema: Union["AtomicType", str],
+    ) -> "DataFrame":
+        ...
+
+    def toDF(
+        self: "RDD[Any]", schema: Optional[Any] = None, sampleRatio: 
Optional[float] = None
+    ) -> "DataFrame":
+        raise RuntimeError("""RDD.toDF was called before SparkSession was 
initialized.""")
 
-def _prepare_for_python_RDD(sc, command):
+
+def _prepare_for_python_RDD(sc: "SparkContext", command: Any) -> Tuple[bytes, 
Any, Any, Any]:
     # the serialized command will be compressed by broadcast
     ser = CloudPickleSerializer()
     pickled_command = ser.dumps(command)
+    assert sc._jvm is not None
     if len(pickled_command) > 
sc._jvm.PythonUtils.getBroadcastThreshold(sc._jsc):  # Default 1M
         # The broadcast will have same life cycle as created PythonRDD
         broadcast = sc.broadcast(pickled_command)
@@ -2937,11 +3338,14 @@ def _prepare_for_python_RDD(sc, command):
     return pickled_command, broadcast_vars, sc.environment, sc._python_includes
 
 
-def _wrap_function(sc, func, deserializer, serializer, profiler=None):
+def _wrap_function(
+    sc: "SparkContext", func: Callable, deserializer: Any, serializer: Any, 
profiler: Any = None
+) -> "JavaObject":
     assert deserializer, "deserializer should not be empty"
     assert serializer, "serializer should not be empty"
     command = (func, profiler, deserializer, serializer)
     pickled_command, broadcast_vars, env, includes = 
_prepare_for_python_RDD(sc, command)
+    assert sc._jvm is not None
     return sc._jvm.PythonFunction(
         bytearray(pickled_command),
         env,
@@ -2953,7 +3357,7 @@ def _wrap_function(sc, func, deserializer, serializer, 
profiler=None):
     )
 
 
-class RDDBarrier:
+class RDDBarrier(Generic[T]):
 
     """
     Wraps an RDD in a barrier stage, which forces Spark to launch tasks of 
this stage together.
@@ -2966,10 +3370,12 @@ class RDDBarrier:
     This API is experimental
     """
 
-    def __init__(self, rdd):
+    def __init__(self, rdd: RDD[T]):
         self.rdd = rdd
 
-    def mapPartitions(self, f, preservesPartitioning=False):
+    def mapPartitions(
+        self, f: Callable[[Iterable[T]], Iterable[U]], preservesPartitioning: 
bool = False
+    ) -> RDD[U]:
         """
         Returns a new RDD by applying a function to each partition of the 
wrapped RDD,
         where tasks are launched together in a barrier stage.
@@ -2983,12 +3389,16 @@ class RDDBarrier:
         This API is experimental
         """
 
-        def func(s, iterator):
+        def func(s: int, iterator: Iterable[T]) -> Iterable[U]:
             return f(iterator)
 
         return PipelinedRDD(self.rdd, func, preservesPartitioning, 
isFromBarrier=True)
 
-    def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
+    def mapPartitionsWithIndex(
+        self,
+        f: Callable[[int, Iterable[T]], Iterable[U]],
+        preservesPartitioning: bool = False,
+    ) -> RDD[U]:
         """
         Returns a new RDD by applying a function to each partition of the 
wrapped RDD, while
         tracking the index of the original partition. And all tasks are 
launched together
@@ -3005,7 +3415,7 @@ class RDDBarrier:
         return PipelinedRDD(self.rdd, f, preservesPartitioning, 
isFromBarrier=True)
 
 
-class PipelinedRDD(RDD):
+class PipelinedRDD(RDD[U], Generic[T, U]):
 
     """
     Examples
@@ -3027,7 +3437,13 @@ class PipelinedRDD(RDD):
     20
     """
 
-    def __init__(self, prev, func, preservesPartitioning=False, 
isFromBarrier=False):
+    def __init__(
+        self,
+        prev: RDD[T],
+        func: Callable[[int, Iterable[T]], Iterable[U]],
+        preservesPartitioning: bool = False,
+        isFromBarrier: bool = False,
+    ):
         if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable():
             # This transformation is the first in its stage:
             self.func = func
@@ -3035,9 +3451,9 @@ class PipelinedRDD(RDD):
             self._prev_jrdd = prev._jrdd
             self._prev_jrdd_deserializer = prev._jrdd_deserializer
         else:
-            prev_func = prev.func
+            prev_func: Callable[[int, Iterable[V]], Iterable[T]] = prev.func
 
-            def pipeline_func(split, iterator):
+            def pipeline_func(split: int, iterator: Iterable[V]) -> 
Iterable[U]:
                 return func(split, prev_func(split, iterator))
 
             self.func = pipeline_func
@@ -3049,18 +3465,18 @@ class PipelinedRDD(RDD):
         self.is_checkpointed = False
         self.ctx = prev.ctx
         self.prev = prev
-        self._jrdd_val = None
+        self._jrdd_val: Optional["JavaObject"] = None
         self._id = None
         self._jrdd_deserializer = self.ctx.serializer
         self._bypass_serializer = False
         self.partitioner = prev.partitioner if self.preservesPartitioning else 
None
         self.is_barrier = isFromBarrier or prev._is_barrier()
 
-    def getNumPartitions(self):
+    def getNumPartitions(self) -> int:
         return self._prev_jrdd.partitions().size()
 
     @property
-    def _jrdd(self):
+    def _jrdd(self) -> "JavaObject":
         if self._jrdd_val:
             return self._jrdd_val
         if self._bypass_serializer:
@@ -3074,29 +3490,32 @@ class PipelinedRDD(RDD):
         wrapped_func = _wrap_function(
             self.ctx, self.func, self._prev_jrdd_deserializer, 
self._jrdd_deserializer, profiler
         )
+
+        assert self.ctx._jvm is not None
         python_rdd = self.ctx._jvm.PythonRDD(
             self._prev_jrdd.rdd(), wrapped_func, self.preservesPartitioning, 
self.is_barrier
         )
         self._jrdd_val = python_rdd.asJavaRDD()
 
         if profiler:
+            assert self._jrdd_val is not None
             self._id = self._jrdd_val.id()
             self.ctx.profiler_collector.add_profiler(self._id, profiler)
         return self._jrdd_val
 
-    def id(self):
+    def id(self) -> int:
         if self._id is None:
             self._id = self._jrdd.id()
         return self._id
 
-    def _is_pipelinable(self):
+    def _is_pipelinable(self) -> bool:
         return not (self.is_cached or self.is_checkpointed or 
self.has_resource_profile)
 
-    def _is_barrier(self):
+    def _is_barrier(self) -> bool:
         return self.is_barrier
 
 
-def _test():
+def _test() -> None:
     import doctest
     from pyspark.context import SparkContext
 
diff --git a/python/pyspark/rdd.pyi b/python/pyspark/rdd.pyi
deleted file mode 100644
index c4eddbf..0000000
--- a/python/pyspark/rdd.pyi
+++ /dev/null
@@ -1,481 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from typing import overload
-from typing import (
-    Any,
-    Callable,
-    Dict,
-    Generic,
-    Hashable,
-    Iterable,
-    Iterator,
-    List,
-    Optional,
-    Tuple,
-    Union,
-    TypeVar,
-)
-from typing_extensions import Literal
-
-from numpy import int32, int64, float32, float64, ndarray
-
-from pyspark._typing import SupportsOrdering
-from pyspark.sql.pandas._typing import (
-    PandasScalarUDFType,
-    PandasScalarIterUDFType,
-    PandasGroupedMapUDFType,
-    PandasCogroupedMapUDFType,
-    PandasGroupedAggUDFType,
-    PandasMapIterUDFType,
-    ArrowMapIterUDFType,
-)
-import pyspark.context
-from pyspark.resultiterable import ResultIterable
-from pyspark.serializers import Serializer
-from pyspark.storagelevel import StorageLevel
-from pyspark.resource.requests import (  # noqa: F401
-    ExecutorResourceRequests,
-    TaskResourceRequests,
-)
-from pyspark.resource.profile import ResourceProfile
-from pyspark.statcounter import StatCounter
-from pyspark.sql.dataframe import DataFrame
-from pyspark.sql.types import AtomicType, StructType
-from pyspark.sql._typing import AtomicValue, RowLike
-from py4j.java_gateway import JavaObject  # type: ignore[import]
-
-T = TypeVar("T")
-T_co = TypeVar("T_co", covariant=True)
-U = TypeVar("U")
-K = TypeVar("K", bound=Hashable)
-V = TypeVar("V")
-V1 = TypeVar("V1")
-V2 = TypeVar("V2")
-V3 = TypeVar("V3")
-O = TypeVar("O", bound=SupportsOrdering)
-NumberOrArray = TypeVar(
-    "NumberOrArray", float, int, complex, int32, int64, float32, float64, 
ndarray
-)
-
-def portable_hash(x: Hashable) -> int: ...
-
-class PythonEvalType:
-    NON_UDF: Literal[0]
-    SQL_BATCHED_UDF: Literal[100]
-    SQL_SCALAR_PANDAS_UDF: PandasScalarUDFType
-    SQL_GROUPED_MAP_PANDAS_UDF: PandasGroupedMapUDFType
-    SQL_GROUPED_AGG_PANDAS_UDF: PandasGroupedAggUDFType
-    SQL_WINDOW_AGG_PANDAS_UDF: Literal[203]
-    SQL_SCALAR_PANDAS_ITER_UDF: PandasScalarIterUDFType
-    SQL_MAP_PANDAS_ITER_UDF: PandasMapIterUDFType
-    SQL_COGROUPED_MAP_PANDAS_UDF: PandasCogroupedMapUDFType
-    SQL_MAP_ARROW_ITER_UDF: ArrowMapIterUDFType
-
-class BoundedFloat(float):
-    def __new__(cls, mean: float, confidence: float, low: float, high: float) 
-> BoundedFloat: ...
-
-class Partitioner:
-    numPartitions: int
-    partitionFunc: Callable[[Any], int]
-    def __init__(self, numPartitions: int, partitionFunc: Callable[[Any], 
int]) -> None: ...
-    def __eq__(self, other: Any) -> bool: ...
-    def __call__(self, k: Any) -> int: ...
-
-class RDD(Generic[T_co]):
-    is_cached: bool
-    is_checkpointed: bool
-    ctx: pyspark.context.SparkContext
-    partitioner: Optional[Partitioner]
-    def __init__(
-        self,
-        jrdd: JavaObject,
-        ctx: pyspark.context.SparkContext,
-        jrdd_deserializer: Serializer = ...,
-    ) -> None: ...
-    def id(self) -> int: ...
-    def __getnewargs__(self) -> Any: ...
-    @property
-    def context(self) -> pyspark.context.SparkContext: ...
-    def cache(self) -> RDD[T_co]: ...
-    def persist(self, storageLevel: StorageLevel = ...) -> RDD[T_co]: ...
-    def unpersist(self, blocking: bool = ...) -> RDD[T_co]: ...
-    def checkpoint(self) -> None: ...
-    def isCheckpointed(self) -> bool: ...
-    def localCheckpoint(self) -> None: ...
-    def isLocallyCheckpointed(self) -> bool: ...
-    def getCheckpointFile(self) -> Optional[str]: ...
-    def map(self, f: Callable[[T_co], U], preservesPartitioning: bool = ...) 
-> RDD[U]: ...
-    def flatMap(
-        self, f: Callable[[T_co], Iterable[U]], preservesPartitioning: bool = 
...
-    ) -> RDD[U]: ...
-    def mapPartitions(
-        self, f: Callable[[Iterable[T_co]], Iterable[U]], 
preservesPartitioning: bool = ...
-    ) -> RDD[U]: ...
-    def mapPartitionsWithIndex(
-        self,
-        f: Callable[[int, Iterable[T_co]], Iterable[U]],
-        preservesPartitioning: bool = ...,
-    ) -> RDD[U]: ...
-    def mapPartitionsWithSplit(
-        self,
-        f: Callable[[int, Iterable[T_co]], Iterable[U]],
-        preservesPartitioning: bool = ...,
-    ) -> RDD[U]: ...
-    def getNumPartitions(self) -> int: ...
-    def filter(self, f: Callable[[T_co], bool]) -> RDD[T_co]: ...
-    def distinct(self, numPartitions: Optional[int] = ...) -> RDD[T_co]: ...
-    def sample(
-        self, withReplacement: bool, fraction: float, seed: Optional[int] = ...
-    ) -> RDD[T_co]: ...
-    def randomSplit(
-        self, weights: List[Union[int, float]], seed: Optional[int] = ...
-    ) -> List[RDD[T_co]]: ...
-    def takeSample(
-        self, withReplacement: bool, num: int, seed: Optional[int] = ...
-    ) -> List[T_co]: ...
-    def union(self, other: RDD[U]) -> RDD[Union[T_co, U]]: ...
-    def intersection(self, other: RDD[T_co]) -> RDD[T_co]: ...
-    def __add__(self, other: RDD[T_co]) -> RDD[T_co]: ...
-    @overload
-    def repartitionAndSortWithinPartitions(
-        self: RDD[Tuple[O, V]],
-        numPartitions: Optional[int] = ...,
-        partitionFunc: Callable[[O], int] = ...,
-        ascending: bool = ...,
-    ) -> RDD[Tuple[O, V]]: ...
-    @overload
-    def repartitionAndSortWithinPartitions(
-        self: RDD[Tuple[K, V]],
-        numPartitions: Optional[int],
-        partitionFunc: Callable[[K], int],
-        ascending: bool,
-        keyfunc: Callable[[K], O],
-    ) -> RDD[Tuple[K, V]]: ...
-    @overload
-    def repartitionAndSortWithinPartitions(
-        self: RDD[Tuple[K, V]],
-        numPartitions: Optional[int] = ...,
-        partitionFunc: Callable[[K], int] = ...,
-        ascending: bool = ...,
-        *,
-        keyfunc: Callable[[K], O],
-    ) -> RDD[Tuple[K, V]]: ...
-    @overload
-    def sortByKey(
-        self: RDD[Tuple[O, V]],
-        ascending: bool = ...,
-        numPartitions: Optional[int] = ...,
-    ) -> RDD[Tuple[K, V]]: ...
-    @overload
-    def sortByKey(
-        self: RDD[Tuple[K, V]],
-        ascending: bool,
-        numPartitions: int,
-        keyfunc: Callable[[K], O],
-    ) -> RDD[Tuple[K, V]]: ...
-    @overload
-    def sortByKey(
-        self: RDD[Tuple[K, V]],
-        ascending: bool = ...,
-        numPartitions: Optional[int] = ...,
-        *,
-        keyfunc: Callable[[K], O],
-    ) -> RDD[Tuple[K, V]]: ...
-    def sortBy(
-        self,
-        keyfunc: Callable[[T_co], O],
-        ascending: bool = ...,
-        numPartitions: Optional[int] = ...,
-    ) -> RDD[T_co]: ...
-    def glom(self) -> RDD[List[T_co]]: ...
-    def cartesian(self, other: RDD[U]) -> RDD[Tuple[T_co, U]]: ...
-    def groupBy(
-        self,
-        f: Callable[[T_co], K],
-        numPartitions: Optional[int] = ...,
-        partitionFunc: Callable[[K], int] = ...,
-    ) -> RDD[Tuple[K, Iterable[T_co]]]: ...
-    def pipe(
-        self, command: str, env: Optional[Dict[str, str]] = ..., checkCode: 
bool = ...
-    ) -> RDD[str]: ...
-    def foreach(self, f: Callable[[T_co], None]) -> None: ...
-    def foreachPartition(self, f: Callable[[Iterable[T_co]], None]) -> None: 
...
-    def collect(self) -> List[T_co]: ...
-    def collectWithJobGroup(
-        self, groupId: str, description: str, interruptOnCancel: bool = ...
-    ) -> List[T_co]: ...
-    def reduce(self, f: Callable[[T_co, T_co], T_co]) -> T_co: ...
-    def treeReduce(self, f: Callable[[T_co, T_co], T_co], depth: int = ...) -> 
T_co: ...
-    def fold(self, zeroValue: T, op: Callable[[T_co, T_co], T_co]) -> T_co: ...
-    def aggregate(
-        self, zeroValue: U, seqOp: Callable[[U, T_co], U], combOp: 
Callable[[U, U], U]
-    ) -> U: ...
-    def treeAggregate(
-        self,
-        zeroValue: U,
-        seqOp: Callable[[U, T_co], U],
-        combOp: Callable[[U, U], U],
-        depth: int = ...,
-    ) -> U: ...
-    @overload
-    def max(self: RDD[O]) -> O: ...
-    @overload
-    def max(self, key: Callable[[T_co], O]) -> T_co: ...
-    @overload
-    def min(self: RDD[O]) -> O: ...
-    @overload
-    def min(self, key: Callable[[T_co], O]) -> T_co: ...
-    def sum(self: RDD[NumberOrArray]) -> NumberOrArray: ...
-    def count(self) -> int: ...
-    def stats(self: RDD[NumberOrArray]) -> StatCounter: ...
-    def histogram(
-        self, buckets: Union[int, List[T_co], Tuple[T_co, ...]]
-    ) -> Tuple[List[T_co], List[int]]: ...
-    def mean(self: RDD[NumberOrArray]) -> NumberOrArray: ...
-    def variance(self: RDD[NumberOrArray]) -> NumberOrArray: ...
-    def stdev(self: RDD[NumberOrArray]) -> NumberOrArray: ...
-    def sampleStdev(self: RDD[NumberOrArray]) -> NumberOrArray: ...
-    def sampleVariance(self: RDD[NumberOrArray]) -> NumberOrArray: ...
-    def countByValue(self: RDD[K]) -> Dict[K, int]: ...
-    @overload
-    def top(self: RDD[O], num: int) -> List[O]: ...
-    @overload
-    def top(self, num: int, key: Callable[[T_co], O]) -> List[T_co]: ...
-    @overload
-    def takeOrdered(self: RDD[O], num: int) -> List[O]: ...
-    @overload
-    def takeOrdered(self, num: int, key: Callable[[T_co], O]) -> List[T_co]: 
...
-    def take(self, num: int) -> List[T_co]: ...
-    def first(self) -> T_co: ...
-    def isEmpty(self) -> bool: ...
-    def saveAsNewAPIHadoopDataset(
-        self: RDD[Tuple[K, V]],
-        conf: Dict[str, str],
-        keyConverter: Optional[str] = ...,
-        valueConverter: Optional[str] = ...,
-    ) -> None: ...
-    def saveAsNewAPIHadoopFile(
-        self: RDD[Tuple[K, V]],
-        path: str,
-        outputFormatClass: str,
-        keyClass: Optional[str] = ...,
-        valueClass: Optional[str] = ...,
-        keyConverter: Optional[str] = ...,
-        valueConverter: Optional[str] = ...,
-        conf: Optional[Dict[str, str]] = ...,
-    ) -> None: ...
-    def saveAsHadoopDataset(
-        self: RDD[Tuple[K, V]],
-        conf: Dict[str, str],
-        keyConverter: Optional[str] = ...,
-        valueConverter: Optional[str] = ...,
-    ) -> None: ...
-    def saveAsHadoopFile(
-        self: RDD[Tuple[K, V]],
-        path: str,
-        outputFormatClass: str,
-        keyClass: Optional[str] = ...,
-        valueClass: Optional[str] = ...,
-        keyConverter: Optional[str] = ...,
-        valueConverter: Optional[str] = ...,
-        conf: Optional[str] = ...,
-        compressionCodecClass: Optional[str] = ...,
-    ) -> None: ...
-    def saveAsSequenceFile(
-        self: RDD[Tuple[K, V]], path: str, compressionCodecClass: 
Optional[str] = ...
-    ) -> None: ...
-    def saveAsPickleFile(self, path: str, batchSize: int = ...) -> None: ...
-    def saveAsTextFile(self, path: str, compressionCodecClass: Optional[str] = 
...) -> None: ...
-    def collectAsMap(self: RDD[Tuple[K, V]]) -> Dict[K, V]: ...
-    def keys(self: RDD[Tuple[K, V]]) -> RDD[K]: ...
-    def values(self: RDD[Tuple[K, V]]) -> RDD[V]: ...
-    def reduceByKey(
-        self: RDD[Tuple[K, V]],
-        func: Callable[[V, V], V],
-        numPartitions: Optional[int] = ...,
-        partitionFunc: Callable[[K], int] = ...,
-    ) -> RDD[Tuple[K, V]]: ...
-    def reduceByKeyLocally(self: RDD[Tuple[K, V]], func: Callable[[V, V], V]) 
-> Dict[K, V]: ...
-    def countByKey(self: RDD[Tuple[K, V]]) -> Dict[K, int]: ...
-    def join(
-        self: RDD[Tuple[K, V]],
-        other: RDD[Tuple[K, U]],
-        numPartitions: Optional[int] = ...,
-    ) -> RDD[Tuple[K, Tuple[V, U]]]: ...
-    def leftOuterJoin(
-        self: RDD[Tuple[K, V]],
-        other: RDD[Tuple[K, U]],
-        numPartitions: Optional[int] = ...,
-    ) -> RDD[Tuple[K, Tuple[V, Optional[U]]]]: ...
-    def rightOuterJoin(
-        self: RDD[Tuple[K, V]],
-        other: RDD[Tuple[K, U]],
-        numPartitions: Optional[int] = ...,
-    ) -> RDD[Tuple[K, Tuple[Optional[V], U]]]: ...
-    def fullOuterJoin(
-        self: RDD[Tuple[K, V]],
-        other: RDD[Tuple[K, U]],
-        numPartitions: Optional[int] = ...,
-    ) -> RDD[Tuple[K, Tuple[Optional[V], Optional[U]]]]: ...
-    def partitionBy(
-        self: RDD[Tuple[K, V]],
-        numPartitions: int,
-        partitionFunc: Callable[[K], int] = ...,
-    ) -> RDD[Tuple[K, V]]: ...
-    def combineByKey(
-        self: RDD[Tuple[K, V]],
-        createCombiner: Callable[[V], U],
-        mergeValue: Callable[[U, V], U],
-        mergeCombiners: Callable[[U, U], U],
-        numPartitions: Optional[int] = ...,
-        partitionFunc: Callable[[K], int] = ...,
-    ) -> RDD[Tuple[K, U]]: ...
-    def aggregateByKey(
-        self: RDD[Tuple[K, V]],
-        zeroValue: U,
-        seqFunc: Callable[[U, V], U],
-        combFunc: Callable[[U, U], U],
-        numPartitions: Optional[int] = ...,
-        partitionFunc: Callable[[K], int] = ...,
-    ) -> RDD[Tuple[K, U]]: ...
-    def foldByKey(
-        self: RDD[Tuple[K, V]],
-        zeroValue: V,
-        func: Callable[[V, V], V],
-        numPartitions: Optional[int] = ...,
-        partitionFunc: Callable[[K], int] = ...,
-    ) -> RDD[Tuple[K, V]]: ...
-    def groupByKey(
-        self: RDD[Tuple[K, V]],
-        numPartitions: Optional[int] = ...,
-        partitionFunc: Callable[[K], int] = ...,
-    ) -> RDD[Tuple[K, Iterable[V]]]: ...
-    def flatMapValues(
-        self: RDD[Tuple[K, V]], f: Callable[[V], Iterable[U]]
-    ) -> RDD[Tuple[K, U]]: ...
-    def mapValues(self: RDD[Tuple[K, V]], f: Callable[[V], U]) -> RDD[Tuple[K, 
U]]: ...
-    @overload
-    def groupWith(
-        self: RDD[Tuple[K, V]], __o: RDD[Tuple[K, V1]]
-    ) -> RDD[Tuple[K, Tuple[ResultIterable[V], ResultIterable[V1]]]]: ...
-    @overload
-    def groupWith(
-        self: RDD[Tuple[K, V]], __o1: RDD[Tuple[K, V1]], __o2: RDD[Tuple[K, 
V2]]
-    ) -> RDD[Tuple[K, Tuple[ResultIterable[V], ResultIterable[V1], 
ResultIterable[V2]]]]: ...
-    @overload
-    def groupWith(
-        self: RDD[Tuple[K, V]],
-        other1: RDD[Tuple[K, V1]],
-        other2: RDD[Tuple[K, V2]],
-        other3: RDD[Tuple[K, V3]],
-    ) -> RDD[
-        Tuple[
-            K,
-            Tuple[
-                ResultIterable[V],
-                ResultIterable[V1],
-                ResultIterable[V2],
-                ResultIterable[V3],
-            ],
-        ]
-    ]: ...
-    def cogroup(
-        self: RDD[Tuple[K, V]],
-        other: RDD[Tuple[K, U]],
-        numPartitions: Optional[int] = ...,
-    ) -> RDD[Tuple[K, Tuple[ResultIterable[V], ResultIterable[U]]]]: ...
-    def sampleByKey(
-        self: RDD[Tuple[K, V]],
-        withReplacement: bool,
-        fractions: Dict[K, Union[float, int]],
-        seed: Optional[int] = ...,
-    ) -> RDD[Tuple[K, V]]: ...
-    def subtractByKey(
-        self: RDD[Tuple[K, V]],
-        other: RDD[Tuple[K, U]],
-        numPartitions: Optional[int] = ...,
-    ) -> RDD[Tuple[K, V]]: ...
-    def subtract(self, other: RDD[T_co], numPartitions: Optional[int] = ...) 
-> RDD[T_co]: ...
-    def keyBy(self, f: Callable[[T_co], K]) -> RDD[Tuple[K, T_co]]: ...
-    def repartition(self, numPartitions: int) -> RDD[T_co]: ...
-    def coalesce(self, numPartitions: int, shuffle: bool = ...) -> RDD[T_co]: 
...
-    def zip(self, other: RDD[U]) -> RDD[Tuple[T_co, U]]: ...
-    def zipWithIndex(self) -> RDD[Tuple[T_co, int]]: ...
-    def zipWithUniqueId(self) -> RDD[Tuple[T_co, int]]: ...
-    def name(self) -> str: ...
-    def setName(self, name: str) -> RDD[T_co]: ...
-    def toDebugString(self) -> bytes: ...
-    def getStorageLevel(self) -> StorageLevel: ...
-    def lookup(self: RDD[Tuple[K, V]], key: K) -> List[V]: ...
-    def countApprox(self, timeout: int, confidence: float = ...) -> int: ...
-    def sumApprox(
-        self: RDD[Union[float, int]], timeout: int, confidence: float = ...
-    ) -> BoundedFloat: ...
-    def meanApprox(
-        self: RDD[Union[float, int]], timeout: int, confidence: float = ...
-    ) -> BoundedFloat: ...
-    def countApproxDistinct(self, relativeSD: float = ...) -> int: ...
-    def toLocalIterator(self, prefetchPartitions: bool = ...) -> 
Iterator[T_co]: ...
-    def barrier(self) -> RDDBarrier[T_co]: ...
-    def withResources(self, profile: ResourceProfile) -> RDD[T_co]: ...
-    def getResourceProfile(self) -> Optional[ResourceProfile]: ...
-    @overload
-    def toDF(
-        self: RDD[RowLike],
-        schema: Optional[Union[List[str], Tuple[str, ...]]] = ...,
-        sampleRatio: Optional[float] = ...,
-    ) -> DataFrame: ...
-    @overload
-    def toDF(self: RDD[RowLike], schema: Optional[Union[StructType, str]] = 
...) -> DataFrame: ...
-    @overload
-    def toDF(
-        self: RDD[AtomicValue],
-        schema: Union[AtomicType, str],
-    ) -> DataFrame: ...
-
-class RDDBarrier(Generic[T]):
-    rdd: RDD[T]
-    def __init__(self, rdd: RDD[T]) -> None: ...
-    def mapPartitions(
-        self, f: Callable[[Iterable[T]], Iterable[U]], preservesPartitioning: 
bool = ...
-    ) -> RDD[U]: ...
-    def mapPartitionsWithIndex(
-        self,
-        f: Callable[[int, Iterable[T]], Iterable[U]],
-        preservesPartitioning: bool = ...,
-    ) -> RDD[U]: ...
-
-class PipelinedRDD(RDD[U], Generic[T, U]):
-    func: Callable[[T], U]
-    preservesPartitioning: bool
-    is_cached: bool
-    is_checkpointed: bool
-    ctx: pyspark.context.SparkContext
-    prev: RDD[T]
-    partitioner: Optional[Partitioner]
-    is_barrier: bool
-    def __init__(
-        self,
-        prev: RDD[T],
-        func: Callable[[Iterable[T]], Iterable[U]],
-        preservesPartitioning: bool = ...,
-        isFromBarrier: bool = ...,
-    ) -> None: ...
-    def getNumPartitions(self) -> int: ...
-    def id(self) -> int: ...
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index a0941afd..628ef18 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -100,6 +100,13 @@ class Serializer:
         """
         raise NotImplementedError
 
+    def dumps(self, obj):
+        """
+        Serialize an object into a byte array.
+        When batching is used, this will be called with an array of objects.
+        """
+        raise NotImplementedError
+
     def _load_stream_without_unbatching(self, stream):
         """
         Return an iterator of deserialized batches (iterable) of objects from 
the input stream.
diff --git a/python/pyspark/sql/_typing.pyi b/python/pyspark/sql/_typing.pyi
index 2adae6c..209bb70 100644
--- a/python/pyspark/sql/_typing.pyi
+++ b/python/pyspark/sql/_typing.pyi
@@ -25,7 +25,7 @@ from typing import (
     TypeVar,
     Union,
 )
-from typing_extensions import Protocol
+from typing_extensions import Literal, Protocol
 
 import datetime
 import decimal
@@ -56,6 +56,8 @@ AtomicValue = TypeVar(
 
 RowLike = TypeVar("RowLike", List[Any], Tuple[Any, ...], pyspark.sql.types.Row)
 
+SQLBatchedUDFType = Literal[100]
+
 class SupportsOpen(Protocol):
     def open(self, partition_id: int, epoch_id: int) -> bool: ...
 
diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi 
b/python/pyspark/sql/pandas/_typing/__init__.pyi
index d3796f4..906703d 100644
--- a/python/pyspark/sql/pandas/_typing/__init__.pyi
+++ b/python/pyspark/sql/pandas/_typing/__init__.pyi
@@ -42,11 +42,12 @@ DataFrameOrSeriesLike_ = TypeVar("DataFrameOrSeriesLike_", 
bound=DataFrameOrSeri
 
 # UDF annotations
 PandasScalarUDFType = Literal[200]
-PandasScalarIterUDFType = Literal[204]
 PandasGroupedMapUDFType = Literal[201]
-PandasCogroupedMapUDFType = Literal[206]
 PandasGroupedAggUDFType = Literal[202]
+PandasWindowAggUDFType = Literal[203]
+PandasScalarIterUDFType = Literal[204]
 PandasMapIterUDFType = Literal[205]
+PandasCogroupedMapUDFType = Literal[206]
 ArrowMapIterUDFType = Literal[207]
 
 class PandasVariadicScalarToScalarFunction(Protocol):
diff --git a/python/pyspark/tests/typing/test_rdd.yml 
b/python/pyspark/tests/typing/test_rdd.yml
index 749ad53..4896582 100644
--- a/python/pyspark/tests/typing/test_rdd.yml
+++ b/python/pyspark/tests/typing/test_rdd.yml
@@ -18,11 +18,11 @@
 - case: toDF
   main: |
     from pyspark.sql.types import (
-      IntegerType,
-      Row,
-      StructType,
-      StringType,
-      StructField,
+        IntegerType,
+        Row,
+        StructType,
+        StringType,
+        StructField,
     )
     from collections import namedtuple
     from pyspark.sql import SparkSession
@@ -60,3 +60,70 @@
     rdd_named_tuple.toDF(sampleRatio=0.4)
     rdd_named_tuple.toDF(["a", "b"], sampleRatio=0.4)
     rdd_named_tuple.toDF(struct)
+
+
+- case: rddMethods
+  main: |
+    from operator import add
+    from typing import Iterable, Set, Tuple
+    from pyspark.sql import SparkSession
+
+    spark = SparkSession.builder.getOrCreate()
+    sc = spark.sparkContext
+
+    def f1(x: int) -> str:
+        return str(x)
+
+    reveal_type(sc.range(10).map(f1))
+
+    def f2(x: int) -> Iterable[int]:
+        return range(x)
+
+    reveal_type(sc.range(10).flatMap(f2))
+
+    reveal_type(sc.parallelize([("a", 1), ("b", 0)]).filter(lambda x: x[1] != 
0))
+
+    reveal_type(sc.parallelize([("a", 1), ("b", 0)]).max())
+
+    reveal_type(sc.range(10).reduce(add))
+
+    def seq_func(xs: Set[str], x: int) -> Set[str]:
+        xs.add(str(x % 11))
+        return xs
+
+    def comb_func(xs: Set[str], ys: Set[str]) -> Set[str]:
+        xs.update(ys)
+        return xs
+
+    zero: Set[str] = set()
+
+    reveal_type(sc.parallelize([("a", 1)]).aggregateByKey(zero, seq_func, 
comb_func))
+
+  out: |
+     main:11: note: Revealed type is "pyspark.rdd.RDD[builtins.str*]"
+     main:16: note: Revealed type is "pyspark.rdd.RDD[builtins.int*]"
+     main:18: note: Revealed type is "pyspark.rdd.RDD[Tuple[builtins.str, 
builtins.int]]"
+     main:20: note: Revealed type is "Tuple[builtins.str, builtins.int]"
+     main:22: note: Revealed type is "builtins.int"
+     main:34: note: Revealed type is "pyspark.rdd.RDD[Tuple[builtins.str, 
builtins.set[builtins.str]]]"
+
+- case: rddMethodsErrors
+  main: |
+    from pyspark.sql import SparkSession
+
+    spark = SparkSession.builder.getOrCreate()
+    sc = spark.sparkContext
+
+    def f1(x: str) -> str:
+        return x
+
+    sc.range(10).map(f1)
+
+    def f2(x: int) -> str:
+        return str(x)
+
+    sc.range(10).reduce(f2)
+
+  out: |
+    main:9: error: Argument 1 to "map" of "RDD" has incompatible type 
"Callable[[str], str]"; expected "Callable[[int], str]"  [arg-type]
+    main:14: error: Argument 1 to "reduce" of "RDD" has incompatible type 
"Callable[[int], str]"; expected "Callable[[int, int], int]"  [arg-type]

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to