Repository: spark
Updated Branches:
  refs/heads/branch-2.0 9c0ac6b53 -> 94d52d765


[SPARK-17269][SQL] Move finish analysis optimization stage into its own file

As part of breaking Optimizer.scala apart, this patch moves various finish 
analysis optimization stage rules into a single file. I'm submitting separate 
pull requests so we can more easily merge this in branch-2.0 to simplify 
optimizer backports.

This should be covered by existing tests.

Author: Reynold Xin <r...@databricks.com>

Closes #14838 from rxin/SPARK-17269.

(cherry picked from commit dcefac438788c51d84641bfbc505efe095731a39)
Signed-off-by: Reynold Xin <r...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/94d52d76
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/94d52d76
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/94d52d76

Branch: refs/heads/branch-2.0
Commit: 94d52d76569f8b0782f424cfac959a4bb75c54c0
Parents: 9c0ac6b
Author: Reynold Xin <r...@databricks.com>
Authored: Fri Aug 26 22:10:28 2016 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Fri Aug 26 22:12:11 2016 -0700

----------------------------------------------------------------------
 .../analysis/RewriteDistinctAggregates.scala    | 269 -------------------
 .../sql/catalyst/optimizer/Optimizer.scala      |  38 ---
 .../optimizer/RewriteDistinctAggregates.scala   | 269 +++++++++++++++++++
 .../sql/catalyst/optimizer/finishAnalysis.scala |  65 +++++
 4 files changed, 334 insertions(+), 307 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/94d52d76/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDistinctAggregates.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDistinctAggregates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDistinctAggregates.scala
deleted file mode 100644
index 8afd28d..0000000
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDistinctAggregates.scala
+++ /dev/null
@@ -1,269 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.analysis
-
-import org.apache.spark.sql.catalyst.expressions._
-import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
AggregateFunction, Complete}
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, 
LogicalPlan}
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.types.IntegerType
-
-/**
- * This rule rewrites an aggregate query with distinct aggregations into an 
expanded double
- * aggregation in which the regular aggregation expressions and every distinct 
clause is aggregated
- * in a separate group. The results are then combined in a second aggregate.
- *
- * For example (in scala):
- * {{{
- *   val data = Seq(
- *     ("a", "ca1", "cb1", 10),
- *     ("a", "ca1", "cb2", 5),
- *     ("b", "ca1", "cb1", 13))
- *     .toDF("key", "cat1", "cat2", "value")
- *   data.createOrReplaceTempView("data")
- *
- *   val agg = data.groupBy($"key")
- *     .agg(
- *       countDistinct($"cat1").as("cat1_cnt"),
- *       countDistinct($"cat2").as("cat2_cnt"),
- *       sum($"value").as("total"))
- * }}}
- *
- * This translates to the following (pseudo) logical plan:
- * {{{
- * Aggregate(
- *    key = ['key]
- *    functions = [COUNT(DISTINCT 'cat1),
- *                 COUNT(DISTINCT 'cat2),
- *                 sum('value)]
- *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
- *   LocalTableScan [...]
- * }}}
- *
- * This rule rewrites this logical plan to the following (pseudo) logical plan:
- * {{{
- * Aggregate(
- *    key = ['key]
- *    functions = [count(if (('gid = 1)) 'cat1 else null),
- *                 count(if (('gid = 2)) 'cat2 else null),
- *                 first(if (('gid = 0)) 'total else null) ignore nulls]
- *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
- *   Aggregate(
- *      key = ['key, 'cat1, 'cat2, 'gid]
- *      functions = [sum('value)]
- *      output = ['key, 'cat1, 'cat2, 'gid, 'total])
- *     Expand(
- *        projections = [('key, null, null, 0, cast('value as bigint)),
- *                       ('key, 'cat1, null, 1, null),
- *                       ('key, null, 'cat2, 2, null)]
- *        output = ['key, 'cat1, 'cat2, 'gid, 'value])
- *       LocalTableScan [...]
- * }}}
- *
- * The rule does the following things here:
- * 1. Expand the data. There are three aggregation groups in this query:
- *    i. the non-distinct group;
- *    ii. the distinct 'cat1 group;
- *    iii. the distinct 'cat2 group.
- *    An expand operator is inserted to expand the child data for each group. 
The expand will null
- *    out all unused columns for the given group; this must be done in order 
to ensure correctness
- *    later on. Groups can by identified by a group id (gid) column added by 
the expand operator.
- * 2. De-duplicate the distinct paths and aggregate the non-aggregate path. 
The group by clause of
- *    this aggregate consists of the original group by clause, all the 
requested distinct columns
- *    and the group id. Both de-duplication of distinct column and the 
aggregation of the
- *    non-distinct group take advantage of the fact that we group by the group 
id (gid) and that we
- *    have nulled out all non-relevant columns the given group.
- * 3. Aggregating the distinct groups and combining this with the results of 
the non-distinct
- *    aggregation. In this step we use the group id to filter the inputs for 
the aggregate
- *    functions. The result of the non-distinct group are 'aggregated' by 
using the first operator,
- *    it might be more elegant to use the native UDAF merge mechanism for this 
in the future.
- *
- * This rule duplicates the input data by two or more times (# distinct groups 
+ an optional
- * non-distinct group). This will put quite a bit of memory pressure of the 
used aggregate and
- * exchange operators. Keeping the number of distinct groups as low a possible 
should be priority,
- * we could improve this in the current rule by applying more advanced 
expression canonicalization
- * techniques.
- */
-object RewriteDistinctAggregates extends Rule[LogicalPlan] {
-
-  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
-    case a: Aggregate => rewrite(a)
-  }
-
-  def rewrite(a: Aggregate): Aggregate = {
-
-    // Collect all aggregate expressions.
-    val aggExpressions = a.aggregateExpressions.flatMap { e =>
-      e.collect {
-        case ae: AggregateExpression => ae
-      }
-    }
-
-    // Extract distinct aggregate expressions.
-    val distinctAggGroups = aggExpressions
-      .filter(_.isDistinct)
-      .groupBy(_.aggregateFunction.children.toSet)
-
-    // Aggregation strategy can handle the query with single distinct
-    if (distinctAggGroups.size > 1) {
-      // Create the attributes for the grouping id and the group by clause.
-      val gid =
-        new AttributeReference("gid", IntegerType, false)(isGenerated = true)
-      val groupByMap = a.groupingExpressions.collect {
-        case ne: NamedExpression => ne -> ne.toAttribute
-        case e => e -> new AttributeReference(e.sql, e.dataType, e.nullable)()
-      }
-      val groupByAttrs = groupByMap.map(_._2)
-
-      // Functions used to modify aggregate functions and their inputs.
-      def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), 
e, nullify(e))
-      def patchAggregateFunctionChildren(
-          af: AggregateFunction)(
-          attrs: Expression => Expression): AggregateFunction = {
-        af.withNewChildren(af.children.map {
-          case afc => attrs(afc)
-        }).asInstanceOf[AggregateFunction]
-      }
-
-      // Setup unique distinct aggregate children.
-      val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct
-      val distinctAggChildAttrMap = 
distinctAggChildren.map(expressionAttributePair)
-      val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2)
-
-      // Setup expand & aggregate operators for distinct aggregate expressions.
-      val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap
-      val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
-        case ((group, expressions), i) =>
-          val id = Literal(i + 1)
-
-          // Expand projection
-          val projection = distinctAggChildren.map {
-            case e if group.contains(e) => e
-            case e => nullify(e)
-          } :+ id
-
-          // Final aggregate
-          val operators = expressions.map { e =>
-            val af = e.aggregateFunction
-            val naf = patchAggregateFunctionChildren(af) { x =>
-              evalWithinGroup(id, distinctAggChildAttrLookup(x))
-            }
-            (e, e.copy(aggregateFunction = naf, isDistinct = false))
-          }
-
-          (projection, operators)
-      }
-
-      // Setup expand for the 'regular' aggregate expressions.
-      val regularAggExprs = aggExpressions.filter(!_.isDistinct)
-      val regularAggChildren = 
regularAggExprs.flatMap(_.aggregateFunction.children).distinct
-      val regularAggChildAttrMap = 
regularAggChildren.map(expressionAttributePair)
-
-      // Setup aggregates for 'regular' aggregate expressions.
-      val regularGroupId = Literal(0)
-      val regularAggChildAttrLookup = regularAggChildAttrMap.toMap
-      val regularAggOperatorMap = regularAggExprs.map { e =>
-        // Perform the actual aggregation in the initial aggregate.
-        val af = 
patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup)
-        val operator = Alias(e.copy(aggregateFunction = af), e.sql)()
-
-        // Select the result of the first aggregate in the last aggregate.
-        val result = AggregateExpression(
-          aggregate.First(evalWithinGroup(regularGroupId, 
operator.toAttribute), Literal(true)),
-          mode = Complete,
-          isDistinct = false)
-
-        // Some aggregate functions (COUNT) have the special property that 
they can return a
-        // non-null result without any input. We need to make sure we return a 
result in this case.
-        val resultWithDefault = af.defaultResult match {
-          case Some(lit) => Coalesce(Seq(result, lit))
-          case None => result
-        }
-
-        // Return a Tuple3 containing:
-        // i. The original aggregate expression (used for look ups).
-        // ii. The actual aggregation operator (used in the first aggregate).
-        // iii. The operator that selects and returns the result (used in the 
second aggregate).
-        (e, operator, resultWithDefault)
-      }
-
-      // Construct the regular aggregate input projection only if we need one.
-      val regularAggProjection = if (regularAggExprs.nonEmpty) {
-        Seq(a.groupingExpressions ++
-          distinctAggChildren.map(nullify) ++
-          Seq(regularGroupId) ++
-          regularAggChildren)
-      } else {
-        Seq.empty[Seq[Expression]]
-      }
-
-      // Construct the distinct aggregate input projections.
-      val regularAggNulls = regularAggChildren.map(nullify)
-      val distinctAggProjections = distinctAggOperatorMap.map {
-        case (projection, _) =>
-          a.groupingExpressions ++
-            projection ++
-            regularAggNulls
-      }
-
-      // Construct the expand operator.
-      val expand = Expand(
-        regularAggProjection ++ distinctAggProjections,
-        groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ 
regularAggChildAttrMap.map(_._2),
-        a.child)
-
-      // Construct the first aggregate operator. This de-duplicates the all 
the children of
-      // distinct operators, and applies the regular aggregate operators.
-      val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid
-      val firstAggregate = Aggregate(
-        firstAggregateGroupBy,
-        firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2),
-        expand)
-
-      // Construct the second aggregate
-      val transformations: Map[Expression, Expression] =
-        (distinctAggOperatorMap.flatMap(_._2) ++
-          regularAggOperatorMap.map(e => (e._1, e._3))).toMap
-
-      val patchedAggExpressions = a.aggregateExpressions.map { e =>
-        e.transformDown {
-          case e: Expression =>
-            // The same GROUP BY clauses can have different forms (different 
names for instance) in
-            // the groupBy and aggregate expressions of an aggregate. This 
makes a map lookup
-            // tricky. So we do a linear search for a semantically equal group 
by expression.
-            groupByMap
-              .find(ge => e.semanticEquals(ge._1))
-              .map(_._2)
-              .getOrElse(transformations.getOrElse(e, e))
-        }.asInstanceOf[NamedExpression]
-      }
-      Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate)
-    } else {
-      a
-    }
-  }
-
-  private def nullify(e: Expression) = Literal.create(null, e.dataType)
-
-  private def expressionAttributePair(e: Expression) =
-    // We are creating a new reference here instead of reusing the attribute 
in case of a
-    // NamedExpression. This is done to prevent collisions between distinct 
and regular aggregate
-    // children, in this case attribute reuse causes the input of the regular 
aggregate to bound to
-    // the (nulled out) input of the distinct aggregate.
-    e -> new AttributeReference(e.sql, e.dataType, true)()
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/94d52d76/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 4cadbc3..f3f1d21 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1583,44 +1583,6 @@ object RemoveRepetitionFromGroupExpressions extends 
Rule[LogicalPlan] {
 }
 
 /**
- * Finds all [[RuntimeReplaceable]] expressions and replace them with the 
expressions that can
- * be evaluated. This is mainly used to provide compatibility with other 
databases.
- * For example, we use this to support "nvl" by replacing it with "coalesce".
- */
-object ReplaceExpressions extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
-    case e: RuntimeReplaceable => e.replaced
-  }
-}
-
-/**
- * Computes the current date and time to make sure we return the same result 
in a single query.
- */
-object ComputeCurrentTime extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = {
-    val dateExpr = CurrentDate()
-    val timeExpr = CurrentTimestamp()
-    val currentDate = Literal.create(dateExpr.eval(EmptyRow), 
dateExpr.dataType)
-    val currentTime = Literal.create(timeExpr.eval(EmptyRow), 
timeExpr.dataType)
-
-    plan transformAllExpressions {
-      case CurrentDate() => currentDate
-      case CurrentTimestamp() => currentTime
-    }
-  }
-}
-
-/** Replaces the expression of CurrentDatabase with the current database name. 
*/
-case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends 
Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = {
-    plan transformAllExpressions {
-      case CurrentDatabase() =>
-        Literal.create(sessionCatalog.getCurrentDatabase, StringType)
-    }
-  }
-}
-
-/**
  * Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] 
beneath it and a
  * [[SerializeFromObject]] above it.  If these serializations can't be 
eliminated, we should embed
  * the deserializer in filter condition to save the extra serialization at 
last.

http://git-wip-us.apache.org/repos/asf/spark/blob/94d52d76/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
new file mode 100644
index 0000000..0f43e7b
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
@@ -0,0 +1,269 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions._
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
AggregateFunction, Complete}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, 
LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.types.IntegerType
+
+/**
+ * This rule rewrites an aggregate query with distinct aggregations into an 
expanded double
+ * aggregation in which the regular aggregation expressions and every distinct 
clause is aggregated
+ * in a separate group. The results are then combined in a second aggregate.
+ *
+ * For example (in scala):
+ * {{{
+ *   val data = Seq(
+ *     ("a", "ca1", "cb1", 10),
+ *     ("a", "ca1", "cb2", 5),
+ *     ("b", "ca1", "cb1", 13))
+ *     .toDF("key", "cat1", "cat2", "value")
+ *   data.createOrReplaceTempView("data")
+ *
+ *   val agg = data.groupBy($"key")
+ *     .agg(
+ *       countDistinct($"cat1").as("cat1_cnt"),
+ *       countDistinct($"cat2").as("cat2_cnt"),
+ *       sum($"value").as("total"))
+ * }}}
+ *
+ * This translates to the following (pseudo) logical plan:
+ * {{{
+ * Aggregate(
+ *    key = ['key]
+ *    functions = [COUNT(DISTINCT 'cat1),
+ *                 COUNT(DISTINCT 'cat2),
+ *                 sum('value)]
+ *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
+ *   LocalTableScan [...]
+ * }}}
+ *
+ * This rule rewrites this logical plan to the following (pseudo) logical plan:
+ * {{{
+ * Aggregate(
+ *    key = ['key]
+ *    functions = [count(if (('gid = 1)) 'cat1 else null),
+ *                 count(if (('gid = 2)) 'cat2 else null),
+ *                 first(if (('gid = 0)) 'total else null) ignore nulls]
+ *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
+ *   Aggregate(
+ *      key = ['key, 'cat1, 'cat2, 'gid]
+ *      functions = [sum('value)]
+ *      output = ['key, 'cat1, 'cat2, 'gid, 'total])
+ *     Expand(
+ *        projections = [('key, null, null, 0, cast('value as bigint)),
+ *                       ('key, 'cat1, null, 1, null),
+ *                       ('key, null, 'cat2, 2, null)]
+ *        output = ['key, 'cat1, 'cat2, 'gid, 'value])
+ *       LocalTableScan [...]
+ * }}}
+ *
+ * The rule does the following things here:
+ * 1. Expand the data. There are three aggregation groups in this query:
+ *    i. the non-distinct group;
+ *    ii. the distinct 'cat1 group;
+ *    iii. the distinct 'cat2 group.
+ *    An expand operator is inserted to expand the child data for each group. 
The expand will null
+ *    out all unused columns for the given group; this must be done in order 
to ensure correctness
+ *    later on. Groups can by identified by a group id (gid) column added by 
the expand operator.
+ * 2. De-duplicate the distinct paths and aggregate the non-aggregate path. 
The group by clause of
+ *    this aggregate consists of the original group by clause, all the 
requested distinct columns
+ *    and the group id. Both de-duplication of distinct column and the 
aggregation of the
+ *    non-distinct group take advantage of the fact that we group by the group 
id (gid) and that we
+ *    have nulled out all non-relevant columns the given group.
+ * 3. Aggregating the distinct groups and combining this with the results of 
the non-distinct
+ *    aggregation. In this step we use the group id to filter the inputs for 
the aggregate
+ *    functions. The result of the non-distinct group are 'aggregated' by 
using the first operator,
+ *    it might be more elegant to use the native UDAF merge mechanism for this 
in the future.
+ *
+ * This rule duplicates the input data by two or more times (# distinct groups 
+ an optional
+ * non-distinct group). This will put quite a bit of memory pressure of the 
used aggregate and
+ * exchange operators. Keeping the number of distinct groups as low a possible 
should be priority,
+ * we could improve this in the current rule by applying more advanced 
expression canonicalization
+ * techniques.
+ */
+object RewriteDistinctAggregates extends Rule[LogicalPlan] {
+
+  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+    case a: Aggregate => rewrite(a)
+  }
+
+  def rewrite(a: Aggregate): Aggregate = {
+
+    // Collect all aggregate expressions.
+    val aggExpressions = a.aggregateExpressions.flatMap { e =>
+      e.collect {
+        case ae: AggregateExpression => ae
+      }
+    }
+
+    // Extract distinct aggregate expressions.
+    val distinctAggGroups = aggExpressions
+      .filter(_.isDistinct)
+      .groupBy(_.aggregateFunction.children.toSet)
+
+    // Aggregation strategy can handle the query with single distinct
+    if (distinctAggGroups.size > 1) {
+      // Create the attributes for the grouping id and the group by clause.
+      val gid =
+        new AttributeReference("gid", IntegerType, false)(isGenerated = true)
+      val groupByMap = a.groupingExpressions.collect {
+        case ne: NamedExpression => ne -> ne.toAttribute
+        case e => e -> new AttributeReference(e.sql, e.dataType, e.nullable)()
+      }
+      val groupByAttrs = groupByMap.map(_._2)
+
+      // Functions used to modify aggregate functions and their inputs.
+      def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), 
e, nullify(e))
+      def patchAggregateFunctionChildren(
+          af: AggregateFunction)(
+          attrs: Expression => Expression): AggregateFunction = {
+        af.withNewChildren(af.children.map {
+          case afc => attrs(afc)
+        }).asInstanceOf[AggregateFunction]
+      }
+
+      // Setup unique distinct aggregate children.
+      val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct
+      val distinctAggChildAttrMap = 
distinctAggChildren.map(expressionAttributePair)
+      val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2)
+
+      // Setup expand & aggregate operators for distinct aggregate expressions.
+      val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap
+      val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
+        case ((group, expressions), i) =>
+          val id = Literal(i + 1)
+
+          // Expand projection
+          val projection = distinctAggChildren.map {
+            case e if group.contains(e) => e
+            case e => nullify(e)
+          } :+ id
+
+          // Final aggregate
+          val operators = expressions.map { e =>
+            val af = e.aggregateFunction
+            val naf = patchAggregateFunctionChildren(af) { x =>
+              evalWithinGroup(id, distinctAggChildAttrLookup(x))
+            }
+            (e, e.copy(aggregateFunction = naf, isDistinct = false))
+          }
+
+          (projection, operators)
+      }
+
+      // Setup expand for the 'regular' aggregate expressions.
+      val regularAggExprs = aggExpressions.filter(!_.isDistinct)
+      val regularAggChildren = 
regularAggExprs.flatMap(_.aggregateFunction.children).distinct
+      val regularAggChildAttrMap = 
regularAggChildren.map(expressionAttributePair)
+
+      // Setup aggregates for 'regular' aggregate expressions.
+      val regularGroupId = Literal(0)
+      val regularAggChildAttrLookup = regularAggChildAttrMap.toMap
+      val regularAggOperatorMap = regularAggExprs.map { e =>
+        // Perform the actual aggregation in the initial aggregate.
+        val af = 
patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup)
+        val operator = Alias(e.copy(aggregateFunction = af), e.sql)()
+
+        // Select the result of the first aggregate in the last aggregate.
+        val result = AggregateExpression(
+          aggregate.First(evalWithinGroup(regularGroupId, 
operator.toAttribute), Literal(true)),
+          mode = Complete,
+          isDistinct = false)
+
+        // Some aggregate functions (COUNT) have the special property that 
they can return a
+        // non-null result without any input. We need to make sure we return a 
result in this case.
+        val resultWithDefault = af.defaultResult match {
+          case Some(lit) => Coalesce(Seq(result, lit))
+          case None => result
+        }
+
+        // Return a Tuple3 containing:
+        // i. The original aggregate expression (used for look ups).
+        // ii. The actual aggregation operator (used in the first aggregate).
+        // iii. The operator that selects and returns the result (used in the 
second aggregate).
+        (e, operator, resultWithDefault)
+      }
+
+      // Construct the regular aggregate input projection only if we need one.
+      val regularAggProjection = if (regularAggExprs.nonEmpty) {
+        Seq(a.groupingExpressions ++
+          distinctAggChildren.map(nullify) ++
+          Seq(regularGroupId) ++
+          regularAggChildren)
+      } else {
+        Seq.empty[Seq[Expression]]
+      }
+
+      // Construct the distinct aggregate input projections.
+      val regularAggNulls = regularAggChildren.map(nullify)
+      val distinctAggProjections = distinctAggOperatorMap.map {
+        case (projection, _) =>
+          a.groupingExpressions ++
+            projection ++
+            regularAggNulls
+      }
+
+      // Construct the expand operator.
+      val expand = Expand(
+        regularAggProjection ++ distinctAggProjections,
+        groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ 
regularAggChildAttrMap.map(_._2),
+        a.child)
+
+      // Construct the first aggregate operator. This de-duplicates the all 
the children of
+      // distinct operators, and applies the regular aggregate operators.
+      val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid
+      val firstAggregate = Aggregate(
+        firstAggregateGroupBy,
+        firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2),
+        expand)
+
+      // Construct the second aggregate
+      val transformations: Map[Expression, Expression] =
+        (distinctAggOperatorMap.flatMap(_._2) ++
+          regularAggOperatorMap.map(e => (e._1, e._3))).toMap
+
+      val patchedAggExpressions = a.aggregateExpressions.map { e =>
+        e.transformDown {
+          case e: Expression =>
+            // The same GROUP BY clauses can have different forms (different 
names for instance) in
+            // the groupBy and aggregate expressions of an aggregate. This 
makes a map lookup
+            // tricky. So we do a linear search for a semantically equal group 
by expression.
+            groupByMap
+              .find(ge => e.semanticEquals(ge._1))
+              .map(_._2)
+              .getOrElse(transformations.getOrElse(e, e))
+        }.asInstanceOf[NamedExpression]
+      }
+      Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate)
+    } else {
+      a
+    }
+  }
+
+  private def nullify(e: Expression) = Literal.create(null, e.dataType)
+
+  private def expressionAttributePair(e: Expression) =
+    // We are creating a new reference here instead of reusing the attribute 
in case of a
+    // NamedExpression. This is done to prevent collisions between distinct 
and regular aggregate
+    // children, in this case attribute reuse causes the input of the regular 
aggregate to bound to
+    // the (nulled out) input of the distinct aggregate.
+    e -> new AttributeReference(e.sql, e.dataType, true)()
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/94d52d76/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
new file mode 100644
index 0000000..7c66731
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.catalog.SessionCatalog
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.types._
+
+
+/**
+ * Finds all [[RuntimeReplaceable]] expressions and replace them with the 
expressions that can
+ * be evaluated. This is mainly used to provide compatibility with other 
databases.
+ * For example, we use this to support "nvl" by replacing it with "coalesce".
+ */
+object ReplaceExpressions extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+    case e: RuntimeReplaceable => e.replaced
+  }
+}
+
+
+/**
+ * Computes the current date and time to make sure we return the same result 
in a single query.
+ */
+object ComputeCurrentTime extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = {
+    val dateExpr = CurrentDate()
+    val timeExpr = CurrentTimestamp()
+    val currentDate = Literal.create(dateExpr.eval(EmptyRow), 
dateExpr.dataType)
+    val currentTime = Literal.create(timeExpr.eval(EmptyRow), 
timeExpr.dataType)
+
+    plan transformAllExpressions {
+      case CurrentDate() => currentDate
+      case CurrentTimestamp() => currentTime
+    }
+  }
+}
+
+
+/** Replaces the expression of CurrentDatabase with the current database name. 
*/
+case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends 
Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = {
+    plan transformAllExpressions {
+      case CurrentDatabase() =>
+        Literal.create(sessionCatalog.getCurrentDatabase, StringType)
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to