This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new 3eb7264 [SPARK-37865][SQL] Fix union deduplication correctness bug 3eb7264 is described below commit 3eb7264491cb57156d40b9c39f3ba2f932d51594 Author: Karen Feng <karen.f...@databricks.com> AuthorDate: Wed Mar 9 09:34:01 2022 +0800 [SPARK-37865][SQL] Fix union deduplication correctness bug Fixes a correctness bug in `Union` in the case that there are duplicate output columns. Previously, duplicate columns on one side of the union would result in a duplicate column being output on the other side of the union. To do so, we go through the union’s child’s output and find the duplicates. For each duplicate set, there is a first duplicate: this one is left alone. All following duplicates are aliased and given a tag; this tag is used to remove ambiguity during resolution. As the first duplicate is left alone, the user can still select it, avoiding a breaking change. As the later duplicates are given new expression IDs, this fixes the correctness bug. Output of union with duplicate columns in the children was incorrect Example query: ``` SELECT a, a FROM VALUES (1, 1), (1, 2) AS t1(a, b) UNION ALL SELECT c, d FROM VALUES (2, 2), (2, 3) AS t2(c, d) ``` Result before: ``` a | a _ | _ 1 | 1 1 | 1 2 | 2 2 | 2 ``` Result after: ``` a | a _ | _ 1 | 1 1 | 1 2 | 2 2 | 3 ``` Unit tests Closes #35760 from karenfeng/spark-37865. Authored-by: Karen Feng <karen.f...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit 59ce0a706cb52a54244a747d0a070b61f5cddd1c) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/analysis/Analyzer.scala | 25 +++++++++ .../spark/sql/catalyst/expressions/package.scala | 8 ++- .../org/apache/spark/sql/DataFrameSuite.scala | 63 ++++++++++++++++++++++ 3 files changed, 95 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index aedfb63..2f33394 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1343,6 +1343,31 @@ class Analyzer( } u.copy(children = newChildren) + case u @ Union(children, _, _) + // if there are duplicate output columns, give them unique expr ids + if children.exists(c => c.output.map(_.exprId).distinct.length < c.output.length) => + val newChildren = children.map { c => + if (c.output.map(_.exprId).distinct.length < c.output.length) { + val existingExprIds = mutable.HashSet[ExprId]() + val projectList = c.output.map { attr => + if (existingExprIds.contains(attr.exprId)) { + // replace non-first duplicates with aliases and tag them + val newMetadata = new MetadataBuilder().withMetadata(attr.metadata) + .putNull("__is_duplicate").build() + Alias(attr, attr.name)(explicitMetadata = Some(newMetadata)) + } else { + // leave first duplicate alone + existingExprIds.add(attr.exprId) + attr + } + } + Project(projectList, c) + } else { + c + } + } + u.withNewChildren(newChildren) + // When resolve `SortOrder`s in Sort based on child, don't report errors as // we still have chance to resolve it based on its descendants case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 8bf1f19..287e934 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -335,8 +335,14 @@ package object expressions { matchWithFourOrMoreQualifierParts(nameParts, resolver) } + val prunedCandidates = if (candidates.size > 1) { + candidates.filter(c => !c.metadata.contains("__is_duplicate")) + } else { + candidates + } + def name = UnresolvedAttribute(nameParts).name - candidates match { + prunedCandidates match { case Seq(a) if nestedFields.nonEmpty => // One match, but we also need to extract the requested nested field. // The foldLeft adds ExtractValues for every remaining parts of the identifier, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e5690f3..7984336 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2529,6 +2529,69 @@ class DataFrameSuite extends QueryTest checkAnswer(sql("SELECT sum(c1 * c3) + sum(c2 * c3) FROM tbl"), Row(2.00000000000) :: Nil) } } + + test("SPARK-37865: Do not deduplicate union output columns") { + val df1 = Seq((1, 1), (1, 2)).toDF("a", "b") + val df2 = Seq((2, 2), (2, 3)).toDF("c", "d") + + def sqlQuery(cols1: Seq[String], cols2: Seq[String], distinct: Boolean): String = { + val union = if (distinct) { + "UNION" + } else { + "UNION ALL" + } + s""" + |SELECT ${cols1.mkString(",")} FROM VALUES (1, 1), (1, 2) AS t1(a, b) + |$union SELECT ${cols2.mkString(",")} FROM VALUES (2, 2), (2, 3) AS t2(c, d) + |""".stripMargin + } + + Seq( + (Seq("a", "a"), Seq("c", "d"), Seq(Row(1, 1), Row(1, 1), Row(2, 2), Row(2, 3))), + (Seq("a", "b"), Seq("c", "d"), Seq(Row(1, 1), Row(1, 2), Row(2, 2), Row(2, 3))), + (Seq("a", "b"), Seq("c", "c"), Seq(Row(1, 1), Row(1, 2), Row(2, 2), Row(2, 2))) + ).foreach { case (cols1, cols2, rows) => + // UNION ALL (non-distinct) + val df3 = df1.selectExpr(cols1: _*).union(df2.selectExpr(cols2: _*)) + checkAnswer(df3, rows) + + val t3 = sqlQuery(cols1, cols2, false) + checkAnswer(sql(t3), rows) + + // Avoid breaking change + var correctAnswer = rows.map(r => Row(r(0))) + checkAnswer(df3.select(df1.col("a")), correctAnswer) + checkAnswer(sql(s"select a from ($t3) t3"), correctAnswer) + + // This has always been broken + intercept[AnalysisException] { + df3.select(df2.col("d")).collect() + } + intercept[AnalysisException] { + sql(s"select d from ($t3) t3") + } + + // UNION (distinct) + val df4 = df3.distinct + checkAnswer(df4, rows.distinct) + + val t4 = sqlQuery(cols1, cols2, true) + checkAnswer(sql(t4), rows.distinct) + + // Avoid breaking change + correctAnswer = rows.distinct.map(r => Row(r(0))) + checkAnswer(df4.select(df1.col("a")), correctAnswer) + checkAnswer(sql(s"select a from ($t4) t4"), correctAnswer) + + // This has always been broken + intercept[AnalysisException] { + df4.select(df2.col("d")).collect() + } + intercept[AnalysisException] { + sql(s"select d from ($t4) t4") + } + } + } } case class GroupByKey(a: Int, b: Int) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org