Repository: spark Updated Branches: refs/heads/master b5c60bcdc -> 297ba3f1b
[SPARK-14275][SQL] Reimplement TypedAggregateExpression to DeclarativeAggregate ## What changes were proposed in this pull request? `ExpressionEncoder` is just a container for serialization and deserialization expressions, we can use these expressions to build `TypedAggregateExpression` directly, so that it can fit in `DeclarativeAggregate`, which is more efficient. One trick is, for each buffer serializer expression, it will reference to the result object of serialization and function call. To avoid re-calculating this result object, we can serialize the buffer object to a single struct field, so that we can use a special `Expression` to only evaluate result object once. ## How was this patch tested? existing tests Author: Wenchen Fan <wenc...@databricks.com> Closes #12067 from cloud-fan/typed_udaf. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/297ba3f1 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/297ba3f1 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/297ba3f1 Branch: refs/heads/master Commit: 297ba3f1b49cc37d9891a529142c553e0a5e2d62 Parents: b5c60bc Author: Wenchen Fan <wenc...@databricks.com> Authored: Fri Apr 15 12:10:00 2016 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Fri Apr 15 12:10:00 2016 +0800 ---------------------------------------------------------------------- .../sql/catalyst/expressions/Expression.scala | 2 +- .../expressions/ReferenceToExpressions.scala | 77 ++++++++ .../sql/catalyst/expressions/literals.scala | 3 +- .../org/apache/spark/sql/types/ObjectType.scala | 2 + .../scala/org/apache/spark/sql/Column.scala | 16 +- .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../spark/sql/KeyValueGroupedDataset.scala | 3 +- .../aggregate/TypedAggregateExpression.scala | 192 ++++++++++--------- .../spark/sql/expressions/Aggregator.scala | 6 +- .../org/apache/spark/sql/DatasetBenchmark.scala | 112 ++++++++--- .../scala/org/apache/spark/sql/QueryTest.scala | 2 + .../sql/execution/WholeStageCodegenSuite.scala | 14 ++ 12 files changed, 303 insertions(+), 130 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index a24a5db..718bb4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -185,7 +185,7 @@ abstract class Expression extends TreeNode[Expression] { * Returns a user-facing string representation of this expression's name. * This should usually match the name of the function in SQL. */ - def prettyName: String = getClass.getSimpleName.toLowerCase + def prettyName: String = nodeName.toLowerCase private def flatArguments = productIterator.flatMap { case t: Traversable[_] => t http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala new file mode 100644 index 0000000..22645c9 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types.DataType + +/** + * A special expression that evaluates [[BoundReference]]s by given expressions instead of the + * input row. + * + * @param result The expression that contains [[BoundReference]] and produces the final output. + * @param children The expressions that used as input values for [[BoundReference]]. + */ +case class ReferenceToExpressions(result: Expression, children: Seq[Expression]) + extends Expression { + + override def nullable: Boolean = result.nullable + override def dataType: DataType = result.dataType + + override def checkInputDataTypes(): TypeCheckResult = { + if (result.references.nonEmpty) { + return TypeCheckFailure("The result expression cannot reference to any attributes.") + } + + var maxOrdinal = -1 + result foreach { + case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal + } + if (maxOrdinal > children.length) { + return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " + + s"there are only ${children.length} inputs.") + } + + TypeCheckSuccess + } + + private lazy val projection = UnsafeProjection.create(children) + + override def eval(input: InternalRow): Any = { + result.eval(projection(input)) + } + + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + val childrenGen = children.map(_.gen(ctx)) + val childrenVars = childrenGen.zip(children).map { + case (childGen, child) => LambdaVariable(childGen.value, childGen.isNull, child.dataType) + } + + val resultGen = result.transform { + case b: BoundReference => childrenVars(b.ordinal) + }.gen(ctx) + + ev.value = resultGen.value + ev.isNull = resultGen.isNull + + childrenGen.map(_.code).mkString("\n") + "\n" + resultGen.code + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index e6804d0..7fd4bc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -60,7 +60,8 @@ object Literal { * Constructs a [[Literal]] of [[ObjectType]], for example when you need to pass an object * into code generation. */ - def fromObject(obj: AnyRef): Literal = new Literal(obj, ObjectType(obj.getClass)) + def fromObject(obj: Any, objType: DataType): Literal = new Literal(obj, objType) + def fromObject(obj: Any): Literal = new Literal(obj, ObjectType(obj.getClass)) def fromJSON(json: JValue): Literal = { val dataType = DataType.parseDataType(json \ "dataType") http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala index 06ee0fb..b7b1acc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -41,4 +41,6 @@ private[sql] case class ObjectType(cls: Class[_]) extends DataType { throw new UnsupportedOperationException("No size estimation available for objects.") def asNullable: DataType = this + + override def simpleString: String = cls.getName } http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/core/src/main/scala/org/apache/spark/sql/Column.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index d64736e..bd96941 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -59,14 +59,14 @@ class TypedColumn[-T, U]( * on a decoded object. */ private[sql] def withInputType( - inputEncoder: ExpressionEncoder[_], - schema: Seq[Attribute]): TypedColumn[T, U] = { - val boundEncoder = inputEncoder.bind(schema).asInstanceOf[ExpressionEncoder[Any]] - new TypedColumn[T, U]( - expr transform { case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => - ta.copy(aEncoder = Some(boundEncoder), children = schema) - }, - encoder) + inputDeserializer: Expression, + inputAttributes: Seq[Attribute]): TypedColumn[T, U] = { + val unresolvedDeserializer = UnresolvedDeserializer(inputDeserializer, inputAttributes) + val newExpr = expr transform { + case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => + ta.copy(inputDeserializer = Some(unresolvedDeserializer)) + } + new TypedColumn[T, U](newExpr, encoder) } } http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index e216945..4edc90d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -992,7 +992,7 @@ class Dataset[T] private[sql]( sqlContext, Project( c1.withInputType( - boundTEncoder, + unresolvedTEncoder.deserializer, logicalPlan.output).named :: Nil, logicalPlan), implicitly[Encoder[U1]]) @@ -1006,7 +1006,7 @@ class Dataset[T] private[sql]( protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named) + columns.map(_.withInputType(unresolvedTEncoder.deserializer, logicalPlan.output).named) val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan)) new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index f19ad6e..05e13e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -209,8 +209,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map( - _.withInputType(resolvedVEncoder, dataAttributes).named) + columns.map(_.withInputType(unresolvedVEncoder.deserializer, dataAttributes).named) val keyColumn = if (resolvedKEncoder.flat) { assert(groupingAttributes.length == 1) groupingAttributes.head http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 9abae53..535e64c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -19,133 +19,153 @@ package org.apache.spark.sql.execution.aggregate import scala.language.existentials -import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, OuterScopes} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate +import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.types._ object TypedAggregateExpression { - def apply[A, B : Encoder, C : Encoder]( - aggregator: Aggregator[A, B, C]): TypedAggregateExpression = { + def apply[BUF : Encoder, OUT : Encoder]( + aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = { + val bufferEncoder = encoderFor[BUF] + // We will insert the deserializer and function call expression at the bottom of each serializer + // expression while executing `TypedAggregateExpression`, which means multiply serializer + // expressions will all evaluate the same sub-expression at bottom. To avoid the re-evaluating, + // here we always use one single serializer expression to serialize the buffer object into a + // single-field row, no matter whether the encoder is flat or not. We also need to update the + // deserializer to read in all fields from that single-field row. + // TODO: remove this trick after we have better integration of subexpression elimination and + // whole stage codegen. + val bufferSerializer = if (bufferEncoder.flat) { + bufferEncoder.namedExpressions.head + } else { + Alias(CreateStruct(bufferEncoder.serializer), "buffer")() + } + + val bufferDeserializer = if (bufferEncoder.flat) { + bufferEncoder.deserializer transformUp { + case b: BoundReference => bufferSerializer.toAttribute + } + } else { + bufferEncoder.deserializer transformUp { + case UnresolvedAttribute(nameParts) => + assert(nameParts.length == 1) + UnresolvedExtractValue(bufferSerializer.toAttribute, Literal(nameParts.head)) + case BoundReference(ordinal, dt, _) => GetStructField(bufferSerializer.toAttribute, ordinal) + } + } + + val outputEncoder = encoderFor[OUT] + val outputType = if (outputEncoder.flat) { + outputEncoder.schema.head.dataType + } else { + outputEncoder.schema + } + new TypedAggregateExpression( aggregator.asInstanceOf[Aggregator[Any, Any, Any]], None, - encoderFor[B].asInstanceOf[ExpressionEncoder[Any]], - encoderFor[C].asInstanceOf[ExpressionEncoder[Any]], - Nil, - 0, - 0) + bufferSerializer, + bufferDeserializer, + outputEncoder.serializer, + outputEncoder.deserializer.dataType, + outputType) } } /** - * This class is a rough sketch of how to hook `Aggregator` into the Aggregation system. It has - * the following limitations: - * - It assumes the aggregator has a zero, `0`. + * A helper class to hook [[Aggregator]] into the aggregation system. */ case class TypedAggregateExpression( aggregator: Aggregator[Any, Any, Any], - aEncoder: Option[ExpressionEncoder[Any]], // Should be bound. - unresolvedBEncoder: ExpressionEncoder[Any], - cEncoder: ExpressionEncoder[Any], - children: Seq[Attribute], - mutableAggBufferOffset: Int, - inputAggBufferOffset: Int) - extends ImperativeAggregate with Logging { - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) + inputDeserializer: Option[Expression], + bufferSerializer: NamedExpression, + bufferDeserializer: Expression, + outputSerializer: Seq[Expression], + outputExternalType: DataType, + dataType: DataType) extends DeclarativeAggregate with NonSQLExpression { override def nullable: Boolean = true - override def dataType: DataType = if (cEncoder.flat) { - cEncoder.schema.head.dataType - } else { - cEncoder.schema - } - override def deterministic: Boolean = true - override lazy val resolved: Boolean = aEncoder.isDefined - - override lazy val inputTypes: Seq[DataType] = Nil + override def children: Seq[Expression] = inputDeserializer.toSeq :+ bufferDeserializer - override val aggBufferSchema: StructType = unresolvedBEncoder.schema + override lazy val resolved: Boolean = inputDeserializer.isDefined && childrenResolved - override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes + override def references: AttributeSet = AttributeSet(inputDeserializer.toSeq) - val bEncoder = unresolvedBEncoder - .resolve(aggBufferAttributes, OuterScopes.outerScopes) - .bind(aggBufferAttributes) + override def inputTypes: Seq[AbstractDataType] = Nil - // Note: although this simply copies aggBufferAttributes, this common code can not be placed - // in the superclass because that will lead to initialization ordering issues. - override val inputAggBufferAttributes: Seq[AttributeReference] = - aggBufferAttributes.map(_.newInstance()) + private def aggregatorLiteral = + Literal.create(aggregator, ObjectType(classOf[Aggregator[Any, Any, Any]])) - // We let the dataset do the binding for us. - lazy val boundA = aEncoder.get + private def bufferExternalType = bufferDeserializer.dataType - private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = { - var i = 0 - while (i < aggBufferAttributes.length) { - val offset = mutableAggBufferOffset + i - aggBufferSchema(i).dataType match { - case BooleanType => buffer.setBoolean(offset, value.getBoolean(i)) - case ByteType => buffer.setByte(offset, value.getByte(i)) - case ShortType => buffer.setShort(offset, value.getShort(i)) - case IntegerType => buffer.setInt(offset, value.getInt(i)) - case LongType => buffer.setLong(offset, value.getLong(i)) - case FloatType => buffer.setFloat(offset, value.getFloat(i)) - case DoubleType => buffer.setDouble(offset, value.getDouble(i)) - case other => buffer.update(offset, value.get(i, other)) - } - i += 1 - } - } + override lazy val aggBufferAttributes: Seq[AttributeReference] = + bufferSerializer.toAttribute.asInstanceOf[AttributeReference] :: Nil - override def initialize(buffer: MutableRow): Unit = { - val zero = bEncoder.toRow(aggregator.zero) - updateBuffer(buffer, zero) + override lazy val initialValues: Seq[Expression] = { + val zero = Literal.fromObject(aggregator.zero, bufferExternalType) + ReferenceToExpressions(bufferSerializer, zero :: Nil) :: Nil } - override def update(buffer: MutableRow, input: InternalRow): Unit = { - val inputA = boundA.fromRow(input) - val currentB = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) - val merged = aggregator.reduce(currentB, inputA) - val returned = bEncoder.toRow(merged) + override lazy val updateExpressions: Seq[Expression] = { + val reduced = Invoke( + aggregatorLiteral, + "reduce", + bufferExternalType, + bufferDeserializer :: inputDeserializer.get :: Nil) - updateBuffer(buffer, returned) + ReferenceToExpressions(bufferSerializer, reduced :: Nil) :: Nil } - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val b1 = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer1) - val b2 = bEncoder.shift(inputAggBufferOffset).fromRow(buffer2) - val merged = aggregator.merge(b1, b2) - val returned = bEncoder.toRow(merged) + override lazy val mergeExpressions: Seq[Expression] = { + val leftBuffer = bufferDeserializer transform { + case a: AttributeReference => a.left + } + val rightBuffer = bufferDeserializer transform { + case a: AttributeReference => a.right + } + val merged = Invoke( + aggregatorLiteral, + "merge", + bufferExternalType, + leftBuffer :: rightBuffer :: Nil) - updateBuffer(buffer1, returned) + ReferenceToExpressions(bufferSerializer, merged :: Nil) :: Nil } - override def eval(buffer: InternalRow): Any = { - val b = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) - val result = cEncoder.toRow(aggregator.finish(b)) + override lazy val evaluateExpression: Expression = { + val resultObj = Invoke( + aggregatorLiteral, + "finish", + outputExternalType, + bufferDeserializer :: Nil) + dataType match { - case _: StructType => result - case _ => result.get(0, dataType) + case s: StructType => + ReferenceToExpressions(CreateStruct(outputSerializer), resultObj :: Nil) + case _ => + assert(outputSerializer.length == 1) + outputSerializer.head transform { + case b: BoundReference => resultObj + } } } override def toString: String = { - s"""${aggregator.getClass.getSimpleName}(${children.mkString(",")})""" + val input = inputDeserializer match { + case Some(UnresolvedDeserializer(deserializer, _)) => deserializer.dataType.simpleString + case Some(deserializer) => deserializer.dataType.simpleString + case _ => "unknown" + } + + s"$nodeName($input)" } - override def nodeName: String = aggregator.getClass.getSimpleName + override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$") } http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 7da8379..baae9dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql.expressions -import org.apache.spark.sql.{DataFrame, Dataset, Encoder, TypedColumn} +import org.apache.spark.sql.{Dataset, Encoder, TypedColumn} import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression /** - * A base class for user-defined aggregations, which can be used in [[DataFrame]] and [[Dataset]] - * operations to take all of the elements of a group and reduce them to a single value. + * A base class for user-defined aggregations, which can be used in [[Dataset]] operations to take + * all of the elements of a group and reduce them to a single value. * * For example, the following aggregator extracts an `int` from a specific class and adds them up: * {{{ http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala index 5f3dd90..ae9fb80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -18,6 +18,9 @@ package org.apache.spark.sql import org.apache.spark.SparkContext +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.expressions.scala.typed +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.StringType import org.apache.spark.util.Benchmark @@ -33,16 +36,17 @@ object DatasetBenchmark { val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) val benchmark = new Benchmark("back-to-back map", numRows) - val func = (d: Data) => Data(d.l + 1, d.s) - benchmark.addCase("Dataset") { iter => - var res = df.as[Data] + + val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + benchmark.addCase("RDD") { iter => + var res = rdd var i = 0 while (i < numChains) { - res = res.map(func) + res = rdd.map(func) i += 1 } - res.queryExecution.toRdd.foreach(_ => Unit) + res.foreach(_ => Unit) } benchmark.addCase("DataFrame") { iter => @@ -55,15 +59,14 @@ object DatasetBenchmark { res.queryExecution.toRdd.foreach(_ => Unit) } - val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) - benchmark.addCase("RDD") { iter => - var res = rdd + benchmark.addCase("Dataset") { iter => + var res = df.as[Data] var i = 0 while (i < numChains) { - res = rdd.map(func) + res = res.map(func) i += 1 } - res.foreach(_ => Unit) + res.queryExecution.toRdd.foreach(_ => Unit) } benchmark @@ -74,19 +77,20 @@ object DatasetBenchmark { val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) val benchmark = new Benchmark("back-to-back filter", numRows) - val func = (d: Data, i: Int) => d.l % (100L + i) == 0L val funcs = 0.until(numChains).map { i => (d: Data) => func(d, i) } - benchmark.addCase("Dataset") { iter => - var res = df.as[Data] + + val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + benchmark.addCase("RDD") { iter => + var res = rdd var i = 0 while (i < numChains) { - res = res.filter(funcs(i)) + res = rdd.filter(funcs(i)) i += 1 } - res.queryExecution.toRdd.foreach(_ => Unit) + res.foreach(_ => Unit) } benchmark.addCase("DataFrame") { iter => @@ -99,15 +103,54 @@ object DatasetBenchmark { res.queryExecution.toRdd.foreach(_ => Unit) } - val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) - benchmark.addCase("RDD") { iter => - var res = rdd + benchmark.addCase("Dataset") { iter => + var res = df.as[Data] var i = 0 while (i < numChains) { - res = rdd.filter(funcs(i)) + res = res.filter(funcs(i)) i += 1 } - res.foreach(_ => Unit) + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark + } + + object ComplexAggregator extends Aggregator[Data, Data, Long] { + override def zero: Data = Data(0, "") + + override def reduce(b: Data, a: Data): Data = Data(b.l + a.l, "") + + override def finish(reduction: Data): Long = reduction.l + + override def merge(b1: Data, b2: Data): Data = Data(b1.l + b2.l, "") + + override def bufferEncoder: Encoder[Data] = Encoders.product[Data] + + override def outputEncoder: Encoder[Long] = Encoders.scalaLong + } + + def aggregate(sqlContext: SQLContext, numRows: Long): Benchmark = { + import sqlContext.implicits._ + + val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val benchmark = new Benchmark("aggregate", numRows) + + val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + benchmark.addCase("RDD sum") { iter => + rdd.aggregate(0L)(_ + _.l, _ + _) + } + + benchmark.addCase("DataFrame sum") { iter => + df.select(sum($"l")).queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("Dataset sum using Aggregator") { iter => + df.as[Data].select(typed.sumLong((d: Data) => d.l)).queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("Dataset complex Aggregator") { iter => + df.as[Data].select(ComplexAggregator.toColumn).queryExecution.toRdd.foreach(_ => Unit) } benchmark @@ -117,30 +160,45 @@ object DatasetBenchmark { val sparkContext = new SparkContext("local[*]", "Dataset benchmark") val sqlContext = new SQLContext(sparkContext) - val numRows = 10000000 + val numRows = 100000000 val numChains = 10 val benchmark = backToBackMap(sqlContext, numRows, numChains) val benchmark2 = backToBackFilter(sqlContext, numRows, numChains) + val benchmark3 = aggregate(sqlContext, numRows) /* Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Dataset 902 / 995 11.1 90.2 1.0X - DataFrame 132 / 167 75.5 13.2 6.8X - RDD 216 / 237 46.3 21.6 4.2X + RDD 1935 / 2105 51.7 19.3 1.0X + DataFrame 756 / 799 132.3 7.6 2.6X + Dataset 7359 / 7506 13.6 73.6 0.3X */ benchmark.run() /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz back-to-back filter: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Dataset 585 / 628 17.1 58.5 1.0X - DataFrame 62 / 80 160.7 6.2 9.4X - RDD 205 / 220 48.7 20.5 2.8X + RDD 1974 / 2036 50.6 19.7 1.0X + DataFrame 103 / 127 967.4 1.0 19.1X + Dataset 4343 / 4477 23.0 43.4 0.5X */ benchmark2.run() + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + aggregate: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + RDD sum 2130 / 2166 46.9 21.3 1.0X + DataFrame sum 92 / 128 1085.3 0.9 23.1X + Dataset sum using Aggregator 4111 / 4282 24.3 41.1 0.5X + Dataset complex Aggregator 8782 / 9036 11.4 87.8 0.2X + */ + benchmark3.run() } } http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 8268628..23a0ce2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.LogicalRDD +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.streaming.MemoryPlan @@ -205,6 +206,7 @@ abstract class QueryTest extends PlanTest { case _: MemoryPlan => return }.transformAllExpressions { case a: ImperativeAggregate => return + case _: TypedAggregateExpression => return case Literal(_, _: ObjectType) => return } http://git-wip-us.apache.org/repos/asf/spark/blob/297ba3f1/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 4474cfc..8efd9de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.BroadcastHashJoin +import org.apache.spark.sql.expressions.scala.typed import org.apache.spark.sql.functions.{avg, broadcast, col, max} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} @@ -99,4 +100,17 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[SerializeFromObject]).isDefined) assert(ds.collect() === Array(0, 6)) } + + test("simple typed UDAF should be included in WholeStageCodegen") { + import testImplicits._ + + val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS() + .groupByKey(_._1).agg(typed.sum(_._2)) + + val plan = ds.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[TungstenAggregate]).isDefined) + assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0))) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org