Repository: spark
Updated Branches:
  refs/heads/master 1ab72b086 -> 6d0ead322


[SPARK-9241][SQL] Supporting multiple DISTINCT columns (2) - Rewriting Rule

The second PR for SPARK-9241, this adds support for multiple distinct columns 
to the new aggregation code path.

This PR solves the multiple DISTINCT column problem by rewriting these 
Aggregates into an Expand-Aggregate-Aggregate combination. See the [JIRA 
ticket](https://issues.apache.org/jira/browse/SPARK-9241) for some information 
on this. The advantages over the - competing - [first 
PR](https://github.com/apache/spark/pull/9280) are:
- This can use the faster TungstenAggregate code path.
- It is impossible to OOM due to an ```OpenHashSet``` allocating to much 
memory. However, this will multiply the number of input rows by the number of 
distinct clauses (plus one), and puts a lot more memory pressure on the 
aggregation code path itself.

The location of this Rule is a bit funny, and should probably change when the 
old aggregation path is changed.

cc yhuai - Could you also tell me where to add tests for this?

Author: Herman van Hovell <hvanhov...@questtec.nl>

Closes #9406 from hvanhovell/SPARK-9241-rewriter.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6d0ead32
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6d0ead32
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6d0ead32

Branch: refs/heads/master
Commit: 6d0ead322e72303c6444c6ac641378a4690cde96
Parents: 1ab72b0
Author: Herman van Hovell <hvanhov...@questtec.nl>
Authored: Fri Nov 6 16:04:20 2015 -0800
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Fri Nov 6 16:04:20 2015 -0800

----------------------------------------------------------------------
 .../catalyst/expressions/aggregate/Count.scala  |   2 +
 .../catalyst/expressions/aggregate/Utils.scala  | 186 ++++++++++++++++++-
 .../expressions/aggregate/interfaces.scala      |   6 +
 .../sql/catalyst/optimizer/Optimizer.scala      |   6 +-
 .../catalyst/plans/logical/basicOperators.scala |  80 ++++----
 .../spark/sql/execution/SparkStrategies.scala   |   2 +-
 6 files changed, 238 insertions(+), 44 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6d0ead32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
index 54df96c..ec0c8b4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
@@ -49,4 +49,6 @@ case class Count(child: Expression) extends 
DeclarativeAggregate {
   )
 
   override val evaluateExpression = Cast(count, LongType)
+
+  override def defaultResult: Option[Literal] = Option(Literal(0L))
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/6d0ead32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
index 644c621..39010c3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
@@ -20,8 +20,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst._
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
-import org.apache.spark.sql.types.{StructType, MapType, ArrayType}
+import org.apache.spark.sql.catalyst.plans.logical.{Expand, Aggregate, 
LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.types.{IntegerType, StructType, MapType, ArrayType}
 
 /**
  * Utility functions used by the query planner to convert our plan to new 
aggregation code path.
@@ -41,7 +42,7 @@ object Utils {
 
   private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
     case p: Aggregate if supportsGroupingKeySchema(p) =>
-      val converted = p.transformExpressionsDown {
+      val converted = 
MultipleDistinctRewriter.rewrite(p.transformExpressionsDown {
         case expressions.Average(child) =>
           aggregate.AggregateExpression2(
             aggregateFunction = aggregate.Average(child),
@@ -144,7 +145,8 @@ object Utils {
             aggregateFunction = aggregate.VarianceSamp(child),
             mode = aggregate.Complete,
             isDistinct = false)
-      }
+      })
+
       // Check if there is any expressions.AggregateExpression1 left.
       // If so, we cannot convert this plan.
       val hasAggregateExpression1 = converted.aggregateExpressions.exists { 
expr =>
@@ -156,6 +158,7 @@ object Utils {
       }
 
       // Check if there are multiple distinct columns.
+      // TODO remove this.
       val aggregateExpressions = converted.aggregateExpressions.flatMap { expr 
=>
         expr.collect {
           case agg: AggregateExpression2 => agg
@@ -213,3 +216,178 @@ object Utils {
     case other => None
   }
 }
+
+/**
+ * This rule rewrites an aggregate query with multiple distinct clauses into 
an expanded double
+ * aggregation in which the regular aggregation expressions and every distinct 
clause is aggregated
+ * in a separate group. The results are then combined in a second aggregate.
+ *
+ * TODO Expression cannocalization
+ * TODO Eliminate foldable expressions from distinct clauses.
+ * TODO This eliminates all distinct expressions. We could safely pass one to 
the aggregate
+ *      operator. Perhaps this is a good thing? It is much simpler to plan 
later on...
+ */
+object MultipleDistinctRewriter extends Rule[LogicalPlan] {
+
+  def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+    case a: Aggregate => rewrite(a)
+    case p => p
+  }
+
+  def rewrite(a: Aggregate): Aggregate = {
+
+    // Collect all aggregate expressions.
+    val aggExpressions = a.aggregateExpressions.flatMap { e =>
+      e.collect {
+        case ae: AggregateExpression2 => ae
+      }
+    }
+
+    // Extract distinct aggregate expressions.
+    val distinctAggGroups = aggExpressions
+      .filter(_.isDistinct)
+      .groupBy(_.aggregateFunction.children.toSet)
+
+    // Only continue to rewrite if there is more than one distinct group.
+    if (distinctAggGroups.size > 1) {
+      // Create the attributes for the grouping id and the group by clause.
+      val gid = new AttributeReference("gid", IntegerType, false)()
+      val groupByMap = a.groupingExpressions.collect {
+        case ne: NamedExpression => ne -> ne.toAttribute
+        case e => e -> new AttributeReference(e.prettyName, e.dataType, 
e.nullable)()
+      }
+      val groupByAttrs = groupByMap.map(_._2)
+
+      // Functions used to modify aggregate functions and their inputs.
+      def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), 
e, nullify(e))
+      def patchAggregateFunctionChildren(
+          af: AggregateFunction2,
+          id: Literal,
+          attrs: Map[Expression, Expression]): AggregateFunction2 = {
+        af.withNewChildren(af.children.map { case afc =>
+          evalWithinGroup(id, attrs(afc))
+        }).asInstanceOf[AggregateFunction2]
+      }
+
+      // Setup unique distinct aggregate children.
+      val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq
+      val distinctAggChildAttrMap = 
distinctAggChildren.map(expressionAttributePair).toMap
+      val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq
+
+      // Setup expand & aggregate operators for distinct aggregate expressions.
+      val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
+        case ((group, expressions), i) =>
+          val id = Literal(i + 1)
+
+          // Expand projection
+          val projection = distinctAggChildren.map {
+            case e if group.contains(e) => e
+            case e => nullify(e)
+          } :+ id
+
+          // Final aggregate
+          val operators = expressions.map { e =>
+            val af = e.aggregateFunction
+            val naf = patchAggregateFunctionChildren(af, id, 
distinctAggChildAttrMap)
+            (e, e.copy(aggregateFunction = naf, isDistinct = false))
+          }
+
+          (projection, operators)
+      }
+
+      // Setup expand for the 'regular' aggregate expressions.
+      val regularAggExprs = aggExpressions.filter(!_.isDistinct)
+      val regularAggChildren = 
regularAggExprs.flatMap(_.aggregateFunction.children).distinct
+      val regularAggChildAttrMap = 
regularAggChildren.map(expressionAttributePair).toMap
+
+      // Setup aggregates for 'regular' aggregate expressions.
+      val regularGroupId = Literal(0)
+      val regularAggOperatorMap = regularAggExprs.map { e =>
+        // Perform the actual aggregation in the initial aggregate.
+        val af = patchAggregateFunctionChildren(
+          e.aggregateFunction,
+          regularGroupId,
+          regularAggChildAttrMap)
+        val a = Alias(e.copy(aggregateFunction = af), e.toString)()
+
+        // Get the result of the first aggregate in the last aggregate.
+        val b = AggregateExpression2(
+          aggregate.First(evalWithinGroup(regularGroupId, a.toAttribute), 
Literal(true)),
+          mode = Complete,
+          isDistinct = false)
+
+        // Some aggregate functions (COUNT) have the special property that 
they can return a
+        // non-null result without any input. We need to make sure we return a 
result in this case.
+        val c = af.defaultResult match {
+          case Some(lit) => Coalesce(Seq(b, lit))
+          case None => b
+        }
+
+        (e, a, c)
+      }
+
+      // Construct the regular aggregate input projection only if we need one.
+      val regularAggProjection = if (regularAggExprs.nonEmpty) {
+        Seq(a.groupingExpressions ++
+          distinctAggChildren.map(nullify) ++
+          Seq(regularGroupId) ++
+          regularAggChildren)
+      } else {
+        Seq.empty[Seq[Expression]]
+      }
+
+      // Construct the distinct aggregate input projections.
+      val regularAggNulls = regularAggChildren.map(nullify)
+      val distinctAggProjections = distinctAggOperatorMap.map {
+        case (projection, _) =>
+          a.groupingExpressions ++
+            projection ++
+            regularAggNulls
+      }
+
+      // Construct the expand operator.
+      val expand = Expand(
+        regularAggProjection ++ distinctAggProjections,
+        groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ 
regularAggChildAttrMap.values.toSeq,
+        a.child)
+
+      // Construct the first aggregate operator. This de-duplicates the all 
the children of
+      // distinct operators, and applies the regular aggregate operators.
+      val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid
+      val firstAggregate = Aggregate(
+        firstAggregateGroupBy,
+        firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2),
+        expand)
+
+      // Construct the second aggregate
+      val transformations: Map[Expression, Expression] =
+        (distinctAggOperatorMap.flatMap(_._2) ++
+          regularAggOperatorMap.map(e => (e._1, e._3))).toMap
+
+      val patchedAggExpressions = a.aggregateExpressions.map { e =>
+        e.transformDown {
+          case e: Expression =>
+            // The same GROUP BY clauses can have different forms (different 
names for instance) in
+            // the groupBy and aggregate expressions of an aggregate. This 
makes a map lookup
+            // tricky. So we do a linear search for a semantically equal group 
by expression.
+            groupByMap
+              .find(ge => e.semanticEquals(ge._1))
+              .map(_._2)
+              .getOrElse(transformations.getOrElse(e, e))
+        }.asInstanceOf[NamedExpression]
+      }
+      Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate)
+    } else {
+      a
+    }
+  }
+
+  private def nullify(e: Expression) = Literal.create(null, e.dataType)
+
+  private def expressionAttributePair(e: Expression) =
+    // We are creating a new reference here instead of reusing the attribute 
in case of a
+    // NamedExpression. This is done to prevent collisions between distinct 
and regular aggregate
+    // children, in this case attribute reuse causes the input of the regular 
aggregate to bound to
+    // the (nulled out) input of the distinct aggregate.
+    e -> new AttributeReference(e.prettyName, e.dataType, true)()
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6d0ead32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index a2fab25..5c5b3d1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -133,6 +133,12 @@ sealed abstract class AggregateFunction2 extends 
Expression with ImplicitCastInp
    */
   def supportsPartial: Boolean = true
 
+  /**
+   * Result of the aggregate function when the input is empty. This is 
currently only used for the
+   * proper rewriting of distinct aggregate functions.
+   */
+  def defaultResult: Option[Literal] = None
+
   override protected def genCode(ctx: CodeGenContext, ev: 
GeneratedExpressionCode): String =
     throw new UnsupportedOperationException(s"Cannot evaluate expression: 
$this")
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/6d0ead32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 338c519..d222dfa 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -200,9 +200,9 @@ object SetOperationPushDown extends Rule[LogicalPlan] with 
PredicateHelper {
  */
 object ColumnPruning extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-    case a @ Aggregate(_, _, e @ Expand(_, groupByExprs, _, child))
-      if (child.outputSet -- AttributeSet(groupByExprs) -- 
a.references).nonEmpty =>
-      a.copy(child = e.copy(child = prunedChild(child, 
AttributeSet(groupByExprs) ++ a.references)))
+    case a @ Aggregate(_, _, e @ Expand(_, _, child))
+      if (child.outputSet -- AttributeSet(e.output) -- a.references).nonEmpty 
=>
+      a.copy(child = e.copy(child = prunedChild(child, AttributeSet(e.output) 
++ a.references)))
 
     // Eliminate attributes that are not needed to calculate the specified 
aggregates.
     case a @ Aggregate(_, _, child) if (child.outputSet -- 
a.references).nonEmpty =>

http://git-wip-us.apache.org/repos/asf/spark/blob/6d0ead32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 4cb67aa..fb963e2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -235,33 +235,17 @@ case class Window(
     projectList ++ windowExpressions.map(_.toAttribute)
 }
 
-/**
- * Apply the all of the GroupExpressions to every input row, hence we will get
- * multiple output rows for a input row.
- * @param bitmasks The bitmask set represents the grouping sets
- * @param groupByExprs The grouping by expressions
- * @param child       Child operator
- */
-case class Expand(
-    bitmasks: Seq[Int],
-    groupByExprs: Seq[Expression],
-    gid: Attribute,
-    child: LogicalPlan) extends UnaryNode {
-  override def statistics: Statistics = {
-    val sizeInBytes = child.statistics.sizeInBytes * projections.length
-    Statistics(sizeInBytes = sizeInBytes)
-  }
-
-  val projections: Seq[Seq[Expression]] = expand()
-
+private[sql] object Expand {
   /**
-   * Extract attribute set according to the grouping id
+   * Extract attribute set according to the grouping id.
+   *
    * @param bitmask bitmask to represent the selected of the attribute sequence
    * @param exprs the attributes in sequence
    * @return the attributes of non selected specified via bitmask (with the 
bit set to 1)
    */
-  private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression])
-  : OpenHashSet[Expression] = {
+  private def buildNonSelectExprSet(
+      bitmask: Int,
+      exprs: Seq[Expression]): OpenHashSet[Expression] = {
     val set = new OpenHashSet[Expression](2)
 
     var bit = exprs.length - 1
@@ -274,18 +258,28 @@ case class Expand(
   }
 
   /**
-   * Create an array of Projections for the child projection, and replace the 
projections'
-   * expressions which equal GroupBy expressions with Literal(null), if those 
expressions
-   * are not set for this grouping set (according to the bit mask).
+   * Apply the all of the GroupExpressions to every input row, hence we will 
get
+   * multiple output rows for a input row.
+   *
+   * @param bitmasks The bitmask set represents the grouping sets
+   * @param groupByExprs The grouping by expressions
+   * @param gid Attribute of the grouping id
+   * @param child Child operator
    */
-  private[this] def expand(): Seq[Seq[Expression]] = {
-    val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]]
-
-    bitmasks.foreach { bitmask =>
+  def apply(
+    bitmasks: Seq[Int],
+    groupByExprs: Seq[Expression],
+    gid: Attribute,
+    child: LogicalPlan): Expand = {
+    // Create an array of Projections for the child projection, and replace 
the projections'
+    // expressions which equal GroupBy expressions with Literal(null), if 
those expressions
+    // are not set for this grouping set (according to the bit mask).
+    val projections = bitmasks.map { bitmask =>
       // get the non selected grouping attributes according to the bit mask
       val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, 
groupByExprs)
 
-      val substitution = (child.output :+ gid).map(expr => expr transformDown {
+      (child.output :+ gid).map(expr => expr transformDown {
+        // TODO this causes a problem when a column is used both for grouping 
and aggregation.
         case x: Expression if nonSelectedGroupExprSet.contains(x) =>
           // if the input attribute in the Invalid Grouping Expression set of 
for this group
           // replace it with constant null
@@ -294,15 +288,29 @@ case class Expand(
           // replace the groupingId with concrete value (the bit mask)
           Literal.create(bitmask, IntegerType)
       })
-
-      result += substitution
     }
-
-    result.toSeq
+    Expand(projections, child.output :+ gid, child)
   }
+}
 
-  override def output: Seq[Attribute] = {
-    child.output :+ gid
+/**
+ * Apply a number of projections to every input row, hence we will get 
multiple output rows for
+ * a input row.
+ *
+ * @param projections to apply
+ * @param output of all projections.
+ * @param child operator.
+ */
+case class Expand(
+    projections: Seq[Seq[Expression]],
+    output: Seq[Attribute],
+    child: LogicalPlan) extends UnaryNode {
+
+  override def statistics: Statistics = {
+    // TODO shouldn't we factor in the size of the projection versus the size 
of the backing child
+    //      row?
+    val sizeInBytes = child.statistics.sizeInBytes * projections.length
+    Statistics(sizeInBytes = sizeInBytes)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6d0ead32/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index f4464e0..dd3bb33 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -420,7 +420,7 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         }
       case logical.Filter(condition, child) =>
         execution.Filter(condition, planLater(child)) :: Nil
-      case e @ logical.Expand(_, _, _, child) =>
+      case e @ logical.Expand(_, _, child) =>
         execution.Expand(e.projections, e.output, planLater(child)) :: Nil
       case a @ logical.Aggregate(group, agg, child) => {
         val useNewAggregation = sqlContext.conf.useSqlAggregate2 && 
sqlContext.conf.codegenEnabled


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to