Robert Joseph Evans created SPARK-37024:
-------------------------------------------

             Summary: Even more decimal overflow issues in average
                 Key: SPARK-37024
                 URL: https://issues.apache.org/jira/browse/SPARK-37024
             Project: Spark
          Issue Type: Bug
          Components: SQL
    Affects Versions: 3.2.0
            Reporter: Robert Joseph Evans


As a part of trying to accelerate the {{Decimal}} average aggregation on a 
[GPU|https://nvidia.github.io/spark-rapids/] I noticed a few issues around 
overflow. I think all of these can be fixed by replacing {{Average}} with 
explicit {{Sum}}, {{Count}}, and {{Divide}} operations for decimal instead of 
implicitly doing them. But the extra checks would come with a performance cost.

This is related to SPARK-35955, but goes quite a bit beyond it.
 # There are no ANSI overflow checks on the summation portion of average.
 # Nulls are inserted/overflow is detected on summation differently depending 
on code generation and parallelism.
 # If the input decimal precision is 11 or below all overflow checks are 
disabled, and the answer is wrong instead of null on overflow.

*Details:*

*there are no ANSI overflow checks on the summation portion.*
{code:scala}
scala> spark.conf.set("spark.sql.ansi.enabled", "true")

scala> spark.time(spark.range(2000001)
    .repartition(2)
    .selectExpr("id", "CAST('99999999999999999999999999999999' AS DECIMAL(32, 
0)) as v")
    .selectExpr("AVG(v)")
    .show(truncate = false))
+------+
|avg(v)|
+------+
|null  |
+------+

Time taken: 622 ms

scala> spark.time(spark.range(2000001)
    .repartition(2)
    .selectExpr("id", "CAST('99999999999999999999999999999999' AS DECIMAL(32, 
0)) as v")
    .selectExpr("SUM(v)")
    .show(truncate = false))
21/10/16 06:08:00 ERROR Executor: Exception in task 0.0 in stage 15.0 (TID 19)
java.lang.ArithmeticException: Overflow in sum of decimals.
...
{code}
*nulls are inserted on summation overflow differently depending on code 
generation and parallelism.*

Because there are no explicit overflow checks when doing the sum a user can get 
very inconsistent results for when a null is inserted on overflow. The checks 
really only take place when the {{Decimal}} value is converted and stored into 
an {{UnsafeRow}}.  This happens when the values are shuffled, or after each 
operation if code gen is disabled.  For a {{DECIMAL(32, 0)}} you can add 
1,000,000 max values before the summation overflows.
{code:scala}
scala> spark.time(spark.range(1000000)
    .coalesce(2)
    .selectExpr("id", "CAST('99999999999999999999999999999999' AS DECIMAL(32, 
0)) as v")
    .selectExpr("SUM(v) as s", "COUNT(v) as c", "AVG(v) as a")
    .selectExpr("s", "c", "s/c as sum_div_count", "a")
    .show(truncate = false))
+--------------------------------------+-------+---------------------------------------+-------------------------------------+
|s                                     |c      |sum_div_count                   
       |a                                    |
+--------------------------------------+-------+---------------------------------------+-------------------------------------+
|99999999999999999999999999999999000000|1000000|99999999999999999999999999999999.000000|99999999999999999999999999999999.0000|
+--------------------------------------+-------+---------------------------------------+-------------------------------------+
Time taken: 241 ms

scala> spark.time(spark.range(2000000)
    .coalesce(2)
    .selectExpr("id", "CAST('99999999999999999999999999999999' AS DECIMAL(32, 
0)) as v")
    .selectExpr("SUM(v) as s", "COUNT(v) as c", "AVG(v) as a")
    .selectExpr("s", "c", "s/c as sum_div_count", "a")
    .show(truncate = false))
+----+-------+-------------+-------------------------------------+
|s   |c      |sum_div_count|a                                    |
+----+-------+-------------+-------------------------------------+
|null|2000000|null         |99999999999999999999999999999999.0000|
+----+-------+-------------+-------------------------------------+
Time taken: 228 ms

scala> spark.time(spark.range(3000000)
    .coalesce(2)
    .selectExpr("id", "CAST('99999999999999999999999999999999' AS DECIMAL(32, 
0)) as v")
    .selectExpr("SUM(v) as s", "COUNT(v) as c", "AVG(v) as a")
    .selectExpr("s", "c", "s/c as sum_div_count", "a")
    .show(truncate = false))
+----+-------+-------------+----+
|s   |c      |sum_div_count|a   |
+----+-------+-------------+----+
|null|3000000|null         |null|
+----+-------+-------------+----+
Time taken: 347 ms

scala> spark.conf.set("spark.sql.codegen.wholeStage", "false")
scala> spark.time(spark.range(1000001)
    .coalesce(2)
    .selectExpr("id", "CAST('99999999999999999999999999999999' AS DECIMAL(32, 
0)) as v")
    .selectExpr("SUM(v) as s", "COUNT(v) as c", "AVG(v) as a")
    .selectExpr("s", "c", "s/c as sum_div_count", "a")
    .show(truncate = false))
+----+-------+-------------+----+
|s   |c      |sum_div_count|a   |
+----+-------+-------------+----+
|null|1000001|null         |null|
+----+-------+-------------+----+
Time taken: 310 ms
{code}
With code gen disabled the limit is enforced on 1,000,001 entries, just like 
with sum, but if code gen is enabled it depends on the number of upstream tasks 
and the order of the data, which means if I change the size of the cluster I am 
running on I might get different results from one run to another.

*if the input decimal precision is 11 or below all overflow checks are 
disabled.*

When the precision of a {{Decimal}} in an average is 11 or below, the average 
will be done in terms of a {{Double}}.  The logic is that a {{Double}} has 
around 56 bits of precision which should give us {{MAX_DOUBLE_DIGITS = 15}} 
worth of decimal precision. The main problem is that the 15 digit limit is 
compared to the output precision, which is input precision + 4, instead of 
being compared to the summation precision, which is input [precision 
+10|https://github.com/apache/spark/blob/67b547aa1cb7aeaf0f6e1f1017d21f582dacf697/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L73].
 So instead of getting a full 10 billion entries before an overflow happens. In 
the worst case I was able to see it happen at just 90,100 entries.
{code:scala}
scala> spark.time(spark.range(90100)
    .repartition(1)
    .selectExpr("id", "CAST('99999999999' AS DECIMAL(11, 0)) as v")
    .selectExpr("AVG(v)")
    .show(truncate = false))
+----------------+
|avg(v)          |
+----------------+
|99999999999.0003|
+----------------+
Time taken: 86 ms
{code}
{{Sum}} has a similar optimization but does its calculations as a long. {{Sum}} 
guarantees 10 billion entries before overflowing (which is what the +10 
precision is for), and in practice can actually sum over 90-billion max values 
in the worst case before the overflow gets to a point where sum can no longer 
detect it and produces an incorrect answer.  For me personally I feel okay with 
taking a risk that my data has 90 billion max values in it vs ~100,000 of them.

*Performance cost*

But there would be a performance cost in switching to SUM/COUNT.  There is the 
cost of the checks in general that sum does.
{code:scala}
scala> spark.time(spark.range(Int.MaxValue)
    .selectExpr("id", "CAST('999999.99' AS DECIMAL(8, 2)) as v")
    .selectExpr("ROUND(SUM(v)/COUNT(v), 6) as sum_cnt_avg")
    .show(truncate = false))
+-------------+                                                                 
|sum_cnt_avg  |
+-------------+
|999999.990000|
+-------------+

Time taken: 2201 ms

scala> spark.time(spark.range(Int.MaxValue)
    .selectExpr("id", "CAST('999999.99' AS DECIMAL(8, 2)) as v")
    .selectExpr("AVG(v)")
    .show(truncate = false))
+-------------+                                                                 
|avg(v)       |
+-------------+
|999999.994967|
+-------------+

Time taken: 1845 ms
{code}
But there is also the cost that averages on Decimal values with a precision of 
9, 10, or 11 would no longer have a performance optimization by doing them as 
longs/doubles.
{code:scala}
scala> spark.time(spark.range(Int.MaxValue)
    .selectExpr("id", "CAST('999999.99' AS DECIMAL(9, 2)) as v")
    .selectExpr("ROUND(SUM(v)/COUNT(v), 6) as sum_cnt_avg")
    .show(truncate = false))
+-------------+                                                                 
|sum_cnt_avg  |
+-------------+
|999999.990000|
+-------------+
Time taken: 13252 ms

scala> spark.time(spark.range(Int.MaxValue)
    .selectExpr("id", "CAST('999999.99' AS DECIMAL(9, 2)) as v")
    .selectExpr("AVG(v)")
    .show(truncate = false))
+-------------+                                                                 
|avg(v)       |
+-------------+
|999999.994967|
+-------------+
Time taken: 1928 ms
{code}
 



--
This message was sent by Atlassian Jira
(v8.3.4#803005)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to