Repository: spark
Updated Branches:
  refs/heads/branch-1.6 aede729a9 -> 696d4a52d


[SPARK-9241][SQL] Supporting multiple DISTINCT columns - follow-up

This PR is a follow up for PR https://github.com/apache/spark/pull/9406. It 
adds more documentation to the rewriting rule, removes a redundant if 
expression in the non-distinct aggregation path and adds a multiple distinct 
test to the AggregationQuerySuite.

cc yhuai marmbrus

Author: Herman van Hovell <hvanhov...@questtec.nl>

Closes #9541 from hvanhovell/SPARK-9241-followup.

(cherry picked from commit ef362846eb448769bcf774fc9090a5013d459464)
Signed-off-by: Yin Huai <yh...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/696d4a52
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/696d4a52
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/696d4a52

Branch: refs/heads/branch-1.6
Commit: 696d4a52d8ee5c1c736ce470ac87255fe58e78c3
Parents: aede729
Author: Herman van Hovell <hvanhov...@questtec.nl>
Authored: Sat Nov 7 13:37:37 2015 -0800
Committer: Yin Huai <yh...@databricks.com>
Committed: Sat Nov 7 13:38:08 2015 -0800

----------------------------------------------------------------------
 .../catalyst/expressions/aggregate/Utils.scala  | 114 +++++++++++++++----
 .../hive/execution/AggregationQuerySuite.scala  |  17 +++
 2 files changed, 108 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/696d4a52/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
index 39010c3..ac23f72 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
@@ -222,10 +222,76 @@ object Utils {
  * aggregation in which the regular aggregation expressions and every distinct 
clause is aggregated
  * in a separate group. The results are then combined in a second aggregate.
  *
- * TODO Expression cannocalization
- * TODO Eliminate foldable expressions from distinct clauses.
- * TODO This eliminates all distinct expressions. We could safely pass one to 
the aggregate
- *      operator. Perhaps this is a good thing? It is much simpler to plan 
later on...
+ * For example (in scala):
+ * {{{
+ *   val data = Seq(
+ *     ("a", "ca1", "cb1", 10),
+ *     ("a", "ca1", "cb2", 5),
+ *     ("b", "ca1", "cb1", 13))
+ *     .toDF("key", "cat1", "cat2", "value")
+ *   data.registerTempTable("data")
+ *
+ *   val agg = data.groupBy($"key")
+ *     .agg(
+ *       countDistinct($"cat1").as("cat1_cnt"),
+ *       countDistinct($"cat2").as("cat2_cnt"),
+ *       sum($"value").as("total"))
+ * }}}
+ *
+ * This translates to the following (pseudo) logical plan:
+ * {{{
+ * Aggregate(
+ *    key = ['key]
+ *    functions = [COUNT(DISTINCT 'cat1),
+ *                 COUNT(DISTINCT 'cat2),
+ *                 sum('value)]
+ *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
+ *   LocalTableScan [...]
+ * }}}
+ *
+ * This rule rewrites this logical plan to the following (pseudo) logical plan:
+ * {{{
+ * Aggregate(
+ *    key = ['key]
+ *    functions = [count(if (('gid = 1)) 'cat1 else null),
+ *                 count(if (('gid = 2)) 'cat2 else null),
+ *                 first(if (('gid = 0)) 'total else null) ignore nulls]
+ *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
+ *   Aggregate(
+ *      key = ['key, 'cat1, 'cat2, 'gid]
+ *      functions = [sum('value)]
+ *      output = ['key, 'cat1, 'cat2, 'gid, 'total])
+ *     Expand(
+ *        projections = [('key, null, null, 0, cast('value as bigint)),
+ *                       ('key, 'cat1, null, 1, null),
+ *                       ('key, null, 'cat2, 2, null)]
+ *        output = ['key, 'cat1, 'cat2, 'gid, 'value])
+ *       LocalTableScan [...]
+ * }}}
+ *
+ * The rule does the following things here:
+ * 1. Expand the data. There are three aggregation groups in this query:
+ *    i. the non-distinct group;
+ *    ii. the distinct 'cat1 group;
+ *    iii. the distinct 'cat2 group.
+ *    An expand operator is inserted to expand the child data for each group. 
The expand will null
+ *    out all unused columns for the given group; this must be done in order 
to ensure correctness
+ *    later on. Groups can by identified by a group id (gid) column added by 
the expand operator.
+ * 2. De-duplicate the distinct paths and aggregate the non-aggregate path. 
The group by clause of
+ *    this aggregate consists of the original group by clause, all the 
requested distinct columns
+ *    and the group id. Both de-duplication of distinct column and the 
aggregation of the
+ *    non-distinct group take advantage of the fact that we group by the group 
id (gid) and that we
+ *    have nulled out all non-relevant columns for the the given group.
+ * 3. Aggregating the distinct groups and combining this with the results of 
the non-distinct
+ *    aggregation. In this step we use the group id to filter the inputs for 
the aggregate
+ *    functions. The result of the non-distinct group are 'aggregated' by 
using the first operator,
+ *    it might be more elegant to use the native UDAF merge mechanism for this 
in the future.
+ *
+ * This rule duplicates the input data by two or more times (# distinct groups 
+ an optional
+ * non-distinct group). This will put quite a bit of memory pressure of the 
used aggregate and
+ * exchange operators. Keeping the number of distinct groups as low a possible 
should be priority,
+ * we could improve this in the current rule by applying more advanced 
expression cannocalization
+ * techniques.
  */
 object MultipleDistinctRewriter extends Rule[LogicalPlan] {
 
@@ -261,11 +327,10 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] 
{
       // Functions used to modify aggregate functions and their inputs.
       def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), 
e, nullify(e))
       def patchAggregateFunctionChildren(
-          af: AggregateFunction2,
-          id: Literal,
-          attrs: Map[Expression, Expression]): AggregateFunction2 = {
-        af.withNewChildren(af.children.map { case afc =>
-          evalWithinGroup(id, attrs(afc))
+          af: AggregateFunction2)(
+          attrs: Expression => Expression): AggregateFunction2 = {
+        af.withNewChildren(af.children.map {
+          case afc => attrs(afc)
         }).asInstanceOf[AggregateFunction2]
       }
 
@@ -288,7 +353,9 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
           // Final aggregate
           val operators = expressions.map { e =>
             val af = e.aggregateFunction
-            val naf = patchAggregateFunctionChildren(af, id, 
distinctAggChildAttrMap)
+            val naf = patchAggregateFunctionChildren(af) { x =>
+              evalWithinGroup(id, distinctAggChildAttrMap(x))
+            }
             (e, e.copy(aggregateFunction = naf, isDistinct = false))
           }
 
@@ -304,26 +371,27 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] 
{
       val regularGroupId = Literal(0)
       val regularAggOperatorMap = regularAggExprs.map { e =>
         // Perform the actual aggregation in the initial aggregate.
-        val af = patchAggregateFunctionChildren(
-          e.aggregateFunction,
-          regularGroupId,
-          regularAggChildAttrMap)
-        val a = Alias(e.copy(aggregateFunction = af), e.toString)()
-
-        // Get the result of the first aggregate in the last aggregate.
-        val b = AggregateExpression2(
-          aggregate.First(evalWithinGroup(regularGroupId, a.toAttribute), 
Literal(true)),
+        val af = 
patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrMap)
+        val operator = Alias(e.copy(aggregateFunction = af), e.toString)()
+
+        // Select the result of the first aggregate in the last aggregate.
+        val result = AggregateExpression2(
+          aggregate.First(evalWithinGroup(regularGroupId, 
operator.toAttribute), Literal(true)),
           mode = Complete,
           isDistinct = false)
 
         // Some aggregate functions (COUNT) have the special property that 
they can return a
         // non-null result without any input. We need to make sure we return a 
result in this case.
-        val c = af.defaultResult match {
-          case Some(lit) => Coalesce(Seq(b, lit))
-          case None => b
+        val resultWithDefault = af.defaultResult match {
+          case Some(lit) => Coalesce(Seq(result, lit))
+          case None => result
         }
 
-        (e, a, c)
+        // Return a Tuple3 containing:
+        // i. The original aggregate expression (used for look ups).
+        // ii. The actual aggregation operator (used in the first aggregate).
+        // iii. The operator that selects and returns the result (used in the 
second aggregate).
+        (e, operator, resultWithDefault)
       }
 
       // Construct the regular aggregate input projection only if we need one.

http://git-wip-us.apache.org/repos/asf/spark/blob/696d4a52/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index ea80060..7f6fe33 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -516,6 +516,23 @@ abstract class AggregationQuerySuite extends QueryTest 
with SQLTestUtils with Te
         Row(3, 4, 4, 3, null) :: Nil)
   }
 
+  test("multiple distinct column sets") {
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT
+          |  key,
+          |  count(distinct value1),
+          |  count(distinct value2)
+          |FROM agg2
+          |GROUP BY key
+        """.stripMargin),
+      Row(null, 3, 3) ::
+        Row(1, 2, 3) ::
+        Row(2, 2, 1) ::
+        Row(3, 0, 1) :: Nil)
+  }
+
   test("test count") {
     checkAnswer(
       sqlContext.sql(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to