Github user cloud-fan commented on a diff in the pull request:

    https://github.com/apache/spark/pull/15544#discussion_r139600490
  
    --- Diff: 
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala
 ---
    @@ -0,0 +1,207 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.catalyst.expressions.aggregate
    +
    +import org.apache.spark.SparkFunSuite
    +import org.apache.spark.sql.catalyst.InternalRow
    +import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
    +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
BoundReference, CreateArray, Literal, SpecificInternalRow}
    +import org.apache.spark.sql.catalyst.util.ArrayData
    +import org.apache.spark.sql.types._
    +
    +class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
    +
    +  test("fails analysis if parameters are invalid") {
    +    def assertEqual[T](left: T, right: T): Unit = {
    +      assert(left == right)
    +    }
    +
    +    val wrongColumnTypes = Seq(BinaryType, BooleanType, StringType, 
ArrayType(IntegerType),
    +      MapType(IntegerType, IntegerType), StructType(Seq(StructField("s", 
IntegerType))))
    +    wrongColumnTypes.foreach { dataType =>
    +      val wrongColumn = new ApproxCountDistinctForIntervals(
    +        AttributeReference("a", dataType)(),
    +        endpointsExpression = CreateArray(Seq(1, 10).map(Literal(_))))
    +      assert(
    +        wrongColumn.checkInputDataTypes() match {
    +          case TypeCheckFailure(msg)
    +            if msg.contains("requires (numeric or timestamp or date) 
type") => true
    +          case _ => false
    +        })
    +    }
    +
    +    var wrongEndpoints = new ApproxCountDistinctForIntervals(
    +      AttributeReference("a", DoubleType)(),
    +      endpointsExpression = Literal(0.5d))
    +    assert(
    +      wrongEndpoints.checkInputDataTypes() match {
    +        case TypeCheckFailure(msg) if msg.contains("requires array type") 
=> true
    +        case _ => false
    +      })
    +
    +    wrongEndpoints = new ApproxCountDistinctForIntervals(
    +      AttributeReference("a", DoubleType)(),
    +      endpointsExpression = CreateArray(Seq(AttributeReference("b", 
DoubleType)())))
    +    assertEqual(
    +      wrongEndpoints.checkInputDataTypes(),
    +      TypeCheckFailure("The intervals provided must be constant literals"))
    +
    +    wrongEndpoints = new ApproxCountDistinctForIntervals(
    +      AttributeReference("a", DoubleType)(),
    +      endpointsExpression = CreateArray(Array(10L).map(Literal(_))))
    +    assertEqual(
    +      wrongEndpoints.checkInputDataTypes(),
    +      TypeCheckFailure("The number of endpoints must be >= 2 to construct 
intervals"))
    +  }
    +
    +  /** Create an ApproxCountDistinctForIntervals instance and an input and 
output buffer. */
    +  private def createEstimator(
    +      endpoints: Array[Double],
    +      rsd: Double = 0.05,
    +      dt: DataType = IntegerType): (ApproxCountDistinctForIntervals, 
InternalRow, InternalRow) = {
    +    val input = new SpecificInternalRow(Seq(dt))
    +    val aggFunc = ApproxCountDistinctForIntervals(
    +      BoundReference(0, dt, nullable = true), 
CreateArray(endpoints.map(Literal(_))), rsd)
    +    val buffer = createBuffer(aggFunc)
    +    (aggFunc, input, buffer)
    +  }
    +
    +  private def createBuffer(aggFunc: ApproxCountDistinctForIntervals): 
InternalRow = {
    +    val buffer = new 
SpecificInternalRow(aggFunc.aggBufferAttributes.map(_.dataType))
    +    aggFunc.initialize(buffer)
    +    buffer
    +  }
    +
    +  test("merging ApproxCountDistinctForIntervals instances") {
    +    val (aggFunc, input, buffer1a) = createEstimator(Array[Double](0, 10, 
2000, 345678, 1000000))
    +    val buffer1b = createBuffer(aggFunc)
    +    val buffer2 = createBuffer(aggFunc)
    +
    +    // Create the
    --- End diff --
    
    typo?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to