Repository: spark
Updated Branches:
  refs/heads/master b5c60bcdc -> 297ba3f1b


[SPARK-14275][SQL] Reimplement TypedAggregateExpression to DeclarativeAggregate

## What changes were proposed in this pull request?

`ExpressionEncoder` is just a container for serialization and deserialization 
expressions, we can use these expressions to build `TypedAggregateExpression` 
directly, so that it can fit in `DeclarativeAggregate`, which is more efficient.

One trick is, for each buffer serializer expression, it will reference to the 
result object of serialization and function call. To avoid re-calculating this 
result object, we can serialize the buffer object to a single struct field, so 
that we can use a special `Expression` to only evaluate result object once.

## How was this patch tested?

existing tests

Author: Wenchen Fan <wenc...@databricks.com>

Closes #12067 from cloud-fan/typed_udaf.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/297ba3f1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/297ba3f1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/297ba3f1

Branch: refs/heads/master
Commit: 297ba3f1b49cc37d9891a529142c553e0a5e2d62
Parents: b5c60bc
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Fri Apr 15 12:10:00 2016 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Fri Apr 15 12:10:00 2016 +0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/Expression.scala   |   2 +-
 .../expressions/ReferenceToExpressions.scala    |  77 ++++++++
 .../sql/catalyst/expressions/literals.scala     |   3 +-
 .../org/apache/spark/sql/types/ObjectType.scala |   2 +
 .../scala/org/apache/spark/sql/Column.scala     |  16 +-
 .../scala/org/apache/spark/sql/Dataset.scala    |   4 +-
 .../spark/sql/KeyValueGroupedDataset.scala      |   3 +-
 .../aggregate/TypedAggregateExpression.scala    | 192 ++++++++++---------
 .../spark/sql/expressions/Aggregator.scala      |   6 +-
 .../org/apache/spark/sql/DatasetBenchmark.scala | 112 ++++++++---
 .../scala/org/apache/spark/sql/QueryTest.scala  |   2 +
 .../sql/execution/WholeStageCodegenSuite.scala  |  14 ++
 12 files changed, 303 insertions(+), 130 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index a24a5db..718bb4b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -185,7 +185,7 @@ abstract class Expression extends TreeNode[Expression] {
    * Returns a user-facing string representation of this expression's name.
    * This should usually match the name of the function in SQL.
    */
-  def prettyName: String = getClass.getSimpleName.toLowerCase
+  def prettyName: String = nodeName.toLowerCase
 
   private def flatArguments = productIterator.flatMap {
     case t: Traversable[_] => t

http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
new file mode 100644
index 0000000..22645c9
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, 
TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode}
+import org.apache.spark.sql.types.DataType
+
+/**
+ * A special expression that evaluates [[BoundReference]]s by given 
expressions instead of the
+ * input row.
+ *
+ * @param result The expression that contains [[BoundReference]] and produces 
the final output.
+ * @param children The expressions that used as input values for 
[[BoundReference]].
+ */
+case class ReferenceToExpressions(result: Expression, children: 
Seq[Expression])
+  extends Expression {
+
+  override def nullable: Boolean = result.nullable
+  override def dataType: DataType = result.dataType
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (result.references.nonEmpty) {
+      return TypeCheckFailure("The result expression cannot reference to any 
attributes.")
+    }
+
+    var maxOrdinal = -1
+    result foreach {
+      case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = 
b.ordinal
+    }
+    if (maxOrdinal > children.length) {
+      return TypeCheckFailure(s"The result expression need $maxOrdinal input 
expressions, but " +
+        s"there are only ${children.length} inputs.")
+    }
+
+    TypeCheckSuccess
+  }
+
+  private lazy val projection = UnsafeProjection.create(children)
+
+  override def eval(input: InternalRow): Any = {
+    result.eval(projection(input))
+  }
+
+  override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = {
+    val childrenGen = children.map(_.gen(ctx))
+    val childrenVars = childrenGen.zip(children).map {
+      case (childGen, child) => LambdaVariable(childGen.value, 
childGen.isNull, child.dataType)
+    }
+
+    val resultGen = result.transform {
+      case b: BoundReference => childrenVars(b.ordinal)
+    }.gen(ctx)
+
+    ev.value = resultGen.value
+    ev.isNull = resultGen.isNull
+
+    childrenGen.map(_.code).mkString("\n") + "\n" + resultGen.code
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index e6804d0..7fd4bc3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -60,7 +60,8 @@ object Literal {
    * Constructs a [[Literal]] of [[ObjectType]], for example when you need to 
pass an object
    * into code generation.
    */
-  def fromObject(obj: AnyRef): Literal = new Literal(obj, 
ObjectType(obj.getClass))
+  def fromObject(obj: Any, objType: DataType): Literal = new Literal(obj, 
objType)
+  def fromObject(obj: Any): Literal = new Literal(obj, 
ObjectType(obj.getClass))
 
   def fromJSON(json: JValue): Literal = {
     val dataType = DataType.parseDataType(json \ "dataType")

http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala
index 06ee0fb..b7b1acc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala
@@ -41,4 +41,6 @@ private[sql] case class ObjectType(cls: Class[_]) extends 
DataType {
     throw new UnsupportedOperationException("No size estimation available for 
objects.")
 
   def asNullable: DataType = this
+
+  override def simpleString: String = cls.getName
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index d64736e..bd96941 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -59,14 +59,14 @@ class TypedColumn[-T, U](
    * on a decoded object.
    */
   private[sql] def withInputType(
-      inputEncoder: ExpressionEncoder[_],
-      schema: Seq[Attribute]): TypedColumn[T, U] = {
-    val boundEncoder = 
inputEncoder.bind(schema).asInstanceOf[ExpressionEncoder[Any]]
-    new TypedColumn[T, U](
-      expr transform { case ta: TypedAggregateExpression if 
ta.aEncoder.isEmpty =>
-        ta.copy(aEncoder = Some(boundEncoder), children = schema)
-      },
-      encoder)
+      inputDeserializer: Expression,
+      inputAttributes: Seq[Attribute]): TypedColumn[T, U] = {
+    val unresolvedDeserializer = UnresolvedDeserializer(inputDeserializer, 
inputAttributes)
+    val newExpr = expr transform {
+      case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty =>
+        ta.copy(inputDeserializer = Some(unresolvedDeserializer))
+    }
+    new TypedColumn[T, U](newExpr, encoder)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index e216945..4edc90d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -992,7 +992,7 @@ class Dataset[T] private[sql](
       sqlContext,
       Project(
         c1.withInputType(
-          boundTEncoder,
+          unresolvedTEncoder.deserializer,
           logicalPlan.output).named :: Nil,
         logicalPlan),
       implicitly[Encoder[U1]])
@@ -1006,7 +1006,7 @@ class Dataset[T] private[sql](
   protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
     val encoders = columns.map(_.encoder)
     val namedColumns =
-      columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named)
+      columns.map(_.withInputType(unresolvedTEncoder.deserializer, 
logicalPlan.output).named)
     val execution = new QueryExecution(sqlContext, Project(namedColumns, 
logicalPlan))
 
     new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))

http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index f19ad6e..05e13e6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -209,8 +209,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
   protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
     val encoders = columns.map(_.encoder)
     val namedColumns =
-      columns.map(
-        _.withInputType(resolvedVEncoder, dataAttributes).named)
+      columns.map(_.withInputType(unresolvedVEncoder.deserializer, 
dataAttributes).named)
     val keyColumn = if (resolvedKEncoder.flat) {
       assert(groupingAttributes.length == 1)
       groupingAttributes.head

http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 9abae53..535e64c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -19,133 +19,153 @@ package org.apache.spark.sql.execution.aggregate
 
 import scala.language.existentials
 
-import org.apache.spark.internal.Logging
 import org.apache.spark.sql.Encoder
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, 
OuterScopes}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, 
UnresolvedDeserializer, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.encoders.encoderFor
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
+import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
 import org.apache.spark.sql.expressions.Aggregator
 import org.apache.spark.sql.types._
 
 object TypedAggregateExpression {
-  def apply[A, B : Encoder, C : Encoder](
-      aggregator: Aggregator[A, B, C]): TypedAggregateExpression = {
+  def apply[BUF : Encoder, OUT : Encoder](
+      aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = {
+    val bufferEncoder = encoderFor[BUF]
+    // We will insert the deserializer and function call expression at the 
bottom of each serializer
+    // expression while executing `TypedAggregateExpression`, which means 
multiply serializer
+    // expressions will all evaluate the same sub-expression at bottom.  To 
avoid the re-evaluating,
+    // here we always use one single serializer expression to serialize the 
buffer object into a
+    // single-field row, no matter whether the encoder is flat or not.  We 
also need to update the
+    // deserializer to read in all fields from that single-field row.
+    // TODO: remove this trick after we have  better integration of 
subexpression elimination and
+    // whole stage codegen.
+    val bufferSerializer = if (bufferEncoder.flat) {
+      bufferEncoder.namedExpressions.head
+    } else {
+      Alias(CreateStruct(bufferEncoder.serializer), "buffer")()
+    }
+
+    val bufferDeserializer = if (bufferEncoder.flat) {
+      bufferEncoder.deserializer transformUp {
+        case b: BoundReference => bufferSerializer.toAttribute
+      }
+    } else {
+      bufferEncoder.deserializer transformUp {
+        case UnresolvedAttribute(nameParts) =>
+          assert(nameParts.length == 1)
+          UnresolvedExtractValue(bufferSerializer.toAttribute, 
Literal(nameParts.head))
+        case BoundReference(ordinal, dt, _) => 
GetStructField(bufferSerializer.toAttribute, ordinal)
+      }
+    }
+
+    val outputEncoder = encoderFor[OUT]
+    val outputType = if (outputEncoder.flat) {
+      outputEncoder.schema.head.dataType
+    } else {
+      outputEncoder.schema
+    }
+
     new TypedAggregateExpression(
       aggregator.asInstanceOf[Aggregator[Any, Any, Any]],
       None,
-      encoderFor[B].asInstanceOf[ExpressionEncoder[Any]],
-      encoderFor[C].asInstanceOf[ExpressionEncoder[Any]],
-      Nil,
-      0,
-      0)
+      bufferSerializer,
+      bufferDeserializer,
+      outputEncoder.serializer,
+      outputEncoder.deserializer.dataType,
+      outputType)
   }
 }
 
 /**
- * This class is a rough sketch of how to hook `Aggregator` into the 
Aggregation system.  It has
- * the following limitations:
- *  - It assumes the aggregator has a zero, `0`.
+ * A helper class to hook [[Aggregator]] into the aggregation system.
  */
 case class TypedAggregateExpression(
     aggregator: Aggregator[Any, Any, Any],
-    aEncoder: Option[ExpressionEncoder[Any]], // Should be bound.
-    unresolvedBEncoder: ExpressionEncoder[Any],
-    cEncoder: ExpressionEncoder[Any],
-    children: Seq[Attribute],
-    mutableAggBufferOffset: Int,
-    inputAggBufferOffset: Int)
-  extends ImperativeAggregate with Logging {
-
-  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
-    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
-  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
-    copy(inputAggBufferOffset = newInputAggBufferOffset)
+    inputDeserializer: Option[Expression],
+    bufferSerializer: NamedExpression,
+    bufferDeserializer: Expression,
+    outputSerializer: Seq[Expression],
+    outputExternalType: DataType,
+    dataType: DataType) extends DeclarativeAggregate with NonSQLExpression {
 
   override def nullable: Boolean = true
 
-  override def dataType: DataType = if (cEncoder.flat) {
-    cEncoder.schema.head.dataType
-  } else {
-    cEncoder.schema
-  }
-
   override def deterministic: Boolean = true
 
-  override lazy val resolved: Boolean = aEncoder.isDefined
-
-  override lazy val inputTypes: Seq[DataType] = Nil
+  override def children: Seq[Expression] = inputDeserializer.toSeq :+ 
bufferDeserializer
 
-  override val aggBufferSchema: StructType = unresolvedBEncoder.schema
+  override lazy val resolved: Boolean = inputDeserializer.isDefined && 
childrenResolved
 
-  override val aggBufferAttributes: Seq[AttributeReference] = 
aggBufferSchema.toAttributes
+  override def references: AttributeSet = AttributeSet(inputDeserializer.toSeq)
 
-  val bEncoder = unresolvedBEncoder
-    .resolve(aggBufferAttributes, OuterScopes.outerScopes)
-    .bind(aggBufferAttributes)
+  override def inputTypes: Seq[AbstractDataType] = Nil
 
-  // Note: although this simply copies aggBufferAttributes, this common code 
can not be placed
-  // in the superclass because that will lead to initialization ordering 
issues.
-  override val inputAggBufferAttributes: Seq[AttributeReference] =
-    aggBufferAttributes.map(_.newInstance())
+  private def aggregatorLiteral =
+    Literal.create(aggregator, ObjectType(classOf[Aggregator[Any, Any, Any]]))
 
-  // We let the dataset do the binding for us.
-  lazy val boundA = aEncoder.get
+  private def bufferExternalType = bufferDeserializer.dataType
 
-  private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = {
-    var i = 0
-    while (i < aggBufferAttributes.length) {
-      val offset = mutableAggBufferOffset + i
-      aggBufferSchema(i).dataType match {
-        case BooleanType => buffer.setBoolean(offset, value.getBoolean(i))
-        case ByteType => buffer.setByte(offset, value.getByte(i))
-        case ShortType => buffer.setShort(offset, value.getShort(i))
-        case IntegerType => buffer.setInt(offset, value.getInt(i))
-        case LongType => buffer.setLong(offset, value.getLong(i))
-        case FloatType => buffer.setFloat(offset, value.getFloat(i))
-        case DoubleType => buffer.setDouble(offset, value.getDouble(i))
-        case other => buffer.update(offset, value.get(i, other))
-      }
-      i += 1
-    }
-  }
+  override lazy val aggBufferAttributes: Seq[AttributeReference] =
+    bufferSerializer.toAttribute.asInstanceOf[AttributeReference] :: Nil
 
-  override def initialize(buffer: MutableRow): Unit = {
-    val zero = bEncoder.toRow(aggregator.zero)
-    updateBuffer(buffer, zero)
+  override lazy val initialValues: Seq[Expression] = {
+    val zero = Literal.fromObject(aggregator.zero, bufferExternalType)
+    ReferenceToExpressions(bufferSerializer, zero :: Nil) :: Nil
   }
 
-  override def update(buffer: MutableRow, input: InternalRow): Unit = {
-    val inputA = boundA.fromRow(input)
-    val currentB = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer)
-    val merged = aggregator.reduce(currentB, inputA)
-    val returned = bEncoder.toRow(merged)
+  override lazy val updateExpressions: Seq[Expression] = {
+    val reduced = Invoke(
+      aggregatorLiteral,
+      "reduce",
+      bufferExternalType,
+      bufferDeserializer :: inputDeserializer.get :: Nil)
 
-    updateBuffer(buffer, returned)
+    ReferenceToExpressions(bufferSerializer, reduced :: Nil) :: Nil
   }
 
-  override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
-    val b1 = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer1)
-    val b2 = bEncoder.shift(inputAggBufferOffset).fromRow(buffer2)
-    val merged = aggregator.merge(b1, b2)
-    val returned = bEncoder.toRow(merged)
+  override lazy val mergeExpressions: Seq[Expression] = {
+    val leftBuffer = bufferDeserializer transform {
+      case a: AttributeReference => a.left
+    }
+    val rightBuffer = bufferDeserializer transform {
+      case a: AttributeReference => a.right
+    }
+    val merged = Invoke(
+      aggregatorLiteral,
+      "merge",
+      bufferExternalType,
+      leftBuffer :: rightBuffer :: Nil)
 
-    updateBuffer(buffer1, returned)
+    ReferenceToExpressions(bufferSerializer, merged :: Nil) :: Nil
   }
 
-  override def eval(buffer: InternalRow): Any = {
-    val b = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer)
-    val result = cEncoder.toRow(aggregator.finish(b))
+  override lazy val evaluateExpression: Expression = {
+    val resultObj = Invoke(
+      aggregatorLiteral,
+      "finish",
+      outputExternalType,
+      bufferDeserializer :: Nil)
+
     dataType match {
-      case _: StructType => result
-      case _ => result.get(0, dataType)
+      case s: StructType =>
+        ReferenceToExpressions(CreateStruct(outputSerializer), resultObj :: 
Nil)
+      case _ =>
+        assert(outputSerializer.length == 1)
+        outputSerializer.head transform {
+          case b: BoundReference => resultObj
+        }
     }
   }
 
   override def toString: String = {
-    s"""${aggregator.getClass.getSimpleName}(${children.mkString(",")})"""
+    val input = inputDeserializer match {
+      case Some(UnresolvedDeserializer(deserializer, _)) => 
deserializer.dataType.simpleString
+      case Some(deserializer) => deserializer.dataType.simpleString
+      case _ => "unknown"
+    }
+
+    s"$nodeName($input)"
   }
 
-  override def nodeName: String = aggregator.getClass.getSimpleName
+  override def nodeName: String = 
aggregator.getClass.getSimpleName.stripSuffix("$")
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 7da8379..baae9dd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -17,14 +17,14 @@
 
 package org.apache.spark.sql.expressions
 
-import org.apache.spark.sql.{DataFrame, Dataset, Encoder, TypedColumn}
+import org.apache.spark.sql.{Dataset, Encoder, TypedColumn}
 import org.apache.spark.sql.catalyst.encoders.encoderFor
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Complete}
 import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
 
 /**
- * A base class for user-defined aggregations, which can be used in 
[[DataFrame]] and [[Dataset]]
- * operations to take all of the elements of a group and reduce them to a 
single value.
+ * A base class for user-defined aggregations, which can be used in 
[[Dataset]] operations to take
+ * all of the elements of a group and reduce them to a single value.
  *
  * For example, the following aggregator extracts an `int` from a specific 
class and adds them up:
  * {{{

http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
index 5f3dd90..ae9fb80 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
@@ -18,6 +18,9 @@
 package org.apache.spark.sql
 
 import org.apache.spark.SparkContext
+import org.apache.spark.sql.expressions.Aggregator
+import org.apache.spark.sql.expressions.scala.typed
+import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.StringType
 import org.apache.spark.util.Benchmark
 
@@ -33,16 +36,17 @@ object DatasetBenchmark {
 
     val df = sqlContext.range(1, numRows).select($"id".as("l"), 
$"id".cast(StringType).as("s"))
     val benchmark = new Benchmark("back-to-back map", numRows)
-
     val func = (d: Data) => Data(d.l + 1, d.s)
-    benchmark.addCase("Dataset") { iter =>
-      var res = df.as[Data]
+
+    val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, 
l.toString))
+    benchmark.addCase("RDD") { iter =>
+      var res = rdd
       var i = 0
       while (i < numChains) {
-        res = res.map(func)
+        res = rdd.map(func)
         i += 1
       }
-      res.queryExecution.toRdd.foreach(_ => Unit)
+      res.foreach(_ => Unit)
     }
 
     benchmark.addCase("DataFrame") { iter =>
@@ -55,15 +59,14 @@ object DatasetBenchmark {
       res.queryExecution.toRdd.foreach(_ => Unit)
     }
 
-    val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, 
l.toString))
-    benchmark.addCase("RDD") { iter =>
-      var res = rdd
+    benchmark.addCase("Dataset") { iter =>
+      var res = df.as[Data]
       var i = 0
       while (i < numChains) {
-        res = rdd.map(func)
+        res = res.map(func)
         i += 1
       }
-      res.foreach(_ => Unit)
+      res.queryExecution.toRdd.foreach(_ => Unit)
     }
 
     benchmark
@@ -74,19 +77,20 @@ object DatasetBenchmark {
 
     val df = sqlContext.range(1, numRows).select($"id".as("l"), 
$"id".cast(StringType).as("s"))
     val benchmark = new Benchmark("back-to-back filter", numRows)
-
     val func = (d: Data, i: Int) => d.l % (100L + i) == 0L
     val funcs = 0.until(numChains).map { i =>
       (d: Data) => func(d, i)
     }
-    benchmark.addCase("Dataset") { iter =>
-      var res = df.as[Data]
+
+    val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, 
l.toString))
+    benchmark.addCase("RDD") { iter =>
+      var res = rdd
       var i = 0
       while (i < numChains) {
-        res = res.filter(funcs(i))
+        res = rdd.filter(funcs(i))
         i += 1
       }
-      res.queryExecution.toRdd.foreach(_ => Unit)
+      res.foreach(_ => Unit)
     }
 
     benchmark.addCase("DataFrame") { iter =>
@@ -99,15 +103,54 @@ object DatasetBenchmark {
       res.queryExecution.toRdd.foreach(_ => Unit)
     }
 
-    val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, 
l.toString))
-    benchmark.addCase("RDD") { iter =>
-      var res = rdd
+    benchmark.addCase("Dataset") { iter =>
+      var res = df.as[Data]
       var i = 0
       while (i < numChains) {
-        res = rdd.filter(funcs(i))
+        res = res.filter(funcs(i))
         i += 1
       }
-      res.foreach(_ => Unit)
+      res.queryExecution.toRdd.foreach(_ => Unit)
+    }
+
+    benchmark
+  }
+
+  object ComplexAggregator extends Aggregator[Data, Data, Long] {
+    override def zero: Data = Data(0, "")
+
+    override def reduce(b: Data, a: Data): Data = Data(b.l + a.l, "")
+
+    override def finish(reduction: Data): Long = reduction.l
+
+    override def merge(b1: Data, b2: Data): Data = Data(b1.l + b2.l, "")
+
+    override def bufferEncoder: Encoder[Data] = Encoders.product[Data]
+
+    override def outputEncoder: Encoder[Long] = Encoders.scalaLong
+  }
+
+  def aggregate(sqlContext: SQLContext, numRows: Long): Benchmark = {
+    import sqlContext.implicits._
+
+    val df = sqlContext.range(1, numRows).select($"id".as("l"), 
$"id".cast(StringType).as("s"))
+    val benchmark = new Benchmark("aggregate", numRows)
+
+    val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, 
l.toString))
+    benchmark.addCase("RDD sum") { iter =>
+      rdd.aggregate(0L)(_ + _.l, _ + _)
+    }
+
+    benchmark.addCase("DataFrame sum") { iter =>
+      df.select(sum($"l")).queryExecution.toRdd.foreach(_ => Unit)
+    }
+
+    benchmark.addCase("Dataset sum using Aggregator") { iter =>
+      df.as[Data].select(typed.sumLong((d: Data) => 
d.l)).queryExecution.toRdd.foreach(_ => Unit)
+    }
+
+    benchmark.addCase("Dataset complex Aggregator") { iter =>
+      
df.as[Data].select(ComplexAggregator.toColumn).queryExecution.toRdd.foreach(_ 
=> Unit)
     }
 
     benchmark
@@ -117,30 +160,45 @@ object DatasetBenchmark {
     val sparkContext = new SparkContext("local[*]", "Dataset benchmark")
     val sqlContext = new SQLContext(sparkContext)
 
-    val numRows = 10000000
+    val numRows = 100000000
     val numChains = 10
 
     val benchmark = backToBackMap(sqlContext, numRows, numChains)
     val benchmark2 = backToBackFilter(sqlContext, numRows, numChains)
+    val benchmark3 = aggregate(sqlContext, numRows)
 
     /*
     Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4
     Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
     back-to-back map:                   Best/Avg Time(ms)    Rate(M/s)   Per 
Row(ns)   Relative
     
-------------------------------------------------------------------------------------------
-    Dataset                                   902 /  995         11.1          
90.2       1.0X
-    DataFrame                                 132 /  167         75.5          
13.2       6.8X
-    RDD                                       216 /  237         46.3          
21.6       4.2X
+    RDD                                      1935 / 2105         51.7          
19.3       1.0X
+    DataFrame                                 756 /  799        132.3          
 7.6       2.6X
+    Dataset                                  7359 / 7506         13.6          
73.6       0.3X
     */
     benchmark.run()
 
     /*
+    Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4
+    Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
     back-to-back filter:                Best/Avg Time(ms)    Rate(M/s)   Per 
Row(ns)   Relative
     
-------------------------------------------------------------------------------------------
-    Dataset                                   585 /  628         17.1          
58.5       1.0X
-    DataFrame                                  62 /   80        160.7          
 6.2       9.4X
-    RDD                                       205 /  220         48.7          
20.5       2.8X
+    RDD                                      1974 / 2036         50.6          
19.7       1.0X
+    DataFrame                                 103 /  127        967.4          
 1.0      19.1X
+    Dataset                                  4343 / 4477         23.0          
43.4       0.5X
     */
     benchmark2.run()
+
+    /*
+    Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4
+    Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
+    aggregate:                          Best/Avg Time(ms)    Rate(M/s)   Per 
Row(ns)   Relative
+    
-------------------------------------------------------------------------------------------
+    RDD sum                                  2130 / 2166         46.9          
21.3       1.0X
+    DataFrame sum                              92 /  128       1085.3          
 0.9      23.1X
+    Dataset sum using Aggregator             4111 / 4282         24.3          
41.1       0.5X
+    Dataset complex Aggregator               8782 / 9036         11.4          
87.8       0.2X
+    */
+    benchmark3.run()
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 8268628..23a0ce2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.trees.TreeNode
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.execution.LogicalRDD
+import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
 import org.apache.spark.sql.execution.columnar.InMemoryRelation
 import org.apache.spark.sql.execution.datasources.LogicalRelation
 import org.apache.spark.sql.execution.streaming.MemoryPlan
@@ -205,6 +206,7 @@ abstract class QueryTest extends PlanTest {
       case _: MemoryPlan => return
     }.transformAllExpressions {
       case a: ImperativeAggregate => return
+      case _: TypedAggregateExpression => return
       case Literal(_, _: ObjectType) => return
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 4474cfc..8efd9de 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.execution.aggregate.TungstenAggregate
 import org.apache.spark.sql.execution.joins.BroadcastHashJoin
+import org.apache.spark.sql.expressions.scala.typed
 import org.apache.spark.sql.functions.{avg, broadcast, col, max}
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
@@ -99,4 +100,17 @@ class WholeStageCodegenSuite extends SparkPlanTest with 
SharedSQLContext {
         
p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[SerializeFromObject]).isDefined)
     assert(ds.collect() === Array(0, 6))
   }
+
+  test("simple typed UDAF should be included in WholeStageCodegen") {
+    import testImplicits._
+
+    val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS()
+      .groupByKey(_._1).agg(typed.sum(_._2))
+
+    val plan = ds.queryExecution.executedPlan
+    assert(plan.find(p =>
+      p.isInstanceOf[WholeStageCodegen] &&
+        
p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[TungstenAggregate]).isDefined)
+    assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to