[SPARK-9830][SQL] Remove AggregateExpression1 and Aggregate Operator used to evaluate AggregateExpression1s
https://issues.apache.org/jira/browse/SPARK-9830 This PR contains the following main changes. * Removing `AggregateExpression1`. * Removing `Aggregate` operator, which is used to evaluate `AggregateExpression1`. * Removing planner rule used to plan `Aggregate`. * Linking `MultipleDistinctRewriter` to analyzer. * Renaming `AggregateExpression2` to `AggregateExpression` and `AggregateFunction2` to `AggregateFunction`. * Updating places where we create aggregate expression. The way to create aggregate expressions is `AggregateExpression(aggregateFunction, mode, isDistinct)`. * Changing `val`s in `DeclarativeAggregate`s that touch children of this function to `lazy val`s (when we create aggregate expression in DataFrame API, children of an aggregate function can be unresolved). Author: Yin Huai <yh...@databricks.com> Closes #9556 from yhuai/removeAgg1. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e0701c75 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e0701c75 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e0701c75 Branch: refs/heads/master Commit: e0701c75601c43f69ed27fc7c252321703db51f2 Parents: 6e5fc37 Author: Yin Huai <yh...@databricks.com> Authored: Tue Nov 10 11:06:29 2015 -0800 Committer: Michael Armbrust <mich...@databricks.com> Committed: Tue Nov 10 11:06:29 2015 -0800 ---------------------------------------------------------------------- R/pkg/R/functions.R | 2 +- python/pyspark/sql/dataframe.py | 2 +- python/pyspark/sql/functions.py | 2 +- python/pyspark/sql/tests.py | 2 +- .../spark/sql/catalyst/CatalystConf.scala | 10 +- .../apache/spark/sql/catalyst/SqlParser.scala | 14 +- .../spark/sql/catalyst/analysis/Analyzer.scala | 26 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 46 +- .../analysis/DistinctAggregationRewriter.scala | 278 +++++ .../catalyst/analysis/FunctionRegistry.scala | 2 + .../catalyst/analysis/HiveTypeCoercion.scala | 20 +- .../sql/catalyst/analysis/unresolved.scala | 4 + .../apache/spark/sql/catalyst/dsl/package.scala | 22 +- .../expressions/aggregate/Average.scala | 31 +- .../aggregate/CentralMomentAgg.scala | 13 +- .../catalyst/expressions/aggregate/Corr.scala | 15 + .../catalyst/expressions/aggregate/Count.scala | 28 +- .../catalyst/expressions/aggregate/First.scala | 14 +- .../aggregate/HyperLogLogPlusPlus.scala | 17 + .../expressions/aggregate/Kurtosis.scala | 2 + .../catalyst/expressions/aggregate/Last.scala | 12 +- .../catalyst/expressions/aggregate/Max.scala | 17 +- .../catalyst/expressions/aggregate/Min.scala | 17 +- .../expressions/aggregate/Skewness.scala | 2 + .../catalyst/expressions/aggregate/Stddev.scala | 31 +- .../catalyst/expressions/aggregate/Sum.scala | 29 +- .../catalyst/expressions/aggregate/Utils.scala | 467 -------- .../expressions/aggregate/Variance.scala | 7 +- .../expressions/aggregate/interfaces.scala | 57 +- .../sql/catalyst/expressions/aggregates.scala | 1073 ------------------ .../sql/catalyst/optimizer/Optimizer.scala | 23 +- .../spark/sql/catalyst/planning/patterns.scala | 74 -- .../spark/sql/catalyst/plans/QueryPlan.scala | 12 +- .../catalyst/plans/logical/basicOperators.scala | 4 +- .../catalyst/analysis/AnalysisErrorSuite.scala | 23 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 2 +- .../analysis/DecimalPrecisionSuite.scala | 1 + .../analysis/ExpressionTypeCheckingSuite.scala | 6 +- .../optimizer/ConstantFoldingSuite.scala | 4 +- .../optimizer/FilterPushdownSuite.scala | 14 +- .../scala/org/apache/spark/sql/DataFrame.scala | 13 +- .../org/apache/spark/sql/GroupedData.scala | 45 +- .../scala/org/apache/spark/sql/SQLConf.scala | 20 +- .../apache/spark/sql/execution/Aggregate.scala | 205 ---- .../org/apache/spark/sql/execution/Expand.scala | 3 + .../spark/sql/execution/SparkPlanner.scala | 1 - .../spark/sql/execution/SparkStrategies.scala | 238 ++-- .../aggregate/AggregationIterator.scala | 28 +- .../aggregate/SortBasedAggregate.scala | 4 +- .../SortBasedAggregationIterator.scala | 8 +- .../execution/aggregate/TungstenAggregate.scala | 6 +- .../aggregate/TungstenAggregationIterator.scala | 36 +- .../spark/sql/execution/aggregate/udaf.scala | 2 +- .../spark/sql/execution/aggregate/utils.scala | 20 +- .../spark/sql/expressions/Aggregator.scala | 5 +- .../spark/sql/expressions/WindowSpec.scala | 82 +- .../org/apache/spark/sql/expressions/udaf.scala | 6 +- .../scala/org/apache/spark/sql/functions.scala | 53 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 69 +- .../apache/spark/sql/UserDefinedTypeSuite.scala | 15 +- .../spark/sql/execution/PlannerSuite.scala | 2 +- .../sql/execution/metric/SQLMetricsSuite.scala | 30 - .../org/apache/spark/sql/hive/HiveContext.scala | 1 - .../org/apache/spark/sql/hive/HiveQl.scala | 8 +- .../hive/execution/AggregationQuerySuite.scala | 188 ++- 65 files changed, 998 insertions(+), 2515 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/R/pkg/R/functions.R ---------------------------------------------------------------------- diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index d7fd279..0b28087 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1339,7 +1339,7 @@ setMethod("pmod", signature(y = "Column"), #' @export setMethod("approxCountDistinct", signature(x = "Column"), - function(x, rsd = 0.95) { + function(x, rsd = 0.05) { jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) column(jc) }) http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/python/pyspark/sql/dataframe.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b97c94d..0dd75ba 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -866,7 +866,7 @@ class DataFrame(object): This is a variant of :func:`select` that accepts SQL expressions. >>> df.selectExpr("age * 2", "abs(age)").collect() - [Row((age * 2)=4, 'abs(age)=2), Row((age * 2)=10, 'abs(age)=5)] + [Row((age * 2)=4, abs(age)=2), Row((age * 2)=10, abs(age)=5)] """ if len(expr) == 1 and isinstance(expr[0], list): expr = expr[0] http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/python/pyspark/sql/functions.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 962f676..6e1cbde 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -382,7 +382,7 @@ def expr(str): """Parses the expression string into the column that it represents >>> df.select(expr("length(name)")).collect() - [Row('length(name)=5), Row('length(name)=3)] + [Row(length(name)=5), Row(length(name)=3)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.expr(str)) http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e224574..9f5f7cf 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1017,7 +1017,7 @@ class SQLTests(ReusedPySparkTestCase): row = Row(a="length string", b=75) df = self.sqlCtx.createDataFrame([row]) result = df.select(functions.expr("length(a)")).collect()[0].asDict() - self.assertEqual(13, result["'length(a)"]) + self.assertEqual(13, result["length(a)"]) def test_replace(self): schema = StructType([ http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index 3f351b0..7c2b8a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst private[spark] trait CatalystConf { def caseSensitiveAnalysis: Boolean + + protected[spark] def specializeSingleDistinctAggPlanning: Boolean } /** @@ -29,7 +31,13 @@ object EmptyConf extends CatalystConf { override def caseSensitiveAnalysis: Boolean = { throw new UnsupportedOperationException } + + protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = { + throw new UnsupportedOperationException + } } /** A CatalystConf that can be used for local testing. */ -case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf +case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf { + protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = true +} http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index cd717c0..2a132d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -22,6 +22,7 @@ import scala.language.implicitConversions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.DataTypeParser @@ -272,7 +273,7 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val function: Parser[Expression] = ( ident <~ ("(" ~ "*" ~ ")") ^^ { case udfName => if (lexical.normalizeKeyword(udfName) == "count") { - Count(Literal(1)) + AggregateExpression(Count(Literal(1)), mode = Complete, isDistinct = false) } else { throw new AnalysisException(s"invalid expression $udfName(*)") } @@ -281,14 +282,14 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { { case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) } | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs => lexical.normalizeKeyword(udfName) match { - case "sum" => SumDistinct(exprs.head) - case "count" => CountDistinct(exprs) + case "count" => + aggregate.Count(exprs).toAggregateExpression(isDistinct = true) case _ => UnresolvedFunction(udfName, exprs, isDistinct = true) } } | APPROXIMATE ~> ident ~ ("(" ~ DISTINCT ~> expression <~ ")") ^^ { case udfName ~ exp => if (lexical.normalizeKeyword(udfName) == "count") { - ApproxCountDistinct(exp) + AggregateExpression(new HyperLogLogPlusPlus(exp), mode = Complete, isDistinct = false) } else { throw new AnalysisException(s"invalid function approximate $udfName") } @@ -296,7 +297,10 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { | APPROXIMATE ~> "(" ~> unsignedFloat ~ ")" ~ ident ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ { case s ~ _ ~ udfName ~ _ ~ _ ~ exp => if (lexical.normalizeKeyword(udfName) == "count") { - ApproxCountDistinct(exp, s.toDouble) + AggregateExpression( + HyperLogLogPlusPlus(exp, s.toDouble, 0, 0), + mode = Complete, + isDistinct = false) } else { throw new AnalysisException(s"invalid function approximate($s) $udfName") } http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 899ee67..b1e1439 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef @@ -79,6 +79,7 @@ class Analyzer( ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: + DistinctAggregationRewriter(conf) :: HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, @@ -525,21 +526,14 @@ class Analyzer( case u @ UnresolvedFunction(name, children, isDistinct) => withPosition(u) { registry.lookupFunction(name, children) match { - // We get an aggregate function built based on AggregateFunction2 interface. - // So, we wrap it in AggregateExpression2. - case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct) - // Currently, our old aggregate function interface supports SUM(DISTINCT ...) - // and COUTN(DISTINCT ...). - case sumDistinct: SumDistinct => sumDistinct - case countDistinct: CountDistinct => countDistinct - // DISTINCT is not meaningful with Max and Min. - case max: Max if isDistinct => max - case min: Min if isDistinct => min - // For other aggregate functions, DISTINCT keyword is not supported for now. - // Once we converted to the new code path, we will allow using DISTINCT keyword. - case other: AggregateExpression1 if isDistinct => - failAnalysis(s"$name does not support DISTINCT keyword.") - // If it does not have DISTINCT keyword, we will return it as is. + // DISTINCT is not meaningful for a Max or a Min. + case max: Max if isDistinct => + AggregateExpression(max, Complete, isDistinct = false) + case min: Min if isDistinct => + AggregateExpression(min, Complete, isDistinct = false) + // We get an aggregate function, we need to wrap it in an AggregateExpression. + case agg2: AggregateFunction => AggregateExpression(agg2, Complete, isDistinct) + // This function is not an aggregate function, just return the resolved one. case other => other } } http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 98d6637..8322e99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, AggregateExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -108,7 +109,19 @@ trait CheckAnalysis { case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { - case _: AggregateExpression => // OK + case aggExpr: AggregateExpression => + // TODO: Is it possible that the child of a agg function is another + // agg function? + aggExpr.aggregateFunction.children.foreach { + // This is just a sanity check, our analysis rule PullOutNondeterministic should + // already pull out those nondeterministic expressions and evaluate them in + // a Project node. + case child if !child.deterministic => + failAnalysis( + s"nondeterministic expression ${expr.prettyString} should not " + + s"appear in the arguments of an aggregate function.") + case child => // OK + } case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => failAnalysis( s"expression '${e.prettyString}' is neither present in the group by, " + @@ -120,14 +133,26 @@ trait CheckAnalysis { case e => e.children.foreach(checkValidAggregateExpression) } - def checkValidGroupingExprs(expr: Expression): Unit = expr.dataType match { - case BinaryType => - failAnalysis(s"binary type expression ${expr.prettyString} cannot be used " + - "in grouping expression") - case m: MapType => - failAnalysis(s"map type expression ${expr.prettyString} cannot be used " + - "in grouping expression") - case _ => // OK + def checkValidGroupingExprs(expr: Expression): Unit = { + expr.dataType match { + case BinaryType => + failAnalysis(s"binary type expression ${expr.prettyString} cannot be used " + + "in grouping expression") + case a: ArrayType => + failAnalysis(s"array type expression ${expr.prettyString} cannot be used " + + "in grouping expression") + case m: MapType => + failAnalysis(s"map type expression ${expr.prettyString} cannot be used " + + "in grouping expression") + case _ => // OK + } + if (!expr.deterministic) { + // This is just a sanity check, our analysis rule PullOutNondeterministic should + // already pull out those nondeterministic expressions and evaluate them in + // a Project node. + failAnalysis(s"nondeterministic expression ${expr.prettyString} should not " + + s"appear in grouping expression.") + } } aggregateExprs.foreach(checkValidAggregateExpression) @@ -179,7 +204,8 @@ trait CheckAnalysis { s"unresolved operator ${operator.simpleString}") case o if o.expressions.exists(!_.deterministic) && - !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] => + !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] & !o.isInstanceOf[Aggregate] => + // The rule above is used to check Aggregate operator. failAnalysis( s"""nondeterministic expressions are only allowed in Project or Filter, found: | ${o.expressions.map(_.prettyString).mkString(",")} http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala new file mode 100644 index 0000000..397eff0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.IntegerType + +/** + * This rule rewrites an aggregate query with distinct aggregations into an expanded double + * aggregation in which the regular aggregation expressions and every distinct clause is aggregated + * in a separate group. The results are then combined in a second aggregate. + * + * For example (in scala): + * {{{ + * val data = Seq( + * ("a", "ca1", "cb1", 10), + * ("a", "ca1", "cb2", 5), + * ("b", "ca1", "cb1", 13)) + * .toDF("key", "cat1", "cat2", "value") + * data.registerTempTable("data") + * + * val agg = data.groupBy($"key") + * .agg( + * countDistinct($"cat1").as("cat1_cnt"), + * countDistinct($"cat2").as("cat2_cnt"), + * sum($"value").as("total")) + * }}} + * + * This translates to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [COUNT(DISTINCT 'cat1), + * COUNT(DISTINCT 'cat2), + * sum('value)] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) + * LocalTableScan [...] + * }}} + * + * This rule rewrites this logical plan to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [count(if (('gid = 1)) 'cat1 else null), + * count(if (('gid = 2)) 'cat2 else null), + * first(if (('gid = 0)) 'total else null) ignore nulls] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) + * Aggregate( + * key = ['key, 'cat1, 'cat2, 'gid] + * functions = [sum('value)] + * output = ['key, 'cat1, 'cat2, 'gid, 'total]) + * Expand( + * projections = [('key, null, null, 0, cast('value as bigint)), + * ('key, 'cat1, null, 1, null), + * ('key, null, 'cat2, 2, null)] + * output = ['key, 'cat1, 'cat2, 'gid, 'value]) + * LocalTableScan [...] + * }}} + * + * The rule does the following things here: + * 1. Expand the data. There are three aggregation groups in this query: + * i. the non-distinct group; + * ii. the distinct 'cat1 group; + * iii. the distinct 'cat2 group. + * An expand operator is inserted to expand the child data for each group. The expand will null + * out all unused columns for the given group; this must be done in order to ensure correctness + * later on. Groups can by identified by a group id (gid) column added by the expand operator. + * 2. De-duplicate the distinct paths and aggregate the non-aggregate path. The group by clause of + * this aggregate consists of the original group by clause, all the requested distinct columns + * and the group id. Both de-duplication of distinct column and the aggregation of the + * non-distinct group take advantage of the fact that we group by the group id (gid) and that we + * have nulled out all non-relevant columns for the the given group. + * 3. Aggregating the distinct groups and combining this with the results of the non-distinct + * aggregation. In this step we use the group id to filter the inputs for the aggregate + * functions. The result of the non-distinct group are 'aggregated' by using the first operator, + * it might be more elegant to use the native UDAF merge mechanism for this in the future. + * + * This rule duplicates the input data by two or more times (# distinct groups + an optional + * non-distinct group). This will put quite a bit of memory pressure of the used aggregate and + * exchange operators. Keeping the number of distinct groups as low a possible should be priority, + * we could improve this in the current rule by applying more advanced expression cannocalization + * techniques. + */ +case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.resolved => p + // We need to wait until this Aggregate operator is resolved. + case a: Aggregate => rewrite(a) + case p => p + } + + def rewrite(a: Aggregate): Aggregate = { + + // Collect all aggregate expressions. + val aggExpressions = a.aggregateExpressions.flatMap { e => + e.collect { + case ae: AggregateExpression => ae + } + } + + // Extract distinct aggregate expressions. + val distinctAggGroups = aggExpressions + .filter(_.isDistinct) + .groupBy(_.aggregateFunction.children.toSet) + + val shouldRewrite = if (conf.specializeSingleDistinctAggPlanning) { + // When the flag is set to specialize single distinct agg planning, + // we will rely on our Aggregation strategy to handle queries with a single + // distinct column and this aggregate operator does have grouping expressions. + distinctAggGroups.size > 1 || (distinctAggGroups.size == 1 && a.groupingExpressions.isEmpty) + } else { + distinctAggGroups.size >= 1 + } + if (shouldRewrite) { + // Create the attributes for the grouping id and the group by clause. + val gid = new AttributeReference("gid", IntegerType, false)() + val groupByMap = a.groupingExpressions.collect { + case ne: NamedExpression => ne -> ne.toAttribute + case e => e -> new AttributeReference(e.prettyString, e.dataType, e.nullable)() + } + val groupByAttrs = groupByMap.map(_._2) + + // Functions used to modify aggregate functions and their inputs. + def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) + def patchAggregateFunctionChildren( + af: AggregateFunction)( + attrs: Expression => Expression): AggregateFunction = { + af.withNewChildren(af.children.map { + case afc => attrs(afc) + }).asInstanceOf[AggregateFunction] + } + + // Setup unique distinct aggregate children. + val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq + val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair).toMap + val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq + + // Setup expand & aggregate operators for distinct aggregate expressions. + val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { + case ((group, expressions), i) => + val id = Literal(i + 1) + + // Expand projection + val projection = distinctAggChildren.map { + case e if group.contains(e) => e + case e => nullify(e) + } :+ id + + // Final aggregate + val operators = expressions.map { e => + val af = e.aggregateFunction + val naf = patchAggregateFunctionChildren(af) { x => + evalWithinGroup(id, distinctAggChildAttrMap(x)) + } + (e, e.copy(aggregateFunction = naf, isDistinct = false)) + } + + (projection, operators) + } + + // Setup expand for the 'regular' aggregate expressions. + val regularAggExprs = aggExpressions.filter(!_.isDistinct) + val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct + val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) + + // Setup aggregates for 'regular' aggregate expressions. + val regularGroupId = Literal(0) + val regularAggChildAttrLookup = regularAggChildAttrMap.toMap + val regularAggOperatorMap = regularAggExprs.map { e => + // Perform the actual aggregation in the initial aggregate. + val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup) + val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)() + + // Select the result of the first aggregate in the last aggregate. + val result = AggregateExpression( + aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)), + mode = Complete, + isDistinct = false) + + // Some aggregate functions (COUNT) have the special property that they can return a + // non-null result without any input. We need to make sure we return a result in this case. + val resultWithDefault = af.defaultResult match { + case Some(lit) => Coalesce(Seq(result, lit)) + case None => result + } + + // Return a Tuple3 containing: + // i. The original aggregate expression (used for look ups). + // ii. The actual aggregation operator (used in the first aggregate). + // iii. The operator that selects and returns the result (used in the second aggregate). + (e, operator, resultWithDefault) + } + + // Construct the regular aggregate input projection only if we need one. + val regularAggProjection = if (regularAggExprs.nonEmpty) { + Seq(a.groupingExpressions ++ + distinctAggChildren.map(nullify) ++ + Seq(regularGroupId) ++ + regularAggChildren) + } else { + Seq.empty[Seq[Expression]] + } + + // Construct the distinct aggregate input projections. + val regularAggNulls = regularAggChildren.map(nullify) + val distinctAggProjections = distinctAggOperatorMap.map { + case (projection, _) => + a.groupingExpressions ++ + projection ++ + regularAggNulls + } + + // Construct the expand operator. + val expand = Expand( + regularAggProjection ++ distinctAggProjections, + groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2), + a.child) + + // Construct the first aggregate operator. This de-duplicates the all the children of + // distinct operators, and applies the regular aggregate operators. + val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid + val firstAggregate = Aggregate( + firstAggregateGroupBy, + firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2), + expand) + + // Construct the second aggregate + val transformations: Map[Expression, Expression] = + (distinctAggOperatorMap.flatMap(_._2) ++ + regularAggOperatorMap.map(e => (e._1, e._3))).toMap + + val patchedAggExpressions = a.aggregateExpressions.map { e => + e.transformDown { + case e: Expression => + // The same GROUP BY clauses can have different forms (different names for instance) in + // the groupBy and aggregate expressions of an aggregate. This makes a map lookup + // tricky. So we do a linear search for a semantically equal group by expression. + groupByMap + .find(ge => e.semanticEquals(ge._1)) + .map(_._2) + .getOrElse(transformations.getOrElse(e, e)) + }.asInstanceOf[NamedExpression] + } + Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) + } else { + a + } + } + + private def nullify(e: Expression) = Literal.create(null, e.dataType) + + private def expressionAttributePair(e: Expression) = + // We are creating a new reference here instead of reusing the attribute in case of a + // NamedExpression. This is done to prevent collisions between distinct and regular aggregate + // children, in this case attribute reuse causes the input of the regular aggregate to bound to + // the (nulled out) input of the distinct aggregate. + e -> new AttributeReference(e.prettyString, e.dataType, true)() +} http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d4334d1..dfa749d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -24,6 +24,7 @@ import scala.util.{Failure, Success, Try} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.util.StringKeyHashMap @@ -177,6 +178,7 @@ object FunctionRegistry { expression[ToRadians]("radians"), // aggregate functions + expression[HyperLogLogPlusPlus]("approx_count_distinct"), expression[Average]("avg"), expression[Corr]("corr"), expression[Count]("count"), http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 84e2b13..bf2bff0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ @@ -295,14 +296,17 @@ object HiveTypeCoercion { i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) - case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) - case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType)) - case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType)) - case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType)) - case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType)) + case VariancePop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + VariancePop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case VarianceSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + VarianceSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case Skewness(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + Skewness(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case Kurtosis(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + Kurtosis(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) } } @@ -562,12 +566,6 @@ object HiveTypeCoercion { case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType)) case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType)) - case s @ SumDistinct(e @ DecimalType()) => s // Decimal is already the biggest. - case SumDistinct(e @ IntegralType()) if e.dataType != LongType => - SumDistinct(Cast(e, LongType)) - case SumDistinct(e @ FractionalType()) if e.dataType != DoubleType => - SumDistinct(Cast(e, DoubleType)) - case s @ Average(e @ DecimalType()) => s // Decimal is already the biggest. case Average(e @ IntegralType()) if e.dataType != LongType => Average(Cast(e, LongType)) http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index eae17c8..6485bdf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -141,6 +141,10 @@ case class UnresolvedFunction( override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false + override def prettyString: String = { + s"${name}(${children.map(_.prettyString).mkString(",")})" + } + override def toString: String = s"'$name(${children.mkString(",")})" } http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index d8df664..af594c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -23,6 +23,7 @@ import scala.language.implicitConversions import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.types._ @@ -144,17 +145,18 @@ package object dsl { } } - def sum(e: Expression): Expression = Sum(e) - def sumDistinct(e: Expression): Expression = SumDistinct(e) - def count(e: Expression): Expression = Count(e) - def countDistinct(e: Expression*): Expression = CountDistinct(e) + def sum(e: Expression): Expression = Sum(e).toAggregateExpression() + def sumDistinct(e: Expression): Expression = Sum(e).toAggregateExpression(isDistinct = true) + def count(e: Expression): Expression = Count(e).toAggregateExpression() + def countDistinct(e: Expression*): Expression = + Count(e).toAggregateExpression(isDistinct = true) def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression = - ApproxCountDistinct(e, rsd) - def avg(e: Expression): Expression = Average(e) - def first(e: Expression): Expression = First(e) - def last(e: Expression): Expression = Last(e) - def min(e: Expression): Expression = Min(e) - def max(e: Expression): Expression = Max(e) + HyperLogLogPlusPlus(e, rsd).toAggregateExpression() + def avg(e: Expression): Expression = Average(e).toAggregateExpression() + def first(e: Expression): Expression = new First(e).toAggregateExpression() + def last(e: Expression): Expression = new Last(e).toAggregateExpression() + def min(e: Expression): Expression = Min(e).toAggregateExpression() + def max(e: Expression): Expression = Max(e).toAggregateExpression() def upper(e: Expression): Expression = Upper(e) def lower(e: Expression): Expression = Lower(e) def sqrt(e: Expression): Expression = Sqrt(e) http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index c8c20ad..7f9e503 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ case class Average(child: Expression) extends DeclarativeAggregate { @@ -32,36 +34,33 @@ case class Average(child: Expression) extends DeclarativeAggregate { // Return data type. override def dataType: DataType = resultType - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select avg(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType)) - private val resultType = child.dataType match { + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function average") + + private lazy val resultType = child.dataType match { case DecimalType.Fixed(p, s) => DecimalType.bounded(p + 4, s + 4) case _ => DoubleType } - private val sumDataType = child.dataType match { + private lazy val sumDataType = child.dataType match { case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s) case _ => DoubleType } - private val sum = AttributeReference("sum", sumDataType)() - private val count = AttributeReference("count", LongType)() + private lazy val sum = AttributeReference("sum", sumDataType)() + private lazy val count = AttributeReference("count", LongType)() - override val aggBufferAttributes = sum :: count :: Nil + override lazy val aggBufferAttributes = sum :: count :: Nil - override val initialValues = Seq( + override lazy val initialValues = Seq( /* sum = */ Cast(Literal(0), sumDataType), /* count = */ Literal(0L) ) - override val updateExpressions = Seq( + override lazy val updateExpressions = Seq( /* sum = */ Add( sum, @@ -69,13 +68,13 @@ case class Average(child: Expression) extends DeclarativeAggregate { /* count = */ If(IsNull(child), count, count + 1L) ) - override val mergeExpressions = Seq( + override lazy val mergeExpressions = Seq( /* sum = */ sum.left + sum.right, /* count = */ count.left + count.right ) // If all input are nulls, count will be 0 and we will get null after the division. - override val evaluateExpression = child.dataType match { + override lazy val evaluateExpression = child.dataType match { case DecimalType.Fixed(p, s) => // increase the precision and scale to prevent precision loss val dt = DecimalType.bounded(p + 14, s + 4) http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index ef08b02..984ce7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** @@ -55,13 +57,10 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w override def dataType: DataType = DoubleType - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select avg(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType)) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName") override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index 8323383..00d7436 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** @@ -35,6 +37,9 @@ case class Corr( inputAggBufferOffset: Int = 0) extends ImperativeAggregate { + def this(left: Expression, right: Expression) = + this(left, right, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + override def children: Seq[Expression] = Seq(left, right) override def nullable: Boolean = false @@ -43,6 +48,16 @@ case class Corr( override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) + override def checkInputDataTypes(): TypeCheckResult = { + if (left.dataType.isInstanceOf[DoubleType] && right.dataType.isInstanceOf[DoubleType]) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"corr requires that both arguments are double type, " + + s"not (${left.dataType}, ${right.dataType}).") + } + } + override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) override def inputAggBufferAttributes: Seq[AttributeReference] = { http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index ec0c8b4..09a1da9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -32,23 +32,39 @@ case class Count(child: Expression) extends DeclarativeAggregate { // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val count = AttributeReference("count", LongType)() + private lazy val count = AttributeReference("count", LongType)() - override val aggBufferAttributes = count :: Nil + override lazy val aggBufferAttributes = count :: Nil - override val initialValues = Seq( + override lazy val initialValues = Seq( /* count = */ Literal(0L) ) - override val updateExpressions = Seq( + override lazy val updateExpressions = Seq( /* count = */ If(IsNull(child), count, count + 1L) ) - override val mergeExpressions = Seq( + override lazy val mergeExpressions = Seq( /* count = */ count.left + count.right ) - override val evaluateExpression = Cast(count, LongType) + override lazy val evaluateExpression = Cast(count, LongType) override def defaultResult: Option[Literal] = Option(Literal(0L)) } + +object Count { + def apply(children: Seq[Expression]): Count = { + // This is used to deal with COUNT DISTINCT. When we have multiple + // children (COUNT(DISTINCT col1, col2, ...)), we wrap them in a STRUCT (i.e. a Row). + // Also, the semantic of COUNT(DISTINCT col1, col2, ...) is that if there is any + // null in the arguments, we will not count that row. So, we use DropAnyNull at here + // to return a null when any field of the created STRUCT is null. + val child = if (children.size > 1) { + DropAnyNull(CreateStruct(children)) + } else { + children.head + } + Count(child) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 9028143..35f5742 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -51,18 +51,18 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val first = AttributeReference("first", child.dataType)() + private lazy val first = AttributeReference("first", child.dataType)() - private val valueSet = AttributeReference("valueSet", BooleanType)() + private lazy val valueSet = AttributeReference("valueSet", BooleanType)() - override val aggBufferAttributes: Seq[AttributeReference] = first :: valueSet :: Nil + override lazy val aggBufferAttributes: Seq[AttributeReference] = first :: valueSet :: Nil - override val initialValues: Seq[Literal] = Seq( + override lazy val initialValues: Seq[Literal] = Seq( /* first = */ Literal.create(null, child.dataType), /* valueSet = */ Literal.create(false, BooleanType) ) - override val updateExpressions: Seq[Expression] = { + override lazy val updateExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( /* first = */ If(Or(valueSet, IsNull(child)), first, child), @@ -76,7 +76,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara } } - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { // For first, we can just check if valueSet.left is set to true. If it is set // to true, we use first.right. If not, we use first.right (even if valueSet.right is // false, we are safe to do so because first.right will be null in this case). @@ -86,7 +86,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara ) } - override val evaluateExpression: AttributeReference = first + override lazy val evaluateExpression: AttributeReference = first override def toString: String = s"first($child)${if (ignoreNulls) " ignore nulls"}" } http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index 8d341ee..8a95c54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -22,6 +22,7 @@ import java.util import com.clearspring.analytics.hash.MurmurHash +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -55,6 +56,22 @@ case class HyperLogLogPlusPlus( extends ImperativeAggregate { import HyperLogLogPlusPlus._ + def this(child: Expression) = { + this(child = child, relativeSD = 0.05, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + } + + def this(child: Expression, relativeSD: Expression) = { + this( + child = child, + relativeSD = relativeSD match { + case Literal(d: Double, DoubleType) => d + case _ => + throw new AnalysisException("The second argument should be a double literal.") + }, + mutableAggBufferOffset = 0, + inputAggBufferOffset = 0) + } + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala index 6da39e7..bae78d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala @@ -24,6 +24,8 @@ case class Kurtosis(child: Expression, inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 8636bfe..be7e12d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -51,15 +51,15 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val last = AttributeReference("last", child.dataType)() + private lazy val last = AttributeReference("last", child.dataType)() - override val aggBufferAttributes: Seq[AttributeReference] = last :: Nil + override lazy val aggBufferAttributes: Seq[AttributeReference] = last :: Nil - override val initialValues: Seq[Literal] = Seq( + override lazy val initialValues: Seq[Literal] = Seq( /* last = */ Literal.create(null, child.dataType) ) - override val updateExpressions: Seq[Expression] = { + override lazy val updateExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( /* last = */ If(IsNull(child), last, child) @@ -71,7 +71,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat } } - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( /* last = */ If(IsNull(last.right), last.left, last.right) @@ -83,7 +83,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat } } - override val evaluateExpression: AttributeReference = last + override lazy val evaluateExpression: AttributeReference = last override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}" } http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index b9d75ad..61cae44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ case class Max(child: Expression) extends DeclarativeAggregate { @@ -32,24 +34,27 @@ case class Max(child: Expression) extends DeclarativeAggregate { // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val max = AttributeReference("max", child.dataType)() + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function max") - override val aggBufferAttributes: Seq[AttributeReference] = max :: Nil + private lazy val max = AttributeReference("max", child.dataType)() - override val initialValues: Seq[Literal] = Seq( + override lazy val aggBufferAttributes: Seq[AttributeReference] = max :: Nil + + override lazy val initialValues: Seq[Literal] = Seq( /* max = */ Literal.create(null, child.dataType) ) - override val updateExpressions: Seq[Expression] = Seq( + override lazy val updateExpressions: Seq[Expression] = Seq( /* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child)))) ) - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { val greatest = Greatest(Seq(max.left, max.right)) Seq( /* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest)) ) } - override val evaluateExpression: AttributeReference = max + override lazy val evaluateExpression: AttributeReference = max } http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index 5ed9cd3..242456d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -33,24 +35,27 @@ case class Min(child: Expression) extends DeclarativeAggregate { // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val min = AttributeReference("min", child.dataType)() + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function min") - override val aggBufferAttributes: Seq[AttributeReference] = min :: Nil + private lazy val min = AttributeReference("min", child.dataType)() - override val initialValues: Seq[Expression] = Seq( + override lazy val aggBufferAttributes: Seq[AttributeReference] = min :: Nil + + override lazy val initialValues: Seq[Expression] = Seq( /* min = */ Literal.create(null, child.dataType) ) - override val updateExpressions: Seq[Expression] = Seq( + override lazy val updateExpressions: Seq[Expression] = Seq( /* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child)))) ) - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { val least = Least(Seq(min.left, min.right)) Seq( /* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least)) ) } - override val evaluateExpression: AttributeReference = min + override lazy val evaluateExpression: AttributeReference = min } http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala index 0def7dd..c593074 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala @@ -24,6 +24,8 @@ case class Skewness(child: Expression, inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala index 3f47ffe..5b9eb7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -48,29 +50,26 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { override def dataType: DataType = resultType - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select stddev(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType)) - private val resultType = DoubleType + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function stddev") - private val count = AttributeReference("count", resultType)() - private val avg = AttributeReference("avg", resultType)() - private val mk = AttributeReference("mk", resultType)() + private lazy val resultType = DoubleType - override val aggBufferAttributes = count :: avg :: mk :: Nil + private lazy val count = AttributeReference("count", resultType)() + private lazy val avg = AttributeReference("avg", resultType)() + private lazy val mk = AttributeReference("mk", resultType)() - override val initialValues: Seq[Expression] = Seq( + override lazy val aggBufferAttributes = count :: avg :: mk :: Nil + + override lazy val initialValues: Seq[Expression] = Seq( /* count = */ Cast(Literal(0), resultType), /* avg = */ Cast(Literal(0), resultType), /* mk = */ Cast(Literal(0), resultType) ) - override val updateExpressions: Seq[Expression] = { + override lazy val updateExpressions: Seq[Expression] = { val value = Cast(child, resultType) val newCount = count + Cast(Literal(1), resultType) @@ -89,7 +88,7 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { ) } - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { // count merge val newCount = count.left + count.right @@ -114,7 +113,7 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { ) } - override val evaluateExpression: Expression = { + override lazy val evaluateExpression: Expression = { // when count == 0, return null // when count == 1, return 0 // when count >1 http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 7f8adbc..c005ec9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ case class Sum(child: Expression) extends DeclarativeAggregate { @@ -29,16 +31,13 @@ case class Sum(child: Expression) extends DeclarativeAggregate { // Return data type. override def dataType: DataType = resultType - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select sum(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType)) - private val resultType = child.dataType match { + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function sum") + + private lazy val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType.bounded(precision + 10, scale) // TODO: Remove this line once we remove the NullType from inputTypes. @@ -46,24 +45,24 @@ case class Sum(child: Expression) extends DeclarativeAggregate { case _ => child.dataType } - private val sumDataType = resultType + private lazy val sumDataType = resultType - private val sum = AttributeReference("sum", sumDataType)() + private lazy val sum = AttributeReference("sum", sumDataType)() - private val zero = Cast(Literal(0), sumDataType) + private lazy val zero = Cast(Literal(0), sumDataType) - override val aggBufferAttributes = sum :: Nil + override lazy val aggBufferAttributes = sum :: Nil - override val initialValues: Seq[Expression] = Seq( + override lazy val initialValues: Seq[Expression] = Seq( /* sum = */ Literal.create(null, sumDataType) ) - override val updateExpressions: Seq[Expression] = Seq( + override lazy val updateExpressions: Seq[Expression] = Seq( /* sum = */ Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) ) - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType)) Seq( /* sum = */ @@ -71,5 +70,5 @@ case class Sum(child: Expression) extends DeclarativeAggregate { ) } - override val evaluateExpression: Expression = Cast(sum, resultType) + override lazy val evaluateExpression: Expression = Cast(sum, resultType) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org