This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 9b268122f68 [SPARK-39293][SQL] Fix the accumulator of ArrayAggregate 
to handle complex types properly
9b268122f68 is described below

commit 9b268122f68718ed46d9ffd97c402c5a1e7db73a
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Thu May 26 10:36:03 2022 +0900

    [SPARK-39293][SQL] Fix the accumulator of ArrayAggregate to handle complex 
types properly
    
    Fix the accumulator of `ArrayAggregate` to handle complex types properly.
    
    The accumulator of `ArrayAggregate` should copy the intermediate result if 
string, struct, array, or map.
    
    If the intermediate data of `ArrayAggregate` holds reusable data, the 
result will be duplicated.
    
    ```scala
    import org.apache.spark.sql.functions._
    
    val reverse = udf((s: String) => s.reverse)
    
    val df = Seq(Array("abc", "def")).toDF("array")
    val testArray = df.withColumn(
      "agg",
      aggregate(
        col("array"),
        array().cast("array<string>"),
        (acc, s) => concat(acc, array(reverse(s)))))
    
    aggArray.show(truncate=false)
    ```
    
    should be:
    
    ```
    +----------+----------+
    |array     |agg       |
    +----------+----------+
    |[abc, def]|[cba, fed]|
    +----------+----------+
    ```
    
    but:
    
    ```
    +----------+----------+
    |array     |agg       |
    +----------+----------+
    |[abc, def]|[fed, fed]|
    +----------+----------+
    ```
    
    Yes, this fixes the correctness issue.
    
    Added a test.
    
    Closes #36674 from ueshin/issues/SPARK-39293/array_aggregate.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
    (cherry picked from commit d6a11cb4b411c8136eb241aac167bc96990f5421)
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
    (cherry picked from commit 92e82fdf8e2faec5add61e2448f11272dfb19c6e)
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
    (cherry picked from commit 68d69501576ba21e182791aad91b82a1e7282d11)
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../catalyst/expressions/higherOrderFunctions.scala   |  2 +-
 .../scala/org/apache/spark/sql/DataFrameSuite.scala   | 19 +++++++++++++++++++
 2 files changed, 20 insertions(+), 1 deletion(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index a530ce5da27..4a8c366107c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -752,7 +752,7 @@ case class ArrayAggregate(
       var i = 0
       while (i < arr.numElements()) {
         elementVar.value.set(arr.get(i, elementVar.dataType))
-        accForMergeVar.value.set(mergeForEval.eval(input))
+        
accForMergeVar.value.set(InternalRow.copyValue(mergeForEval.eval(input)))
         i += 1
       }
       accForFinishVar.value.set(accForMergeVar.value.get)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 7984336beba..1d752a675dd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -2515,6 +2515,25 @@ class DataFrameSuite extends QueryTest
     checkAnswer(df3.select($"*-#&% ?.`a``b.c`"), Row("col1"))
   }
 
+  test("SPARK-39293: The accumulator of ArrayAggregate to handle complex types 
properly") {
+    val reverse = udf((s: String) => s.reverse)
+
+    val df = Seq(Array("abc", "def")).toDF("array")
+    val testArray = df.select(
+      aggregate(
+        col("array"),
+        array().cast("array<string>"),
+        (acc, s) => concat(acc, array(reverse(s)))))
+    checkAnswer(testArray, Row(Array("cba", "fed")) :: Nil)
+
+    val testMap = df.select(
+      aggregate(
+        col("array"),
+        map().cast("map<string, string>"),
+        (acc, s) => map_concat(acc, map(s, reverse(s)))))
+    checkAnswer(testMap, Row(Map("abc" -> "cba", "def" -> "fed")) :: Nil)
+  }
+
   test("SPARK-35886: PromotePrecision should be subexpr replaced") {
     withTable("tbl") {
       sql(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to