This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 15f25cb  [SPARK-37374][PYTHON] Fix StatCounter to use mergeStats when 
merging with self
15f25cb is described below

commit 15f25cbb39c1b7945c9ffacb164594583c79e1a0
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Fri Nov 19 11:15:35 2021 +0900

    [SPARK-37374][PYTHON] Fix StatCounter to use mergeStats when merging with 
self
    
    ### What changes were proposed in this pull request?
    
    Fixes `StatCounter` to use `mergeStats` instead of `merge` when merging 
with `self`.
    
    ### Why are the changes needed?
    
    `StatCounter` should use `mergeStats` instead of `merge` when merging with 
`self`.
    
    ```py
    >>> from pyspark.statcounter import StatCounter
    >>> stats = StatCounter([1.0, 2.0, 3.0, 4.0])
    >>> stats.mergeStats(stats)
    Traceback (most recent call last):
    
    ...
    
    TypeError: unsupported operand type(s) for -: 'StatCounter' and 'float'
    ```
    
    This is a long standing bug but usually this bug won't be hit unless users 
explicitly use `mergeStats` with `self`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Added some tests.
    
    Closes #34653 from ueshin/issues/SPARK-37374/statcounter.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 dev/sparktestsupport/modules.py          |   1 +
 python/pyspark/statcounter.py            |   2 +-
 python/pyspark/tests/test_statcounter.py | 104 +++++++++++++++++++++++++++++++
 3 files changed, 106 insertions(+), 1 deletion(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 7d3ebb0..d13be2e 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -402,6 +402,7 @@ pyspark_core = Module(
         "pyspark.tests.test_readwrite",
         "pyspark.tests.test_serializers",
         "pyspark.tests.test_shuffle",
+        "pyspark.tests.test_statcounter",
         "pyspark.tests.test_taskcontext",
         "pyspark.tests.test_util",
         "pyspark.tests.test_worker",
diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py
index bf40281..a994671 100644
--- a/python/pyspark/statcounter.py
+++ b/python/pyspark/statcounter.py
@@ -59,7 +59,7 @@ class StatCounter(object):
             raise TypeError("Can only merge StatCounter but got %s" % 
type(other))
 
         if other is self:  # reference equality holds
-            self.merge(copy.deepcopy(other))  # type: ignore[arg-type]  # 
Avoid overwriting fields in a weird order
+            self.mergeStats(other.copy())  # Avoid overwriting fields in a 
weird order
         else:
             if self.n == 0:
                 self.mu = other.mu
diff --git a/python/pyspark/tests/test_statcounter.py 
b/python/pyspark/tests/test_statcounter.py
new file mode 100644
index 0000000..9651871
--- /dev/null
+++ b/python/pyspark/tests/test_statcounter.py
@@ -0,0 +1,104 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from pyspark.statcounter import StatCounter
+from pyspark.testing.utils import ReusedPySparkTestCase
+
+
+class StatCounterTests(ReusedPySparkTestCase):
+    def test_base(self):
+        stats = self.sc.parallelize([1.0, 2.0, 3.0, 4.0]).stats()
+        self.assertEqual(stats.count(), 4)
+        self.assertEqual(stats.max(), 4.0)
+        self.assertEqual(stats.mean(), 2.5)
+        self.assertEqual(stats.min(), 1.0)
+        self.assertAlmostEqual(stats.stdev(), 1.118033988749895)
+        self.assertAlmostEqual(stats.sampleStdev(), 1.2909944487358056)
+        self.assertEqual(stats.sum(), 10.0)
+        self.assertAlmostEqual(stats.variance(), 1.25)
+        self.assertAlmostEqual(stats.sampleVariance(), 1.6666666666666667)
+
+    def test_as_dict(self):
+        stats = self.sc.parallelize([1.0, 2.0, 3.0, 4.0]).stats().asDict()
+        self.assertEqual(stats["count"], 4)
+        self.assertEqual(stats["max"], 4.0)
+        self.assertEqual(stats["mean"], 2.5)
+        self.assertEqual(stats["min"], 1.0)
+        self.assertAlmostEqual(stats["stdev"], 1.2909944487358056)
+        self.assertEqual(stats["sum"], 10.0)
+        self.assertAlmostEqual(stats["variance"], 1.6666666666666667)
+
+        stats = self.sc.parallelize([1.0, 2.0, 3.0, 
4.0]).stats().asDict(sample=True)
+        self.assertEqual(stats["count"], 4)
+        self.assertEqual(stats["max"], 4.0)
+        self.assertEqual(stats["mean"], 2.5)
+        self.assertEqual(stats["min"], 1.0)
+        self.assertAlmostEqual(stats["stdev"], 1.118033988749895)
+        self.assertEqual(stats["sum"], 10.0)
+        self.assertAlmostEqual(stats["variance"], 1.25)
+
+    def test_merge(self):
+        stats = StatCounter([1.0, 2.0, 3.0, 4.0])
+        stats.merge(5.0)
+        self.assertEqual(stats.count(), 5)
+        self.assertEqual(stats.max(), 5.0)
+        self.assertEqual(stats.mean(), 3.0)
+        self.assertEqual(stats.min(), 1.0)
+        self.assertAlmostEqual(stats.stdev(), 1.414213562373095)
+        self.assertAlmostEqual(stats.sampleStdev(), 1.5811388300841898)
+        self.assertEqual(stats.sum(), 15.0)
+        self.assertAlmostEqual(stats.variance(), 2.0)
+        self.assertAlmostEqual(stats.sampleVariance(), 2.5)
+
+    def test_merge_stats(self):
+        stats1 = StatCounter([1.0, 2.0, 3.0, 4.0])
+        stats2 = StatCounter([1.0, 2.0, 3.0, 4.0])
+        stats = stats1.mergeStats(stats2)
+        self.assertEqual(stats.count(), 8)
+        self.assertEqual(stats.max(), 4.0)
+        self.assertEqual(stats.mean(), 2.5)
+        self.assertEqual(stats.min(), 1.0)
+        self.assertAlmostEqual(stats.stdev(), 1.118033988749895)
+        self.assertAlmostEqual(stats.sampleStdev(), 1.1952286093343936)
+        self.assertEqual(stats.sum(), 20.0)
+        self.assertAlmostEqual(stats.variance(), 1.25)
+        self.assertAlmostEqual(stats.sampleVariance(), 1.4285714285714286)
+
+    def test_merge_stats_with_self(self):
+        stats = StatCounter([1.0, 2.0, 3.0, 4.0])
+        stats.mergeStats(stats)
+        self.assertEqual(stats.count(), 8)
+        self.assertEqual(stats.max(), 4.0)
+        self.assertEqual(stats.mean(), 2.5)
+        self.assertEqual(stats.min(), 1.0)
+        self.assertAlmostEqual(stats.stdev(), 1.118033988749895)
+        self.assertAlmostEqual(stats.sampleStdev(), 1.1952286093343936)
+        self.assertEqual(stats.sum(), 20.0)
+        self.assertAlmostEqual(stats.variance(), 1.25)
+        self.assertAlmostEqual(stats.sampleVariance(), 1.4285714285714286)
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.tests.test_statcounter import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to