Repository: spark
Updated Branches:
  refs/heads/master d7b4c0952 -> 2bf338c62


[SPARK-10165] [SQL] Await child resolution in ResolveFunctions

Currently, we eagerly attempt to resolve functions, even before their children 
are resolved.  However, this is not valid in cases where we need to know the 
types of the input arguments (i.e. when resolving Hive UDFs).

As a fix, this PR delays function resolution until the functions children are 
resolved.  This change also necessitates a change to the way we resolve 
aggregate expressions that are not in aggregate operators (e.g., in `HAVING` or 
`ORDER BY` clauses).  Specifically, we can't assume that these misplaced 
functions will be resolved, allowing us to differentiate aggregate functions 
from normal functions.  To compensate for this change we now attempt to resolve 
these unresolved expressions in the context of the aggregate operator, before 
checking to see if any aggregate expressions are present.

Author: Michael Armbrust <mich...@databricks.com>

Closes #8371 from marmbrus/hiveUDFResolution.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2bf338c6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2bf338c6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2bf338c6

Branch: refs/heads/master
Commit: 2bf338c626e9d97ccc033cfadae8b36a82c66fd1
Parents: d7b4c09
Author: Michael Armbrust <mich...@databricks.com>
Authored: Mon Aug 24 18:10:51 2015 -0700
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Mon Aug 24 18:10:51 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 116 ++++++++++++-------
 .../spark/sql/hive/execution/HiveUDFSuite.scala |   5 +
 2 files changed, 77 insertions(+), 44 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2bf338c6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index d0eb9c2..1a5de15 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -78,7 +78,7 @@ class Analyzer(
       ResolveAliases ::
       ExtractWindowExpressions ::
       GlobalAggregates ::
-      UnresolvedHavingClauseAttributes ::
+      ResolveAggregateFunctions ::
       HiveTypeCoercion.typeCoercionRules ++
       extendedResolutionRules : _*),
     Batch("Nondeterministic", Once,
@@ -452,37 +452,6 @@ class Analyzer(
           logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}")
           s // Nothing we can do here. Return original plan.
         }
-      case s @ Sort(ordering, global, a @ Aggregate(grouping, aggs, child))
-          if !s.resolved && a.resolved =>
-        // A small hack to create an object that will allow us to resolve any 
references that
-        // refer to named expressions that are present in the grouping 
expressions.
-        val groupingRelation = LocalRelation(
-          grouping.collect { case ne: NamedExpression => ne.toAttribute }
-        )
-
-        // Find sort attributes that are projected away so we can temporarily 
add them back in.
-        val (newOrdering, missingAttr) = resolveAndFindMissing(ordering, a, 
groupingRelation)
-
-        // Find aggregate expressions and evaluate them early, since they 
can't be evaluated in a
-        // Sort.
-        val (withAggsRemoved, aliasedAggregateList) = newOrdering.map {
-          case aggOrdering if aggOrdering.collect { case a: 
AggregateExpression => a }.nonEmpty =>
-            val aliased = Alias(aggOrdering.child, "_aggOrdering")()
-            (aggOrdering.copy(child = aliased.toAttribute), Some(aliased))
-
-          case other => (other, None)
-        }.unzip
-
-        val missing = missingAttr ++ aliasedAggregateList.flatten
-
-        if (missing.nonEmpty) {
-          // Add missing grouping exprs and then project them away after the 
sort.
-          Project(a.output,
-            Sort(withAggsRemoved, global,
-              Aggregate(grouping, aggs ++ missing, child)))
-        } else {
-          s // Nothing we can do here. Return original plan.
-        }
     }
 
     /**
@@ -515,6 +484,7 @@ class Analyzer(
     def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
       case q: LogicalPlan =>
         q transformExpressions {
+          case u if !u.childrenResolved => u // Skip until children are 
resolved.
           case u @ UnresolvedFunction(name, children, isDistinct) =>
             withPosition(u) {
               registry.lookupFunction(name, children) match {
@@ -559,21 +529,79 @@ class Analyzer(
   }
 
   /**
-   * This rule finds expressions in HAVING clause filters that depend on
-   * unresolved attributes.  It pushes these expressions down to the underlying
-   * aggregates and then projects them away above the filter.
+   * This rule finds aggregate expressions that are not in an aggregate 
operator.  For example,
+   * those in a HAVING clause or ORDER BY clause.  These expressions are 
pushed down to the
+   * underlying aggregate operator and then projected away after the original 
operator.
    */
-  object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] {
+  object ResolveAggregateFunctions extends Rule[LogicalPlan] {
     def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
-      case filter @ Filter(havingCondition, aggregate @ Aggregate(_, 
originalAggExprs, _))
-          if aggregate.resolved && containsAggregate(havingCondition) =>
-
-        val evaluatedCondition = Alias(havingCondition, "havingCondition")()
-        val aggExprsWithHaving = evaluatedCondition +: originalAggExprs
+      case filter @ Filter(havingCondition,
+             aggregate @ Aggregate(grouping, originalAggExprs, child))
+          if aggregate.resolved && !filter.resolved =>
+
+        // Try resolving the condition of the filter as though it is in the 
aggregate clause
+        val aggregatedCondition =
+          Aggregate(grouping, Alias(havingCondition, "havingCondition")() :: 
Nil, child)
+        val resolvedOperator = execute(aggregatedCondition)
+        def resolvedAggregateFilter =
+          resolvedOperator
+            .asInstanceOf[Aggregate]
+            .aggregateExpressions.head
+
+        // If resolution was successful and we see the filter has an aggregate 
in it, add it to
+        // the original aggregate operator.
+        if (resolvedOperator.resolved && 
containsAggregate(resolvedAggregateFilter)) {
+          val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs
+
+          Project(aggregate.output,
+            Filter(resolvedAggregateFilter.toAttribute,
+              aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
+        } else {
+          filter
+        }
 
-        Project(aggregate.output,
-          Filter(evaluatedCondition.toAttribute,
-            aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
+      case sort @ Sort(sortOrder, global,
+             aggregate @ Aggregate(grouping, originalAggExprs, child))
+        if aggregate.resolved && !sort.resolved =>
+
+        // Try resolving the ordering as though it is in the aggregate clause.
+        try {
+          val aliasedOrder = sortOrder.map(o => Alias(o.child, "aggOrder")())
+          val aggregatedOrdering = Aggregate(grouping, aliasedOrder, child)
+          val resolvedOperator: Aggregate = 
execute(aggregatedOrdering).asInstanceOf[Aggregate]
+          def resolvedAggregateOrdering = resolvedOperator.aggregateExpressions
+
+          // Expressions that have an aggregate can be pushed down.
+          val needsAggregate = 
resolvedAggregateOrdering.exists(containsAggregate)
+
+          // Attribute references, that are missing from the order but are 
present in the grouping
+          // expressions can also be pushed down.
+          val requiredAttributes = 
resolvedAggregateOrdering.map(_.references).reduce(_ ++ _)
+          val missingAttributes = requiredAttributes -- aggregate.outputSet
+          val validPushdownAttributes =
+            missingAttributes.filter(a => grouping.exists(a.semanticEquals))
+
+          // If resolution was successful and we see the ordering either has 
an aggregate in it or
+          // it is missing something that is projected away by the aggregate, 
add the ordering
+          // the original aggregate operator.
+          if (resolvedOperator.resolved && (needsAggregate || 
validPushdownAttributes.nonEmpty)) {
+            val evaluatedOrderings: Seq[SortOrder] = 
sortOrder.zip(resolvedAggregateOrdering).map {
+              case (order, evaluated) => order.copy(child = 
evaluated.toAttribute)
+            }
+            val aggExprsWithOrdering: Seq[NamedExpression] =
+              resolvedAggregateOrdering ++ originalAggExprs
+
+            Project(aggregate.output,
+              Sort(evaluatedOrderings, global,
+                aggregate.copy(aggregateExpressions = aggExprsWithOrdering)))
+          } else {
+            sort
+          }
+        } catch {
+          // Attempting to resolve in the aggregate can result in ambiguity.  
When this happens,
+          // just return the original plan.
+          case ae: AnalysisException => sort
+        }
     }
 
     protected def containsAggregate(condition: Expression): Boolean = {

http://git-wip-us.apache.org/repos/asf/spark/blob/2bf338c6/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index 10f2902..b03a351 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -276,6 +276,11 @@ class HiveUDFSuite extends QueryTest {
     checkAnswer(
       sql("SELECT testStringStringUDF(\"hello\", s) FROM stringTable"),
       Seq(Row("hello world"), Row("hello goodbye")))
+
+    checkAnswer(
+      sql("SELECT testStringStringUDF(\"\", testStringStringUDF(\"hello\", s)) 
FROM stringTable"),
+      Seq(Row(" hello world"), Row(" hello goodbye")))
+
     sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUDF")
 
     TestHive.reset()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to