Repository: spark
Updated Branches:
  refs/heads/master 3121e7816 -> 21c562fa0


[SPARK-9241][SQL] Supporting multiple DISTINCT columns - follow-up (3)

This PR is a 2nd follow-up for 
[SPARK-9241](https://issues.apache.org/jira/browse/SPARK-9241). It contains the 
following improvements:
* Fix for a potential bug in distinct child expression and attribute alignment.
* Improved handling of duplicate distinct child expressions.
* Added test for distinct UDAF with multiple children.

cc yhuai

Author: Herman van Hovell <hvanhov...@questtec.nl>

Closes #9566 from hvanhovell/SPARK-9241-followup-2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/21c562fa
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/21c562fa
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/21c562fa

Branch: refs/heads/master
Commit: 21c562fa03430365f5c2b7d6de1f8f60ab2140d4
Parents: 3121e78
Author: Herman van Hovell <hvanhov...@questtec.nl>
Authored: Tue Nov 10 16:28:21 2015 -0800
Committer: Yin Huai <yh...@databricks.com>
Committed: Tue Nov 10 16:28:21 2015 -0800

----------------------------------------------------------------------
 .../analysis/DistinctAggregationRewriter.scala  |  9 +++--
 .../hive/execution/AggregationQuerySuite.scala  | 41 ++++++++++++++++++--
 2 files changed, 42 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/21c562fa/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
index 397eff0..c0c9604 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala
@@ -151,11 +151,12 @@ case class DistinctAggregationRewriter(conf: 
CatalystConf) extends Rule[LogicalP
       }
 
       // Setup unique distinct aggregate children.
-      val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq
-      val distinctAggChildAttrMap = 
distinctAggChildren.map(expressionAttributePair).toMap
-      val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq
+      val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct
+      val distinctAggChildAttrMap = 
distinctAggChildren.map(expressionAttributePair)
+      val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2)
 
       // Setup expand & aggregate operators for distinct aggregate expressions.
+      val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap
       val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
         case ((group, expressions), i) =>
           val id = Literal(i + 1)
@@ -170,7 +171,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) 
extends Rule[LogicalP
           val operators = expressions.map { e =>
             val af = e.aggregateFunction
             val naf = patchAggregateFunctionChildren(af) { x =>
-              evalWithinGroup(id, distinctAggChildAttrMap(x))
+              evalWithinGroup(id, distinctAggChildAttrLookup(x))
             }
             (e, e.copy(aggregateFunction = naf, isDistinct = false))
           }

http://git-wip-us.apache.org/repos/asf/spark/blob/21c562fa/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 6bf2c53..8253921 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -66,6 +66,36 @@ class ScalaAggregateFunction(schema: StructType) extends 
UserDefinedAggregateFun
   }
 }
 
+class LongProductSum extends UserDefinedAggregateFunction {
+  def inputSchema: StructType = new StructType()
+    .add("a", LongType)
+    .add("b", LongType)
+
+  def bufferSchema: StructType = new StructType()
+    .add("product", LongType)
+
+  def dataType: DataType = LongType
+
+  def deterministic: Boolean = true
+
+  def initialize(buffer: MutableAggregationBuffer): Unit = {
+    buffer(0) = 0L
+  }
+
+  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
+    if (!(input.isNullAt(0) || input.isNullAt(1))) {
+      buffer(0) = buffer.getLong(0) + input.getLong(0) * input.getLong(1)
+    }
+  }
+
+  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
+    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
+  }
+
+  def evaluate(buffer: Row): Any =
+    buffer.getLong(0)
+}
+
 abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with 
TestHiveSingleton {
   import testImplicits._
 
@@ -110,6 +140,7 @@ abstract class AggregationQuerySuite extends QueryTest with 
SQLTestUtils with Te
     // Register UDAFs
     sqlContext.udf.register("mydoublesum", new MyDoubleSum)
     sqlContext.udf.register("mydoubleavg", new MyDoubleAvg)
+    sqlContext.udf.register("longProductSum", new LongProductSum)
   }
 
   override def afterAll(): Unit = {
@@ -545,19 +576,21 @@ abstract class AggregationQuerySuite extends QueryTest 
with SQLTestUtils with Te
           |  count(distinct value2),
           |  sum(distinct value2),
           |  count(distinct value1, value2),
+          |  longProductSum(distinct value1, value2),
           |  count(value1),
           |  sum(value1),
           |  count(value2),
           |  sum(value2),
+          |  longProductSum(value1, value2),
           |  count(*),
           |  count(1)
           |FROM agg2
           |GROUP BY key
         """.stripMargin),
-      Row(null, 3, 30, 3, 60, 3, 3, 30, 3, 60, 4, 4) ::
-        Row(1, 2, 40, 3, -10, 3, 3, 70, 3, -10, 3, 3) ::
-        Row(2, 2, 0, 1, 1, 1, 3, 1, 3, 3, 4, 4) ::
-        Row(3, 0, null, 1, 3, 0, 0, null, 1, 3, 2, 2) :: Nil)
+      Row(null, 3, 30, 3, 60, 3, -4700, 3, 30, 3, 60, -4700, 4, 4) ::
+        Row(1, 2, 40, 3, -10, 3, -100, 3, 70, 3, -10, -100, 3, 3) ::
+        Row(2, 2, 0, 1, 1, 1, 1, 3, 1, 3, 3, 2, 4, 4) ::
+        Row(3, 0, null, 1, 3, 0, 0, 0, null, 1, 3, 0, 2, 2) :: Nil)
   }
 
   test("test count") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to