Github user viirya commented on a diff in the pull request:

    https://github.com/apache/spark/pull/22635#discussion_r223251175
  
    --- Diff: python/pyspark/sql/tests.py ---
    @@ -3603,6 +3603,31 @@ def test_repr_behaviors(self):
                         self.assertEquals(None, df._repr_html_())
                         self.assertEquals(expected, df.__repr__())
     
    +    # SPARK-25591
    +    def test_same_accumulator_in_udfs(self):
    +        from pyspark.sql.functions import udf
    +
    +        data_schema = StructType([StructField("a", DoubleType(), True),
    +                                  StructField("b", DoubleType(), True)])
    +        data = self.spark.createDataFrame([[1.0, 2.0]], schema=data_schema)
    +
    +        test_accum = self.sc.accumulator(0.0)
    +
    +        def first_udf(x):
    +            test_accum.add(1.0)
    +            return x
    +
    +        def second_udf(x):
    +            test_accum.add(100.0)
    +            return x
    +
    +        func_udf = udf(first_udf, DoubleType())
    +        func_udf2 = udf(second_udf, DoubleType())
    +        data = data.withColumn("out1", func_udf(data["a"]))
    +        data = data.withColumn("out2", func_udf2(data["b"]))
    +        data.collect()
    +        self.assertEqual(test_accum.value, 101)
    --- End diff --
    
    Ok. 


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to