Repository: spark
Updated Branches:
  refs/heads/master 5c4e6d7ec -> 30c8ba71a


[SPARK-11451][SQL] Support single distinct count on multiple columns.

This PR adds support for multiple column in a single count distinct aggregate 
to the new aggregation path.

cc yhuai

Author: Herman van Hovell <hvanhov...@questtec.nl>

Closes #9409 from hvanhovell/SPARK-11451.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/30c8ba71
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/30c8ba71
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/30c8ba71

Branch: refs/heads/master
Commit: 30c8ba71a76788cbc6916bc1ba6bc8522925fc2b
Parents: 5c4e6d7
Author: Herman van Hovell <hvanhov...@questtec.nl>
Authored: Sun Nov 8 11:06:10 2015 -0800
Committer: Yin Huai <yh...@databricks.com>
Committed: Sun Nov 8 11:06:10 2015 -0800

----------------------------------------------------------------------
 .../catalyst/expressions/aggregate/Utils.scala  | 44 +++++++++++---------
 .../expressions/conditionalExpressions.scala    | 30 ++++++++++++-
 .../catalyst/plans/logical/basicOperators.scala |  3 ++
 .../ConditionalExpressionSuite.scala            | 14 +++++++
 .../spark/sql/DataFrameAggregateSuite.scala     | 25 +++++++++++
 .../hive/execution/AggregationQuerySuite.scala  | 37 +++++++++++++---
 6 files changed, 127 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/30c8ba71/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
index ac23f72..9b22ce2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
@@ -22,26 +22,27 @@ import org.apache.spark.sql.catalyst._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.{Expand, Aggregate, 
LogicalPlan}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.types.{IntegerType, StructType, MapType, ArrayType}
+import org.apache.spark.sql.types._
 
 /**
  * Utility functions used by the query planner to convert our plan to new 
aggregation code path.
  */
 object Utils {
-  // Right now, we do not support complex types in the grouping key schema.
-  private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = {
-    val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists 
{
-      case array: ArrayType => true
-      case map: MapType => true
-      case struct: StructType => true
-      case _ => false
-    }
 
-    !hasComplexTypes
+  // Check if the DataType given cannot be part of a group by clause.
+  private def isUnGroupable(dt: DataType): Boolean = dt match {
+    case _: ArrayType | _: MapType => true
+    case s: StructType => s.fields.exists(f => isUnGroupable(f.dataType))
+    case _ => false
   }
 
+  // Right now, we do not support complex types in the grouping key schema.
+  private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean =
+    !aggregate.groupingExpressions.exists(e => isUnGroupable(e.dataType))
+
   private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
     case p: Aggregate if supportsGroupingKeySchema(p) =>
+
       val converted = 
MultipleDistinctRewriter.rewrite(p.transformExpressionsDown {
         case expressions.Average(child) =>
           aggregate.AggregateExpression2(
@@ -55,10 +56,14 @@ object Utils {
             mode = aggregate.Complete,
             isDistinct = false)
 
-        // We do not support multiple COUNT DISTINCT columns for now.
-        case expressions.CountDistinct(children) if children.length == 1 =>
+        case expressions.CountDistinct(children) =>
+          val child = if (children.size > 1) {
+            DropAnyNull(CreateStruct(children))
+          } else {
+            children.head
+          }
           aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.Count(children.head),
+            aggregateFunction = aggregate.Count(child),
             mode = aggregate.Complete,
             isDistinct = true)
 
@@ -320,7 +325,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
       val gid = new AttributeReference("gid", IntegerType, false)()
       val groupByMap = a.groupingExpressions.collect {
         case ne: NamedExpression => ne -> ne.toAttribute
-        case e => e -> new AttributeReference(e.prettyName, e.dataType, 
e.nullable)()
+        case e => e -> new AttributeReference(e.prettyString, e.dataType, 
e.nullable)()
       }
       val groupByAttrs = groupByMap.map(_._2)
 
@@ -365,14 +370,15 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] 
{
       // Setup expand for the 'regular' aggregate expressions.
       val regularAggExprs = aggExpressions.filter(!_.isDistinct)
       val regularAggChildren = 
regularAggExprs.flatMap(_.aggregateFunction.children).distinct
-      val regularAggChildAttrMap = 
regularAggChildren.map(expressionAttributePair).toMap
+      val regularAggChildAttrMap = 
regularAggChildren.map(expressionAttributePair)
 
       // Setup aggregates for 'regular' aggregate expressions.
       val regularGroupId = Literal(0)
+      val regularAggChildAttrLookup = regularAggChildAttrMap.toMap
       val regularAggOperatorMap = regularAggExprs.map { e =>
         // Perform the actual aggregation in the initial aggregate.
-        val af = 
patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrMap)
-        val operator = Alias(e.copy(aggregateFunction = af), e.toString)()
+        val af = 
patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup)
+        val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)()
 
         // Select the result of the first aggregate in the last aggregate.
         val result = AggregateExpression2(
@@ -416,7 +422,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
       // Construct the expand operator.
       val expand = Expand(
         regularAggProjection ++ distinctAggProjections,
-        groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ 
regularAggChildAttrMap.values.toSeq,
+        groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ 
regularAggChildAttrMap.map(_._2),
         a.child)
 
       // Construct the first aggregate operator. This de-duplicates the all 
the children of
@@ -457,5 +463,5 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] {
     // NamedExpression. This is done to prevent collisions between distinct 
and regular aggregate
     // children, in this case attribute reuse causes the input of the regular 
aggregate to bound to
     // the (nulled out) input of the distinct aggregate.
-    e -> new AttributeReference(e.prettyName, e.dataType, true)()
+    e -> new AttributeReference(e.prettyString, e.dataType, true)()
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/30c8ba71/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index d532629..0d4af43 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.util.TypeUtils
-import org.apache.spark.sql.types.{NullType, BooleanType, DataType}
+import org.apache.spark.sql.types._
 
 
 case class If(predicate: Expression, trueValue: Expression, falseValue: 
Expression)
@@ -419,3 +419,31 @@ case class Greatest(children: Seq[Expression]) extends 
Expression {
     """
   }
 }
+
+/** Operator that drops a row when it contains any nulls. */
+case class DropAnyNull(child: Expression) extends UnaryExpression with 
ExpectsInputTypes {
+  override def nullable: Boolean = true
+  override def dataType: DataType = child.dataType
+  override def inputTypes: Seq[AbstractDataType] = Seq(StructType)
+
+  protected override def nullSafeEval(input: Any): InternalRow = {
+    val row = input.asInstanceOf[InternalRow]
+    if (row.anyNull) {
+      null
+    } else {
+      row
+    }
+  }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
+    nullSafeCodeGen(ctx, ev, eval => {
+      s"""
+        if ($eval.anyNull()) {
+          ${ev.isNull} = true;
+        } else {
+          ${ev.value} = $eval;
+        }
+      """
+    })
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/30c8ba71/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index fb963e2..09aac00 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -306,6 +306,9 @@ case class Expand(
     output: Seq[Attribute],
     child: LogicalPlan) extends UnaryNode {
 
+  override def references: AttributeSet =
+    AttributeSet(projections.flatten.flatMap(_.references))
+
   override def statistics: Statistics = {
     // TODO shouldn't we factor in the size of the projection versus the size 
of the backing child
     //      row?

http://git-wip-us.apache.org/repos/asf/spark/blob/30c8ba71/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
index 0df673b..c1e3c17 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
@@ -231,4 +231,18 @@ class ConditionalExpressionSuite extends SparkFunSuite 
with ExpressionEvalHelper
       checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2)
     }
   }
+
+  test("function dropAnyNull") {
+    val drop = DropAnyNull(CreateStruct(Seq('a.string.at(0), 'b.string.at(1))))
+    val a = create_row("a", "q")
+    val nullStr: String = null
+    checkEvaluation(drop, a, a)
+    checkEvaluation(drop, null, create_row("b", nullStr))
+    checkEvaluation(drop, null, create_row(nullStr, nullStr))
+
+    val row = 'r.struct(
+      StructField("a", StringType, false),
+      StructField("b", StringType, true)).at(0)
+    checkEvaluation(DropAnyNull(row), null, create_row(null))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/30c8ba71/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 2e679e7..eb1ee26 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -162,6 +162,31 @@ class DataFrameAggregateSuite extends QueryTest with 
SharedSQLContext {
     )
   }
 
+  test("multiple column distinct count") {
+    val df1 = Seq(
+      ("a", "b", "c"),
+      ("a", "b", "c"),
+      ("a", "b", "d"),
+      ("x", "y", "z"),
+      ("x", "q", null.asInstanceOf[String]))
+      .toDF("key1", "key2", "key3")
+
+    checkAnswer(
+      df1.agg(countDistinct('key1, 'key2)),
+      Row(3)
+    )
+
+    checkAnswer(
+      df1.agg(countDistinct('key1, 'key2, 'key3)),
+      Row(3)
+    )
+
+    checkAnswer(
+      df1.groupBy('key1).agg(countDistinct('key2, 'key3)),
+      Seq(Row("a", 2), Row("x", 1))
+    )
+  }
+
   test("zero count") {
     val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
     checkAnswer(

http://git-wip-us.apache.org/repos/asf/spark/blob/30c8ba71/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 7f6fe33..ea36c13 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -516,21 +516,46 @@ abstract class AggregationQuerySuite extends QueryTest 
with SQLTestUtils with Te
         Row(3, 4, 4, 3, null) :: Nil)
   }
 
-  test("multiple distinct column sets") {
+  test("single distinct multiple columns set") {
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT
+          |  key,
+          |  count(distinct value1, value2)
+          |FROM agg2
+          |GROUP BY key
+        """.stripMargin),
+      Row(null, 3) ::
+        Row(1, 3) ::
+        Row(2, 1) ::
+        Row(3, 0) :: Nil)
+  }
+
+  test("multiple distinct multiple columns sets") {
     checkAnswer(
       sqlContext.sql(
         """
           |SELECT
           |  key,
           |  count(distinct value1),
-          |  count(distinct value2)
+          |  sum(distinct value1),
+          |  count(distinct value2),
+          |  sum(distinct value2),
+          |  count(distinct value1, value2),
+          |  count(value1),
+          |  sum(value1),
+          |  count(value2),
+          |  sum(value2),
+          |  count(*),
+          |  count(1)
           |FROM agg2
           |GROUP BY key
         """.stripMargin),
-      Row(null, 3, 3) ::
-        Row(1, 2, 3) ::
-        Row(2, 2, 1) ::
-        Row(3, 0, 1) :: Nil)
+      Row(null, 3, 30, 3, 60, 3, 3, 30, 3, 60, 4, 4) ::
+        Row(1, 2, 40, 3, -10, 3, 3, 70, 3, -10, 3, 3) ::
+        Row(2, 2, 0, 1, 1, 1, 3, 1, 3, 3, 4, 4) ::
+        Row(3, 0, null, 1, 3, 0, 0, null, 1, 3, 2, 2) :: Nil)
   }
 
   test("test count") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to