Github user cloud-fan commented on a diff in the pull request: https://github.com/apache/spark/pull/15544#discussion_r139600490 --- Diff: sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala --- @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, CreateArray, Literal, SpecificInternalRow} +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types._ + +class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { + + test("fails analysis if parameters are invalid") { + def assertEqual[T](left: T, right: T): Unit = { + assert(left == right) + } + + val wrongColumnTypes = Seq(BinaryType, BooleanType, StringType, ArrayType(IntegerType), + MapType(IntegerType, IntegerType), StructType(Seq(StructField("s", IntegerType)))) + wrongColumnTypes.foreach { dataType => + val wrongColumn = new ApproxCountDistinctForIntervals( + AttributeReference("a", dataType)(), + endpointsExpression = CreateArray(Seq(1, 10).map(Literal(_)))) + assert( + wrongColumn.checkInputDataTypes() match { + case TypeCheckFailure(msg) + if msg.contains("requires (numeric or timestamp or date) type") => true + case _ => false + }) + } + + var wrongEndpoints = new ApproxCountDistinctForIntervals( + AttributeReference("a", DoubleType)(), + endpointsExpression = Literal(0.5d)) + assert( + wrongEndpoints.checkInputDataTypes() match { + case TypeCheckFailure(msg) if msg.contains("requires array type") => true + case _ => false + }) + + wrongEndpoints = new ApproxCountDistinctForIntervals( + AttributeReference("a", DoubleType)(), + endpointsExpression = CreateArray(Seq(AttributeReference("b", DoubleType)()))) + assertEqual( + wrongEndpoints.checkInputDataTypes(), + TypeCheckFailure("The intervals provided must be constant literals")) + + wrongEndpoints = new ApproxCountDistinctForIntervals( + AttributeReference("a", DoubleType)(), + endpointsExpression = CreateArray(Array(10L).map(Literal(_)))) + assertEqual( + wrongEndpoints.checkInputDataTypes(), + TypeCheckFailure("The number of endpoints must be >= 2 to construct intervals")) + } + + /** Create an ApproxCountDistinctForIntervals instance and an input and output buffer. */ + private def createEstimator( + endpoints: Array[Double], + rsd: Double = 0.05, + dt: DataType = IntegerType): (ApproxCountDistinctForIntervals, InternalRow, InternalRow) = { + val input = new SpecificInternalRow(Seq(dt)) + val aggFunc = ApproxCountDistinctForIntervals( + BoundReference(0, dt, nullable = true), CreateArray(endpoints.map(Literal(_))), rsd) + val buffer = createBuffer(aggFunc) + (aggFunc, input, buffer) + } + + private def createBuffer(aggFunc: ApproxCountDistinctForIntervals): InternalRow = { + val buffer = new SpecificInternalRow(aggFunc.aggBufferAttributes.map(_.dataType)) + aggFunc.initialize(buffer) + buffer + } + + test("merging ApproxCountDistinctForIntervals instances") { + val (aggFunc, input, buffer1a) = createEstimator(Array[Double](0, 10, 2000, 345678, 1000000)) + val buffer1b = createBuffer(aggFunc) + val buffer2 = createBuffer(aggFunc) + + // Create the --- End diff -- typo?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org