Github user sethah commented on a diff in the pull request:

    https://github.com/apache/spark/pull/9003#discussion_r42677457
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
 ---
    @@ -857,3 +857,329 @@ object HyperLogLogPlusPlus {
       )
       // scalastyle:on
     }
    +
    +/**
    + * A central moment is the expected value of a specified power of the 
deviation of a random
    + * variable from the mean. Central moments are often used to characterize 
the properties of about
    + * the shape of a distribution.
    + *
    + * This class implements online, one-pass algorithms for computing the 
central moments of a set of
    + * points.
    + *
    + * References:
    + *  - Xiangrui Meng.  "Simpler Online Updates for Arbitrary-Order Central 
Moments."
    + *      2015. http://arxiv.org/abs/1510.04923
    + *
    + * @see [[https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
    + *     Algorithms for calculating variance (Wikipedia)]]
    + *
    + * @param child to compute central moments of.
    + */
    +abstract class CentralMomentAgg(child: Expression) extends 
ImperativeAggregate with Serializable {
    +
    +  /**
    +   * The maximum central moment order to be computed.
    +   */
    +  protected def momentOrder: Int
    +
    +  /**
    +   * Array of sufficient moments need to compute the aggregate statistic.
    +   */
    +  protected def sufficientMoments: Array[Int]
    +
    +  override def children: Seq[Expression] = Seq(child)
    +
    +  override def nullable: Boolean = false
    +
    +  override def dataType: DataType = DoubleType
    +
    +  // Expected input data type.
    +  // TODO: Right now, we replace old aggregate functions (based on 
AggregateExpression1) to the
    +  // new version at planning time (after analysis phase). For now, 
NullType is added at here
    +  // to make it resolved when we have cases like `select avg(null)`.
    +  // We can use our analyzer to cast NullType to the default data type of 
the NumericType once
    +  // we remove the old aggregate functions. Then, we will not need 
NullType at here.
    +  override def inputTypes: Seq[AbstractDataType] = 
Seq(TypeCollection(NumericType, NullType))
    +
    +  override def aggBufferSchema: StructType = 
StructType.fromAttributes(aggBufferAttributes)
    +
    +  /**
    +   * The number of central moments to store in the buffer.
    +   */
    +  private[this] val numMoments = 5
    +
    +  override val aggBufferAttributes: Seq[AttributeReference] = 
Seq.tabulate(numMoments) { i =>
    +    AttributeReference(s"M$i", DoubleType)()
    +  }
    +
    +  // Note: although this simply copies aggBufferAttributes, this common 
code can not be placed
    +  // in the superclass because that will lead to initialization ordering 
issues.
    +  override val inputAggBufferAttributes: Seq[AttributeReference] =
    +    aggBufferAttributes.map(_.newInstance())
    +
    +  /**
    +   * Initialize all moments to zero.
    +   */
    +  override def initialize(buffer: MutableRow): Unit = {
    +    var aggIndex = 0
    +    while (aggIndex < numMoments) {
    +      buffer.setDouble(mutableAggBufferOffset + aggIndex, 0.0)
    +      aggIndex += 1
    +    }
    +  }
    +
    +  // frequently used values for online updates
    +  private[this] var delta = 0.0
    +  private[this] var deltaN = 0.0
    +  private[this] var delta2 = 0.0
    +  private[this] var deltaN2 = 0.0
    +
    +  /**
    +   * Update the central moments buffer.
    +   */
    +  override def update(buffer: MutableRow, input: InternalRow): Unit = {
    +    val v = Cast(child, DoubleType).eval(input)
    +    if (v != null) {
    +      val updateValue = v match {
    +        case d: Double => d
    +        case _ => 0.0
    +      }
    +      var n = buffer.getDouble(mutableAggBufferOffset)
    +      var mean = buffer.getDouble(mutableAggBufferOffset + 1)
    +      var m2 = 0.0
    +      var m3 = 0.0
    +      var m4 = 0.0
    +
    +      n += 1.0
    +      delta = updateValue - mean
    +      deltaN = delta / n
    +      mean += deltaN
    +      buffer.setDouble(mutableAggBufferOffset, n)
    +      buffer.setDouble(mutableAggBufferOffset + 1, mean)
    +
    +      if (momentOrder >= 2) {
    +        m2 = buffer.getDouble(mutableAggBufferOffset + 2)
    +        m2 += delta * (delta - deltaN)
    +        buffer.setDouble(mutableAggBufferOffset + 2, m2)
    +      }
    +
    +      if (momentOrder >= 3) {
    +        delta2 = delta * delta
    +        deltaN2 = deltaN * deltaN
    +        m3 = buffer.getDouble(mutableAggBufferOffset + 3)
    +        m3 += -3.0 * deltaN * m2 + delta * (delta2 - deltaN2)
    +        buffer.setDouble(mutableAggBufferOffset + 3, m3)
    +      }
    +
    +      if (momentOrder >= 4) {
    +        m4 = buffer.getDouble(mutableAggBufferOffset + 4)
    +        m4 += -4.0 * deltaN * m3 - 6.0 * deltaN2 * m2 +
    +          delta * (delta * delta2 - deltaN * deltaN2)
    +        buffer.setDouble(mutableAggBufferOffset + 4, m4)
    +      }
    +    }
    +  }
    +
    +  /**
    +   * Merge two central moment buffers.
    +   */
    +  override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
    +    val n1 = buffer1.getDouble(mutableAggBufferOffset)
    +    val n2 = buffer2.getDouble(inputAggBufferOffset)
    +    val mean1 = buffer1.getDouble(mutableAggBufferOffset + 1)
    +    val mean2 = buffer2.getDouble(inputAggBufferOffset + 1)
    +
    +    var secondMoment1 = 0.0
    +    var secondMoment2 = 0.0
    +    var secondMoment = 0.0
    +
    +    var thirdMoment1 = 0.0
    +    var thirdMoment2 = 0.0
    +    var thirdMoment = 0.0
    +
    +    var fourthMoment1 = 0.0
    +    var fourthMoment2 = 0.0
    +    var fourthMoment = 0.0
    +
    +    val n = n1 + n2
    +    delta = mean2 - mean1
    +    deltaN = delta / n
    +    val mean = mean1 + deltaN * n2
    +
    +    buffer1.setDouble(mutableAggBufferOffset, n)
    +    buffer1.setDouble(mutableAggBufferOffset + 1, mean)
    +
    +    // higher order moments computed according to:
    +    // 
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics
    +    if (momentOrder >= 2) {
    +      secondMoment1 = buffer1.getDouble(mutableAggBufferOffset + 2)
    +      secondMoment2 = buffer2.getDouble(inputAggBufferOffset + 2)
    +      secondMoment = secondMoment1 + secondMoment2 + delta * deltaN * n1 * 
n2
    +      buffer1.setDouble(mutableAggBufferOffset + 2, secondMoment)
    +    }
    +
    +
    +    if (momentOrder >= 3) {
    +      thirdMoment1 = buffer1.getDouble(mutableAggBufferOffset + 3)
    +      thirdMoment2 = buffer2.getDouble(inputAggBufferOffset + 3)
    +      thirdMoment = thirdMoment1 + thirdMoment2 + deltaN * deltaN * delta 
* n1 * n2 *
    +        (n1 - n2) + 3.0 * deltaN * (n1 * secondMoment2 - n2 * 
secondMoment1)
    +      buffer1.setDouble(mutableAggBufferOffset + 3, thirdMoment)
    +    }
    +
    +    if (momentOrder >= 4) {
    +      fourthMoment1 = buffer1.getDouble(mutableAggBufferOffset + 4)
    +      fourthMoment2 = buffer2.getDouble(inputAggBufferOffset + 4)
    +      fourthMoment = fourthMoment1 + fourthMoment2 + deltaN * deltaN * 
deltaN * delta * n1 *
    +        n2 * (n1 * n1 - n1 * n2 + n2 * n2) + deltaN * deltaN * 6.0 *
    +        (n1 * n1 * secondMoment2 + n2 * n2 * secondMoment1) +
    +        4.0 * deltaN * (n1 * thirdMoment2 - n2 * thirdMoment1)
    +      buffer1.setDouble(mutableAggBufferOffset + 4, fourthMoment)
    +    }
    +  }
    +
    +  /**
    +   * Compute aggregate statistic from sufficient moments.
    +   */
    +  def getStatistic(n: Double, moments: Array[Double]): Double
    +
    +  override final def eval(buffer: InternalRow): Any = {
    +    val n = buffer.getDouble(mutableAggBufferOffset)
    +    val moments = sufficientMoments.map { momentIdx =>
    --- End diff --
    
    Per Xiangrui's comments about passing buffer to subclasses, I think it's 
necessary to do it this way. We don't have to pass the InternalRow object to 
the subclasses, and just pass the sufficient moments instead. See below 
discussion for more details.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to