This is an automated email from the ASF dual-hosted git repository. viirya pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new b89cd8d [SPARK-35886][SQL] PromotePrecision should not overwrite genCode b89cd8d is described below commit b89cd8d75a0e78c6953cdd21c6e9c41495ed018f Author: Liang-Chi Hsieh <vii...@gmail.com> AuthorDate: Sat Jun 26 23:19:58 2021 -0700 [SPARK-35886][SQL] PromotePrecision should not overwrite genCode ### What changes were proposed in this pull request? This patch fixes `PromotePrecision` where it overwrites `genCode` where subexpression elimination should happen. ### Why are the changes needed? `PromotePrecision` overwrites `genCode` where subexpression elimination should happen. So if it is most top expression of a subexpression, it is never replaced. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added test. Closes #33103 from viirya/fix-precision. Authored-by: Liang-Chi Hsieh <vii...@gmail.com> Signed-off-by: Liang-Chi Hsieh <vii...@gmail.com> --- .../sql/catalyst/expressions/decimalExpressions.scala | 6 ++---- .../expressions/SubexpressionEliminationSuite.scala | 15 ++++++++++++++- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 15 +++++++++++++++ 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 7165bca..673458a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -110,10 +110,8 @@ object MakeDecimal { case class PromotePrecision(child: Expression) extends UnaryExpression { override def dataType: DataType = child.dataType override def eval(input: InternalRow): Any = child.eval(input) - /** Just a simple pass-through for code generation. */ - override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = - ev.copy(EmptyBlock) + child.genCode(ctx) override def prettyName: String = "promote_precision" override def sql: String = child.sql override lazy val canonicalized: Expression = child.canonicalized diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 11f987c..edddfbe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType} +import org.apache.spark.sql.types.{BinaryType, DataType, Decimal, IntegerType} class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHelper { test("Semantic equals and hash") { @@ -391,6 +391,19 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel assert(exprs2.sorted(exprOrdering) === Seq(add2, add1, add2, add1, add2, add1, caseWhenExpr)) } + + test("SPARK-35886: PromotePrecision should not overwrite genCode") { + val p = PromotePrecision(Literal(Decimal("10.1"))) + + val ctx = new CodegenContext() + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(Seq(p, p)) + val code = ctx.withSubExprEliminationExprs(subExprs.states) { + Seq(p.genCode(ctx)) + }.head + // Decimal `Literal` will add the value by `addReferenceObj`. + // So if `p` is replaced by subexpression, the literal will be reused. + assert(code.value.toString == "((Decimal) references[0] /* literal */)") + } } case class CodegenFallbackExpression(child: Expression) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 114eb45..3791666 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2914,6 +2914,21 @@ class DataFrameSuite extends QueryTest val df2 = (1 to 10).toDF() assert(df2.isLocal) } + + test("SPARK-35886: PromotePrecision should be subexpr replaced") { + withTable("tbl") { + sql( + """ + |CREATE TABLE tbl ( + | c1 DECIMAL(18,6), + | c2 DECIMAL(18,6), + | c3 DECIMAL(18,6)) + |USING parquet; + |""".stripMargin) + sql("INSERT INTO tbl SELECT 1, 1, 1") + checkAnswer(sql("SELECT sum(c1 * c3) + sum(c2 * c3) FROM tbl"), Row(2.00000000000) :: Nil) + } + } } case class GroupByKey(a: Int, b: Int) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org