This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9d061e3 [SPARK-35926][SQL] Add support YearMonthIntervalType for width_bucket 9d061e3 is described below commit 9d061e3939a021c602c070fc13cef951a8f94c82 Author: PengLei <peng.8...@gmail.com> AuthorDate: Fri Oct 15 17:15:50 2021 +0300 [SPARK-35926][SQL] Add support YearMonthIntervalType for width_bucket ### What changes were proposed in this pull request? Support width_bucket(YearMonthIntervalType, YearMonthIntervalType, YearMonthIntervalType, Long), it return long result eg: ``` width_bucket(input_value, min_value, max_value, bucket_nums) width_bucket(INTERVAL '1' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10) It will divides the range between the max_value and min_value into 10 buckets. [ INTERVAL '0' YEAR, INTERVAL '1' YEAR), [ INTERVAL '1' YEAR, INTERVAL '2' YEAR)...... [INTERVAL '9' YEAR, INTERVAL '10' YEAR) Then, calculates which bucket the given input_value locate. ``` The function `width_bucket` is introduced from [SPARK-21117](https://issues.apache.org/jira/browse/SPARK-21117) ### Why are the changes needed? [35926](https://issues.apache.org/jira/browse/SPARK-35926) 1. The `WIDTH_BUCKET` function assigns values to buckets (individual segments) in an equiwidth histogram. The ANSI SQL Standard Syntax is like follow: `WIDTH_BUCKET( expression, min, max, buckets)`. [Reference](https://www.oreilly.com/library/view/sql-in-a/9780596155322/re91.html). 2. `WIDTH_BUCKET` just support `Double` at now, Of course, we can cast `Int` to `Double` to use it. But we cloud not cast `YearMonthIntervayType` to `Double`. 3. I think it has a use scenario. eg: Histogram of employee years of service, the `years of service` is a column of `YearMonthIntervalType` dataType. ### Does this PR introduce _any_ user-facing change? Yes. The user can use `width_bucket` with YearMonthIntervalType. ### How was this patch tested? Add ut test Closes #33132 from Peng-Lei/SPARK-35926. Authored-by: PengLei <peng.8...@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- .../sql/catalyst/expressions/mathExpressions.scala | 33 ++++++++++++++++++---- .../expressions/MathExpressionsSuite.scala | 15 ++++++++++ .../test/resources/sql-tests/inputs/interval.sql | 2 ++ .../sql-tests/results/ansi/interval.sql.out | 18 +++++++++++- .../resources/sql-tests/results/interval.sql.out | 18 +++++++++++- .../org/apache/spark/sql/MathFunctionsSuite.scala | 17 +++++++++++ 6 files changed, 96 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index c14fa72..6c34ed6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.NumberConverter +import org.apache.spark.sql.catalyst.util.{NumberConverter, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1613,6 +1613,10 @@ object WidthBucket { 5 > SELECT _FUNC_(-0.9, 5.2, 0.5, 2); 3 + > SELECT _FUNC_(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10); + 1 + > SELECT _FUNC_(INTERVAL '1' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10); + 2 """, since = "3.1.0", group = "math_funcs") @@ -1623,16 +1627,35 @@ case class WidthBucket( numBucket: Expression) extends QuaternaryExpression with ImplicitCastInputTypes with NullIntolerant { - override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType, DoubleType, LongType) + override def inputTypes: Seq[AbstractDataType] = Seq( + TypeCollection(DoubleType, YearMonthIntervalType), + TypeCollection(DoubleType, YearMonthIntervalType), + TypeCollection(DoubleType, YearMonthIntervalType), + LongType) + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case TypeCheckSuccess => + (value.dataType, minValue.dataType, maxValue.dataType) match { + case (_: YearMonthIntervalType, _: YearMonthIntervalType, _: YearMonthIntervalType) => + TypeCheckSuccess + case _ => + val types = Seq(value.dataType, minValue.dataType, maxValue.dataType) + TypeUtils.checkForSameTypeInputExpr(types, s"function $prettyName") + } + case f => f + } + } + override def dataType: DataType = LongType override def nullable: Boolean = true override def prettyName: String = "width_bucket" override protected def nullSafeEval(input: Any, min: Any, max: Any, numBucket: Any): Any = { WidthBucket.computeBucketNumber( - input.asInstanceOf[Double], - min.asInstanceOf[Double], - max.asInstanceOf[Double], + input.asInstanceOf[Number].doubleValue(), + min.asInstanceOf[Number].doubleValue(), + max.asInstanceOf[Number].doubleValue(), numBucket.asInstanceOf[Long]) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index aced787..bfb9614 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -725,4 +725,19 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Signum(Literal(Duration.of(Long.MaxValue, ChronoUnit.MICROS))), 1.0) checkEvaluation(Signum(Literal(Duration.of(Long.MinValue, ChronoUnit.MICROS))), -1.0) } + + test("SPARK-35926: Support YearMonthIntervalType in width-bucket function") { + Seq( + (Period.ofMonths(-1), Period.ofYears(0), Period.ofYears(10), 10L) -> 0L, + (Period.ofMonths(0), Period.ofYears(0), Period.ofYears(10), 10L) -> 1L, + (Period.ofMonths(13), Period.ofYears(0), Period.ofYears(10), 10L) -> 2L, + (Period.ofYears(1), Period.ofYears(0), Period.ofYears(10), 10L) -> 2L, + (Period.ofYears(1), Period.ofYears(0), Period.ofYears(1), 10L) -> 11L, + (Period.ofMonths(Int.MaxValue), Period.ofYears(0), Period.ofYears(1), 10L) -> 11L, + (Period.ofMonths(0), Period.ofMonths(Int.MinValue), Period.ofMonths(Int.MaxValue), 10L) -> 6L, + (Period.ofMonths(-1), Period.ofMonths(Int.MinValue), Period.ofMonths(Int.MaxValue), 10L) -> 5L + ).foreach { case ((v, s, e, n), expected) => + checkEvaluation(WidthBucket(Literal(v), Literal(s), Literal(e), Literal(n)), expected) + } + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/interval.sql b/sql/core/src/test/resources/sql-tests/inputs/interval.sql index 7dd7e4e..2d1d8c4 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/interval.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/interval.sql @@ -382,3 +382,5 @@ SELECT signum(INTERVAL '0-0' YEAR TO MONTH); SELECT signum(INTERVAL '-10' DAY); SELECT signum(INTERVAL '10' HOUR); SELECT signum(INTERVAL '0 0:0:0' DAY TO SECOND); +SELECT width_bucket(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10); +SELECT width_bucket(INTERVAL '-1' YEAR, INTERVAL -'1-2' YEAR TO MONTH, INTERVAL '1-2' YEAR TO MONTH, 10); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index cff294e..294e5c9 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 282 +-- Number of queries: 284 -- !query @@ -2657,3 +2657,19 @@ SELECT signum(INTERVAL '0 0:0:0' DAY TO SECOND) struct<SIGNUM(INTERVAL '0 00:00:00' DAY TO SECOND):double> -- !query output 0.0 + + +-- !query +SELECT width_bucket(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10) +-- !query schema +struct<width_bucket(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10):bigint> +-- !query output +1 + + +-- !query +SELECT width_bucket(INTERVAL '-1' YEAR, INTERVAL -'1-2' YEAR TO MONTH, INTERVAL '1-2' YEAR TO MONTH, 10) +-- !query schema +struct<width_bucket(INTERVAL '-1' YEAR, INTERVAL '-1-2' YEAR TO MONTH, INTERVAL '1-2' YEAR TO MONTH, 10):bigint> +-- !query output +1 diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out index 688cde5..5d2edba 100644 --- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 282 +-- Number of queries: 284 -- !query @@ -2646,3 +2646,19 @@ SELECT signum(INTERVAL '0 0:0:0' DAY TO SECOND) struct<SIGNUM(INTERVAL '0 00:00:00' DAY TO SECOND):double> -- !query output 0.0 + + +-- !query +SELECT width_bucket(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10) +-- !query schema +struct<width_bucket(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10):bigint> +-- !query output +1 + + +-- !query +SELECT width_bucket(INTERVAL '-1' YEAR, INTERVAL -'1-2' YEAR TO MONTH, INTERVAL '1-2' YEAR TO MONTH, 10) +-- !query schema +struct<width_bucket(INTERVAL '-1' YEAR, INTERVAL '-1-2' YEAR TO MONTH, INTERVAL '1-2' YEAR TO MONTH, 10):bigint> +-- !query output +1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala index 3512e5c..ce25a88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.nio.charset.StandardCharsets +import java.time.Period import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions.{log => logarithm} @@ -520,4 +521,20 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(df.selectExpr("positive(a)"), Row(1)) checkAnswer(df.selectExpr("positive(b)"), Row(-1)) } + + test("SPARK-35926: Support YearMonthIntervalType in width-bucket function") { + Seq( + (Period.ofMonths(-1), Period.ofYears(0), Period.ofYears(10), 10) -> 0, + (Period.ofMonths(0), Period.ofYears(0), Period.ofYears(10), 10) -> 1, + (Period.ofMonths(13), Period.ofYears(0), Period.ofYears(10), 10) -> 2, + (Period.ofYears(1), Period.ofYears(0), Period.ofYears(10), 10) -> 2, + (Period.ofYears(1), Period.ofYears(0), Period.ofYears(1), 10) -> 11, + (Period.ofMonths(Int.MaxValue), Period.ofYears(0), Period.ofYears(1), 10) -> 11, + (Period.ofMonths(0), Period.ofMonths(Int.MinValue), Period.ofMonths(Int.MaxValue), 10) -> 6, + (Period.ofMonths(-1), Period.ofMonths(Int.MinValue), Period.ofMonths(Int.MaxValue), 10) -> 5 + ).foreach { case ((value, start, end, num), expected) => + val df = Seq((value, start, end, num)).toDF("v", "s", "e", "n") + checkAnswer(df.selectExpr("width_bucket(v, s, e, n)"), Row(expected)) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org