This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 37307fb0d14e [SPARK-47241][SQL] Fix rule order issues for ExtractGenerator 37307fb0d14e is described below commit 37307fb0d14e03b1085ed12d8d540d2606bf1e9d Author: Wenchen Fan <wenc...@databricks.com> AuthorDate: Thu Mar 7 17:02:09 2024 +0800 [SPARK-47241][SQL] Fix rule order issues for ExtractGenerator ### What changes were proposed in this pull request? The rule `ExtractGenerator` does not define any trigger condition when rewriting generator functions in `Project`, which makes the behavior quite unstable and heavily depends on the execution order of analyzer rules. Two bugs I've found so far: 1. By design, we want to forbid users from using more than one generator function in SELECT. However, we can't really enforce it if two generator functions are not resolved at the same time: the rule thinks there is only one generate function (the other is still unresolved), then rewrite it. The other one gets resolved later and gets rewritten later. 2. When a generator function is put after `SELECT *`, it's possible that `*` is not expanded yet when we enter `ExtractGenerator`. The rule rewrites the generator function: insert a `Generate` operator below, and add a new column to the projectList for the generator function output. Then we expand `*` to the child plan output which is `Generate`, we end up with two identical columns for the generate function output. This PR fixes it by adding a trigger condition when rewriting generator functions in `Project`: the projectList should be resolved or a generator function. This is the same trigger condition we used for `Aggregate`. To avoid breaking changes, this PR also allows multiple generator functions in `Project`, which works totally fine. ### Why are the changes needed? bug fix ### Does this PR introduce _any_ user-facing change? Yes, now multiple generator functions are allowed in `Project`. And there won't be duplicated columns for generator function output. ### How was this patch tested? new test ### Was this patch authored or co-authored using generative AI tooling? No Closes #45350 from cloud-fan/generate. Lead-authored-by: Wenchen Fan <wenc...@databricks.com> Co-authored-by: Wenchen Fan <cloud0...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit 51f4cfa7560bba576577d3a5f254daaad516849d) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../src/main/resources/error/error-classes.json | 2 +- ...conditions-unsupported-generator-error-class.md | 2 +- .../spark/sql/catalyst/analysis/Analyzer.scala | 43 ++++++++++++++-------- .../sql/catalyst/analysis/CheckAnalysis.scala | 10 ----- .../spark/sql/errors/QueryCompilationErrors.scala | 3 +- .../sql/catalyst/analysis/AnalysisErrorSuite.scala | 14 +------ .../org/apache/spark/sql/DataFrameSuite.scala | 14 ------- .../apache/spark/sql/GeneratorFunctionSuite.scala | 27 +++++++++++++- .../sql/errors/QueryCompilationErrorsSuite.scala | 12 ------ .../spark/sql/hive/execution/HiveQuerySuite.scala | 22 ----------- 10 files changed, 57 insertions(+), 92 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 2d50fe1a1a1a..b9d4c2c297f8 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -3056,7 +3056,7 @@ "subClass" : { "MULTI_GENERATOR" : { "message" : [ - "only one generator allowed per <clause> clause but found <num>: <generators>." + "only one generator allowed per SELECT clause but found <num>: <generators>." ] }, "NESTED_IN_EXPRESSIONS" : { diff --git a/docs/sql-error-conditions-unsupported-generator-error-class.md b/docs/sql-error-conditions-unsupported-generator-error-class.md index 7960c14767d1..38b3bbfaa3c3 100644 --- a/docs/sql-error-conditions-unsupported-generator-error-class.md +++ b/docs/sql-error-conditions-unsupported-generator-error-class.md @@ -27,7 +27,7 @@ This error class has the following derived error classes: ## MULTI_GENERATOR -only one generator allowed per `<clause>` clause but found `<num>`: `<generators>`. +only one generator allowed per SELECT clause but found `<num>`: `<generators>`. ## NESTED_IN_EXPRESSIONS diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8fe87a05d02d..eae150001249 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2742,28 +2742,36 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } } + // We must wait until all expressions except for generator functions are resolved before + // rewriting generator functions in Project/Aggregate. This is necessary to make this rule + // stable for different execution orders of analyzer rules. See also SPARK-47241. + private def canRewriteGenerator(namedExprs: Seq[NamedExpression]): Boolean = { + namedExprs.forall { ne => + ne.resolved || { + trimNonTopLevelAliases(ne) match { + case AliasedGenerator(_, _, _) => true + case _ => false + } + } + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( _.containsPattern(GENERATOR), ruleId) { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get throw QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator)) - case Project(projectList, _) if projectList.count(hasGenerator) > 1 => - val generators = projectList.filter(hasGenerator).map(trimAlias) - throw QueryCompilationErrors.moreThanOneGeneratorError(generators, "SELECT") - case Aggregate(_, aggList, _) if aggList.exists(hasNestedGenerator) => val nestedGenerator = aggList.find(hasNestedGenerator).get throw QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator)) case Aggregate(_, aggList, _) if aggList.count(hasGenerator) > 1 => val generators = aggList.filter(hasGenerator).map(trimAlias) - throw QueryCompilationErrors.moreThanOneGeneratorError(generators, "aggregate") + throw QueryCompilationErrors.moreThanOneGeneratorError(generators) - case agg @ Aggregate(groupList, aggList, child) if aggList.forall { - case AliasedGenerator(_, _, _) => true - case other => other.resolved - } && aggList.exists(hasGenerator) => + case Aggregate(groupList, aggList, child) if canRewriteGenerator(aggList) && + aggList.exists(hasGenerator) => // If generator in the aggregate list was visited, set the boolean flag true. var generatorVisited = false @@ -2808,16 +2816,16 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // first for replacing `Project` with `Aggregate`. p - case p @ Project(projectList, child) => + case p @ Project(projectList, child) if canRewriteGenerator(projectList) && + projectList.exists(hasGenerator) => val (resolvedGenerator, newProjectList) = projectList .map(trimNonTopLevelAliases) .foldLeft((None: Option[Generate], Nil: Seq[NamedExpression])) { (res, e) => e match { - case AliasedGenerator(generator, names, outer) if generator.childrenResolved => - // It's a sanity check, this should not happen as the previous case will throw - // exception earlier. - assert(res._1.isEmpty, "More than one generator found in SELECT.") - + // If there are more than one generator, we only rewrite the first one and wait for + // the next analyzer iteration to rewrite the next one. + case AliasedGenerator(generator, names, outer) if res._1.isEmpty && + generator.childrenResolved => val g = Generate( generator, unrequiredChildIndex = Nil, @@ -2825,7 +2833,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor qualifier = None, generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names), child) - (Some(g), res._2 ++ g.nullableOutput) case other => (res._1, res._2 :+ other) @@ -2845,6 +2852,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case u: UnresolvedTableValuedFunction => u + case p: Project => p + + case a: Aggregate => a + case p if p.expressions.exists(hasGenerator) => throw QueryCompilationErrors.generatorOutsideSelectError(p) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 533ea8a2b799..7f10bdbc80ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -64,12 +64,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB messageParameters = messageParameters) } - protected def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = { - exprs.flatMap(_.collect { - case e: Generator => e - }).length > 1 - } - protected def hasMapType(dt: DataType): Boolean = { dt.existsRecursively(_.isInstanceOf[MapType]) } @@ -687,10 +681,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB )) } - case p @ Project(exprs, _) if containsMultipleGenerators(exprs) => - val generators = exprs.filter(expr => expr.exists(_.isInstanceOf[Generator])) - throw QueryCompilationErrors.moreThanOneGeneratorError(generators, "SELECT") - case p @ Project(projectList, _) => projectList.foreach(_.transformDownWithPruning( _.containsPattern(UNRESOLVED_WINDOW_EXPRESSION)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 9dca2c5f2822..a78e092c4bfa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -248,11 +248,10 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat messageParameters = Map("expression" -> toSQLExpr(trimmedNestedGenerator))) } - def moreThanOneGeneratorError(generators: Seq[Expression], clause: String): Throwable = { + def moreThanOneGeneratorError(generators: Seq[Expression]): Throwable = { new AnalysisException( errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR", messageParameters = Map( - "clause" -> clause, "num" -> generators.size.toString, "generators" -> generators.map(toSQLExpr).mkString(", "))) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index e2e980073307..e8dc9061199c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -344,11 +344,6 @@ class AnalysisErrorSuite extends AnalysisTest { "inputType" -> "\"BOOLEAN\"", "requiredType" -> "\"INT\"")) - errorTest( - "too many generators", - listRelation.select(Explode($"list").as("a"), Explode($"list").as("b")), - "only one generator" :: "explode" :: Nil) - errorClassTest( "unresolved attributes", testRelation.select($"abcd"), @@ -754,18 +749,11 @@ class AnalysisErrorSuite extends AnalysisTest { "SUM_OF_LIMIT_AND_OFFSET_EXCEEDS_MAX_INT", Map("limit" -> "1000000000", "offset" -> "2000000000")) - errorTest( - "more than one generators in SELECT", - listRelation.select(Explode($"list"), Explode($"list")), - "The generator is not supported: only one generator allowed per select clause but found 2: " + - """"explode(list)", "explode(list)"""" :: Nil - ) - errorTest( "more than one generators for aggregates in SELECT", testRelation.select(Explode(CreateArray(min($"a") :: Nil)), Explode(CreateArray(max($"a") :: Nil))), - "The generator is not supported: only one generator allowed per select clause but found 2: " + + "The generator is not supported: only one generator allowed per SELECT clause but found 2: " + """"explode(array(min(a)))", "explode(array(max(a)))"""" :: Nil ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 002719f06896..c586da6105fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -368,20 +368,6 @@ class DataFrameSuite extends QueryTest Row("a", Seq("a"), 1) :: Nil) } - test("more than one generator in SELECT clause") { - val df = Seq((Array("a"), 1)).toDF("a", "b") - - checkError( - exception = intercept[AnalysisException] { - df.select(explode($"a").as("a"), explode($"a").as("b")) - }, - errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR", - parameters = Map( - "clause" -> "SELECT", - "num" -> "2", - "generators" -> "\"explode(a)\", \"explode(a)\"")) - } - test("sort after generate with join=true") { val df = Seq((Array("a"), 1)).toDF("a", "b") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 0746a4b92af2..7c285759fcd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -432,7 +432,6 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { }, errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR", parameters = Map( - "clause" -> "aggregate", "num" -> "2", "generators" -> ("\"explode(array(min(c2), max(c2)))\", " + "\"posexplode(array(min(c2), max(c2)))\""))) @@ -543,6 +542,32 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { checkAnswer(df, Row(0.7604953758285915d)) } } + + test("SPARK-47241: two generator functions in SELECT") { + def testTwoGenerators(needImplicitCast: Boolean): Unit = { + val df = sql( + s""" + |SELECT + |explode(array('a', 'b')) as c1, + |explode(array(0L, ${if (needImplicitCast) "0L + 1" else "1L"})) as c2 + |""".stripMargin) + checkAnswer(df, Seq(Row("a", 0L), Row("a", 1L), Row("b", 0L), Row("b", 1L))) + } + testTwoGenerators(needImplicitCast = true) + testTwoGenerators(needImplicitCast = false) + } + + test("SPARK-47241: generator function after wildcard in SELECT") { + val df = sql( + s""" + |SELECT *, explode(array('a', 'b')) as c1 + |FROM + |( + | SELECT id FROM range(1) GROUP BY 1 + |) + |""".stripMargin) + checkAnswer(df, Seq(Row(0, "a"), Row(0, "b"))) + } } case class EmptyGenerator() extends Generator with LeafLike[Expression] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala index 7f938deaaa64..ac57c958828b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala @@ -646,18 +646,6 @@ class QueryCompilationErrorsSuite parameters = Map("expression" -> "\"(explode(array(1, 2, 3)) + 1)\"")) } - test("UNSUPPORTED_GENERATOR: only one generator allowed") { - val e = intercept[AnalysisException]( - sql("""select explode(Array(1, 2, 3)), explode(Array(1, 2, 3))""").collect() - ) - - checkError( - exception = e, - errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR", - parameters = Map("clause" -> "SELECT", "num" -> "2", - "generators" -> "\"explode(array(1, 2, 3))\", \"explode(array(1, 2, 3))\"")) - } - test("UNSUPPORTED_GENERATOR: generators are not supported outside the SELECT clause") { val e = intercept[AnalysisException]( sql("""select 1 from t order by explode(Array(1, 2, 3))""").collect() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 82b88ec9f35d..4b85b37b6c2c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -161,28 +161,6 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd | SELECT key FROM gen_tmp ORDER BY key ASC; """.stripMargin) - test("multiple generators in projection") { - checkError( - exception = intercept[AnalysisException] { - sql("SELECT explode(array(key, key)), explode(array(key, key)) FROM src").collect() - }, - errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR", - parameters = Map( - "clause" -> "SELECT", - "num" -> "2", - "generators" -> "\"explode(array(key, key))\", \"explode(array(key, key))\"")) - - checkError( - exception = intercept[AnalysisException] { - sql("SELECT explode(array(key, key)) as k1, explode(array(key, key)) FROM src").collect() - }, - errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR", - parameters = Map( - "clause" -> "SELECT", - "num" -> "2", - "generators" -> "\"explode(array(key, key))\", \"explode(array(key, key))\"")) - } - createQueryTest("! operator", """ |SELECT a FROM ( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org