Repository: spark Updated Branches: refs/heads/branch-2.0 9c0ac6b53 -> 94d52d765
[SPARK-17269][SQL] Move finish analysis optimization stage into its own file As part of breaking Optimizer.scala apart, this patch moves various finish analysis optimization stage rules into a single file. I'm submitting separate pull requests so we can more easily merge this in branch-2.0 to simplify optimizer backports. This should be covered by existing tests. Author: Reynold Xin <r...@databricks.com> Closes #14838 from rxin/SPARK-17269. (cherry picked from commit dcefac438788c51d84641bfbc505efe095731a39) Signed-off-by: Reynold Xin <r...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/94d52d76 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/94d52d76 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/94d52d76 Branch: refs/heads/branch-2.0 Commit: 94d52d76569f8b0782f424cfac959a4bb75c54c0 Parents: 9c0ac6b Author: Reynold Xin <r...@databricks.com> Authored: Fri Aug 26 22:10:28 2016 -0700 Committer: Reynold Xin <r...@databricks.com> Committed: Fri Aug 26 22:12:11 2016 -0700 ---------------------------------------------------------------------- .../analysis/RewriteDistinctAggregates.scala | 269 ------------------- .../sql/catalyst/optimizer/Optimizer.scala | 38 --- .../optimizer/RewriteDistinctAggregates.scala | 269 +++++++++++++++++++ .../sql/catalyst/optimizer/finishAnalysis.scala | 65 +++++ 4 files changed, 334 insertions(+), 307 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/94d52d76/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDistinctAggregates.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDistinctAggregates.scala deleted file mode 100644 index 8afd28d..0000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteDistinctAggregates.scala +++ /dev/null @@ -1,269 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.analysis - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.types.IntegerType - -/** - * This rule rewrites an aggregate query with distinct aggregations into an expanded double - * aggregation in which the regular aggregation expressions and every distinct clause is aggregated - * in a separate group. The results are then combined in a second aggregate. - * - * For example (in scala): - * {{{ - * val data = Seq( - * ("a", "ca1", "cb1", 10), - * ("a", "ca1", "cb2", 5), - * ("b", "ca1", "cb1", 13)) - * .toDF("key", "cat1", "cat2", "value") - * data.createOrReplaceTempView("data") - * - * val agg = data.groupBy($"key") - * .agg( - * countDistinct($"cat1").as("cat1_cnt"), - * countDistinct($"cat2").as("cat2_cnt"), - * sum($"value").as("total")) - * }}} - * - * This translates to the following (pseudo) logical plan: - * {{{ - * Aggregate( - * key = ['key] - * functions = [COUNT(DISTINCT 'cat1), - * COUNT(DISTINCT 'cat2), - * sum('value)] - * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) - * LocalTableScan [...] - * }}} - * - * This rule rewrites this logical plan to the following (pseudo) logical plan: - * {{{ - * Aggregate( - * key = ['key] - * functions = [count(if (('gid = 1)) 'cat1 else null), - * count(if (('gid = 2)) 'cat2 else null), - * first(if (('gid = 0)) 'total else null) ignore nulls] - * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) - * Aggregate( - * key = ['key, 'cat1, 'cat2, 'gid] - * functions = [sum('value)] - * output = ['key, 'cat1, 'cat2, 'gid, 'total]) - * Expand( - * projections = [('key, null, null, 0, cast('value as bigint)), - * ('key, 'cat1, null, 1, null), - * ('key, null, 'cat2, 2, null)] - * output = ['key, 'cat1, 'cat2, 'gid, 'value]) - * LocalTableScan [...] - * }}} - * - * The rule does the following things here: - * 1. Expand the data. There are three aggregation groups in this query: - * i. the non-distinct group; - * ii. the distinct 'cat1 group; - * iii. the distinct 'cat2 group. - * An expand operator is inserted to expand the child data for each group. The expand will null - * out all unused columns for the given group; this must be done in order to ensure correctness - * later on. Groups can by identified by a group id (gid) column added by the expand operator. - * 2. De-duplicate the distinct paths and aggregate the non-aggregate path. The group by clause of - * this aggregate consists of the original group by clause, all the requested distinct columns - * and the group id. Both de-duplication of distinct column and the aggregation of the - * non-distinct group take advantage of the fact that we group by the group id (gid) and that we - * have nulled out all non-relevant columns the given group. - * 3. Aggregating the distinct groups and combining this with the results of the non-distinct - * aggregation. In this step we use the group id to filter the inputs for the aggregate - * functions. The result of the non-distinct group are 'aggregated' by using the first operator, - * it might be more elegant to use the native UDAF merge mechanism for this in the future. - * - * This rule duplicates the input data by two or more times (# distinct groups + an optional - * non-distinct group). This will put quite a bit of memory pressure of the used aggregate and - * exchange operators. Keeping the number of distinct groups as low a possible should be priority, - * we could improve this in the current rule by applying more advanced expression canonicalization - * techniques. - */ -object RewriteDistinctAggregates extends Rule[LogicalPlan] { - - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case a: Aggregate => rewrite(a) - } - - def rewrite(a: Aggregate): Aggregate = { - - // Collect all aggregate expressions. - val aggExpressions = a.aggregateExpressions.flatMap { e => - e.collect { - case ae: AggregateExpression => ae - } - } - - // Extract distinct aggregate expressions. - val distinctAggGroups = aggExpressions - .filter(_.isDistinct) - .groupBy(_.aggregateFunction.children.toSet) - - // Aggregation strategy can handle the query with single distinct - if (distinctAggGroups.size > 1) { - // Create the attributes for the grouping id and the group by clause. - val gid = - new AttributeReference("gid", IntegerType, false)(isGenerated = true) - val groupByMap = a.groupingExpressions.collect { - case ne: NamedExpression => ne -> ne.toAttribute - case e => e -> new AttributeReference(e.sql, e.dataType, e.nullable)() - } - val groupByAttrs = groupByMap.map(_._2) - - // Functions used to modify aggregate functions and their inputs. - def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) - def patchAggregateFunctionChildren( - af: AggregateFunction)( - attrs: Expression => Expression): AggregateFunction = { - af.withNewChildren(af.children.map { - case afc => attrs(afc) - }).asInstanceOf[AggregateFunction] - } - - // Setup unique distinct aggregate children. - val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct - val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) - val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) - - // Setup expand & aggregate operators for distinct aggregate expressions. - val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap - val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { - case ((group, expressions), i) => - val id = Literal(i + 1) - - // Expand projection - val projection = distinctAggChildren.map { - case e if group.contains(e) => e - case e => nullify(e) - } :+ id - - // Final aggregate - val operators = expressions.map { e => - val af = e.aggregateFunction - val naf = patchAggregateFunctionChildren(af) { x => - evalWithinGroup(id, distinctAggChildAttrLookup(x)) - } - (e, e.copy(aggregateFunction = naf, isDistinct = false)) - } - - (projection, operators) - } - - // Setup expand for the 'regular' aggregate expressions. - val regularAggExprs = aggExpressions.filter(!_.isDistinct) - val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct - val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) - - // Setup aggregates for 'regular' aggregate expressions. - val regularGroupId = Literal(0) - val regularAggChildAttrLookup = regularAggChildAttrMap.toMap - val regularAggOperatorMap = regularAggExprs.map { e => - // Perform the actual aggregation in the initial aggregate. - val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup) - val operator = Alias(e.copy(aggregateFunction = af), e.sql)() - - // Select the result of the first aggregate in the last aggregate. - val result = AggregateExpression( - aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)), - mode = Complete, - isDistinct = false) - - // Some aggregate functions (COUNT) have the special property that they can return a - // non-null result without any input. We need to make sure we return a result in this case. - val resultWithDefault = af.defaultResult match { - case Some(lit) => Coalesce(Seq(result, lit)) - case None => result - } - - // Return a Tuple3 containing: - // i. The original aggregate expression (used for look ups). - // ii. The actual aggregation operator (used in the first aggregate). - // iii. The operator that selects and returns the result (used in the second aggregate). - (e, operator, resultWithDefault) - } - - // Construct the regular aggregate input projection only if we need one. - val regularAggProjection = if (regularAggExprs.nonEmpty) { - Seq(a.groupingExpressions ++ - distinctAggChildren.map(nullify) ++ - Seq(regularGroupId) ++ - regularAggChildren) - } else { - Seq.empty[Seq[Expression]] - } - - // Construct the distinct aggregate input projections. - val regularAggNulls = regularAggChildren.map(nullify) - val distinctAggProjections = distinctAggOperatorMap.map { - case (projection, _) => - a.groupingExpressions ++ - projection ++ - regularAggNulls - } - - // Construct the expand operator. - val expand = Expand( - regularAggProjection ++ distinctAggProjections, - groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2), - a.child) - - // Construct the first aggregate operator. This de-duplicates the all the children of - // distinct operators, and applies the regular aggregate operators. - val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid - val firstAggregate = Aggregate( - firstAggregateGroupBy, - firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2), - expand) - - // Construct the second aggregate - val transformations: Map[Expression, Expression] = - (distinctAggOperatorMap.flatMap(_._2) ++ - regularAggOperatorMap.map(e => (e._1, e._3))).toMap - - val patchedAggExpressions = a.aggregateExpressions.map { e => - e.transformDown { - case e: Expression => - // The same GROUP BY clauses can have different forms (different names for instance) in - // the groupBy and aggregate expressions of an aggregate. This makes a map lookup - // tricky. So we do a linear search for a semantically equal group by expression. - groupByMap - .find(ge => e.semanticEquals(ge._1)) - .map(_._2) - .getOrElse(transformations.getOrElse(e, e)) - }.asInstanceOf[NamedExpression] - } - Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) - } else { - a - } - } - - private def nullify(e: Expression) = Literal.create(null, e.dataType) - - private def expressionAttributePair(e: Expression) = - // We are creating a new reference here instead of reusing the attribute in case of a - // NamedExpression. This is done to prevent collisions between distinct and regular aggregate - // children, in this case attribute reuse causes the input of the regular aggregate to bound to - // the (nulled out) input of the distinct aggregate. - e -> new AttributeReference(e.sql, e.dataType, true)() -} http://git-wip-us.apache.org/repos/asf/spark/blob/94d52d76/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 4cadbc3..f3f1d21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1583,44 +1583,6 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { } /** - * Finds all [[RuntimeReplaceable]] expressions and replace them with the expressions that can - * be evaluated. This is mainly used to provide compatibility with other databases. - * For example, we use this to support "nvl" by replacing it with "coalesce". - */ -object ReplaceExpressions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case e: RuntimeReplaceable => e.replaced - } -} - -/** - * Computes the current date and time to make sure we return the same result in a single query. - */ -object ComputeCurrentTime extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - val dateExpr = CurrentDate() - val timeExpr = CurrentTimestamp() - val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) - val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) - - plan transformAllExpressions { - case CurrentDate() => currentDate - case CurrentTimestamp() => currentTime - } - } -} - -/** Replaces the expression of CurrentDatabase with the current database name. */ -case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - plan transformAllExpressions { - case CurrentDatabase() => - Literal.create(sessionCatalog.getCurrentDatabase, StringType) - } - } -} - -/** * Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a * [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed * the deserializer in filter condition to save the extra serialization at last. http://git-wip-us.apache.org/repos/asf/spark/blob/94d52d76/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala new file mode 100644 index 0000000..0f43e7b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.IntegerType + +/** + * This rule rewrites an aggregate query with distinct aggregations into an expanded double + * aggregation in which the regular aggregation expressions and every distinct clause is aggregated + * in a separate group. The results are then combined in a second aggregate. + * + * For example (in scala): + * {{{ + * val data = Seq( + * ("a", "ca1", "cb1", 10), + * ("a", "ca1", "cb2", 5), + * ("b", "ca1", "cb1", 13)) + * .toDF("key", "cat1", "cat2", "value") + * data.createOrReplaceTempView("data") + * + * val agg = data.groupBy($"key") + * .agg( + * countDistinct($"cat1").as("cat1_cnt"), + * countDistinct($"cat2").as("cat2_cnt"), + * sum($"value").as("total")) + * }}} + * + * This translates to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [COUNT(DISTINCT 'cat1), + * COUNT(DISTINCT 'cat2), + * sum('value)] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) + * LocalTableScan [...] + * }}} + * + * This rule rewrites this logical plan to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [count(if (('gid = 1)) 'cat1 else null), + * count(if (('gid = 2)) 'cat2 else null), + * first(if (('gid = 0)) 'total else null) ignore nulls] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) + * Aggregate( + * key = ['key, 'cat1, 'cat2, 'gid] + * functions = [sum('value)] + * output = ['key, 'cat1, 'cat2, 'gid, 'total]) + * Expand( + * projections = [('key, null, null, 0, cast('value as bigint)), + * ('key, 'cat1, null, 1, null), + * ('key, null, 'cat2, 2, null)] + * output = ['key, 'cat1, 'cat2, 'gid, 'value]) + * LocalTableScan [...] + * }}} + * + * The rule does the following things here: + * 1. Expand the data. There are three aggregation groups in this query: + * i. the non-distinct group; + * ii. the distinct 'cat1 group; + * iii. the distinct 'cat2 group. + * An expand operator is inserted to expand the child data for each group. The expand will null + * out all unused columns for the given group; this must be done in order to ensure correctness + * later on. Groups can by identified by a group id (gid) column added by the expand operator. + * 2. De-duplicate the distinct paths and aggregate the non-aggregate path. The group by clause of + * this aggregate consists of the original group by clause, all the requested distinct columns + * and the group id. Both de-duplication of distinct column and the aggregation of the + * non-distinct group take advantage of the fact that we group by the group id (gid) and that we + * have nulled out all non-relevant columns the given group. + * 3. Aggregating the distinct groups and combining this with the results of the non-distinct + * aggregation. In this step we use the group id to filter the inputs for the aggregate + * functions. The result of the non-distinct group are 'aggregated' by using the first operator, + * it might be more elegant to use the native UDAF merge mechanism for this in the future. + * + * This rule duplicates the input data by two or more times (# distinct groups + an optional + * non-distinct group). This will put quite a bit of memory pressure of the used aggregate and + * exchange operators. Keeping the number of distinct groups as low a possible should be priority, + * we could improve this in the current rule by applying more advanced expression canonicalization + * techniques. + */ +object RewriteDistinctAggregates extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case a: Aggregate => rewrite(a) + } + + def rewrite(a: Aggregate): Aggregate = { + + // Collect all aggregate expressions. + val aggExpressions = a.aggregateExpressions.flatMap { e => + e.collect { + case ae: AggregateExpression => ae + } + } + + // Extract distinct aggregate expressions. + val distinctAggGroups = aggExpressions + .filter(_.isDistinct) + .groupBy(_.aggregateFunction.children.toSet) + + // Aggregation strategy can handle the query with single distinct + if (distinctAggGroups.size > 1) { + // Create the attributes for the grouping id and the group by clause. + val gid = + new AttributeReference("gid", IntegerType, false)(isGenerated = true) + val groupByMap = a.groupingExpressions.collect { + case ne: NamedExpression => ne -> ne.toAttribute + case e => e -> new AttributeReference(e.sql, e.dataType, e.nullable)() + } + val groupByAttrs = groupByMap.map(_._2) + + // Functions used to modify aggregate functions and their inputs. + def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) + def patchAggregateFunctionChildren( + af: AggregateFunction)( + attrs: Expression => Expression): AggregateFunction = { + af.withNewChildren(af.children.map { + case afc => attrs(afc) + }).asInstanceOf[AggregateFunction] + } + + // Setup unique distinct aggregate children. + val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct + val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) + val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) + + // Setup expand & aggregate operators for distinct aggregate expressions. + val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap + val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { + case ((group, expressions), i) => + val id = Literal(i + 1) + + // Expand projection + val projection = distinctAggChildren.map { + case e if group.contains(e) => e + case e => nullify(e) + } :+ id + + // Final aggregate + val operators = expressions.map { e => + val af = e.aggregateFunction + val naf = patchAggregateFunctionChildren(af) { x => + evalWithinGroup(id, distinctAggChildAttrLookup(x)) + } + (e, e.copy(aggregateFunction = naf, isDistinct = false)) + } + + (projection, operators) + } + + // Setup expand for the 'regular' aggregate expressions. + val regularAggExprs = aggExpressions.filter(!_.isDistinct) + val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct + val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) + + // Setup aggregates for 'regular' aggregate expressions. + val regularGroupId = Literal(0) + val regularAggChildAttrLookup = regularAggChildAttrMap.toMap + val regularAggOperatorMap = regularAggExprs.map { e => + // Perform the actual aggregation in the initial aggregate. + val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup) + val operator = Alias(e.copy(aggregateFunction = af), e.sql)() + + // Select the result of the first aggregate in the last aggregate. + val result = AggregateExpression( + aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)), + mode = Complete, + isDistinct = false) + + // Some aggregate functions (COUNT) have the special property that they can return a + // non-null result without any input. We need to make sure we return a result in this case. + val resultWithDefault = af.defaultResult match { + case Some(lit) => Coalesce(Seq(result, lit)) + case None => result + } + + // Return a Tuple3 containing: + // i. The original aggregate expression (used for look ups). + // ii. The actual aggregation operator (used in the first aggregate). + // iii. The operator that selects and returns the result (used in the second aggregate). + (e, operator, resultWithDefault) + } + + // Construct the regular aggregate input projection only if we need one. + val regularAggProjection = if (regularAggExprs.nonEmpty) { + Seq(a.groupingExpressions ++ + distinctAggChildren.map(nullify) ++ + Seq(regularGroupId) ++ + regularAggChildren) + } else { + Seq.empty[Seq[Expression]] + } + + // Construct the distinct aggregate input projections. + val regularAggNulls = regularAggChildren.map(nullify) + val distinctAggProjections = distinctAggOperatorMap.map { + case (projection, _) => + a.groupingExpressions ++ + projection ++ + regularAggNulls + } + + // Construct the expand operator. + val expand = Expand( + regularAggProjection ++ distinctAggProjections, + groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2), + a.child) + + // Construct the first aggregate operator. This de-duplicates the all the children of + // distinct operators, and applies the regular aggregate operators. + val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid + val firstAggregate = Aggregate( + firstAggregateGroupBy, + firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2), + expand) + + // Construct the second aggregate + val transformations: Map[Expression, Expression] = + (distinctAggOperatorMap.flatMap(_._2) ++ + regularAggOperatorMap.map(e => (e._1, e._3))).toMap + + val patchedAggExpressions = a.aggregateExpressions.map { e => + e.transformDown { + case e: Expression => + // The same GROUP BY clauses can have different forms (different names for instance) in + // the groupBy and aggregate expressions of an aggregate. This makes a map lookup + // tricky. So we do a linear search for a semantically equal group by expression. + groupByMap + .find(ge => e.semanticEquals(ge._1)) + .map(_._2) + .getOrElse(transformations.getOrElse(e, e)) + }.asInstanceOf[NamedExpression] + } + Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) + } else { + a + } + } + + private def nullify(e: Expression) = Literal.create(null, e.dataType) + + private def expressionAttributePair(e: Expression) = + // We are creating a new reference here instead of reusing the attribute in case of a + // NamedExpression. This is done to prevent collisions between distinct and regular aggregate + // children, in this case attribute reuse causes the input of the regular aggregate to bound to + // the (nulled out) input of the distinct aggregate. + e -> new AttributeReference(e.sql, e.dataType, true)() +} http://git-wip-us.apache.org/repos/asf/spark/blob/94d52d76/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala new file mode 100644 index 0000000..7c66731 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.catalog.SessionCatalog +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types._ + + +/** + * Finds all [[RuntimeReplaceable]] expressions and replace them with the expressions that can + * be evaluated. This is mainly used to provide compatibility with other databases. + * For example, we use this to support "nvl" by replacing it with "coalesce". + */ +object ReplaceExpressions extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case e: RuntimeReplaceable => e.replaced + } +} + + +/** + * Computes the current date and time to make sure we return the same result in a single query. + */ +object ComputeCurrentTime extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + val dateExpr = CurrentDate() + val timeExpr = CurrentTimestamp() + val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) + val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) + + plan transformAllExpressions { + case CurrentDate() => currentDate + case CurrentTimestamp() => currentTime + } + } +} + + +/** Replaces the expression of CurrentDatabase with the current database name. */ +case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + plan transformAllExpressions { + case CurrentDatabase() => + Literal.create(sessionCatalog.getCurrentDatabase, StringType) + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org