[SPARK-9830][SQL] Remove AggregateExpression1 and Aggregate Operator used to 
evaluate AggregateExpression1s

https://issues.apache.org/jira/browse/SPARK-9830

This PR contains the following main changes.
* Removing `AggregateExpression1`.
* Removing `Aggregate` operator, which is used to evaluate 
`AggregateExpression1`.
* Removing planner rule used to plan `Aggregate`.
* Linking `MultipleDistinctRewriter` to analyzer.
* Renaming `AggregateExpression2` to `AggregateExpression` and 
`AggregateFunction2` to `AggregateFunction`.
* Updating places where we create aggregate expression. The way to create 
aggregate expressions is `AggregateExpression(aggregateFunction, mode, 
isDistinct)`.
* Changing `val`s in `DeclarativeAggregate`s that touch children of this 
function to `lazy val`s (when we create aggregate expression in DataFrame API, 
children of an aggregate function can be unresolved).

Author: Yin Huai <yh...@databricks.com>

Closes #9556 from yhuai/removeAgg1.

(cherry picked from commit e0701c75601c43f69ed27fc7c252321703db51f2)
Signed-off-by: Michael Armbrust <mich...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7c4ade0d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7c4ade0d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7c4ade0d

Branch: refs/heads/branch-1.6
Commit: 7c4ade0d7665e0f473d00f4a812fa69a0e0d14b5
Parents: 825e971
Author: Yin Huai <yh...@databricks.com>
Authored: Tue Nov 10 11:06:29 2015 -0800
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Tue Nov 10 11:06:48 2015 -0800

----------------------------------------------------------------------
 R/pkg/R/functions.R                             |    2 +-
 python/pyspark/sql/dataframe.py                 |    2 +-
 python/pyspark/sql/functions.py                 |    2 +-
 python/pyspark/sql/tests.py                     |    2 +-
 .../spark/sql/catalyst/CatalystConf.scala       |   10 +-
 .../apache/spark/sql/catalyst/SqlParser.scala   |   14 +-
 .../spark/sql/catalyst/analysis/Analyzer.scala  |   26 +-
 .../sql/catalyst/analysis/CheckAnalysis.scala   |   46 +-
 .../analysis/DistinctAggregationRewriter.scala  |  278 +++++
 .../catalyst/analysis/FunctionRegistry.scala    |    2 +
 .../catalyst/analysis/HiveTypeCoercion.scala    |   20 +-
 .../sql/catalyst/analysis/unresolved.scala      |    4 +
 .../apache/spark/sql/catalyst/dsl/package.scala |   22 +-
 .../expressions/aggregate/Average.scala         |   31 +-
 .../aggregate/CentralMomentAgg.scala            |   13 +-
 .../catalyst/expressions/aggregate/Corr.scala   |   15 +
 .../catalyst/expressions/aggregate/Count.scala  |   28 +-
 .../catalyst/expressions/aggregate/First.scala  |   14 +-
 .../aggregate/HyperLogLogPlusPlus.scala         |   17 +
 .../expressions/aggregate/Kurtosis.scala        |    2 +
 .../catalyst/expressions/aggregate/Last.scala   |   12 +-
 .../catalyst/expressions/aggregate/Max.scala    |   17 +-
 .../catalyst/expressions/aggregate/Min.scala    |   17 +-
 .../expressions/aggregate/Skewness.scala        |    2 +
 .../catalyst/expressions/aggregate/Stddev.scala |   31 +-
 .../catalyst/expressions/aggregate/Sum.scala    |   29 +-
 .../catalyst/expressions/aggregate/Utils.scala  |  467 --------
 .../expressions/aggregate/Variance.scala        |    7 +-
 .../expressions/aggregate/interfaces.scala      |   57 +-
 .../sql/catalyst/expressions/aggregates.scala   | 1073 ------------------
 .../sql/catalyst/optimizer/Optimizer.scala      |   23 +-
 .../spark/sql/catalyst/planning/patterns.scala  |   74 --
 .../spark/sql/catalyst/plans/QueryPlan.scala    |   12 +-
 .../catalyst/plans/logical/basicOperators.scala |    4 +-
 .../catalyst/analysis/AnalysisErrorSuite.scala  |   23 +-
 .../sql/catalyst/analysis/AnalysisSuite.scala   |    2 +-
 .../analysis/DecimalPrecisionSuite.scala        |    1 +
 .../analysis/ExpressionTypeCheckingSuite.scala  |    6 +-
 .../optimizer/ConstantFoldingSuite.scala        |    4 +-
 .../optimizer/FilterPushdownSuite.scala         |   14 +-
 .../scala/org/apache/spark/sql/DataFrame.scala  |   13 +-
 .../org/apache/spark/sql/GroupedData.scala      |   45 +-
 .../scala/org/apache/spark/sql/SQLConf.scala    |   20 +-
 .../apache/spark/sql/execution/Aggregate.scala  |  205 ----
 .../org/apache/spark/sql/execution/Expand.scala |    3 +
 .../spark/sql/execution/SparkPlanner.scala      |    1 -
 .../spark/sql/execution/SparkStrategies.scala   |  238 ++--
 .../aggregate/AggregationIterator.scala         |   28 +-
 .../aggregate/SortBasedAggregate.scala          |    4 +-
 .../SortBasedAggregationIterator.scala          |    8 +-
 .../execution/aggregate/TungstenAggregate.scala |    6 +-
 .../aggregate/TungstenAggregationIterator.scala |   36 +-
 .../spark/sql/execution/aggregate/udaf.scala    |    2 +-
 .../spark/sql/execution/aggregate/utils.scala   |   20 +-
 .../spark/sql/expressions/Aggregator.scala      |    5 +-
 .../spark/sql/expressions/WindowSpec.scala      |   82 +-
 .../org/apache/spark/sql/expressions/udaf.scala |    6 +-
 .../scala/org/apache/spark/sql/functions.scala  |   53 +-
 .../org/apache/spark/sql/SQLQuerySuite.scala    |   69 +-
 .../apache/spark/sql/UserDefinedTypeSuite.scala |   15 +-
 .../spark/sql/execution/PlannerSuite.scala      |    2 +-
 .../sql/execution/metric/SQLMetricsSuite.scala  |   30 -
 .../org/apache/spark/sql/hive/HiveContext.scala |    1 -
 .../org/apache/spark/sql/hive/HiveQl.scala      |    8 +-
 .../hive/execution/AggregationQuerySuite.scala  |  188 ++-
 65 files changed, 998 insertions(+), 2515 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/R/pkg/R/functions.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index d7fd279..0b28087 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -1339,7 +1339,7 @@ setMethod("pmod", signature(y = "Column"),
 #' @export
 setMethod("approxCountDistinct",
           signature(x = "Column"),
-          function(x, rsd = 0.95) {
+          function(x, rsd = 0.05) {
             jc <- callJStatic("org.apache.spark.sql.functions", 
"approxCountDistinct", x@jc, rsd)
             column(jc)
           })

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index b97c94d..0dd75ba 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -866,7 +866,7 @@ class DataFrame(object):
         This is a variant of :func:`select` that accepts SQL expressions.
 
         >>> df.selectExpr("age * 2", "abs(age)").collect()
-        [Row((age * 2)=4, 'abs(age)=2), Row((age * 2)=10, 'abs(age)=5)]
+        [Row((age * 2)=4, abs(age)=2), Row((age * 2)=10, abs(age)=5)]
         """
         if len(expr) == 1 and isinstance(expr[0], list):
             expr = expr[0]

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 962f676..6e1cbde 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -382,7 +382,7 @@ def expr(str):
     """Parses the expression string into the column that it represents
 
     >>> df.select(expr("length(name)")).collect()
-    [Row('length(name)=5), Row('length(name)=3)]
+    [Row(length(name)=5), Row(length(name)=3)]
     """
     sc = SparkContext._active_spark_context
     return Column(sc._jvm.functions.expr(str))

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index e224574..9f5f7cf 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1017,7 +1017,7 @@ class SQLTests(ReusedPySparkTestCase):
         row = Row(a="length string", b=75)
         df = self.sqlCtx.createDataFrame([row])
         result = df.select(functions.expr("length(a)")).collect()[0].asDict()
-        self.assertEqual(13, result["'length(a)"])
+        self.assertEqual(13, result["length(a)"])
 
     def test_replace(self):
         schema = StructType([

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
index 3f351b0..7c2b8a9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst
 
 private[spark] trait CatalystConf {
   def caseSensitiveAnalysis: Boolean
+
+  protected[spark] def specializeSingleDistinctAggPlanning: Boolean
 }
 
 /**
@@ -29,7 +31,13 @@ object EmptyConf extends CatalystConf {
   override def caseSensitiveAnalysis: Boolean = {
     throw new UnsupportedOperationException
   }
+
+  protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = 
{
+    throw new UnsupportedOperationException
+  }
 }
 
 /** A CatalystConf that can be used for local testing. */
-case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends 
CatalystConf
+case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends 
CatalystConf {
+  protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = 
true
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index cd717c0..2a132d8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -22,6 +22,7 @@ import scala.language.implicitConversions
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.analysis._
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.util.DataTypeParser
@@ -272,7 +273,7 @@ object SqlParser extends AbstractSparkSQLParser with 
DataTypeParser {
   protected lazy val function: Parser[Expression] =
     ( ident <~ ("(" ~ "*" ~ ")") ^^ { case udfName =>
       if (lexical.normalizeKeyword(udfName) == "count") {
-        Count(Literal(1))
+        AggregateExpression(Count(Literal(1)), mode = Complete, isDistinct = 
false)
       } else {
         throw new AnalysisException(s"invalid expression $udfName(*)")
       }
@@ -281,14 +282,14 @@ object SqlParser extends AbstractSparkSQLParser with 
DataTypeParser {
       { case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct 
= false) }
     | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case 
udfName ~ exprs =>
       lexical.normalizeKeyword(udfName) match {
-        case "sum" => SumDistinct(exprs.head)
-        case "count" => CountDistinct(exprs)
+        case "count" =>
+          aggregate.Count(exprs).toAggregateExpression(isDistinct = true)
         case _ => UnresolvedFunction(udfName, exprs, isDistinct = true)
       }
     }
     | APPROXIMATE ~> ident ~ ("(" ~ DISTINCT ~> expression <~ ")") ^^ { case 
udfName ~ exp =>
       if (lexical.normalizeKeyword(udfName) == "count") {
-        ApproxCountDistinct(exp)
+        AggregateExpression(new HyperLogLogPlusPlus(exp), mode = Complete, 
isDistinct = false)
       } else {
         throw new AnalysisException(s"invalid function approximate $udfName")
       }
@@ -296,7 +297,10 @@ object SqlParser extends AbstractSparkSQLParser with 
DataTypeParser {
     | APPROXIMATE ~> "(" ~> unsignedFloat ~ ")" ~ ident ~ "(" ~ DISTINCT ~ 
expression <~ ")" ^^
       { case s ~ _ ~ udfName ~ _ ~ _ ~ exp =>
         if (lexical.normalizeKeyword(udfName) == "count") {
-          ApproxCountDistinct(exp, s.toDouble)
+          AggregateExpression(
+            HyperLogLogPlusPlus(exp, s.toDouble, 0, 0),
+            mode = Complete,
+            isDistinct = false)
         } else {
           throw new AnalysisException(s"invalid function approximate($s) 
$udfName")
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 899ee67..b1e1439 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.analysis
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, 
AggregateExpression2, AggregateFunction2}
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
 import org.apache.spark.sql.catalyst.trees.TreeNodeRef
@@ -79,6 +79,7 @@ class Analyzer(
       ExtractWindowExpressions ::
       GlobalAggregates ::
       ResolveAggregateFunctions ::
+      DistinctAggregationRewriter(conf) ::
       HiveTypeCoercion.typeCoercionRules ++
       extendedResolutionRules : _*),
     Batch("Nondeterministic", Once,
@@ -525,21 +526,14 @@ class Analyzer(
           case u @ UnresolvedFunction(name, children, isDistinct) =>
             withPosition(u) {
               registry.lookupFunction(name, children) match {
-                // We get an aggregate function built based on 
AggregateFunction2 interface.
-                // So, we wrap it in AggregateExpression2.
-                case agg2: AggregateFunction2 => AggregateExpression2(agg2, 
Complete, isDistinct)
-                // Currently, our old aggregate function interface supports 
SUM(DISTINCT ...)
-                // and COUTN(DISTINCT ...).
-                case sumDistinct: SumDistinct => sumDistinct
-                case countDistinct: CountDistinct => countDistinct
-                // DISTINCT is not meaningful with Max and Min.
-                case max: Max if isDistinct => max
-                case min: Min if isDistinct => min
-                // For other aggregate functions, DISTINCT keyword is not 
supported for now.
-                // Once we converted to the new code path, we will allow using 
DISTINCT keyword.
-                case other: AggregateExpression1 if isDistinct =>
-                  failAnalysis(s"$name does not support DISTINCT keyword.")
-                // If it does not have DISTINCT keyword, we will return it as 
is.
+                // DISTINCT is not meaningful for a Max or a Min.
+                case max: Max if isDistinct =>
+                  AggregateExpression(max, Complete, isDistinct = false)
+                case min: Min if isDistinct =>
+                  AggregateExpression(min, Complete, isDistinct = false)
+                // We get an aggregate function, we need to wrap it in an 
AggregateExpression.
+                case agg2: AggregateFunction => AggregateExpression(agg2, 
Complete, isDistinct)
+                // This function is not an aggregate function, just return the 
resolved one.
                 case other => other
               }
             }

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 98d6637..8322e99 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, 
AggregateExpression}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.types._
 
@@ -108,7 +109,19 @@ trait CheckAnalysis {
 
           case Aggregate(groupingExprs, aggregateExprs, child) =>
             def checkValidAggregateExpression(expr: Expression): Unit = expr 
match {
-              case _: AggregateExpression => // OK
+              case aggExpr: AggregateExpression =>
+                // TODO: Is it possible that the child of a agg function is 
another
+                // agg function?
+                aggExpr.aggregateFunction.children.foreach {
+                  // This is just a sanity check, our analysis rule 
PullOutNondeterministic should
+                  // already pull out those nondeterministic expressions and 
evaluate them in
+                  // a Project node.
+                  case child if !child.deterministic =>
+                    failAnalysis(
+                      s"nondeterministic expression ${expr.prettyString} 
should not " +
+                        s"appear in the arguments of an aggregate function.")
+                  case child => // OK
+                }
               case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) 
=>
                 failAnalysis(
                   s"expression '${e.prettyString}' is neither present in the 
group by, " +
@@ -120,14 +133,26 @@ trait CheckAnalysis {
               case e => e.children.foreach(checkValidAggregateExpression)
             }
 
-            def checkValidGroupingExprs(expr: Expression): Unit = 
expr.dataType match {
-              case BinaryType =>
-                failAnalysis(s"binary type expression ${expr.prettyString} 
cannot be used " +
-                  "in grouping expression")
-              case m: MapType =>
-                failAnalysis(s"map type expression ${expr.prettyString} cannot 
be used " +
-                  "in grouping expression")
-              case _ => // OK
+            def checkValidGroupingExprs(expr: Expression): Unit = {
+              expr.dataType match {
+                case BinaryType =>
+                  failAnalysis(s"binary type expression ${expr.prettyString} 
cannot be used " +
+                    "in grouping expression")
+                case a: ArrayType =>
+                  failAnalysis(s"array type expression ${expr.prettyString} 
cannot be used " +
+                    "in grouping expression")
+                case m: MapType =>
+                  failAnalysis(s"map type expression ${expr.prettyString} 
cannot be used " +
+                    "in grouping expression")
+                case _ => // OK
+              }
+              if (!expr.deterministic) {
+                // This is just a sanity check, our analysis rule 
PullOutNondeterministic should
+                // already pull out those nondeterministic expressions and 
evaluate them in
+                // a Project node.
+                failAnalysis(s"nondeterministic expression 
${expr.prettyString} should not " +
+                  s"appear in grouping expression.")
+              }
             }
 
             aggregateExprs.foreach(checkValidAggregateExpression)
@@ -179,7 +204,8 @@ trait CheckAnalysis {
               s"unresolved operator ${operator.simpleString}")
 
           case o if o.expressions.exists(!_.deterministic) &&
-            !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] =>
+            !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] & 
!o.isInstanceOf[Aggregate] =>
+            // The rule above is used to check Aggregate operator.
             failAnalysis(
               s"""nondeterministic expressions are only allowed in Project or 
Filter, found:
                  | ${o.expressions.map(_.prettyString).mkString(",")}

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
new file mode 100644
index 0000000..397eff0
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
@@ -0,0 +1,278 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.CatalystConf
+import org.apache.spark.sql.catalyst.expressions._
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
AggregateFunction, Complete}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, 
LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.types.IntegerType
+
+/**
+ * This rule rewrites an aggregate query with distinct aggregations into an 
expanded double
+ * aggregation in which the regular aggregation expressions and every distinct 
clause is aggregated
+ * in a separate group. The results are then combined in a second aggregate.
+ *
+ * For example (in scala):
+ * {{{
+ *   val data = Seq(
+ *     ("a", "ca1", "cb1", 10),
+ *     ("a", "ca1", "cb2", 5),
+ *     ("b", "ca1", "cb1", 13))
+ *     .toDF("key", "cat1", "cat2", "value")
+ *   data.registerTempTable("data")
+ *
+ *   val agg = data.groupBy($"key")
+ *     .agg(
+ *       countDistinct($"cat1").as("cat1_cnt"),
+ *       countDistinct($"cat2").as("cat2_cnt"),
+ *       sum($"value").as("total"))
+ * }}}
+ *
+ * This translates to the following (pseudo) logical plan:
+ * {{{
+ * Aggregate(
+ *    key = ['key]
+ *    functions = [COUNT(DISTINCT 'cat1),
+ *                 COUNT(DISTINCT 'cat2),
+ *                 sum('value)]
+ *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
+ *   LocalTableScan [...]
+ * }}}
+ *
+ * This rule rewrites this logical plan to the following (pseudo) logical plan:
+ * {{{
+ * Aggregate(
+ *    key = ['key]
+ *    functions = [count(if (('gid = 1)) 'cat1 else null),
+ *                 count(if (('gid = 2)) 'cat2 else null),
+ *                 first(if (('gid = 0)) 'total else null) ignore nulls]
+ *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
+ *   Aggregate(
+ *      key = ['key, 'cat1, 'cat2, 'gid]
+ *      functions = [sum('value)]
+ *      output = ['key, 'cat1, 'cat2, 'gid, 'total])
+ *     Expand(
+ *        projections = [('key, null, null, 0, cast('value as bigint)),
+ *                       ('key, 'cat1, null, 1, null),
+ *                       ('key, null, 'cat2, 2, null)]
+ *        output = ['key, 'cat1, 'cat2, 'gid, 'value])
+ *       LocalTableScan [...]
+ * }}}
+ *
+ * The rule does the following things here:
+ * 1. Expand the data. There are three aggregation groups in this query:
+ *    i. the non-distinct group;
+ *    ii. the distinct 'cat1 group;
+ *    iii. the distinct 'cat2 group.
+ *    An expand operator is inserted to expand the child data for each group. 
The expand will null
+ *    out all unused columns for the given group; this must be done in order 
to ensure correctness
+ *    later on. Groups can by identified by a group id (gid) column added by 
the expand operator.
+ * 2. De-duplicate the distinct paths and aggregate the non-aggregate path. 
The group by clause of
+ *    this aggregate consists of the original group by clause, all the 
requested distinct columns
+ *    and the group id. Both de-duplication of distinct column and the 
aggregation of the
+ *    non-distinct group take advantage of the fact that we group by the group 
id (gid) and that we
+ *    have nulled out all non-relevant columns for the the given group.
+ * 3. Aggregating the distinct groups and combining this with the results of 
the non-distinct
+ *    aggregation. In this step we use the group id to filter the inputs for 
the aggregate
+ *    functions. The result of the non-distinct group are 'aggregated' by 
using the first operator,
+ *    it might be more elegant to use the native UDAF merge mechanism for this 
in the future.
+ *
+ * This rule duplicates the input data by two or more times (# distinct groups 
+ an optional
+ * non-distinct group). This will put quite a bit of memory pressure of the 
used aggregate and
+ * exchange operators. Keeping the number of distinct groups as low a possible 
should be priority,
+ * we could improve this in the current rule by applying more advanced 
expression cannocalization
+ * techniques.
+ */
+case class DistinctAggregationRewriter(conf: CatalystConf) extends 
Rule[LogicalPlan] {
+
+  def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+    case p if !p.resolved => p
+    // We need to wait until this Aggregate operator is resolved.
+    case a: Aggregate => rewrite(a)
+    case p => p
+  }
+
+  def rewrite(a: Aggregate): Aggregate = {
+
+    // Collect all aggregate expressions.
+    val aggExpressions = a.aggregateExpressions.flatMap { e =>
+      e.collect {
+        case ae: AggregateExpression => ae
+      }
+    }
+
+    // Extract distinct aggregate expressions.
+    val distinctAggGroups = aggExpressions
+      .filter(_.isDistinct)
+      .groupBy(_.aggregateFunction.children.toSet)
+
+    val shouldRewrite = if (conf.specializeSingleDistinctAggPlanning) {
+      // When the flag is set to specialize single distinct agg planning,
+      // we will rely on our Aggregation strategy to handle queries with a 
single
+      // distinct column and this aggregate operator does have grouping 
expressions.
+      distinctAggGroups.size > 1 || (distinctAggGroups.size == 1 && 
a.groupingExpressions.isEmpty)
+    } else {
+      distinctAggGroups.size >= 1
+    }
+    if (shouldRewrite) {
+      // Create the attributes for the grouping id and the group by clause.
+      val gid = new AttributeReference("gid", IntegerType, false)()
+      val groupByMap = a.groupingExpressions.collect {
+        case ne: NamedExpression => ne -> ne.toAttribute
+        case e => e -> new AttributeReference(e.prettyString, e.dataType, 
e.nullable)()
+      }
+      val groupByAttrs = groupByMap.map(_._2)
+
+      // Functions used to modify aggregate functions and their inputs.
+      def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), 
e, nullify(e))
+      def patchAggregateFunctionChildren(
+          af: AggregateFunction)(
+          attrs: Expression => Expression): AggregateFunction = {
+        af.withNewChildren(af.children.map {
+          case afc => attrs(afc)
+        }).asInstanceOf[AggregateFunction]
+      }
+
+      // Setup unique distinct aggregate children.
+      val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq
+      val distinctAggChildAttrMap = 
distinctAggChildren.map(expressionAttributePair).toMap
+      val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq
+
+      // Setup expand & aggregate operators for distinct aggregate expressions.
+      val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
+        case ((group, expressions), i) =>
+          val id = Literal(i + 1)
+
+          // Expand projection
+          val projection = distinctAggChildren.map {
+            case e if group.contains(e) => e
+            case e => nullify(e)
+          } :+ id
+
+          // Final aggregate
+          val operators = expressions.map { e =>
+            val af = e.aggregateFunction
+            val naf = patchAggregateFunctionChildren(af) { x =>
+              evalWithinGroup(id, distinctAggChildAttrMap(x))
+            }
+            (e, e.copy(aggregateFunction = naf, isDistinct = false))
+          }
+
+          (projection, operators)
+      }
+
+      // Setup expand for the 'regular' aggregate expressions.
+      val regularAggExprs = aggExpressions.filter(!_.isDistinct)
+      val regularAggChildren = 
regularAggExprs.flatMap(_.aggregateFunction.children).distinct
+      val regularAggChildAttrMap = 
regularAggChildren.map(expressionAttributePair)
+
+      // Setup aggregates for 'regular' aggregate expressions.
+      val regularGroupId = Literal(0)
+      val regularAggChildAttrLookup = regularAggChildAttrMap.toMap
+      val regularAggOperatorMap = regularAggExprs.map { e =>
+        // Perform the actual aggregation in the initial aggregate.
+        val af = 
patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup)
+        val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)()
+
+        // Select the result of the first aggregate in the last aggregate.
+        val result = AggregateExpression(
+          aggregate.First(evalWithinGroup(regularGroupId, 
operator.toAttribute), Literal(true)),
+          mode = Complete,
+          isDistinct = false)
+
+        // Some aggregate functions (COUNT) have the special property that 
they can return a
+        // non-null result without any input. We need to make sure we return a 
result in this case.
+        val resultWithDefault = af.defaultResult match {
+          case Some(lit) => Coalesce(Seq(result, lit))
+          case None => result
+        }
+
+        // Return a Tuple3 containing:
+        // i. The original aggregate expression (used for look ups).
+        // ii. The actual aggregation operator (used in the first aggregate).
+        // iii. The operator that selects and returns the result (used in the 
second aggregate).
+        (e, operator, resultWithDefault)
+      }
+
+      // Construct the regular aggregate input projection only if we need one.
+      val regularAggProjection = if (regularAggExprs.nonEmpty) {
+        Seq(a.groupingExpressions ++
+          distinctAggChildren.map(nullify) ++
+          Seq(regularGroupId) ++
+          regularAggChildren)
+      } else {
+        Seq.empty[Seq[Expression]]
+      }
+
+      // Construct the distinct aggregate input projections.
+      val regularAggNulls = regularAggChildren.map(nullify)
+      val distinctAggProjections = distinctAggOperatorMap.map {
+        case (projection, _) =>
+          a.groupingExpressions ++
+            projection ++
+            regularAggNulls
+      }
+
+      // Construct the expand operator.
+      val expand = Expand(
+        regularAggProjection ++ distinctAggProjections,
+        groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ 
regularAggChildAttrMap.map(_._2),
+        a.child)
+
+      // Construct the first aggregate operator. This de-duplicates the all 
the children of
+      // distinct operators, and applies the regular aggregate operators.
+      val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid
+      val firstAggregate = Aggregate(
+        firstAggregateGroupBy,
+        firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2),
+        expand)
+
+      // Construct the second aggregate
+      val transformations: Map[Expression, Expression] =
+        (distinctAggOperatorMap.flatMap(_._2) ++
+          regularAggOperatorMap.map(e => (e._1, e._3))).toMap
+
+      val patchedAggExpressions = a.aggregateExpressions.map { e =>
+        e.transformDown {
+          case e: Expression =>
+            // The same GROUP BY clauses can have different forms (different 
names for instance) in
+            // the groupBy and aggregate expressions of an aggregate. This 
makes a map lookup
+            // tricky. So we do a linear search for a semantically equal group 
by expression.
+            groupByMap
+              .find(ge => e.semanticEquals(ge._1))
+              .map(_._2)
+              .getOrElse(transformations.getOrElse(e, e))
+        }.asInstanceOf[NamedExpression]
+      }
+      Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate)
+    } else {
+      a
+    }
+  }
+
+  private def nullify(e: Expression) = Literal.create(null, e.dataType)
+
+  private def expressionAttributePair(e: Expression) =
+    // We are creating a new reference here instead of reusing the attribute 
in case of a
+    // NamedExpression. This is done to prevent collisions between distinct 
and regular aggregate
+    // children, in this case attribute reuse causes the input of the regular 
aggregate to bound to
+    // the (nulled out) input of the distinct aggregate.
+    e -> new AttributeReference(e.prettyString, e.dataType, true)()
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index d4334d1..dfa749d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -24,6 +24,7 @@ import scala.util.{Failure, Success, Try}
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.util.StringKeyHashMap
 
 
@@ -177,6 +178,7 @@ object FunctionRegistry {
     expression[ToRadians]("radians"),
 
     // aggregate functions
+    expression[HyperLogLogPlusPlus]("approx_count_distinct"),
     expression[Average]("avg"),
     expression[Corr]("corr"),
     expression[Count]("count"),

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 84e2b13..bf2bff0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
 import javax.annotation.Nullable
 
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.types._
@@ -295,14 +296,17 @@ object HiveTypeCoercion {
         i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
 
       case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
-      case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType))
       case Average(e @ StringType()) => Average(Cast(e, DoubleType))
       case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
       case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType))
-      case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType))
-      case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType))
-      case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType))
-      case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType))
+      case VariancePop(e @ StringType(), mutableAggBufferOffset, 
inputAggBufferOffset) =>
+        VariancePop(Cast(e, DoubleType), mutableAggBufferOffset, 
inputAggBufferOffset)
+      case VarianceSamp(e @ StringType(), mutableAggBufferOffset, 
inputAggBufferOffset) =>
+        VarianceSamp(Cast(e, DoubleType), mutableAggBufferOffset, 
inputAggBufferOffset)
+      case Skewness(e @ StringType(), mutableAggBufferOffset, 
inputAggBufferOffset) =>
+        Skewness(Cast(e, DoubleType), mutableAggBufferOffset, 
inputAggBufferOffset)
+      case Kurtosis(e @ StringType(), mutableAggBufferOffset, 
inputAggBufferOffset) =>
+        Kurtosis(Cast(e, DoubleType), mutableAggBufferOffset, 
inputAggBufferOffset)
     }
   }
 
@@ -562,12 +566,6 @@ object HiveTypeCoercion {
       case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, 
LongType))
       case Sum(e @ FractionalType()) if e.dataType != DoubleType => 
Sum(Cast(e, DoubleType))
 
-      case s @ SumDistinct(e @ DecimalType()) => s // Decimal is already the 
biggest.
-      case SumDistinct(e @ IntegralType()) if e.dataType != LongType =>
-        SumDistinct(Cast(e, LongType))
-      case SumDistinct(e @ FractionalType()) if e.dataType != DoubleType =>
-        SumDistinct(Cast(e, DoubleType))
-
       case s @ Average(e @ DecimalType()) => s // Decimal is already the 
biggest.
       case Average(e @ IntegralType()) if e.dataType != LongType =>
         Average(Cast(e, LongType))

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index eae17c8..6485bdf 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -141,6 +141,10 @@ case class UnresolvedFunction(
   override def nullable: Boolean = throw new UnresolvedException(this, 
"nullable")
   override lazy val resolved = false
 
+  override def prettyString: String = {
+    s"${name}(${children.map(_.prettyString).mkString(",")})"
+  }
+
   override def toString: String = s"'$name(${children.mkString(",")})"
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index d8df664..af594c2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -23,6 +23,7 @@ import scala.language.implicitConversions
 
 import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, 
UnresolvedExtractValue, UnresolvedAttribute}
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
 import org.apache.spark.sql.types._
@@ -144,17 +145,18 @@ package object dsl {
       }
     }
 
-    def sum(e: Expression): Expression = Sum(e)
-    def sumDistinct(e: Expression): Expression = SumDistinct(e)
-    def count(e: Expression): Expression = Count(e)
-    def countDistinct(e: Expression*): Expression = CountDistinct(e)
+    def sum(e: Expression): Expression = Sum(e).toAggregateExpression()
+    def sumDistinct(e: Expression): Expression = 
Sum(e).toAggregateExpression(isDistinct = true)
+    def count(e: Expression): Expression = Count(e).toAggregateExpression()
+    def countDistinct(e: Expression*): Expression =
+      Count(e).toAggregateExpression(isDistinct = true)
     def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression =
-      ApproxCountDistinct(e, rsd)
-    def avg(e: Expression): Expression = Average(e)
-    def first(e: Expression): Expression = First(e)
-    def last(e: Expression): Expression = Last(e)
-    def min(e: Expression): Expression = Min(e)
-    def max(e: Expression): Expression = Max(e)
+      HyperLogLogPlusPlus(e, rsd).toAggregateExpression()
+    def avg(e: Expression): Expression = Average(e).toAggregateExpression()
+    def first(e: Expression): Expression = new First(e).toAggregateExpression()
+    def last(e: Expression): Expression = new Last(e).toAggregateExpression()
+    def min(e: Expression): Expression = Min(e).toAggregateExpression()
+    def max(e: Expression): Expression = Max(e).toAggregateExpression()
     def upper(e: Expression): Expression = Upper(e)
     def lower(e: Expression): Expression = Lower(e)
     def sqrt(e: Expression): Expression = Sqrt(e)

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index c8c20ad..7f9e503 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -17,8 +17,10 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
 
 case class Average(child: Expression) extends DeclarativeAggregate {
@@ -32,36 +34,33 @@ case class Average(child: Expression) extends 
DeclarativeAggregate {
   // Return data type.
   override def dataType: DataType = resultType
 
-  // Expected input data type.
-  // TODO: Right now, we replace old aggregate functions (based on 
AggregateExpression1) to the
-  // new version at planning time (after analysis phase). For now, NullType is 
added at here
-  // to make it resolved when we have cases like `select avg(null)`.
-  // We can use our analyzer to cast NullType to the default data type of the 
NumericType once
-  // we remove the old aggregate functions. Then, we will not need NullType at 
here.
-  override def inputTypes: Seq[AbstractDataType] = 
Seq(TypeCollection(NumericType, NullType))
+  override def inputTypes: Seq[AbstractDataType] = 
Seq(TypeCollection(NumericType))
 
-  private val resultType = child.dataType match {
+  override def checkInputDataTypes(): TypeCheckResult =
+    TypeUtils.checkForNumericExpr(child.dataType, "function average")
+
+  private lazy val resultType = child.dataType match {
     case DecimalType.Fixed(p, s) =>
       DecimalType.bounded(p + 4, s + 4)
     case _ => DoubleType
   }
 
-  private val sumDataType = child.dataType match {
+  private lazy val sumDataType = child.dataType match {
     case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
     case _ => DoubleType
   }
 
-  private val sum = AttributeReference("sum", sumDataType)()
-  private val count = AttributeReference("count", LongType)()
+  private lazy val sum = AttributeReference("sum", sumDataType)()
+  private lazy val count = AttributeReference("count", LongType)()
 
-  override val aggBufferAttributes = sum :: count :: Nil
+  override lazy val aggBufferAttributes = sum :: count :: Nil
 
-  override val initialValues = Seq(
+  override lazy val initialValues = Seq(
     /* sum = */ Cast(Literal(0), sumDataType),
     /* count = */ Literal(0L)
   )
 
-  override val updateExpressions = Seq(
+  override lazy val updateExpressions = Seq(
     /* sum = */
     Add(
       sum,
@@ -69,13 +68,13 @@ case class Average(child: Expression) extends 
DeclarativeAggregate {
     /* count = */ If(IsNull(child), count, count + 1L)
   )
 
-  override val mergeExpressions = Seq(
+  override lazy val mergeExpressions = Seq(
     /* sum = */ sum.left + sum.right,
     /* count = */ count.left + count.right
   )
 
   // If all input are nulls, count will be 0 and we will get null after the 
division.
-  override val evaluateExpression = child.dataType match {
+  override lazy val evaluateExpression = child.dataType match {
     case DecimalType.Fixed(p, s) =>
       // increase the precision and scale to prevent precision loss
       val dt = DecimalType.bounded(p + 14, s + 4)

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
index ef08b02..984ce7f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
@@ -18,7 +18,9 @@
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
 
 /**
@@ -55,13 +57,10 @@ abstract class CentralMomentAgg(child: Expression) extends 
ImperativeAggregate w
 
   override def dataType: DataType = DoubleType
 
-  // Expected input data type.
-  // TODO: Right now, we replace old aggregate functions (based on 
AggregateExpression1) to the
-  // new version at planning time (after analysis phase). For now, NullType is 
added at here
-  // to make it resolved when we have cases like `select avg(null)`.
-  // We can use our analyzer to cast NullType to the default data type of the 
NumericType once
-  // we remove the old aggregate functions. Then, we will not need NullType at 
here.
-  override def inputTypes: Seq[AbstractDataType] = 
Seq(TypeCollection(NumericType, NullType))
+  override def inputTypes: Seq[AbstractDataType] = 
Seq(TypeCollection(NumericType))
+
+  override def checkInputDataTypes(): TypeCheckResult =
+    TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName")
 
   override def aggBufferSchema: StructType = 
StructType.fromAttributes(aggBufferAttributes)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
index 8323383..00d7436 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
@@ -18,7 +18,9 @@
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
 
 /**
@@ -35,6 +37,9 @@ case class Corr(
     inputAggBufferOffset: Int = 0)
   extends ImperativeAggregate {
 
+  def this(left: Expression, right: Expression) =
+    this(left, right, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
+
   override def children: Seq[Expression] = Seq(left, right)
 
   override def nullable: Boolean = false
@@ -43,6 +48,16 @@ case class Corr(
 
   override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
 
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (left.dataType.isInstanceOf[DoubleType] && 
right.dataType.isInstanceOf[DoubleType]) {
+      TypeCheckResult.TypeCheckSuccess
+    } else {
+      TypeCheckResult.TypeCheckFailure(
+        s"corr requires that both arguments are double type, " +
+          s"not (${left.dataType}, ${right.dataType}).")
+    }
+  }
+
   override def aggBufferSchema: StructType = 
StructType.fromAttributes(aggBufferAttributes)
 
   override def inputAggBufferAttributes: Seq[AttributeReference] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
index ec0c8b4..09a1da9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
@@ -32,23 +32,39 @@ case class Count(child: Expression) extends 
DeclarativeAggregate {
   // Expected input data type.
   override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
 
-  private val count = AttributeReference("count", LongType)()
+  private lazy val count = AttributeReference("count", LongType)()
 
-  override val aggBufferAttributes = count :: Nil
+  override lazy val aggBufferAttributes = count :: Nil
 
-  override val initialValues = Seq(
+  override lazy val initialValues = Seq(
     /* count = */ Literal(0L)
   )
 
-  override val updateExpressions = Seq(
+  override lazy val updateExpressions = Seq(
     /* count = */ If(IsNull(child), count, count + 1L)
   )
 
-  override val mergeExpressions = Seq(
+  override lazy val mergeExpressions = Seq(
     /* count = */ count.left + count.right
   )
 
-  override val evaluateExpression = Cast(count, LongType)
+  override lazy val evaluateExpression = Cast(count, LongType)
 
   override def defaultResult: Option[Literal] = Option(Literal(0L))
 }
+
+object Count {
+  def apply(children: Seq[Expression]): Count = {
+    // This is used to deal with COUNT DISTINCT. When we have multiple
+    // children (COUNT(DISTINCT col1, col2, ...)), we wrap them in a STRUCT 
(i.e. a Row).
+    // Also, the semantic of COUNT(DISTINCT col1, col2, ...) is that if there 
is any
+    // null in the arguments, we will not count that row. So, we use 
DropAnyNull at here
+    // to return a null when any field of the created STRUCT is null.
+    val child = if (children.size > 1) {
+      DropAnyNull(CreateStruct(children))
+    } else {
+      children.head
+    }
+    Count(child)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
index 9028143..35f5742 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
@@ -51,18 +51,18 @@ case class First(child: Expression, ignoreNullsExpr: 
Expression) extends Declara
   // Expected input data type.
   override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
 
-  private val first = AttributeReference("first", child.dataType)()
+  private lazy val first = AttributeReference("first", child.dataType)()
 
-  private val valueSet = AttributeReference("valueSet", BooleanType)()
+  private lazy val valueSet = AttributeReference("valueSet", BooleanType)()
 
-  override val aggBufferAttributes: Seq[AttributeReference] = first :: 
valueSet :: Nil
+  override lazy val aggBufferAttributes: Seq[AttributeReference] = first :: 
valueSet :: Nil
 
-  override val initialValues: Seq[Literal] = Seq(
+  override lazy val initialValues: Seq[Literal] = Seq(
     /* first = */ Literal.create(null, child.dataType),
     /* valueSet = */ Literal.create(false, BooleanType)
   )
 
-  override val updateExpressions: Seq[Expression] = {
+  override lazy val updateExpressions: Seq[Expression] = {
     if (ignoreNulls) {
       Seq(
         /* first = */ If(Or(valueSet, IsNull(child)), first, child),
@@ -76,7 +76,7 @@ case class First(child: Expression, ignoreNullsExpr: 
Expression) extends Declara
     }
   }
 
-  override val mergeExpressions: Seq[Expression] = {
+  override lazy val mergeExpressions: Seq[Expression] = {
     // For first, we can just check if valueSet.left is set to true. If it is 
set
     // to true, we use first.right. If not, we use first.right (even if 
valueSet.right is
     // false, we are safe to do so because first.right will be null in this 
case).
@@ -86,7 +86,7 @@ case class First(child: Expression, ignoreNullsExpr: 
Expression) extends Declara
     )
   }
 
-  override val evaluateExpression: AttributeReference = first
+  override lazy val evaluateExpression: AttributeReference = first
 
   override def toString: String = s"first($child)${if (ignoreNulls) " ignore 
nulls"}"
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
index 8d341ee..8a95c54 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
@@ -22,6 +22,7 @@ import java.util
 
 import com.clearspring.analytics.hash.MurmurHash
 
+import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types._
@@ -55,6 +56,22 @@ case class HyperLogLogPlusPlus(
   extends ImperativeAggregate {
   import HyperLogLogPlusPlus._
 
+  def this(child: Expression) = {
+    this(child = child, relativeSD = 0.05, mutableAggBufferOffset = 0, 
inputAggBufferOffset = 0)
+  }
+
+  def this(child: Expression, relativeSD: Expression) = {
+    this(
+      child = child,
+      relativeSD = relativeSD match {
+        case Literal(d: Double, DoubleType) => d
+        case _ =>
+          throw new AnalysisException("The second argument should be a double 
literal.")
+      },
+      mutableAggBufferOffset = 0,
+      inputAggBufferOffset = 0)
+  }
+
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
     copy(mutableAggBufferOffset = newMutableAggBufferOffset)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
index 6da39e7..bae78d9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala
@@ -24,6 +24,8 @@ case class Kurtosis(child: Expression,
     inputAggBufferOffset: Int = 0)
   extends CentralMomentAgg(child) {
 
+  def this(child: Expression) = this(child, mutableAggBufferOffset = 0, 
inputAggBufferOffset = 0)
+
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
     copy(mutableAggBufferOffset = newMutableAggBufferOffset)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
index 8636bfe..be7e12d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
@@ -51,15 +51,15 @@ case class Last(child: Expression, ignoreNullsExpr: 
Expression) extends Declarat
   // Expected input data type.
   override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
 
-  private val last = AttributeReference("last", child.dataType)()
+  private lazy val last = AttributeReference("last", child.dataType)()
 
-  override val aggBufferAttributes: Seq[AttributeReference] = last :: Nil
+  override lazy val aggBufferAttributes: Seq[AttributeReference] = last :: Nil
 
-  override val initialValues: Seq[Literal] = Seq(
+  override lazy val initialValues: Seq[Literal] = Seq(
     /* last = */ Literal.create(null, child.dataType)
   )
 
-  override val updateExpressions: Seq[Expression] = {
+  override lazy val updateExpressions: Seq[Expression] = {
     if (ignoreNulls) {
       Seq(
         /* last = */ If(IsNull(child), last, child)
@@ -71,7 +71,7 @@ case class Last(child: Expression, ignoreNullsExpr: 
Expression) extends Declarat
     }
   }
 
-  override val mergeExpressions: Seq[Expression] = {
+  override lazy val mergeExpressions: Seq[Expression] = {
     if (ignoreNulls) {
       Seq(
         /* last = */ If(IsNull(last.right), last.left, last.right)
@@ -83,7 +83,7 @@ case class Last(child: Expression, ignoreNullsExpr: 
Expression) extends Declarat
     }
   }
 
-  override val evaluateExpression: AttributeReference = last
+  override lazy val evaluateExpression: AttributeReference = last
 
   override def toString: String = s"last($child)${if (ignoreNulls) " ignore 
nulls"}"
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
index b9d75ad..61cae44 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
@@ -17,7 +17,9 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
 
 case class Max(child: Expression) extends DeclarativeAggregate {
@@ -32,24 +34,27 @@ case class Max(child: Expression) extends 
DeclarativeAggregate {
   // Expected input data type.
   override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
 
-  private val max = AttributeReference("max", child.dataType)()
+  override def checkInputDataTypes(): TypeCheckResult =
+    TypeUtils.checkForOrderingExpr(child.dataType, "function max")
 
-  override val aggBufferAttributes: Seq[AttributeReference] = max :: Nil
+  private lazy val max = AttributeReference("max", child.dataType)()
 
-  override val initialValues: Seq[Literal] = Seq(
+  override lazy val aggBufferAttributes: Seq[AttributeReference] = max :: Nil
+
+  override lazy val initialValues: Seq[Literal] = Seq(
     /* max = */ Literal.create(null, child.dataType)
   )
 
-  override val updateExpressions: Seq[Expression] = Seq(
+  override lazy val updateExpressions: Seq[Expression] = Seq(
     /* max = */ If(IsNull(child), max, If(IsNull(max), child, 
Greatest(Seq(max, child))))
   )
 
-  override val mergeExpressions: Seq[Expression] = {
+  override lazy val mergeExpressions: Seq[Expression] = {
     val greatest = Greatest(Seq(max.left, max.right))
     Seq(
       /* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), 
max.right, greatest))
     )
   }
 
-  override val evaluateExpression: AttributeReference = max
+  override lazy val evaluateExpression: AttributeReference = max
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
index 5ed9cd3..242456d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
@@ -17,7 +17,9 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
 
 
@@ -33,24 +35,27 @@ case class Min(child: Expression) extends 
DeclarativeAggregate {
   // Expected input data type.
   override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
 
-  private val min = AttributeReference("min", child.dataType)()
+  override def checkInputDataTypes(): TypeCheckResult =
+    TypeUtils.checkForOrderingExpr(child.dataType, "function min")
 
-  override val aggBufferAttributes: Seq[AttributeReference] = min :: Nil
+  private lazy val min = AttributeReference("min", child.dataType)()
 
-  override val initialValues: Seq[Expression] = Seq(
+  override lazy val aggBufferAttributes: Seq[AttributeReference] = min :: Nil
+
+  override lazy val initialValues: Seq[Expression] = Seq(
     /* min = */ Literal.create(null, child.dataType)
   )
 
-  override val updateExpressions: Seq[Expression] = Seq(
+  override lazy val updateExpressions: Seq[Expression] = Seq(
     /* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, 
child))))
   )
 
-  override val mergeExpressions: Seq[Expression] = {
+  override lazy val mergeExpressions: Seq[Expression] = {
     val least = Least(Seq(min.left, min.right))
     Seq(
       /* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), 
min.right, least))
     )
   }
 
-  override val evaluateExpression: AttributeReference = min
+  override lazy val evaluateExpression: AttributeReference = min
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
index 0def7dd..c593074 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala
@@ -24,6 +24,8 @@ case class Skewness(child: Expression,
     inputAggBufferOffset: Int = 0)
   extends CentralMomentAgg(child) {
 
+  def this(child: Expression) = this(child, mutableAggBufferOffset = 0, 
inputAggBufferOffset = 0)
+
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
     copy(mutableAggBufferOffset = newMutableAggBufferOffset)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
index 3f47ffe..5b9eb7a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
@@ -17,8 +17,10 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
 
 
@@ -48,29 +50,26 @@ abstract class StddevAgg(child: Expression) extends 
DeclarativeAggregate {
 
   override def dataType: DataType = resultType
 
-  // Expected input data type.
-  // TODO: Right now, we replace old aggregate functions (based on 
AggregateExpression1) to the
-  // new version at planning time (after analysis phase). For now, NullType is 
added at here
-  // to make it resolved when we have cases like `select stddev(null)`.
-  // We can use our analyzer to cast NullType to the default data type of the 
NumericType once
-  // we remove the old aggregate functions. Then, we will not need NullType at 
here.
-  override def inputTypes: Seq[AbstractDataType] = 
Seq(TypeCollection(NumericType, NullType))
+  override def inputTypes: Seq[AbstractDataType] = 
Seq(TypeCollection(NumericType))
 
-  private val resultType = DoubleType
+  override def checkInputDataTypes(): TypeCheckResult =
+    TypeUtils.checkForNumericExpr(child.dataType, "function stddev")
 
-  private val count = AttributeReference("count", resultType)()
-  private val avg = AttributeReference("avg", resultType)()
-  private val mk = AttributeReference("mk", resultType)()
+  private lazy val resultType = DoubleType
 
-  override val aggBufferAttributes = count :: avg :: mk :: Nil
+  private lazy val count = AttributeReference("count", resultType)()
+  private lazy val avg = AttributeReference("avg", resultType)()
+  private lazy val mk = AttributeReference("mk", resultType)()
 
-  override val initialValues: Seq[Expression] = Seq(
+  override lazy val aggBufferAttributes = count :: avg :: mk :: Nil
+
+  override lazy val initialValues: Seq[Expression] = Seq(
     /* count = */ Cast(Literal(0), resultType),
     /* avg = */ Cast(Literal(0), resultType),
     /* mk = */ Cast(Literal(0), resultType)
   )
 
-  override val updateExpressions: Seq[Expression] = {
+  override lazy val updateExpressions: Seq[Expression] = {
     val value = Cast(child, resultType)
     val newCount = count + Cast(Literal(1), resultType)
 
@@ -89,7 +88,7 @@ abstract class StddevAgg(child: Expression) extends 
DeclarativeAggregate {
     )
   }
 
-  override val mergeExpressions: Seq[Expression] = {
+  override lazy val mergeExpressions: Seq[Expression] = {
 
     // count merge
     val newCount = count.left + count.right
@@ -114,7 +113,7 @@ abstract class StddevAgg(child: Expression) extends 
DeclarativeAggregate {
     )
   }
 
-  override val evaluateExpression: Expression = {
+  override lazy val evaluateExpression: Expression = {
     // when count == 0, return null
     // when count == 1, return 0
     // when count >1

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index 7f8adbc..c005ec9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -17,7 +17,9 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
 
 case class Sum(child: Expression) extends DeclarativeAggregate {
@@ -29,16 +31,13 @@ case class Sum(child: Expression) extends 
DeclarativeAggregate {
   // Return data type.
   override def dataType: DataType = resultType
 
-  // Expected input data type.
-  // TODO: Right now, we replace old aggregate functions (based on 
AggregateExpression1) to the
-  // new version at planning time (after analysis phase). For now, NullType is 
added at here
-  // to make it resolved when we have cases like `select sum(null)`.
-  // We can use our analyzer to cast NullType to the default data type of the 
NumericType once
-  // we remove the old aggregate functions. Then, we will not need NullType at 
here.
   override def inputTypes: Seq[AbstractDataType] =
     Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType))
 
-  private val resultType = child.dataType match {
+  override def checkInputDataTypes(): TypeCheckResult =
+    TypeUtils.checkForNumericExpr(child.dataType, "function sum")
+
+  private lazy val resultType = child.dataType match {
     case DecimalType.Fixed(precision, scale) =>
       DecimalType.bounded(precision + 10, scale)
     // TODO: Remove this line once we remove the NullType from inputTypes.
@@ -46,24 +45,24 @@ case class Sum(child: Expression) extends 
DeclarativeAggregate {
     case _ => child.dataType
   }
 
-  private val sumDataType = resultType
+  private lazy val sumDataType = resultType
 
-  private val sum = AttributeReference("sum", sumDataType)()
+  private lazy val sum = AttributeReference("sum", sumDataType)()
 
-  private val zero = Cast(Literal(0), sumDataType)
+  private lazy val zero = Cast(Literal(0), sumDataType)
 
-  override val aggBufferAttributes = sum :: Nil
+  override lazy val aggBufferAttributes = sum :: Nil
 
-  override val initialValues: Seq[Expression] = Seq(
+  override lazy val initialValues: Seq[Expression] = Seq(
     /* sum = */ Literal.create(null, sumDataType)
   )
 
-  override val updateExpressions: Seq[Expression] = Seq(
+  override lazy val updateExpressions: Seq[Expression] = Seq(
     /* sum = */
     Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum))
   )
 
-  override val mergeExpressions: Seq[Expression] = {
+  override lazy val mergeExpressions: Seq[Expression] = {
     val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType))
     Seq(
       /* sum = */
@@ -71,5 +70,5 @@ case class Sum(child: Expression) extends 
DeclarativeAggregate {
     )
   }
 
-  override val evaluateExpression: Expression = Cast(sum, resultType)
+  override lazy val evaluateExpression: Expression = Cast(sum, resultType)
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to