Repository: spark
Updated Branches:
  refs/heads/branch-1.6 825e971d0 -> 7c4ade0d7


http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index ab88c1e..6f8ed41 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -38,6 +38,7 @@ import org.apache.spark.Logging
 import org.apache.spark.sql.{AnalysisException, catalyst}
 import org.apache.spark.sql.catalyst.analysis._
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.{logical, _}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.trees.CurrentOrigin
@@ -1508,9 +1509,10 @@ 
https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
       UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName(name)))
 
     /* Aggregate Functions */
-    case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => 
Count(Literal(1))
-    case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => 
CountDistinct(args.map(nodeToExpr))
-    case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => 
SumDistinct(nodeToExpr(arg))
+    case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) =>
+      Count(args.map(nodeToExpr)).toAggregateExpression(isDistinct = true)
+    case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) =>
+      Count(Literal(1)).toAggregateExpression()
 
     /* Casts */
     case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/7c4ade0d/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index ea36c13..6bf2c53 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -69,11 +69,7 @@ class ScalaAggregateFunction(schema: StructType) extends 
UserDefinedAggregateFun
 abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with 
TestHiveSingleton {
   import testImplicits._
 
-  var originalUseAggregate2: Boolean = _
-
   override def beforeAll(): Unit = {
-    originalUseAggregate2 = sqlContext.conf.useSqlAggregate2
-    sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, "true")
     val data1 = Seq[(Integer, Integer)](
       (1, 10),
       (null, -60),
@@ -120,7 +116,6 @@ abstract class AggregationQuerySuite extends QueryTest with 
SQLTestUtils with Te
     sqlContext.sql("DROP TABLE IF EXISTS agg1")
     sqlContext.sql("DROP TABLE IF EXISTS agg2")
     sqlContext.dropTempTable("emptyTable")
-    sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, 
originalUseAggregate2.toString)
   }
 
   test("empty table") {
@@ -447,73 +442,80 @@ abstract class AggregationQuerySuite extends QueryTest 
with SQLTestUtils with Te
   }
 
   test("single distinct column set") {
-    // DISTINCT is not meaningful with Max and Min, so we just ignore the 
DISTINCT keyword.
-    checkAnswer(
-      sqlContext.sql(
-        """
-          |SELECT
-          |  min(distinct value1),
-          |  sum(distinct value1),
-          |  avg(value1),
-          |  avg(value2),
-          |  max(distinct value1)
-          |FROM agg2
-        """.stripMargin),
-      Row(-60, 70.0, 101.0/9.0, 5.6, 100))
-
-    checkAnswer(
-      sqlContext.sql(
-        """
-          |SELECT
-          |  mydoubleavg(distinct value1),
-          |  avg(value1),
-          |  avg(value2),
-          |  key,
-          |  mydoubleavg(value1 - 1),
-          |  mydoubleavg(distinct value1) * 0.1,
-          |  avg(value1 + value2)
-          |FROM agg2
-          |GROUP BY key
-        """.stripMargin),
-      Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) ::
-        Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) ::
-        Row(null, null, 3.0, 3, null, null, null) ::
-        Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil)
-
-    checkAnswer(
-      sqlContext.sql(
-        """
-          |SELECT
-          |  key,
-          |  mydoubleavg(distinct value1),
-          |  mydoublesum(value2),
-          |  mydoublesum(distinct value1),
-          |  mydoubleavg(distinct value1),
-          |  mydoubleavg(value1)
-          |FROM agg2
-          |GROUP BY key
-        """.stripMargin),
-      Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) ::
-        Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) ::
-        Row(3, null, 3.0, null, null, null) ::
-        Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil)
-
-    checkAnswer(
-      sqlContext.sql(
-        """
-          |SELECT
-          |  count(value1),
-          |  count(*),
-          |  count(1),
-          |  count(DISTINCT value1),
-          |  key
-          |FROM agg2
-          |GROUP BY key
-        """.stripMargin),
-      Row(3, 3, 3, 2, 1) ::
-        Row(3, 4, 4, 2, 2) ::
-        Row(0, 2, 2, 0, 3) ::
-        Row(3, 4, 4, 3, null) :: Nil)
+    Seq(true, false).foreach { specializeSingleDistinctAgg =>
+      val conf =
+        (SQLConf.SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING.key,
+          specializeSingleDistinctAgg.toString)
+      withSQLConf(conf) {
+        // DISTINCT is not meaningful with Max and Min, so we just ignore the 
DISTINCT keyword.
+        checkAnswer(
+          sqlContext.sql(
+            """
+              |SELECT
+              |  min(distinct value1),
+              |  sum(distinct value1),
+              |  avg(value1),
+              |  avg(value2),
+              |  max(distinct value1)
+              |FROM agg2
+            """.stripMargin),
+          Row(-60, 70.0, 101.0/9.0, 5.6, 100))
+
+        checkAnswer(
+          sqlContext.sql(
+            """
+              |SELECT
+              |  mydoubleavg(distinct value1),
+              |  avg(value1),
+              |  avg(value2),
+              |  key,
+              |  mydoubleavg(value1 - 1),
+              |  mydoubleavg(distinct value1) * 0.1,
+              |  avg(value1 + value2)
+              |FROM agg2
+              |GROUP BY key
+            """.stripMargin),
+          Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) ::
+            Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) ::
+            Row(null, null, 3.0, 3, null, null, null) ::
+            Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil)
+
+        checkAnswer(
+          sqlContext.sql(
+            """
+              |SELECT
+              |  key,
+              |  mydoubleavg(distinct value1),
+              |  mydoublesum(value2),
+              |  mydoublesum(distinct value1),
+              |  mydoubleavg(distinct value1),
+              |  mydoubleavg(value1)
+              |FROM agg2
+              |GROUP BY key
+            """.stripMargin),
+          Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) ::
+            Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) ::
+            Row(3, null, 3.0, null, null, null) ::
+            Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil)
+
+        checkAnswer(
+          sqlContext.sql(
+            """
+              |SELECT
+              |  count(value1),
+              |  count(*),
+              |  count(1),
+              |  count(DISTINCT value1),
+              |  key
+              |FROM agg2
+              |GROUP BY key
+            """.stripMargin),
+          Row(3, 3, 3, 2, 1) ::
+            Row(3, 4, 4, 2, 2) ::
+            Row(0, 2, 2, 0, 3) ::
+            Row(3, 4, 4, 3, null) :: Nil)
+      }
+    }
   }
 
   test("single distinct multiple columns set") {
@@ -699,48 +701,6 @@ abstract class AggregationQuerySuite extends QueryTest 
with SQLTestUtils with Te
 
     val corr7 = sqlContext.sql("SELECT corr(b, c) FROM 
covar_tab").collect()(0).getDouble(0)
     assert(math.abs(corr7 - 0.6633880657639323) < 1e-12)
-
-    withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") {
-      val errorMessage = intercept[SparkException] {
-        val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", 
"b", "c")
-        val corr1 = df.repartition(2).groupBy().agg(corr("a", 
"b")).collect()(0).getDouble(0)
-      }.getMessage
-      assert(errorMessage.contains("java.lang.UnsupportedOperationException: " 
+
-        "Corr only supports the new AggregateExpression2"))
-    }
-  }
-
-  test("test Last implemented based on AggregateExpression1") {
-    // TODO: Remove this test once we remove AggregateExpression1.
-    import org.apache.spark.sql.functions._
-    val df = Seq((1, 1), (2, 2), (3, 3)).toDF("i", "j").repartition(1)
-    withSQLConf(
-      SQLConf.SHUFFLE_PARTITIONS.key -> "1",
-      SQLConf.USE_SQL_AGGREGATE2.key -> "false") {
-
-      checkAnswer(
-        df.groupBy("i").agg(last("j")),
-        df
-      )
-    }
-  }
-
-  test("error handling") {
-    withSQLConf("spark.sql.useAggregate2" -> "false") {
-      val errorMessage = intercept[AnalysisException] {
-        sqlContext.sql(
-          """
-            |SELECT
-            |  key,
-            |  sum(value + 1.5 * key),
-            |  mydoublesum(value),
-            |  mydoubleavg(value)
-            |FROM agg1
-            |GROUP BY key
-          """.stripMargin).collect()
-      }.getMessage
-      assert(errorMessage.contains("implemented based on the new Aggregate 
Function interface"))
-    }
   }
 
   test("no aggregation function (SPARK-11486)") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to