Repository: spark Updated Branches: refs/heads/master 219922422 -> cb90617f8
[SPARK-25591][PYSPARK][SQL] Avoid overwriting deserialized accumulator ## What changes were proposed in this pull request? If we use accumulators in more than one UDFs, it is possible to overwrite deserialized accumulators and its values. We should check if an accumulator was deserialized before overwriting it in accumulator registry. ## How was this patch tested? Added test. Closes #22635 from viirya/SPARK-25591. Authored-by: Liang-Chi Hsieh <vii...@gmail.com> Signed-off-by: hyukjinkwon <gurwls...@apache.org> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cb90617f Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cb90617f Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cb90617f Branch: refs/heads/master Commit: cb90617f894fd51a092710271823ec7d1cd3a668 Parents: 2199224 Author: Liang-Chi Hsieh <vii...@gmail.com> Authored: Mon Oct 8 15:18:08 2018 +0800 Committer: hyukjinkwon <gurwls...@apache.org> Committed: Mon Oct 8 15:18:08 2018 +0800 ---------------------------------------------------------------------- python/pyspark/accumulators.py | 12 ++++++++---- python/pyspark/sql/tests.py | 25 +++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/cb90617f/python/pyspark/accumulators.py ---------------------------------------------------------------------- diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 30ad042..00ec094 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -109,10 +109,14 @@ _accumulatorRegistry = {} def _deserialize_accumulator(aid, zero_value, accum_param): from pyspark.accumulators import _accumulatorRegistry - accum = Accumulator(aid, zero_value, accum_param) - accum._deserialized = True - _accumulatorRegistry[aid] = accum - return accum + # If this certain accumulator was deserialized, don't overwrite it. + if aid in _accumulatorRegistry: + return _accumulatorRegistry[aid] + else: + accum = Accumulator(aid, zero_value, accum_param) + accum._deserialized = True + _accumulatorRegistry[aid] = accum + return accum class Accumulator(object): http://git-wip-us.apache.org/repos/asf/spark/blob/cb90617f/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d3c29d0..ac87ccd 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3603,6 +3603,31 @@ class SQLTests(ReusedSQLTestCase): self.assertEquals(None, df._repr_html_()) self.assertEquals(expected, df.__repr__()) + # SPARK-25591 + def test_same_accumulator_in_udfs(self): + from pyspark.sql.functions import udf + + data_schema = StructType([StructField("a", IntegerType(), True), + StructField("b", IntegerType(), True)]) + data = self.spark.createDataFrame([[1, 2]], schema=data_schema) + + test_accum = self.sc.accumulator(0) + + def first_udf(x): + test_accum.add(1) + return x + + def second_udf(x): + test_accum.add(100) + return x + + func_udf = udf(first_udf, IntegerType()) + func_udf2 = udf(second_udf, IntegerType()) + data = data.withColumn("out1", func_udf(data["a"])) + data = data.withColumn("out2", func_udf2(data["b"])) + data.collect() + self.assertEqual(test_accum.value, 101) + class HiveSparkSubmitTests(SparkSubmitTests): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org