This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 55d6b51af7cd [SPARK-46340][PS][CONNECT][TESTS] Reorganize `EWMTests` 55d6b51af7cd is described below commit 55d6b51af7cd9108752eea65e7eef13da01118e8 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Sun Dec 10 08:53:19 2023 +0800 [SPARK-46340][PS][CONNECT][TESTS] Reorganize `EWMTests` ### What changes were proposed in this pull request? Reorganize `EWMTests` ### Why are the changes needed? break it into smaller files to be consistent with pandas tests (see https://github.com/pandas-dev/pandas/tree/main/pandas/tests/window ) ### Does this PR introduce _any_ user-facing change? no, test-only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #44273 from zhengruifeng/ps_test_ewm. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- dev/sparktestsupport/modules.py | 8 +- .../{test_parity_ewm.py => window/__init__.py} | 21 -- .../test_parity_ewm_error.py} | 11 +- .../test_parity_ewm_mean.py} | 11 +- .../test_parity_groupby_ewm_mean.py} | 11 +- .../test_parity_ewm.py => window/__init__.py} | 21 -- .../pyspark/pandas/tests/window/test_ewm_error.py | 97 ++++++++++ .../pyspark/pandas/tests/window/test_ewm_mean.py | 194 +++++++++++++++++++ .../test_groupby_ewm_mean.py} | 215 +-------------------- 9 files changed, 328 insertions(+), 261 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index e67cfce0f5c0..ca35fdabc0c4 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -729,7 +729,9 @@ pyspark_pandas = Module( "pyspark.pandas.tests.test_default_index", "pyspark.pandas.tests.test_expanding", "pyspark.pandas.tests.test_extension", - "pyspark.pandas.tests.test_ewm", + "pyspark.pandas.tests.window.test_ewm_error", + "pyspark.pandas.tests.window.test_ewm_mean", + "pyspark.pandas.tests.window.test_groupby_ewm_mean", "pyspark.pandas.tests.test_frame_spark", "pyspark.pandas.tests.test_generic_functions", "pyspark.pandas.tests.test_frame_interpolate", @@ -1113,7 +1115,9 @@ pyspark_pandas_connect_part2 = Module( "pyspark.pandas.tests.connect.test_parity_series_interpolate", "pyspark.pandas.tests.connect.resample.test_parity_frame", "pyspark.pandas.tests.connect.resample.test_parity_series", - "pyspark.pandas.tests.connect.test_parity_ewm", + "pyspark.pandas.tests.connect.window.test_parity_ewm_error", + "pyspark.pandas.tests.connect.window.test_parity_ewm_mean", + "pyspark.pandas.tests.connect.window.test_parity_groupby_ewm_mean", "pyspark.pandas.tests.connect.test_parity_rolling", "pyspark.pandas.tests.connect.test_parity_expanding", "pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames_groupby_rolling", diff --git a/python/pyspark/pandas/tests/connect/test_parity_ewm.py b/python/pyspark/pandas/tests/connect/window/__init__.py similarity index 53% copy from python/pyspark/pandas/tests/connect/test_parity_ewm.py copy to python/pyspark/pandas/tests/connect/window/__init__.py index 748728203337..cce3acad34a4 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_ewm.py +++ b/python/pyspark/pandas/tests/connect/window/__init__.py @@ -14,24 +14,3 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import unittest - -from pyspark.pandas.tests.test_ewm import EWMTestsMixin -from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils - - -class EWMParityTests(EWMTestsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase, TestUtils): - pass - - -if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_ewm import * # noqa: F401 - - try: - import xmlrunner # type: ignore[import] - - testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) - except ImportError: - testRunner = None - unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/connect/test_parity_ewm.py b/python/pyspark/pandas/tests/connect/window/test_parity_ewm_error.py similarity index 81% copy from python/pyspark/pandas/tests/connect/test_parity_ewm.py copy to python/pyspark/pandas/tests/connect/window/test_parity_ewm_error.py index 748728203337..7f6b0e8494cf 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_ewm.py +++ b/python/pyspark/pandas/tests/connect/window/test_parity_ewm_error.py @@ -16,17 +16,22 @@ # import unittest -from pyspark.pandas.tests.test_ewm import EWMTestsMixin +from pyspark.pandas.tests.window.test_ewm_error import EWMErrorMixin from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils -class EWMParityTests(EWMTestsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase, TestUtils): +class EWMParityErrorTests( + EWMErrorMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, + TestUtils, +): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_ewm import * # noqa: F401 + from pyspark.pandas.tests.connect.window.test_parity_ewm_error import * # noqa: F401 try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/pandas/tests/connect/test_parity_ewm.py b/python/pyspark/pandas/tests/connect/window/test_parity_ewm_mean.py similarity index 81% copy from python/pyspark/pandas/tests/connect/test_parity_ewm.py copy to python/pyspark/pandas/tests/connect/window/test_parity_ewm_mean.py index 748728203337..8c7144799bce 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_ewm.py +++ b/python/pyspark/pandas/tests/connect/window/test_parity_ewm_mean.py @@ -16,17 +16,22 @@ # import unittest -from pyspark.pandas.tests.test_ewm import EWMTestsMixin +from pyspark.pandas.tests.window.test_ewm_mean import EWMMeanMixin from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils -class EWMParityTests(EWMTestsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase, TestUtils): +class EWMParityMeanTests( + EWMMeanMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, + TestUtils, +): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_ewm import * # noqa: F401 + from pyspark.pandas.tests.connect.window.test_parity_ewm_mean import * # noqa: F401 try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/pandas/tests/connect/test_parity_ewm.py b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_ewm_mean.py similarity index 79% copy from python/pyspark/pandas/tests/connect/test_parity_ewm.py copy to python/pyspark/pandas/tests/connect/window/test_parity_groupby_ewm_mean.py index 748728203337..76254698b757 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_ewm.py +++ b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_ewm_mean.py @@ -16,17 +16,22 @@ # import unittest -from pyspark.pandas.tests.test_ewm import EWMTestsMixin +from pyspark.pandas.tests.window.test_groupby_ewm_mean import GroupByEWMMeanMixin from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils -class EWMParityTests(EWMTestsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase, TestUtils): +class EWMParityGroupByMeanTests( + GroupByEWMMeanMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, + TestUtils, +): pass if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_ewm import * # noqa: F401 + from pyspark.pandas.tests.connect.window.test_parity_groupby_ewm_mean import * # noqa: F401 try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/pandas/tests/connect/test_parity_ewm.py b/python/pyspark/pandas/tests/window/__init__.py similarity index 53% rename from python/pyspark/pandas/tests/connect/test_parity_ewm.py rename to python/pyspark/pandas/tests/window/__init__.py index 748728203337..cce3acad34a4 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_ewm.py +++ b/python/pyspark/pandas/tests/window/__init__.py @@ -14,24 +14,3 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import unittest - -from pyspark.pandas.tests.test_ewm import EWMTestsMixin -from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils - - -class EWMParityTests(EWMTestsMixin, PandasOnSparkTestUtils, ReusedConnectTestCase, TestUtils): - pass - - -if __name__ == "__main__": - from pyspark.pandas.tests.connect.test_parity_ewm import * # noqa: F401 - - try: - import xmlrunner # type: ignore[import] - - testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) - except ImportError: - testRunner = None - unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/window/test_ewm_error.py b/python/pyspark/pandas/tests/window/test_ewm_error.py new file mode 100644 index 000000000000..02018fb10617 --- /dev/null +++ b/python/pyspark/pandas/tests/window/test_ewm_error.py @@ -0,0 +1,97 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pyspark.pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils +from pyspark.pandas.window import ExponentialMoving + + +class EWMErrorMixin: + def test_ewm_error(self): + with self.assertRaisesRegex( + TypeError, "psdf_or_psser must be a series or dataframe; however, got:.*int" + ): + ExponentialMoving(1, 2) + + psdf = ps.range(10) + + with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"): + psdf.ewm(min_periods=-1, alpha=0.5).mean() + + with self.assertRaisesRegex(ValueError, "com must be >= 0"): + psdf.ewm(com=-0.1).mean() + + with self.assertRaisesRegex(ValueError, "span must be >= 1"): + psdf.ewm(span=0.7).mean() + + with self.assertRaisesRegex(ValueError, "halflife must be > 0"): + psdf.ewm(halflife=0).mean() + + with self.assertRaisesRegex(ValueError, "alpha must be in"): + psdf.ewm(alpha=1.7).mean() + + with self.assertRaisesRegex(ValueError, "Must pass one of com, span, halflife, or alpha"): + psdf.ewm().mean() + + with self.assertRaisesRegex( + ValueError, "com, span, halflife, and alpha are mutually exclusive" + ): + psdf.ewm(com=0.5, alpha=0.7).mean() + + with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"): + psdf.groupby(psdf.id).ewm(min_periods=-1, alpha=0.5).mean() + + with self.assertRaisesRegex(ValueError, "com must be >= 0"): + psdf.groupby(psdf.id).ewm(com=-0.1).mean() + + with self.assertRaisesRegex(ValueError, "span must be >= 1"): + psdf.groupby(psdf.id).ewm(span=0.7).mean() + + with self.assertRaisesRegex(ValueError, "halflife must be > 0"): + psdf.groupby(psdf.id).ewm(halflife=0).mean() + + with self.assertRaisesRegex(ValueError, "alpha must be in"): + psdf.groupby(psdf.id).ewm(alpha=1.7).mean() + + with self.assertRaisesRegex(ValueError, "Must pass one of com, span, halflife, or alpha"): + psdf.groupby(psdf.id).ewm().mean() + + with self.assertRaisesRegex( + ValueError, "com, span, halflife, and alpha are mutually exclusive" + ): + psdf.groupby(psdf.id).ewm(com=0.5, alpha=0.7).mean() + + +class EWMErrorTests( + EWMErrorMixin, + PandasOnSparkTestCase, + TestUtils, +): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.pandas.tests.window.test_ewm_error import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/window/test_ewm_mean.py b/python/pyspark/pandas/tests/window/test_ewm_mean.py new file mode 100644 index 000000000000..00750b867610 --- /dev/null +++ b/python/pyspark/pandas/tests/window/test_ewm_mean.py @@ -0,0 +1,194 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import pandas as pd + +import pyspark.pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils + + +class EWMMeanMixin: + def _test_ewm_func(self, f): + pser = pd.Series([1, 2, 3], index=np.random.rand(3), name="a") + psser = ps.from_pandas(pser) + self.assert_eq(getattr(psser.ewm(com=0.2), f)(), getattr(pser.ewm(com=0.2), f)()) + self.assert_eq( + getattr(psser.ewm(com=0.2), f)().sum(), getattr(pser.ewm(com=0.2), f)().sum() + ) + self.assert_eq(getattr(psser.ewm(span=1.7), f)(), getattr(pser.ewm(span=1.7), f)()) + self.assert_eq( + getattr(psser.ewm(span=1.7), f)().sum(), getattr(pser.ewm(span=1.7), f)().sum() + ) + self.assert_eq(getattr(psser.ewm(halflife=0.5), f)(), getattr(pser.ewm(halflife=0.5), f)()) + self.assert_eq( + getattr(psser.ewm(halflife=0.5), f)().sum(), getattr(pser.ewm(halflife=0.5), f)().sum() + ) + self.assert_eq(getattr(psser.ewm(alpha=0.7), f)(), getattr(pser.ewm(alpha=0.7), f)()) + self.assert_eq( + getattr(psser.ewm(alpha=0.7), f)().sum(), getattr(pser.ewm(alpha=0.7), f)().sum() + ) + self.assert_eq( + getattr(psser.ewm(alpha=0.7, min_periods=2), f)(), + getattr(pser.ewm(alpha=0.7, min_periods=2), f)(), + ) + self.assert_eq( + getattr(psser.ewm(alpha=0.7, min_periods=2), f)().sum(), + getattr(pser.ewm(alpha=0.7, min_periods=2), f)().sum(), + ) + + pdf = pd.DataFrame( + {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}, index=np.random.rand(4) + ) + psdf = ps.from_pandas(pdf) + self.assert_eq(getattr(psdf.ewm(com=0.2), f)(), getattr(pdf.ewm(com=0.2), f)()) + self.assert_eq(getattr(psdf.ewm(com=0.2), f)().sum(), getattr(pdf.ewm(com=0.2), f)().sum()) + self.assert_eq(getattr(psdf.ewm(span=1.7), f)(), getattr(pdf.ewm(span=1.7), f)()) + self.assert_eq( + getattr(psdf.ewm(span=1.7), f)().sum(), getattr(pdf.ewm(span=1.7), f)().sum() + ) + self.assert_eq(getattr(psdf.ewm(halflife=0.5), f)(), getattr(pdf.ewm(halflife=0.5), f)()) + self.assert_eq( + getattr(psdf.ewm(halflife=0.5), f)().sum(), getattr(pdf.ewm(halflife=0.5), f)().sum() + ) + self.assert_eq(getattr(psdf.ewm(alpha=0.7), f)(), getattr(pdf.ewm(alpha=0.7), f)()) + self.assert_eq( + getattr(psdf.ewm(alpha=0.7), f)().sum(), getattr(pdf.ewm(alpha=0.7), f)().sum() + ) + self.assert_eq( + getattr(psdf.ewm(alpha=0.7, min_periods=2), f)(), + getattr(pdf.ewm(alpha=0.7, min_periods=2), f)(), + ) + self.assert_eq( + getattr(psdf.ewm(alpha=0.7, min_periods=2), f)().sum(), + getattr(pdf.ewm(alpha=0.7, min_periods=2), f)().sum(), + ) + + pdf = pd.DataFrame( + { + "s1": [None, 2, 3, 4], + "s2": [1, None, 3, 4], + "s3": [1, 3, 4, 5], + "s4": [1, 0, 3, 4], + "s5": [None, None, 1, None], + "s6": [None, None, None, None], + } + ) + psdf = ps.from_pandas(pdf) + self.assert_eq( + getattr(psdf.ewm(com=0.2, ignore_na=True), f)(), + getattr(pdf.ewm(com=0.2, ignore_na=True), f)(), + ) + self.assert_eq( + getattr(psdf.ewm(com=0.2, ignore_na=True), f)().sum(), + getattr(pdf.ewm(com=0.2, ignore_na=True), f)().sum(), + ) + self.assert_eq( + getattr(psdf.ewm(com=0.2, ignore_na=False), f)(), + getattr(pdf.ewm(com=0.2, ignore_na=False), f)(), + ) + self.assert_eq( + getattr(psdf.ewm(com=0.2, ignore_na=False), f)().sum(), + getattr(pdf.ewm(com=0.2, ignore_na=False), f)().sum(), + ) + self.assert_eq( + getattr(psdf.ewm(span=1.7, ignore_na=True), f)(), + getattr(pdf.ewm(span=1.7, ignore_na=True), f)(), + ) + self.assert_eq( + getattr(psdf.ewm(span=1.7, ignore_na=True), f)().sum(), + getattr(pdf.ewm(span=1.7, ignore_na=True), f)().sum(), + ) + self.assert_eq( + getattr(psdf.ewm(span=1.7, ignore_na=False), f)(), + getattr(pdf.ewm(span=1.7, ignore_na=False), f)(), + ) + self.assert_eq( + getattr(psdf.ewm(span=1.7, ignore_na=False), f)().sum(), + getattr(pdf.ewm(span=1.7, ignore_na=False), f)().sum(), + ) + self.assert_eq( + getattr(psdf.ewm(halflife=0.5, ignore_na=True), f)(), + getattr(pdf.ewm(halflife=0.5, ignore_na=True), f)(), + ) + self.assert_eq( + getattr(psdf.ewm(halflife=0.5, ignore_na=True), f)().sum(), + getattr(pdf.ewm(halflife=0.5, ignore_na=True), f)().sum(), + ) + self.assert_eq( + getattr(psdf.ewm(halflife=0.5, ignore_na=False), f)(), + getattr(pdf.ewm(halflife=0.5, ignore_na=False), f)(), + ) + self.assert_eq( + getattr(psdf.ewm(halflife=0.5, ignore_na=False), f)().sum(), + getattr(pdf.ewm(halflife=0.5, ignore_na=False), f)().sum(), + ) + self.assert_eq( + getattr(psdf.ewm(alpha=0.7, ignore_na=True), f)(), + getattr(pdf.ewm(alpha=0.7, ignore_na=True), f)(), + ) + self.assert_eq( + getattr(psdf.ewm(alpha=0.7, ignore_na=True), f)().sum(), + getattr(pdf.ewm(alpha=0.7, ignore_na=True), f)().sum(), + ) + self.assert_eq( + getattr(psdf.ewm(alpha=0.7, ignore_na=False), f)(), + getattr(pdf.ewm(alpha=0.7, ignore_na=False), f)(), + ) + self.assert_eq( + getattr(psdf.ewm(alpha=0.7, ignore_na=False), f)().sum(), + getattr(pdf.ewm(alpha=0.7, ignore_na=False), f)().sum(), + ) + self.assert_eq( + getattr(psdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), f)(), + getattr(pdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), f)(), + ) + self.assert_eq( + getattr(psdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), f)().sum(), + getattr(pdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), f)().sum(), + ) + self.assert_eq( + getattr(psdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), f)(), + getattr(pdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), f)(), + ) + self.assert_eq( + getattr(psdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), f)().sum(), + getattr(pdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), f)().sum(), + ) + + def test_ewm_mean(self): + self._test_ewm_func("mean") + + +class EWMMeanTests( + EWMMeanMixin, + PandasOnSparkTestCase, + TestUtils, +): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.pandas.tests.window.test_ewm_mean import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/test_ewm.py b/python/pyspark/pandas/tests/window/test_groupby_ewm_mean.py similarity index 52% rename from python/pyspark/pandas/tests/test_ewm.py rename to python/pyspark/pandas/tests/window/test_groupby_ewm_mean.py index a8886a0af69c..fb29cb8ea04f 100644 --- a/python/pyspark/pandas/tests/test_ewm.py +++ b/python/pyspark/pandas/tests/window/test_groupby_ewm_mean.py @@ -19,214 +19,9 @@ import pandas as pd import pyspark.pandas as ps from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils -from pyspark.pandas.window import ExponentialMoving -class EWMTestsMixin: - def test_ewm_error(self): - with self.assertRaisesRegex( - TypeError, "psdf_or_psser must be a series or dataframe; however, got:.*int" - ): - ExponentialMoving(1, 2) - - psdf = ps.range(10) - - with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"): - psdf.ewm(min_periods=-1, alpha=0.5).mean() - - with self.assertRaisesRegex(ValueError, "com must be >= 0"): - psdf.ewm(com=-0.1).mean() - - with self.assertRaisesRegex(ValueError, "span must be >= 1"): - psdf.ewm(span=0.7).mean() - - with self.assertRaisesRegex(ValueError, "halflife must be > 0"): - psdf.ewm(halflife=0).mean() - - with self.assertRaisesRegex(ValueError, "alpha must be in"): - psdf.ewm(alpha=1.7).mean() - - with self.assertRaisesRegex(ValueError, "Must pass one of com, span, halflife, or alpha"): - psdf.ewm().mean() - - with self.assertRaisesRegex( - ValueError, "com, span, halflife, and alpha are mutually exclusive" - ): - psdf.ewm(com=0.5, alpha=0.7).mean() - - with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"): - psdf.groupby(psdf.id).ewm(min_periods=-1, alpha=0.5).mean() - - with self.assertRaisesRegex(ValueError, "com must be >= 0"): - psdf.groupby(psdf.id).ewm(com=-0.1).mean() - - with self.assertRaisesRegex(ValueError, "span must be >= 1"): - psdf.groupby(psdf.id).ewm(span=0.7).mean() - - with self.assertRaisesRegex(ValueError, "halflife must be > 0"): - psdf.groupby(psdf.id).ewm(halflife=0).mean() - - with self.assertRaisesRegex(ValueError, "alpha must be in"): - psdf.groupby(psdf.id).ewm(alpha=1.7).mean() - - with self.assertRaisesRegex(ValueError, "Must pass one of com, span, halflife, or alpha"): - psdf.groupby(psdf.id).ewm().mean() - - with self.assertRaisesRegex( - ValueError, "com, span, halflife, and alpha are mutually exclusive" - ): - psdf.groupby(psdf.id).ewm(com=0.5, alpha=0.7).mean() - - def _test_ewm_func(self, f): - pser = pd.Series([1, 2, 3], index=np.random.rand(3), name="a") - psser = ps.from_pandas(pser) - self.assert_eq(getattr(psser.ewm(com=0.2), f)(), getattr(pser.ewm(com=0.2), f)()) - self.assert_eq( - getattr(psser.ewm(com=0.2), f)().sum(), getattr(pser.ewm(com=0.2), f)().sum() - ) - self.assert_eq(getattr(psser.ewm(span=1.7), f)(), getattr(pser.ewm(span=1.7), f)()) - self.assert_eq( - getattr(psser.ewm(span=1.7), f)().sum(), getattr(pser.ewm(span=1.7), f)().sum() - ) - self.assert_eq(getattr(psser.ewm(halflife=0.5), f)(), getattr(pser.ewm(halflife=0.5), f)()) - self.assert_eq( - getattr(psser.ewm(halflife=0.5), f)().sum(), getattr(pser.ewm(halflife=0.5), f)().sum() - ) - self.assert_eq(getattr(psser.ewm(alpha=0.7), f)(), getattr(pser.ewm(alpha=0.7), f)()) - self.assert_eq( - getattr(psser.ewm(alpha=0.7), f)().sum(), getattr(pser.ewm(alpha=0.7), f)().sum() - ) - self.assert_eq( - getattr(psser.ewm(alpha=0.7, min_periods=2), f)(), - getattr(pser.ewm(alpha=0.7, min_periods=2), f)(), - ) - self.assert_eq( - getattr(psser.ewm(alpha=0.7, min_periods=2), f)().sum(), - getattr(pser.ewm(alpha=0.7, min_periods=2), f)().sum(), - ) - - pdf = pd.DataFrame( - {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}, index=np.random.rand(4) - ) - psdf = ps.from_pandas(pdf) - self.assert_eq(getattr(psdf.ewm(com=0.2), f)(), getattr(pdf.ewm(com=0.2), f)()) - self.assert_eq(getattr(psdf.ewm(com=0.2), f)().sum(), getattr(pdf.ewm(com=0.2), f)().sum()) - self.assert_eq(getattr(psdf.ewm(span=1.7), f)(), getattr(pdf.ewm(span=1.7), f)()) - self.assert_eq( - getattr(psdf.ewm(span=1.7), f)().sum(), getattr(pdf.ewm(span=1.7), f)().sum() - ) - self.assert_eq(getattr(psdf.ewm(halflife=0.5), f)(), getattr(pdf.ewm(halflife=0.5), f)()) - self.assert_eq( - getattr(psdf.ewm(halflife=0.5), f)().sum(), getattr(pdf.ewm(halflife=0.5), f)().sum() - ) - self.assert_eq(getattr(psdf.ewm(alpha=0.7), f)(), getattr(pdf.ewm(alpha=0.7), f)()) - self.assert_eq( - getattr(psdf.ewm(alpha=0.7), f)().sum(), getattr(pdf.ewm(alpha=0.7), f)().sum() - ) - self.assert_eq( - getattr(psdf.ewm(alpha=0.7, min_periods=2), f)(), - getattr(pdf.ewm(alpha=0.7, min_periods=2), f)(), - ) - self.assert_eq( - getattr(psdf.ewm(alpha=0.7, min_periods=2), f)().sum(), - getattr(pdf.ewm(alpha=0.7, min_periods=2), f)().sum(), - ) - - pdf = pd.DataFrame( - { - "s1": [None, 2, 3, 4], - "s2": [1, None, 3, 4], - "s3": [1, 3, 4, 5], - "s4": [1, 0, 3, 4], - "s5": [None, None, 1, None], - "s6": [None, None, None, None], - } - ) - psdf = ps.from_pandas(pdf) - self.assert_eq( - getattr(psdf.ewm(com=0.2, ignore_na=True), f)(), - getattr(pdf.ewm(com=0.2, ignore_na=True), f)(), - ) - self.assert_eq( - getattr(psdf.ewm(com=0.2, ignore_na=True), f)().sum(), - getattr(pdf.ewm(com=0.2, ignore_na=True), f)().sum(), - ) - self.assert_eq( - getattr(psdf.ewm(com=0.2, ignore_na=False), f)(), - getattr(pdf.ewm(com=0.2, ignore_na=False), f)(), - ) - self.assert_eq( - getattr(psdf.ewm(com=0.2, ignore_na=False), f)().sum(), - getattr(pdf.ewm(com=0.2, ignore_na=False), f)().sum(), - ) - self.assert_eq( - getattr(psdf.ewm(span=1.7, ignore_na=True), f)(), - getattr(pdf.ewm(span=1.7, ignore_na=True), f)(), - ) - self.assert_eq( - getattr(psdf.ewm(span=1.7, ignore_na=True), f)().sum(), - getattr(pdf.ewm(span=1.7, ignore_na=True), f)().sum(), - ) - self.assert_eq( - getattr(psdf.ewm(span=1.7, ignore_na=False), f)(), - getattr(pdf.ewm(span=1.7, ignore_na=False), f)(), - ) - self.assert_eq( - getattr(psdf.ewm(span=1.7, ignore_na=False), f)().sum(), - getattr(pdf.ewm(span=1.7, ignore_na=False), f)().sum(), - ) - self.assert_eq( - getattr(psdf.ewm(halflife=0.5, ignore_na=True), f)(), - getattr(pdf.ewm(halflife=0.5, ignore_na=True), f)(), - ) - self.assert_eq( - getattr(psdf.ewm(halflife=0.5, ignore_na=True), f)().sum(), - getattr(pdf.ewm(halflife=0.5, ignore_na=True), f)().sum(), - ) - self.assert_eq( - getattr(psdf.ewm(halflife=0.5, ignore_na=False), f)(), - getattr(pdf.ewm(halflife=0.5, ignore_na=False), f)(), - ) - self.assert_eq( - getattr(psdf.ewm(halflife=0.5, ignore_na=False), f)().sum(), - getattr(pdf.ewm(halflife=0.5, ignore_na=False), f)().sum(), - ) - self.assert_eq( - getattr(psdf.ewm(alpha=0.7, ignore_na=True), f)(), - getattr(pdf.ewm(alpha=0.7, ignore_na=True), f)(), - ) - self.assert_eq( - getattr(psdf.ewm(alpha=0.7, ignore_na=True), f)().sum(), - getattr(pdf.ewm(alpha=0.7, ignore_na=True), f)().sum(), - ) - self.assert_eq( - getattr(psdf.ewm(alpha=0.7, ignore_na=False), f)(), - getattr(pdf.ewm(alpha=0.7, ignore_na=False), f)(), - ) - self.assert_eq( - getattr(psdf.ewm(alpha=0.7, ignore_na=False), f)().sum(), - getattr(pdf.ewm(alpha=0.7, ignore_na=False), f)().sum(), - ) - self.assert_eq( - getattr(psdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), f)(), - getattr(pdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), f)(), - ) - self.assert_eq( - getattr(psdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), f)().sum(), - getattr(pdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), f)().sum(), - ) - self.assert_eq( - getattr(psdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), f)(), - getattr(pdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), f)(), - ) - self.assert_eq( - getattr(psdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), f)().sum(), - getattr(pdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), f)().sum(), - ) - - def test_ewm_mean(self): - self._test_ewm_func("mean") - +class GroupByEWMMeanMixin: def _test_groupby_ewm_func(self, f): pser = pd.Series([1, 2, 3, 2], index=np.random.rand(4), name="a") psser = ps.from_pandas(pser) @@ -417,13 +212,17 @@ class EWMTestsMixin: self._test_groupby_ewm_func("mean") -class EWMTests(EWMTestsMixin, PandasOnSparkTestCase, TestUtils): +class GroupByEWMMeanTests( + GroupByEWMMeanMixin, + PandasOnSparkTestCase, + TestUtils, +): pass if __name__ == "__main__": import unittest - from pyspark.pandas.tests.test_ewm import * # noqa: F401 + from pyspark.pandas.tests.window.test_groupby_ewm_mean import * # noqa: F401 try: import xmlrunner --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org