[
https://issues.apache.org/jira/browse/SPARK-56949?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
]
Kent Yao resolved SPARK-56949.
------------------------------
Resolution: Fixed
> DecimalAggregates should preserve evalMode of Sum/Average
> ---------------------------------------------------------
>
> Key: SPARK-56949
> URL: https://issues.apache.org/jira/browse/SPARK-56949
> Project: Spark
> Issue Type: Bug
> Components: SQL
> Affects Versions: 5.0.0
> Reporter: Kent Yao
> Assignee: Kent Yao
> Priority: Major
> Labels: correctness, pull-request-available, try-functions
> Fix For: 4.3.0
>
>
> h3. What's the issue?
> {{Optimizer.DecimalAggregates}} rewrites {{Sum(decimal_col)}} /
> {{Average(decimal_col)}} to {{Sum(UnscaledValue(decimal_col))}} /
> {{Average(UnscaledValue(decimal_col))}} via the single-argument helper
> constructor:
> {code:scala}
> // Sum.scala
> def this(child: Expression) = this(child,
> NumericEvalContext.fromSQLConf(SQLConf.get))
> // Average.scala
> def this(child: Expression) = this(child, EvalMode.fromSQLConf(SQLConf.get))
> {code}
> This re-reads {{EvalMode}} from {{SQLConf}} at rewrite time, silently
> dropping {{EvalMode.TRY}} from the original {{Sum}} / {{Average}} node. For
> {{try_sum}} / {{try_avg}} on narrow-precision decimal columns that trigger
> the DecimalAggregates fast path (SUM gate {{p + 10 <= 18}}, AVG gate {{p + 4
> <= 15}}), the {{try_*}} "return NULL on overflow" semantics is broken:
> * With {{spark.sql.ansi.enabled=false}}: overflow silently wraps around
> instead of returning NULL.
> * With {{spark.sql.ansi.enabled=true}}: overflow throws instead of returning
> NULL.
> h3. Reproduction
> Vanilla pyspark 3.5.3 (master reproduces identically -- same helper-ctor in
> {{Sum.scala:52}} / {{Average.scala:50}}):
> {code:python}
> from pyspark.sql import SparkSession
> def walk(spark, sql, label):
> df = spark.range(10).selectExpr("cast(id as decimal(7,2)) as x")
> df.createOrReplaceTempView("t")
> plan = spark.sql(sql)._jdf.queryExecution().optimizedPlan()
> def walk_expr(e, depth=0):
> cls = e.getClass().getName()
> if "aggregate.Sum" in cls or "aggregate.Average" in cls:
> print(f"{' '*depth}{cls.split('.')[-1]} evalMode={e.evalMode()}
> :: {e}")
> it = e.children().iterator()
> while it.hasNext():
> walk_expr(it.next(), depth+1)
> it = plan.expressions().iterator()
> while it.hasNext():
> walk_expr(it.next())
> # Rule ON (default)
> spark = SparkSession.builder.master("local[1]") \
> .config("spark.sql.ansi.enabled", "false").getOrCreate()
> walk(spark, "select try_sum(x) from t", "rule ON")
> walk(spark, "select try_avg(x) from t", "rule ON")
> spark.stop()
> # Rule OFF -- exclude DecimalAggregates
> spark2 = SparkSession.builder.master("local[1]") \
> .config("spark.sql.ansi.enabled", "false") \
> .config("spark.sql.optimizer.excludedRules",
> "org.apache.spark.sql.catalyst.optimizer.DecimalAggregates") \
> .getOrCreate()
> walk(spark2, "select try_sum(x) from t", "rule OFF")
> walk(spark2, "select try_avg(x) from t", "rule OFF")
> spark2.stop()
> {code}
> Output:
> {noformat}
> rule ON : try_sum(x) -> Sum evalMode=LEGACY [broken]
> rule ON : try_avg(x) -> Average evalMode=LEGACY [broken]
> rule OFF : try_sum(x) -> Sum evalMode=TRY [correct]
> rule OFF : try_avg(x) -> Average evalMode=TRY [correct]
> {noformat}
> h3. Why does evalMode matter at runtime?
> {{Sum.scala:58}} explicitly branches on {{evalContext.evalMode ==
> EvalMode.TRY}} to decide buffer schema ({{shouldTrackIsEmpty}}), which is
> what makes {{try_sum}} correctly track overflow and return NULL. Dropping TRY
> changes user-visible behavior.
> h3. How long has this been broken?
> Latent since {{evalMode}} / {{evalContext}} fields were added to {{Sum}} /
> {{Average}} as part of the {{try_*}} functions work. The optimizer rule was
> not updated to preserve them.
> h3. Proposed fix
> Pattern-bind the original aggregate node and use case-class {{copy}}, which
> preserves all sibling fields (including {{evalMode}} / {{evalContext}}) and
> is future-proof against new fields:
> {code:scala}
> case s @ Sum(e @ DecimalExpression(prec, scale), _) if prec + 10 <=
> MAX_LONG_DIGITS =>
> MakeDecimal(ae.copy(aggregateFunction = s.copy(child = UnscaledValue(e))),
> prec + 10, scale)
> {code}
> Applied to all 4 {{Aggregate}}-arm sites in {{Optimizer.DecimalAggregates}}.
> h3. Scope
> This issue covers the 4 {{Aggregate}}-arm sites. The 2 {{Window}}-arm sites
> (also using the single-arg helper) will be addressed in a separate follow-up
> after a dedicated repro confirms they are reachable through
> {{ExtractWindowExpressions}} hoisting.
--
This message was sent by Atlassian Jira
(v8.20.10#820010)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]