Repository: spark
Updated Branches:
  refs/heads/master fa84e4a7b -> 55946e76f


[SPARK-9349] [SQL] UDAF cleanup

https://issues.apache.org/jira/browse/SPARK-9349

With this PR, we only expose `UserDefinedAggregateFunction` (an abstract class) 
and `MutableAggregationBuffer` (an interface). Other internal wrappers and 
helper classes are moved to `org.apache.spark.sql.execution.aggregate` and 
marked as `private[sql]`.

Author: Yin Huai <yh...@databricks.com>

Closes #7687 from yhuai/UDAF-cleanup and squashes the following commits:

db36542 [Yin Huai] Add comments to UDAF examples.
ae17f66 [Yin Huai] Address comments.
9c9fa5f [Yin Huai] UDAF cleanup.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/55946e76
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/55946e76
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/55946e76

Branch: refs/heads/master
Commit: 55946e76fd136958081f073c0c5e3ff8563d505b
Parents: fa84e4a
Author: Yin Huai <yh...@databricks.com>
Authored: Mon Jul 27 13:26:57 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Mon Jul 27 13:26:57 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/sql/UDAFRegistration.scala |   3 +-
 .../spark/sql/execution/aggregate/udaf.scala    | 231 +++++++++++++++
 .../spark/sql/expressions/aggregate/udaf.scala  | 287 -------------------
 .../org/apache/spark/sql/expressions/udaf.scala | 101 +++++++
 .../spark/sql/hive/aggregate/MyDoubleAvg.java   |  34 ++-
 .../spark/sql/hive/aggregate/MyDoubleSum.java   |  28 +-
 6 files changed, 385 insertions(+), 299 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/55946e76/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
index 5b872f5..0d4e30f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql
 
 import org.apache.spark.Logging
 import org.apache.spark.sql.catalyst.expressions.{Expression}
-import org.apache.spark.sql.expressions.aggregate.{ScalaUDAF, 
UserDefinedAggregateFunction}
+import org.apache.spark.sql.execution.aggregate.ScalaUDAF
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
 
 class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/55946e76/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
new file mode 100644
index 0000000..073c45a
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -0,0 +1,231 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
+import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
+import org.apache.spark.sql.catalyst.expressions.{MutableRow, 
InterpretedMutableProjection, AttributeReference, Expression}
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2
+import org.apache.spark.sql.expressions.{MutableAggregationBuffer, 
UserDefinedAggregateFunction}
+import org.apache.spark.sql.types.{Metadata, StructField, StructType, DataType}
+
+/**
+ * A Mutable [[Row]] representing an mutable aggregation buffer.
+ */
+private[sql] class MutableAggregationBufferImpl (
+    schema: StructType,
+    toCatalystConverters: Array[Any => Any],
+    toScalaConverters: Array[Any => Any],
+    bufferOffset: Int,
+    var underlyingBuffer: MutableRow)
+  extends MutableAggregationBuffer {
+
+  private[this] val offsets: Array[Int] = {
+    val newOffsets = new Array[Int](length)
+    var i = 0
+    while (i < newOffsets.length) {
+      newOffsets(i) = bufferOffset + i
+      i += 1
+    }
+    newOffsets
+  }
+
+  override def length: Int = toCatalystConverters.length
+
+  override def get(i: Int): Any = {
+    if (i >= length || i < 0) {
+      throw new IllegalArgumentException(
+        s"Could not access ${i}th value in this buffer because it only has 
$length values.")
+    }
+    toScalaConverters(i)(underlyingBuffer.get(offsets(i), schema(i).dataType))
+  }
+
+  def update(i: Int, value: Any): Unit = {
+    if (i >= length || i < 0) {
+      throw new IllegalArgumentException(
+        s"Could not update ${i}th value in this buffer because it only has 
$length values.")
+    }
+    underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value))
+  }
+
+  override def copy(): MutableAggregationBufferImpl = {
+    new MutableAggregationBufferImpl(
+      schema,
+      toCatalystConverters,
+      toScalaConverters,
+      bufferOffset,
+      underlyingBuffer)
+  }
+}
+
+/**
+ * A [[Row]] representing an immutable aggregation buffer.
+ */
+private[sql] class InputAggregationBuffer private[sql] (
+    schema: StructType,
+    toCatalystConverters: Array[Any => Any],
+    toScalaConverters: Array[Any => Any],
+    bufferOffset: Int,
+    var underlyingInputBuffer: InternalRow)
+  extends Row {
+
+  private[this] val offsets: Array[Int] = {
+    val newOffsets = new Array[Int](length)
+    var i = 0
+    while (i < newOffsets.length) {
+      newOffsets(i) = bufferOffset + i
+      i += 1
+    }
+    newOffsets
+  }
+
+  override def length: Int = toCatalystConverters.length
+
+  override def get(i: Int): Any = {
+    if (i >= length || i < 0) {
+      throw new IllegalArgumentException(
+        s"Could not access ${i}th value in this buffer because it only has 
$length values.")
+    }
+    // TODO: Use buffer schema to avoid using generic getter.
+    toScalaConverters(i)(underlyingInputBuffer.get(offsets(i), 
schema(i).dataType))
+  }
+
+  override def copy(): InputAggregationBuffer = {
+    new InputAggregationBuffer(
+      schema,
+      toCatalystConverters,
+      toScalaConverters,
+      bufferOffset,
+      underlyingInputBuffer)
+  }
+}
+
+/**
+ * The internal wrapper used to hook a [[UserDefinedAggregateFunction]] `udaf` 
in the
+ * internal aggregation code path.
+ * @param children
+ * @param udaf
+ */
+private[sql] case class ScalaUDAF(
+    children: Seq[Expression],
+    udaf: UserDefinedAggregateFunction)
+  extends AggregateFunction2 with Logging {
+
+  require(
+    children.length == udaf.inputSchema.length,
+    s"$udaf only accepts ${udaf.inputSchema.length} arguments, " +
+      s"but ${children.length} are provided.")
+
+  override def nullable: Boolean = true
+
+  override def dataType: DataType = udaf.returnDataType
+
+  override def deterministic: Boolean = udaf.deterministic
+
+  override val inputTypes: Seq[DataType] = udaf.inputSchema.map(_.dataType)
+
+  override val bufferSchema: StructType = udaf.bufferSchema
+
+  override val bufferAttributes: Seq[AttributeReference] = 
bufferSchema.toAttributes
+
+  override lazy val cloneBufferAttributes = 
bufferAttributes.map(_.newInstance())
+
+  val childrenSchema: StructType = {
+    val inputFields = children.zipWithIndex.map {
+      case (child, index) =>
+        StructField(s"input$index", child.dataType, child.nullable, 
Metadata.empty)
+    }
+    StructType(inputFields)
+  }
+
+  lazy val inputProjection = {
+    val inputAttributes = childrenSchema.toAttributes
+    log.debug(
+      s"Creating MutableProj: $children, inputSchema: $inputAttributes.")
+    try {
+      GenerateMutableProjection.generate(children, inputAttributes)()
+    } catch {
+      case e: Exception =>
+        log.error("Failed to generate mutable projection, fallback to 
interpreted", e)
+        new InterpretedMutableProjection(children, inputAttributes)
+    }
+  }
+
+  val inputToScalaConverters: Any => Any =
+    CatalystTypeConverters.createToScalaConverter(childrenSchema)
+
+  val bufferValuesToCatalystConverters: Array[Any => Any] = 
bufferSchema.fields.map { field =>
+    CatalystTypeConverters.createToCatalystConverter(field.dataType)
+  }
+
+  val bufferValuesToScalaConverters: Array[Any => Any] = 
bufferSchema.fields.map { field =>
+    CatalystTypeConverters.createToScalaConverter(field.dataType)
+  }
+
+  lazy val inputAggregateBuffer: InputAggregationBuffer =
+    new InputAggregationBuffer(
+      bufferSchema,
+      bufferValuesToCatalystConverters,
+      bufferValuesToScalaConverters,
+      bufferOffset,
+      null)
+
+  lazy val mutableAggregateBuffer: MutableAggregationBufferImpl =
+    new MutableAggregationBufferImpl(
+      bufferSchema,
+      bufferValuesToCatalystConverters,
+      bufferValuesToScalaConverters,
+      bufferOffset,
+      null)
+
+
+  override def initialize(buffer: MutableRow): Unit = {
+    mutableAggregateBuffer.underlyingBuffer = buffer
+
+    udaf.initialize(mutableAggregateBuffer)
+  }
+
+  override def update(buffer: MutableRow, input: InternalRow): Unit = {
+    mutableAggregateBuffer.underlyingBuffer = buffer
+
+    udaf.update(
+      mutableAggregateBuffer,
+      inputToScalaConverters(inputProjection(input)).asInstanceOf[Row])
+  }
+
+  override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
+    mutableAggregateBuffer.underlyingBuffer = buffer1
+    inputAggregateBuffer.underlyingInputBuffer = buffer2
+
+    udaf.merge(mutableAggregateBuffer, inputAggregateBuffer)
+  }
+
+  override def eval(buffer: InternalRow = null): Any = {
+    inputAggregateBuffer.underlyingInputBuffer = buffer
+
+    udaf.evaluate(inputAggregateBuffer)
+  }
+
+  override def toString: String = {
+    s"""${udaf.getClass.getSimpleName}(${children.mkString(",")})"""
+  }
+
+  override def nodeName: String = udaf.getClass.getSimpleName
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/55946e76/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala
deleted file mode 100644
index 4ada9ec..0000000
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala
+++ /dev/null
@@ -1,287 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.expressions.aggregate
-
-import org.apache.spark.Logging
-import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
-import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2
-import org.apache.spark.sql.types._
-import org.apache.spark.sql.Row
-
-/**
- * The abstract class for implementing user-defined aggregate function.
- */
-abstract class UserDefinedAggregateFunction extends Serializable {
-
-  /**
-   * A [[StructType]] represents data types of input arguments of this 
aggregate function.
-   * For example, if a [[UserDefinedAggregateFunction]] expects two input 
arguments
-   * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] 
will look like
-   *
-   * ```
-   *   StructType(Seq(StructField("doubleInput", DoubleType), 
StructField("longInput", LongType)))
-   * ```
-   *
-   * The name of a field of this [[StructType]] is only used to identify the 
corresponding
-   * input argument. Users can choose names to identify the input arguments.
-   */
-  def inputSchema: StructType
-
-  /**
-   * A [[StructType]] represents data types of values in the aggregation 
buffer.
-   * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values
-   * (i.e. two intermediate values) with type of [[DoubleType]] and 
[[LongType]],
-   * the returned [[StructType]] will look like
-   *
-   * ```
-   *   StructType(Seq(StructField("doubleInput", DoubleType), 
StructField("longInput", LongType)))
-   * ```
-   *
-   * The name of a field of this [[StructType]] is only used to identify the 
corresponding
-   * buffer value. Users can choose names to identify the input arguments.
-   */
-  def bufferSchema: StructType
-
-  /**
-   * The [[DataType]] of the returned value of this 
[[UserDefinedAggregateFunction]].
-   */
-  def returnDataType: DataType
-
-  /** Indicates if this function is deterministic. */
-  def deterministic: Boolean
-
-  /**
-   *  Initializes the given aggregation buffer. Initial values set by this 
method should satisfy
-   *  the condition that when merging two buffers with initial values, the new 
buffer should
-   *  still store initial values.
-   */
-  def initialize(buffer: MutableAggregationBuffer): Unit
-
-  /** Updates the given aggregation buffer `buffer` with new input data from 
`input`. */
-  def update(buffer: MutableAggregationBuffer, input: Row): Unit
-
-  /** Merges two aggregation buffers and stores the updated buffer values back 
in `buffer1`. */
-  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
-
-  /**
-   * Calculates the final result of this [[UserDefinedAggregateFunction]] 
based on the given
-   * aggregation buffer.
-   */
-  def evaluate(buffer: Row): Any
-}
-
-private[sql] abstract class AggregationBuffer(
-    toCatalystConverters: Array[Any => Any],
-    toScalaConverters: Array[Any => Any],
-    bufferOffset: Int)
-  extends Row {
-
-  override def length: Int = toCatalystConverters.length
-
-  protected val offsets: Array[Int] = {
-    val newOffsets = new Array[Int](length)
-    var i = 0
-    while (i < newOffsets.length) {
-      newOffsets(i) = bufferOffset + i
-      i += 1
-    }
-    newOffsets
-  }
-}
-
-/**
- * A Mutable [[Row]] representing an mutable aggregation buffer.
- */
-class MutableAggregationBuffer private[sql] (
-    schema: StructType,
-    toCatalystConverters: Array[Any => Any],
-    toScalaConverters: Array[Any => Any],
-    bufferOffset: Int,
-    var underlyingBuffer: MutableRow)
-  extends AggregationBuffer(toCatalystConverters, toScalaConverters, 
bufferOffset) {
-
-  override def get(i: Int): Any = {
-    if (i >= length || i < 0) {
-      throw new IllegalArgumentException(
-        s"Could not access ${i}th value in this buffer because it only has 
$length values.")
-    }
-    toScalaConverters(i)(underlyingBuffer.get(offsets(i), schema(i).dataType))
-  }
-
-  def update(i: Int, value: Any): Unit = {
-    if (i >= length || i < 0) {
-      throw new IllegalArgumentException(
-        s"Could not update ${i}th value in this buffer because it only has 
$length values.")
-    }
-    underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value))
-  }
-
-  override def copy(): MutableAggregationBuffer = {
-    new MutableAggregationBuffer(
-      schema,
-      toCatalystConverters,
-      toScalaConverters,
-      bufferOffset,
-      underlyingBuffer)
-  }
-}
-
-/**
- * A [[Row]] representing an immutable aggregation buffer.
- */
-class InputAggregationBuffer private[sql] (
-    schema: StructType,
-    toCatalystConverters: Array[Any => Any],
-    toScalaConverters: Array[Any => Any],
-    bufferOffset: Int,
-    var underlyingInputBuffer: InternalRow)
-  extends AggregationBuffer(toCatalystConverters, toScalaConverters, 
bufferOffset) {
-
-  override def get(i: Int): Any = {
-    if (i >= length || i < 0) {
-      throw new IllegalArgumentException(
-        s"Could not access ${i}th value in this buffer because it only has 
$length values.")
-    }
-    // TODO: Use buffer schema to avoid using generic getter.
-    toScalaConverters(i)(underlyingInputBuffer.get(offsets(i), 
schema(i).dataType))
-  }
-
-  override def copy(): InputAggregationBuffer = {
-    new InputAggregationBuffer(
-      schema,
-      toCatalystConverters,
-      toScalaConverters,
-      bufferOffset,
-      underlyingInputBuffer)
-  }
-}
-
-/**
- * The internal wrapper used to hook a [[UserDefinedAggregateFunction]] `udaf` 
in the
- * internal aggregation code path.
- * @param children
- * @param udaf
- */
-case class ScalaUDAF(
-    children: Seq[Expression],
-    udaf: UserDefinedAggregateFunction)
-  extends AggregateFunction2 with Logging {
-
-  require(
-    children.length == udaf.inputSchema.length,
-    s"$udaf only accepts ${udaf.inputSchema.length} arguments, " +
-      s"but ${children.length} are provided.")
-
-  override def nullable: Boolean = true
-
-  override def dataType: DataType = udaf.returnDataType
-
-  override def deterministic: Boolean = udaf.deterministic
-
-  override val inputTypes: Seq[DataType] = udaf.inputSchema.map(_.dataType)
-
-  override val bufferSchema: StructType = udaf.bufferSchema
-
-  override val bufferAttributes: Seq[AttributeReference] = 
bufferSchema.toAttributes
-
-  override lazy val cloneBufferAttributes = 
bufferAttributes.map(_.newInstance())
-
-  val childrenSchema: StructType = {
-    val inputFields = children.zipWithIndex.map {
-      case (child, index) =>
-        StructField(s"input$index", child.dataType, child.nullable, 
Metadata.empty)
-    }
-    StructType(inputFields)
-  }
-
-  lazy val inputProjection = {
-    val inputAttributes = childrenSchema.toAttributes
-    log.debug(
-      s"Creating MutableProj: $children, inputSchema: $inputAttributes.")
-    try {
-      GenerateMutableProjection.generate(children, inputAttributes)()
-    } catch {
-      case e: Exception =>
-        log.error("Failed to generate mutable projection, fallback to 
interpreted", e)
-        new InterpretedMutableProjection(children, inputAttributes)
-    }
-  }
-
-  val inputToScalaConverters: Any => Any =
-    CatalystTypeConverters.createToScalaConverter(childrenSchema)
-
-  val bufferValuesToCatalystConverters: Array[Any => Any] = 
bufferSchema.fields.map { field =>
-    CatalystTypeConverters.createToCatalystConverter(field.dataType)
-  }
-
-  val bufferValuesToScalaConverters: Array[Any => Any] = 
bufferSchema.fields.map { field =>
-    CatalystTypeConverters.createToScalaConverter(field.dataType)
-  }
-
-  lazy val inputAggregateBuffer: InputAggregationBuffer =
-    new InputAggregationBuffer(
-      bufferSchema,
-      bufferValuesToCatalystConverters,
-      bufferValuesToScalaConverters,
-      bufferOffset,
-      null)
-
-  lazy val mutableAggregateBuffer: MutableAggregationBuffer =
-    new MutableAggregationBuffer(
-      bufferSchema,
-      bufferValuesToCatalystConverters,
-      bufferValuesToScalaConverters,
-      bufferOffset,
-      null)
-
-
-  override def initialize(buffer: MutableRow): Unit = {
-    mutableAggregateBuffer.underlyingBuffer = buffer
-
-    udaf.initialize(mutableAggregateBuffer)
-  }
-
-  override def update(buffer: MutableRow, input: InternalRow): Unit = {
-    mutableAggregateBuffer.underlyingBuffer = buffer
-
-    udaf.update(
-      mutableAggregateBuffer,
-      inputToScalaConverters(inputProjection(input)).asInstanceOf[Row])
-  }
-
-  override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
-    mutableAggregateBuffer.underlyingBuffer = buffer1
-    inputAggregateBuffer.underlyingInputBuffer = buffer2
-
-    udaf.merge(mutableAggregateBuffer, inputAggregateBuffer)
-  }
-
-  override def eval(buffer: InternalRow = null): Any = {
-    inputAggregateBuffer.underlyingInputBuffer = buffer
-
-    udaf.evaluate(inputAggregateBuffer)
-  }
-
-  override def toString: String = {
-    s"""${udaf.getClass.getSimpleName}(${children.mkString(",")})"""
-  }
-
-  override def nodeName: String = udaf.getClass.getSimpleName
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/55946e76/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
new file mode 100644
index 0000000..278dd43
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types._
+import org.apache.spark.annotation.Experimental
+
+/**
+ * :: Experimental ::
+ * The abstract class for implementing user-defined aggregate functions.
+ */
+@Experimental
+abstract class UserDefinedAggregateFunction extends Serializable {
+
+  /**
+   * A [[StructType]] represents data types of input arguments of this 
aggregate function.
+   * For example, if a [[UserDefinedAggregateFunction]] expects two input 
arguments
+   * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] 
will look like
+   *
+   * ```
+   *   new StructType()
+   *    .add("doubleInput", DoubleType)
+   *    .add("longInput", LongType)
+   * ```
+   *
+   * The name of a field of this [[StructType]] is only used to identify the 
corresponding
+   * input argument. Users can choose names to identify the input arguments.
+   */
+  def inputSchema: StructType
+
+  /**
+   * A [[StructType]] represents data types of values in the aggregation 
buffer.
+   * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values
+   * (i.e. two intermediate values) with type of [[DoubleType]] and 
[[LongType]],
+   * the returned [[StructType]] will look like
+   *
+   * ```
+   *   new StructType()
+   *    .add("doubleInput", DoubleType)
+   *    .add("longInput", LongType)
+   * ```
+   *
+   * The name of a field of this [[StructType]] is only used to identify the 
corresponding
+   * buffer value. Users can choose names to identify the input arguments.
+   */
+  def bufferSchema: StructType
+
+  /**
+   * The [[DataType]] of the returned value of this 
[[UserDefinedAggregateFunction]].
+   */
+  def returnDataType: DataType
+
+  /** Indicates if this function is deterministic. */
+  def deterministic: Boolean
+
+  /**
+   *  Initializes the given aggregation buffer. Initial values set by this 
method should satisfy
+   *  the condition that when merging two buffers with initial values, the new 
buffer
+   *  still store initial values.
+   */
+  def initialize(buffer: MutableAggregationBuffer): Unit
+
+  /** Updates the given aggregation buffer `buffer` with new input data from 
`input`. */
+  def update(buffer: MutableAggregationBuffer, input: Row): Unit
+
+  /** Merges two aggregation buffers and stores the updated buffer values back 
to `buffer1`. */
+  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
+
+  /**
+   * Calculates the final result of this [[UserDefinedAggregateFunction]] 
based on the given
+   * aggregation buffer.
+   */
+  def evaluate(buffer: Row): Any
+}
+
+/**
+ * :: Experimental ::
+ * A [[Row]] representing an mutable aggregation buffer.
+ */
+@Experimental
+trait MutableAggregationBuffer extends Row {
+
+  /** Update the ith value of this buffer. */
+  def update(i: Int, value: Any): Unit
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/55946e76/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
 
b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
index 5c9d0e9..a2247e3 100644
--- 
a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
+++ 
b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
@@ -21,13 +21,18 @@ import java.util.ArrayList;
 import java.util.List;
 
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer;
-import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction;
+import org.apache.spark.sql.expressions.MutableAggregationBuffer;
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
 import org.apache.spark.sql.types.DataType;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 
+/**
+ * An example {@link UserDefinedAggregateFunction} to calculate a special 
average value of a
+ * {@link org.apache.spark.sql.types.DoubleType} column. This special average 
value is the sum
+ * of the average value of input values and 100.0.
+ */
 public class MyDoubleAvg extends UserDefinedAggregateFunction {
 
   private StructType _inputDataType;
@@ -37,10 +42,13 @@ public class MyDoubleAvg extends 
UserDefinedAggregateFunction {
   private DataType _returnDataType;
 
   public MyDoubleAvg() {
-    List<StructField> inputfields = new ArrayList<StructField>();
-    inputfields.add(DataTypes.createStructField("inputDouble", 
DataTypes.DoubleType, true));
-    _inputDataType = DataTypes.createStructType(inputfields);
+    List<StructField> inputFields = new ArrayList<StructField>();
+    inputFields.add(DataTypes.createStructField("inputDouble", 
DataTypes.DoubleType, true));
+    _inputDataType = DataTypes.createStructType(inputFields);
 
+    // The buffer has two values, bufferSum for storing the current sum and
+    // bufferCount for storing the number of non-null input values that have 
been contribuetd
+    // to the current sum.
     List<StructField> bufferFields = new ArrayList<StructField>();
     bufferFields.add(DataTypes.createStructField("bufferSum", 
DataTypes.DoubleType, true));
     bufferFields.add(DataTypes.createStructField("bufferCount", 
DataTypes.LongType, true));
@@ -66,16 +74,23 @@ public class MyDoubleAvg extends 
UserDefinedAggregateFunction {
   }
 
   @Override public void initialize(MutableAggregationBuffer buffer) {
+    // The initial value of the sum is null.
     buffer.update(0, null);
+    // The initial value of the count is 0.
     buffer.update(1, 0L);
   }
 
   @Override public void update(MutableAggregationBuffer buffer, Row input) {
+    // This input Row only has a single column storing the input value in 
Double.
+    // We only update the buffer when the input value is not null.
     if (!input.isNullAt(0)) {
+      // If the buffer value (the intermediate result of the sum) is still 
null,
+      // we set the input value to the buffer and set the bufferCount to 1.
       if (buffer.isNullAt(0)) {
         buffer.update(0, input.getDouble(0));
         buffer.update(1, 1L);
       } else {
+        // Otherwise, update the bufferSum and increment bufferCount.
         Double newValue = input.getDouble(0) + buffer.getDouble(0);
         buffer.update(0, newValue);
         buffer.update(1, buffer.getLong(1) + 1L);
@@ -84,11 +99,16 @@ public class MyDoubleAvg extends 
UserDefinedAggregateFunction {
   }
 
   @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
+    // buffer1 and buffer2 have the same structure.
+    // We only update the buffer1 when the input buffer2's sum value is not 
null.
     if (!buffer2.isNullAt(0)) {
       if (buffer1.isNullAt(0)) {
+        // If the buffer value (intermediate result of the sum) is still null,
+        // we set the it as the input buffer's value.
         buffer1.update(0, buffer2.getDouble(0));
         buffer1.update(1, buffer2.getLong(1));
       } else {
+        // Otherwise, we update the bufferSum and bufferCount.
         Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
         buffer1.update(0, newValue);
         buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1));
@@ -98,10 +118,12 @@ public class MyDoubleAvg extends 
UserDefinedAggregateFunction {
 
   @Override public Object evaluate(Row buffer) {
     if (buffer.isNullAt(0)) {
+      // If the bufferSum is still null, we return null because this function 
has not got
+      // any input row.
       return null;
     } else {
+      // Otherwise, we calculate the special average value.
       return buffer.getDouble(0) / buffer.getLong(1) + 100.0;
     }
   }
 }
-

http://git-wip-us.apache.org/repos/asf/spark/blob/55946e76/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
 
b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
index 1d4587a..da29e24 100644
--- 
a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
+++ 
b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
@@ -20,14 +20,18 @@ package test.org.apache.spark.sql.hive.aggregate;
 import java.util.ArrayList;
 import java.util.List;
 
-import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer;
-import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction;
+import org.apache.spark.sql.expressions.MutableAggregationBuffer;
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 import org.apache.spark.sql.types.DataType;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.Row;
 
+/**
+ * An example {@link UserDefinedAggregateFunction} to calculate the sum of a
+ * {@link org.apache.spark.sql.types.DoubleType} column.
+ */
 public class MyDoubleSum extends UserDefinedAggregateFunction {
 
   private StructType _inputDataType;
@@ -37,9 +41,9 @@ public class MyDoubleSum extends UserDefinedAggregateFunction 
{
   private DataType _returnDataType;
 
   public MyDoubleSum() {
-    List<StructField> inputfields = new ArrayList<StructField>();
-    inputfields.add(DataTypes.createStructField("inputDouble", 
DataTypes.DoubleType, true));
-    _inputDataType = DataTypes.createStructType(inputfields);
+    List<StructField> inputFields = new ArrayList<StructField>();
+    inputFields.add(DataTypes.createStructField("inputDouble", 
DataTypes.DoubleType, true));
+    _inputDataType = DataTypes.createStructType(inputFields);
 
     List<StructField> bufferFields = new ArrayList<StructField>();
     bufferFields.add(DataTypes.createStructField("bufferDouble", 
DataTypes.DoubleType, true));
@@ -65,14 +69,20 @@ public class MyDoubleSum extends 
UserDefinedAggregateFunction {
   }
 
   @Override public void initialize(MutableAggregationBuffer buffer) {
+    // The initial value of the sum is null.
     buffer.update(0, null);
   }
 
   @Override public void update(MutableAggregationBuffer buffer, Row input) {
+    // This input Row only has a single column storing the input value in 
Double.
+    // We only update the buffer when the input value is not null.
     if (!input.isNullAt(0)) {
       if (buffer.isNullAt(0)) {
+        // If the buffer value (the intermediate result of the sum) is still 
null,
+        // we set the input value to the buffer.
         buffer.update(0, input.getDouble(0));
       } else {
+        // Otherwise, we add the input value to the buffer value.
         Double newValue = input.getDouble(0) + buffer.getDouble(0);
         buffer.update(0, newValue);
       }
@@ -80,10 +90,16 @@ public class MyDoubleSum extends 
UserDefinedAggregateFunction {
   }
 
   @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
+    // buffer1 and buffer2 have the same structure.
+    // We only update the buffer1 when the input buffer2's value is not null.
     if (!buffer2.isNullAt(0)) {
       if (buffer1.isNullAt(0)) {
+        // If the buffer value (intermediate result of the sum) is still null,
+        // we set the it as the input buffer's value.
         buffer1.update(0, buffer2.getDouble(0));
       } else {
+        // Otherwise, we add the input buffer's value (buffer1) to the mutable
+        // buffer's value (buffer2).
         Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
         buffer1.update(0, newValue);
       }
@@ -92,8 +108,10 @@ public class MyDoubleSum extends 
UserDefinedAggregateFunction {
 
   @Override public Object evaluate(Row buffer) {
     if (buffer.isNullAt(0)) {
+      // If the buffer value is still null, we return null.
       return null;
     } else {
+      // Otherwise, the intermediate sum is the final result.
       return buffer.getDouble(0);
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to