This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 763c448759d [SPARK-44714] Ease restriction of LCA resolution regarding queries with having 763c448759d is described below commit 763c448759df02a5370b8b50cb877f855c4eda10 Author: Xinyi Yu <xinyi...@databricks.com> AuthorDate: Tue Aug 8 17:07:04 2023 +0800 [SPARK-44714] Ease restriction of LCA resolution regarding queries with having ### What changes were proposed in this pull request? This PR eases some restriction of LCA resolution regarding queries with having. Previously LCA won't rewrite (to the new plan shape) when the whole queries contains `UnresolvedHaving`, in case it breaks the plan shape of `UnresolvedHaving - Aggregate` that can be recognized by other rules. But this limitation is too strict and it causes some deadlock in having - lca - window queries. See https://issues.apache.org/jira/browse/SPARK-42936 for more details and examples. With this PR, it will only skip LCA resolution on the `Aggregate` whose direct parent is `UnresolvedHaving`. This is enabled by a new bottom-up resolution without using the transform or resolve utility function. This PR also recognizes a vulnerability related to `TEMP_RESOVLED_COLUMN` and comments in the code. It should be considered as future work. ### Why are the changes needed? More complete functionality and better user experience. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New tests. Closes #42276 from anchovYu/lca-limitation-better-error. Authored-by: Xinyi Yu <xinyi...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit 29e8331681c6214390f426806d19ee9673b073e1) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/analysis/Analyzer.scala | 2 + .../ResolveLateralColumnAliasReference.scala | 200 ++++++++++++--------- .../apache/spark/sql/LateralColumnAliasSuite.scala | 145 +++++++++++++++ 3 files changed, 261 insertions(+), 86 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 09467c22e2b..6c5d19f58ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -285,6 +285,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor AddMetadataColumns :: DeduplicateRelations :: ResolveReferences :: + // Please do not insert any other rules in between. See the TODO comments in rule + // ResolveLateralColumnAliasReference for more details. ResolveLateralColumnAliasReference :: ResolveExpressionsWithNamePlaceholders :: ResolveDeserializer :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala index 5d89de00084..c249a3506f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.WindowExpression.hasWindowExpre import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, TEMP_RESOLVED_COLUMN, UNRESOLVED_HAVING} +import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, TEMP_RESOLVED_COLUMN} import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -131,95 +131,97 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { (pList.exists(hasWindowExpression) && p.expressions.forall(_.resolved) && p.childrenResolved) } - override def apply(plan: LogicalPlan): LogicalPlan = { - if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { - plan - } else if (plan.containsAnyPattern(TEMP_RESOLVED_COLUMN, UNRESOLVED_HAVING)) { - // It should not change the plan if `TempResolvedColumn` or `UnresolvedHaving` is present in - // the query plan. These plans need certain plan shape to get recognized and resolved by other - // rules, such as Filter/Sort + Aggregate to be matched by ResolveAggregateFunctions. - // LCA resolution can break the plan shape, like adding Project above Aggregate. - plan - } else { - // phase 2: unwrap - plan.resolveOperatorsUpWithPruning( - _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE), ruleId) { - case p @ Project(projectList, child) if ruleApplicableOnOperator(p, projectList) - && projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => - var aliasMap = AttributeMap.empty[AliasEntry] - val referencedAliases = collection.mutable.Set.empty[AliasEntry] - def unwrapLCAReference(e: NamedExpression): NamedExpression = { - e.transformWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { - case lcaRef: LateralColumnAliasReference if aliasMap.contains(lcaRef.a) => - val aliasEntry = aliasMap.get(lcaRef.a).get - // If there is no chaining of lateral column alias reference, push down the alias - // and unwrap the LateralColumnAliasReference to the NamedExpression inside - // If there is chaining, don't resolve and save to future rounds - if (!aliasEntry.alias.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { - referencedAliases += aliasEntry - lcaRef.ne - } else { - lcaRef - } - case lcaRef: LateralColumnAliasReference if !aliasMap.contains(lcaRef.a) => - // It shouldn't happen, but restore to unresolved attribute to be safe. - UnresolvedAttribute(lcaRef.nameParts) - }.asInstanceOf[NamedExpression] - } - val newProjectList = projectList.zipWithIndex.map { - case (a: Alias, idx) => - val lcaResolved = unwrapLCAReference(a) - // Insert the original alias instead of rewritten one to detect chained LCA - aliasMap += (a.toAttribute -> AliasEntry(a, idx)) - lcaResolved - case (e, _) => - unwrapLCAReference(e) - } + /** Internal application method. A hand-written bottom-up recursive traverse. */ + private def apply0(plan: LogicalPlan): LogicalPlan = { + plan match { + case p: LogicalPlan if !p.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE) => + p - if (referencedAliases.isEmpty) { - p - } else { - val outerProjectList = collection.mutable.Seq(newProjectList: _*) - val innerProjectList = - collection.mutable.ArrayBuffer(child.output.map(_.asInstanceOf[NamedExpression]): _*) - referencedAliases.foreach { case AliasEntry(alias: Alias, idx) => - outerProjectList.update(idx, alias.toAttribute) - innerProjectList += alias - } - p.copy( - projectList = outerProjectList.toSeq, - child = Project(innerProjectList.toSeq, child) - ) - } + // It should not change the Aggregate (and thus the plan shape) if its parent is an + // UnresolvedHaving, to avoid breaking the shape pattern `UnresolvedHaving - Aggregate` + // matched by ResolveAggregateFunctions. See SPARK-42936 and SPARK-44714 for more details. + case u @ UnresolvedHaving(_, agg: Aggregate) => + u.copy(child = agg.mapChildren(apply0)) - case agg @ Aggregate(groupingExpressions, aggregateExpressions, _) - if ruleApplicableOnOperator(agg, aggregateExpressions) - && aggregateExpressions.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => + case pOriginal: Project if ruleApplicableOnOperator(pOriginal, pOriginal.projectList) + && pOriginal.projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => + val p @ Project(projectList, child) = pOriginal.mapChildren(apply0) + var aliasMap = AttributeMap.empty[AliasEntry] + val referencedAliases = collection.mutable.Set.empty[AliasEntry] + def unwrapLCAReference(e: NamedExpression): NamedExpression = { + e.transformWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + case lcaRef: LateralColumnAliasReference if aliasMap.contains(lcaRef.a) => + val aliasEntry = aliasMap.get(lcaRef.a).get + // If there is no chaining of lateral column alias reference, push down the alias + // and unwrap the LateralColumnAliasReference to the NamedExpression inside + // If there is chaining, don't resolve and save to future rounds + if (!aliasEntry.alias.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + referencedAliases += aliasEntry + lcaRef.ne + } else { + lcaRef + } + case lcaRef: LateralColumnAliasReference if !aliasMap.contains(lcaRef.a) => + // It shouldn't happen, but restore to unresolved attribute to be safe. + UnresolvedAttribute(lcaRef.nameParts) + }.asInstanceOf[NamedExpression] + } + val newProjectList = projectList.zipWithIndex.map { + case (a: Alias, idx) => + val lcaResolved = unwrapLCAReference(a) + // Insert the original alias instead of rewritten one to detect chained LCA + aliasMap += (a.toAttribute -> AliasEntry(a, idx)) + lcaResolved + case (e, _) => + unwrapLCAReference(e) + } - // Check if current Aggregate is eligible to lift up with Project: the aggregate - // expression only contains: 1) aggregate functions, 2) grouping expressions, 3) leaf - // expressions excluding attributes not in grouping expressions - // This check is to prevent unnecessary transformation on invalid plan, to guarantee it - // throws the same exception. For example, cases like non-aggregate expressions not - // in group by, once transformed, will throw a different exception: missing input. - def eligibleToLiftUp(exp: Expression): Boolean = { - exp match { - case _: AggregateExpression => true - case e if groupingExpressions.exists(_.semanticEquals(e)) => true - case a: Attribute => false - case s: ScalarSubquery if s.children.nonEmpty - && !groupingExpressions.exists(_.semanticEquals(s)) => false - // Manually skip detection on function itself because it can be an aggregate function. - // This is to avoid expressions like sum(salary) over () eligible to lift up. - case WindowExpression(function, spec) => - function.children.forall(eligibleToLiftUp) && eligibleToLiftUp(spec) - case e => e.children.forall(eligibleToLiftUp) - } - } - if (!aggregateExpressions.forall(eligibleToLiftUp)) { - return agg + if (referencedAliases.isEmpty) { + p + } else { + val outerProjectList = collection.mutable.Seq(newProjectList: _*) + val innerProjectList = + collection.mutable.ArrayBuffer(child.output.map(_.asInstanceOf[NamedExpression]): _*) + referencedAliases.foreach { case AliasEntry(alias: Alias, idx) => + outerProjectList.update(idx, alias.toAttribute) + innerProjectList += alias } + p.copy( + projectList = outerProjectList.toSeq, + child = Project(innerProjectList.toSeq, child) + ) + } + + case aggOriginal: Aggregate + if ruleApplicableOnOperator(aggOriginal, aggOriginal.aggregateExpressions) + && aggOriginal.aggregateExpressions.exists( + _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => + val agg @ Aggregate(groupingExpressions, aggregateExpressions, _) = + aggOriginal.mapChildren(apply0) + // Check if current Aggregate is eligible to lift up with Project: the aggregate + // expression only contains: 1) aggregate functions, 2) grouping expressions, 3) leaf + // expressions excluding attributes not in grouping expressions + // This check is to prevent unnecessary transformation on invalid plan, to guarantee it + // throws the same exception. For example, cases like non-aggregate expressions not + // in group by, once transformed, will throw a different exception: missing input. + def eligibleToLiftUp(exp: Expression): Boolean = { + exp match { + case _: AggregateExpression => true + case e if groupingExpressions.exists(_.semanticEquals(e)) => true + case a: Attribute => false + case s: ScalarSubquery if s.children.nonEmpty + && !groupingExpressions.exists(_.semanticEquals(s)) => false + // Manually skip detection on function itself because it can be an aggregate function. + // This is to avoid expressions like sum(salary) over () eligible to lift up. + case WindowExpression(function, spec) => + function.children.forall(eligibleToLiftUp) && eligibleToLiftUp(spec) + case e => e.children.forall(eligibleToLiftUp) + } + } + if (!aggregateExpressions.forall(eligibleToLiftUp)) { + agg + } else { val newAggExprs = collection.mutable.Set.empty[NamedExpression] val expressionMap = collection.mutable.LinkedHashMap.empty[Expression, NamedExpression] // Extract the expressions to keep in the Aggregate. Return the transformed expression @@ -262,7 +264,33 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { projectList = projectExprs, child = agg.copy(aggregateExpressions = newAggExprs.toSeq) ) - } + } + + case p: LogicalPlan => + p.mapChildren(apply0) + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { + plan + } else if (plan.containsAnyPattern(TEMP_RESOLVED_COLUMN)) { + // It should not change the plan if `TempResolvedColumn` is present in the query plan. These + // plans need certain plan shape to get recognized and resolved by other rules, such as + // Filter/Sort + Aggregate to be matched by ResolveAggregateFunctions. LCA resolution can + // break the plan shape, like adding Project above Aggregate. + // TODO: this condition only guarantees to keep the shape after the plan has + // `TempResolvedColumn`. However, it does not consider the case of breaking the shape even + // before `TempResolvedColumn` is generated by matching Filter/Sort - Aggregate in + // ResolveReferences. Currently the correctness of this case now relies on the rule + // application order, that ResolveReference is right before the application of + // ResolveLateralColumnAliasReference. The condition in the two rules guarantees that the + // case can never happen. We should consider to remove this order dependency but still assure + // correctness in the future. + plan + } else { + // phase 2: unwrap + apply0(plan) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index 1e3a0d70c7f..cc4aeb42326 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -669,6 +669,20 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { s"FROM $testTable GROUP BY dept ORDER BY max(name)"), Row(1, 1) :: Row(2, 2) :: Row(6, 6) :: Nil ) + checkAnswer( + sql("SELECT dept, avg(salary) AS a, a + 10 FROM employee GROUP BY dept ORDER BY max(name)"), + Row(1, 9500, 9510) :: Row(2, 11000, 11010) :: Row(6, 12000, 12010) :: Nil + ) + checkAnswer( + sql("SELECT dept, avg(salary) AS a, a + 10 AS b " + + "FROM employee GROUP BY dept ORDER BY max(name)"), + Row(1, 9500, 9510) :: Row(2, 11000, 11010) :: Row(6, 12000, 12010) :: Nil + ) + checkAnswer( + sql("SELECT dept, avg(salary) AS a, a + cast(10 as double) AS b " + + "FROM employee GROUP BY dept ORDER BY max(name)"), + Row(1, 9500, 9510) :: Row(2, 11000, 11010) :: Row(6, 12000, 12010) :: Nil + ) // having cond is resolved by aggregate's child checkAnswer( @@ -676,6 +690,21 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { s"FROM $testTable GROUP BY dept HAVING max(name) = 'david'"), Row(1250, 2, 11000, 11010) :: Nil ) + checkAnswer( + sql("SELECT dept, avg(salary) AS a, a + 10 " + + "FROM employee GROUP BY dept HAVING max(bonus) > 1200"), + Row(2, 11000, 11010) :: Nil + ) + checkAnswer( + sql("SELECT dept, avg(salary) AS a, a + 10 AS b " + + "FROM employee GROUP BY dept HAVING max(bonus) > 1200"), + Row(2, 11000, 11010) :: Nil + ) + checkAnswer( + sql("SELECT dept, avg(salary) AS a, a + cast(10 as double) AS b " + + "FROM employee GROUP BY dept HAVING max(bonus) > 1200"), + Row(2, 11000, 11010) :: Nil + ) // having cond is resolved by aggregate itself checkAnswer( sql(s"SELECT avg(bonus) AS a, a FROM $testTable GROUP BY dept HAVING a > 1200"), @@ -1139,4 +1168,120 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { // non group by or non aggregate function in Aggregate queries negative cases are covered in // "Aggregate expressions not eligible to lift up, throws same error as inline". } + + test("Still resolves when Aggregate with LCA is not the direct child of Having") { + // Previously there was a limitation of lca that it can't resolve the query when it satisfies + // all the following criteria: + // 1) the main (outer) query has having clause + // 2) there is a window expression in the query + // 3) in the same SELECT list as the window expression in 2), there is an lca + // Though [UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_WITH_WINDOW_AND_HAVING] is + // still not supported, after SPARK-44714, a lot other limitations are + // lifted because it allows to resolve LCA when the query has UnresolvedHaving but its direct + // child does not contain an LCA. + // Testcases in this test focus on this change regarding enablement of resolution. + + // CTE definition contains window and LCA; outer query contains having + checkAnswer( + sql( + s""" + |with w as ( + | select name, dept, salary, rank() over (partition by dept order by salary) as r, r + | from $testTable + |) + |select dept + |from w + |group by dept + |having max(salary) > 10000 + |""".stripMargin), + Row(2) :: Row(6) :: Nil + ) + checkAnswer( + sql( + s""" + |with w as ( + | select name, dept, salary, rank() over (partition by dept order by salary) as r, r + | from $testTable + |) + |select dept as d, d + |from w + |group by dept + |having max(salary) > 10000 + |""".stripMargin), + Row(2, 2) :: Row(6, 6) :: Nil + ) + checkAnswer( + sql( + s""" + |with w as ( + | select name, dept, salary, rank() over (partition by dept order by salary) as r, r + | from $testTable + |) + |select dept as d + |from w + |group by dept + |having d = 2 + |""".stripMargin), + Row(2) :: Nil + ) + + // inner subquery contains window and LCA; outer query contains having + checkAnswer( + sql( + s""" + |SELECT + | dept + |FROM + | ( + | select + | name, dept, salary, rank() over (partition by dept order by salary) as r, + | 1 as a, a + 1 as e + | FROM + | $testTable + | ) AS inner_t + |GROUP BY + | dept + |HAVING max(salary) > 10000 + |""".stripMargin), + Row(2) :: Row(6) :: Nil + ) + checkAnswer( + sql( + s""" + |SELECT + | dept as d, d + |FROM + | ( + | select + | name, dept, salary, rank() over (partition by dept order by salary) as r, + | 1 as a, a + 1 as e + | FROM + | $testTable + | ) AS inner_t + |GROUP BY + | dept + |HAVING max(salary) > 10000 + |""".stripMargin), + Row(2, 2) :: Row(6, 6) :: Nil + ) + checkAnswer( + sql( + s""" + |SELECT + | dept as d + |FROM + | ( + | select + | name, dept, salary, rank() over (partition by dept order by salary) as r, + | 1 as a, a + 1 as e + | FROM + | $testTable + | ) AS inner_t + |GROUP BY + | dept + |HAVING d = 2 + |""".stripMargin), + Row(2) :: Nil + ) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org