Repository: spark
Updated Branches:
  refs/heads/master 47735cdc2 -> dfcfcbcc0


[SPARK-11578][SQL][FOLLOW-UP] complete the user facing api for typed aggregation

Currently the user facing api for typed aggregation has some limitations:

* the customized typed aggregation must be the first of aggregation list
* the customized typed aggregation can only use long as buffer type
* the customized typed aggregation can only use flat type as result type

This PR tries to remove these limitations.

Author: Wenchen Fan <wenc...@databricks.com>

Closes #9599 from cloud-fan/agg.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/dfcfcbcc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/dfcfcbcc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/dfcfcbcc

Branch: refs/heads/master
Commit: dfcfcbcc0448ebc6f02eba6bf0495832a321c87e
Parents: 47735cd
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Tue Nov 10 11:14:25 2015 -0800
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Tue Nov 10 11:14:25 2015 -0800

----------------------------------------------------------------------
 .../catalyst/encoders/ExpressionEncoder.scala   |  6 +++
 .../aggregate/TypedAggregateExpression.scala    | 50 +++++++++++++------
 .../spark/sql/expressions/Aggregator.scala      |  5 ++
 .../spark/sql/DatasetAggregatorSuite.scala      | 52 ++++++++++++++++++++
 4 files changed, 99 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/dfcfcbcc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index c287aeb..005c062 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -185,6 +185,12 @@ case class ExpressionEncoder[T](
     })
   }
 
+  def shift(delta: Int): ExpressionEncoder[T] = {
+    copy(constructExpression = constructExpression transform {
+      case r: BoundReference => r.copy(ordinal = r.ordinal + delta)
+    })
+  }
+
   /**
    * Returns a copy of this encoder where the expressions used to create an 
object given an
    * input row have been modified to pull the object out from a nested struct, 
instead of the

http://git-wip-us.apache.org/repos/asf/spark/blob/dfcfcbcc/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 24d8122..0e5bc1f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -20,13 +20,13 @@ package org.apache.spark.sql.execution.aggregate
 import scala.language.existentials
 
 import org.apache.spark.Logging
+import org.apache.spark.sql.expressions.Aggregator
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
 import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
-import org.apache.spark.sql.expressions.Aggregator
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.{StructType, DataType}
+import org.apache.spark.sql.types._
 
 object TypedAggregateExpression {
   def apply[A, B : Encoder, C : Encoder](
@@ -67,8 +67,11 @@ case class TypedAggregateExpression(
 
   override def nullable: Boolean = true
 
-  // TODO: this assumes flat results...
-  override def dataType: DataType = cEncoder.schema.head.dataType
+  override def dataType: DataType = if (cEncoder.flat) {
+    cEncoder.schema.head.dataType
+  } else {
+    cEncoder.schema
+  }
 
   override def deterministic: Boolean = true
 
@@ -93,32 +96,51 @@ case class TypedAggregateExpression(
       case a: AttributeReference => inputMapping(a)
     })
 
-  // TODO: this probably only works when we are in the first column.
   val bAttributes = bEncoder.schema.toAttributes
   lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes)
 
+  private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = {
+    // todo: need a more neat way to assign the value.
+    var i = 0
+    while (i < aggBufferAttributes.length) {
+      aggBufferSchema(i).dataType match {
+        case IntegerType => buffer.setInt(mutableAggBufferOffset + i, 
value.getInt(i))
+        case LongType => buffer.setLong(mutableAggBufferOffset + i, 
value.getLong(i))
+      }
+      i += 1
+    }
+  }
+
   override def initialize(buffer: MutableRow): Unit = {
-    // TODO: We need to either force Aggregator to have a zero or we need to 
eliminate the need for
-    // this in execution.
-    buffer.setInt(mutableAggBufferOffset, aggregator.zero.asInstanceOf[Int])
+    val zero = bEncoder.toRow(aggregator.zero)
+    updateBuffer(buffer, zero)
   }
 
   override def update(buffer: MutableRow, input: InternalRow): Unit = {
     val inputA = boundA.fromRow(input)
-    val currentB = boundB.fromRow(buffer)
+    val currentB = boundB.shift(mutableAggBufferOffset).fromRow(buffer)
     val merged = aggregator.reduce(currentB, inputA)
     val returned = boundB.toRow(merged)
-    buffer.setInt(mutableAggBufferOffset, returned.getInt(0))
+
+    updateBuffer(buffer, returned)
   }
 
   override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
-    buffer1.setLong(
-      mutableAggBufferOffset,
-      buffer1.getLong(mutableAggBufferOffset) + 
buffer2.getLong(inputAggBufferOffset))
+    val b1 = boundB.shift(mutableAggBufferOffset).fromRow(buffer1)
+    val b2 = boundB.shift(inputAggBufferOffset).fromRow(buffer2)
+    val merged = aggregator.merge(b1, b2)
+    val returned = boundB.toRow(merged)
+
+    updateBuffer(buffer1, returned)
   }
 
   override def eval(buffer: InternalRow): Any = {
-    buffer.getInt(mutableAggBufferOffset)
+    val b = boundB.shift(mutableAggBufferOffset).fromRow(buffer)
+    val result = cEncoder.toRow(aggregator.present(b))
+    dataType match {
+      case _: StructType => result
+      case _ => result.get(0, dataType)
+    }
   }
 
   override def toString: String = {

http://git-wip-us.apache.org/repos/asf/spark/blob/dfcfcbcc/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 8cc25c2..3c1c457 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -58,6 +58,11 @@ abstract class Aggregator[-A, B, C] {
   def reduce(b: B, a: A): B
 
   /**
+   * Merge two intermediate values
+   */
+  def merge(b1: B, b2: B): B
+
+  /**
    * Transform the output of the reduction.
    */
   def present(reduction: B): C

http://git-wip-us.apache.org/repos/asf/spark/blob/dfcfcbcc/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 340470c..206095a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -34,9 +34,41 @@ class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, 
N, N] with Serializ
 
   override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
 
+  override def merge(b1: N, b2: N): N = numeric.plus(b1, b2)
+
   override def present(reduction: N): N = reduction
 }
 
+object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] 
with Serializable {
+  override def zero: (Long, Long) = (0, 0)
+
+  override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, 
Long) = {
+    (countAndSum._1 + 1, countAndSum._2 + input._2)
+  }
+
+  override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = {
+    (b1._1 + b2._1, b1._2 + b2._2)
+  }
+
+  override def present(countAndSum: (Long, Long)): Double = countAndSum._2 / 
countAndSum._1
+}
+
+object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, 
Long)]
+  with Serializable {
+
+  override def zero: (Long, Long) = (0, 0)
+
+  override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, 
Long) = {
+    (countAndSum._1 + 1, countAndSum._2 + input._2)
+  }
+
+  override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = {
+    (b1._1 + b2._1, b1._2 + b2._2)
+  }
+
+  override def present(reduction: (Long, Long)): (Long, Long) = reduction
+}
+
 class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
 
   import testImplicits._
@@ -62,4 +94,24 @@ class DatasetAggregatorSuite extends QueryTest with 
SharedSQLContext {
         count("*")),
       ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L))
   }
+
+  test("typed aggregation: complex case") {
+    val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS()
+
+    checkAnswer(
+      ds.groupBy(_._1).agg(
+        expr("avg(_2)").as[Double],
+        TypedAverage.toColumn),
+      ("a", 2.0, 2.0), ("b", 3.0, 3.0))
+  }
+
+  test("typed aggregation: complex result type") {
+    val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS()
+
+    checkAnswer(
+      ds.groupBy(_._1).agg(
+        expr("avg(_2)").as[Double],
+        ComplexResultAgg.toColumn),
+      ("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L)))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to