Repository: spark
Updated Branches:
  refs/heads/branch-2.0 e4f57f2b3 -> d7b9d6235


[SPARK-21332][SQL] Incorrect result type inferred for some decimal expressions

## What changes were proposed in this pull request?

This PR changes the direction of expression transformation in the 
DecimalPrecision rule. Previously, the expressions were transformed down, which 
led to incorrect result types when decimal expressions had other decimal 
expressions as their operands. The root cause of this issue was in visiting 
outer nodes before their children. Consider the example below:

```
    val inputSchema = StructType(StructField("col", DecimalType(26, 6)) :: Nil)
    val sc = spark.sparkContext
    val rdd = sc.parallelize(1 to 2).map(_ => Row(BigDecimal(12)))
    val df = spark.createDataFrame(rdd, inputSchema)

    // Works correctly since no nested decimal expression is involved
    // Expected result type: (26, 6) * (26, 6) = (38, 12)
    df.select($"col" * $"col").explain(true)
    df.select($"col" * $"col").printSchema()

    // Gives a wrong result since there is a nested decimal expression that 
should be visited first
    // Expected result type: ((26, 6) * (26, 6)) * (26, 6) = (38, 12) * (26, 6) 
= (38, 18)
    df.select($"col" * $"col" * $"col").explain(true)
    df.select($"col" * $"col" * $"col").printSchema()
```

The example above gives the following output:

```
// Correct result without sub-expressions
== Parsed Logical Plan ==
'Project [('col * 'col) AS (col * col)#4]
+- LogicalRDD [col#1]

== Analyzed Logical Plan ==
(col * col): decimal(38,12)
Project [CheckOverflow((promote_precision(cast(col#1 as decimal(26,6))) * 
promote_precision(cast(col#1 as decimal(26,6)))), DecimalType(38,12)) AS (col * 
col)#4]
+- LogicalRDD [col#1]

== Optimized Logical Plan ==
Project [CheckOverflow((col#1 * col#1), DecimalType(38,12)) AS (col * col)#4]
+- LogicalRDD [col#1]

== Physical Plan ==
*Project [CheckOverflow((col#1 * col#1), DecimalType(38,12)) AS (col * col)#4]
+- Scan ExistingRDD[col#1]

// Schema
root
 |-- (col * col): decimal(38,12) (nullable = true)

// Incorrect result with sub-expressions
== Parsed Logical Plan ==
'Project [(('col * 'col) * 'col) AS ((col * col) * col)#11]
+- LogicalRDD [col#1]

== Analyzed Logical Plan ==
((col * col) * col): decimal(38,12)
Project 
[CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(col#1
 as decimal(26,6))) * promote_precision(cast(col#1 as decimal(26,6)))), 
DecimalType(38,12)) as decimal(26,6))) * promote_precision(cast(col#1 as 
decimal(26,6)))), DecimalType(38,12)) AS ((col * col) * col)#11]
+- LogicalRDD [col#1]

== Optimized Logical Plan ==
Project [CheckOverflow((cast(CheckOverflow((col#1 * col#1), DecimalType(38,12)) 
as decimal(26,6)) * col#1), DecimalType(38,12)) AS ((col * col) * col)#11]
+- LogicalRDD [col#1]

== Physical Plan ==
*Project [CheckOverflow((cast(CheckOverflow((col#1 * col#1), 
DecimalType(38,12)) as decimal(26,6)) * col#1), DecimalType(38,12)) AS ((col * 
col) * col)#11]
+- Scan ExistingRDD[col#1]

// Schema
root
 |-- ((col * col) * col): decimal(38,12) (nullable = true)
```

## How was this patch tested?

This PR was tested with available unit tests. Moreover, there are tests to 
cover previously failing scenarios.

Author: aokolnychyi <anton.okolnyc...@sap.com>

Closes #18583 from aokolnychyi/spark-21332.

(cherry picked from commit 0be5fb41a6b7ef4da9ba36f3604ac646cb6d4ae3)
Signed-off-by: gatorsmile <gatorsm...@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d7b9d623
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d7b9d623
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d7b9d623

Branch: refs/heads/branch-2.0
Commit: d7b9d623598f868a63c2cf872e9188b9181947d5
Parents: e4f57f2
Author: aokolnychyi <anton.okolnyc...@sap.com>
Authored: Mon Jul 17 21:07:50 2017 -0700
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Mon Jul 17 21:08:57 2017 -0700

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/analysis/DecimalPrecision.scala  | 2 +-
 .../spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala    | 6 ++++++
 2 files changed, 7 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d7b9d623/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
index 9c38dd2..fd2ac78 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
@@ -80,7 +80,7 @@ object DecimalPrecision extends Rule[LogicalPlan] {
 
   def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
     // fix decimal precision for expressions
-    case q => q.transformExpressions(
+    case q => q.transformExpressionsUp(
       
decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b9d623/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
index 66d9b4c..f98d5c0 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -92,8 +92,14 @@ class DecimalPrecisionSuite extends PlanTest with 
BeforeAndAfter {
     checkType(Average(d1), DecimalType(6, 5))
 
     checkType(Add(Add(d1, d2), d1), DecimalType(7, 2))
+    checkType(Add(Add(d1, d1), d1), DecimalType(4, 1))
+    checkType(Add(d1, Add(d1, d1)), DecimalType(4, 1))
     checkType(Add(Add(Add(d1, d2), d1), d2), DecimalType(8, 2))
     checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2))
+    checkType(Subtract(Subtract(d2, d1), d1), DecimalType(7, 2))
+    checkType(Multiply(Multiply(d1, d1), d2), DecimalType(11, 4))
+    checkType(Divide(d2, Add(d1, d1)), DecimalType(10, 6))
+    checkType(Sum(Add(d1, d1)), DecimalType(13, 1))
   }
 
   test("Comparison operations") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to