This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9287fc9dd73 [SPARK-40618][SQL] Fix bug in MergeScalarSubqueries rule with nested subqueries using reference tracking 9287fc9dd73 is described below commit 9287fc9dd73a0909d5705308532b24528b3f1090 Author: Peter Toth <peter.t...@gmail.com> AuthorDate: Thu Oct 13 22:06:15 2022 +0800 [SPARK-40618][SQL] Fix bug in MergeScalarSubqueries rule with nested subqueries using reference tracking ### What changes were proposed in this pull request? This PR reverts the previous fix https://github.com/apache/spark/pull/38052 and adds subquery reference tracking to `MergeScalarSubqueries` to restore previous functionality of merging independent nested subqueries. ### Why are the changes needed? Restore previous functionality but fix the bug discovered in https://issues.apache.org/jira/browse/SPARK-40618. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing and new UTs. Closes #38093 from peter-toth/SPARK-40618-fix-mergescalarsubqueries. Authored-by: Peter Toth <peter.t...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../catalyst/optimizer/MergeScalarSubqueries.scala | 62 +++++++++++++--------- .../scala/org/apache/spark/sql/SubquerySuite.scala | 35 +++++++++--- 2 files changed, 67 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala index 69f77e8f3f4..1cb3f3f157c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ @@ -126,8 +127,14 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] { * merged as there can be subqueries that are different ([[checkIdenticalPlans]] is * false) due to an extra [[Project]] node in one of them. In that case * `attributes.size` remains 1 after merging, but the merged flag becomes true. + * @param references A set of subquery indexes in the cache to track all (including transitive) + * nested subqueries. */ - case class Header(attributes: Seq[Attribute], plan: LogicalPlan, merged: Boolean) + case class Header( + attributes: Seq[Attribute], + plan: LogicalPlan, + merged: Boolean, + references: Set[Int]) private def extractCommonScalarSubqueries(plan: LogicalPlan) = { val cache = ArrayBuffer.empty[Header] @@ -166,26 +173,39 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] { // "Header". private def cacheSubquery(plan: LogicalPlan, cache: ArrayBuffer[Header]): (Int, Int) = { val output = plan.output.head - cache.zipWithIndex.collectFirst(Function.unlift { case (header, subqueryIndex) => - checkIdenticalPlans(plan, header.plan).map { outputMap => - val mappedOutput = mapAttributes(output, outputMap) - val headerIndex = header.attributes.indexWhere(_.exprId == mappedOutput.exprId) - subqueryIndex -> headerIndex - }.orElse(tryMergePlans(plan, header.plan).map { - case (mergedPlan, outputMap) => + val references = mutable.HashSet.empty[Int] + plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY_REFERENCE)) { + case ssr: ScalarSubqueryReference => + references += ssr.subqueryIndex + references ++= cache(ssr.subqueryIndex).references + ssr + } + + cache.zipWithIndex.collectFirst(Function.unlift { + case (header, subqueryIndex) if !references.contains(subqueryIndex) => + checkIdenticalPlans(plan, header.plan).map { outputMap => val mappedOutput = mapAttributes(output, outputMap) - var headerIndex = header.attributes.indexWhere(_.exprId == mappedOutput.exprId) - val newHeaderAttributes = if (headerIndex == -1) { - headerIndex = header.attributes.size - header.attributes :+ mappedOutput - } else { - header.attributes - } - cache(subqueryIndex) = Header(newHeaderAttributes, mergedPlan, true) + val headerIndex = header.attributes.indexWhere(_.exprId == mappedOutput.exprId) subqueryIndex -> headerIndex - }) + }.orElse{ + tryMergePlans(plan, header.plan).map { + case (mergedPlan, outputMap) => + val mappedOutput = mapAttributes(output, outputMap) + var headerIndex = header.attributes.indexWhere(_.exprId == mappedOutput.exprId) + val newHeaderAttributes = if (headerIndex == -1) { + headerIndex = header.attributes.size + header.attributes :+ mappedOutput + } else { + header.attributes + } + cache(subqueryIndex) = + Header(newHeaderAttributes, mergedPlan, true, header.references ++ references) + subqueryIndex -> headerIndex + } + } + case _ => None }).getOrElse { - cache += Header(Seq(output), plan, false) + cache += Header(Seq(output), plan, false, references.toSet) cache.length - 1 -> 0 } } @@ -210,12 +230,6 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] { cachedPlan: LogicalPlan): Option[(LogicalPlan, AttributeMap[Attribute])] = { checkIdenticalPlans(newPlan, cachedPlan).map(cachedPlan -> _).orElse( (newPlan, cachedPlan) match { - case (_, _) if newPlan.containsPattern(SCALAR_SUBQUERY_REFERENCE) || - cachedPlan.containsPattern(SCALAR_SUBQUERY_REFERENCE) => - // Subquery expressions with nested subquery expressions within are not supported for now. - // TODO: support this optimization by collecting the transitive subquery references in the - // new plan and recording them in order to suppress merging the new plan into those. - None case (np: Project, cp: Project) => tryMergePlans(np.child, cp.child).map { case (mergedChild, outputMap) => val (mergedProjectList, newOutputMap) = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index ca78aaae414..2c8c3eda953 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -2251,7 +2251,7 @@ class SubquerySuite extends QueryTest } } - test("SPARK-40618: Do not merge scalar subqueries with nested subqueries inside") { + test("Merge non-correlated scalar subqueries from different parent plans") { Seq(false, true).foreach { enableAQE => withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) { @@ -2283,13 +2283,13 @@ class SubquerySuite extends QueryTest } if (enableAQE) { - assert(subqueryIds.size == 4, "Missing or unexpected SubqueryExec in the plan") - assert(reusedSubqueryIds.size == 2, - "Missing or unexpected reused ReusedSubqueryExec in the plan") - } else { assert(subqueryIds.size == 3, "Missing or unexpected SubqueryExec in the plan") assert(reusedSubqueryIds.size == 3, "Missing or unexpected reused ReusedSubqueryExec in the plan") + } else { + assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan") + assert(reusedSubqueryIds.size == 4, + "Missing or unexpected reused ReusedSubqueryExec in the plan") } } } @@ -2426,9 +2426,32 @@ class SubquerySuite extends QueryTest // This test contains a subquery expression with another subquery expression nested inside. // It acts as a regression test to ensure that the MergeScalarSubqueries rule does not attempt // to merge them together. - withTable("t") { + withTable("t", "t2") { sql("create table t(col int) using csv") checkAnswer(sql("select(select sum((select sum(col) from t)) from t)"), Row(null)) + + checkAnswer(sql( + """ + |select + | (select sum( + | (select sum( + | (select sum(col) from t)) + | from t)) + | from t) + |""".stripMargin), + Row(null)) + + sql("create table t2(col int) using csv") + checkAnswer(sql( + """ + |select + | (select sum( + | (select sum( + | (select sum(col) from t)) + | from t2)) + | from t) + |""".stripMargin), + Row(null)) } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org