This is an automated email from the ASF dual-hosted git repository. gengliang pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a7f0adb2dd8 [SPARK-38589][SQL] New SQL function: try_avg a7f0adb2dd8 is described below commit a7f0adb2dd8449af6f9e9b5a25f11b5dcf5868f1 Author: Gengliang Wang <gengli...@apache.org> AuthorDate: Tue Apr 12 20:39:08 2022 +0800 [SPARK-38589][SQL] New SQL function: try_avg ### What changes were proposed in this pull request? Add a new SQL function: try_avg. It is identical to the function `avg`, except that it returns NULL result instead of throwing an exception on decimal/interval value overflow. Note it is also different from `avg` when ANSI mode is off on interval overflows | Function | avg | try_avg | |------------------|------------------------------------|-------------| | year-month interval overflow | Error | Return NULL | | day-time interval overflow | Error | Return NULL | ### Why are the changes needed? * Users can manage to finish queries without interruptions in ANSI mode. * Users can get NULLs instead of runtime errors if interval overflow occurs when ANSI mode is off. For example ``` > SELECT avg(col) FROM VALUES (interval '2147483647 months'),(interval '1 months') AS tab(col) java.lang.ArithmeticException: integer overflow. > SELECT try_avg(col) FROM VALUES (interval '2147483647 months'),(interval '1 months') AS tab(col) NULL ``` ### Does this PR introduce _any_ user-facing change? Yes, adding a new SQL function: try_avg. It is identical to the function `avg`, except that it returns NULL result instead of throwing an exception on decimal/interval value overflow. ### How was this patch tested? UT Closes #35896 from gengliangwang/tryAvg. Lead-authored-by: Gengliang Wang <gengli...@apache.org> Co-authored-by: Gengliang Wang <ltn...@gmail.com> Signed-off-by: Gengliang Wang <gengli...@apache.org> --- docs/sql-ref-ansi-compliance.md | 3 +- .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../catalyst/expressions/aggregate/Average.scala | 125 +++++++++++++++++---- .../sql/catalyst/expressions/aggregate/Sum.scala | 35 +++--- .../sql-functions/sql-expression-schema.md | 5 +- .../resources/sql-tests/inputs/try_aggregates.sql | 14 +++ .../sql-tests/results/ansi/try_aggregates.sql.out | 82 +++++++++++++- .../sql-tests/results/try_aggregates.sql.out | 82 +++++++++++++- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 12 ++ 9 files changed, 313 insertions(+), 46 deletions(-) diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index 0f7f29cde7f..66161a112b1 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -316,7 +316,8 @@ When ANSI mode is on, it throws exceptions for invalid operations. You can use t - `try_subtract`: identical to the add operator `-`, except that it returns `NULL` result instead of throwing an exception on integral value overflow. - `try_multiply`: identical to the add operator `*`, except that it returns `NULL` result instead of throwing an exception on integral value overflow. - `try_divide`: identical to the division operator `/`, except that it returns `NULL` result instead of throwing an exception on dividing 0. - - `try_sum`: identical to the function `sum`, except that it returns `NULL` result instead of throwing an exception on integral/decimal value overflow. + - `try_sum`: identical to the function `sum`, except that it returns `NULL` result instead of throwing an exception on integral/decimal/interval value overflow. + - `try_avg`: identical to the function `avg`, except that it returns `NULL` result instead of throwing an exception on decimal/interval value overflow. - `try_element_at`: identical to the function `element_at`, except that it returns `NULL` result instead of throwing an exception on array's index out of bound or map's key not found. ### SQL Keywords (optional, disabled by default) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 1824fb68f76..80374f769a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -453,6 +453,7 @@ object FunctionRegistry { expression[TrySubtract]("try_subtract"), expression[TryMultiply]("try_multiply"), expression[TryElementAt]("try_element_at"), + expression[TryAverage]("try_avg"), expression[TrySum]("try_sum"), expression[TryToBinary]("try_to_binary"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 533f7f20b25..14914576091 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -26,25 +26,13 @@ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -@ExpressionDescription( - usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.", - examples = """ - Examples: - > SELECT _FUNC_(col) FROM VALUES (1), (2), (3) AS tab(col); - 2.0 - > SELECT _FUNC_(col) FROM VALUES (1), (2), (NULL) AS tab(col); - 1.5 - """, - group = "agg_funcs", - since = "1.0.0") -case class Average( - child: Expression, - failOnError: Boolean = SQLConf.get.ansiEnabled) +abstract class AverageBase extends DeclarativeAggregate with ImplicitCastInputTypes with UnaryLike[Expression] { - def this(child: Expression) = this(child, failOnError = SQLConf.get.ansiEnabled) + // Whether to use ANSI add or not during the execution. + def useAnsiAdd: Boolean override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg") @@ -61,7 +49,7 @@ case class Average( final override val nodePatterns: Seq[TreePattern] = Seq(AVERAGE) - private lazy val resultType = child.dataType match { + protected lazy val resultType = child.dataType match { case DecimalType.Fixed(p, s) => DecimalType.bounded(p + 4, s + 4) case _: YearMonthIntervalType => YearMonthIntervalType() @@ -86,18 +74,18 @@ case class Average( /* count = */ Literal(0L) ) - override lazy val mergeExpressions = Seq( - /* sum = */ sum.left + sum.right, + protected def getMergeExpressions = Seq( + /* sum = */ Add(sum.left, sum.right, useAnsiAdd), /* count = */ count.left + count.right ) // If all input are nulls, count will be 0 and we will get null after the division. // We can't directly use `/` as it throws an exception under ansi mode. - override lazy val evaluateExpression = child.dataType match { + protected def getEvaluateExpression = child.dataType match { case _: DecimalType => DecimalPrecision.decimalAndDecimal()( Divide( - CheckOverflowInSum(sum, sumDataType.asInstanceOf[DecimalType], !failOnError), + CheckOverflowInSum(sum, sumDataType.asInstanceOf[DecimalType], !useAnsiAdd), count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType) case _: YearMonthIntervalType => If(EqualTo(count, Literal(0L)), @@ -109,17 +97,106 @@ case class Average( Divide(sum.cast(resultType), count.cast(resultType), failOnError = false) } - override lazy val updateExpressions: Seq[Expression] = Seq( + protected def getUpdateExpressions: Seq[Expression] = Seq( /* sum = */ Add( sum, - coalesce(child.cast(sumDataType), Literal.default(sumDataType))), + coalesce(child.cast(sumDataType), Literal.default(sumDataType)), + failOnError = useAnsiAdd), /* count = */ If(child.isNull, count, count + 1L) ) + // The flag `useAnsiAdd` won't be shown in the `toString` or `toAggString` methods + override def flatArguments: Iterator[Any] = Iterator(child) +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.", + examples = """ + Examples: + > SELECT _FUNC_(col) FROM VALUES (1), (2), (3) AS tab(col); + 2.0 + > SELECT _FUNC_(col) FROM VALUES (1), (2), (NULL) AS tab(col); + 1.5 + """, + group = "agg_funcs", + since = "1.0.0") +case class Average( + child: Expression, + useAnsiAdd: Boolean = SQLConf.get.ansiEnabled) extends AverageBase { + def this(child: Expression) = this(child, useAnsiAdd = SQLConf.get.ansiEnabled) + override protected def withNewChildInternal(newChild: Expression): Average = copy(child = newChild) - // The flag `failOnError` won't be shown in the `toString` or `toAggString` methods - override def flatArguments: Iterator[Any] = Iterator(child) + override lazy val updateExpressions: Seq[Expression] = getUpdateExpressions + + override lazy val mergeExpressions: Seq[Expression] = getMergeExpressions + + override lazy val evaluateExpression: Expression = getEvaluateExpression +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the mean calculated from values of a group and the result is null on overflow.", + examples = """ + Examples: + > SELECT _FUNC_(col) FROM VALUES (1), (2), (3) AS tab(col); + 2.0 + > SELECT _FUNC_(col) FROM VALUES (1), (2), (NULL) AS tab(col); + 1.5 + > SELECT _FUNC_(col) FROM VALUES (interval '2147483647 months'), (interval '1 months') AS tab(col); + NULL + """, + group = "agg_funcs", + since = "3.3.0") +// scalastyle:on line.size.limit +case class TryAverage(child: Expression) extends AverageBase { + override def useAnsiAdd: Boolean = resultType match { + // Double type won't fail, thus we can always use non-Ansi Add. + // For decimal type, it returns NULL on overflow. It behaves the same as TrySum when + // `failOnError` is false. + case _: DoubleType | _: DecimalType => false + case _ => true + } + + private def addTryEvalIfNeeded(expression: Expression): Expression = { + if (useAnsiAdd) { + TryEval(expression) + } else { + expression + } + } + + override lazy val updateExpressions: Seq[Expression] = { + val expressions = getUpdateExpressions + addTryEvalIfNeeded(expressions.head) +: expressions.tail + } + + override lazy val mergeExpressions: Seq[Expression] = { + val expressions = getMergeExpressions + if (useAnsiAdd) { + val bufferOverflow = sum.left.isNull && count.left > 0L + val inputOverflow = sum.right.isNull && count.right > 0L + Seq( + If( + bufferOverflow || inputOverflow, + Literal.create(null, resultType), + // If both the buffer and the input do not overflow, just add them, as they can't be + // null. + TryEval(Add(KnownNotNull(sum.left), KnownNotNull(sum.right), useAnsiAdd))), + expressions(1)) + } else { + expressions + } + } + + override lazy val evaluateExpression: Expression = { + addTryEvalIfNeeded(getEvaluateExpression) + } + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) + + override def prettyName: String = "try_avg" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index fd27edfc8fc..f2c6925b837 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -30,7 +30,8 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes with UnaryLike[Expression] { - def failOnError: Boolean + // Whether to use ANSI add or not during the execution. + def useAnsiAdd: Boolean protected def shouldTrackIsEmpty: Boolean @@ -81,9 +82,9 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate // null if overflow happens under non-ansi mode. val sumExpr = if (child.nullable) { If(child.isNull, sum, - Add(sum, KnownNotNull(child).cast(resultType), failOnError = failOnError)) + Add(sum, KnownNotNull(child).cast(resultType), failOnError = useAnsiAdd)) } else { - Add(sum, child.cast(resultType), failOnError = failOnError) + Add(sum, child.cast(resultType), failOnError = useAnsiAdd) } // The buffer becomes non-empty after seeing the first not-null input. val isEmptyExpr = if (child.nullable) { @@ -98,10 +99,10 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate // in case the input is nullable. The `sum` can only be null if there is no value, as // non-decimal type can produce overflowed value under non-ansi mode. if (child.nullable) { - Seq(coalesce(Add(coalesce(sum, zero), child.cast(resultType), failOnError = failOnError), + Seq(coalesce(Add(coalesce(sum, zero), child.cast(resultType), failOnError = useAnsiAdd), sum)) } else { - Seq(Add(coalesce(sum, zero), child.cast(resultType), failOnError = failOnError)) + Seq(Add(coalesce(sum, zero), child.cast(resultType), failOnError = useAnsiAdd)) } } @@ -127,11 +128,11 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate // If both the buffer and the input do not overflow, just add them, as they can't be // null. See the comments inside `updateExpressions`: `sum` can only be null if // overflow happens. - Add(KnownNotNull(sum.left), KnownNotNull(sum.right), failOnError)), + Add(KnownNotNull(sum.left), KnownNotNull(sum.right), useAnsiAdd)), isEmpty.left && isEmpty.right) } else { Seq(coalesce( - Add(coalesce(sum.left, zero), sum.right, failOnError = failOnError), + Add(coalesce(sum.left, zero), sum.right, failOnError = useAnsiAdd), sum.left)) } @@ -145,13 +146,13 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate protected def getEvaluateExpression: Expression = resultType match { case d: DecimalType => If(isEmpty, Literal.create(null, resultType), - CheckOverflowInSum(sum, d, !failOnError)) + CheckOverflowInSum(sum, d, !useAnsiAdd)) case _ if shouldTrackIsEmpty => If(isEmpty, Literal.create(null, resultType), sum) case _ => sum } - // The flag `failOnError` won't be shown in the `toString` or `toAggString` methods + // The flag `useAnsiAdd` won't be shown in the `toString` or `toAggString` methods override def flatArguments: Iterator[Any] = Iterator(child) } @@ -170,9 +171,9 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate since = "1.0.0") case class Sum( child: Expression, - failOnError: Boolean = SQLConf.get.ansiEnabled) + useAnsiAdd: Boolean = SQLConf.get.ansiEnabled) extends SumBase(child) { - def this(child: Expression) = this(child, failOnError = SQLConf.get.ansiEnabled) + def this(child: Expression) = this(child, useAnsiAdd = SQLConf.get.ansiEnabled) override def shouldTrackIsEmpty: Boolean = resultType match { case _: DecimalType => true @@ -207,10 +208,10 @@ case class Sum( // scalastyle:on line.size.limit case class TrySum(child: Expression) extends SumBase(child) { - override def failOnError: Boolean = dataType match { - // Double type won't fail, thus the failOnError is always false + override def useAnsiAdd: Boolean = dataType match { + // Double type won't fail, thus useAnsiAdd is always false // For decimal type, it returns NULL on overflow. It behaves the same as TrySum when - // `failOnError` is false. + // `useAnsiAdd` is false. case _: DoubleType | _: DecimalType => false case _ => true } @@ -224,7 +225,7 @@ case class TrySum(child: Expression) extends SumBase(child) { } override lazy val updateExpressions: Seq[Expression] = - if (failOnError) { + if (useAnsiAdd) { val expressions = getUpdateExpressions // If the length of updateExpressions is larger than 1, the tail expressions are for // tracking whether the input is empty, which doesn't need `TryEval` execution. @@ -234,14 +235,14 @@ case class TrySum(child: Expression) extends SumBase(child) { } override lazy val mergeExpressions: Seq[Expression] = - if (failOnError) { + if (useAnsiAdd) { getMergeExpressions.map(TryEval) } else { getMergeExpressions } override lazy val evaluateExpression: Expression = - if (failOnError) { + if (useAnsiAdd) { TryEval(getEvaluateExpression) } else { getEvaluateExpression diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 14902b08549..9f8faf517a4 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -1,6 +1,6 @@ <!-- Automatically generated by ExpressionsSchemaSuite --> ## Summary - - Number of queries: 387 + - Number of queries: 388 - Number of expressions that missing example: 12 - Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint ## Schema of Built-in Functions @@ -380,6 +380,7 @@ | org.apache.spark.sql.catalyst.expressions.aggregate.StddevSamp | stddev | SELECT stddev(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<stddev(col):double> | | org.apache.spark.sql.catalyst.expressions.aggregate.StddevSamp | stddev_samp | SELECT stddev_samp(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<stddev_samp(col):double> | | org.apache.spark.sql.catalyst.expressions.aggregate.Sum | sum | SELECT sum(col) FROM VALUES (5), (10), (15) AS tab(col) | struct<sum(col):bigint> | +| org.apache.spark.sql.catalyst.expressions.aggregate.TryAverage | try_avg | SELECT try_avg(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<try_avg(col):double> | | org.apache.spark.sql.catalyst.expressions.aggregate.TrySum | try_sum | SELECT try_sum(col) FROM VALUES (5), (10), (15) AS tab(col) | struct<try_sum(col):bigint> | | org.apache.spark.sql.catalyst.expressions.aggregate.VariancePop | var_pop | SELECT var_pop(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<var_pop(col):double> | | org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | var_samp | SELECT var_samp(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<var_samp(col):double> | @@ -392,4 +393,4 @@ | org.apache.spark.sql.catalyst.expressions.xml.XPathList | xpath | SELECT xpath('<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>','a/b/text()') | struct<xpath(<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>, a/b/text()):array<string>> | | org.apache.spark.sql.catalyst.expressions.xml.XPathLong | xpath_long | SELECT xpath_long('<a><b>1</b><b>2</b></a>', 'sum(a/b)') | struct<xpath_long(<a><b>1</b><b>2</b></a>, sum(a/b)):bigint> | | org.apache.spark.sql.catalyst.expressions.xml.XPathShort | xpath_short | SELECT xpath_short('<a><b>1</b><b>2</b></a>', 'sum(a/b)') | struct<xpath_short(<a><b>1</b><b>2</b></a>, sum(a/b)):smallint> | -| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('<a><b>b</b><c>cc</c></a>','a/c') | struct<xpath_string(<a><b>b</b><c>cc</c></a>, a/c):string> | +| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('<a><b>b</b><c>cc</c></a>','a/c') | struct<xpath_string(<a><b>b</b><c>cc</c></a>, a/c):string> | \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/inputs/try_aggregates.sql b/sql/core/src/test/resources/sql-tests/inputs/try_aggregates.sql index ffa8eefe828..cdd2e632319 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/try_aggregates.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/try_aggregates.sql @@ -11,3 +11,17 @@ SELECT try_sum(col) FROM VALUES (interval '1 months'), (interval '1 months') AS SELECT try_sum(col) FROM VALUES (interval '2147483647 months'), (interval '1 months') AS tab(col); SELECT try_sum(col) FROM VALUES (interval '1 seconds'), (interval '1 seconds') AS tab(col); SELECT try_sum(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 DAYS') AS tab(col); + +-- try_avg +SELECT try_avg(col) FROM VALUES (5), (10), (15) AS tab(col); +SELECT try_avg(col) FROM VALUES (5.0), (10.0), (15.0) AS tab(col); +SELECT try_avg(col) FROM VALUES (NULL), (10), (15) AS tab(col); +SELECT try_avg(col) FROM VALUES (NULL), (NULL) AS tab(col); +SELECT try_avg(col) FROM VALUES (9223372036854775807L), (1L) AS tab(col); +-- test overflow in Decimal(38, 0) +SELECT try_avg(col) FROM VALUES (98765432109876543210987654321098765432BD), (98765432109876543210987654321098765432BD) AS tab(col); + +SELECT try_avg(col) FROM VALUES (interval '1 months'), (interval '1 months') AS tab(col); +SELECT try_avg(col) FROM VALUES (interval '2147483647 months'), (interval '1 months') AS tab(col); +SELECT try_avg(col) FROM VALUES (interval '1 seconds'), (interval '1 seconds') AS tab(col); +SELECT try_avg(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 DAYS') AS tab(col); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/try_aggregates.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/try_aggregates.sql.out index 7ae217ad758..724553f6bd1 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/try_aggregates.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/try_aggregates.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 20 -- !query @@ -80,3 +80,83 @@ SELECT try_sum(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 DAYS') struct<try_sum(col):interval day> -- !query output NULL + + +-- !query +SELECT try_avg(col) FROM VALUES (5), (10), (15) AS tab(col) +-- !query schema +struct<try_avg(col):double> +-- !query output +10.0 + + +-- !query +SELECT try_avg(col) FROM VALUES (5.0), (10.0), (15.0) AS tab(col) +-- !query schema +struct<try_avg(col):decimal(7,5)> +-- !query output +10.00000 + + +-- !query +SELECT try_avg(col) FROM VALUES (NULL), (10), (15) AS tab(col) +-- !query schema +struct<try_avg(col):double> +-- !query output +12.5 + + +-- !query +SELECT try_avg(col) FROM VALUES (NULL), (NULL) AS tab(col) +-- !query schema +struct<try_avg(col):double> +-- !query output +NULL + + +-- !query +SELECT try_avg(col) FROM VALUES (9223372036854775807L), (1L) AS tab(col) +-- !query schema +struct<try_avg(col):double> +-- !query output +4.6116860184273879E18 + + +-- !query +SELECT try_avg(col) FROM VALUES (98765432109876543210987654321098765432BD), (98765432109876543210987654321098765432BD) AS tab(col) +-- !query schema +struct<try_avg(col):decimal(38,4)> +-- !query output +NULL + + +-- !query +SELECT try_avg(col) FROM VALUES (interval '1 months'), (interval '1 months') AS tab(col) +-- !query schema +struct<try_avg(col):interval year to month> +-- !query output +0-1 + + +-- !query +SELECT try_avg(col) FROM VALUES (interval '2147483647 months'), (interval '1 months') AS tab(col) +-- !query schema +struct<try_avg(col):interval year to month> +-- !query output +NULL + + +-- !query +SELECT try_avg(col) FROM VALUES (interval '1 seconds'), (interval '1 seconds') AS tab(col) +-- !query schema +struct<try_avg(col):interval day to second> +-- !query output +0 00:00:01.000000000 + + +-- !query +SELECT try_avg(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 DAYS') AS tab(col) +-- !query schema +struct<try_avg(col):interval day to second> +-- !query output +NULL diff --git a/sql/core/src/test/resources/sql-tests/results/try_aggregates.sql.out b/sql/core/src/test/resources/sql-tests/results/try_aggregates.sql.out index 7ae217ad758..724553f6bd1 100644 --- a/sql/core/src/test/resources/sql-tests/results/try_aggregates.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/try_aggregates.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 20 -- !query @@ -80,3 +80,83 @@ SELECT try_sum(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 DAYS') struct<try_sum(col):interval day> -- !query output NULL + + +-- !query +SELECT try_avg(col) FROM VALUES (5), (10), (15) AS tab(col) +-- !query schema +struct<try_avg(col):double> +-- !query output +10.0 + + +-- !query +SELECT try_avg(col) FROM VALUES (5.0), (10.0), (15.0) AS tab(col) +-- !query schema +struct<try_avg(col):decimal(7,5)> +-- !query output +10.00000 + + +-- !query +SELECT try_avg(col) FROM VALUES (NULL), (10), (15) AS tab(col) +-- !query schema +struct<try_avg(col):double> +-- !query output +12.5 + + +-- !query +SELECT try_avg(col) FROM VALUES (NULL), (NULL) AS tab(col) +-- !query schema +struct<try_avg(col):double> +-- !query output +NULL + + +-- !query +SELECT try_avg(col) FROM VALUES (9223372036854775807L), (1L) AS tab(col) +-- !query schema +struct<try_avg(col):double> +-- !query output +4.6116860184273879E18 + + +-- !query +SELECT try_avg(col) FROM VALUES (98765432109876543210987654321098765432BD), (98765432109876543210987654321098765432BD) AS tab(col) +-- !query schema +struct<try_avg(col):decimal(38,4)> +-- !query output +NULL + + +-- !query +SELECT try_avg(col) FROM VALUES (interval '1 months'), (interval '1 months') AS tab(col) +-- !query schema +struct<try_avg(col):interval year to month> +-- !query output +0-1 + + +-- !query +SELECT try_avg(col) FROM VALUES (interval '2147483647 months'), (interval '1 months') AS tab(col) +-- !query schema +struct<try_avg(col):interval year to month> +-- !query output +NULL + + +-- !query +SELECT try_avg(col) FROM VALUES (interval '1 seconds'), (interval '1 seconds') AS tab(col) +-- !query schema +struct<try_avg(col):interval day to second> +-- !query output +0 00:00:01.000000000 + + +-- !query +SELECT try_avg(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 DAYS') AS tab(col) +-- !query schema +struct<try_avg(col):interval day to second> +-- !query output +NULL diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 81067eef401..0b00659f73b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4331,6 +4331,18 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer(df.repartitionByRange(2, col("v")).selectExpr("try_sum(v)"), Row(null)) } } + + test("SPARK-38589: try_avg should return null if overflow happens before merging") { + val yearMonthDf = Seq(Int.MaxValue, Int.MaxValue, 2) + .map(Period.ofMonths) + .toDF("v") + val dayTimeDf = Seq(106751991L, 106751991L, 2L) + .map(Duration.ofDays) + .toDF("v") + Seq(yearMonthDf, dayTimeDf).foreach { df => + checkAnswer(df.repartitionByRange(2, col("v")).selectExpr("try_avg(v)"), Row(null)) + } + } } case class Foo(bar: Option[String]) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org