Repository: spark
Updated Branches:
  refs/heads/master e78ec1a8f -> 3744b7fd4


[SPARK-9422] [SQL] Remove the placeholder attributes used in the aggregation 
buffers

https://issues.apache.org/jira/browse/SPARK-9422

Author: Yin Huai <yh...@databricks.com>

Closes #7737 from yhuai/removePlaceHolder and squashes the following commits:

ec29b44 [Yin Huai]  Remove placeholder attributes.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3744b7fd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3744b7fd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3744b7fd

Branch: refs/heads/master
Commit: 3744b7fd42e52011af60cc205fcb4e4b23b35c68
Parents: e78ec1a
Author: Yin Huai <yh...@databricks.com>
Authored: Tue Jul 28 19:01:25 2015 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Tue Jul 28 19:01:25 2015 -0700

----------------------------------------------------------------------
 .../expressions/aggregate/interfaces.scala      |  27 ++-
 .../aggregate/aggregateOperators.scala          |   4 +-
 .../aggregate/sortBasedIterators.scala          | 209 +++++++------------
 .../spark/sql/execution/aggregate/udaf.scala    |  17 +-
 .../spark/sql/execution/aggregate/utils.scala   |   4 +-
 5 files changed, 121 insertions(+), 140 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3744b7fd/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 10bd19c..9fb7623 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -103,9 +103,30 @@ abstract class AggregateFunction2
   final override def foldable: Boolean = false
 
   /**
-   * The offset of this function's buffer in the underlying buffer shared with 
other functions.
+   * The offset of this function's start buffer value in the
+   * underlying shared mutable aggregation buffer.
+   * For example, we have two aggregate functions `avg(x)` and `avg(y)`, which 
share
+   * the same aggregation buffer. In this shared buffer, the position of the 
first
+   * buffer value of `avg(x)` will be 0 and the position of the first buffer 
value of `avg(y)`
+   * will be 2.
    */
-  var bufferOffset: Int = 0
+  var mutableBufferOffset: Int = 0
+
+  /**
+   * The offset of this function's start buffer value in the
+   * underlying shared input aggregation buffer. An input aggregation buffer 
is used
+   * when we merge two aggregation buffers and it is basically the immutable 
one
+   * (we merge an input aggregation buffer and a mutable aggregation buffer and
+   * then store the new buffer values to the mutable aggregation buffer).
+   * Usually, an input aggregation buffer also contain extra elements like 
grouping
+   * keys at the beginning. So, mutableBufferOffset and inputBufferOffset are 
often
+   * different.
+   * For example, we have a grouping expression `key``, and two aggregate 
functions
+   * `avg(x)` and `avg(y)`. In this shared input aggregation buffer, the 
position of the first
+   * buffer value of `avg(x)` will be 1 and the position of the first buffer 
value of `avg(y)`
+   * will be 3 (position 0 is used for the value of key`).
+   */
+  var inputBufferOffset: Int = 0
 
   /** The schema of the aggregation buffer. */
   def bufferSchema: StructType
@@ -176,7 +197,7 @@ abstract class AlgebraicAggregate extends 
AggregateFunction2 with Serializable w
   override def initialize(buffer: MutableRow): Unit = {
     var i = 0
     while (i < bufferAttributes.size) {
-      buffer(i + bufferOffset) = initialValues(i).eval()
+      buffer(i + mutableBufferOffset) = initialValues(i).eval()
       i += 1
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/3744b7fd/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
index 0c90828..98538c4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala
@@ -72,8 +72,10 @@ case class Aggregate2Sort(
   protected override def doExecute(): RDD[InternalRow] = attachTree(this, 
"execute") {
     child.execute().mapPartitions { iter =>
       if (aggregateExpressions.length == 0) {
-        new GroupingIterator(
+        new FinalSortAggregationIterator(
           groupingExpressions,
+          Nil,
+          Nil,
           resultExpressions,
           newMutableProjection,
           child.output,

http://git-wip-us.apache.org/repos/asf/spark/blob/3744b7fd/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
index 1b89eda..2ca0cb8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
@@ -41,7 +41,8 @@ private[sql] abstract class SortAggregationIterator(
   ///////////////////////////////////////////////////////////////////////////
 
   protected val aggregateFunctions: Array[AggregateFunction2] = {
-    var bufferOffset = initialBufferOffset
+    var mutableBufferOffset = 0
+    var inputBufferOffset: Int = initialInputBufferOffset
     val functions = new Array[AggregateFunction2](aggregateExpressions.length)
     var i = 0
     while (i < aggregateExpressions.length) {
@@ -54,13 +55,18 @@ private[sql] abstract class SortAggregationIterator(
           // function's children in the update method of this aggregate 
function.
           // Those eval calls require BoundReferences to work.
           BindReferences.bindReference(func, inputAttributes)
-        case _ => func
+        case _ =>
+          // We only need to set inputBufferOffset for aggregate functions 
with mode
+          // PartialMerge and Final.
+          func.inputBufferOffset = inputBufferOffset
+          inputBufferOffset += func.bufferSchema.length
+          func
       }
-      // Set bufferOffset for this function. It is important that setting 
bufferOffset
-      // happens after all potential bindReference operations because 
bindReference
-      // will create a new instance of the function.
-      funcWithBoundReferences.bufferOffset = bufferOffset
-      bufferOffset += funcWithBoundReferences.bufferSchema.length
+      // Set mutableBufferOffset for this function. It is important that 
setting
+      // mutableBufferOffset happens after all potential bindReference 
operations
+      // because bindReference will create a new instance of the function.
+      funcWithBoundReferences.mutableBufferOffset = mutableBufferOffset
+      mutableBufferOffset += funcWithBoundReferences.bufferSchema.length
       functions(i) = funcWithBoundReferences
       i += 1
     }
@@ -97,25 +103,24 @@ private[sql] abstract class SortAggregationIterator(
     // The number of elements of the underlying buffer of this operator.
     // All aggregate functions are sharing this underlying buffer and they 
find their
     // buffer values through bufferOffset.
-    var size = initialBufferOffset
-    var i = 0
-    while (i < aggregateFunctions.length) {
-      size += aggregateFunctions(i).bufferSchema.length
-      i += 1
-    }
-    new GenericMutableRow(size)
+    // var size = 0
+    // var i = 0
+    // while (i < aggregateFunctions.length) {
+    //  size += aggregateFunctions(i).bufferSchema.length
+    //  i += 1
+    // }
+    new GenericMutableRow(aggregateFunctions.map(_.bufferSchema.length).sum)
   }
 
   protected val joinedRow = new JoinedRow
 
-  protected val placeholderExpressions = Seq.fill(initialBufferOffset)(NoOp)
-
   // This projection is used to initialize buffer values for all 
AlgebraicAggregates.
   protected val algebraicInitialProjection = {
-    val initExpressions = placeholderExpressions ++ aggregateFunctions.flatMap 
{
+    val initExpressions = aggregateFunctions.flatMap {
       case ae: AlgebraicAggregate => ae.initialValues
       case agg: AggregateFunction2 => 
Seq.fill(agg.bufferAttributes.length)(NoOp)
     }
+
     newMutableProjection(initExpressions, Nil)().target(buffer)
   }
 
@@ -132,10 +137,6 @@ private[sql] abstract class SortAggregationIterator(
   // Indicates if we has new group of rows to process.
   protected var hasNewGroup: Boolean = true
 
-  ///////////////////////////////////////////////////////////////////////////
-  // Private methods
-  ///////////////////////////////////////////////////////////////////////////
-
   /** Initializes buffer values for all aggregate functions. */
   protected def initializeBuffer(): Unit = {
     algebraicInitialProjection(EmptyRow)
@@ -160,6 +161,10 @@ private[sql] abstract class SortAggregationIterator(
     }
   }
 
+  ///////////////////////////////////////////////////////////////////////////
+  // Private methods
+  ///////////////////////////////////////////////////////////////////////////
+
   /** Processes rows in the current group. It will stop when it find a new 
group. */
   private def processCurrentGroup(): Unit = {
     currentGroupingKey = nextGroupingKey
@@ -218,10 +223,13 @@ private[sql] abstract class SortAggregationIterator(
   // Methods that need to be implemented
   ///////////////////////////////////////////////////////////////////////////
 
-  protected def initialBufferOffset: Int
+  /** The initial input buffer offset for `inputBufferOffset` of an 
[[AggregateFunction2]]. */
+  protected def initialInputBufferOffset: Int
 
+  /** The function used to process an input row. */
   protected def processRow(row: InternalRow): Unit
 
+  /** The function used to generate the result row. */
   protected def generateOutput(): InternalRow
 
   ///////////////////////////////////////////////////////////////////////////
@@ -232,37 +240,6 @@ private[sql] abstract class SortAggregationIterator(
 }
 
 /**
- * An iterator only used to group input rows according to values of 
`groupingExpressions`.
- * It assumes that input rows are already grouped by values of 
`groupingExpressions`.
- */
-class GroupingIterator(
-    groupingExpressions: Seq[NamedExpression],
-    resultExpressions: Seq[NamedExpression],
-    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => 
MutableProjection),
-    inputAttributes: Seq[Attribute],
-    inputIter: Iterator[InternalRow])
-  extends SortAggregationIterator(
-    groupingExpressions,
-    Nil,
-    newMutableProjection,
-    inputAttributes,
-    inputIter) {
-
-  private val resultProjection =
-    newMutableProjection(resultExpressions, 
groupingExpressions.map(_.toAttribute))()
-
-  override protected def initialBufferOffset: Int = 0
-
-  override protected def processRow(row: InternalRow): Unit = {
-    // Since we only do grouping, there is nothing to do at here.
-  }
-
-  override protected def generateOutput(): InternalRow = {
-    resultProjection(currentGroupingKey)
-  }
-}
-
-/**
  * An iterator used to do partial aggregations (for those aggregate functions 
with mode Partial).
  * It assumes that input rows are already grouped by values of 
`groupingExpressions`.
  * The format of its output rows is:
@@ -291,7 +268,7 @@ class PartialSortAggregationIterator(
     newMutableProjection(updateExpressions, bufferSchema ++ 
inputAttributes)().target(buffer)
   }
 
-  override protected def initialBufferOffset: Int = 0
+  override protected def initialInputBufferOffset: Int = 0
 
   override protected def processRow(row: InternalRow): Unit = {
     // Process all algebraic aggregate functions.
@@ -318,11 +295,7 @@ class PartialSortAggregationIterator(
  * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
  *
  * The format of its internal buffer is:
- * |placeholder1|...|placeholderN|aggregationBuffer1|...|aggregationBufferN|
- * Every placeholder is for a grouping expression.
- * The actual buffers are stored after placeholderN.
- * The reason that we have placeholders at here is to make our underlying 
buffer have the same
- * length with a input row.
+ * |aggregationBuffer1|...|aggregationBufferN|
  *
  * The format of its output rows is:
  * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
@@ -340,33 +313,21 @@ class PartialMergeSortAggregationIterator(
     inputAttributes,
     inputIter) {
 
-  private val placeholderAttributes =
-    Seq.fill(initialBufferOffset)(AttributeReference("placeholder", 
NullType)())
-
   // This projection is used to merge buffer values for all 
AlgebraicAggregates.
   private val algebraicMergeProjection = {
-    val bufferSchemata =
-      placeholderAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) 
++
-        placeholderAttributes ++ 
aggregateFunctions.flatMap(_.cloneBufferAttributes)
-    val mergeExpressions = placeholderExpressions ++ 
aggregateFunctions.flatMap {
+    val mergeInputSchema =
+      aggregateFunctions.flatMap(_.bufferAttributes) ++
+        groupingExpressions.map(_.toAttribute) ++
+        aggregateFunctions.flatMap(_.cloneBufferAttributes)
+    val mergeExpressions = aggregateFunctions.flatMap {
       case ae: AlgebraicAggregate => ae.mergeExpressions
       case agg: AggregateFunction2 => 
Seq.fill(agg.bufferAttributes.length)(NoOp)
     }
 
-    newMutableProjection(mergeExpressions, bufferSchemata)()
+    newMutableProjection(mergeExpressions, mergeInputSchema)()
   }
 
-  // This projection is used to extract aggregation buffers from the 
underlying buffer.
-  // We need it because the underlying buffer has placeholders at its 
beginning.
-  private val extractsBufferValues = {
-    val expressions = aggregateFunctions.flatMap {
-      case agg => agg.bufferAttributes
-    }
-
-    newMutableProjection(expressions, inputAttributes)()
-  }
-
-  override protected def initialBufferOffset: Int = groupingExpressions.length
+  override protected def initialInputBufferOffset: Int = 
groupingExpressions.length
 
   override protected def processRow(row: InternalRow): Unit = {
     // Process all algebraic aggregate functions.
@@ -381,7 +342,7 @@ class PartialMergeSortAggregationIterator(
 
   override protected def generateOutput(): InternalRow = {
     // We output grouping expressions and aggregation buffers.
-    joinedRow(currentGroupingKey, extractsBufferValues(buffer))
+    joinedRow(currentGroupingKey, buffer).copy()
   }
 }
 
@@ -393,11 +354,7 @@ class PartialMergeSortAggregationIterator(
  * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
  *
  * The format of its internal buffer is:
- * |placeholder1|...|placeholder N|aggregationBuffer1|...|aggregationBufferN|
- * Every placeholder is for a grouping expression.
- * The actual buffers are stored after placeholderN.
- * The reason that we have placeholders at here is to make our underlying 
buffer have the same
- * length with a input row.
+ * |aggregationBuffer1|...|aggregationBufferN|
  *
  * The format of its output rows is represented by the schema of 
`resultExpressions`.
  */
@@ -425,27 +382,23 @@ class FinalSortAggregationIterator(
     newMutableProjection(
       resultExpressions, groupingExpressions.map(_.toAttribute) ++ 
aggregateAttributes)()
 
-  private val offsetAttributes =
-    Seq.fill(initialBufferOffset)(AttributeReference("placeholder", 
NullType)())
-
   // This projection is used to merge buffer values for all 
AlgebraicAggregates.
   private val algebraicMergeProjection = {
-    val bufferSchemata =
-      offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++
-        offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes)
-    val mergeExpressions = placeholderExpressions ++ 
aggregateFunctions.flatMap {
+    val mergeInputSchema =
+      aggregateFunctions.flatMap(_.bufferAttributes) ++
+        groupingExpressions.map(_.toAttribute) ++
+        aggregateFunctions.flatMap(_.cloneBufferAttributes)
+    val mergeExpressions = aggregateFunctions.flatMap {
       case ae: AlgebraicAggregate => ae.mergeExpressions
       case agg: AggregateFunction2 => 
Seq.fill(agg.bufferAttributes.length)(NoOp)
     }
 
-    newMutableProjection(mergeExpressions, bufferSchemata)()
+    newMutableProjection(mergeExpressions, mergeInputSchema)()
   }
 
   // This projection is used to evaluate all AlgebraicAggregates.
   private val algebraicEvalProjection = {
-    val bufferSchemata =
-      offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++
-        offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes)
+    val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes)
     val evalExpressions = aggregateFunctions.map {
       case ae: AlgebraicAggregate => ae.evaluateExpression
       case agg: AggregateFunction2 => NoOp
@@ -454,7 +407,7 @@ class FinalSortAggregationIterator(
     newMutableProjection(evalExpressions, bufferSchemata)()
   }
 
-  override protected def initialBufferOffset: Int = groupingExpressions.length
+  override protected def initialInputBufferOffset: Int = 
groupingExpressions.length
 
   override def initialize(): Unit = {
     if (inputIter.hasNext) {
@@ -471,7 +424,10 @@ class FinalSortAggregationIterator(
         // Right now, the buffer only contains initial buffer values. Because
         // merging two buffers with initial values will generate a row that
         // still store initial values. We set the currentRow as the copy of 
the current buffer.
-        val currentRow = buffer.copy()
+        // Because input aggregation buffer has initialInputBufferOffset extra 
values at the
+        // beginning, we create a dummy row for this part.
+        val currentRow =
+          joinedRow(new GenericInternalRow(initialInputBufferOffset), 
buffer).copy()
         nextGroupingKey = groupGenerator(currentRow).copy()
         firstRowInNextGroup = currentRow
       } else {
@@ -518,18 +474,15 @@ class FinalSortAggregationIterator(
  * Final mode.
  *
  * The format of its internal buffer is:
- * 
|placeholder1|...|placeholder(N+M)|aggregationBuffer1|...|aggregationBuffer(N+M)|
- * The first N placeholders represent slots of grouping expressions.
- * Then, next M placeholders represent slots of col1 to colM.
+ * |aggregationBuffer1|...|aggregationBuffer(N+M)|
  * For aggregation buffers, first N aggregation buffers are used by N 
aggregate functions with
  * mode Final. Then, the last M aggregation buffers are used by M aggregate 
functions with mode
- * Complete. The reason that we have placeholders at here is to make our 
underlying buffer
- * have the same length with a input row.
+ * Complete.
  *
  * The format of its output rows is represented by the schema of 
`resultExpressions`.
  */
 class FinalAndCompleteSortAggregationIterator(
-    override protected val initialBufferOffset: Int,
+    override protected val initialInputBufferOffset: Int,
     groupingExpressions: Seq[NamedExpression],
     finalAggregateExpressions: Seq[AggregateExpression2],
     finalAggregateAttributes: Seq[Attribute],
@@ -561,9 +514,6 @@ class FinalAndCompleteSortAggregationIterator(
     newMutableProjection(resultExpressions, inputSchema)()
   }
 
-  private val offsetAttributes =
-    Seq.fill(initialBufferOffset)(AttributeReference("placeholder", 
NullType)())
-
   // All aggregate functions with mode Final.
   private val finalAggregateFunctions: Array[AggregateFunction2] = {
     val functions = new 
Array[AggregateFunction2](finalAggregateExpressions.length)
@@ -601,38 +551,38 @@ class FinalAndCompleteSortAggregationIterator(
   // This projection is used to merge buffer values for all 
AlgebraicAggregates with mode
   // Final.
   private val finalAlgebraicMergeProjection = {
-    val numCompleteOffsetAttributes =
-      completeAggregateFunctions.map(_.bufferAttributes.length).sum
-    val completeOffsetAttributes =
-      Seq.fill(numCompleteOffsetAttributes)(AttributeReference("placeholder", 
NullType)())
-    val completeOffsetExpressions = Seq.fill(numCompleteOffsetAttributes)(NoOp)
-
-    val bufferSchemata =
-      offsetAttributes ++ finalAggregateFunctions.flatMap(_.bufferAttributes) 
++
-        completeOffsetAttributes ++ offsetAttributes ++
-        finalAggregateFunctions.flatMap(_.cloneBufferAttributes) ++ 
completeOffsetAttributes
+    // The first initialInputBufferOffset values of the input aggregation 
buffer is
+    // for grouping expressions and distinct columns.
+    val groupingAttributesAndDistinctColumns = 
inputAttributes.take(initialInputBufferOffset)
+
+    val completeOffsetExpressions =
+      
Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
+
+    val mergeInputSchema =
+      finalAggregateFunctions.flatMap(_.bufferAttributes) ++
+        completeAggregateFunctions.flatMap(_.bufferAttributes) ++
+        groupingAttributesAndDistinctColumns ++
+        finalAggregateFunctions.flatMap(_.cloneBufferAttributes)
     val mergeExpressions =
-      placeholderExpressions ++ finalAggregateFunctions.flatMap {
+      finalAggregateFunctions.flatMap {
         case ae: AlgebraicAggregate => ae.mergeExpressions
         case agg: AggregateFunction2 => 
Seq.fill(agg.bufferAttributes.length)(NoOp)
       } ++ completeOffsetExpressions
-
-    newMutableProjection(mergeExpressions, bufferSchemata)()
+    newMutableProjection(mergeExpressions, mergeInputSchema)()
   }
 
   // This projection is used to update buffer values for all 
AlgebraicAggregates with mode
   // Complete.
   private val completeAlgebraicUpdateProjection = {
-    val numFinalOffsetAttributes = 
finalAggregateFunctions.map(_.bufferAttributes.length).sum
-    val finalOffsetAttributes =
-      Seq.fill(numFinalOffsetAttributes)(AttributeReference("placeholder", 
NullType)())
-    val finalOffsetExpressions = Seq.fill(numFinalOffsetAttributes)(NoOp)
+    // We do not touch buffer values of aggregate functions with the Final 
mode.
+    val finalOffsetExpressions =
+      
Seq.fill(finalAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
 
     val bufferSchema =
-      offsetAttributes ++ finalOffsetAttributes ++
+      finalAggregateFunctions.flatMap(_.bufferAttributes) ++
         completeAggregateFunctions.flatMap(_.bufferAttributes)
     val updateExpressions =
-      placeholderExpressions ++ finalOffsetExpressions ++ 
completeAggregateFunctions.flatMap {
+      finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
         case ae: AlgebraicAggregate => ae.updateExpressions
         case agg: AggregateFunction2 => 
Seq.fill(agg.bufferAttributes.length)(NoOp)
       }
@@ -641,9 +591,7 @@ class FinalAndCompleteSortAggregationIterator(
 
   // This projection is used to evaluate all AlgebraicAggregates.
   private val algebraicEvalProjection = {
-    val bufferSchemata =
-      offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++
-        offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes)
+    val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes)
     val evalExpressions = aggregateFunctions.map {
       case ae: AlgebraicAggregate => ae.evaluateExpression
       case agg: AggregateFunction2 => NoOp
@@ -667,7 +615,10 @@ class FinalAndCompleteSortAggregationIterator(
         // Right now, the buffer only contains initial buffer values. Because
         // merging two buffers with initial values will generate a row that
         // still store initial values. We set the currentRow as the copy of 
the current buffer.
-        val currentRow = buffer.copy()
+        // Because input aggregation buffer has initialInputBufferOffset extra 
values at the
+        // beginning, we create a dummy row for this part.
+        val currentRow =
+          joinedRow(new GenericInternalRow(initialInputBufferOffset), 
buffer).copy()
         nextGroupingKey = groupGenerator(currentRow).copy()
         firstRowInNextGroup = currentRow
       } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/3744b7fd/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index 073c45a..cc54319 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -184,7 +184,7 @@ private[sql] case class ScalaUDAF(
       bufferSchema,
       bufferValuesToCatalystConverters,
       bufferValuesToScalaConverters,
-      bufferOffset,
+      inputBufferOffset,
       null)
 
   lazy val mutableAggregateBuffer: MutableAggregationBufferImpl =
@@ -192,9 +192,16 @@ private[sql] case class ScalaUDAF(
       bufferSchema,
       bufferValuesToCatalystConverters,
       bufferValuesToScalaConverters,
-      bufferOffset,
+      mutableBufferOffset,
       null)
 
+  lazy val evalAggregateBuffer: InputAggregationBuffer =
+    new InputAggregationBuffer(
+      bufferSchema,
+      bufferValuesToCatalystConverters,
+      bufferValuesToScalaConverters,
+      mutableBufferOffset,
+      null)
 
   override def initialize(buffer: MutableRow): Unit = {
     mutableAggregateBuffer.underlyingBuffer = buffer
@@ -217,10 +224,10 @@ private[sql] case class ScalaUDAF(
     udaf.merge(mutableAggregateBuffer, inputAggregateBuffer)
   }
 
-  override def eval(buffer: InternalRow = null): Any = {
-    inputAggregateBuffer.underlyingInputBuffer = buffer
+  override def eval(buffer: InternalRow): Any = {
+    evalAggregateBuffer.underlyingInputBuffer = buffer
 
-    udaf.evaluate(inputAggregateBuffer)
+    udaf.evaluate(evalAggregateBuffer)
   }
 
   override def toString: String = {

http://git-wip-us.apache.org/repos/asf/spark/blob/3744b7fd/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 5bbe6c1..6549c87 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -292,8 +292,8 @@ object Utils {
         AggregateExpression2(aggregateFunction, PartialMerge, false)
     }
     val partialMergeAggregateAttributes =
-      partialMergeAggregateExpressions.map {
-        expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
+      partialMergeAggregateExpressions.flatMap { agg =>
+        agg.aggregateFunction.bufferAttributes
       }
     val partialMergeAggregate =
       Aggregate2Sort(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to