Repository: spark Updated Branches: refs/heads/master 1eebfbe19 -> 755f2f518
[SPARK-20392][SQL][FOLLOWUP] should not add extra AnalysisBarrier ## What changes were proposed in this pull request? I found this problem while auditing the analyzer code. It's dangerous to introduce extra `AnalysisBarrer` during analysis, as the plan inside it will bypass all analysis afterward, which may not be expected. We should only preserve `AnalysisBarrer` but not introduce new ones. ## How was this patch tested? existing tests Author: Wenchen Fan <wenc...@databricks.com> Closes #20094 from cloud-fan/barrier. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/755f2f51 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/755f2f51 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/755f2f51 Branch: refs/heads/master Commit: 755f2f5189a08597fddc90b7f9df4ad0ec6bd2ef Parents: 1eebfbe Author: Wenchen Fan <wenc...@databricks.com> Authored: Thu Dec 28 21:33:03 2017 +0800 Committer: gatorsmile <gatorsm...@gmail.com> Committed: Thu Dec 28 21:33:03 2017 +0800 ---------------------------------------------------------------------- .../spark/sql/catalyst/analysis/Analyzer.scala | 191 ++++++++----------- .../sql/hive/execution/SQLQuerySuite.scala | 2 +- 2 files changed, 84 insertions(+), 109 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/755f2f51/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 10b237f..7f2128e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException @@ -665,14 +664,18 @@ class Analyzer( * Generate a new logical plan for the right child with different expression IDs * for all conflicting attributes. */ - private def dedupRight (left: LogicalPlan, originalRight: LogicalPlan): LogicalPlan = { - // Remove analysis barrier if any. - val right = EliminateBarriers(originalRight) + private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = { val conflictingAttributes = left.outputSet.intersect(right.outputSet) logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " + s"between $left and $right") right.collect { + // For `AnalysisBarrier`, recursively de-duplicate its child. + case oldVersion: AnalysisBarrier + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + val newVersion = dedupRight(left, oldVersion.child) + (oldVersion, AnalysisBarrier(newVersion)) + // Handle base relations that might appear more than once. case oldVersion: MultiInstanceRelation if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => @@ -710,10 +713,10 @@ class Analyzer( * that this rule cannot handle. When that is the case, there must be another rule * that resolves these conflicts. Otherwise, the analysis will fail. */ - originalRight + right case Some((oldRelation, newRelation)) => val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) - val newRight = right transformUp { + right transformUp { case r if r == oldRelation => newRelation } transformUp { case other => other transformExpressions { @@ -723,7 +726,6 @@ class Analyzer( s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) } } - AnalysisBarrier(newRight) } } @@ -958,7 +960,8 @@ class Analyzer( protected[sql] def resolveExpression( expr: Expression, plan: LogicalPlan, - throws: Boolean = false) = { + throws: Boolean = false): Expression = { + if (expr.resolved) return expr // Resolve expression in one round. // If throws == false or the desired attribute doesn't exist // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one. @@ -1079,100 +1082,74 @@ class Analyzer( case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(order, _, originalChild) if !s.resolved && originalChild.resolved => - val child = EliminateBarriers(originalChild) - try { - val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) - val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) - val missingAttrs = requiredAttrs -- child.outputSet - if (missingAttrs.nonEmpty) { - // Add missing attributes and then project them away after the sort. - Project(child.output, - Sort(newOrder, s.global, addMissingAttr(child, missingAttrs))) - } else if (newOrder != order) { - s.copy(order = newOrder) - } else { - s - } - } catch { - // Attempting to resolve it might fail. When this happens, return the original plan. - // Users will see an AnalysisException for resolution failure of missing attributes - // in Sort - case ae: AnalysisException => s + case s @ Sort(order, _, child) if !s.resolved && child.resolved => + val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, child) + val ordering = newOrder.map(_.asInstanceOf[SortOrder]) + if (child.output == newChild.output) { + s.copy(order = ordering) + } else { + // Add missing attributes and then project them away. + val newSort = s.copy(order = ordering, child = newChild) + Project(child.output, newSort) } - case f @ Filter(cond, originalChild) if !f.resolved && originalChild.resolved => - val child = EliminateBarriers(originalChild) - try { - val newCond = resolveExpressionRecursively(cond, child) - val requiredAttrs = newCond.references.filter(_.resolved) - val missingAttrs = requiredAttrs -- child.outputSet - if (missingAttrs.nonEmpty) { - // Add missing attributes and then project them away. - Project(child.output, - Filter(newCond, addMissingAttr(child, missingAttrs))) - } else if (newCond != cond) { - f.copy(condition = newCond) - } else { - f - } - } catch { - // Attempting to resolve it might fail. When this happens, return the original plan. - // Users will see an AnalysisException for resolution failure of missing attributes - case ae: AnalysisException => f + case f @ Filter(cond, child) if !f.resolved && child.resolved => + val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), child) + if (child.output == newChild.output) { + f.copy(condition = newCond.head) + } else { + // Add missing attributes and then project them away. + val newFilter = Filter(newCond.head, newChild) + Project(child.output, newFilter) } } - /** - * Add the missing attributes into projectList of Project/Window or aggregateExpressions of - * Aggregate. - */ - private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = { - if (missingAttrs.isEmpty) { - return AnalysisBarrier(plan) - } - plan match { - case p: Project => - val missing = missingAttrs -- p.child.outputSet - Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, missing)) - case a: Aggregate => - // all the missing attributes should be grouping expressions - // TODO: push down AggregateExpression - missingAttrs.foreach { attr => - if (!a.groupingExpressions.exists(_.semanticEquals(attr))) { - throw new AnalysisException(s"Can't add $attr to ${a.simpleString}") - } - } - val newAggregateExpressions = a.aggregateExpressions ++ missingAttrs - a.copy(aggregateExpressions = newAggregateExpressions) - case g: Generate => - // If join is false, we will convert it to true for getting from the child the missing - // attributes that its child might have or could have. - val missing = missingAttrs -- g.child.outputSet - g.copy(join = true, child = addMissingAttr(g.child, missing)) - case d: Distinct => - throw new AnalysisException(s"Can't add $missingAttrs to $d") - case u: UnaryNode => - u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil) - case other => - throw new AnalysisException(s"Can't add $missingAttrs to $other") - } - } - - /** - * Resolve the expression on a specified logical plan and it's child (recursively), until - * the expression is resolved or meet a non-unary node or Subquery. - */ - @tailrec - private def resolveExpressionRecursively(expr: Expression, plan: LogicalPlan): Expression = { - val resolved = resolveExpression(expr, plan) - if (resolved.resolved) { - resolved + private def resolveExprsAndAddMissingAttrs( + exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = { + if (exprs.forall(_.resolved)) { + // All given expressions are resolved, no need to continue anymore. + (exprs, plan) } else { plan match { - case u: UnaryNode if !u.isInstanceOf[SubqueryAlias] => - resolveExpressionRecursively(resolved, u.child) - case other => resolved + // For `AnalysisBarrier`, recursively resolve expressions and add missing attributes via + // its child. + case barrier: AnalysisBarrier => + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(exprs, barrier.child) + (newExprs, AnalysisBarrier(newChild)) + + case p: Project => + val maybeResolvedExprs = exprs.map(resolveExpression(_, p)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child) + val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) + (newExprs, Project(p.projectList ++ missingAttrs, newChild)) + + case a @ Aggregate(groupExprs, aggExprs, child) => + val maybeResolvedExprs = exprs.map(resolveExpression(_, a)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child) + val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) + if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) { + // All the missing attributes are grouping expressions, valid case. + (newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild)) + } else { + // Need to add non-grouping attributes, invalid case. + (exprs, a) + } + + case g: Generate => + val maybeResolvedExprs = exprs.map(resolveExpression(_, g)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, g.child) + (newExprs, g.copy(join = true, child = newChild)) + + // For `Distinct` and `SubqueryAlias`, we can't recursively resolve and add attributes + // via its children. + case u: UnaryNode if !u.isInstanceOf[Distinct] && !u.isInstanceOf[SubqueryAlias] => + val maybeResolvedExprs = exprs.map(resolveExpression(_, u)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, u.child) + (newExprs, u.withNewChildren(Seq(newChild))) + + // For other operators, we can't recursively resolve and add attributes via its children. + case other => + (exprs.map(resolveExpression(_, other)), other) } } } @@ -1404,18 +1381,16 @@ class Analyzer( */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { - case filter @ Filter(havingCondition, AnalysisBarrier(aggregate: Aggregate)) => - apply(Filter(havingCondition, aggregate)).mapChildren(AnalysisBarrier) - case filter @ Filter(havingCondition, - aggregate @ Aggregate(grouping, originalAggExprs, child)) - if aggregate.resolved => + case Filter(cond, AnalysisBarrier(agg: Aggregate)) => + apply(Filter(cond, agg)).mapChildren(AnalysisBarrier) + case f @ Filter(cond, agg @ Aggregate(grouping, originalAggExprs, child)) if agg.resolved => // Try resolving the condition of the filter as though it is in the aggregate clause try { val aggregatedCondition = Aggregate( grouping, - Alias(havingCondition, "havingCondition")() :: Nil, + Alias(cond, "havingCondition")() :: Nil, child) val resolvedOperator = execute(aggregatedCondition) def resolvedAggregateFilter = @@ -1436,7 +1411,7 @@ class Analyzer( // Grouping functions are handled in the rule [[ResolveGroupingAnalytics]]. case e: Expression if grouping.exists(_.semanticEquals(e)) && !ResolveGroupingAnalytics.hasGroupingFunction(e) && - !aggregate.output.exists(_.semanticEquals(e)) => + !agg.output.exists(_.semanticEquals(e)) => e match { case ne: NamedExpression => aggregateExpressions += ne @@ -1450,22 +1425,22 @@ class Analyzer( // Push the aggregate expressions into the aggregate (if any). if (aggregateExpressions.nonEmpty) { - Project(aggregate.output, + Project(agg.output, Filter(transformedAggregateFilter, - aggregate.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions))) + agg.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions))) } else { - filter + f } } else { - filter + f } } catch { // Attempting to resolve in the aggregate can result in ambiguity. When this happens, // just return the original plan. - case ae: AnalysisException => filter + case ae: AnalysisException => f } - case sort @ Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) => + case Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) => apply(Sort(sortOrder, global, aggregate)).mapChildren(AnalysisBarrier) case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => http://git-wip-us.apache.org/repos/asf/spark/blob/755f2f51/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c11e37a..07ae3ae 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1562,7 +1562,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("multi-insert with lateral view") { - withTempView("t1") { + withTempView("source") { spark.range(10) .select(array($"id", $"id" + 1).as("arr"), $"id") .createOrReplaceTempView("source") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org