Repository: spark
Updated Branches:
  refs/heads/master 2f3837885 -> 9c740a9dd


[SPARK-11578][SQL] User API for Typed Aggregation

This PR adds a new interface for user-defined aggregations, that can be used in 
`DataFrame` and `Dataset` operations to take all of the elements of a group and 
reduce them to a single value.

For example, the following aggregator extracts an `int` from a specific class 
and adds them up:

```scala
  case class Data(i: Int)

  val customSummer =  new Aggregator[Data, Int, Int] {
    def prepare(d: Data) = d.i
    def reduce(l: Int, r: Int) = l + r
    def present(r: Int) = r
  }.toColumn()

  val ds: Dataset[Data] = ...
  val aggregated = ds.select(customSummer)
```

By using helper functions, users can make a generic `Aggregator` that works on 
any input type:

```scala
/** An `Aggregator` that adds up any numeric type returned by the given 
function. */
class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with 
Serializable {
  val numeric = implicitly[Numeric[N]]
  override def zero: N = numeric.zero
  override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
  override def present(reduction: N): N = reduction
}

def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new 
SumOf(f).toColumn
```

These aggregators can then be used alongside other built-in SQL aggregations.

```scala
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
ds
  .groupBy(_._1)
  .agg(
    sum(_._2),                // The aggregator defined above.
    expr("sum(_2)").as[Int],  // A built-in dynatically typed aggregation.
    count("*"))               // A built-in statically typed aggregation.
  .collect()

res0: ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L)
```

The current implementation focuses on integrating this into the typed API, but 
currently only supports running aggregations that return a single long value as 
explained in `TypedAggregateExpression`.  This will be improved in a followup 
PR.

Author: Michael Armbrust <mich...@databricks.com>

Closes #9555 from marmbrus/dataset-useragg.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9c740a9d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9c740a9d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9c740a9d

Branch: refs/heads/master
Commit: 9c740a9ddf6344a03b4b45380eaf0cfc6e2299b5
Parents: 2f38378
Author: Michael Armbrust <mich...@databricks.com>
Authored: Mon Nov 9 16:11:00 2015 -0800
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Mon Nov 9 16:11:00 2015 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/Column.scala     |  11 +-
 .../scala/org/apache/spark/sql/Dataset.scala    |  30 ++---
 .../org/apache/spark/sql/GroupedDataset.scala   |  51 +++++---
 .../scala/org/apache/spark/sql/SQLContext.scala |   1 -
 .../aggregate/TypedAggregateExpression.scala    | 129 +++++++++++++++++++
 .../spark/sql/expressions/Aggregator.scala      |  81 ++++++++++++
 .../scala/org/apache/spark/sql/functions.scala  |  30 ++++-
 .../org/apache/spark/sql/JavaDatasetSuite.java  |   4 +-
 .../spark/sql/DatasetAggregatorSuite.scala      |  65 ++++++++++
 9 files changed, 360 insertions(+), 42 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9c740a9d/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index c32c938..d26b6c3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -23,7 +23,7 @@ import org.apache.spark.annotation.Experimental
 import org.apache.spark.Logging
 import org.apache.spark.sql.functions.lit
 import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.util.DataTypeParser
 import org.apache.spark.sql.types._
@@ -39,10 +39,13 @@ private[sql] object Column {
 }
 
 /**
- * A [[Column]] where an [[Encoder]] has been given for the expected return 
type.
+ * A [[Column]] where an [[Encoder]] has been given for the expected input and 
return type.
  * @since 1.6.0
+ * @tparam T The input type expected for this expression.  Can be `Any` if the 
expression is type
+ *           checked by the analyzer instead of the compiler (i.e. 
`expr("sum(...)")`).
+ * @tparam U The output type of this column.
  */
-class TypedColumn[T](expr: Expression)(implicit val encoder: Encoder[T]) 
extends Column(expr)
+class TypedColumn[-T, U](expr: Expression, val encoder: Encoder[U]) extends 
Column(expr)
 
 /**
  * :: Experimental ::
@@ -85,7 +88,7 @@ class Column(protected[sql] val expr: Expression) extends 
Logging {
    * results into the correct JVM types.
    * @since 1.6.0
    */
-  def as[T : Encoder]: TypedColumn[T] = new TypedColumn[T](expr)
+  def as[U : Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](expr, 
encoderFor[U])
 
   /**
    * Extracts a value or values from a complex type.

http://git-wip-us.apache.org/repos/asf/spark/blob/9c740a9d/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 959e0f5..6d2968e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -358,7 +358,7 @@ class Dataset[T] private[sql](
    * }}}
    * @since 1.6.0
    */
-  def select[U1: Encoder](c1: TypedColumn[U1]): Dataset[U1] = {
+  def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
     new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, 
logicalPlan))
   }
 
@@ -367,7 +367,7 @@ class Dataset[T] private[sql](
    * code reuse, we do this without the help of the type system and then use 
helper functions
    * that cast appropriately for the user facing interface.
    */
-  protected def selectUntyped(columns: TypedColumn[_]*): Dataset[_] = {
+  protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
     val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, 
s"_${i + 1}")() }
     val unresolvedPlan = Project(aliases, logicalPlan)
     val execution = new QueryExecution(sqlContext, unresolvedPlan)
@@ -385,7 +385,7 @@ class Dataset[T] private[sql](
    * Returns a new [[Dataset]] by computing the given [[Column]] expressions 
for each element.
    * @since 1.6.0
    */
-  def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, 
U2)] =
+  def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): 
Dataset[(U1, U2)] =
     selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]]
 
   /**
@@ -393,9 +393,9 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   def select[U1, U2, U3](
-      c1: TypedColumn[U1],
-      c2: TypedColumn[U2],
-      c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] =
+      c1: TypedColumn[T, U1],
+      c2: TypedColumn[T, U2],
+      c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] =
     selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]]
 
   /**
@@ -403,10 +403,10 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   def select[U1, U2, U3, U4](
-      c1: TypedColumn[U1],
-      c2: TypedColumn[U2],
-      c3: TypedColumn[U3],
-      c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] =
+      c1: TypedColumn[T, U1],
+      c2: TypedColumn[T, U2],
+      c3: TypedColumn[T, U3],
+      c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] =
     selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]]
 
   /**
@@ -414,11 +414,11 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   def select[U1, U2, U3, U4, U5](
-      c1: TypedColumn[U1],
-      c2: TypedColumn[U2],
-      c3: TypedColumn[U3],
-      c4: TypedColumn[U4],
-      c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] =
+      c1: TypedColumn[T, U1],
+      c2: TypedColumn[T, U2],
+      c3: TypedColumn[T, U3],
+      c4: TypedColumn[T, U4],
+      c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] =
     selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, 
U5)]]
 
   /* **************** *

http://git-wip-us.apache.org/repos/asf/spark/blob/9c740a9d/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 850315e..db61499 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql
 
 import java.util.{Iterator => JIterator}
+
 import scala.collection.JavaConverters._
 
 import org.apache.spark.annotation.Experimental
@@ -26,8 +27,10 @@ import 
org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttrib
 import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, 
Encoder}
 import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, 
Alias, Attribute}
 import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
 import org.apache.spark.sql.execution.QueryExecution
 
+
 /**
  * :: Experimental ::
  * A [[Dataset]] has been logically grouped by a user specified grouping key.  
Users should not
@@ -143,7 +146,7 @@ class GroupedDataset[K, T] private[sql](
    * that cast appropriately for the user facing interface.
    * TODO: does not handle aggrecations that return nonflat results,
    */
-  protected def aggUntyped(columns: TypedColumn[_]*): Dataset[_] = {
+  protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
     val aliases = (groupingAttributes ++ columns.map(_.expr)).map {
       case u: UnresolvedAttribute => UnresolvedAlias(u)
       case expr: NamedExpression => expr
@@ -151,7 +154,15 @@ class GroupedDataset[K, T] private[sql](
     }
 
     val unresolvedPlan = Aggregate(groupingAttributes, aliases, logicalPlan)
-    val execution = new QueryExecution(sqlContext, unresolvedPlan)
+
+    // Fill in the input encoders for any aggregators in the plan.
+    val withEncoders = unresolvedPlan transformAllExpressions {
+      case ta: TypedAggregateExpression if ta.aEncoder.isEmpty =>
+        ta.copy(
+          aEncoder = Some(tEnc.asInstanceOf[ExpressionEncoder[Any]]),
+          children = dataAttributes)
+    }
+    val execution = new QueryExecution(sqlContext, withEncoders)
 
     val columnEncoders = 
columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]])
 
@@ -162,43 +173,47 @@ class GroupedDataset[K, T] private[sql](
       case (e, a) =>
         e.unbind(a :: Nil).resolve(execution.analyzed.output)
     }
-    new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
+
+    new Dataset(
+      sqlContext,
+      execution,
+      ExpressionEncoder.tuple(encoders))
   }
 
   /**
    * Computes the given aggregation, returning a [[Dataset]] of tuples for 
each unique key
    * and the result of computing this aggregation over all elements in the 
group.
    */
-  def agg[A1](col1: TypedColumn[A1]): Dataset[(K, A1)] =
-    aggUntyped(col1).asInstanceOf[Dataset[(K, A1)]]
+  def agg[U1](col1: TypedColumn[T, U1]): Dataset[(K, U1)] =
+    aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]]
 
   /**
    * Computes the given aggregations, returning a [[Dataset]] of tuples for 
each unique key
    * and the result of computing these aggregations over all elements in the 
group.
    */
-  def agg[A1, A2](col1: TypedColumn[A1], col2: TypedColumn[A2]): Dataset[(K, 
A1, A2)] =
-    aggUntyped(col1, col2).asInstanceOf[Dataset[(K, A1, A2)]]
+  def agg[U1, U2](col1: TypedColumn[T, U1], col2: TypedColumn[T, U2]): 
Dataset[(K, U1, U2)] =
+    aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]]
 
   /**
    * Computes the given aggregations, returning a [[Dataset]] of tuples for 
each unique key
    * and the result of computing these aggregations over all elements in the 
group.
    */
-  def agg[A1, A2, A3](
-      col1: TypedColumn[A1],
-      col2: TypedColumn[A2],
-      col3: TypedColumn[A3]): Dataset[(K, A1, A2, A3)] =
-    aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, A1, A2, A3)]]
+  def agg[U1, U2, U3](
+      col1: TypedColumn[T, U1],
+      col2: TypedColumn[T, U2],
+      col3: TypedColumn[T, U3]): Dataset[(K, U1, U2, U3)] =
+    aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]]
 
   /**
    * Computes the given aggregations, returning a [[Dataset]] of tuples for 
each unique key
    * and the result of computing these aggregations over all elements in the 
group.
    */
-  def agg[A1, A2, A3, A4](
-      col1: TypedColumn[A1],
-      col2: TypedColumn[A2],
-      col3: TypedColumn[A3],
-      col4: TypedColumn[A4]): Dataset[(K, A1, A2, A3, A4)] =
-    aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, A1, A2, A3, 
A4)]]
+  def agg[U1, U2, U3, U4](
+      col1: TypedColumn[T, U1],
+      col2: TypedColumn[T, U2],
+      col3: TypedColumn[T, U3],
+      col4: TypedColumn[T, U4]): Dataset[(K, U1, U2, U3, U4)] =
+    aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, 
U4)]]
 
   /**
    * Returns a [[Dataset]] that contains a tuple with each key and the number 
of items present

http://git-wip-us.apache.org/repos/asf/spark/blob/9c740a9d/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 5598731..1cf1e30 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -21,7 +21,6 @@ import java.beans.{BeanInfo, Introspector}
 import java.util.Properties
 import java.util.concurrent.atomic.AtomicReference
 
-
 import scala.collection.JavaConverters._
 import scala.collection.immutable
 import scala.reflect.runtime.universe.TypeTag

http://git-wip-us.apache.org/repos/asf/spark/blob/9c740a9d/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
new file mode 100644
index 0000000..24d8122
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import scala.language.existentials
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
+import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
+import org.apache.spark.sql.expressions.Aggregator
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.{StructType, DataType}
+
+object TypedAggregateExpression {
+  def apply[A, B : Encoder, C : Encoder](
+      aggregator: Aggregator[A, B, C]): TypedAggregateExpression = {
+    new TypedAggregateExpression(
+      aggregator.asInstanceOf[Aggregator[Any, Any, Any]],
+      None,
+      encoderFor[B].asInstanceOf[ExpressionEncoder[Any]],
+      encoderFor[C].asInstanceOf[ExpressionEncoder[Any]],
+      Nil,
+      0,
+      0)
+  }
+}
+
+/**
+ * This class is a rough sketch of how to hook `Aggregator` into the 
Aggregation system.  It has
+ * the following limitations:
+ *  - It assumes the aggregator reduces and returns a single column of type 
`long`.
+ *  - It might only work when there is a single aggregator in the first column.
+ *  - It assumes the aggregator has a zero, `0`.
+ */
+case class TypedAggregateExpression(
+    aggregator: Aggregator[Any, Any, Any],
+    aEncoder: Option[ExpressionEncoder[Any]],
+    bEncoder: ExpressionEncoder[Any],
+    cEncoder: ExpressionEncoder[Any],
+    children: Seq[Expression],
+    mutableAggBufferOffset: Int,
+    inputAggBufferOffset: Int)
+  extends ImperativeAggregate with Logging {
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override def nullable: Boolean = true
+
+  // TODO: this assumes flat results...
+  override def dataType: DataType = cEncoder.schema.head.dataType
+
+  override def deterministic: Boolean = true
+
+  override lazy val resolved: Boolean = aEncoder.isDefined
+
+  override lazy val inputTypes: Seq[DataType] =
+    aEncoder.map(_.schema.map(_.dataType)).getOrElse(Nil)
+
+  override val aggBufferSchema: StructType = bEncoder.schema
+
+  override val aggBufferAttributes: Seq[AttributeReference] = 
aggBufferSchema.toAttributes
+
+  // Note: although this simply copies aggBufferAttributes, this common code 
can not be placed
+  // in the superclass because that will lead to initialization ordering 
issues.
+  override val inputAggBufferAttributes: Seq[AttributeReference] =
+    aggBufferAttributes.map(_.newInstance())
+
+  lazy val inputAttributes = aEncoder.get.schema.toAttributes
+  lazy val inputMapping = AttributeMap(inputAttributes.zip(children))
+  lazy val boundA =
+    aEncoder.get.copy(constructExpression = aEncoder.get.constructExpression 
transform {
+      case a: AttributeReference => inputMapping(a)
+    })
+
+  // TODO: this probably only works when we are in the first column.
+  val bAttributes = bEncoder.schema.toAttributes
+  lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes)
+
+  override def initialize(buffer: MutableRow): Unit = {
+    // TODO: We need to either force Aggregator to have a zero or we need to 
eliminate the need for
+    // this in execution.
+    buffer.setInt(mutableAggBufferOffset, aggregator.zero.asInstanceOf[Int])
+  }
+
+  override def update(buffer: MutableRow, input: InternalRow): Unit = {
+    val inputA = boundA.fromRow(input)
+    val currentB = boundB.fromRow(buffer)
+    val merged = aggregator.reduce(currentB, inputA)
+    val returned = boundB.toRow(merged)
+    buffer.setInt(mutableAggBufferOffset, returned.getInt(0))
+  }
+
+  override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
+    buffer1.setLong(
+      mutableAggBufferOffset,
+      buffer1.getLong(mutableAggBufferOffset) + 
buffer2.getLong(inputAggBufferOffset))
+  }
+
+  override def eval(buffer: InternalRow): Any = {
+    buffer.getInt(mutableAggBufferOffset)
+  }
+
+  override def toString: String = {
+    s"""${aggregator.getClass.getSimpleName}(${children.mkString(",")})"""
+  }
+
+  override def nodeName: String = aggregator.getClass.getSimpleName
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/9c740a9d/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
new file mode 100644
index 0000000..0b3192a
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions
+
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, 
AggregateExpression2}
+import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
+import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn}
+
+/**
+ * A base class for user-defined aggregations, which can be used in 
[[DataFrame]] and [[Dataset]]
+ * operations to take all of the elements of a group and reduce them to a 
single value.
+ *
+ * For example, the following aggregator extracts an `int` from a specific 
class and adds them up:
+ * {{{
+ *   case class Data(i: Int)
+ *
+ *   val customSummer =  new Aggregator[Data, Int, Int] {
+ *     def zero = 0
+ *     def reduce(b: Int, a: Data) = b + a.i
+ *     def present(r: Int) = r
+ *   }.toColumn()
+ *
+ *   val ds: Dataset[Data]
+ *   val aggregated = ds.select(customSummer)
+ * }}}
+ *
+ * Based loosely on Aggregator from Algebird: 
https://github.com/twitter/algebird
+ *
+ * @tparam A The input type for the aggregation.
+ * @tparam B The type of the intermediate value of the reduction.
+ * @tparam C The type of the final result.
+ */
+abstract class Aggregator[-A, B, C] {
+
+  /** A zero value for this aggregation. Should satisfy the property that any 
b + zero = b */
+  def zero: B
+
+  /**
+   * Combine two values to produce a new value.  For performance, the function 
may modify `b` and
+   * return it instead of constructing new object for b.
+   */
+  def reduce(b: B, a: A): B
+
+  /**
+   * Transform the output of the reduction.
+   */
+  def present(reduction: B): C
+
+  /**
+   * Returns this `Aggregator` as a [[TypedColumn]] that can be used in 
[[Dataset]] or [[DataFrame]]
+   * operations.
+   */
+  def toColumn(
+      implicit bEncoder: Encoder[B],
+      cEncoder: Encoder[C]): TypedColumn[A, C] = {
+    val expr =
+      new AggregateExpression2(
+        TypedAggregateExpression(this),
+        Complete,
+        false)
+
+    new TypedColumn[A, C](expr, encoderFor[C])
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/9c740a9d/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 3f0b24b..6d56542 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql
 
+
+
 import scala.language.implicitConversions
 import scala.reflect.runtime.universe.{TypeTag, typeTag}
 import scala.util.Try
@@ -24,12 +26,33 @@ import scala.util.Try
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
 import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, Encoder}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
 /**
+ * Ensures that java functions signatures for methods that now return a 
[[TypedColumn]] still have
+ * legacy equivalents in bytecode.  This compatibility is done by forcing the 
compiler to generate
+ * "bridge" methods due to the use of covariant return types.
+ *
+ * {{{
+ * In LegacyFunctions:
+ * public abstract org.apache.spark.sql.Column avg(java.lang.String);
+ *
+ * In functions:
+ * public static org.apache.spark.sql.TypedColumn<java.lang.Object, 
java.lang.Object> avg(...);
+ * }}}
+ *
+ * This allows us to use the same functions both in typed [[Dataset]] 
operations and untyped
+ * [[DataFrame]] operations when the return type for a given function is 
statically known.
+ */
+private[sql] abstract class LegacyFunctions {
+  def count(columnName: String): Column
+}
+
+/**
  * :: Experimental ::
  * Functions available for [[DataFrame]].
  *
@@ -48,11 +71,14 @@ import org.apache.spark.util.Utils
  */
 @Experimental
 // scalastyle:off
-object functions {
+object functions extends LegacyFunctions {
 // scalastyle:on
 
   private def withExpr(expr: Expression): Column = Column(expr)
 
+  private implicit def newLongEncoder: Encoder[Long] = 
ExpressionEncoder[Long](flat = true)
+
+
   /**
    * Returns a [[Column]] based on the given column name.
    *
@@ -234,7 +260,7 @@ object functions {
    * @group agg_funcs
    * @since 1.3.0
    */
-  def count(columnName: String): Column = count(Column(columnName))
+  def count(columnName: String): TypedColumn[Any, Long] = 
count(Column(columnName)).as[Long]
 
   /**
    * Aggregate function: returns the number of distinct items in a group.

http://git-wip-us.apache.org/repos/asf/spark/blob/9c740a9d/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 312cf33..2da63d1 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -258,8 +258,8 @@ public class JavaDatasetSuite implements Serializable {
     Dataset<Integer> ds = context.createDataset(data, e.INT());
 
     Dataset<Tuple2<Integer, String>> selected = ds.select(
-      expr("value + 1").as(e.INT()),
-      col("value").cast("string").as(e.STRING()));
+      expr("value + 1"),
+      col("value").cast("string")).as(e.tuple(e.INT(), e.STRING()));
 
     Assert.assertEquals(
       Arrays.asList(tuple2(3, "2"), tuple2(7, "6")),

http://git-wip-us.apache.org/repos/asf/spark/blob/9c740a9d/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
new file mode 100644
index 0000000..340470c
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.functions._
+
+import scala.language.postfixOps
+
+import org.apache.spark.sql.test.SharedSQLContext
+
+import org.apache.spark.sql.expressions.Aggregator
+
+/** An `Aggregator` that adds up any numeric type returned by the given 
function. */
+class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with 
Serializable {
+  val numeric = implicitly[Numeric[N]]
+
+  override def zero: N = numeric.zero
+
+  override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
+
+  override def present(reduction: N): N = reduction
+}
+
+class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
+
+  import testImplicits._
+
+  def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] =
+    new SumOf(f).toColumn
+
+  test("typed aggregation: TypedAggregator") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+
+    checkAnswer(
+      ds.groupBy(_._1).agg(sum(_._2)),
+      ("a", 30), ("b", 3), ("c", 1))
+  }
+
+  test("typed aggregation: TypedAggregator, expr, expr") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+
+    checkAnswer(
+      ds.groupBy(_._1).agg(
+        sum(_._2),
+        expr("sum(_2)").as[Int],
+        count("*")),
+      ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L))
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to