[ 
https://issues.apache.org/jira/browse/SPARK-56949?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
 ]

Kent Yao resolved SPARK-56949.
------------------------------
    Resolution: Fixed

> DecimalAggregates should preserve evalMode of Sum/Average
> ---------------------------------------------------------
>
>                 Key: SPARK-56949
>                 URL: https://issues.apache.org/jira/browse/SPARK-56949
>             Project: Spark
>          Issue Type: Bug
>          Components: SQL
>    Affects Versions: 5.0.0
>            Reporter: Kent Yao
>            Assignee: Kent Yao
>            Priority: Major
>              Labels: correctness, pull-request-available, try-functions
>             Fix For: 4.3.0
>
>
> h3. What's the issue?
> {{Optimizer.DecimalAggregates}} rewrites {{Sum(decimal_col)}} / 
> {{Average(decimal_col)}} to {{Sum(UnscaledValue(decimal_col))}} / 
> {{Average(UnscaledValue(decimal_col))}} via the single-argument helper 
> constructor:
> {code:scala}
> // Sum.scala
> def this(child: Expression) = this(child, 
> NumericEvalContext.fromSQLConf(SQLConf.get))
> // Average.scala
> def this(child: Expression) = this(child, EvalMode.fromSQLConf(SQLConf.get))
> {code}
> This re-reads {{EvalMode}} from {{SQLConf}} at rewrite time, silently 
> dropping {{EvalMode.TRY}} from the original {{Sum}} / {{Average}} node. For 
> {{try_sum}} / {{try_avg}} on narrow-precision decimal columns that trigger 
> the DecimalAggregates fast path (SUM gate {{p + 10 <= 18}}, AVG gate {{p + 4 
> <= 15}}), the {{try_*}} "return NULL on overflow" semantics is broken:
> * With {{spark.sql.ansi.enabled=false}}: overflow silently wraps around 
> instead of returning NULL.
> * With {{spark.sql.ansi.enabled=true}}: overflow throws instead of returning 
> NULL.
> h3. Reproduction
> Vanilla pyspark 3.5.3 (master reproduces identically -- same helper-ctor in 
> {{Sum.scala:52}} / {{Average.scala:50}}):
> {code:python}
> from pyspark.sql import SparkSession
> def walk(spark, sql, label):
>     df = spark.range(10).selectExpr("cast(id as decimal(7,2)) as x")
>     df.createOrReplaceTempView("t")
>     plan = spark.sql(sql)._jdf.queryExecution().optimizedPlan()
>     def walk_expr(e, depth=0):
>         cls = e.getClass().getName()
>         if "aggregate.Sum" in cls or "aggregate.Average" in cls:
>             print(f"{'  '*depth}{cls.split('.')[-1]} evalMode={e.evalMode()} 
> :: {e}")
>         it = e.children().iterator()
>         while it.hasNext():
>             walk_expr(it.next(), depth+1)
>     it = plan.expressions().iterator()
>     while it.hasNext():
>         walk_expr(it.next())
> # Rule ON (default)
> spark = SparkSession.builder.master("local[1]") \
>     .config("spark.sql.ansi.enabled", "false").getOrCreate()
> walk(spark, "select try_sum(x) from t", "rule ON")
> walk(spark, "select try_avg(x) from t", "rule ON")
> spark.stop()
> # Rule OFF -- exclude DecimalAggregates
> spark2 = SparkSession.builder.master("local[1]") \
>     .config("spark.sql.ansi.enabled", "false") \
>     .config("spark.sql.optimizer.excludedRules",
>             "org.apache.spark.sql.catalyst.optimizer.DecimalAggregates") \
>     .getOrCreate()
> walk(spark2, "select try_sum(x) from t", "rule OFF")
> walk(spark2, "select try_avg(x) from t", "rule OFF")
> spark2.stop()
> {code}
> Output:
> {noformat}
> rule ON  : try_sum(x) -> Sum evalMode=LEGACY      [broken]
> rule ON  : try_avg(x) -> Average evalMode=LEGACY  [broken]
> rule OFF : try_sum(x) -> Sum evalMode=TRY         [correct]
> rule OFF : try_avg(x) -> Average evalMode=TRY     [correct]
> {noformat}
> h3. Why does evalMode matter at runtime?
> {{Sum.scala:58}} explicitly branches on {{evalContext.evalMode == 
> EvalMode.TRY}} to decide buffer schema ({{shouldTrackIsEmpty}}), which is 
> what makes {{try_sum}} correctly track overflow and return NULL. Dropping TRY 
> changes user-visible behavior.
> h3. How long has this been broken?
> Latent since {{evalMode}} / {{evalContext}} fields were added to {{Sum}} / 
> {{Average}} as part of the {{try_*}} functions work. The optimizer rule was 
> not updated to preserve them.
> h3. Proposed fix
> Pattern-bind the original aggregate node and use case-class {{copy}}, which 
> preserves all sibling fields (including {{evalMode}} / {{evalContext}}) and 
> is future-proof against new fields:
> {code:scala}
> case s @ Sum(e @ DecimalExpression(prec, scale), _) if prec + 10 <= 
> MAX_LONG_DIGITS =>
>   MakeDecimal(ae.copy(aggregateFunction = s.copy(child = UnscaledValue(e))), 
> prec + 10, scale)
> {code}
> Applied to all 4 {{Aggregate}}-arm sites in {{Optimizer.DecimalAggregates}}.
> h3. Scope
> This issue covers the 4 {{Aggregate}}-arm sites. The 2 {{Window}}-arm sites 
> (also using the single-arg helper) will be addressed in a separate follow-up 
> after a dedicated repro confirms they are reachable through 
> {{ExtractWindowExpressions}} hoisting.



--
This message was sent by Atlassian Jira
(v8.20.10#820010)

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to