This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 3eb7264  [SPARK-37865][SQL] Fix union deduplication correctness bug
3eb7264 is described below

commit 3eb7264491cb57156d40b9c39f3ba2f932d51594
Author: Karen Feng <karen.f...@databricks.com>
AuthorDate: Wed Mar 9 09:34:01 2022 +0800

    [SPARK-37865][SQL] Fix union deduplication correctness bug
    
    Fixes a correctness bug in `Union` in the case that there are duplicate 
output columns. Previously, duplicate columns on one side of the union would 
result in a duplicate column being output on the other side of the union.
    
    To do so, we go through the union’s child’s output and find the duplicates. 
For each duplicate set, there is a first duplicate: this one is left alone. All 
following duplicates are aliased and given a tag; this tag is used to remove 
ambiguity during resolution.
    
    As the first duplicate is left alone, the user can still select it, 
avoiding a breaking change. As the later duplicates are given new expression 
IDs, this fixes the correctness bug.
    
    Output of union with duplicate columns in the children was incorrect
    
    Example query:
    ```
    SELECT a, a FROM VALUES (1, 1), (1, 2) AS t1(a, b)
    UNION ALL SELECT c, d FROM VALUES (2, 2), (2, 3) AS t2(c, d)
    ```
    
    Result before:
    ```
    a | a
    _ | _
    1 | 1
    1 | 1
    2 | 2
    2 | 2
    ```
    
    Result after:
    ```
    a | a
    _ | _
    1 | 1
    1 | 1
    2 | 2
    2 | 3
    ```
    
    Unit tests
    
    Closes #35760 from karenfeng/spark-37865.
    
    Authored-by: Karen Feng <karen.f...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 59ce0a706cb52a54244a747d0a070b61f5cddd1c)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     | 25 +++++++++
 .../spark/sql/catalyst/expressions/package.scala   |  8 ++-
 .../org/apache/spark/sql/DataFrameSuite.scala      | 63 ++++++++++++++++++++++
 3 files changed, 95 insertions(+), 1 deletion(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index aedfb63..2f33394 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1343,6 +1343,31 @@ class Analyzer(
         }
         u.copy(children = newChildren)
 
+      case u @ Union(children, _, _)
+        // if there are duplicate output columns, give them unique expr ids
+          if children.exists(c => c.output.map(_.exprId).distinct.length < 
c.output.length) =>
+        val newChildren = children.map { c =>
+          if (c.output.map(_.exprId).distinct.length < c.output.length) {
+            val existingExprIds = mutable.HashSet[ExprId]()
+            val projectList = c.output.map { attr =>
+              if (existingExprIds.contains(attr.exprId)) {
+                // replace non-first duplicates with aliases and tag them
+                val newMetadata = new 
MetadataBuilder().withMetadata(attr.metadata)
+                  .putNull("__is_duplicate").build()
+                Alias(attr, attr.name)(explicitMetadata = Some(newMetadata))
+              } else {
+                // leave first duplicate alone
+                existingExprIds.add(attr.exprId)
+                attr
+              }
+            }
+            Project(projectList, c)
+          } else {
+            c
+          }
+        }
+        u.withNewChildren(newChildren)
+
       // When resolve `SortOrder`s in Sort based on child, don't report errors 
as
       // we still have chance to resolve it based on its descendants
       case s @ Sort(ordering, global, child) if child.resolved && !s.resolved 
=>
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index 8bf1f19..287e934 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -335,8 +335,14 @@ package object expressions  {
         matchWithFourOrMoreQualifierParts(nameParts, resolver)
       }
 
+      val prunedCandidates = if (candidates.size > 1) {
+        candidates.filter(c => !c.metadata.contains("__is_duplicate"))
+      } else {
+        candidates
+      }
+
       def name = UnresolvedAttribute(nameParts).name
-      candidates match {
+      prunedCandidates match {
         case Seq(a) if nestedFields.nonEmpty =>
           // One match, but we also need to extract the requested nested field.
           // The foldLeft adds ExtractValues for every remaining parts of the 
identifier,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index e5690f3..7984336 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -2529,6 +2529,69 @@ class DataFrameSuite extends QueryTest
       checkAnswer(sql("SELECT sum(c1 * c3) + sum(c2 * c3) FROM tbl"), 
Row(2.00000000000) :: Nil)
     }
   }
+
+  test("SPARK-37865: Do not deduplicate union output columns") {
+    val df1 = Seq((1, 1), (1, 2)).toDF("a", "b")
+    val df2 = Seq((2, 2), (2, 3)).toDF("c", "d")
+
+    def sqlQuery(cols1: Seq[String], cols2: Seq[String], distinct: Boolean): 
String = {
+      val union = if (distinct) {
+        "UNION"
+      } else {
+        "UNION ALL"
+      }
+      s"""
+         |SELECT ${cols1.mkString(",")} FROM VALUES (1, 1), (1, 2) AS t1(a, b)
+         |$union SELECT ${cols2.mkString(",")} FROM VALUES (2, 2), (2, 3) AS 
t2(c, d)
+         |""".stripMargin
+    }
+
+    Seq(
+      (Seq("a", "a"), Seq("c", "d"), Seq(Row(1, 1), Row(1, 1), Row(2, 2), 
Row(2, 3))),
+      (Seq("a", "b"), Seq("c", "d"), Seq(Row(1, 1), Row(1, 2), Row(2, 2), 
Row(2, 3))),
+      (Seq("a", "b"), Seq("c", "c"), Seq(Row(1, 1), Row(1, 2), Row(2, 2), 
Row(2, 2)))
+    ).foreach { case (cols1, cols2, rows) =>
+      // UNION ALL (non-distinct)
+      val df3 = df1.selectExpr(cols1: _*).union(df2.selectExpr(cols2: _*))
+      checkAnswer(df3, rows)
+
+      val t3 = sqlQuery(cols1, cols2, false)
+      checkAnswer(sql(t3), rows)
+
+      // Avoid breaking change
+      var correctAnswer = rows.map(r => Row(r(0)))
+      checkAnswer(df3.select(df1.col("a")), correctAnswer)
+      checkAnswer(sql(s"select a from ($t3) t3"), correctAnswer)
+
+      // This has always been broken
+      intercept[AnalysisException] {
+        df3.select(df2.col("d")).collect()
+      }
+      intercept[AnalysisException] {
+        sql(s"select d from ($t3) t3")
+      }
+
+      // UNION (distinct)
+      val df4 = df3.distinct
+      checkAnswer(df4, rows.distinct)
+
+      val t4 = sqlQuery(cols1, cols2, true)
+      checkAnswer(sql(t4), rows.distinct)
+
+      // Avoid breaking change
+      correctAnswer = rows.distinct.map(r => Row(r(0)))
+      checkAnswer(df4.select(df1.col("a")), correctAnswer)
+      checkAnswer(sql(s"select a from ($t4) t4"), correctAnswer)
+
+      // This has always been broken
+      intercept[AnalysisException] {
+        df4.select(df2.col("d")).collect()
+      }
+      intercept[AnalysisException] {
+        sql(s"select d from ($t4) t4")
+      }
+    }
+  }
 }
 
 case class GroupByKey(a: Int, b: Int)

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to