This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 6ee0eb4  [SPARK-32280][SPARK-32372][2.4][SQL] 
ResolveReferences.dedupRight should only rewrite attributes for ancestor nodes 
of the conflict plan
6ee0eb4 is described below

commit 6ee0eb40870c92889bc3c627d4b3178033a64a18
Author: yi.wu <yi...@databricks.com>
AuthorDate: Fri Jul 24 04:26:22 2020 +0000

    [SPARK-32280][SPARK-32372][2.4][SQL] ResolveReferences.dedupRight should 
only rewrite attributes for ancestor nodes of the conflict plan
    
    ### What changes were proposed in this pull request?
    
    This PR refactors `ResolveReferences.dedupRight` to make sure it only 
rewrite attributes for ancestor nodes of the conflict plan.
    
    ### Why are the changes needed?
    
    This is a bug fix.
    
    ```scala
    sql("SELECT name, avg(age) as avg_age FROM person GROUP BY name")
      .createOrReplaceTempView("person_a")
    sql("SELECT p1.name, p2.avg_age FROM person p1 JOIN person_a p2 ON p1.name 
= p2.name")
      .createOrReplaceTempView("person_b")
    sql("SELECT * FROM person_a UNION SELECT * FROM person_b")
      .createOrReplaceTempView("person_c")
    sql("SELECT p1.name, p2.avg_age FROM person_c p1 JOIN person_c p2 ON 
p1.name = p2.name").show()
    ```
    When executing the above query, we'll hit the error:
    
    ```scala
    [info]   Failed to analyze query: org.apache.spark.sql.AnalysisException: 
Resolved attribute(s) avg_age#231 missing from 
name#223,avg_age#218,id#232,age#234,name#233 in operator !Project [name#233, 
avg_age#231]. Attribute(s) with the same name appear in the operation: avg_age. 
Please check if the right attribute(s) are used.;;
    ...
    ```
    
    The plan below is the problematic plan which is the right plan of a `Join` 
operator. And, it has conflict plans comparing to the left plan. In this 
problematic plan, the first `Aggregate` operator (the one under the first child 
of `Union`) becomes a conflict plan compares to the left one and has a rewrite 
attribute pair as  `avg_age#218` -> `avg_age#231`. With the current 
`dedupRight` logic, we'll first replace this `Aggregate` with a new one, and 
then rewrites the attribute `avg_age# [...]
    
    ```scala
    :

    :
    +- SubqueryAlias p2
       +- SubqueryAlias person_c
          +- Distinct
             +- Union
                :- Project [name#233, avg_age#231]
                :  +- SubqueryAlias person_a
                :     +- Aggregate [name#233], [name#233, avg(cast(age#234 as 
bigint)) AS avg_age#231]
                :        +- SubqueryAlias person
                :           +- SerializeFromObject 
[knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$Person, true])).id AS id#232, 
staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, 
fromString, knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$Person, true])).name, true, false) AS 
name#233, knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$Person, true])).age AS age#234]
                :              +- ExternalRDD [obj#165]
                +- Project [name#233 AS name#227, avg_age#231 AS avg_age#228]
                   +- Project [name#233, avg_age#231]
                      +- SubqueryAlias person_b
                         +- !Project [name#233, avg_age#231]
                            +- Join Inner, (name#233 = name#223)
                               :- SubqueryAlias p1
                               :  +- SubqueryAlias person
                               :     +- SerializeFromObject 
[knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$Person, true])).id AS id#232, 
staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, 
fromString, knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$Person, true])).name, true, false) AS 
name#233, knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$Person, true])).age AS age#234]
                               :        +- ExternalRDD [obj#165]
                               +- SubqueryAlias p2
                                  +- SubqueryAlias person_a
                                     +- Aggregate [name#223], [name#223, 
avg(cast(age#224 as bigint)) AS avg_age#218]
                                        +- SubqueryAlias person
                                           +- SerializeFromObject 
[knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$Person, true])).id AS id#222, 
staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, 
fromString, knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$Person, true])).name, true, false) AS 
name#223, knownnotnull(assertnotnull(input[0, 
org.apache.spark.sql.test.SQLTestData$Person, true])).age AS age#224]
                                              +- ExternalRDD [obj#165]
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, users would no longer hit the error after this fix.
    
    ### How was this patch tested?
    
    Added test.
    
    Closes #29208 from Ngone51/cherry-pick-spark-32372.
    
    Authored-by: yi.wu <yi...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     | 54 ++++++++++++++++++----
 .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 24 ++++++++++
 2 files changed, 68 insertions(+), 10 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index afe7b4f..aaaf707 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -886,17 +886,51 @@ class Analyzer(
            */
           right
         case Some((oldRelation, newRelation)) =>
-          val attributeRewrites = 
AttributeMap(oldRelation.output.zip(newRelation.output))
-          right transformUp {
-            case r if r == oldRelation => newRelation
-          } transformUp {
-            case other => other transformExpressions {
-              case a: Attribute =>
-                dedupAttr(a, attributeRewrites)
-              case s: SubqueryExpression =>
-                s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, 
attributeRewrites))
-            }
+          rewritePlan(right, Map(oldRelation -> newRelation))._1
+      }
+    }
+
+    private def rewritePlan(plan: LogicalPlan, conflictPlanMap: 
Map[LogicalPlan, LogicalPlan])
+      : (LogicalPlan, Seq[(Attribute, Attribute)]) = {
+      if (conflictPlanMap.contains(plan)) {
+        // If the plan is the one that conflict the with left one, we'd
+        // just replace it with the new plan and collect the rewrite
+        // attributes for the parent node.
+        val newRelation = conflictPlanMap(plan)
+        newRelation -> plan.output.zip(newRelation.output)
+      } else {
+        val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
+        val newPlan = plan.mapChildren { child =>
+          // If not, we'd rewrite child plan recursively until we find the
+          // conflict node or reach the leaf node.
+          val (newChild, childAttrMapping) = rewritePlan(child, 
conflictPlanMap)
+          attrMapping ++= childAttrMapping.filter { case (oldAttr, _) =>
+            // `attrMapping` is not only used to replace the attributes of the 
current `plan`,
+            // but also to be propagated to the parent plans of the current 
`plan`. Therefore,
+            // the `oldAttr` must be part of either `plan.references` (so that 
it can be used to
+            // replace attributes of the current `plan`) or `plan.outputSet` 
(so that it can be
+            // used by those parent plans).
+            (plan.outputSet ++ plan.references).contains(oldAttr)
           }
+          newChild
+        }
+
+        if (attrMapping.isEmpty) {
+          newPlan -> attrMapping
+        } else {
+          assert(!attrMapping.groupBy(_._1.exprId)
+            .exists(_._2.map(_._2.exprId).distinct.length > 1),
+            "Found duplicate rewrite attributes")
+          val attributeRewrites = AttributeMap(attrMapping)
+          // Using attrMapping from the children plans to rewrite their parent 
node.
+          // Note that we shouldn't rewrite a node using attrMapping from its 
sibling nodes.
+          newPlan.transformExpressions {
+            case a: Attribute =>
+              dedupAttr(a, attributeRewrites)
+            case s: SubqueryExpression =>
+              s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, 
attributeRewrites))
+          } -> attrMapping
+        }
       }
     }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index d0114f6..c424ef8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -3078,6 +3078,30 @@ class SQLQuerySuite extends QueryTest with 
SharedSQLContext {
       checkAnswer(df, Row(1))
     }
   }
+
+  test("SPARK-32372: ResolveReferences.dedupRight should only rewrite 
attributes for ancestor " +
+    "plans of the conflict plan") {
+    sql("SELECT name, avg(age) as avg_age FROM person GROUP BY name")
+      .createOrReplaceTempView("person_a")
+    sql("SELECT p1.name, p2.avg_age FROM person p1 JOIN person_a p2 ON p1.name 
= p2.name")
+      .createOrReplaceTempView("person_b")
+    sql("SELECT * FROM person_a UNION SELECT * FROM person_b")
+      .createOrReplaceTempView("person_c")
+    checkAnswer(
+      sql("SELECT p1.name, p2.avg_age FROM person_c p1 JOIN person_c p2 ON 
p1.name = p2.name"),
+      Row("jim", 20.0) :: Row("mike", 30.0) :: Nil)
+  }
+
+  test("SPARK-32280: Avoid duplicate rewrite attributes when there're multiple 
JOINs") {
+    withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+      sql("SELECT 1 AS id").createOrReplaceTempView("A")
+      sql("SELECT id, 'foo' AS kind FROM A").createOrReplaceTempView("B")
+      sql("SELECT l.id as id FROM B AS l LEFT SEMI JOIN B AS r ON l.kind = 
r.kind")
+        .createOrReplaceTempView("C")
+      checkAnswer(sql("SELECT 0 FROM ( SELECT * FROM B JOIN C USING (id)) " +
+        "JOIN ( SELECT * FROM B JOIN C USING (id)) USING (id)"), Row(0))
+    }
+  }
 }
 
 case class Foo(bar: Option[String])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to