This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b89cd8d  [SPARK-35886][SQL] PromotePrecision should not overwrite 
genCode
b89cd8d is described below

commit b89cd8d75a0e78c6953cdd21c6e9c41495ed018f
Author: Liang-Chi Hsieh <vii...@gmail.com>
AuthorDate: Sat Jun 26 23:19:58 2021 -0700

    [SPARK-35886][SQL] PromotePrecision should not overwrite genCode
    
    ### What changes were proposed in this pull request?
    
    This patch fixes `PromotePrecision` where it overwrites `genCode` where 
subexpression elimination should happen.
    
    ### Why are the changes needed?
    
    `PromotePrecision` overwrites `genCode` where subexpression elimination 
should happen. So if it is most top expression of a subexpression, it is never 
replaced.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added test.
    
    Closes #33103 from viirya/fix-precision.
    
    Authored-by: Liang-Chi Hsieh <vii...@gmail.com>
    Signed-off-by: Liang-Chi Hsieh <vii...@gmail.com>
---
 .../sql/catalyst/expressions/decimalExpressions.scala     |  6 ++----
 .../expressions/SubexpressionEliminationSuite.scala       | 15 ++++++++++++++-
 .../test/scala/org/apache/spark/sql/DataFrameSuite.scala  | 15 +++++++++++++++
 3 files changed, 31 insertions(+), 5 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index 7165bca..673458a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
EmptyBlock, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode}
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.internal.SQLConf
@@ -110,10 +110,8 @@ object MakeDecimal {
 case class PromotePrecision(child: Expression) extends UnaryExpression {
   override def dataType: DataType = child.dataType
   override def eval(input: InternalRow): Any = child.eval(input)
-  /** Just a simple pass-through for code generation. */
-  override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx)
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode =
-    ev.copy(EmptyBlock)
+    child.genCode(ctx)
   override def prettyName: String = "promote_precision"
   override def sql: String = child.sql
   override lazy val canonicalized: Expression = child.canonicalized
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
index 11f987c..edddfbe 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType}
+import org.apache.spark.sql.types.{BinaryType, DataType, Decimal, IntegerType}
 
 class SubexpressionEliminationSuite extends SparkFunSuite with 
ExpressionEvalHelper {
   test("Semantic equals and hash") {
@@ -391,6 +391,19 @@ class SubexpressionEliminationSuite extends SparkFunSuite 
with ExpressionEvalHel
     assert(exprs2.sorted(exprOrdering) ===
       Seq(add2, add1, add2, add1, add2, add1, caseWhenExpr))
   }
+
+  test("SPARK-35886: PromotePrecision should not overwrite genCode") {
+    val p = PromotePrecision(Literal(Decimal("10.1")))
+
+    val ctx = new CodegenContext()
+    val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(Seq(p, p))
+    val code = ctx.withSubExprEliminationExprs(subExprs.states) {
+      Seq(p.genCode(ctx))
+    }.head
+    // Decimal `Literal` will add the value by `addReferenceObj`.
+    // So if `p` is replaced by subexpression, the literal will be reused.
+    assert(code.value.toString == "((Decimal) references[0] /* literal */)")
+  }
 }
 
 case class CodegenFallbackExpression(child: Expression)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 114eb45..3791666 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -2914,6 +2914,21 @@ class DataFrameSuite extends QueryTest
     val df2 = (1 to 10).toDF()
     assert(df2.isLocal)
   }
+
+  test("SPARK-35886: PromotePrecision should be subexpr replaced") {
+    withTable("tbl") {
+      sql(
+        """
+          |CREATE TABLE tbl (
+          |  c1 DECIMAL(18,6),
+          |  c2 DECIMAL(18,6),
+          |  c3 DECIMAL(18,6))
+          |USING parquet;
+          |""".stripMargin)
+      sql("INSERT INTO tbl SELECT 1, 1, 1")
+      checkAnswer(sql("SELECT sum(c1 * c3) + sum(c2 * c3) FROM tbl"), 
Row(2.00000000000) :: Nil)
+    }
+  }
 }
 
 case class GroupByKey(a: Int, b: Int)

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to