This is an automated email from the ASF dual-hosted git repository. gengliang pushed a commit to branch branch-3.3 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push: new ec6fc741957 [SPARK-39210][SQL] Provide query context of Decimal overflow in AVG when WSCG is off ec6fc741957 is described below commit ec6fc7419571114cfda94bfa15d4a40712b53fea Author: Gengliang Wang <gengli...@apache.org> AuthorDate: Wed May 18 18:52:15 2022 +0800 [SPARK-39210][SQL] Provide query context of Decimal overflow in AVG when WSCG is off ### What changes were proposed in this pull request? Similar to https://github.com/apache/spark/pull/36525, this PR provides runtime error query context for the Average expression when WSCG is off. ### Why are the changes needed? Enhance the runtime error query context of Average function. After changes, it works when the whole stage codegen is not available. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? New UT Closes #36582 from gengliangwang/fixAvgContext. Authored-by: Gengliang Wang <gengli...@apache.org> Signed-off-by: Gengliang Wang <gengli...@apache.org> (cherry picked from commit 8b5b3e95f8761af97255cbcba35c3d836a419dba) Signed-off-by: Gengliang Wang <gengli...@apache.org> --- .../sql/catalyst/expressions/aggregate/Average.scala | 16 +++++++++++----- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 7 ++++--- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 14914576091..343e27d863b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -81,11 +81,11 @@ abstract class AverageBase // If all input are nulls, count will be 0 and we will get null after the division. // We can't directly use `/` as it throws an exception under ansi mode. - protected def getEvaluateExpression = child.dataType match { + protected def getEvaluateExpression(queryContext: String) = child.dataType match { case _: DecimalType => DecimalPrecision.decimalAndDecimal()( Divide( - CheckOverflowInSum(sum, sumDataType.asInstanceOf[DecimalType], !useAnsiAdd), + CheckOverflowInSum(sum, sumDataType.asInstanceOf[DecimalType], !useAnsiAdd, queryContext), count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType) case _: YearMonthIntervalType => If(EqualTo(count, Literal(0L)), @@ -123,7 +123,7 @@ abstract class AverageBase since = "1.0.0") case class Average( child: Expression, - useAnsiAdd: Boolean = SQLConf.get.ansiEnabled) extends AverageBase { + useAnsiAdd: Boolean = SQLConf.get.ansiEnabled) extends AverageBase with SupportQueryContext { def this(child: Expression) = this(child, useAnsiAdd = SQLConf.get.ansiEnabled) override protected def withNewChildInternal(newChild: Expression): Average = @@ -133,7 +133,13 @@ case class Average( override lazy val mergeExpressions: Seq[Expression] = getMergeExpressions - override lazy val evaluateExpression: Expression = getEvaluateExpression + override lazy val evaluateExpression: Expression = getEvaluateExpression(queryContext) + + override def initQueryContext(): String = if (useAnsiAdd) { + origin.context + } else { + "" + } } // scalastyle:off line.size.limit @@ -192,7 +198,7 @@ case class TryAverage(child: Expression) extends AverageBase { } override lazy val evaluateExpression: Expression = { - addTryEvalIfNeeded(getEvaluateExpression) + addTryEvalIfNeeded(getEvaluateExpression("")) } override protected def withNewChildInternal(newChild: Expression): Expression = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 422ba7c2a9e..919fe88ec4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4423,8 +4423,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } - test("SPARK-39190, SPARK-39208: Query context of decimal overflow error should be serialized " + - "to executors when WSCG is off") { + test("SPARK-39190,SPARK-39208,SPARK-39210: Query context of decimal overflow error should " + + "be serialized to executors when WSCG is off") { withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", SQLConf.ANSI_ENABLED.key -> "true") { withTable("t") { @@ -4432,7 +4432,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark sql("insert into t values (6e37BD),(6e37BD)") Seq( "select d / 0.1 from t", - "select sum(d) from t").foreach { query => + "select sum(d) from t", + "select avg(d) from t").foreach { query => val msg = intercept[SparkException] { sql(query).collect() }.getMessage --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org