This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new ca3593288d57 [SPARK-48252][SQL] Update CommonExpressionRef when necessary ca3593288d57 is described below commit ca3593288d577435a193f356b5214cf6f4bd534a Author: Wenchen Fan <wenc...@databricks.com> AuthorDate: Thu May 16 09:42:36 2024 +0800 [SPARK-48252][SQL] Update CommonExpressionRef when necessary ### What changes were proposed in this pull request? The `With` expression assumes that it should be created after all input expressions are fully resolved. This is mostly true (function lookup happens after function input expressions are resolved), but there is a special case of column resolution in HAVING: we use `TempResolvedColumn` to try one column resolution option. If it doesn't work, re-resolve the column, which may be a different data type. `With` expression should update the refs when this happens. ### Why are the changes needed? bug fix, otherwise the query will fail ### Does this PR introduce _any_ user-facing change? This feature is not released yet. ### How was this patch tested? new test ### Was this patch authored or co-authored using generative AI tooling? no Closes #46552 from cloud-fan/with. Lead-authored-by: Wenchen Fan <wenc...@databricks.com> Co-authored-by: Wenchen Fan <cloud0...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../apache/spark/sql/catalyst/expressions/With.scala | 18 +++++++++++++++++- .../optimizer/RewriteWithExpressionSuite.scala | 14 ++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala index 29794b33641c..5f6f9afa5797 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala @@ -40,7 +40,23 @@ case class With(child: Expression, defs: Seq[CommonExpressionDef]) override def children: Seq[Expression] = child +: defs override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): Expression = { - copy(child = newChildren.head, defs = newChildren.tail.map(_.asInstanceOf[CommonExpressionDef])) + val newDefs = newChildren.tail.map(_.asInstanceOf[CommonExpressionDef]) + // If any `CommonExpressionDef` has been updated (data type or nullability), also update its + // `CommonExpressionRef` in the `child`. + val newChild = newDefs.filter(_.resolved).foldLeft(newChildren.head) { (result, newDef) => + defs.find(_.id == newDef.id).map { oldDef => + if (newDef.dataType != oldDef.dataType || newDef.nullable != oldDef.nullable) { + val newRef = new CommonExpressionRef(newDef) + result.transform { + case oldRef: CommonExpressionRef if oldRef.id == newRef.id => + newRef + } + } else { + result + } + }.getOrElse(result) + } + copy(child = newChild, defs = newDefs) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala index aa8ffb2b0454..0aeca961aa51 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.analysis.TempResolvedColumn import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -438,4 +439,17 @@ class RewriteWithExpressionSuite extends PlanTest { Optimizer.execute(plan) } } + + test("SPARK-48252: TempResolvedColumn in common expression") { + val a = testRelation.output.head + val tempResolved = TempResolvedColumn(a, Seq("a")) + val expr = With(tempResolved) { case Seq(ref) => + ref === 1 + } + val plan = testRelation.having($"b")(avg("a").as("a"))(expr).analyze + comparePlans( + Optimizer.execute(plan), + testRelation.groupBy($"b")(avg("a").as("a")).where($"a" === 1).analyze + ) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org