This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8c401be  [SPARK-35901][PYTHON] Refine type hints in 
pyspark.pandas.window
8c401be is described below

commit 8c401beb806267d4c23aeb27ab8898dcc3a0f98d
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Mon Jun 28 12:23:32 2021 +0900

    [SPARK-35901][PYTHON] Refine type hints in pyspark.pandas.window
    
    ### What changes were proposed in this pull request?
    
    Refines type hints in `pyspark.pandas.window`.
    
    Also, some refactoring is included to clean up the type hierarchy of 
`Rolling` and `Expanding`.
    
    ### Why are the changes needed?
    
    We can use more strict type hints for functions in pyspark.pandas.window 
using the generic way.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests.
    
    Closes #33097 from ueshin/issues/SPARK-35901/window.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/pandas/frame.py   |  14 +++
 python/pyspark/pandas/generic.py |  18 +--
 python/pyspark/pandas/groupby.py |  22 ++--
 python/pyspark/pandas/series.py  |  14 +++
 python/pyspark/pandas/window.py  | 249 ++++++++++++++++++---------------------
 5 files changed, 166 insertions(+), 151 deletions(-)

diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 7f26346..6b6301a 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -11676,6 +11676,20 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
         """
         return DataFrame(pd.DataFrame.from_dict(data, orient=orient, 
dtype=dtype, columns=columns))
 
+    # Override the `groupby` to specify the actual return type annotation.
+    def groupby(
+        self,
+        by: Union[Any, Tuple, "Series", List[Union[Any, Tuple, "Series"]]],
+        axis: Union[int, str] = 0,
+        as_index: bool = True,
+        dropna: bool = True,
+    ) -> "DataFrameGroupBy":
+        return cast(
+            "DataFrameGroupBy", super().groupby(by=by, axis=axis, 
as_index=as_index, dropna=dropna)
+        )
+
+    groupby.__doc__ = Frame.groupby.__doc__
+
     def _build_groupby(
         self, by: List[Union["Series", Tuple]], as_index: bool, dropna: bool
     ) -> "DataFrameGroupBy":
diff --git a/python/pyspark/pandas/generic.py b/python/pyspark/pandas/generic.py
index 3a33295..a0c3f23 100644
--- a/python/pyspark/pandas/generic.py
+++ b/python/pyspark/pandas/generic.py
@@ -67,13 +67,13 @@ from pyspark.pandas.utils import (
     validate_axis,
     SPARK_CONF_ARROW_ENABLED,
 )
-from pyspark.pandas.window import Rolling, Expanding
 
 if TYPE_CHECKING:
     from pyspark.pandas.frame import DataFrame  # noqa: F401 (SPARK-34943)
     from pyspark.pandas.indexes.base import Index  # noqa: F401 (SPARK-34943)
     from pyspark.pandas.groupby import GroupBy  # noqa: F401 (SPARK-34943)
     from pyspark.pandas.series import Series  # noqa: F401 (SPARK-34943)
+    from pyspark.pandas.window import Rolling, Expanding  # noqa: F401 
(SPARK-34943)
 
 
 T_Frame = TypeVar("T_Frame", bound="Frame")
@@ -2508,7 +2508,9 @@ class Frame(object, metaclass=ABCMeta):
             return tuple(last_valid_row)
 
     # TODO: 'center', 'win_type', 'on', 'axis' parameter should be implemented.
-    def rolling(self, window: int, min_periods: Optional[int] = None) -> 
Rolling:
+    def rolling(
+        self: T_Frame, window: int, min_periods: Optional[int] = None
+    ) -> "Rolling[T_Frame]":
         """
         Provide rolling transformations.
 
@@ -2533,13 +2535,13 @@ class Frame(object, metaclass=ABCMeta):
         -------
         a Window sub-classed for the particular operation
         """
-        return Rolling(
-            cast(Union["Series", "DataFrame"], self), window=window, 
min_periods=min_periods
-        )
+        from pyspark.pandas.window import Rolling
+
+        return Rolling(self, window=window, min_periods=min_periods)
 
     # TODO: 'center' and 'axis' parameter should be implemented.
     #   'axis' implementation, refer https://github.com/pyspark.pandas/pull/607
-    def expanding(self, min_periods: int = 1) -> Expanding:
+    def expanding(self: T_Frame, min_periods: int = 1) -> "Expanding[T_Frame]":
         """
         Provide expanding transformations.
 
@@ -2557,7 +2559,9 @@ class Frame(object, metaclass=ABCMeta):
         -------
         a Window sub-classed for the particular operation
         """
-        return Expanding(cast(Union["Series", "DataFrame"], self), 
min_periods=min_periods)
+        from pyspark.pandas.window import Expanding
+
+        return Expanding(self, min_periods=min_periods)
 
     def get(self, key: Any, default: Optional[Any] = None) -> Any:
         """
diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py
index 860540e..1620c8c 100644
--- a/python/pyspark/pandas/groupby.py
+++ b/python/pyspark/pandas/groupby.py
@@ -41,6 +41,7 @@ from typing import (
     TypeVar,
     Union,
     cast,
+    TYPE_CHECKING,
 )
 
 import pandas as pd
@@ -85,9 +86,12 @@ from pyspark.pandas.utils import (
     verify_temp_column_name,
 )
 from pyspark.pandas.spark.utils import as_nullable_spark_type, 
force_decimal_precision_scale
-from pyspark.pandas.window import RollingGroupby, ExpandingGroupby
 from pyspark.pandas.exceptions import DataError
 
+if TYPE_CHECKING:
+    from pyspark.pandas.window import RollingGroupby, ExpandingGroupby  # 
noqa: F401 (SPARK-34943)
+
+
 # to keep it the same as pandas
 NamedAgg = namedtuple("NamedAgg", ["column", "aggfunc"])
 
@@ -2320,7 +2324,7 @@ class GroupBy(Generic[T_Frame], metaclass=ABCMeta):
 
         return self._reduce_for_stat_function(stat_function, 
only_numeric=False)
 
-    def rolling(self, window: int, min_periods: Optional[int] = None) -> 
RollingGroupby:
+    def rolling(self, window: int, min_periods: Optional[int] = None) -> 
"RollingGroupby[T_Frame]":
         """
         Return an rolling grouper, providing rolling
         functionality per group.
@@ -2345,11 +2349,11 @@ class GroupBy(Generic[T_Frame], metaclass=ABCMeta):
         Series.groupby
         DataFrame.groupby
         """
-        return RollingGroupby(
-            cast(Union[SeriesGroupBy, DataFrameGroupBy], self), window, 
min_periods=min_periods
-        )
+        from pyspark.pandas.window import RollingGroupby
 
-    def expanding(self, min_periods: int = 1) -> ExpandingGroupby:
+        return RollingGroupby(self, window, min_periods=min_periods)
+
+    def expanding(self, min_periods: int = 1) -> "ExpandingGroupby[T_Frame]":
         """
         Return an expanding grouper, providing expanding
         functionality per group.
@@ -2369,9 +2373,9 @@ class GroupBy(Generic[T_Frame], metaclass=ABCMeta):
         Series.groupby
         DataFrame.groupby
         """
-        return ExpandingGroupby(
-            cast(Union[SeriesGroupBy, DataFrameGroupBy], self), 
min_periods=min_periods
-        )
+        from pyspark.pandas.window import ExpandingGroupby
+
+        return ExpandingGroupby(self, min_periods=min_periods)
 
     def get_group(self, name: Union[Any, Tuple, List[Union[Any, Tuple]]]) -> 
T_Frame:
         """
diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py
index 3de7243..b4e95ad 100644
--- a/python/pyspark/pandas/series.py
+++ b/python/pyspark/pandas/series.py
@@ -6216,6 +6216,20 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
         result = unpack_scalar(self._internal.spark_frame.select(scol))
         return result if result is not None else np.nan
 
+    # Override the `groupby` to specify the actual return type annotation.
+    def groupby(
+        self,
+        by: Union[Any, Tuple, "Series", List[Union[Any, Tuple, "Series"]]],
+        axis: Union[int, str_type] = 0,
+        as_index: bool = True,
+        dropna: bool = True,
+    ) -> "SeriesGroupBy":
+        return cast(
+            "SeriesGroupBy", super().groupby(by=by, axis=axis, 
as_index=as_index, dropna=dropna)
+        )
+
+    groupby.__doc__ = Frame.groupby.__doc__
+
     def _build_groupby(
         self, by: List[Union["Series", Tuple]], as_index: bool, dropna: bool
     ) -> "SeriesGroupBy":
diff --git a/python/pyspark/pandas/window.py b/python/pyspark/pandas/window.py
index 8c9a59d..b1ee83f 100644
--- a/python/pyspark/pandas/window.py
+++ b/python/pyspark/pandas/window.py
@@ -14,15 +14,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+from abc import ABCMeta, abstractmethod
 from functools import partial
 from typing import (  # noqa: F401 (SPARK-34943)
     Any,
-    Union,
-    TYPE_CHECKING,
     Callable,
+    Generic,
     List,
-    cast,
     Optional,
+    TypeVar,
 )
 
 from pyspark.sql import Window
@@ -42,18 +42,15 @@ from pyspark.pandas.utils import scol_for
 from pyspark.sql.column import Column
 from pyspark.sql.window import WindowSpec
 
-if TYPE_CHECKING:
-    from pyspark.pandas.frame import DataFrame  # noqa: F401 (SPARK-34943)
-    from pyspark.pandas.series import Series  # noqa: F401 (SPARK-34943)
-    from pyspark.pandas.groupby import SeriesGroupBy  # noqa: F401 
(SPARK-34943)
-    from pyspark.pandas.groupby import DataFrameGroupBy  # noqa: F401 
(SPARK-34943)
+from pyspark.pandas.generic import Frame
+from pyspark.pandas.groupby import GroupBy
 
 
-class RollingAndExpanding(object):
-    def __init__(
-        self, psdf_or_psser: Union["Series", "DataFrame"], window: WindowSpec, 
min_periods: int
-    ):
-        self._psdf_or_psser = psdf_or_psser
+T_Frame = TypeVar("T_Frame", bound=Frame)
+
+
+class RollingAndExpanding(Generic[T_Frame], metaclass=ABCMeta):
+    def __init__(self, window: WindowSpec, min_periods: int):
         self._window = window
         # This unbounded Window is later used to handle 'min_periods' for now.
         self._unbounded_window = 
Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween(
@@ -61,28 +58,20 @@ class RollingAndExpanding(object):
         )
         self._min_periods = min_periods
 
-    def _apply_as_series_or_frame(
-        self, func: Callable[[Column], Column]
-    ) -> Union["Series", "DataFrame"]:
+    @abstractmethod
+    def _apply_as_series_or_frame(self, func: Callable[[Column], Column]) -> 
T_Frame:
         """
         Wraps a function that handles Spark column in order
         to support it in both pandas-on-Spark Series and DataFrame.
         Note that the given `func` name should be same as the API's method 
name.
         """
-        raise NotImplementedError(
-            "A class that inherits this class should implement this method "
-            "to handle the index and columns of output."
-        )
+        pass
 
-    def count(self) -> Union["Series", "DataFrame"]:
-        def count(scol: Column) -> Column:
-            return F.count(scol).over(self._window)
-
-        return cast(
-            Union["Series", "DataFrame"], 
self._apply_as_series_or_frame(count).astype("float64")
-        )
+    @abstractmethod
+    def count(self) -> T_Frame:
+        pass
 
-    def sum(self) -> Union["Series", "DataFrame"]:
+    def sum(self) -> T_Frame:
         def sum(scol: Column) -> Column:
             return F.when(
                 F.row_number().over(self._unbounded_window) >= 
self._min_periods,
@@ -91,7 +80,7 @@ class RollingAndExpanding(object):
 
         return self._apply_as_series_or_frame(sum)
 
-    def min(self) -> Union["Series", "DataFrame"]:
+    def min(self) -> T_Frame:
         def min(scol: Column) -> Column:
             return F.when(
                 F.row_number().over(self._unbounded_window) >= 
self._min_periods,
@@ -100,7 +89,7 @@ class RollingAndExpanding(object):
 
         return self._apply_as_series_or_frame(min)
 
-    def max(self) -> Union["Series", "DataFrame"]:
+    def max(self) -> T_Frame:
         def max(scol: Column) -> Column:
             return F.when(
                 F.row_number().over(self._unbounded_window) >= 
self._min_periods,
@@ -109,7 +98,7 @@ class RollingAndExpanding(object):
 
         return self._apply_as_series_or_frame(max)
 
-    def mean(self) -> Union["Series", "DataFrame"]:
+    def mean(self) -> T_Frame:
         def mean(scol: Column) -> Column:
             return F.when(
                 F.row_number().over(self._unbounded_window) >= 
self._min_periods,
@@ -118,7 +107,7 @@ class RollingAndExpanding(object):
 
         return self._apply_as_series_or_frame(mean)
 
-    def std(self) -> Union["Series", "DataFrame"]:
+    def std(self) -> T_Frame:
         def std(scol: Column) -> Column:
             return F.when(
                 F.row_number().over(self._unbounded_window) >= 
self._min_periods,
@@ -127,7 +116,7 @@ class RollingAndExpanding(object):
 
         return self._apply_as_series_or_frame(std)
 
-    def var(self) -> Union["Series", "DataFrame"]:
+    def var(self) -> T_Frame:
         def var(scol: Column) -> Column:
             return F.when(
                 F.row_number().over(self._unbounded_window) >= 
self._min_periods,
@@ -137,15 +126,12 @@ class RollingAndExpanding(object):
         return self._apply_as_series_or_frame(var)
 
 
-class Rolling(RollingAndExpanding):
+class RollingLike(RollingAndExpanding[T_Frame]):
     def __init__(
         self,
-        psdf_or_psser: Union["Series", "DataFrame"],
         window: int,
         min_periods: Optional[int] = None,
     ):
-        from pyspark.pandas import DataFrame, Series
-
         if window < 0:
             raise ValueError("window must be >= 0")
         if (min_periods is not None) and (min_periods < 0):
@@ -155,17 +141,37 @@ class Rolling(RollingAndExpanding):
             #  a value.
             min_periods = window
 
+        window_spec = Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween(
+            Window.currentRow - (window - 1), Window.currentRow
+        )
+
+        super().__init__(window_spec, min_periods)
+
+    def count(self) -> T_Frame:
+        def count(scol: Column) -> Column:
+            return F.count(scol).over(self._window)
+
+        return self._apply_as_series_or_frame(count).astype("float64")  # 
type: ignore
+
+
+class Rolling(RollingLike[T_Frame]):
+    def __init__(
+        self,
+        psdf_or_psser: T_Frame,
+        window: int,
+        min_periods: Optional[int] = None,
+    ):
+        from pyspark.pandas.frame import DataFrame
+        from pyspark.pandas.series import Series
+
+        super().__init__(window, min_periods)
+
         if not isinstance(psdf_or_psser, (DataFrame, Series)):
             raise TypeError(
                 "psdf_or_psser must be a series or dataframe; however, got: %s"
                 % type(psdf_or_psser)
             )
-
-        window_spec = Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween(
-            Window.currentRow - (window - 1), Window.currentRow
-        )
-
-        super().__init__(psdf_or_psser, window_spec, min_periods)
+        self._psdf_or_psser = psdf_or_psser
 
     def __getattr__(self, item: str) -> Any:
         if hasattr(MissingPandasLikeRolling, item):
@@ -176,15 +182,13 @@ class Rolling(RollingAndExpanding):
                 return partial(property_or_func, self)
         raise AttributeError(item)
 
-    def _apply_as_series_or_frame(
-        self, func: Callable[[Column], Column]
-    ) -> Union["Series", "DataFrame"]:
+    def _apply_as_series_or_frame(self, func: Callable[[Column], Column]) -> 
T_Frame:
         return self._psdf_or_psser._apply_series_op(
             lambda psser: psser._with_new_scol(func(psser.spark.column)),  # 
TODO: dtype?
             should_resolve=True,
         )
 
-    def count(self) -> Union["Series", "DataFrame"]:
+    def count(self) -> T_Frame:
         """
         The rolling count of any non-NaN observations inside the window.
 
@@ -233,7 +237,7 @@ class Rolling(RollingAndExpanding):
         """
         return super().count()
 
-    def sum(self) -> Union["Series", "DataFrame"]:
+    def sum(self) -> T_Frame:
         """
         Calculate rolling summation of given DataFrame or Series.
 
@@ -311,7 +315,7 @@ class Rolling(RollingAndExpanding):
         """
         return super().sum()
 
-    def min(self) -> Union["Series", "DataFrame"]:
+    def min(self) -> T_Frame:
         """
         Calculate the rolling minimum.
 
@@ -389,7 +393,7 @@ class Rolling(RollingAndExpanding):
         """
         return super().min()
 
-    def max(self) -> Union["Series", "DataFrame"]:
+    def max(self) -> T_Frame:
         """
         Calculate the rolling maximum.
 
@@ -466,7 +470,7 @@ class Rolling(RollingAndExpanding):
         """
         return super().max()
 
-    def mean(self) -> Union["Series", "DataFrame"]:
+    def mean(self) -> T_Frame:
         """
         Calculate the rolling mean of the values.
 
@@ -544,7 +548,7 @@ class Rolling(RollingAndExpanding):
         """
         return super().mean()
 
-    def std(self) -> Union["Series", "DataFrame"]:
+    def std(self) -> T_Frame:
         """
         Calculate rolling standard deviation.
 
@@ -594,7 +598,7 @@ class Rolling(RollingAndExpanding):
         """
         return super().std()
 
-    def var(self) -> Union["Series", "DataFrame"]:
+    def var(self) -> T_Frame:
         """
         Calculate unbiased rolling variance.
 
@@ -645,27 +649,14 @@ class Rolling(RollingAndExpanding):
         return super().var()
 
 
-class RollingGroupby(Rolling):
+class RollingGroupby(RollingLike[T_Frame]):
     def __init__(
         self,
-        groupby: Union["SeriesGroupBy", "DataFrameGroupBy"],
+        groupby: GroupBy[T_Frame],
         window: int,
         min_periods: Optional[int] = None,
     ):
-        from pyspark.pandas.groupby import SeriesGroupBy
-        from pyspark.pandas.groupby import DataFrameGroupBy
-
-        if isinstance(groupby, SeriesGroupBy):
-            psdf_or_psser = groupby._psser  # type: Union[DataFrame, Series]
-        elif isinstance(groupby, DataFrameGroupBy):
-            psdf_or_psser = groupby._psdf
-        else:
-            raise TypeError(
-                "groupby must be a SeriesGroupBy or DataFrameGroupBy; "
-                "however, got: %s" % type(groupby)
-            )
-
-        super().__init__(psdf_or_psser, window, min_periods)
+        super().__init__(window, min_periods)
 
         self._groupby = groupby
         self._window = self._window.partitionBy(*[ser.spark.column for ser in 
groupby._groupkeys])
@@ -682,17 +673,13 @@ class RollingGroupby(Rolling):
                 return partial(property_or_func, self)
         raise AttributeError(item)
 
-    def _apply_as_series_or_frame(
-        self, func: Callable[[Column], Column]
-    ) -> Union["Series", "DataFrame"]:
+    def _apply_as_series_or_frame(self, func: Callable[[Column], Column]) -> 
T_Frame:
         """
         Wraps a function that handles Spark column in order
         to support it in both pandas-on-Spark Series and DataFrame.
         Note that the given `func` name should be same as the API's method 
name.
         """
         from pyspark.pandas import DataFrame
-        from pyspark.pandas.series import first_series
-        from pyspark.pandas.groupby import SeriesGroupBy
 
         groupby = self._groupby
         psdf = groupby._psdf
@@ -755,13 +742,9 @@ class RollingGroupby(Rolling):
             data_fields=[c._internal.data_fields[0] for c in applied],
         )
 
-        ret = DataFrame(internal)  # type: DataFrame
-        if isinstance(groupby, SeriesGroupBy):
-            return first_series(ret)
-        else:
-            return ret
+        return groupby._cleanup_and_return(DataFrame(internal))
 
-    def count(self) -> Union["Series", "DataFrame"]:
+    def count(self) -> T_Frame:
         """
         The rolling count of any non-NaN observations inside the window.
 
@@ -815,7 +798,7 @@ class RollingGroupby(Rolling):
         """
         return super().count()
 
-    def sum(self) -> Union["Series", "DataFrame"]:
+    def sum(self) -> T_Frame:
         """
         The rolling summation of any non-NaN observations inside the window.
 
@@ -869,7 +852,7 @@ class RollingGroupby(Rolling):
         """
         return super().sum()
 
-    def min(self) -> Union["Series", "DataFrame"]:
+    def min(self) -> T_Frame:
         """
         The rolling minimum of any non-NaN observations inside the window.
 
@@ -923,7 +906,7 @@ class RollingGroupby(Rolling):
         """
         return super().min()
 
-    def max(self) -> Union["Series", "DataFrame"]:
+    def max(self) -> T_Frame:
         """
         The rolling maximum of any non-NaN observations inside the window.
 
@@ -977,7 +960,7 @@ class RollingGroupby(Rolling):
         """
         return super().max()
 
-    def mean(self) -> Union["Series", "DataFrame"]:
+    def mean(self) -> T_Frame:
         """
         The rolling mean of any non-NaN observations inside the window.
 
@@ -1031,7 +1014,7 @@ class RollingGroupby(Rolling):
         """
         return super().mean()
 
-    def std(self) -> Union["Series", "DataFrame"]:
+    def std(self) -> T_Frame:
         """
         Calculate rolling standard deviation.
 
@@ -1050,7 +1033,7 @@ class RollingGroupby(Rolling):
         """
         return super().std()
 
-    def var(self) -> Union["Series", "DataFrame"]:
+    def var(self) -> T_Frame:
         """
         Calculate unbiased rolling variance.
 
@@ -1070,24 +1053,40 @@ class RollingGroupby(Rolling):
         return super().var()
 
 
-class Expanding(RollingAndExpanding):
-    def __init__(self, psdf_or_psser: Union["Series", "DataFrame"], 
min_periods: int = 1):
-        from pyspark.pandas import DataFrame, Series
-
+class ExpandingLike(RollingAndExpanding[T_Frame]):
+    def __init__(self, min_periods: int = 1):
         if min_periods < 0:
             raise ValueError("min_periods must be >= 0")
 
+        window = Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween(
+            Window.unboundedPreceding, Window.currentRow
+        )
+
+        super().__init__(window, min_periods)
+
+    def count(self) -> T_Frame:
+        def count(scol: Column) -> Column:
+            return F.when(
+                F.row_number().over(self._unbounded_window) >= 
self._min_periods,
+                F.count(scol).over(self._window),
+            ).otherwise(F.lit(None))
+
+        return self._apply_as_series_or_frame(count).astype("float64")  # 
type: ignore
+
+
+class Expanding(ExpandingLike[T_Frame]):
+    def __init__(self, psdf_or_psser: T_Frame, min_periods: int = 1):
+        from pyspark.pandas.frame import DataFrame
+        from pyspark.pandas.series import Series
+
+        super().__init__(min_periods)
+
         if not isinstance(psdf_or_psser, (DataFrame, Series)):
             raise TypeError(
                 "psdf_or_psser must be a series or dataframe; however, got: %s"
                 % type(psdf_or_psser)
             )
-
-        window = Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween(
-            Window.unboundedPreceding, Window.currentRow
-        )
-
-        super().__init__(psdf_or_psser, window, min_periods)
+        self._psdf_or_psser = psdf_or_psser
 
     def __getattr__(self, item: str) -> Any:
         if hasattr(MissingPandasLikeExpanding, item):
@@ -1104,7 +1103,7 @@ class Expanding(RollingAndExpanding):
 
     _apply_as_series_or_frame = Rolling._apply_as_series_or_frame
 
-    def count(self) -> Union["Series", "DataFrame"]:
+    def count(self) -> T_Frame:
         """
         The expanding count of any non-NaN observations inside the window.
 
@@ -1143,16 +1142,9 @@ class Expanding(RollingAndExpanding):
         2  2.0
         3  3.0
         """
+        return super().count()
 
-        def count(scol: Column) -> Column:
-            return F.when(
-                F.row_number().over(self._unbounded_window) >= 
self._min_periods,
-                F.count(scol).over(self._window),
-            ).otherwise(F.lit(None))
-
-        return self._apply_as_series_or_frame(count).astype("float64")  # 
type: ignore
-
-    def sum(self) -> Union["Series", "DataFrame"]:
+    def sum(self) -> T_Frame:
         """
         Calculate expanding summation of given DataFrame or Series.
 
@@ -1214,7 +1206,7 @@ class Expanding(RollingAndExpanding):
         """
         return super().sum()
 
-    def min(self) -> Union["Series", "DataFrame"]:
+    def min(self) -> T_Frame:
         """
         Calculate the expanding minimum.
 
@@ -1251,7 +1243,7 @@ class Expanding(RollingAndExpanding):
         """
         return super().min()
 
-    def max(self) -> Union["Series", "DataFrame"]:
+    def max(self) -> T_Frame:
         """
         Calculate the expanding maximum.
 
@@ -1287,7 +1279,7 @@ class Expanding(RollingAndExpanding):
         """
         return super().max()
 
-    def mean(self) -> Union["Series", "DataFrame"]:
+    def mean(self) -> T_Frame:
         """
         Calculate the expanding mean of the values.
 
@@ -1331,7 +1323,7 @@ class Expanding(RollingAndExpanding):
         """
         return super().mean()
 
-    def std(self) -> Union["Series", "DataFrame"]:
+    def std(self) -> T_Frame:
         """
         Calculate expanding standard deviation.
 
@@ -1381,7 +1373,7 @@ class Expanding(RollingAndExpanding):
         """
         return super().std()
 
-    def var(self) -> Union["Series", "DataFrame"]:
+    def var(self) -> T_Frame:
         """
         Calculate unbiased expanding variance.
 
@@ -1432,22 +1424,9 @@ class Expanding(RollingAndExpanding):
         return super().var()
 
 
-class ExpandingGroupby(Expanding):
-    def __init__(self, groupby: Union["SeriesGroupBy", "DataFrameGroupBy"], 
min_periods: int = 1):
-        from pyspark.pandas.groupby import SeriesGroupBy
-        from pyspark.pandas.groupby import DataFrameGroupBy
-
-        if isinstance(groupby, SeriesGroupBy):
-            psdf_or_psser = groupby._psser  # type: Union[DataFrame, Series]
-        elif isinstance(groupby, DataFrameGroupBy):
-            psdf_or_psser = groupby._psdf
-        else:
-            raise TypeError(
-                "groupby must be a SeriesGroupBy or DataFrameGroupBy; "
-                "however, got: %s" % type(groupby)
-            )
-
-        super().__init__(psdf_or_psser, min_periods)
+class ExpandingGroupby(ExpandingLike[T_Frame]):
+    def __init__(self, groupby: GroupBy[T_Frame], min_periods: int = 1):
+        super().__init__(min_periods)
 
         self._groupby = groupby
         self._window = self._window.partitionBy(*[ser.spark.column for ser in 
groupby._groupkeys])
@@ -1464,9 +1443,9 @@ class ExpandingGroupby(Expanding):
                 return partial(property_or_func, self)
         raise AttributeError(item)
 
-    _apply_as_series_or_frame = RollingGroupby._apply_as_series_or_frame  # 
type: ignore
+    _apply_as_series_or_frame = RollingGroupby._apply_as_series_or_frame
 
-    def count(self) -> Union["Series", "DataFrame"]:
+    def count(self) -> T_Frame:
         """
         The expanding count of any non-NaN observations inside the window.
 
@@ -1520,7 +1499,7 @@ class ExpandingGroupby(Expanding):
         """
         return super().count()
 
-    def sum(self) -> Union["Series", "DataFrame"]:
+    def sum(self) -> T_Frame:
         """
         Calculate expanding summation of given DataFrame or Series.
 
@@ -1574,7 +1553,7 @@ class ExpandingGroupby(Expanding):
         """
         return super().sum()
 
-    def min(self) -> Union["Series", "DataFrame"]:
+    def min(self) -> T_Frame:
         """
         Calculate the expanding minimum.
 
@@ -1628,7 +1607,7 @@ class ExpandingGroupby(Expanding):
         """
         return super().min()
 
-    def max(self) -> Union["Series", "DataFrame"]:
+    def max(self) -> T_Frame:
         """
         Calculate the expanding maximum.
 
@@ -1681,7 +1660,7 @@ class ExpandingGroupby(Expanding):
         """
         return super().max()
 
-    def mean(self) -> Union["Series", "DataFrame"]:
+    def mean(self) -> T_Frame:
         """
         Calculate the expanding mean of the values.
 
@@ -1735,7 +1714,7 @@ class ExpandingGroupby(Expanding):
         """
         return super().mean()
 
-    def std(self) -> Union["Series", "DataFrame"]:
+    def std(self) -> T_Frame:
         """
         Calculate expanding standard deviation.
 
@@ -1755,7 +1734,7 @@ class ExpandingGroupby(Expanding):
         """
         return super().std()
 
-    def var(self) -> Union["Series", "DataFrame"]:
+    def var(self) -> T_Frame:
         """
         Calculate unbiased expanding variance.
 

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to