Repository: spark
Updated Branches:
  refs/heads/branch-1.6 f061d2539 -> 6c1bf19e8


[SPARK-11654][SQL] add reduce to GroupedDataset

This PR adds a new method, `reduce`, to `GroupedDataset`, which allows similar 
operations to `reduceByKey` on a traditional `PairRDD`.

```scala
val ds = Seq("abc", "xyz", "hello").toDS()
ds.groupBy(_.length).reduce(_ + _).collect()  // not actually commutative :P

res0: Array(3 -> "abcxyz", 5 -> "hello")
```

While implementing this method and its test cases several more deficiencies 
were found in our encoder handling.  Specifically, in order to support 
positional resolution, named resolution and tuple composition, it is important 
to keep the unresolved encoder around and to use it when constructing new 
`Datasets` with the same object type but different output attributes.  We now 
divide the encoder lifecycle into three phases (that mirror the lifecycle of 
standard expressions) and have checks at various boundaries:

 - Unresoved Encoders: all users facing encoders (those constructed by 
implicits, static methods, or tuple composition) are unresolved, meaning they 
have only `UnresolvedAttributes` for named fields and `BoundReferences` for 
fields accessed by ordinal.
 - Resolved Encoders: internal to a `[Grouped]Dataset` the encoder is resolved, 
meaning all input has been resolved to a specific `AttributeReference`.  Any 
encoders that are placed into a logical plan for use in object construction 
should be resolved.
 - BoundEncoder: Are constructed by physical plans, right before actual 
conversion from row -> object is performed.

It is left to future work to add explicit checks for resolution and provide 
good error messages when it fails.  We might also consider enforcing the above 
constraints in the type system (i.e. `fromRow` only exists on a 
`ResolvedEncoder`), but we should probably wait before spending too much time 
on this.

Author: Michael Armbrust <mich...@databricks.com>
Author: Wenchen Fan <wenc...@databricks.com>

Closes #9673 from marmbrus/pr/9628.

(cherry picked from commit 41bbd2300472501d69ed46f0407d5ed7cbede4a8)
Signed-off-by: Michael Armbrust <mich...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6c1bf19e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6c1bf19e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6c1bf19e

Branch: refs/heads/branch-1.6
Commit: 6c1bf19e84aa08291e6595e6bc54f9e2e375169a
Parents: f061d25
Author: Michael Armbrust <mich...@databricks.com>
Authored: Thu Nov 12 17:20:30 2015 -0800
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Thu Nov 12 17:20:42 2015 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/Encoder.scala    |  10 +-
 .../catalyst/encoders/ExpressionEncoder.scala   | 124 +++++++++++--------
 .../spark/sql/catalyst/encoders/package.scala   |  11 +-
 .../expressions/complexTypeExtractors.scala     |   7 +-
 .../catalyst/plans/logical/basicOperators.scala |  15 ++-
 .../scala/org/apache/spark/sql/Column.scala     |  43 ++++++-
 .../scala/org/apache/spark/sql/DataFrame.scala  |  17 +--
 .../scala/org/apache/spark/sql/Dataset.scala    |  85 ++++++-------
 .../org/apache/spark/sql/GroupedDataset.scala   |  98 +++++++--------
 .../aggregate/TypedAggregateExpression.scala    |  13 +-
 .../spark/sql/execution/basicOperators.scala    |   7 +-
 .../org/apache/spark/sql/JavaDatasetSuite.java  |  12 +-
 .../spark/sql/DatasetAggregatorSuite.scala      |  42 +++++++
 .../org/apache/spark/sql/DatasetSuite.scala     |   9 ++
 .../scala/org/apache/spark/sql/QueryTest.scala  |  13 +-
 15 files changed, 309 insertions(+), 197 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6c1bf19e/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index 6134f9e..5f619d6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -84,7 +84,7 @@ object Encoders {
   private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] 
= {
     assert(encoders.length > 1)
     // make sure all encoders are resolved, i.e. `Attribute` has been resolved 
to `BoundReference`.
-    
assert(encoders.forall(_.constructExpression.find(_.isInstanceOf[Attribute]).isEmpty))
+    
assert(encoders.forall(_.fromRowExpression.find(_.isInstanceOf[Attribute]).isEmpty))
 
     val schema = StructType(encoders.zipWithIndex.map {
       case (e, i) => StructField(s"_${i + 1}", if (e.flat) 
e.schema.head.dataType else e.schema)
@@ -93,8 +93,8 @@ object Encoders {
     val cls = 
Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
 
     val extractExpressions = encoders.map {
-      case e if e.flat => e.extractExpressions.head
-      case other => CreateStruct(other.extractExpressions)
+      case e if e.flat => e.toRowExpressions.head
+      case other => CreateStruct(other.toRowExpressions)
     }.zipWithIndex.map { case (expr, index) =>
       expr.transformUp {
         case BoundReference(0, t: ObjectType, _) =>
@@ -107,11 +107,11 @@ object Encoders {
 
     val constructExpressions = encoders.zipWithIndex.map { case (enc, index) =>
       if (enc.flat) {
-        enc.constructExpression.transform {
+        enc.fromRowExpression.transform {
           case b: BoundReference => b.copy(ordinal = index)
         }
       } else {
-        enc.constructExpression.transformUp {
+        enc.fromRowExpression.transformUp {
           case BoundReference(ordinal, dt, _) =>
             GetInternalRowField(BoundReference(index, enc.schema, nullable = 
true), ordinal, dt)
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/6c1bf19e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 294afde..0d3e4aa 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import 
org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, 
GenerateUnsafeProjection}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.types.{StructField, ObjectType, StructType}
+import org.apache.spark.sql.types.{NullType, StructField, ObjectType, 
StructType}
 
 /**
  * A factory for constructing encoders that convert objects and primitves to 
and from the
@@ -61,20 +61,39 @@ object ExpressionEncoder {
 
   /**
    * Given a set of N encoders, constructs a new encoder that produce objects 
as items in an
-   * N-tuple.  Note that these encoders should first be bound correctly to the 
combined input
-   * schema.
+   * N-tuple.  Note that these encoders should be unresolved so that 
information about
+   * name/positional binding is preserved.
    */
   def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
+    encoders.foreach(_.assertUnresolved())
+
     val schema =
       StructType(
-        encoders.zipWithIndex.map { case (e, i) => StructField(s"_${i + 1}", 
e.schema)})
+        encoders.zipWithIndex.map {
+          case (e, i) => StructField(s"_${i + 1}", if (e.flat) 
e.schema.head.dataType else e.schema)
+        })
     val cls = 
Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
-    val extractExpressions = encoders.map {
-      case e if e.flat => e.extractExpressions.head
-      case other => CreateStruct(other.extractExpressions)
+
+    // Rebind the encoders to the nested schema.
+    val newConstructExpressions = encoders.zipWithIndex.map {
+      case (e, i) if !e.flat => e.nested(i).fromRowExpression
+      case (e, i) => e.shift(i).fromRowExpression
     }
+
     val constructExpression =
-      NewInstance(cls, encoders.map(_.constructExpression), false, 
ObjectType(cls))
+      NewInstance(cls, newConstructExpressions, false, ObjectType(cls))
+
+    val input = BoundReference(0, ObjectType(cls), false)
+    val extractExpressions = encoders.zipWithIndex.map {
+      case (e, i) if !e.flat => CreateStruct(e.toRowExpressions.map(_ 
transformUp {
+        case b: BoundReference =>
+          Invoke(input, s"_${i + 1}", b.dataType, Nil)
+      }))
+      case (e, i) => e.toRowExpressions.head transformUp {
+        case b: BoundReference =>
+          Invoke(input, s"_${i + 1}", b.dataType, Nil)
+      }
+    }
 
     new ExpressionEncoder[Any](
       schema,
@@ -95,35 +114,40 @@ object ExpressionEncoder {
  * A generic encoder for JVM objects.
  *
  * @param schema The schema after converting `T` to a Spark SQL row.
- * @param extractExpressions A set of expressions, one for each top-level 
field that can be used to
- *                           extract the values from a raw object.
+ * @param toRowExpressions A set of expressions, one for each top-level field 
that can be used to
+ *                           extract the values from a raw object into an 
[[InternalRow]].
+ * @param fromRowExpression An expression that will construct an object given 
an [[InternalRow]].
  * @param clsTag A classtag for `T`.
  */
 case class ExpressionEncoder[T](
     schema: StructType,
     flat: Boolean,
-    extractExpressions: Seq[Expression],
-    constructExpression: Expression,
+    toRowExpressions: Seq[Expression],
+    fromRowExpression: Expression,
     clsTag: ClassTag[T])
   extends Encoder[T] {
 
-  if (flat) require(extractExpressions.size == 1)
+  if (flat) require(toRowExpressions.size == 1)
 
   @transient
-  private lazy val extractProjection = 
GenerateUnsafeProjection.generate(extractExpressions)
+  private lazy val extractProjection = 
GenerateUnsafeProjection.generate(toRowExpressions)
   private val inputRow = new GenericMutableRow(1)
 
   @transient
-  private lazy val constructProjection = 
GenerateSafeProjection.generate(constructExpression :: Nil)
+  private lazy val constructProjection = 
GenerateSafeProjection.generate(fromRowExpression :: Nil)
 
   /**
    * Returns an encoded version of `t` as a Spark SQL row.  Note that multiple 
calls to
    * toRow are allowed to return the same actual [[InternalRow]] object.  
Thus, the caller should
    * copy the result before making another call if required.
    */
-  def toRow(t: T): InternalRow = {
+  def toRow(t: T): InternalRow = try {
     inputRow(0) = t
     extractProjection(inputRow)
+  } catch {
+    case e: Exception =>
+      throw new RuntimeException(
+        s"Error while encoding: 
$e\n${toRowExpressions.map(_.treeString).mkString("\n")}", e)
   }
 
   /**
@@ -135,7 +159,20 @@ case class ExpressionEncoder[T](
     constructProjection(row).get(0, 
ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
   } catch {
     case e: Exception =>
-      throw new RuntimeException(s"Error while decoding: 
$e\n${constructExpression.treeString}", e)
+      throw new RuntimeException(s"Error while decoding: 
$e\n${fromRowExpression.treeString}", e)
+  }
+
+  /**
+   * The process of resolution to a given schema throws away information about 
where a given field
+   * is being bound by ordinal instead of by name.  This method checks to make 
sure this process
+   * has not been done already in places where we plan to do later composition 
of encoders.
+   */
+  def assertUnresolved(): Unit = {
+    (fromRowExpression +:  toRowExpressions).foreach(_.foreach {
+      case a: AttributeReference =>
+        sys.error(s"Unresolved encoder expected, but $a was found.")
+      case _ =>
+    })
   }
 
   /**
@@ -143,9 +180,14 @@ case class ExpressionEncoder[T](
    * given schema.
    */
   def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = {
-    val plan = Project(Alias(constructExpression, "")() :: Nil, 
LocalRelation(schema))
+    val positionToAttribute = AttributeMap.toIndex(schema)
+    val unbound = fromRowExpression transform {
+      case b: BoundReference => positionToAttribute(b.ordinal)
+    }
+
+    val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema))
     val analyzedPlan = SimpleAnalyzer.execute(plan)
-    copy(constructExpression = analyzedPlan.expressions.head.children.head)
+    copy(fromRowExpression = analyzedPlan.expressions.head.children.head)
   }
 
   /**
@@ -154,39 +196,14 @@ case class ExpressionEncoder[T](
    * resolve before bind.
    */
   def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
-    copy(constructExpression = 
BindReferences.bindReference(constructExpression, schema))
-  }
-
-  /**
-   * Replaces any bound references in the schema with the attributes at the 
corresponding ordinal
-   * in the provided schema.  This can be used to "relocate" a given encoder 
to pull values from
-   * a different schema than it was initially bound to.  It can also be used 
to assign attributes
-   * to ordinal based extraction (i.e. because the input data was a tuple).
-   */
-  def unbind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
-    val positionToAttribute = AttributeMap.toIndex(schema)
-    copy(constructExpression = constructExpression transform {
-      case b: BoundReference => positionToAttribute(b.ordinal)
-    })
+    copy(fromRowExpression = BindReferences.bindReference(fromRowExpression, 
schema))
   }
 
   /**
-   * Given an encoder that has already been bound to a given schema, returns a 
new encoder
-   * where the positions are mapped from `oldSchema` to `newSchema`.  This can 
be used, for example,
-   * when you are trying to use an encoder on grouping keys that were 
originally part of a larger
-   * row, but now you have projected out only the key expressions.
+   * Returns a new encoder with input columns shifted by `delta` ordinals
    */
-  def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): 
ExpressionEncoder[T] = {
-    val positionToAttribute = AttributeMap.toIndex(oldSchema)
-    val attributeToNewPosition = AttributeMap.byIndex(newSchema)
-    copy(constructExpression = constructExpression transform {
-      case r: BoundReference =>
-        r.copy(ordinal = 
attributeToNewPosition(positionToAttribute(r.ordinal)))
-    })
-  }
-
   def shift(delta: Int): ExpressionEncoder[T] = {
-    copy(constructExpression = constructExpression transform {
+    copy(fromRowExpression = fromRowExpression transform {
       case r: BoundReference => r.copy(ordinal = r.ordinal + delta)
     })
   }
@@ -196,11 +213,14 @@ case class ExpressionEncoder[T](
    * input row have been modified to pull the object out from a nested struct, 
instead of the
    * top level fields.
    */
-  def nested(input: Expression = BoundReference(0, schema, true)): 
ExpressionEncoder[T] = {
-    copy(constructExpression = constructExpression transform {
-      case u: Attribute if u != input =>
+  private def nested(i: Int): ExpressionEncoder[T] = {
+    // We don't always know our input type at this point since it might be 
unresolved.
+    // We fill in null and it will get unbound to the actual attribute at this 
position.
+    val input = BoundReference(i, NullType, nullable = true)
+    copy(fromRowExpression = fromRowExpression transformUp {
+      case u: Attribute =>
         UnresolvedExtractValue(input, Literal(u.name))
-      case b: BoundReference if b != input =>
+      case b: BoundReference =>
         GetStructField(
           input,
           StructField(s"i[${b.ordinal}]", b.dataType),
@@ -208,7 +228,7 @@ case class ExpressionEncoder[T](
     })
   }
 
-  protected val attrs = extractExpressions.flatMap(_.collect {
+  protected val attrs = toRowExpressions.flatMap(_.collect {
     case _: UnresolvedAttribute => ""
     case a: Attribute => s"#${a.exprId}"
     case b: BoundReference => s"[${b.ordinal}]"

http://git-wip-us.apache.org/repos/asf/spark/blob/6c1bf19e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
index 2c35adc..9e283f5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
@@ -18,10 +18,19 @@
 package org.apache.spark.sql.catalyst
 
 import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
 
 package object encoders {
+  /**
+   * Returns an internal encoder object that can be used to serialize / 
deserialize JVM objects
+   * into Spark SQL rows.  The implicit encoder should always be unresolved 
(i.e. have no attribute
+   * references from a specific schema.)  This requirement allows us to 
preserve whether a given
+   * object type is being bound by name or by ordinal when doing resolution.
+   */
   private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = 
implicitly[Encoder[A]] match {
-    case e: ExpressionEncoder[A] => e
+    case e: ExpressionEncoder[A] =>
+      e.assertUnresolved()
+      e
     case _ => sys.error(s"Only expression encoders are supported today")
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/6c1bf19e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 41cd0a1..f871b73 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -97,11 +97,16 @@ object ExtractValue {
  * Returns the value of fields in the Struct `child`.
  *
  * No need to do type checking since it is handled by [[ExtractValue]].
+ * TODO: Unify with [[GetInternalRowField]], remove the need to specify a 
[[StructField]].
  */
 case class GetStructField(child: Expression, field: StructField, ordinal: Int)
   extends UnaryExpression {
 
-  override def dataType: DataType = field.dataType
+  override def dataType: DataType = child.dataType match {
+    case s: StructType => s(ordinal).dataType
+    // This is a hack to avoid breaking existing code until we remove the need 
for the struct field
+    case _ => field.dataType
+  }
   override def nullable: Boolean = child.nullable || field.nullable
   override def toString: String = s"$child.${field.name}"
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6c1bf19e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 32b09b5..d9f046e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -483,9 +483,12 @@ case class MapPartitions[T, U](
 
 /** Factory for constructing new `AppendColumn` nodes. */
 object AppendColumn {
-  def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): 
AppendColumn[T, U] = {
+  def apply[T, U : Encoder](
+      func: T => U,
+      tEncoder: ExpressionEncoder[T],
+      child: LogicalPlan): AppendColumn[T, U] = {
     val attrs = encoderFor[U].schema.toAttributes
-    new AppendColumn[T, U](func, encoderFor[T], encoderFor[U], attrs, child)
+    new AppendColumn[T, U](func, tEncoder, encoderFor[U], attrs, child)
   }
 }
 
@@ -506,14 +509,16 @@ case class AppendColumn[T, U](
 
 /** Factory for constructing new `MapGroups` nodes. */
 object MapGroups {
-  def apply[K : Encoder, T : Encoder, U : Encoder](
+  def apply[K, T, U : Encoder](
       func: (K, Iterator[T]) => TraversableOnce[U],
+      kEncoder: ExpressionEncoder[K],
+      tEncoder: ExpressionEncoder[T],
       groupingAttributes: Seq[Attribute],
       child: LogicalPlan): MapGroups[K, T, U] = {
     new MapGroups(
       func,
-      encoderFor[K],
-      encoderFor[T],
+      kEncoder,
+      tEncoder,
       encoderFor[U],
       groupingAttributes,
       encoderFor[U].schema.toAttributes,

http://git-wip-us.apache.org/repos/asf/spark/blob/6c1bf19e/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index f0f275e..9292244 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -17,13 +17,15 @@
 
 package org.apache.spark.sql
 
+import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
+
 import scala.language.implicitConversions
 
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.Logging
 import org.apache.spark.sql.functions.lit
 import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.encoders.encoderFor
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.util.DataTypeParser
 import org.apache.spark.sql.types._
@@ -45,7 +47,25 @@ private[sql] object Column {
  *           checked by the analyzer instead of the compiler (i.e. 
`expr("sum(...)")`).
  * @tparam U The output type of this column.
  */
-class TypedColumn[-T, U](expr: Expression, val encoder: Encoder[U]) extends 
Column(expr)
+class TypedColumn[-T, U](
+    expr: Expression,
+    private[sql] val encoder: ExpressionEncoder[U]) extends Column(expr) {
+
+  /**
+   * Inserts the specific input type and schema into any expressions that are 
expected to operate
+   * on a decoded object.
+   */
+  private[sql] def withInputType(
+      inputEncoder: ExpressionEncoder[_],
+      schema: Seq[Attribute]): TypedColumn[T, U] = {
+    new TypedColumn[T, U] (expr transform {
+      case ta: TypedAggregateExpression if ta.aEncoder.isEmpty =>
+        ta.copy(
+          aEncoder = Some(inputEncoder.asInstanceOf[ExpressionEncoder[Any]]),
+          children = schema)
+    }, encoder)
+  }
+}
 
 /**
  * :: Experimental ::
@@ -73,6 +93,25 @@ class Column(protected[sql] val expr: Expression) extends 
Logging {
   /** Creates a column based on the given expression. */
   private def withExpr(newExpr: Expression): Column = new Column(newExpr)
 
+  /**
+   * Returns the expression for this column either with an existing or auto 
assigned name.
+   */
+  private[sql] def named: NamedExpression = expr match {
+    // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve 
UnresolvedAttribute, we
+    // will remove intermediate Alias for ExtractValue chain, and we need to 
alias it again to
+    // make it a NamedExpression.
+    case u: UnresolvedAttribute => UnresolvedAlias(u)
+
+    case expr: NamedExpression => expr
+
+    // Leave an unaliased generator with an empty list of names since the 
analyzer will generate
+    // the correct defaults after the nested expression's type has been 
resolved.
+    case explode: Explode => MultiAlias(explode, Nil)
+    case jt: JsonTuple => MultiAlias(jt, Nil)
+
+    case expr: Expression => Alias(expr, expr.prettyString)()
+  }
+
   override def toString: String = expr.prettyString
 
   override def equals(that: Any): Boolean = that match {

http://git-wip-us.apache.org/repos/asf/spark/blob/6c1bf19e/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index a492099..3ba4ba1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -735,22 +735,7 @@ class DataFrame private[sql](
    */
   @scala.annotation.varargs
   def select(cols: Column*): DataFrame = withPlan {
-    val namedExpressions = cols.map {
-      // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve 
UnresolvedAttribute, we
-      // will remove intermediate Alias for ExtractValue chain, and we need to 
alias it again to
-      // make it a NamedExpression.
-      case Column(u: UnresolvedAttribute) => UnresolvedAlias(u)
-
-      case Column(expr: NamedExpression) => expr
-
-      // Leave an unaliased generator with an empty list of names since the 
analyzer will generate
-      // the correct defaults after the nested expression's type has been 
resolved.
-      case Column(explode: Explode) => MultiAlias(explode, Nil)
-      case Column(jt: JsonTuple) => MultiAlias(jt, Nil)
-
-      case Column(expr: Expression) => Alias(expr, expr.prettyString)()
-    }
-    Project(namedExpressions.toSeq, logicalPlan)
+    Project(cols.map(_.named), logicalPlan)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/6c1bf19e/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 87dae6b..b930e46 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
 import org.apache.spark.sql.catalyst.plans.Inner
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.{Queryable, QueryExecution}
-import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
 import org.apache.spark.sql.types.StructType
 
 /**
@@ -63,15 +62,20 @@ import org.apache.spark.sql.types.StructType
 class Dataset[T] private[sql](
     @transient val sqlContext: SQLContext,
     @transient val queryExecution: QueryExecution,
-    unresolvedEncoder: Encoder[T]) extends Queryable with Serializable {
+    tEncoder: Encoder[T]) extends Queryable with Serializable {
+
+  /**
+   * An unresolved version of the internal encoder for the type of this 
dataset.  This one is marked
+   * implicit so that we can use it when constructing new [[Dataset]] objects 
that have the same
+   * object type (that will be possibly resolved to a different schema).
+   */
+  private implicit val unresolvedTEncoder: ExpressionEncoder[T] = 
encoderFor(tEncoder)
 
   /** The encoder for this [[Dataset]] that has been resolved to its output 
schema. */
-  private[sql] implicit val encoder: ExpressionEncoder[T] = unresolvedEncoder 
match {
-    case e: ExpressionEncoder[T] => e.resolve(queryExecution.analyzed.output)
-    case _ => throw new IllegalArgumentException("Only expression encoders are 
currently supported")
-  }
+  private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
+    unresolvedTEncoder.resolve(queryExecution.analyzed.output)
 
-  private implicit def classTag = encoder.clsTag
+  private implicit def classTag = resolvedTEncoder.clsTag
 
   private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit 
encoder: Encoder[T]) =
     this(sqlContext, new QueryExecution(sqlContext, plan), encoder)
@@ -81,7 +85,7 @@ class Dataset[T] private[sql](
    *
    * @since 1.6.0
    */
-  def schema: StructType = encoder.schema
+  def schema: StructType = resolvedTEncoder.schema
 
   /* ************* *
    *  Conversions  *
@@ -134,7 +138,7 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   def rdd: RDD[T] = {
-    val tEnc = encoderFor[T]
+    val tEnc = resolvedTEncoder
     val input = queryExecution.analyzed.output
     queryExecution.toRdd.mapPartitions { iter =>
       val bound = tEnc.bind(input)
@@ -195,7 +199,7 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] 
= {
-    new Dataset(
+    new Dataset[U](
       sqlContext,
       MapPartitions[T, U](
         func,
@@ -295,12 +299,12 @@ class Dataset[T] private[sql](
    */
   def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = {
     val inputPlan = queryExecution.analyzed
-    val withGroupingKey = AppendColumn(func, inputPlan)
+    val withGroupingKey = AppendColumn(func, resolvedTEncoder, inputPlan)
     val executed = sqlContext.executePlan(withGroupingKey)
 
     new GroupedDataset(
-      encoderFor[K].resolve(withGroupingKey.newColumns),
-      encoderFor[T].bind(inputPlan.output),
+      encoderFor[K],
+      encoderFor[T],
       executed,
       inputPlan.output,
       withGroupingKey.newColumns)
@@ -360,7 +364,15 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
-    new Dataset[U1](sqlContext, Project(Alias(withEncoder(c1).expr, "_1")() :: 
Nil, logicalPlan))
+    // We use an unbound encoder since the expression will make up its own 
schema.
+    // TODO: This probably doesn't work if we are relying on reordering of the 
input class fields.
+    new Dataset[U1](
+      sqlContext,
+      Project(
+        c1.withInputType(
+          resolvedTEncoder.bind(queryExecution.analyzed.output),
+          queryExecution.analyzed.output).named :: Nil,
+        logicalPlan))
   }
 
   /**
@@ -369,28 +381,14 @@ class Dataset[T] private[sql](
    * that cast appropriately for the user facing interface.
    */
   protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
-    val withEncoders = columns.map(withEncoder)
-    val aliases = withEncoders.zipWithIndex.map { case (c, i) => Alias(c.expr, 
s"_${i + 1}")() }
-    val unresolvedPlan = Project(aliases, logicalPlan)
-    val execution = new QueryExecution(sqlContext, unresolvedPlan)
-    // Rebind the encoders to the nested schema that will be produced by the 
select.
-    val encoders = 
withEncoders.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map 
{
-      case (e: ExpressionEncoder[_], a) if !e.flat =>
-        e.nested(a.toAttribute).resolve(execution.analyzed.output)
-      case (e, a) =>
-        e.unbind(a.toAttribute :: Nil).resolve(execution.analyzed.output)
-    }
-    new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
-  }
+    val encoders = columns.map(_.encoder)
+    // We use an unbound encoder since the expression will make up its own 
schema.
+    // TODO: This probably doesn't work if we are relying on reordering of the 
input class fields.
+    val namedColumns =
+      columns.map(_.withInputType(unresolvedTEncoder, 
queryExecution.analyzed.output).named)
+    val execution = new QueryExecution(sqlContext, Project(namedColumns, 
logicalPlan))
 
-  private def withEncoder(c: TypedColumn[_, _]): TypedColumn[_, _] = {
-    val e = c.expr transform {
-      case ta: TypedAggregateExpression if ta.aEncoder.isEmpty =>
-        ta.copy(
-          aEncoder = Some(encoder.asInstanceOf[ExpressionEncoder[Any]]),
-          children = queryExecution.analyzed.output)
-    }
-    new TypedColumn(e, c.encoder)
+    new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
   }
 
   /**
@@ -497,23 +495,18 @@ class Dataset[T] private[sql](
     val left = this.logicalPlan
     val right = other.logicalPlan
 
-    val leftData = this.encoder match {
+    val leftData = this.unresolvedTEncoder match {
       case e if e.flat => Alias(left.output.head, "_1")()
       case _ => Alias(CreateStruct(left.output), "_1")()
     }
-    val rightData = other.encoder match {
+    val rightData = other.unresolvedTEncoder match {
       case e if e.flat => Alias(right.output.head, "_2")()
       case _ => Alias(CreateStruct(right.output), "_2")()
     }
-    val leftEncoder =
-      if (encoder.flat) encoder else encoder.nested(leftData.toAttribute)
-    val rightEncoder =
-      if (other.encoder.flat) other.encoder else 
other.encoder.nested(rightData.toAttribute)
-    implicit val tuple2Encoder: Encoder[(T, U)] =
-      ExpressionEncoder.tuple(
-        leftEncoder,
-        rightEncoder.rebind(right.output, left.output ++ right.output))
 
+
+    implicit val tuple2Encoder: Encoder[(T, U)] =
+      ExpressionEncoder.tuple(this.unresolvedTEncoder, 
other.unresolvedTEncoder)
     withPlan[(T, U)](other) { (left, right) =>
       Project(
         leftData :: rightData :: Nil,
@@ -580,7 +573,7 @@ class Dataset[T] private[sql](
   private[sql] def logicalPlan = queryExecution.analyzed
 
   private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] =
-    new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), encoder)
+    new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), 
tEncoder)
 
   private[sql] def withPlan[R : Encoder](
       other: Dataset[_])(

http://git-wip-us.apache.org/repos/asf/spark/blob/6c1bf19e/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 61e2a95..ae1272a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -17,20 +17,16 @@
 
 package org.apache.spark.sql
 
-import java.util.{Iterator => JIterator}
 
 import scala.collection.JavaConverters._
 
 import org.apache.spark.annotation.Experimental
-import org.apache.spark.api.java.function.{Function2 => JFunction2, Function3 
=> JFunction3, _}
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, 
UnresolvedAttribute}
+import org.apache.spark.api.java.function._
 import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
-import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, 
Alias, Attribute}
+import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
 import org.apache.spark.sql.execution.QueryExecution
 
-
 /**
  * :: Experimental ::
  * A [[Dataset]] has been logically grouped by a user specified grouping key.  
Users should not
@@ -44,23 +40,21 @@ import org.apache.spark.sql.execution.QueryExecution
  */
 @Experimental
 class GroupedDataset[K, T] private[sql](
-    private val kEncoder: Encoder[K],
-    private val tEncoder: Encoder[T],
-    queryExecution: QueryExecution,
+    kEncoder: Encoder[K],
+    tEncoder: Encoder[T],
+    val queryExecution: QueryExecution,
     private val dataAttributes: Seq[Attribute],
     private val groupingAttributes: Seq[Attribute]) extends Serializable {
 
-  private implicit val kEnc = kEncoder match {
-    case e: ExpressionEncoder[K] => 
e.unbind(groupingAttributes).resolve(groupingAttributes)
-    case other =>
-      throw new UnsupportedOperationException("Only expression encoders are 
currently supported")
-  }
+  // Similar to [[Dataset]], we use unresolved encoders for later composition 
and resolved encoders
+  // when constructing new logical plans that will operate on the output of 
the current
+  // queryexecution.
 
-  private implicit val tEnc = tEncoder match {
-    case e: ExpressionEncoder[T] => e.resolve(dataAttributes)
-    case other =>
-      throw new UnsupportedOperationException("Only expression encoders are 
currently supported")
-  }
+  private implicit val unresolvedKEncoder = encoderFor(kEncoder)
+  private implicit val unresolvedTEncoder = encoderFor(tEncoder)
+
+  private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes)
+  private val resolvedTEncoder = unresolvedTEncoder.resolve(dataAttributes)
 
   /** Encoders for built in aggregations. */
   private implicit def newLongEncoder: Encoder[Long] = 
ExpressionEncoder[Long](flat = true)
@@ -79,7 +73,7 @@ class GroupedDataset[K, T] private[sql](
   def asKey[L : Encoder]: GroupedDataset[L, T] =
     new GroupedDataset(
       encoderFor[L],
-      tEncoder,
+      unresolvedTEncoder,
       queryExecution,
       dataAttributes,
       groupingAttributes)
@@ -95,7 +89,7 @@ class GroupedDataset[K, T] private[sql](
   }
 
   /**
-   * Applies the given function to each group of data.  For each unique group, 
the function will
+   * Applies the given function to each group of data.  For each unique group, 
the function  will
    * be passed the group key and an iterator that contains all of the elements 
in the group. The
    * function can return an iterator containing elements of an arbitrary type 
which will be returned
    * as a new [[Dataset]].
@@ -108,7 +102,12 @@ class GroupedDataset[K, T] private[sql](
   def flatMap[U : Encoder](f: (K, Iterator[T]) => TraversableOnce[U]): 
Dataset[U] = {
     new Dataset[U](
       sqlContext,
-      MapGroups(f, groupingAttributes, logicalPlan))
+      MapGroups(
+        f,
+        resolvedKEncoder,
+        resolvedTEncoder,
+        groupingAttributes,
+        logicalPlan))
   }
 
   def flatMap[U](f: FlatMapGroupFunction[K, T, U], encoder: Encoder[U]): 
Dataset[U] = {
@@ -127,15 +126,28 @@ class GroupedDataset[K, T] private[sql](
    */
   def map[U : Encoder](f: (K, Iterator[T]) => U): Dataset[U] = {
     val func = (key: K, it: Iterator[T]) => Iterator(f(key, it))
-    new Dataset[U](
-      sqlContext,
-      MapGroups(func, groupingAttributes, logicalPlan))
+    flatMap(func)
   }
 
   def map[U](f: MapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = {
     map((key, data) => f.call(key, data.asJava))(encoder)
   }
 
+  /**
+   * Reduces the elements of each group of data using the specified binary 
function.
+   * The given function must be commutative and associative or the result may 
be non-deterministic.
+   */
+  def reduce(f: (T, T) => T): Dataset[(K, T)] = {
+    val func = (key: K, it: Iterator[T]) => Iterator(key -> it.reduce(f))
+
+    implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, 
unresolvedTEncoder)
+    flatMap(func)
+  }
+
+  def reduce(f: ReduceFunction[T]): Dataset[(K, T)] = {
+    reduce(f.call _)
+  }
+
   // To ensure valid overloading.
   protected def agg(expr: Column, exprs: Column*): DataFrame =
     groupedData.agg(expr, exprs: _*)
@@ -147,37 +159,17 @@ class GroupedDataset[K, T] private[sql](
    * TODO: does not handle aggrecations that return nonflat results,
    */
   protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
-    val aliases = (groupingAttributes ++ columns.map(_.expr)).map {
-      case u: UnresolvedAttribute => UnresolvedAlias(u)
-      case expr: NamedExpression => expr
-      case expr: Expression => Alias(expr, expr.prettyString)()
-    }
-
-    val unresolvedPlan = Aggregate(groupingAttributes, aliases, logicalPlan)
-
-    // Fill in the input encoders for any aggregators in the plan.
-    val withEncoders = unresolvedPlan transformAllExpressions {
-      case ta: TypedAggregateExpression if ta.aEncoder.isEmpty =>
-        ta.copy(
-          aEncoder = Some(tEnc.asInstanceOf[ExpressionEncoder[Any]]),
-          children = dataAttributes)
-    }
-    val execution = new QueryExecution(sqlContext, withEncoders)
-
-    val columnEncoders = 
columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]])
-
-    // Rebind the encoders to the nested schema that will be produced by the 
aggregation.
-    val encoders = (kEnc +: columnEncoders).zip(execution.analyzed.output).map 
{
-      case (e: ExpressionEncoder[_], a) if !e.flat =>
-        e.nested(a).resolve(execution.analyzed.output)
-      case (e, a) =>
-        e.unbind(a :: Nil).resolve(execution.analyzed.output)
-    }
+    val encoders = columns.map(_.encoder)
+    val namedColumns =
+      columns.map(
+        _.withInputType(resolvedTEncoder.bind(dataAttributes), 
dataAttributes).named)
+    val aggregate = Aggregate(groupingAttributes, groupingAttributes ++ 
namedColumns, logicalPlan)
+    val execution = new QueryExecution(sqlContext, aggregate)
 
     new Dataset(
       sqlContext,
       execution,
-      ExpressionEncoder.tuple(encoders))
+      ExpressionEncoder.tuple(unresolvedKEncoder +: encoders))
   }
 
   /**
@@ -230,7 +222,7 @@ class GroupedDataset[K, T] private[sql](
   def cogroup[U, R : Encoder](
       other: GroupedDataset[K, U])(
       f: (K, Iterator[T], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
-    implicit def uEnc: Encoder[U] = other.tEncoder
+    implicit def uEnc: Encoder[U] = other.unresolvedTEncoder
     new Dataset[R](
       sqlContext,
       CoGroup(

http://git-wip-us.apache.org/repos/asf/spark/blob/6c1bf19e/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index dfcbac8..3f27758 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -55,7 +55,7 @@ case class TypedAggregateExpression(
     aEncoder: Option[ExpressionEncoder[Any]],
     bEncoder: ExpressionEncoder[Any],
     cEncoder: ExpressionEncoder[Any],
-    children: Seq[Expression],
+    children: Seq[Attribute],
     mutableAggBufferOffset: Int,
     inputAggBufferOffset: Int)
   extends ImperativeAggregate with Logging {
@@ -78,8 +78,7 @@ case class TypedAggregateExpression(
 
   override lazy val resolved: Boolean = aEncoder.isDefined
 
-  override lazy val inputTypes: Seq[DataType] =
-    aEncoder.map(_.schema.map(_.dataType)).getOrElse(Nil)
+  override lazy val inputTypes: Seq[DataType] = Nil
 
   override val aggBufferSchema: StructType = bEncoder.schema
 
@@ -90,12 +89,8 @@ case class TypedAggregateExpression(
   override val inputAggBufferAttributes: Seq[AttributeReference] =
     aggBufferAttributes.map(_.newInstance())
 
-  lazy val inputAttributes = aEncoder.get.schema.toAttributes
-  lazy val inputMapping = AttributeMap(inputAttributes.zip(children))
-  lazy val boundA =
-    aEncoder.get.copy(constructExpression = aEncoder.get.constructExpression 
transform {
-      case a: AttributeReference => inputMapping(a)
-    })
+  // We let the dataset do the binding for us.
+  lazy val boundA = aEncoder.get
 
   val bAttributes = bEncoder.schema.toAttributes
   lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes)

http://git-wip-us.apache.org/repos/asf/spark/blob/6c1bf19e/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index ae08fb7..ed82c9a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -311,6 +311,10 @@ case class AppendColumns[T, U](
     newColumns: Seq[Attribute],
     child: SparkPlan) extends UnaryNode {
 
+  // We are using an unsafe combiner.
+  override def canProcessSafeRows: Boolean = false
+  override def canProcessUnsafeRows: Boolean = true
+
   override def output: Seq[Attribute] = child.output ++ newColumns
 
   override protected def doExecute(): RDD[InternalRow] = {
@@ -349,11 +353,12 @@ case class MapGroups[K, T, U](
     child.execute().mapPartitions { iter =>
       val grouped = GroupedIterator(iter, groupingAttributes, child.output)
       val groupKeyEncoder = kEncoder.bind(groupingAttributes)
+      val groupDataEncoder = tEncoder.bind(child.output)
 
       grouped.flatMap { case (key, rowIter) =>
         val result = func(
           groupKeyEncoder.fromRow(key),
-          rowIter.map(tEncoder.fromRow))
+          rowIter.map(groupDataEncoder.fromRow))
         result.map(uEncoder.toRow)
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/6c1bf19e/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 33d8388..46169ca 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -157,7 +157,6 @@ public class JavaDatasetSuite implements Serializable {
     Assert.assertEquals(6, reduced);
   }
 
-  @Test
   public void testGroupBy() {
     List<String> data = Arrays.asList("a", "foo", "bar");
     Dataset<String> ds = context.createDataset(data, Encoders.STRING());
@@ -196,6 +195,17 @@ public class JavaDatasetSuite implements Serializable {
 
     Assert.assertEquals(Arrays.asList("1a", "3foobar"), 
flatMapped.collectAsList());
 
+    Dataset<Tuple2<Integer, String>> reduced = grouped.reduce(new 
ReduceFunction<String>() {
+      @Override
+      public String call(String v1, String v2) throws Exception {
+        return v1 + v2;
+      }
+    });
+
+    Assert.assertEquals(
+      Arrays.asList(tuple2(1, "a"), tuple2(3, "foobar")),
+      reduced.collectAsList());
+
     List<Integer> data2 = Arrays.asList(2, 6, 10);
     Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT());
     GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new 
MapFunction<Integer, Integer>() {

http://git-wip-us.apache.org/repos/asf/spark/blob/6c1bf19e/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 378cd36..20896ef 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -67,6 +67,28 @@ object ComplexResultAgg extends Aggregator[(String, Int), 
(Long, Long), (Long, L
   override def finish(reduction: (Long, Long)): (Long, Long) = reduction
 }
 
+case class AggData(a: Int, b: String)
+object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable {
+  /** A zero value for this aggregation. Should satisfy the property that any 
b + zero = b */
+  override def zero: Int = 0
+
+  /**
+   * Combine two values to produce a new value.  For performance, the function 
may modify `b` and
+   * return it instead of constructing new object for b.
+   */
+  override def reduce(b: Int, a: AggData): Int = b + a.a
+
+  /**
+   * Transform the output of the reduction.
+   */
+  override def finish(reduction: Int): Int = reduction
+
+  /**
+   * Merge two intermediate values
+   */
+  override def merge(b1: Int, b2: Int): Int = b1 + b2
+}
+
 class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
 
   import testImplicits._
@@ -123,4 +145,24 @@ class DatasetAggregatorSuite extends QueryTest with 
SharedSQLContext {
       ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)),
       11 -> 22)
   }
+
+  test("typed aggregation: class input") {
+    val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS()
+
+    checkAnswer(
+      ds.select(ClassInputAgg.toColumn),
+      3)
+  }
+
+  test("typed aggregation: class input with reordering") {
+    val ds = sql("SELECT 'one' AS b, 1 as a").as[AggData]
+
+    checkAnswer(
+      ds.select(ClassInputAgg.toColumn),
+      1)
+
+    checkAnswer(
+      ds.groupBy(_.b).agg(ClassInputAgg.toColumn),
+      ("one", 1))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/6c1bf19e/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 6211485..c23dd46 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -218,6 +218,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext 
{
       "a", "30", "b", "3", "c", "1")
   }
 
+  test("groupBy function, reduce") {
+    val ds = Seq("abc", "xyz", "hello").toDS()
+    val agged = ds.groupBy(_.length).reduce(_ + _)
+
+    checkAnswer(
+      agged,
+      3 -> "abcxyz", 5 -> "hello")
+  }
+
   test("groupBy columns, map") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
     val grouped = ds.groupBy($"_1")

http://git-wip-us.apache.org/repos/asf/spark/blob/6c1bf19e/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 7a8b7ae..b5417b1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -82,18 +82,21 @@ abstract class QueryTest extends PlanTest {
         fail(
           s"""
              |Exception collecting dataset as objects
-             |${ds.encoder}
-             |${ds.encoder.constructExpression.treeString}
+             |${ds.resolvedTEncoder}
+             |${ds.resolvedTEncoder.fromRowExpression.treeString}
              |${ds.queryExecution}
            """.stripMargin, e)
     }
 
     if (decoded != expectedAnswer.toSet) {
+      val expected = expectedAnswer.toSet.toSeq.map((a: Any) => 
a.toString).sorted
+      val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted
+
+      val comparision = sideBySide("expected" +: expected, "spark" +: 
actual).mkString("\n")
       fail(
         s"""Decoded objects do not match expected objects:
-           |Expected: ${expectedAnswer.toSet.toSeq.map((a: Any) => 
a.toString).sorted}
-            |Actual ${decoded.toSet.toSeq.map((a: Any) => a.toString).sorted}
-            |${ds.encoder.constructExpression.treeString}
+            |$comparision
+            |${ds.resolvedTEncoder.fromRowExpression.treeString}
          """.stripMargin)
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to