This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 15f25cb [SPARK-37374][PYTHON] Fix StatCounter to use mergeStats when merging with self 15f25cb is described below commit 15f25cbb39c1b7945c9ffacb164594583c79e1a0 Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Fri Nov 19 11:15:35 2021 +0900 [SPARK-37374][PYTHON] Fix StatCounter to use mergeStats when merging with self ### What changes were proposed in this pull request? Fixes `StatCounter` to use `mergeStats` instead of `merge` when merging with `self`. ### Why are the changes needed? `StatCounter` should use `mergeStats` instead of `merge` when merging with `self`. ```py >>> from pyspark.statcounter import StatCounter >>> stats = StatCounter([1.0, 2.0, 3.0, 4.0]) >>> stats.mergeStats(stats) Traceback (most recent call last): ... TypeError: unsupported operand type(s) for -: 'StatCounter' and 'float' ``` This is a long standing bug but usually this bug won't be hit unless users explicitly use `mergeStats` with `self`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added some tests. Closes #34653 from ueshin/issues/SPARK-37374/statcounter. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- dev/sparktestsupport/modules.py | 1 + python/pyspark/statcounter.py | 2 +- python/pyspark/tests/test_statcounter.py | 104 +++++++++++++++++++++++++++++++ 3 files changed, 106 insertions(+), 1 deletion(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 7d3ebb0..d13be2e 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -402,6 +402,7 @@ pyspark_core = Module( "pyspark.tests.test_readwrite", "pyspark.tests.test_serializers", "pyspark.tests.test_shuffle", + "pyspark.tests.test_statcounter", "pyspark.tests.test_taskcontext", "pyspark.tests.test_util", "pyspark.tests.test_worker", diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py index bf40281..a994671 100644 --- a/python/pyspark/statcounter.py +++ b/python/pyspark/statcounter.py @@ -59,7 +59,7 @@ class StatCounter(object): raise TypeError("Can only merge StatCounter but got %s" % type(other)) if other is self: # reference equality holds - self.merge(copy.deepcopy(other)) # type: ignore[arg-type] # Avoid overwriting fields in a weird order + self.mergeStats(other.copy()) # Avoid overwriting fields in a weird order else: if self.n == 0: self.mu = other.mu diff --git a/python/pyspark/tests/test_statcounter.py b/python/pyspark/tests/test_statcounter.py new file mode 100644 index 0000000..9651871 --- /dev/null +++ b/python/pyspark/tests/test_statcounter.py @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from pyspark.statcounter import StatCounter +from pyspark.testing.utils import ReusedPySparkTestCase + + +class StatCounterTests(ReusedPySparkTestCase): + def test_base(self): + stats = self.sc.parallelize([1.0, 2.0, 3.0, 4.0]).stats() + self.assertEqual(stats.count(), 4) + self.assertEqual(stats.max(), 4.0) + self.assertEqual(stats.mean(), 2.5) + self.assertEqual(stats.min(), 1.0) + self.assertAlmostEqual(stats.stdev(), 1.118033988749895) + self.assertAlmostEqual(stats.sampleStdev(), 1.2909944487358056) + self.assertEqual(stats.sum(), 10.0) + self.assertAlmostEqual(stats.variance(), 1.25) + self.assertAlmostEqual(stats.sampleVariance(), 1.6666666666666667) + + def test_as_dict(self): + stats = self.sc.parallelize([1.0, 2.0, 3.0, 4.0]).stats().asDict() + self.assertEqual(stats["count"], 4) + self.assertEqual(stats["max"], 4.0) + self.assertEqual(stats["mean"], 2.5) + self.assertEqual(stats["min"], 1.0) + self.assertAlmostEqual(stats["stdev"], 1.2909944487358056) + self.assertEqual(stats["sum"], 10.0) + self.assertAlmostEqual(stats["variance"], 1.6666666666666667) + + stats = self.sc.parallelize([1.0, 2.0, 3.0, 4.0]).stats().asDict(sample=True) + self.assertEqual(stats["count"], 4) + self.assertEqual(stats["max"], 4.0) + self.assertEqual(stats["mean"], 2.5) + self.assertEqual(stats["min"], 1.0) + self.assertAlmostEqual(stats["stdev"], 1.118033988749895) + self.assertEqual(stats["sum"], 10.0) + self.assertAlmostEqual(stats["variance"], 1.25) + + def test_merge(self): + stats = StatCounter([1.0, 2.0, 3.0, 4.0]) + stats.merge(5.0) + self.assertEqual(stats.count(), 5) + self.assertEqual(stats.max(), 5.0) + self.assertEqual(stats.mean(), 3.0) + self.assertEqual(stats.min(), 1.0) + self.assertAlmostEqual(stats.stdev(), 1.414213562373095) + self.assertAlmostEqual(stats.sampleStdev(), 1.5811388300841898) + self.assertEqual(stats.sum(), 15.0) + self.assertAlmostEqual(stats.variance(), 2.0) + self.assertAlmostEqual(stats.sampleVariance(), 2.5) + + def test_merge_stats(self): + stats1 = StatCounter([1.0, 2.0, 3.0, 4.0]) + stats2 = StatCounter([1.0, 2.0, 3.0, 4.0]) + stats = stats1.mergeStats(stats2) + self.assertEqual(stats.count(), 8) + self.assertEqual(stats.max(), 4.0) + self.assertEqual(stats.mean(), 2.5) + self.assertEqual(stats.min(), 1.0) + self.assertAlmostEqual(stats.stdev(), 1.118033988749895) + self.assertAlmostEqual(stats.sampleStdev(), 1.1952286093343936) + self.assertEqual(stats.sum(), 20.0) + self.assertAlmostEqual(stats.variance(), 1.25) + self.assertAlmostEqual(stats.sampleVariance(), 1.4285714285714286) + + def test_merge_stats_with_self(self): + stats = StatCounter([1.0, 2.0, 3.0, 4.0]) + stats.mergeStats(stats) + self.assertEqual(stats.count(), 8) + self.assertEqual(stats.max(), 4.0) + self.assertEqual(stats.mean(), 2.5) + self.assertEqual(stats.min(), 1.0) + self.assertAlmostEqual(stats.stdev(), 1.118033988749895) + self.assertAlmostEqual(stats.sampleStdev(), 1.1952286093343936) + self.assertEqual(stats.sum(), 20.0) + self.assertAlmostEqual(stats.variance(), 1.25) + self.assertAlmostEqual(stats.sampleVariance(), 1.4285714285714286) + + +if __name__ == "__main__": + import unittest + from pyspark.tests.test_statcounter import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org