Repository: spark
Updated Branches:
  refs/heads/master 1eebfbe19 -> 755f2f518


[SPARK-20392][SQL][FOLLOWUP] should not add extra AnalysisBarrier

## What changes were proposed in this pull request?

I found this problem while auditing the analyzer code. It's dangerous to 
introduce extra `AnalysisBarrer` during analysis, as the plan inside it will 
bypass all analysis afterward, which may not be expected. We should only 
preserve `AnalysisBarrer` but not introduce new ones.

## How was this patch tested?

existing tests

Author: Wenchen Fan <wenc...@databricks.com>

Closes #20094 from cloud-fan/barrier.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/755f2f51
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/755f2f51
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/755f2f51

Branch: refs/heads/master
Commit: 755f2f5189a08597fddc90b7f9df4ad0ec6bd2ef
Parents: 1eebfbe
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Thu Dec 28 21:33:03 2017 +0800
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Thu Dec 28 21:33:03 2017 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 191 ++++++++-----------
 .../sql/hive/execution/SQLQuerySuite.scala      |   2 +-
 2 files changed, 84 insertions(+), 109 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/755f2f51/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 10b237f..7f2128e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.catalyst.analysis
 
-import scala.annotation.tailrec
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.sql.AnalysisException
@@ -665,14 +664,18 @@ class Analyzer(
      * Generate a new logical plan for the right child with different 
expression IDs
      * for all conflicting attributes.
      */
-    private def dedupRight (left: LogicalPlan, originalRight: LogicalPlan): 
LogicalPlan = {
-      // Remove analysis barrier if any.
-      val right = EliminateBarriers(originalRight)
+    private def dedupRight (left: LogicalPlan, right: LogicalPlan): 
LogicalPlan = {
       val conflictingAttributes = left.outputSet.intersect(right.outputSet)
       logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} 
" +
         s"between $left and $right")
 
       right.collect {
+        // For `AnalysisBarrier`, recursively de-duplicate its child.
+        case oldVersion: AnalysisBarrier
+            if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty 
=>
+          val newVersion = dedupRight(left, oldVersion.child)
+          (oldVersion, AnalysisBarrier(newVersion))
+
         // Handle base relations that might appear more than once.
         case oldVersion: MultiInstanceRelation
             if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty 
=>
@@ -710,10 +713,10 @@ class Analyzer(
            * that this rule cannot handle. When that is the case, there must 
be another rule
            * that resolves these conflicts. Otherwise, the analysis will fail.
            */
-          originalRight
+          right
         case Some((oldRelation, newRelation)) =>
           val attributeRewrites = 
AttributeMap(oldRelation.output.zip(newRelation.output))
-          val newRight = right transformUp {
+          right transformUp {
             case r if r == oldRelation => newRelation
           } transformUp {
             case other => other transformExpressions {
@@ -723,7 +726,6 @@ class Analyzer(
                 s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, 
attributeRewrites))
             }
           }
-          AnalysisBarrier(newRight)
       }
     }
 
@@ -958,7 +960,8 @@ class Analyzer(
   protected[sql] def resolveExpression(
       expr: Expression,
       plan: LogicalPlan,
-      throws: Boolean = false) = {
+      throws: Boolean = false): Expression = {
+    if (expr.resolved) return expr
     // Resolve expression in one round.
     // If throws == false or the desired attribute doesn't exist
     // (like try to resolve `a.b` but `a` doesn't exist), fail and return the 
origin one.
@@ -1079,100 +1082,74 @@ class Analyzer(
       case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa
       case sa @ Sort(_, _, child: Aggregate) => sa
 
-      case s @ Sort(order, _, originalChild) if !s.resolved && 
originalChild.resolved =>
-        val child = EliminateBarriers(originalChild)
-        try {
-          val newOrder = order.map(resolveExpressionRecursively(_, 
child).asInstanceOf[SortOrder])
-          val requiredAttrs = AttributeSet(newOrder).filter(_.resolved)
-          val missingAttrs = requiredAttrs -- child.outputSet
-          if (missingAttrs.nonEmpty) {
-            // Add missing attributes and then project them away after the 
sort.
-            Project(child.output,
-              Sort(newOrder, s.global, addMissingAttr(child, missingAttrs)))
-          } else if (newOrder != order) {
-            s.copy(order = newOrder)
-          } else {
-            s
-          }
-        } catch {
-          // Attempting to resolve it might fail. When this happens, return 
the original plan.
-          // Users will see an AnalysisException for resolution failure of 
missing attributes
-          // in Sort
-          case ae: AnalysisException => s
+      case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
+        val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, child)
+        val ordering = newOrder.map(_.asInstanceOf[SortOrder])
+        if (child.output == newChild.output) {
+          s.copy(order = ordering)
+        } else {
+          // Add missing attributes and then project them away.
+          val newSort = s.copy(order = ordering, child = newChild)
+          Project(child.output, newSort)
         }
 
-      case f @ Filter(cond, originalChild) if !f.resolved && 
originalChild.resolved =>
-        val child = EliminateBarriers(originalChild)
-        try {
-          val newCond = resolveExpressionRecursively(cond, child)
-          val requiredAttrs = newCond.references.filter(_.resolved)
-          val missingAttrs = requiredAttrs -- child.outputSet
-          if (missingAttrs.nonEmpty) {
-            // Add missing attributes and then project them away.
-            Project(child.output,
-              Filter(newCond, addMissingAttr(child, missingAttrs)))
-          } else if (newCond != cond) {
-            f.copy(condition = newCond)
-          } else {
-            f
-          }
-        } catch {
-          // Attempting to resolve it might fail. When this happens, return 
the original plan.
-          // Users will see an AnalysisException for resolution failure of 
missing attributes
-          case ae: AnalysisException => f
+      case f @ Filter(cond, child) if !f.resolved && child.resolved =>
+        val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), 
child)
+        if (child.output == newChild.output) {
+          f.copy(condition = newCond.head)
+        } else {
+          // Add missing attributes and then project them away.
+          val newFilter = Filter(newCond.head, newChild)
+          Project(child.output, newFilter)
         }
     }
 
-    /**
-     * Add the missing attributes into projectList of Project/Window or 
aggregateExpressions of
-     * Aggregate.
-     */
-    private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): 
LogicalPlan = {
-      if (missingAttrs.isEmpty) {
-        return AnalysisBarrier(plan)
-      }
-      plan match {
-        case p: Project =>
-          val missing = missingAttrs -- p.child.outputSet
-          Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, 
missing))
-        case a: Aggregate =>
-          // all the missing attributes should be grouping expressions
-          // TODO: push down AggregateExpression
-          missingAttrs.foreach { attr =>
-            if (!a.groupingExpressions.exists(_.semanticEquals(attr))) {
-              throw new AnalysisException(s"Can't add $attr to 
${a.simpleString}")
-            }
-          }
-          val newAggregateExpressions = a.aggregateExpressions ++ missingAttrs
-          a.copy(aggregateExpressions = newAggregateExpressions)
-        case g: Generate =>
-          // If join is false, we will convert it to true for getting from the 
child the missing
-          // attributes that its child might have or could have.
-          val missing = missingAttrs -- g.child.outputSet
-          g.copy(join = true, child = addMissingAttr(g.child, missing))
-        case d: Distinct =>
-          throw new AnalysisException(s"Can't add $missingAttrs to $d")
-        case u: UnaryNode =>
-          u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil)
-        case other =>
-          throw new AnalysisException(s"Can't add $missingAttrs to $other")
-      }
-    }
-
-    /**
-     * Resolve the expression on a specified logical plan and it's child 
(recursively), until
-     * the expression is resolved or meet a non-unary node or Subquery.
-     */
-    @tailrec
-    private def resolveExpressionRecursively(expr: Expression, plan: 
LogicalPlan): Expression = {
-      val resolved = resolveExpression(expr, plan)
-      if (resolved.resolved) {
-        resolved
+    private def resolveExprsAndAddMissingAttrs(
+        exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], 
LogicalPlan) = {
+      if (exprs.forall(_.resolved)) {
+        // All given expressions are resolved, no need to continue anymore.
+        (exprs, plan)
       } else {
         plan match {
-          case u: UnaryNode if !u.isInstanceOf[SubqueryAlias] =>
-            resolveExpressionRecursively(resolved, u.child)
-          case other => resolved
+          // For `AnalysisBarrier`, recursively resolve expressions and add 
missing attributes via
+          // its child.
+          case barrier: AnalysisBarrier =>
+            val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(exprs, 
barrier.child)
+            (newExprs, AnalysisBarrier(newChild))
+
+          case p: Project =>
+            val maybeResolvedExprs = exprs.map(resolveExpression(_, p))
+            val (newExprs, newChild) = 
resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child)
+            val missingAttrs = AttributeSet(newExprs) -- 
AttributeSet(maybeResolvedExprs)
+            (newExprs, Project(p.projectList ++ missingAttrs, newChild))
+
+          case a @ Aggregate(groupExprs, aggExprs, child) =>
+            val maybeResolvedExprs = exprs.map(resolveExpression(_, a))
+            val (newExprs, newChild) = 
resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child)
+            val missingAttrs = AttributeSet(newExprs) -- 
AttributeSet(maybeResolvedExprs)
+            if (missingAttrs.forall(attr => 
groupExprs.exists(_.semanticEquals(attr)))) {
+              // All the missing attributes are grouping expressions, valid 
case.
+              (newExprs, a.copy(aggregateExpressions = aggExprs ++ 
missingAttrs, child = newChild))
+            } else {
+              // Need to add non-grouping attributes, invalid case.
+              (exprs, a)
+            }
+
+          case g: Generate =>
+            val maybeResolvedExprs = exprs.map(resolveExpression(_, g))
+            val (newExprs, newChild) = 
resolveExprsAndAddMissingAttrs(maybeResolvedExprs, g.child)
+            (newExprs, g.copy(join = true, child = newChild))
+
+          // For `Distinct` and `SubqueryAlias`, we can't recursively resolve 
and add attributes
+          // via its children.
+          case u: UnaryNode if !u.isInstanceOf[Distinct] && 
!u.isInstanceOf[SubqueryAlias] =>
+            val maybeResolvedExprs = exprs.map(resolveExpression(_, u))
+            val (newExprs, newChild) = 
resolveExprsAndAddMissingAttrs(maybeResolvedExprs, u.child)
+            (newExprs, u.withNewChildren(Seq(newChild)))
+
+          // For other operators, we can't recursively resolve and add 
attributes via its children.
+          case other =>
+            (exprs.map(resolveExpression(_, other)), other)
         }
       }
     }
@@ -1404,18 +1381,16 @@ class Analyzer(
    */
   object ResolveAggregateFunctions extends Rule[LogicalPlan] {
     def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
-      case filter @ Filter(havingCondition, AnalysisBarrier(aggregate: 
Aggregate)) =>
-        apply(Filter(havingCondition, aggregate)).mapChildren(AnalysisBarrier)
-      case filter @ Filter(havingCondition,
-             aggregate @ Aggregate(grouping, originalAggExprs, child))
-          if aggregate.resolved =>
+      case Filter(cond, AnalysisBarrier(agg: Aggregate)) =>
+        apply(Filter(cond, agg)).mapChildren(AnalysisBarrier)
+      case f @ Filter(cond, agg @ Aggregate(grouping, originalAggExprs, 
child)) if agg.resolved =>
 
         // Try resolving the condition of the filter as though it is in the 
aggregate clause
         try {
           val aggregatedCondition =
             Aggregate(
               grouping,
-              Alias(havingCondition, "havingCondition")() :: Nil,
+              Alias(cond, "havingCondition")() :: Nil,
               child)
           val resolvedOperator = execute(aggregatedCondition)
           def resolvedAggregateFilter =
@@ -1436,7 +1411,7 @@ class Analyzer(
               // Grouping functions are handled in the rule 
[[ResolveGroupingAnalytics]].
               case e: Expression if grouping.exists(_.semanticEquals(e)) &&
                   !ResolveGroupingAnalytics.hasGroupingFunction(e) &&
-                  !aggregate.output.exists(_.semanticEquals(e)) =>
+                  !agg.output.exists(_.semanticEquals(e)) =>
                 e match {
                   case ne: NamedExpression =>
                     aggregateExpressions += ne
@@ -1450,22 +1425,22 @@ class Analyzer(
 
             // Push the aggregate expressions into the aggregate (if any).
             if (aggregateExpressions.nonEmpty) {
-              Project(aggregate.output,
+              Project(agg.output,
                 Filter(transformedAggregateFilter,
-                  aggregate.copy(aggregateExpressions = originalAggExprs ++ 
aggregateExpressions)))
+                  agg.copy(aggregateExpressions = originalAggExprs ++ 
aggregateExpressions)))
             } else {
-              filter
+              f
             }
           } else {
-            filter
+            f
           }
         } catch {
           // Attempting to resolve in the aggregate can result in ambiguity.  
When this happens,
           // just return the original plan.
-          case ae: AnalysisException => filter
+          case ae: AnalysisException => f
         }
 
-      case sort @ Sort(sortOrder, global, AnalysisBarrier(aggregate: 
Aggregate)) =>
+      case Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) =>
         apply(Sort(sortOrder, global, aggregate)).mapChildren(AnalysisBarrier)
       case sort @ Sort(sortOrder, global, aggregate: Aggregate) if 
aggregate.resolved =>
 

http://git-wip-us.apache.org/repos/asf/spark/blob/755f2f51/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index c11e37a..07ae3ae 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -1562,7 +1562,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils 
with TestHiveSingleton {
   }
 
   test("multi-insert with lateral view") {
-    withTempView("t1") {
+    withTempView("source") {
       spark.range(10)
         .select(array($"id", $"id" + 1).as("arr"), $"id")
         .createOrReplaceTempView("source")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to