This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 08123a3 [SPARK-37138][SQL] Support ANSI Interval types in ApproxCountDistinctForIntervals/ApproximatePercentile/Percentile 08123a3 is described below commit 08123a3795683238352e5bf55452de381349fdd9 Author: Angerszhuuuu <angers....@gmail.com> AuthorDate: Sat Oct 30 20:03:20 2021 +0300 [SPARK-37138][SQL] Support ANSI Interval types in ApproxCountDistinctForIntervals/ApproximatePercentile/Percentile ### What changes were proposed in this pull request? Support Ansi Interval types in the agg expressions: - ApproxCountDistinctForIntervals - ApproximatePercentile - Percentile ### Why are the changes needed? To improve user experience with Spark SQL. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added new UT. Closes #34412 from AngersZhuuuu/SPARK-37138. Authored-by: Angerszhuuuu <angers....@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- .../ApproxCountDistinctForIntervals.scala | 13 +++--- .../aggregate/ApproximatePercentile.scala | 32 ++++++++------ .../expressions/aggregate/Percentile.scala | 26 +++++++++--- .../ApproxCountDistinctForIntervalsSuite.scala | 6 ++- .../expressions/aggregate/PercentileSuite.scala | 8 ++-- ...ApproxCountDistinctForIntervalsQuerySuite.scala | 28 +++++++++++++ .../sql/ApproximatePercentileQuerySuite.scala | 22 +++++++++- .../apache/spark/sql/PercentileQuerySuite.scala | 49 ++++++++++++++++++++++ 8 files changed, 153 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala index a7e9a22..f3bf251 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala @@ -61,7 +61,8 @@ case class ApproxCountDistinctForIntervals( } override def inputTypes: Seq[AbstractDataType] = { - Seq(TypeCollection(NumericType, TimestampType, DateType, TimestampNTZType), ArrayType) + Seq(TypeCollection(NumericType, TimestampType, DateType, TimestampNTZType, + YearMonthIntervalType, DayTimeIntervalType), ArrayType) } // Mark as lazy so that endpointsExpression is not evaluated during tree transformation. @@ -79,14 +80,16 @@ case class ApproxCountDistinctForIntervals( TypeCheckFailure("The endpoints provided must be constant literals") } else { endpointsExpression.dataType match { - case ArrayType(_: NumericType | DateType | TimestampType | TimestampNTZType, _) => + case ArrayType(_: NumericType | DateType | TimestampType | TimestampNTZType | + _: AnsiIntervalType, _) => if (endpoints.length < 2) { TypeCheckFailure("The number of endpoints must be >= 2 to construct intervals") } else { TypeCheckSuccess } case _ => - TypeCheckFailure("Endpoints require (numeric or timestamp or date) type") + TypeCheckFailure("Endpoints require (numeric or timestamp or date or timestamp_ntz or " + + "interval year to month or interval day to second) type") } } } @@ -120,9 +123,9 @@ case class ApproxCountDistinctForIntervals( val doubleValue = child.dataType match { case n: NumericType => n.numeric.toDouble(value.asInstanceOf[n.InternalType]) - case _: DateType => + case _: DateType | _: YearMonthIntervalType => value.asInstanceOf[Int].toDouble - case TimestampType | TimestampNTZType => + case TimestampType | TimestampNTZType | _: DayTimeIntervalType => value.asInstanceOf[Long].toDouble } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 8cce79c..0dcb906 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -49,15 +49,16 @@ import org.apache.spark.sql.types._ * yields better accuracy, the default value is * DEFAULT_PERCENTILE_ACCURACY. */ +// scalastyle:off line.size.limit @ExpressionDescription( usage = """ - _FUNC_(col, percentage [, accuracy]) - Returns the approximate `percentile` of the numeric - column `col` which is the smallest value in the ordered `col` values (sorted from least to - greatest) such that no more than `percentage` of `col` values is less than the value - or equal to that value. The value of percentage must be between 0.0 and 1.0. The `accuracy` - parameter (default: 10000) is a positive numeric literal which controls approximation accuracy - at the cost of memory. Higher value of `accuracy` yields better accuracy, `1.0/accuracy` is - the relative error of the approximation. + _FUNC_(col, percentage [, accuracy]) - Returns the approximate `percentile` of the numeric or + ansi interval column `col` which is the smallest value in the ordered `col` values (sorted + from least to greatest) such that no more than `percentage` of `col` values is less than + the value or equal to that value. The value of percentage must be between 0.0 and 1.0. + The `accuracy` parameter (default: 10000) is a positive numeric literal which controls + approximation accuracy at the cost of memory. Higher value of `accuracy` yields better + accuracy, `1.0/accuracy` is the relative error of the approximation. When `percentage` is an array, each value of the percentage array must be between 0.0 and 1.0. In this case, returns the approximate percentile array of column `col` at the given percentage array. @@ -68,9 +69,14 @@ import org.apache.spark.sql.types._ [1,1,0] > SELECT _FUNC_(col, 0.5, 100) FROM VALUES (0), (6), (7), (9), (10) AS tab(col); 7 + > SELECT _FUNC_(col, 0.5, 100) FROM VALUES (INTERVAL '0' MONTH), (INTERVAL '1' MONTH), (INTERVAL '2' MONTH), (INTERVAL '10' MONTH) AS tab(col); + 0-1 + > SELECT _FUNC_(col, array(0.5, 0.7), 100) FROM VALUES (INTERVAL '0' SECOND), (INTERVAL '1' SECOND), (INTERVAL '2' SECOND), (INTERVAL '10' SECOND) AS tab(col); + [0 00:00:01.000000000,0 00:00:02.000000000] """, group = "agg_funcs", since = "2.1.0") +// scalastyle:on line.size.limit case class ApproximatePercentile( child: Expression, percentageExpression: Expression, @@ -94,7 +100,8 @@ case class ApproximatePercentile( override def inputTypes: Seq[AbstractDataType] = { // Support NumericType, DateType, TimestampType and TimestampNTZType since their internal types // are all numeric, and can be easily cast to double for processing. - Seq(TypeCollection(NumericType, DateType, TimestampType, TimestampNTZType), + Seq(TypeCollection(NumericType, DateType, TimestampType, TimestampNTZType, + YearMonthIntervalType, DayTimeIntervalType), TypeCollection(DoubleType, ArrayType(DoubleType, containsNull = false)), IntegralType) } @@ -138,8 +145,9 @@ case class ApproximatePercentile( if (value != null) { // Convert the value to a double value val doubleValue = child.dataType match { - case DateType => value.asInstanceOf[Int].toDouble - case TimestampType | TimestampNTZType => value.asInstanceOf[Long].toDouble + case DateType | _: YearMonthIntervalType => value.asInstanceOf[Int].toDouble + case TimestampType | TimestampNTZType | _: DayTimeIntervalType => + value.asInstanceOf[Long].toDouble case n: NumericType => n.numeric.toDouble(value.asInstanceOf[n.InternalType]) case other: DataType => throw QueryExecutionErrors.dataTypeUnexpectedError(other) @@ -157,8 +165,8 @@ case class ApproximatePercentile( override def eval(buffer: PercentileDigest): Any = { val doubleResult = buffer.getPercentiles(percentages) val result = child.dataType match { - case DateType => doubleResult.map(_.toInt) - case TimestampType | TimestampNTZType => doubleResult.map(_.toLong) + case DateType | _: YearMonthIntervalType => doubleResult.map(_.toInt) + case TimestampType | TimestampNTZType | _: DayTimeIntervalType => doubleResult.map(_.toLong) case ByteType => doubleResult.map(_.toByte) case ShortType => doubleResult.map(_.toShort) case IntegerType => doubleResult.map(_.toInt) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 5bce4d3..7d3dd0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -43,12 +43,13 @@ import org.apache.spark.util.collection.OpenHashMap * percentage values. Each percentage value must be in the range * [0.0, 1.0]. */ +// scalastyle:off line.size.limit @ExpressionDescription( usage = """ - _FUNC_(col, percentage [, frequency]) - Returns the exact percentile value of numeric column - `col` at the given percentage. The value of percentage must be between 0.0 and 1.0. The - value of frequency should be positive integral + _FUNC_(col, percentage [, frequency]) - Returns the exact percentile value of numeric + or ansi interval column `col` at the given percentage. The value of percentage must be + between 0.0 and 1.0. The value of frequency should be positive integral _FUNC_(col, array(percentage1 [, percentage2]...) [, frequency]) - Returns the exact percentile value array of numeric column `col` at the given percentage(s). Each value @@ -62,9 +63,14 @@ import org.apache.spark.util.collection.OpenHashMap 3.0 > SELECT _FUNC_(col, array(0.25, 0.75)) FROM VALUES (0), (10) AS tab(col); [2.5,7.5] + > SELECT _FUNC_(col, 0.5) FROM VALUES (INTERVAL '0' MONTH), (INTERVAL '10' MONTH) AS tab(col); + 5.0 + > SELECT _FUNC_(col, array(0.2, 0.5)) FROM VALUES (INTERVAL '0' SECOND), (INTERVAL '10' SECOND) AS tab(col); + [2000000.0,5000000.0] """, group = "agg_funcs", since = "2.1.0") +// scalastyle:on line.size.limit case class Percentile( child: Expression, percentageExpression: Expression, @@ -118,7 +124,8 @@ case class Percentile( case _: ArrayType => ArrayType(DoubleType, false) case _ => DoubleType } - Seq(NumericType, percentageExpType, IntegralType) + Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType), + percentageExpType, IntegralType) } // Check the inputTypes are valid, and the percentageExpression satisfies: @@ -191,8 +198,15 @@ case class Percentile( return Seq.empty } - val sortedCounts = buffer.toSeq.sortBy(_._1)( - child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[AnyRef]]) + val ordering = + if (child.dataType.isInstanceOf[NumericType]) { + child.dataType.asInstanceOf[NumericType].ordering + } else if (child.dataType.isInstanceOf[YearMonthIntervalType]) { + child.dataType.asInstanceOf[YearMonthIntervalType].ordering + } else if (child.dataType.isInstanceOf[DayTimeIntervalType]) { + child.dataType.asInstanceOf[DayTimeIntervalType].ordering + } + val sortedCounts = buffer.toSeq.sortBy(_._1)(ordering.asInstanceOf[Ordering[AnyRef]]) val accumulatedCounts = sortedCounts.scanLeft((sortedCounts.head._1, 0L)) { case ((key1, count1), (key2, count2)) => (key2, count1 + count2) }.tail diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala index 9d53673..a017e5b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala @@ -39,7 +39,8 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { assert( wrongColumn.checkInputDataTypes() match { case TypeCheckFailure(msg) - if msg.contains("requires (numeric or timestamp or date or timestamp_ntz) type") => true + if msg.contains("requires (numeric or timestamp or date or timestamp_ntz or " + + "interval year to month or interval day to second) type") => true case _ => false }) } @@ -69,7 +70,8 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { AttributeReference("a", DoubleType)(), endpointsExpression = CreateArray(Array("foobar").map(Literal(_)))) assert(wrongEndpoints.checkInputDataTypes() == - TypeCheckFailure("Endpoints require (numeric or timestamp or date) type")) + TypeCheckFailure("Endpoints require (numeric or timestamp or date or timestamp_ntz or " + + "interval year to month or interval day to second) type")) } /** Create an ApproxCountDistinctForIntervals instance and an input and output buffer. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala index fa87407..b5882b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala @@ -170,8 +170,8 @@ class PercentileSuite extends SparkFunSuite { val child = AttributeReference("a", dataType)() val percentile = new Percentile(child, percentage) assertEqual(percentile.checkInputDataTypes(), - TypeCheckFailure(s"argument 1 requires numeric type, however, " + - s"'a' is of ${dataType.simpleString} type.")) + TypeCheckFailure(s"argument 1 requires (numeric or interval year to month or " + + s"interval day to second) type, however, 'a' is of ${dataType.simpleString} type.")) } val invalidFrequencyDataTypes = Seq(FloatType, DoubleType, BooleanType, @@ -184,8 +184,8 @@ class PercentileSuite extends SparkFunSuite { val frq = AttributeReference("frq", frequencyType)() val percentile = new Percentile(child, percentage, frq) assertEqual(percentile.checkInputDataTypes(), - TypeCheckFailure(s"argument 1 requires numeric type, however, " + - s"'a' is of ${dataType.simpleString} type.")) + TypeCheckFailure(s"argument 1 requires (numeric or interval year to month or " + + s"interval day to second) type, however, 'a' is of ${dataType.simpleString} type.")) } for(dataType <- validDataTypes; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala index 171e93c..53662c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.time.{Duration, Period} + import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxCountDistinctForIntervals import org.apache.spark.sql.catalyst.plans.logical.Aggregate @@ -58,4 +60,30 @@ class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSpa } } } + + test("SPARK-37138: Support Ansi Interval type in ApproxCountDistinctForIntervals") { + val table = "approx_count_distinct_for_ansi_intervals_tbl" + withTable(table) { + Seq((Period.ofMonths(100), Duration.ofSeconds(100L)), + (Period.ofMonths(200), Duration.ofSeconds(200L)), + (Period.ofMonths(300), Duration.ofSeconds(300L))) + .toDF("col1", "col2").createOrReplaceTempView(table) + val endpoints = (0 to 5).map(_ / 10) + + val relation = spark.table(table).logicalPlan + val ymAttr = relation.output.find(_.name == "col1").get + val ymAggFunc = + ApproxCountDistinctForIntervals(ymAttr, CreateArray(endpoints.map(Literal(_)))) + val ymAggExpr = ymAggFunc.toAggregateExpression() + val ymNamedExpr = Alias(ymAggExpr, ymAggExpr.toString)() + + val dtAttr = relation.output.find(_.name == "col2").get + val dtAggFunc = + ApproxCountDistinctForIntervals(dtAttr, CreateArray(endpoints.map(Literal(_)))) + val dtAggExpr = dtAggFunc.toAggregateExpression() + val dtNamedExpr = Alias(dtAggExpr, dtAggExpr.toString)() + val result = Dataset.ofRows(spark, Aggregate(Nil, Seq(ymNamedExpr, dtNamedExpr), relation)) + checkAnswer(result, Row(Array(1, 1, 1, 1, 1), Array(1, 1, 1, 1, 1))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala index 5ff15c9..9237c9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} -import java.time.LocalDateTime +import java.time.{Duration, LocalDateTime, Period} import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY @@ -32,7 +32,7 @@ import org.apache.spark.sql.test.SharedSparkSession class ApproximatePercentileQuerySuite extends QueryTest with SharedSparkSession { import testImplicits._ - private val table = "percentile_test" + private val table = "percentile_approx" test("percentile_approx, single percentile value") { withTempView(table) { @@ -319,4 +319,22 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSparkSession Row(18, 17, 17, 17)) } } + + test("SPARK-37138: Support Ansi Interval type in ApproximatePercentile") { + withTempView(table) { + Seq((Period.ofMonths(100), Duration.ofSeconds(100L)), + (Period.ofMonths(200), Duration.ofSeconds(200L)), + (Period.ofMonths(300), Duration.ofSeconds(300L))) + .toDF("col1", "col2").createOrReplaceTempView(table) + checkAnswer( + spark.sql( + s"""SELECT + | percentile_approx(col1, 0.5), + | SUM(null), + | percentile_approx(col2, 0.5) + |FROM $table + """.stripMargin), + Row(Period.ofMonths(200).normalized(), null, Duration.ofSeconds(200L))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PercentileQuerySuite.scala new file mode 100644 index 0000000..f39f0c1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/PercentileQuerySuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.time.{Duration, Period} + +import org.apache.spark.sql.test.SharedSparkSession + +/** + * End-to-end tests for percentile aggregate function. + */ +class PercentileQuerySuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + private val table = "percentile_test" + + test("SPARK-37138: Support Ansi Interval type in Percentile") { + withTempView(table) { + Seq((Period.ofMonths(100), Duration.ofSeconds(100L)), + (Period.ofMonths(200), Duration.ofSeconds(200L)), + (Period.ofMonths(300), Duration.ofSeconds(300L))) + .toDF("col1", "col2").createOrReplaceTempView(table) + checkAnswer( + spark.sql( + s"""SELECT + | CAST(percentile(col1, 0.5) AS STRING), + | SUM(null), + | CAST(percentile(col2, 0.5) AS STRING) + |FROM $table + """.stripMargin), + Row("200.0", null, "2.0E8")) + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org