Repository: spark
Updated Branches:
  refs/heads/master 219922422 -> cb90617f8


[SPARK-25591][PYSPARK][SQL] Avoid overwriting deserialized accumulator

## What changes were proposed in this pull request?

If we use accumulators in more than one UDFs, it is possible to overwrite 
deserialized accumulators and its values. We should check if an accumulator was 
deserialized before overwriting it in accumulator registry.

## How was this patch tested?

Added test.

Closes #22635 from viirya/SPARK-25591.

Authored-by: Liang-Chi Hsieh <vii...@gmail.com>
Signed-off-by: hyukjinkwon <gurwls...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cb90617f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cb90617f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cb90617f

Branch: refs/heads/master
Commit: cb90617f894fd51a092710271823ec7d1cd3a668
Parents: 2199224
Author: Liang-Chi Hsieh <vii...@gmail.com>
Authored: Mon Oct 8 15:18:08 2018 +0800
Committer: hyukjinkwon <gurwls...@apache.org>
Committed: Mon Oct 8 15:18:08 2018 +0800

----------------------------------------------------------------------
 python/pyspark/accumulators.py | 12 ++++++++----
 python/pyspark/sql/tests.py    | 25 +++++++++++++++++++++++++
 2 files changed, 33 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/cb90617f/python/pyspark/accumulators.py
----------------------------------------------------------------------
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index 30ad042..00ec094 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -109,10 +109,14 @@ _accumulatorRegistry = {}
 
 def _deserialize_accumulator(aid, zero_value, accum_param):
     from pyspark.accumulators import _accumulatorRegistry
-    accum = Accumulator(aid, zero_value, accum_param)
-    accum._deserialized = True
-    _accumulatorRegistry[aid] = accum
-    return accum
+    # If this certain accumulator was deserialized, don't overwrite it.
+    if aid in _accumulatorRegistry:
+        return _accumulatorRegistry[aid]
+    else:
+        accum = Accumulator(aid, zero_value, accum_param)
+        accum._deserialized = True
+        _accumulatorRegistry[aid] = accum
+        return accum
 
 
 class Accumulator(object):

http://git-wip-us.apache.org/repos/asf/spark/blob/cb90617f/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index d3c29d0..ac87ccd 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3603,6 +3603,31 @@ class SQLTests(ReusedSQLTestCase):
                     self.assertEquals(None, df._repr_html_())
                     self.assertEquals(expected, df.__repr__())
 
+    # SPARK-25591
+    def test_same_accumulator_in_udfs(self):
+        from pyspark.sql.functions import udf
+
+        data_schema = StructType([StructField("a", IntegerType(), True),
+                                  StructField("b", IntegerType(), True)])
+        data = self.spark.createDataFrame([[1, 2]], schema=data_schema)
+
+        test_accum = self.sc.accumulator(0)
+
+        def first_udf(x):
+            test_accum.add(1)
+            return x
+
+        def second_udf(x):
+            test_accum.add(100)
+            return x
+
+        func_udf = udf(first_udf, IntegerType())
+        func_udf2 = udf(second_udf, IntegerType())
+        data = data.withColumn("out1", func_udf(data["a"]))
+        data = data.withColumn("out2", func_udf2(data["b"]))
+        data.collect()
+        self.assertEqual(test_accum.value, 101)
+
 
 class HiveSparkSubmitTests(SparkSubmitTests):
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to