Github user viirya commented on a diff in the pull request:

    https://github.com/apache/spark/pull/15544#discussion_r140140577
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
 ---
    @@ -0,0 +1,235 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.catalyst.expressions.aggregate
    +
    +import java.util
    +
    +import org.apache.spark.sql.catalyst.InternalRow
    +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
    +import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, 
TypeCheckSuccess}
    +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
ExpectsInputTypes, Expression}
    +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, 
HyperLogLogPlusPlusHelper}
    +import org.apache.spark.sql.types._
    +
    +/**
    + * This function counts the approximate number of distinct values (ndv) in
    + * intervals constructed from endpoints specified in 
`endpointsExpression`. The endpoints should be
    + * sorted into ascending order. E.g., given an array of endpoints
    + * (endpoint_1, endpoint_2, ... endpoint_N), returns the approximate ndv's 
for intervals
    + * [endpoint_1, endpoint_2], (endpoint_2, endpoint_3], ... (endpoint_N-1, 
endpoint_N].
    + * To count ndv's in these intervals, apply the HyperLogLogPlusPlus 
algorithm in each of them.
    + * @param child to estimate the ndv's of.
    + * @param endpointsExpression to construct the intervals, should be sorted 
into ascending order.
    + * @param relativeSD The maximum estimation error allowed in the 
HyperLogLogPlusPlus algorithm.
    + */
    +case class ApproxCountDistinctForIntervals(
    +    child: Expression,
    +    endpointsExpression: Expression,
    +    relativeSD: Double = 0.05,
    +    mutableAggBufferOffset: Int = 0,
    +    inputAggBufferOffset: Int = 0)
    +  extends ImperativeAggregate with ExpectsInputTypes {
    +
    +  def this(child: Expression, endpointsExpression: Expression) = {
    +    this(
    +      child = child,
    +      endpointsExpression = endpointsExpression,
    +      relativeSD = 0.05,
    +      mutableAggBufferOffset = 0,
    +      inputAggBufferOffset = 0)
    +  }
    +
    +  def this(child: Expression, endpointsExpression: Expression, relativeSD: 
Expression) = {
    +    this(
    +      child = child,
    +      endpointsExpression = endpointsExpression,
    +      relativeSD = HyperLogLogPlusPlus.validateDoubleLiteral(relativeSD),
    +      mutableAggBufferOffset = 0,
    +      inputAggBufferOffset = 0)
    +  }
    +
    +  override def inputTypes: Seq[AbstractDataType] = {
    +    Seq(TypeCollection(NumericType, TimestampType, DateType), ArrayType)
    +  }
    +
    +  // Mark as lazy so that endpointsExpression is not evaluated during tree 
transformation.
    +  lazy val endpoints: Array[Double] =
    +    (endpointsExpression.dataType, endpointsExpression.eval()) match {
    +      case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) =>
    +        val numericArray = arrayData.toObjectArray(baseType)
    +        numericArray.map { x =>
    +          baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType])
    +        }
    +    }
    +
    +  override def checkInputDataTypes(): TypeCheckResult = {
    +    val defaultCheck = super.checkInputDataTypes()
    +    if (defaultCheck.isFailure) {
    +      defaultCheck
    +    } else if (!endpointsExpression.foldable) {
    +      TypeCheckFailure("The intervals provided must be constant literals")
    +    } else if (endpoints.length < 2) {
    +      TypeCheckFailure("The number of endpoints must be >= 2 to construct 
intervals")
    +    } else {
    +      TypeCheckSuccess
    +    }
    +  }
    +
    +  // N endpoints construct N-1 intervals, creating a HLLPP for each 
interval
    +  private lazy val hllppArray = {
    +    val array = new Array[HyperLogLogPlusPlusHelper](endpoints.length - 1)
    +    for (i <- array.indices) {
    +      array(i) = new HyperLogLogPlusPlusHelper(relativeSD)
    +    }
    +    // `numWords` in each HLLPPHelper should be the same because it is 
determined by `relativeSD`
    +    // which is shared among all HLLPPHelpers.
    +    assert(array.map(_.numWords).distinct.length == 1)
    +    array
    +  }
    +
    +  private lazy val numWordsPerHllpp = hllppArray.head.numWords
    +
    +  private lazy val totalNumWords = numWordsPerHllpp * hllppArray.length
    +
    +  /** Allocate enough words to store all registers. */
    +  override lazy val aggBufferAttributes: Seq[AttributeReference] = {
    +    Seq.tabulate(totalNumWords) { i =>
    +      AttributeReference(s"MS[$i]", LongType)()
    +    }
    +  }
    +
    +  override def aggBufferSchema: StructType = 
StructType.fromAttributes(aggBufferAttributes)
    +
    +  // Note: although this simply copies aggBufferAttributes, this common 
code can not be placed
    +  // in the superclass because that will lead to initialization ordering 
issues.
    +  override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
    +    aggBufferAttributes.map(_.newInstance())
    +
    +  /** Fill all words with zeros. */
    +  override def initialize(buffer: InternalRow): Unit = {
    +    var word = 0
    +    while (word < totalNumWords) {
    +      buffer.setLong(mutableAggBufferOffset + word, 0)
    +      word += 1
    +    }
    +  }
    +
    +  override def update(buffer: InternalRow, input: InternalRow): Unit = {
    +    val value = child.eval(input)
    +    // Ignore empty rows
    +    if (value != null) {
    +      // convert the value into a double value for searching in the double 
array
    +      val doubleValue = child.dataType match {
    +        case n: NumericType =>
    +          n.numeric.toDouble(value.asInstanceOf[n.InternalType])
    +        case _: DateType =>
    +          value.asInstanceOf[Int].toDouble
    +        case _: TimestampType =>
    +          value.asInstanceOf[Long].toDouble
    +      }
    +
    +      // endpoints are sorted into ascending order already
    +      if (endpoints.head > doubleValue || endpoints.last < doubleValue) {
    +        // ignore if the value is out of the whole range
    +        return
    +      }
    +
    +      val hllppIndex = findHllppIndex(doubleValue)
    +      val offset = mutableAggBufferOffset + hllppIndex * numWordsPerHllpp
    +      hllppArray(hllppIndex).update(buffer, offset, value, child.dataType)
    +    }
    +  }
    +
    +  // Find which interval (HyperLogLogPlusPlusHelper) should receive the 
given value.
    +  def findHllppIndex(value: Double): Int = {
    +    var index = util.Arrays.binarySearch(endpoints, value)
    +    if (index >= 0) {
    +      // The value is found.
    +      if (index == 0) {
    +        0
    +      } else {
    +        // If the endpoints contains multiple elements with the specified 
value, there is no
    +        // guarantee which one binarySearch will return. We remove this 
uncertainty by moving the
    +        // index to the first position of these elements.
    --- End diff --
    
    We'd better explain this in the expression comment.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to