This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 55d6b51af7cd [SPARK-46340][PS][CONNECT][TESTS] Reorganize `EWMTests`
55d6b51af7cd is described below

commit 55d6b51af7cd9108752eea65e7eef13da01118e8
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Sun Dec 10 08:53:19 2023 +0800

    [SPARK-46340][PS][CONNECT][TESTS] Reorganize `EWMTests`
    
    ### What changes were proposed in this pull request?
    Reorganize `EWMTests`
    
    ### Why are the changes needed?
    break it into smaller files to be consistent with pandas tests (see 
https://github.com/pandas-dev/pandas/tree/main/pandas/tests/window )
    
    ### Does this PR introduce _any_ user-facing change?
    no, test-only
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #44273 from zhengruifeng/ps_test_ewm.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 dev/sparktestsupport/modules.py                    |   8 +-
 .../{test_parity_ewm.py => window/__init__.py}     |  21 --
 .../test_parity_ewm_error.py}                      |  11 +-
 .../test_parity_ewm_mean.py}                       |  11 +-
 .../test_parity_groupby_ewm_mean.py}               |  11 +-
 .../test_parity_ewm.py => window/__init__.py}      |  21 --
 .../pyspark/pandas/tests/window/test_ewm_error.py  |  97 ++++++++++
 .../pyspark/pandas/tests/window/test_ewm_mean.py   | 194 +++++++++++++++++++
 .../test_groupby_ewm_mean.py}                      | 215 +--------------------
 9 files changed, 328 insertions(+), 261 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index e67cfce0f5c0..ca35fdabc0c4 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -729,7 +729,9 @@ pyspark_pandas = Module(
         "pyspark.pandas.tests.test_default_index",
         "pyspark.pandas.tests.test_expanding",
         "pyspark.pandas.tests.test_extension",
-        "pyspark.pandas.tests.test_ewm",
+        "pyspark.pandas.tests.window.test_ewm_error",
+        "pyspark.pandas.tests.window.test_ewm_mean",
+        "pyspark.pandas.tests.window.test_groupby_ewm_mean",
         "pyspark.pandas.tests.test_frame_spark",
         "pyspark.pandas.tests.test_generic_functions",
         "pyspark.pandas.tests.test_frame_interpolate",
@@ -1113,7 +1115,9 @@ pyspark_pandas_connect_part2 = Module(
         "pyspark.pandas.tests.connect.test_parity_series_interpolate",
         "pyspark.pandas.tests.connect.resample.test_parity_frame",
         "pyspark.pandas.tests.connect.resample.test_parity_series",
-        "pyspark.pandas.tests.connect.test_parity_ewm",
+        "pyspark.pandas.tests.connect.window.test_parity_ewm_error",
+        "pyspark.pandas.tests.connect.window.test_parity_ewm_mean",
+        "pyspark.pandas.tests.connect.window.test_parity_groupby_ewm_mean",
         "pyspark.pandas.tests.connect.test_parity_rolling",
         "pyspark.pandas.tests.connect.test_parity_expanding",
         
"pyspark.pandas.tests.connect.test_parity_ops_on_diff_frames_groupby_rolling",
diff --git a/python/pyspark/pandas/tests/connect/test_parity_ewm.py 
b/python/pyspark/pandas/tests/connect/window/__init__.py
similarity index 53%
copy from python/pyspark/pandas/tests/connect/test_parity_ewm.py
copy to python/pyspark/pandas/tests/connect/window/__init__.py
index 748728203337..cce3acad34a4 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_ewm.py
+++ b/python/pyspark/pandas/tests/connect/window/__init__.py
@@ -14,24 +14,3 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-import unittest
-
-from pyspark.pandas.tests.test_ewm import EWMTestsMixin
-from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
-
-
-class EWMParityTests(EWMTestsMixin, PandasOnSparkTestUtils, 
ReusedConnectTestCase, TestUtils):
-    pass
-
-
-if __name__ == "__main__":
-    from pyspark.pandas.tests.connect.test_parity_ewm import *  # noqa: F401
-
-    try:
-        import xmlrunner  # type: ignore[import]
-
-        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
-    except ImportError:
-        testRunner = None
-    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/connect/test_parity_ewm.py 
b/python/pyspark/pandas/tests/connect/window/test_parity_ewm_error.py
similarity index 81%
copy from python/pyspark/pandas/tests/connect/test_parity_ewm.py
copy to python/pyspark/pandas/tests/connect/window/test_parity_ewm_error.py
index 748728203337..7f6b0e8494cf 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_ewm.py
+++ b/python/pyspark/pandas/tests/connect/window/test_parity_ewm_error.py
@@ -16,17 +16,22 @@
 #
 import unittest
 
-from pyspark.pandas.tests.test_ewm import EWMTestsMixin
+from pyspark.pandas.tests.window.test_ewm_error import EWMErrorMixin
 from pyspark.testing.connectutils import ReusedConnectTestCase
 from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
 
 
-class EWMParityTests(EWMTestsMixin, PandasOnSparkTestUtils, 
ReusedConnectTestCase, TestUtils):
+class EWMParityErrorTests(
+    EWMErrorMixin,
+    PandasOnSparkTestUtils,
+    ReusedConnectTestCase,
+    TestUtils,
+):
     pass
 
 
 if __name__ == "__main__":
-    from pyspark.pandas.tests.connect.test_parity_ewm import *  # noqa: F401
+    from pyspark.pandas.tests.connect.window.test_parity_ewm_error import *  # 
noqa: F401
 
     try:
         import xmlrunner  # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/test_parity_ewm.py 
b/python/pyspark/pandas/tests/connect/window/test_parity_ewm_mean.py
similarity index 81%
copy from python/pyspark/pandas/tests/connect/test_parity_ewm.py
copy to python/pyspark/pandas/tests/connect/window/test_parity_ewm_mean.py
index 748728203337..8c7144799bce 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_ewm.py
+++ b/python/pyspark/pandas/tests/connect/window/test_parity_ewm_mean.py
@@ -16,17 +16,22 @@
 #
 import unittest
 
-from pyspark.pandas.tests.test_ewm import EWMTestsMixin
+from pyspark.pandas.tests.window.test_ewm_mean import EWMMeanMixin
 from pyspark.testing.connectutils import ReusedConnectTestCase
 from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
 
 
-class EWMParityTests(EWMTestsMixin, PandasOnSparkTestUtils, 
ReusedConnectTestCase, TestUtils):
+class EWMParityMeanTests(
+    EWMMeanMixin,
+    PandasOnSparkTestUtils,
+    ReusedConnectTestCase,
+    TestUtils,
+):
     pass
 
 
 if __name__ == "__main__":
-    from pyspark.pandas.tests.connect.test_parity_ewm import *  # noqa: F401
+    from pyspark.pandas.tests.connect.window.test_parity_ewm_mean import *  # 
noqa: F401
 
     try:
         import xmlrunner  # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/test_parity_ewm.py 
b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_ewm_mean.py
similarity index 79%
copy from python/pyspark/pandas/tests/connect/test_parity_ewm.py
copy to 
python/pyspark/pandas/tests/connect/window/test_parity_groupby_ewm_mean.py
index 748728203337..76254698b757 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_ewm.py
+++ b/python/pyspark/pandas/tests/connect/window/test_parity_groupby_ewm_mean.py
@@ -16,17 +16,22 @@
 #
 import unittest
 
-from pyspark.pandas.tests.test_ewm import EWMTestsMixin
+from pyspark.pandas.tests.window.test_groupby_ewm_mean import 
GroupByEWMMeanMixin
 from pyspark.testing.connectutils import ReusedConnectTestCase
 from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
 
 
-class EWMParityTests(EWMTestsMixin, PandasOnSparkTestUtils, 
ReusedConnectTestCase, TestUtils):
+class EWMParityGroupByMeanTests(
+    GroupByEWMMeanMixin,
+    PandasOnSparkTestUtils,
+    ReusedConnectTestCase,
+    TestUtils,
+):
     pass
 
 
 if __name__ == "__main__":
-    from pyspark.pandas.tests.connect.test_parity_ewm import *  # noqa: F401
+    from pyspark.pandas.tests.connect.window.test_parity_groupby_ewm_mean 
import *  # noqa: F401
 
     try:
         import xmlrunner  # type: ignore[import]
diff --git a/python/pyspark/pandas/tests/connect/test_parity_ewm.py 
b/python/pyspark/pandas/tests/window/__init__.py
similarity index 53%
rename from python/pyspark/pandas/tests/connect/test_parity_ewm.py
rename to python/pyspark/pandas/tests/window/__init__.py
index 748728203337..cce3acad34a4 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_ewm.py
+++ b/python/pyspark/pandas/tests/window/__init__.py
@@ -14,24 +14,3 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-import unittest
-
-from pyspark.pandas.tests.test_ewm import EWMTestsMixin
-from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
-
-
-class EWMParityTests(EWMTestsMixin, PandasOnSparkTestUtils, 
ReusedConnectTestCase, TestUtils):
-    pass
-
-
-if __name__ == "__main__":
-    from pyspark.pandas.tests.connect.test_parity_ewm import *  # noqa: F401
-
-    try:
-        import xmlrunner  # type: ignore[import]
-
-        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
-    except ImportError:
-        testRunner = None
-    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/window/test_ewm_error.py 
b/python/pyspark/pandas/tests/window/test_ewm_error.py
new file mode 100644
index 000000000000..02018fb10617
--- /dev/null
+++ b/python/pyspark/pandas/tests/window/test_ewm_error.py
@@ -0,0 +1,97 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pyspark.pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
+from pyspark.pandas.window import ExponentialMoving
+
+
+class EWMErrorMixin:
+    def test_ewm_error(self):
+        with self.assertRaisesRegex(
+            TypeError, "psdf_or_psser must be a series or dataframe; however, 
got:.*int"
+        ):
+            ExponentialMoving(1, 2)
+
+        psdf = ps.range(10)
+
+        with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"):
+            psdf.ewm(min_periods=-1, alpha=0.5).mean()
+
+        with self.assertRaisesRegex(ValueError, "com must be >= 0"):
+            psdf.ewm(com=-0.1).mean()
+
+        with self.assertRaisesRegex(ValueError, "span must be >= 1"):
+            psdf.ewm(span=0.7).mean()
+
+        with self.assertRaisesRegex(ValueError, "halflife must be > 0"):
+            psdf.ewm(halflife=0).mean()
+
+        with self.assertRaisesRegex(ValueError, "alpha must be in"):
+            psdf.ewm(alpha=1.7).mean()
+
+        with self.assertRaisesRegex(ValueError, "Must pass one of com, span, 
halflife, or alpha"):
+            psdf.ewm().mean()
+
+        with self.assertRaisesRegex(
+            ValueError, "com, span, halflife, and alpha are mutually exclusive"
+        ):
+            psdf.ewm(com=0.5, alpha=0.7).mean()
+
+        with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"):
+            psdf.groupby(psdf.id).ewm(min_periods=-1, alpha=0.5).mean()
+
+        with self.assertRaisesRegex(ValueError, "com must be >= 0"):
+            psdf.groupby(psdf.id).ewm(com=-0.1).mean()
+
+        with self.assertRaisesRegex(ValueError, "span must be >= 1"):
+            psdf.groupby(psdf.id).ewm(span=0.7).mean()
+
+        with self.assertRaisesRegex(ValueError, "halflife must be > 0"):
+            psdf.groupby(psdf.id).ewm(halflife=0).mean()
+
+        with self.assertRaisesRegex(ValueError, "alpha must be in"):
+            psdf.groupby(psdf.id).ewm(alpha=1.7).mean()
+
+        with self.assertRaisesRegex(ValueError, "Must pass one of com, span, 
halflife, or alpha"):
+            psdf.groupby(psdf.id).ewm().mean()
+
+        with self.assertRaisesRegex(
+            ValueError, "com, span, halflife, and alpha are mutually exclusive"
+        ):
+            psdf.groupby(psdf.id).ewm(com=0.5, alpha=0.7).mean()
+
+
+class EWMErrorTests(
+    EWMErrorMixin,
+    PandasOnSparkTestCase,
+    TestUtils,
+):
+    pass
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.pandas.tests.window.test_ewm_error import *  # noqa: F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/window/test_ewm_mean.py 
b/python/pyspark/pandas/tests/window/test_ewm_mean.py
new file mode 100644
index 000000000000..00750b867610
--- /dev/null
+++ b/python/pyspark/pandas/tests/window/test_ewm_mean.py
@@ -0,0 +1,194 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import numpy as np
+import pandas as pd
+
+import pyspark.pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
+
+
+class EWMMeanMixin:
+    def _test_ewm_func(self, f):
+        pser = pd.Series([1, 2, 3], index=np.random.rand(3), name="a")
+        psser = ps.from_pandas(pser)
+        self.assert_eq(getattr(psser.ewm(com=0.2), f)(), 
getattr(pser.ewm(com=0.2), f)())
+        self.assert_eq(
+            getattr(psser.ewm(com=0.2), f)().sum(), getattr(pser.ewm(com=0.2), 
f)().sum()
+        )
+        self.assert_eq(getattr(psser.ewm(span=1.7), f)(), 
getattr(pser.ewm(span=1.7), f)())
+        self.assert_eq(
+            getattr(psser.ewm(span=1.7), f)().sum(), 
getattr(pser.ewm(span=1.7), f)().sum()
+        )
+        self.assert_eq(getattr(psser.ewm(halflife=0.5), f)(), 
getattr(pser.ewm(halflife=0.5), f)())
+        self.assert_eq(
+            getattr(psser.ewm(halflife=0.5), f)().sum(), 
getattr(pser.ewm(halflife=0.5), f)().sum()
+        )
+        self.assert_eq(getattr(psser.ewm(alpha=0.7), f)(), 
getattr(pser.ewm(alpha=0.7), f)())
+        self.assert_eq(
+            getattr(psser.ewm(alpha=0.7), f)().sum(), 
getattr(pser.ewm(alpha=0.7), f)().sum()
+        )
+        self.assert_eq(
+            getattr(psser.ewm(alpha=0.7, min_periods=2), f)(),
+            getattr(pser.ewm(alpha=0.7, min_periods=2), f)(),
+        )
+        self.assert_eq(
+            getattr(psser.ewm(alpha=0.7, min_periods=2), f)().sum(),
+            getattr(pser.ewm(alpha=0.7, min_periods=2), f)().sum(),
+        )
+
+        pdf = pd.DataFrame(
+            {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}, 
index=np.random.rand(4)
+        )
+        psdf = ps.from_pandas(pdf)
+        self.assert_eq(getattr(psdf.ewm(com=0.2), f)(), 
getattr(pdf.ewm(com=0.2), f)())
+        self.assert_eq(getattr(psdf.ewm(com=0.2), f)().sum(), 
getattr(pdf.ewm(com=0.2), f)().sum())
+        self.assert_eq(getattr(psdf.ewm(span=1.7), f)(), 
getattr(pdf.ewm(span=1.7), f)())
+        self.assert_eq(
+            getattr(psdf.ewm(span=1.7), f)().sum(), getattr(pdf.ewm(span=1.7), 
f)().sum()
+        )
+        self.assert_eq(getattr(psdf.ewm(halflife=0.5), f)(), 
getattr(pdf.ewm(halflife=0.5), f)())
+        self.assert_eq(
+            getattr(psdf.ewm(halflife=0.5), f)().sum(), 
getattr(pdf.ewm(halflife=0.5), f)().sum()
+        )
+        self.assert_eq(getattr(psdf.ewm(alpha=0.7), f)(), 
getattr(pdf.ewm(alpha=0.7), f)())
+        self.assert_eq(
+            getattr(psdf.ewm(alpha=0.7), f)().sum(), 
getattr(pdf.ewm(alpha=0.7), f)().sum()
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(alpha=0.7, min_periods=2), f)(),
+            getattr(pdf.ewm(alpha=0.7, min_periods=2), f)(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(alpha=0.7, min_periods=2), f)().sum(),
+            getattr(pdf.ewm(alpha=0.7, min_periods=2), f)().sum(),
+        )
+
+        pdf = pd.DataFrame(
+            {
+                "s1": [None, 2, 3, 4],
+                "s2": [1, None, 3, 4],
+                "s3": [1, 3, 4, 5],
+                "s4": [1, 0, 3, 4],
+                "s5": [None, None, 1, None],
+                "s6": [None, None, None, None],
+            }
+        )
+        psdf = ps.from_pandas(pdf)
+        self.assert_eq(
+            getattr(psdf.ewm(com=0.2, ignore_na=True), f)(),
+            getattr(pdf.ewm(com=0.2, ignore_na=True), f)(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(com=0.2, ignore_na=True), f)().sum(),
+            getattr(pdf.ewm(com=0.2, ignore_na=True), f)().sum(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(com=0.2, ignore_na=False), f)(),
+            getattr(pdf.ewm(com=0.2, ignore_na=False), f)(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(com=0.2, ignore_na=False), f)().sum(),
+            getattr(pdf.ewm(com=0.2, ignore_na=False), f)().sum(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(span=1.7, ignore_na=True), f)(),
+            getattr(pdf.ewm(span=1.7, ignore_na=True), f)(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(span=1.7, ignore_na=True), f)().sum(),
+            getattr(pdf.ewm(span=1.7, ignore_na=True), f)().sum(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(span=1.7, ignore_na=False), f)(),
+            getattr(pdf.ewm(span=1.7, ignore_na=False), f)(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(span=1.7, ignore_na=False), f)().sum(),
+            getattr(pdf.ewm(span=1.7, ignore_na=False), f)().sum(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(halflife=0.5, ignore_na=True), f)(),
+            getattr(pdf.ewm(halflife=0.5, ignore_na=True), f)(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(halflife=0.5, ignore_na=True), f)().sum(),
+            getattr(pdf.ewm(halflife=0.5, ignore_na=True), f)().sum(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(halflife=0.5, ignore_na=False), f)(),
+            getattr(pdf.ewm(halflife=0.5, ignore_na=False), f)(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(halflife=0.5, ignore_na=False), f)().sum(),
+            getattr(pdf.ewm(halflife=0.5, ignore_na=False), f)().sum(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(alpha=0.7, ignore_na=True), f)(),
+            getattr(pdf.ewm(alpha=0.7, ignore_na=True), f)(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(alpha=0.7, ignore_na=True), f)().sum(),
+            getattr(pdf.ewm(alpha=0.7, ignore_na=True), f)().sum(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(alpha=0.7, ignore_na=False), f)(),
+            getattr(pdf.ewm(alpha=0.7, ignore_na=False), f)(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(alpha=0.7, ignore_na=False), f)().sum(),
+            getattr(pdf.ewm(alpha=0.7, ignore_na=False), f)().sum(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), f)(),
+            getattr(pdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), f)(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), 
f)().sum(),
+            getattr(pdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), 
f)().sum(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), f)(),
+            getattr(pdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), f)(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), 
f)().sum(),
+            getattr(pdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), 
f)().sum(),
+        )
+
+    def test_ewm_mean(self):
+        self._test_ewm_func("mean")
+
+
+class EWMMeanTests(
+    EWMMeanMixin,
+    PandasOnSparkTestCase,
+    TestUtils,
+):
+    pass
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.pandas.tests.window.test_ewm_mean import *  # noqa: F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/test_ewm.py 
b/python/pyspark/pandas/tests/window/test_groupby_ewm_mean.py
similarity index 52%
rename from python/pyspark/pandas/tests/test_ewm.py
rename to python/pyspark/pandas/tests/window/test_groupby_ewm_mean.py
index a8886a0af69c..fb29cb8ea04f 100644
--- a/python/pyspark/pandas/tests/test_ewm.py
+++ b/python/pyspark/pandas/tests/window/test_groupby_ewm_mean.py
@@ -19,214 +19,9 @@ import pandas as pd
 
 import pyspark.pandas as ps
 from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
-from pyspark.pandas.window import ExponentialMoving
 
 
-class EWMTestsMixin:
-    def test_ewm_error(self):
-        with self.assertRaisesRegex(
-            TypeError, "psdf_or_psser must be a series or dataframe; however, 
got:.*int"
-        ):
-            ExponentialMoving(1, 2)
-
-        psdf = ps.range(10)
-
-        with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"):
-            psdf.ewm(min_periods=-1, alpha=0.5).mean()
-
-        with self.assertRaisesRegex(ValueError, "com must be >= 0"):
-            psdf.ewm(com=-0.1).mean()
-
-        with self.assertRaisesRegex(ValueError, "span must be >= 1"):
-            psdf.ewm(span=0.7).mean()
-
-        with self.assertRaisesRegex(ValueError, "halflife must be > 0"):
-            psdf.ewm(halflife=0).mean()
-
-        with self.assertRaisesRegex(ValueError, "alpha must be in"):
-            psdf.ewm(alpha=1.7).mean()
-
-        with self.assertRaisesRegex(ValueError, "Must pass one of com, span, 
halflife, or alpha"):
-            psdf.ewm().mean()
-
-        with self.assertRaisesRegex(
-            ValueError, "com, span, halflife, and alpha are mutually exclusive"
-        ):
-            psdf.ewm(com=0.5, alpha=0.7).mean()
-
-        with self.assertRaisesRegex(ValueError, "min_periods must be >= 0"):
-            psdf.groupby(psdf.id).ewm(min_periods=-1, alpha=0.5).mean()
-
-        with self.assertRaisesRegex(ValueError, "com must be >= 0"):
-            psdf.groupby(psdf.id).ewm(com=-0.1).mean()
-
-        with self.assertRaisesRegex(ValueError, "span must be >= 1"):
-            psdf.groupby(psdf.id).ewm(span=0.7).mean()
-
-        with self.assertRaisesRegex(ValueError, "halflife must be > 0"):
-            psdf.groupby(psdf.id).ewm(halflife=0).mean()
-
-        with self.assertRaisesRegex(ValueError, "alpha must be in"):
-            psdf.groupby(psdf.id).ewm(alpha=1.7).mean()
-
-        with self.assertRaisesRegex(ValueError, "Must pass one of com, span, 
halflife, or alpha"):
-            psdf.groupby(psdf.id).ewm().mean()
-
-        with self.assertRaisesRegex(
-            ValueError, "com, span, halflife, and alpha are mutually exclusive"
-        ):
-            psdf.groupby(psdf.id).ewm(com=0.5, alpha=0.7).mean()
-
-    def _test_ewm_func(self, f):
-        pser = pd.Series([1, 2, 3], index=np.random.rand(3), name="a")
-        psser = ps.from_pandas(pser)
-        self.assert_eq(getattr(psser.ewm(com=0.2), f)(), 
getattr(pser.ewm(com=0.2), f)())
-        self.assert_eq(
-            getattr(psser.ewm(com=0.2), f)().sum(), getattr(pser.ewm(com=0.2), 
f)().sum()
-        )
-        self.assert_eq(getattr(psser.ewm(span=1.7), f)(), 
getattr(pser.ewm(span=1.7), f)())
-        self.assert_eq(
-            getattr(psser.ewm(span=1.7), f)().sum(), 
getattr(pser.ewm(span=1.7), f)().sum()
-        )
-        self.assert_eq(getattr(psser.ewm(halflife=0.5), f)(), 
getattr(pser.ewm(halflife=0.5), f)())
-        self.assert_eq(
-            getattr(psser.ewm(halflife=0.5), f)().sum(), 
getattr(pser.ewm(halflife=0.5), f)().sum()
-        )
-        self.assert_eq(getattr(psser.ewm(alpha=0.7), f)(), 
getattr(pser.ewm(alpha=0.7), f)())
-        self.assert_eq(
-            getattr(psser.ewm(alpha=0.7), f)().sum(), 
getattr(pser.ewm(alpha=0.7), f)().sum()
-        )
-        self.assert_eq(
-            getattr(psser.ewm(alpha=0.7, min_periods=2), f)(),
-            getattr(pser.ewm(alpha=0.7, min_periods=2), f)(),
-        )
-        self.assert_eq(
-            getattr(psser.ewm(alpha=0.7, min_periods=2), f)().sum(),
-            getattr(pser.ewm(alpha=0.7, min_periods=2), f)().sum(),
-        )
-
-        pdf = pd.DataFrame(
-            {"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}, 
index=np.random.rand(4)
-        )
-        psdf = ps.from_pandas(pdf)
-        self.assert_eq(getattr(psdf.ewm(com=0.2), f)(), 
getattr(pdf.ewm(com=0.2), f)())
-        self.assert_eq(getattr(psdf.ewm(com=0.2), f)().sum(), 
getattr(pdf.ewm(com=0.2), f)().sum())
-        self.assert_eq(getattr(psdf.ewm(span=1.7), f)(), 
getattr(pdf.ewm(span=1.7), f)())
-        self.assert_eq(
-            getattr(psdf.ewm(span=1.7), f)().sum(), getattr(pdf.ewm(span=1.7), 
f)().sum()
-        )
-        self.assert_eq(getattr(psdf.ewm(halflife=0.5), f)(), 
getattr(pdf.ewm(halflife=0.5), f)())
-        self.assert_eq(
-            getattr(psdf.ewm(halflife=0.5), f)().sum(), 
getattr(pdf.ewm(halflife=0.5), f)().sum()
-        )
-        self.assert_eq(getattr(psdf.ewm(alpha=0.7), f)(), 
getattr(pdf.ewm(alpha=0.7), f)())
-        self.assert_eq(
-            getattr(psdf.ewm(alpha=0.7), f)().sum(), 
getattr(pdf.ewm(alpha=0.7), f)().sum()
-        )
-        self.assert_eq(
-            getattr(psdf.ewm(alpha=0.7, min_periods=2), f)(),
-            getattr(pdf.ewm(alpha=0.7, min_periods=2), f)(),
-        )
-        self.assert_eq(
-            getattr(psdf.ewm(alpha=0.7, min_periods=2), f)().sum(),
-            getattr(pdf.ewm(alpha=0.7, min_periods=2), f)().sum(),
-        )
-
-        pdf = pd.DataFrame(
-            {
-                "s1": [None, 2, 3, 4],
-                "s2": [1, None, 3, 4],
-                "s3": [1, 3, 4, 5],
-                "s4": [1, 0, 3, 4],
-                "s5": [None, None, 1, None],
-                "s6": [None, None, None, None],
-            }
-        )
-        psdf = ps.from_pandas(pdf)
-        self.assert_eq(
-            getattr(psdf.ewm(com=0.2, ignore_na=True), f)(),
-            getattr(pdf.ewm(com=0.2, ignore_na=True), f)(),
-        )
-        self.assert_eq(
-            getattr(psdf.ewm(com=0.2, ignore_na=True), f)().sum(),
-            getattr(pdf.ewm(com=0.2, ignore_na=True), f)().sum(),
-        )
-        self.assert_eq(
-            getattr(psdf.ewm(com=0.2, ignore_na=False), f)(),
-            getattr(pdf.ewm(com=0.2, ignore_na=False), f)(),
-        )
-        self.assert_eq(
-            getattr(psdf.ewm(com=0.2, ignore_na=False), f)().sum(),
-            getattr(pdf.ewm(com=0.2, ignore_na=False), f)().sum(),
-        )
-        self.assert_eq(
-            getattr(psdf.ewm(span=1.7, ignore_na=True), f)(),
-            getattr(pdf.ewm(span=1.7, ignore_na=True), f)(),
-        )
-        self.assert_eq(
-            getattr(psdf.ewm(span=1.7, ignore_na=True), f)().sum(),
-            getattr(pdf.ewm(span=1.7, ignore_na=True), f)().sum(),
-        )
-        self.assert_eq(
-            getattr(psdf.ewm(span=1.7, ignore_na=False), f)(),
-            getattr(pdf.ewm(span=1.7, ignore_na=False), f)(),
-        )
-        self.assert_eq(
-            getattr(psdf.ewm(span=1.7, ignore_na=False), f)().sum(),
-            getattr(pdf.ewm(span=1.7, ignore_na=False), f)().sum(),
-        )
-        self.assert_eq(
-            getattr(psdf.ewm(halflife=0.5, ignore_na=True), f)(),
-            getattr(pdf.ewm(halflife=0.5, ignore_na=True), f)(),
-        )
-        self.assert_eq(
-            getattr(psdf.ewm(halflife=0.5, ignore_na=True), f)().sum(),
-            getattr(pdf.ewm(halflife=0.5, ignore_na=True), f)().sum(),
-        )
-        self.assert_eq(
-            getattr(psdf.ewm(halflife=0.5, ignore_na=False), f)(),
-            getattr(pdf.ewm(halflife=0.5, ignore_na=False), f)(),
-        )
-        self.assert_eq(
-            getattr(psdf.ewm(halflife=0.5, ignore_na=False), f)().sum(),
-            getattr(pdf.ewm(halflife=0.5, ignore_na=False), f)().sum(),
-        )
-        self.assert_eq(
-            getattr(psdf.ewm(alpha=0.7, ignore_na=True), f)(),
-            getattr(pdf.ewm(alpha=0.7, ignore_na=True), f)(),
-        )
-        self.assert_eq(
-            getattr(psdf.ewm(alpha=0.7, ignore_na=True), f)().sum(),
-            getattr(pdf.ewm(alpha=0.7, ignore_na=True), f)().sum(),
-        )
-        self.assert_eq(
-            getattr(psdf.ewm(alpha=0.7, ignore_na=False), f)(),
-            getattr(pdf.ewm(alpha=0.7, ignore_na=False), f)(),
-        )
-        self.assert_eq(
-            getattr(psdf.ewm(alpha=0.7, ignore_na=False), f)().sum(),
-            getattr(pdf.ewm(alpha=0.7, ignore_na=False), f)().sum(),
-        )
-        self.assert_eq(
-            getattr(psdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), f)(),
-            getattr(pdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), f)(),
-        )
-        self.assert_eq(
-            getattr(psdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), 
f)().sum(),
-            getattr(pdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), 
f)().sum(),
-        )
-        self.assert_eq(
-            getattr(psdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), f)(),
-            getattr(pdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), f)(),
-        )
-        self.assert_eq(
-            getattr(psdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), 
f)().sum(),
-            getattr(pdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), 
f)().sum(),
-        )
-
-    def test_ewm_mean(self):
-        self._test_ewm_func("mean")
-
+class GroupByEWMMeanMixin:
     def _test_groupby_ewm_func(self, f):
         pser = pd.Series([1, 2, 3, 2], index=np.random.rand(4), name="a")
         psser = ps.from_pandas(pser)
@@ -417,13 +212,17 @@ class EWMTestsMixin:
         self._test_groupby_ewm_func("mean")
 
 
-class EWMTests(EWMTestsMixin, PandasOnSparkTestCase, TestUtils):
+class GroupByEWMMeanTests(
+    GroupByEWMMeanMixin,
+    PandasOnSparkTestCase,
+    TestUtils,
+):
     pass
 
 
 if __name__ == "__main__":
     import unittest
-    from pyspark.pandas.tests.test_ewm import *  # noqa: F401
+    from pyspark.pandas.tests.window.test_groupby_ewm_mean import *  # noqa: 
F401
 
     try:
         import xmlrunner


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to