This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6ddb5ae5766 [SPARK-39081][PYTHON][SQL] Implement DataFrame.resample 
and Series.resample
6ddb5ae5766 is described below

commit 6ddb5ae57665e10596182a0ed1d7c683be36078e
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu May 12 09:16:43 2022 +0900

    [SPARK-39081][PYTHON][SQL] Implement DataFrame.resample and Series.resample
    
    ### What changes were proposed in this pull request?
    Implement DataFrame.resample and Series.resample
    
    ### Why are the changes needed?
    To Increase pandas API coverage in PySpark
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new methods added
    
    for example:
    ```
    In [3]:
       ...: dates = [
       ...:     datetime.datetime(2011, 12, 31),
       ...:     datetime.datetime(2012, 1, 2),
       ...:     pd.NaT,
       ...:     datetime.datetime(2013, 5, 3),
       ...:     datetime.datetime(2022, 5, 3),
       ...: ]
       ...: pdf = pd.DataFrame(np.ones(len(dates)), 
index=pd.DatetimeIndex(dates), columns=['A'])
       ...: psdf = ps.from_pandas(pdf)
       ...: psdf.resample('3Y').sum().sort_index()
       ...:
    Out[3]:
                  A
    2011-12-31  1.0
    2014-12-31  2.0
    2017-12-31  0.0
    2020-12-31  0.0
    2023-12-31  1.0
    
    ```
    
    ### How was this patch tested?
    added UT
    
    Closes #36420 from zhengruifeng/impl_resample.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 dev/sparktestsupport/modules.py                    |   1 +
 .../docs/source/reference/pyspark.pandas/frame.rst |   1 +
 .../source/reference/pyspark.pandas/series.rst     |   1 +
 .../pandas_on_spark/supported_pandas_api.rst       |   8 +-
 python/pyspark/pandas/frame.py                     |  70 +++
 python/pyspark/pandas/missing/frame.py             |   1 -
 python/pyspark/pandas/missing/resample.py          | 105 +++++
 python/pyspark/pandas/missing/series.py            |   1 -
 python/pyspark/pandas/resample.py                  | 500 +++++++++++++++++++++
 python/pyspark/pandas/series.py                    | 135 ++++++
 python/pyspark/pandas/tests/test_resample.py       | 281 ++++++++++++
 .../spark/sql/api/python/PythonSQLUtils.scala      |  21 +
 12 files changed, 1121 insertions(+), 4 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index ed1eeb9b807..fc9e2ced9a9 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -643,6 +643,7 @@ pyspark_pandas = Module(
         "pyspark.pandas.tests.test_ops_on_diff_frames_groupby_expanding",
         "pyspark.pandas.tests.test_ops_on_diff_frames_groupby_rolling",
         "pyspark.pandas.tests.test_repr",
+        "pyspark.pandas.tests.test_resample",
         "pyspark.pandas.tests.test_reshape",
         "pyspark.pandas.tests.test_rolling",
         "pyspark.pandas.tests.test_series_conversion",
diff --git a/python/docs/source/reference/pyspark.pandas/frame.rst 
b/python/docs/source/reference/pyspark.pandas/frame.rst
index 75a8941ad78..bcf9694a7ae 100644
--- a/python/docs/source/reference/pyspark.pandas/frame.rst
+++ b/python/docs/source/reference/pyspark.pandas/frame.rst
@@ -260,6 +260,7 @@ Time series-related
 .. autosummary::
    :toctree: api/
 
+   DataFrame.resample
    DataFrame.shift
    DataFrame.first_valid_index
    DataFrame.last_valid_index
diff --git a/python/docs/source/reference/pyspark.pandas/series.rst 
b/python/docs/source/reference/pyspark.pandas/series.rst
index 48f67192349..1cf63c1a8ae 100644
--- a/python/docs/source/reference/pyspark.pandas/series.rst
+++ b/python/docs/source/reference/pyspark.pandas/series.rst
@@ -257,6 +257,7 @@ Time series-related
    :toctree: api/
 
    Series.asof
+   Series.resample
    Series.shift
    Series.first_valid_index
    Series.last_valid_index
diff --git 
a/python/docs/source/user_guide/pandas_on_spark/supported_pandas_api.rst 
b/python/docs/source/user_guide/pandas_on_spark/supported_pandas_api.rst
index f63b4a2f05d..8b456445a1e 100644
--- a/python/docs/source/user_guide/pandas_on_spark/supported_pandas_api.rst
+++ b/python/docs/source/user_guide/pandas_on_spark/supported_pandas_api.rst
@@ -362,7 +362,9 @@ Supported DataFrame APIs
 
+--------------------------------------------+-------------+--------------------------------------+
 | :func:`replace`                            | P           | ``regex``, 
``method``                |
 
+--------------------------------------------+-------------+--------------------------------------+
-| resample                                   | N           |                   
                   |
+| :func:`resample`                           | P           |``axis``, 
``convention``, ``kind``    |
+|                                            |             | ``loffset``, 
``base``, ``level``     |
+|                                            |             | ``origin``, 
``offset``               |
 
+--------------------------------------------+-------------+--------------------------------------+
 | :func:`reset_index`                        | Y           |                   
                   |
 
+--------------------------------------------+-------------+--------------------------------------+
@@ -1016,7 +1018,9 @@ Supported Series APIs
 | :func:`replace`                 | P                 | ``inplace``, 
``limit``, ``regex``,        |
 |                                 |                   | ``method``             
                   |
 
+---------------------------------+-------------------+-------------------------------------------+
-| resample                        | N                 |                        
                   |
+| :func:`resample`                | P                 |``axis``, 
``convention``, ``kind``         |
+|                                 |                   | ``loffset``, ``base``, 
``level``          |
+|                                 |                   | ``origin``, ``offset`` 
                   |
 
+---------------------------------+-------------------+-------------------------------------------+
 | :func:`reset_index`             | Y                 |                        
                   |
 
+---------------------------------+-------------------+-------------------------------------------+
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 8527477b7a2..f3ef0b15879 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -138,6 +138,7 @@ if TYPE_CHECKING:
     from pyspark.sql._typing import OptionalPrimitiveType
 
     from pyspark.pandas.groupby import DataFrameGroupBy
+    from pyspark.pandas.resample import DataFrameResampler
     from pyspark.pandas.indexes import Index
     from pyspark.pandas.series import Series
 
@@ -12660,6 +12661,75 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
 
         return DataFrameGroupBy._build(self, by, as_index=as_index, 
dropna=dropna)
 
+    def resample(
+        self,
+        rule: str,
+        closed: Optional[str] = None,
+        label: Optional[str] = None,
+        on: Optional["Series"] = None,
+    ) -> "DataFrameResampler":
+        """
+        Resample time-series data.
+
+        Convenience method for frequency conversion and resampling of time 
series.
+        The object must have a datetime-like index (only support 
`DatetimeIndex` for now),
+        or the caller must pass the label of a datetime-like
+        series/index to the ``on`` keyword parameter.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        rule : str
+            The offset string or object representing target conversion.
+            Currently, supported units are {'Y', 'A', 'M', 'D', 'H',
+            'T', 'MIN', 'S'}.
+        closed : {{'right', 'left'}}, default None
+            Which side of bin interval is closed. The default is 'left'
+            for all frequency offsets except for 'A', 'Y' and 'M' which all
+            have a default of 'right'.
+        label : {{'right', 'left'}}, default None
+            Which bin edge label to label bucket with. The default is 'left'
+            for all frequency offsets except for 'A', 'Y' and 'M' which all
+            have a default of 'right'.
+        on : Series, optional
+            For a DataFrame, column to use instead of index for resampling.
+            Column must be datetime-like.
+
+        Returns
+        -------
+        DataFrameResampler
+
+        See Also
+        --------
+        Series.resample : Resample a Series.
+        groupby : Group by mapping, function, label, or list of labels.
+        """
+        from pyspark.pandas.indexes import DatetimeIndex
+        from pyspark.pandas.resample import DataFrameResampler
+
+        if on is None and not isinstance(self.index, DatetimeIndex):
+            raise NotImplementedError("resample currently works only for 
DatetimeIndex")
+        if on is not None and not isinstance(as_spark_type(on.dtype), 
TimestampType):
+            raise NotImplementedError("resample currently works only for 
TimestampType")
+
+        agg_columns: List[ps.Series] = []
+        for column_label in self._internal.column_labels:
+            if isinstance(self._internal.spark_type_for(column_label), 
(NumericType, BooleanType)):
+                agg_columns.append(self._psser_for(column_label))
+
+        if len(agg_columns) == 0:
+            raise ValueError("No available aggregation columns!")
+
+        return DataFrameResampler(
+            psdf=self,
+            resamplekey=on,
+            rule=rule,
+            closed=closed,
+            label=label,
+            agg_columns=agg_columns,
+        )
+
     def _to_internal_pandas(self) -> pd.DataFrame:
         """
         Return a pandas DataFrame directly from _internal to avoid overhead of 
copy.
diff --git a/python/pyspark/pandas/missing/frame.py 
b/python/pyspark/pandas/missing/frame.py
index cd5e447cf0b..ddacd3ca12f 100644
--- a/python/pyspark/pandas/missing/frame.py
+++ b/python/pyspark/pandas/missing/frame.py
@@ -42,7 +42,6 @@ class _MissingPandasLikeDataFrame:
     infer_objects = _unsupported_function("infer_objects")
     mode = _unsupported_function("mode")
     reorder_levels = _unsupported_function("reorder_levels")
-    resample = _unsupported_function("resample")
     set_axis = _unsupported_function("set_axis")
     to_feather = _unsupported_function("to_feather")
     to_gbq = _unsupported_function("to_gbq")
diff --git a/python/pyspark/pandas/missing/resample.py 
b/python/pyspark/pandas/missing/resample.py
new file mode 100644
index 00000000000..91932797205
--- /dev/null
+++ b/python/pyspark/pandas/missing/resample.py
@@ -0,0 +1,105 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+from pyspark.pandas.missing import unsupported_function, unsupported_property
+
+
+def _unsupported_function(method_name, deprecated=False, reason=""):
+    return unsupported_function(
+        class_name="pd.resample.Resampler",
+        method_name=method_name,
+        deprecated=deprecated,
+        reason=reason,
+    )
+
+
+def _unsupported_property(property_name, deprecated=False, reason=""):
+    return unsupported_property(
+        class_name="pd.resample.Resampler",
+        property_name=property_name,
+        deprecated=deprecated,
+        reason=reason,
+    )
+
+
+class MissingPandasLikeDataFrameResampler:
+    # NOTE: Please update the document "Supported pandas APIs" when 
implementing the new API.
+    # Documentation path: 
`python/docs/source/user_guide/pandas_on_spark/supported_pandas_api.rst`.
+
+    # Properties
+    groups = _unsupported_property("groups")
+    indices = _unsupported_property("indices")
+
+    # Functions
+    get_group = _unsupported_property("get_group")
+    apply = _unsupported_function("apply")
+    aggregate = _unsupported_function("aggregate")
+    transform = _unsupported_function("transform")
+    pipe = _unsupported_function("pipe")
+    ffill = _unsupported_function("ffill")
+    backfill = _unsupported_function("backfill")
+    bfill = _unsupported_function("bfill")
+    pad = _unsupported_function("pad")
+    nearest = _unsupported_function("nearest")
+    fillna = _unsupported_function("fillna")
+    asfreq = _unsupported_function("asfreq")
+    interpolate = _unsupported_function("interpolate")
+    count = _unsupported_function("count")
+    nunique = _unsupported_function("nunique")
+    first = _unsupported_function("first")
+    last = _unsupported_function("last")
+    median = _unsupported_function("median")
+    ohlc = _unsupported_function("ohlc")
+    prod = _unsupported_function("prod")
+    size = _unsupported_function("size")
+    sem = _unsupported_function("sem")
+    quantile = _unsupported_function("quantile")
+
+
+class MissingPandasLikeSeriesResampler:
+    # NOTE: Please update the document "Supported pandas APIs" when 
implementing the new API.
+    # Documentation path: 
`python/docs/source/user_guide/pandas_on_spark/supported_pandas_api.rst`.
+
+    # Properties
+    groups = _unsupported_property("groups")
+    indices = _unsupported_property("indices")
+
+    # Functions
+    get_group = _unsupported_property("get_group")
+    apply = _unsupported_function("apply")
+    aggregate = _unsupported_function("aggregate")
+    transform = _unsupported_function("transform")
+    pipe = _unsupported_function("pipe")
+    ffill = _unsupported_function("ffill")
+    backfill = _unsupported_function("backfill")
+    bfill = _unsupported_function("bfill")
+    pad = _unsupported_function("pad")
+    nearest = _unsupported_function("nearest")
+    fillna = _unsupported_function("fillna")
+    asfreq = _unsupported_function("asfreq")
+    interpolate = _unsupported_function("interpolate")
+    count = _unsupported_function("count")
+    nunique = _unsupported_function("nunique")
+    first = _unsupported_function("first")
+    last = _unsupported_function("last")
+    median = _unsupported_function("median")
+    ohlc = _unsupported_function("ohlc")
+    prod = _unsupported_function("prod")
+    size = _unsupported_function("size")
+    sem = _unsupported_function("sem")
+    quantile = _unsupported_function("quantile")
diff --git a/python/pyspark/pandas/missing/series.py 
b/python/pyspark/pandas/missing/series.py
index ce06ceff16a..3fafbbe048f 100644
--- a/python/pyspark/pandas/missing/series.py
+++ b/python/pyspark/pandas/missing/series.py
@@ -39,7 +39,6 @@ class MissingPandasLikeSeries:
     convert_dtypes = _unsupported_function("convert_dtypes")
     infer_objects = _unsupported_function("infer_objects")
     reorder_levels = _unsupported_function("reorder_levels")
-    resample = _unsupported_function("resample")
     searchsorted = _unsupported_function("searchsorted")
     set_axis = _unsupported_function("set_axis")
     to_hdf = _unsupported_function("to_hdf")
diff --git a/python/pyspark/pandas/resample.py 
b/python/pyspark/pandas/resample.py
new file mode 100644
index 00000000000..4743e0e1a9c
--- /dev/null
+++ b/python/pyspark/pandas/resample.py
@@ -0,0 +1,500 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+A wrapper for ResampledData to behave similar to pandas Resampler.
+"""
+from abc import ABCMeta
+from distutils.version import LooseVersion
+from functools import partial
+from typing import (
+    Any,
+    Generic,
+    List,
+    Optional,
+)
+
+import numpy as np
+
+import pandas as pd
+from pandas.tseries.frequencies import to_offset
+
+if LooseVersion(pd.__version__) >= LooseVersion("1.3.0"):
+    from pandas.core.common import _builtin_table  # type: ignore[attr-defined]
+else:
+    from pandas.core.base import SelectionMixin
+
+    _builtin_table = SelectionMixin._builtin_table  # type: 
ignore[attr-defined]
+
+from pyspark import SparkContext
+from pyspark.sql import Column, functions as F
+from pyspark.sql.types import (
+    NumericType,
+    StructField,
+    TimestampType,
+)
+
+from pyspark import pandas as ps  # For running doctests and reference 
resolution in PyCharm.
+from pyspark.pandas._typing import FrameLike
+from pyspark.pandas.frame import DataFrame
+from pyspark.pandas.internal import (
+    InternalField,
+    InternalFrame,
+    SPARK_DEFAULT_INDEX_NAME,
+)
+from pyspark.pandas.missing.resample import (
+    MissingPandasLikeDataFrameResampler,
+    MissingPandasLikeSeriesResampler,
+)
+from pyspark.pandas.series import Series, first_series
+from pyspark.pandas.utils import (
+    scol_for,
+    verify_temp_column_name,
+)
+
+
+class Resampler(Generic[FrameLike], metaclass=ABCMeta):
+    """
+    Class for resampling datetimelike data, a groupby-like operation.
+
+    It's easiest to use obj.resample(...) to use Resampler.
+
+    Parameters
+    ----------
+    psdf : DataFrame
+
+    Returns
+    -------
+    a Resampler of the appropriate type
+
+    Notes
+    -----
+    After resampling, see aggregate, apply, and transform functions.
+    """
+
+    def __init__(
+        self,
+        psdf: DataFrame,
+        resamplekey: Optional[Series],
+        rule: str,
+        closed: Optional[str] = None,
+        label: Optional[str] = None,
+        agg_columns: List[Series] = [],
+    ):
+        self._psdf = psdf
+        self._resamplekey = resamplekey
+
+        self._offset = to_offset(rule)
+        if self._offset.rule_code not in ["A-DEC", "M", "D", "H", "T", "S"]:
+            raise ValueError("rule code {} is not 
supported".format(self._offset.rule_code))
+        if not self._offset.n > 0:  # type: ignore[attr-defined]
+            raise ValueError("rule offset must be positive")
+
+        if closed is None:
+            self._closed = "right" if self._offset.rule_code in ["A-DEC", "M"] 
else "left"
+        elif closed in ["left", "right"]:
+            self._closed = closed
+        else:
+            raise ValueError("invalid closed: '{}'".format(closed))
+
+        if label is None:
+            self._label = "right" if self._offset.rule_code in ["A-DEC", "M"] 
else "left"
+        elif label in ["left", "right"]:
+            self._label = label
+        else:
+            raise ValueError("invalid label: '{}'".format(label))
+
+        self._agg_columns = agg_columns
+
+    @property
+    def _resamplekey_scol(self) -> Column:
+        if self._resamplekey is None:
+            return self._psdf.index.spark.column
+        else:
+            return self._resamplekey.spark.column
+
+    @property
+    def _agg_columns_scols(self) -> List[Column]:
+        return [s.spark.column for s in self._agg_columns]
+
+    def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
+        sql_utils = SparkContext._active_spark_context._jvm.PythonSQLUtils
+        origin_scol = F.lit(origin)
+        (rule_code, n) = (self._offset.rule_code, self._offset.n)  # type: 
ignore[attr-defined]
+        left_closed, right_closed = (self._closed == "left", self._closed == 
"right")
+        left_labeled, right_labeled = (self._label == "left", self._label == 
"right")
+
+        if rule_code == "A-DEC":
+            assert (
+                origin.month == 12
+                and origin.day == 31
+                and origin.hour == 0
+                and origin.minute == 0
+                and origin.second == 0
+            )
+
+            diff = F.year(ts_scol) - F.year(origin_scol)
+            mod = F.lit(0) if n == 1 else (diff % n)
+            edge_cond = (mod == 0) & (F.month(ts_scol) == 12) & 
(F.dayofmonth(ts_scol) == 31)
+
+            edge_label = F.year(ts_scol)
+            if left_closed and right_labeled:
+                edge_label += n
+            elif right_closed and left_labeled:
+                edge_label -= n
+
+            if left_labeled:
+                non_edge_label = F.when(mod == 0, F.year(ts_scol) - 
n).otherwise(
+                    F.year(ts_scol) - mod
+                )
+            else:
+                non_edge_label = F.when(mod == 0, F.year(ts_scol)).otherwise(
+                    F.year(ts_scol) - (mod - n)
+                )
+
+            return F.to_timestamp(
+                F.make_date(
+                    F.when(edge_cond, edge_label).otherwise(non_edge_label), 
F.lit(12), F.lit(31)
+                )
+            )
+
+        elif rule_code == "M":
+            assert (
+                origin.is_month_end
+                and origin.hour == 0
+                and origin.minute == 0
+                and origin.second == 0
+            )
+
+            diff = (
+                (F.year(ts_scol) - F.year(origin_scol)) * 12
+                + F.month(ts_scol)
+                - F.month(origin_scol)
+            )
+            mod = F.lit(0) if n == 1 else (diff % n)
+            edge_cond = (mod == 0) & (F.dayofmonth(ts_scol) == 
F.dayofmonth(F.last_day(ts_scol)))
+
+            truncated_ts_scol = F.date_trunc("MONTH", ts_scol)
+            edge_label = truncated_ts_scol
+            if left_closed and right_labeled:
+                edge_label += sql_utils.makeInterval("MONTH", F.lit(n)._jc)
+            elif right_closed and left_labeled:
+                edge_label -= sql_utils.makeInterval("MONTH", F.lit(n)._jc)
+
+            if left_labeled:
+                non_edge_label = F.when(
+                    mod == 0,
+                    truncated_ts_scol - sql_utils.makeInterval("MONTH", 
F.lit(n)._jc),
+                ).otherwise(truncated_ts_scol - 
sql_utils.makeInterval("MONTH", mod._jc))
+            else:
+                non_edge_label = F.when(mod == 0, truncated_ts_scol).otherwise(
+                    truncated_ts_scol - sql_utils.makeInterval("MONTH", (mod - 
n)._jc)
+                )
+
+            return F.to_timestamp(
+                F.last_day(F.when(edge_cond, 
edge_label).otherwise(non_edge_label))
+            )
+
+        elif rule_code == "D":
+            assert origin.hour == 0 and origin.minute == 0 and origin.second 
== 0
+
+            if n == 1:
+                # NOTE: the logic to process '1D' is different from the cases 
with n>1,
+                # since hour/minute/second parts are taken into account to 
determine edges!
+                edge_cond = (
+                    (F.hour(ts_scol) == 0) & (F.minute(ts_scol) == 0) & 
(F.second(ts_scol) == 0)
+                )
+
+                if left_closed and left_labeled:
+                    return F.date_trunc("DAY", ts_scol)
+                elif left_closed and right_labeled:
+                    return F.date_trunc("DAY", F.date_add(ts_scol, 1))
+                elif right_closed and left_labeled:
+                    return F.when(edge_cond, F.date_trunc("DAY", 
F.date_sub(ts_scol, 1))).otherwise(
+                        F.date_trunc("DAY", ts_scol)
+                    )
+                else:
+                    return F.when(edge_cond, F.date_trunc("DAY", 
ts_scol)).otherwise(
+                        F.date_trunc("DAY", F.date_add(ts_scol, 1))
+                    )
+
+            else:
+                diff = F.datediff(end=ts_scol, start=origin_scol)
+                mod = diff % n
+
+                edge_cond = mod == 0
+
+                truncated_ts_scol = F.date_trunc("DAY", ts_scol)
+                edge_label = truncated_ts_scol
+                if left_closed and right_labeled:
+                    edge_label = F.date_add(truncated_ts_scol, n)
+                elif right_closed and left_labeled:
+                    edge_label = F.date_sub(truncated_ts_scol, n)
+
+                if left_labeled:
+                    non_edge_label = F.date_sub(truncated_ts_scol, mod)
+                else:
+                    non_edge_label = F.date_sub(truncated_ts_scol, mod - n)
+
+                return F.when(edge_cond, edge_label).otherwise(non_edge_label)
+
+        elif rule_code in ["H", "T", "S"]:
+            unit_mapping = {"H": "HOUR", "T": "MINUTE", "S": "SECOND"}
+            unit_str = unit_mapping[rule_code]
+
+            truncated_ts_scol = F.date_trunc(unit_str, ts_scol)
+            diff = sql_utils.timestampDiff(unit_str, origin_scol._jc, 
truncated_ts_scol._jc)
+            mod = F.lit(0) if n == 1 else (diff % F.lit(n))
+
+            if rule_code == "H":
+                assert origin.minute == 0 and origin.second == 0
+                edge_cond = (mod == 0) & (F.minute(ts_scol) == 0) & 
(F.second(ts_scol) == 0)
+            elif rule_code == "T":
+                assert origin.second == 0
+                edge_cond = (mod == 0) & (F.second(ts_scol) == 0)
+            else:
+                edge_cond = mod == 0
+
+            edge_label = truncated_ts_scol
+            if left_closed and right_labeled:
+                edge_label += sql_utils.makeInterval(unit_str, F.lit(n)._jc)
+            elif right_closed and left_labeled:
+                edge_label -= sql_utils.makeInterval(unit_str, F.lit(n)._jc)
+
+            if left_labeled:
+                non_edge_label = F.when(mod == 0, truncated_ts_scol).otherwise(
+                    truncated_ts_scol - sql_utils.makeInterval(unit_str, 
mod._jc)
+                )
+            else:
+                non_edge_label = F.when(
+                    mod == 0,
+                    truncated_ts_scol + sql_utils.makeInterval(unit_str, 
F.lit(n)._jc),
+                ).otherwise(truncated_ts_scol - 
sql_utils.makeInterval(unit_str, (mod - n)._jc))
+
+            return F.when(edge_cond, edge_label).otherwise(non_edge_label)
+
+        else:
+            raise ValueError("Got the unexpected unit {}".format(rule_code))
+
+    def _downsample(self, f: str) -> DataFrame:
+        """
+        Downsample the defined function.
+
+        Parameters
+        ----------
+        how : string / mapped function
+        **kwargs : kw args passed to how function
+        """
+
+        # a simple example to illustrate the computation:
+        #   dates = [
+        #         datetime.datetime(2012, 1, 2),
+        #         datetime.datetime(2012, 5, 3),
+        #         datetime.datetime(2022, 5, 3),
+        #   ]
+        #   index = pd.DatetimeIndex(dates)
+        #   pdf = pd.DataFrame(np.array([1,2,3]), index=index, columns=['A'])
+        #   pdf.resample('3Y').max()
+        #                 A
+        #   2012-12-31  2.0
+        #   2015-12-31  NaN
+        #   2018-12-31  NaN
+        #   2021-12-31  NaN
+        #   2024-12-31  3.0
+        #
+        # in this case:
+        # 1, obtain one origin point to bin all timestamps, we can get one 
(2009-12-31)
+        # from the minimum timestamp (2012-01-02);
+        # 2, the default intervals for 'Y' are right-closed, so intervals are:
+        # (2009-12-31, 2012-12-31], (2012-12-31, 2015-12-31], (2015-12-31, 
2018-12-31], ...
+        # 3, bin all timestamps, for example, 2022-05-03 belongs to interval
+        # (2021-12-31, 2024-12-31], since the default label is 'right', label 
it with the right
+        # edge 2024-12-31;
+        # 4, some intervals maybe too large for this down sampling, so we need 
to pad the dataframe
+        # to avoid missing some results, like: 2015-12-31, 2018-12-31 and 
2021-12-31;
+        # 5, union the binned dataframe and padded dataframe, and apply 
aggregation 'max' to get
+        # the final results;
+
+        # one action to obtain the range, in the future we may cache it in the 
index.
+        ts_min, ts_max = (
+            self._psdf._internal.spark_frame.select(
+                F.min(self._resamplekey_scol), F.max(self._resamplekey_scol)
+            )
+            .toPandas()
+            .iloc[0]
+        )
+
+        # the logic to obtain an origin point to bin the timestamps is too 
complex to follow,
+        # here just use Pandas' resample on a 1-length series to get it.
+        ts_origin = (
+            pd.Series([0], index=[ts_min])
+            .resample(rule=self._offset.freqstr, closed=self._closed, 
label="left")
+            .sum()
+            .index[0]
+        )
+        assert ts_origin <= ts_min
+
+        bin_col_name = "__tmp_resample_bin_col__"
+        bin_col_label = verify_temp_column_name(self._psdf, bin_col_name)
+        bin_col_field = InternalField(
+            dtype=np.dtype("datetime64[ns]"),
+            struct_field=StructField(bin_col_name, TimestampType(), True),
+        )
+        bin_scol = self._bin_time_stamp(
+            ts_origin,
+            self._resamplekey_scol,
+        )
+
+        agg_columns = [
+            psser for psser in self._agg_columns if 
(isinstance(psser.spark.data_type, NumericType))
+        ]
+        assert len(agg_columns) > 0
+
+        # in the binning side, label the timestamps according to the origin 
and the freq(rule)
+        bin_sdf = self._psdf._internal.spark_frame.select(
+            F.col(SPARK_DEFAULT_INDEX_NAME),
+            bin_scol.alias(bin_col_name),
+            *[psser.spark.column for psser in agg_columns],
+        )
+
+        # in the padding side, insert necessary points
+        # again, directly apply Pandas' resample on a 2-length series to 
obtain the indices
+        pad_sdf = (
+            ps.from_pandas(
+                pd.Series([0, 0], index=[ts_min, ts_max])
+                .resample(rule=self._offset.freqstr, closed=self._closed, 
label=self._label)
+                .sum()
+                .index
+            )
+            
._internal.spark_frame.select(F.col(SPARK_DEFAULT_INDEX_NAME).alias(bin_col_name))
+            .where((ts_min <= F.col(bin_col_name)) & (F.col(bin_col_name) <= 
ts_max))
+        )
+
+        # union the above two spark dataframes.
+        sdf = bin_sdf.unionByName(pad_sdf, allowMissingColumns=True).where(
+            ~F.isnull(F.col(bin_col_name))
+        )
+
+        internal = InternalFrame(
+            spark_frame=sdf,
+            index_spark_columns=[scol_for(sdf, SPARK_DEFAULT_INDEX_NAME)],
+            data_spark_columns=[F.col(bin_col_name)]
+            + [scol_for(sdf, psser._internal.data_spark_column_names[0]) for 
psser in agg_columns],
+            column_labels=[bin_col_label] + [psser._column_label for psser in 
agg_columns],
+            data_fields=[bin_col_field]
+            + [psser._internal.data_fields[0].copy(nullable=True) for psser in 
agg_columns],
+            column_label_names=self._psdf._internal.column_label_names,
+        )
+        psdf: DataFrame = DataFrame(internal)
+
+        groupby = psdf.groupby(psdf._psser_for(bin_col_label), dropna=False)
+        downsampled = getattr(groupby, f)()
+        downsampled.index.name = None
+
+        return downsampled
+
+
+class DataFrameResampler(Resampler[DataFrame]):
+    def __init__(
+        self,
+        psdf: DataFrame,
+        resamplekey: Optional[Series],
+        rule: str,
+        closed: Optional[str] = None,
+        label: Optional[str] = None,
+        agg_columns: List[Series] = [],
+    ):
+        super().__init__(
+            psdf=psdf,
+            resamplekey=resamplekey,
+            rule=rule,
+            closed=closed,
+            label=label,
+            agg_columns=agg_columns,
+        )
+
+    def __getattr__(self, item: str) -> Any:
+        if hasattr(MissingPandasLikeDataFrameResampler, item):
+            property_or_func = getattr(MissingPandasLikeDataFrameResampler, 
item)
+            if isinstance(property_or_func, property):
+                return property_or_func.fget(self)
+            else:
+                return partial(property_or_func, self)
+
+    def min(self) -> DataFrame:
+        return self._downsample("min")
+
+    def max(self) -> DataFrame:
+        return self._downsample("max")
+
+    def sum(self) -> DataFrame:
+        return self._downsample("sum").fillna(0.0)
+
+    def mean(self) -> DataFrame:
+        return self._downsample("mean")
+
+    def std(self) -> DataFrame:
+        return self._downsample("std")
+
+    def var(self) -> DataFrame:
+        return self._downsample("var")
+
+
+class SeriesResampler(Resampler[Series]):
+    def __init__(
+        self,
+        psdf: DataFrame,
+        resamplekey: Optional[Series],
+        rule: str,
+        closed: Optional[str] = None,
+        label: Optional[str] = None,
+        agg_columns: List[Series] = [],
+    ):
+        super().__init__(
+            psdf=psdf,
+            resamplekey=resamplekey,
+            rule=rule,
+            closed=closed,
+            label=label,
+            agg_columns=agg_columns,
+        )
+
+    def __getattr__(self, item: str) -> Any:
+        if hasattr(MissingPandasLikeSeriesResampler, item):
+            property_or_func = getattr(MissingPandasLikeSeriesResampler, item)
+            if isinstance(property_or_func, property):
+                return property_or_func.fget(self)
+            else:
+                return partial(property_or_func, self)
+
+    def min(self) -> Series:
+        return first_series(self._downsample("min"))
+
+    def max(self) -> Series:
+        return first_series(self._downsample("max"))
+
+    def sum(self) -> Series:
+        return first_series(self._downsample("sum").fillna(0.0))
+
+    def mean(self) -> Series:
+        return first_series(self._downsample("mean"))
+
+    def std(self) -> Series:
+        return first_series(self._downsample("std"))
+
+    def var(self) -> Series:
+        return first_series(self._downsample("var"))
diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py
index b8915e160c1..d748b99f344 100644
--- a/python/pyspark/pandas/series.py
+++ b/python/pyspark/pandas/series.py
@@ -118,11 +118,13 @@ from pyspark.pandas.typedef import (
     SeriesType,
     create_type_for_series_type,
 )
+from pyspark.pandas.typedef.typehints import as_spark_type
 
 if TYPE_CHECKING:
     from pyspark.sql._typing import ColumnOrName
 
     from pyspark.pandas.groupby import SeriesGroupBy
+    from pyspark.pandas.resample import SeriesResampler
     from pyspark.pandas.indexes import Index
     from pyspark.pandas.spark.accessors import SparkIndexOpsMethods
 
@@ -6900,6 +6902,139 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
 
         return SeriesGroupBy._build(self, by, as_index=as_index, dropna=dropna)
 
+    def resample(
+        self,
+        rule: str_type,
+        closed: Optional[str_type] = None,
+        label: Optional[str_type] = None,
+        on: Optional["Series"] = None,
+    ) -> "SeriesResampler":
+        """
+        Resample time-series data.
+
+        Convenience method for frequency conversion and resampling of time 
series.
+        The object must have a datetime-like index (only support 
`DatetimeIndex` for now),
+        or the caller must pass the label of a datetime-like
+        series/index to the ``on`` keyword parameter.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        rule : str
+            The offset string or object representing target conversion.
+            Currently, supported units are {'Y', 'A', 'M', 'D', 'H',
+            'T', 'MIN', 'S'}.
+        closed : {{'right', 'left'}}, default None
+            Which side of bin interval is closed. The default is 'left'
+            for all frequency offsets except for 'A', 'Y' and 'M' which all
+            have a default of 'right'.
+        label : {{'right', 'left'}}, default None
+            Which bin edge label to label bucket with. The default is 'left'
+            for all frequency offsets except for 'A', 'Y' and 'M' which all
+            have a default of 'right'.
+        on : Series, optional
+            For a DataFrame, column to use instead of index for resampling.
+            Column must be datetime-like.
+
+        Returns
+        -------
+        SeriesResampler
+
+
+        Examples
+        --------
+        Start by creating a series with 9 one minute timestamps.
+
+        >>> index = pd.date_range('1/1/2000', periods=9, freq='T')
+        >>> series = ps.Series(range(9), index=index, name='V')
+        >>> series
+        2000-01-01 00:00:00    0
+        2000-01-01 00:01:00    1
+        2000-01-01 00:02:00    2
+        2000-01-01 00:03:00    3
+        2000-01-01 00:04:00    4
+        2000-01-01 00:05:00    5
+        2000-01-01 00:06:00    6
+        2000-01-01 00:07:00    7
+        2000-01-01 00:08:00    8
+        Name: V, dtype: int64
+
+        Downsample the series into 3 minute bins and sum the values
+        of the timestamps falling into a bin.
+
+        >>> series.resample('3T').sum().sort_index()
+        2000-01-01 00:00:00     3.0
+        2000-01-01 00:03:00    12.0
+        2000-01-01 00:06:00    21.0
+        Name: V, dtype: float64
+
+        Downsample the series into 3 minute bins as above, but label each
+        bin using the right edge instead of the left. Please note that the
+        value in the bucket used as the label is not included in the bucket,
+        which it labels. For example, in the original series the
+        bucket ``2000-01-01 00:03:00`` contains the value 3, but the summed
+        value in the resampled bucket with the label ``2000-01-01 00:03:00``
+        does not include 3 (if it did, the summed value would be 6, not 3).
+        To include this value close the right side of the bin interval as
+        illustrated in the example below this one.
+
+        >>> series.resample('3T', label='right').sum().sort_index()
+        2000-01-01 00:03:00     3.0
+        2000-01-01 00:06:00    12.0
+        2000-01-01 00:09:00    21.0
+        Name: V, dtype: float64
+
+        Downsample the series into 3 minute bins as above, but close the right
+        side of the bin interval.
+
+        >>> series.resample('3T', label='right', 
closed='right').sum().sort_index()
+        2000-01-01 00:00:00     0.0
+        2000-01-01 00:03:00     6.0
+        2000-01-01 00:06:00    15.0
+        2000-01-01 00:09:00    15.0
+        Name: V, dtype: float64
+
+        Upsample the series into 30 second bins.
+
+        >>> series.resample('30S').sum().sort_index()[0:5]   # Select first 5 
rows
+        2000-01-01 00:00:00    0.0
+        2000-01-01 00:00:30    0.0
+        2000-01-01 00:01:00    1.0
+        2000-01-01 00:01:30    0.0
+        2000-01-01 00:02:00    2.0
+        Name: V, dtype: float64
+
+        See Also
+        --------
+        DataFrame.resample : Resample a DataFrame.
+        groupby : Group by mapping, function, label, or list of labels.
+        """
+        from pyspark.pandas.indexes import DatetimeIndex
+        from pyspark.pandas.resample import SeriesResampler
+
+        if on is None and not isinstance(self.index, DatetimeIndex):
+            raise NotImplementedError("resample currently works only for 
DatetimeIndex")
+        if on is not None and not isinstance(as_spark_type(on.dtype), 
TimestampType):
+            raise NotImplementedError("resample currently works only for 
TimestampType")
+
+        agg_columns: List[ps.Series] = []
+        column_label = self._internal.column_labels[0]
+        if isinstance(self._internal.spark_type_for(column_label), 
(NumericType, BooleanType)):
+            agg_columns.append(self)
+
+        if len(agg_columns) == 0:
+            raise ValueError("No available aggregation columns!")
+
+        return SeriesResampler(
+            psdf=self._psdf,
+            resamplekey=on,
+            rule=rule,
+            closed=closed,
+            label=label,
+            agg_columns=agg_columns,
+        )
+
     def __getitem__(self, key: Any) -> Any:
         try:
             if (isinstance(key, slice) and any(type(n) == int for n in 
[key.start, key.stop])) or (
diff --git a/python/pyspark/pandas/tests/test_resample.py 
b/python/pyspark/pandas/tests/test_resample.py
new file mode 100644
index 00000000000..e9359b0a8a7
--- /dev/null
+++ b/python/pyspark/pandas/tests/test_resample.py
@@ -0,0 +1,281 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import unittest
+import inspect
+import datetime
+import numpy as np
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.pandas.exceptions import PandasNotImplementedError, DataError
+from pyspark.pandas.missing.resample import (
+    MissingPandasLikeDataFrameResampler,
+    MissingPandasLikeSeriesResampler,
+)
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
+
+
+class ResampleTest(PandasOnSparkTestCase, TestUtils):
+    @property
+    def pdf1(self):
+        np.random.seed(11)
+        dates = [
+            pd.NaT,
+            datetime.datetime(2011, 12, 31),
+            datetime.datetime(2011, 12, 31, 0, 0, 1),
+            datetime.datetime(2011, 12, 31, 23, 59, 59),
+            datetime.datetime(2012, 1, 1),
+            datetime.datetime(2012, 1, 1, 0, 0, 1),
+            pd.NaT,
+            datetime.datetime(2012, 1, 1, 23, 59, 59),
+            datetime.datetime(2012, 1, 2),
+            pd.NaT,
+            datetime.datetime(2012, 1, 30, 23, 59, 59),
+            datetime.datetime(2012, 1, 31),
+            datetime.datetime(2012, 1, 31, 0, 0, 1),
+            datetime.datetime(2012, 3, 31),
+            datetime.datetime(2013, 5, 3),
+            datetime.datetime(2022, 5, 3),
+        ]
+        return pd.DataFrame(
+            np.random.rand(len(dates), 2), index=pd.DatetimeIndex(dates), 
columns=list("AB")
+        )
+
+    @property
+    def pdf2(self):
+        np.random.seed(22)
+        dates = [
+            datetime.datetime(2022, 5, 1, 4, 5, 6),
+            datetime.datetime(2022, 5, 3),
+            datetime.datetime(2022, 5, 3, 23, 59, 59),
+            datetime.datetime(2022, 5, 4),
+            pd.NaT,
+            datetime.datetime(2022, 5, 4, 0, 0, 1),
+            datetime.datetime(2022, 5, 11),
+        ]
+        return pd.DataFrame(
+            np.random.rand(len(dates), 2), index=pd.DatetimeIndex(dates), 
columns=list("AB")
+        )
+
+    @property
+    def pdf3(self):
+        np.random.seed(22)
+        index = pd.date_range(start="2011-01-02", end="2022-05-01", freq="1D")
+        return pd.DataFrame(np.random.rand(len(index), 2), index=index, 
columns=list("AB"))
+
+    @property
+    def pdf4(self):
+        np.random.seed(33)
+        index = pd.date_range(start="2020-12-12", end="2022-05-01", freq="1H")
+        return pd.DataFrame(np.random.rand(len(index), 2), index=index, 
columns=list("AB"))
+
+    @property
+    def pdf5(self):
+        np.random.seed(44)
+        index = pd.date_range(start="2021-12-30 03:04:05", end="2022-01-02 
06:07:08", freq="1T")
+        return pd.DataFrame(np.random.rand(len(index), 2), index=index, 
columns=list("AB"))
+
+    @property
+    def pdf6(self):
+        np.random.seed(55)
+        index = pd.date_range(start="2022-05-02 03:04:05", end="2022-05-03 
06:07:08", freq="1S")
+        return pd.DataFrame(np.random.rand(len(index), 2), index=index, 
columns=list("AB"))
+
+    @property
+    def psdf1(self):
+        return ps.from_pandas(self.pdf1)
+
+    @property
+    def psdf2(self):
+        return ps.from_pandas(self.pdf2)
+
+    @property
+    def psdf3(self):
+        return ps.from_pandas(self.pdf3)
+
+    @property
+    def psdf4(self):
+        return ps.from_pandas(self.pdf4)
+
+    @property
+    def psdf5(self):
+        return ps.from_pandas(self.pdf5)
+
+    @property
+    def psdf6(self):
+        return ps.from_pandas(self.pdf6)
+
+    def test_resample_error(self):
+        psdf = ps.range(10)
+
+        with self.assertRaisesRegex(
+            NotImplementedError, "resample currently works only for 
DatetimeIndex"
+        ):
+            psdf.resample("3Y").sum()
+
+        dates = [
+            datetime.datetime(2012, 1, 2),
+            datetime.datetime(2012, 5, 3),
+            datetime.datetime(2022, 5, 3),
+            pd.NaT,
+        ]
+        pdf = pd.DataFrame(np.ones(len(dates)), index=pd.DatetimeIndex(dates), 
columns=["A"])
+        psdf = ps.from_pandas(pdf)
+
+        with self.assertRaisesRegex(ValueError, "rule code W-SUN is not 
supported"):
+            psdf.A.resample("3W").sum()
+
+        with self.assertRaisesRegex(ValueError, "rule offset must be 
positive"):
+            psdf.A.resample("0Y").sum()
+
+        with self.assertRaisesRegex(ValueError, "invalid closed: 'middle'"):
+            psdf.A.resample("3Y", closed="middle").sum()
+
+        with self.assertRaisesRegex(ValueError, "invalid label: 'both'"):
+            psdf.A.resample("3Y", label="both").sum()
+
+    def test_missing(self):
+        pdf_r = self.psdf1.resample("3Y")
+        pser_r = self.psdf1.A.resample("3Y")
+
+        # DataFrameResampler functions
+        missing_functions = inspect.getmembers(
+            MissingPandasLikeDataFrameResampler, inspect.isfunction
+        )
+        unsupported_functions = [
+            name for (name, type_) in missing_functions if type_.__name__ == 
"unsupported_function"
+        ]
+        for name in unsupported_functions:
+            with self.assertRaisesRegex(
+                PandasNotImplementedError,
+                "method.*Resampler.*{}.*not implemented( yet\\.|\\. 
.+)".format(name),
+            ):
+                getattr(pdf_r, name)()
+
+        # SeriesResampler functions
+        missing_functions = 
inspect.getmembers(MissingPandasLikeSeriesResampler, inspect.isfunction)
+        unsupported_functions = [
+            name for (name, type_) in missing_functions if type_.__name__ == 
"unsupported_function"
+        ]
+        for name in unsupported_functions:
+            with self.assertRaisesRegex(
+                PandasNotImplementedError,
+                "method.*Resampler.*{}.*not implemented( yet\\.|\\. 
.+)".format(name),
+            ):
+                getattr(pser_r, name)()
+
+        # DataFrameResampler properties
+        missing_properties = inspect.getmembers(
+            MissingPandasLikeDataFrameResampler, lambda o: isinstance(o, 
property)
+        )
+        unsupported_properties = [
+            name
+            for (name, type_) in missing_properties
+            if type_.fget.__name__ == "unsupported_property"
+        ]
+        for name in unsupported_properties:
+            with self.assertRaisesRegex(
+                PandasNotImplementedError,
+                "property.*Resampler.*{}.*not implemented( yet\\.|\\. 
.+)".format(name),
+            ):
+                getattr(pdf_r, name)
+
+        # SeriesResampler properties
+        missing_properties = inspect.getmembers(
+            MissingPandasLikeSeriesResampler, lambda o: isinstance(o, property)
+        )
+        unsupported_properties = [
+            name
+            for (name, type_) in missing_properties
+            if type_.fget.__name__ == "unsupported_property"
+        ]
+        for name in unsupported_properties:
+            with self.assertRaisesRegex(
+                PandasNotImplementedError,
+                "property.*Resampler.*{}.*not implemented( yet\\.|\\. 
.+)".format(name),
+            ):
+                getattr(pser_r, name)
+
+    def _test_resample(self, pobj, psobj, rules, funcs):
+        for rule in rules:
+            for func in funcs:
+                for closed in [None, "left", "right"]:
+                    for label in [None, "left", "right"]:
+                        p_resample = pobj.resample(rule=rule, closed=closed, 
label=label)
+                        ps_resample = psobj.resample(rule=rule, closed=closed, 
label=label)
+                        self.assert_eq(
+                            getattr(p_resample, func)().sort_index(),
+                            getattr(ps_resample, func)().sort_index(),
+                            almost=True,
+                        )
+
+    def test_dataframe_resample(self):
+        self._test_resample(
+            self.pdf1,
+            self.psdf1,
+            ["Y", "3Y", "M", "9M", "D", "17D"],
+            ["min", "max", "sum", "mean", "std", "var"],
+        )
+        self._test_resample(self.pdf2, self.psdf2, ["3A", "A", "11M", "D"], 
["sum"])
+        self._test_resample(self.pdf3, self.psdf3, ["27H", "1D", "2D", "1M"], 
["sum"])
+        self._test_resample(self.pdf4, self.psdf4, ["1H", "5H", "D", "2D"], 
["sum"])
+        self._test_resample(self.pdf5, self.psdf5, ["1T", "2T", "5MIN", "1H", 
"2H", "D"], ["sum"])
+        self._test_resample(self.pdf6, self.psdf6, ["1S", "2S", "1MIN", "H", 
"2H"], ["sum"])
+
+    def test_series_resample(self):
+        self._test_resample(self.pdf1.A, self.psdf1.A, ["4Y"], ["sum"])
+        self._test_resample(self.pdf2.A, self.psdf2.A, ["13M"], ["sum"])
+        self._test_resample(self.pdf3.A, self.psdf3.A, ["18H"], ["sum"])
+        self._test_resample(self.pdf4.A, self.psdf4.A, ["6D"], ["sum"])
+        self._test_resample(self.pdf5.A, self.psdf5.A, ["47T"], ["sum"])
+        self._test_resample(self.pdf6.A, self.psdf6.A, ["37S"], ["sum"])
+
+    def test_resample_on(self):
+        np.random.seed(77)
+        dates = [
+            datetime.datetime(2022, 5, 1, 4, 5, 6),
+            datetime.datetime(2022, 5, 3),
+            datetime.datetime(2022, 5, 3, 23, 59, 59),
+            datetime.datetime(2022, 5, 4),
+            pd.NaT,
+            datetime.datetime(2022, 5, 4, 0, 0, 1),
+            datetime.datetime(2022, 5, 11),
+        ]
+        pdf = pd.DataFrame(
+            np.random.rand(len(dates), 3), index=pd.DatetimeIndex(dates), 
columns=list("ABC")
+        )
+        pdf["X"] = pd.DatetimeIndex(dates)
+        psdf = ps.from_pandas(pdf)
+        self.assert_eq(
+            pdf.resample("2D", on="X").sum().sort_index(),
+            psdf.resample("2D", on=psdf.X).sum().sort_index(),
+            almost=True,
+        )
+
+
+if __name__ == "__main__":
+    from pyspark.pandas.tests.test_resample import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index 2cc595ed2bf..19cfbd0bbe7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.api.python
 
 import java.io.InputStream
 import java.nio.channels.Channels
+import java.util.Locale
 
 import net.razorvine.pickle.Pickler
 
@@ -101,6 +102,26 @@ private[sql] object PythonSQLUtils extends Logging {
   def lastNonNull(e: Column): Column = Column(LastNonNull(e.expr))
 
   def nullIndex(e: Column): Column = Column(NullIndex(e.expr))
+
+  def makeInterval(unit: String, e: Column): Column = {
+    val zero = MakeInterval(years = Literal(0), months = Literal(0), weeks = 
Literal(0),
+      days = Literal(0), hours = Literal(0), mins = Literal(0), secs = 
Literal(0))
+
+    unit.toUpperCase(Locale.ROOT) match {
+      case "YEAR" => Column(zero.copy(years = e.expr))
+      case "MONTH" => Column(zero.copy(months = e.expr))
+      case "WEEK" => Column(zero.copy(weeks = e.expr))
+      case "DAY" => Column(zero.copy(days = e.expr))
+      case "HOUR" => Column(zero.copy(hours = e.expr))
+      case "MINUTE" => Column(zero.copy(mins = e.expr))
+      case "SECOND" => Column(zero.copy(secs = e.expr))
+      case _ => throw new IllegalStateException(s"Got the unexpected unit 
'$unit'.")
+    }
+  }
+
+  def timestampDiff(unit: String, start: Column, end: Column): Column = {
+    Column(TimestampDiff(unit, start.expr, end.expr))
+  }
 }
 
 /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to