This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.3 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push: new 4cb2ae220cf [SPARK-38666][SQL] Add missing aggregate filter checks 4cb2ae220cf is described below commit 4cb2ae220cf6c04221eaf70fa2af1526507a38de Author: Bruce Robbins <bersprock...@gmail.com> AuthorDate: Fri Apr 22 11:12:32 2022 +0800 [SPARK-38666][SQL] Add missing aggregate filter checks ### What changes were proposed in this pull request? Add checks in `ResolveFunctions#validateFunction` to ensure the following about each aggregate filter: - has a datatype of boolean - doesn't contain an aggregate expression - doesn't contain a window expression `ExtractGenerator` already handles the case of a generator in an aggregate filter. ### Why are the changes needed? There are three cases where a query with an aggregate filter produces non-helpful error messages. 1) Window expression in aggregate filter ``` select sum(a) filter (where nth_value(a, 2) over (order by b) > 1) from (select 1 a, '2' b); ``` The above query should produce an analysis error, but instead produces a stack overflow: ``` java.lang.StackOverflowError: null at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62) ~[scala-library.jar:?] at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53) ~[scala-library.jar:?] at scala.collection.immutable.VectorBuilder.$plus$plus$eq(Vector.scala:668) ~[scala-library.jar:?] at scala.collection.immutable.VectorBuilder.$plus$plus$eq(Vector.scala:645) ~[scala-library.jar:?] at scala.collection.generic.GenericCompanion.apply(GenericCompanion.scala:56) ~[scala-library.jar:?] at org.apache.spark.sql.catalyst.trees.UnaryLike.children(TreeNode.scala:1172) ~[spark-catalyst_2.12-3.4.0-SNAPSHOT.jar:3.4.0-SNAPSHOT] at org.apache.spark.sql.catalyst.trees.UnaryLike.children$(TreeNode.scala:1172) ~[spark-catalyst_2.12-3.4.0-SNAPSHOT.jar:3.4.0-SNAPSHOT] at org.apache.spark.sql.catalyst.expressions.UnaryExpression.children$lzycompute(Expression.scala:494) ~[spark-catalyst_2.12-3.4.0-SNAPSHOT.jar:3.4.0-SNAPSHOT] at org.apache.spark.sql.catalyst.expressions.UnaryExpression.children(Expression.scala:494) ~[spark-catalyst_2.12-3.4.0-SNAPSHOT.jar:3.4.0-SNAPSHOT] at org.apache.spark.sql.catalyst.expressions.Expression.childrenResolved(Expression.scala:223) ~[spark-catalyst_2.12-3.4.0-SNAPSHOT.jar:3.4.0-SNAPSHOT] at org.apache.spark.sql.catalyst.expressions.Alias.resolved$lzycompute(namedExpressions.scala:155) ~[spark-catalyst_2.12-3.4.0-SNAPSHOT.jar:3.4.0-SNAPSHOT] at org.apache.spark.sql.catalyst.expressions.Alias.resolved(namedExpressions.scala:155) ~[spark-catalyst_2.12-3.4.0-SNAPSHOT.jar:3.4.0-SNAPSHOT] ``` With this PR, the query will instead produce ``` org.apache.spark.sql.AnalysisException: FILTER expression contains window function. It cannot be used in an aggregate function; line 1 pos 7 ``` 2) Non-boolean filter expression ``` select sum(a) filter (where a) from (select 1 a, '2' b); ``` This query should produce an analysis error, but instead causes a projection compilation error or whole-stage codegen error (depending on the datatype of the expression): ```` org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 50, Column 6: Not a boolean expression at org.codehaus.janino.UnitCompiler.compileError(UnitCompiler.java:12021) ~[janino-3.0.16.jar:?] at org.codehaus.janino.UnitCompiler.compileBoolean2(UnitCompiler.java:4049) ~[janino-3.0.16.jar:?] at org.codehaus.janino.UnitCompiler.access$6300(UnitCompiler.java:226) ~[janino-3.0.16.jar:?] at org.codehaus.janino.UnitCompiler$14.visitIntegerLiteral(UnitCompiler.java:4016) ~[janino-3.0.16.jar:?] at org.codehaus.janino.UnitCompiler$14.visitIntegerLiteral(UnitCompiler.java:3986) ~[janino-3.0.16.jar:?] ... at com.google.common.cache.LocalCache$LoadingValueReference.loadFuture(LocalCache.java:3599) ~[guava-14.0.1.jar:?] at com.google.common.cache.LocalCache$Segment.loadSync(LocalCache.java:2379) ~[guava-14.0.1.jar:?] ... 37 more NULL Time taken: 6.132 seconds, Fetched 1 row(s) ```` After the compilation error, _the query returns a result as if `a` was a boolean `false`_. With this PR, the query will instead produce ``` org.apache.spark.sql.AnalysisException: FILTER expression is not of type boolean. It cannot be used in an aggregate function; line 1 pos 7 ``` 3) Aggregate expression in filter expression ``` select max(b) filter (where max(a) > 1) from (select 1 a, '2' b); ``` The above query should produce an analysis error, but instead causes a projection compilation error or whole-stage codegen error (depending on the datatype of the expression being aggregated): ``` org.apache.spark.SparkUnsupportedOperationException: Cannot generate code for expression: max(1) at org.apache.spark.sql.errors.QueryExecutionErrors$.cannotGenerateCodeForExpressionError(QueryExecutionErrors.scala:84) ~[spark-catalyst_2.12-3.4.0-SNAPSHOT.jar:3.4.0-SNAPSHOT] at org.apache.spark.sql.catalyst.expressions.Unevaluable.doGenCode(Expression.scala:347) ~[spark-catalyst_2.12-3.4.0-SNAPSHOT.jar:3.4.0-SNAPSHOT] at org.apache.spark.sql.catalyst.expressions.Unevaluable.doGenCode$(Expression.scala:346) ~[spark-catalyst_2.12-3.4.0-SNAPSHOT.jar:3.4.0-SNAPSHOT] at org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression.doGenCode(interfaces.scala:99) ~[spark-catalyst_2.12-3.4.0-SNAPSHOT.jar:3.4.0-SNAPSHOT] ``` With this PR, the query will instead produce ``` org.apache.spark.sql.AnalysisException: FILTER expression contains aggregate. It cannot be used in an aggregate function; line 1 pos 7 ``` ### Does this PR introduce _any_ user-facing change? No, except in error conditions. ### How was this patch tested? New unit tests. Closes #36072 from bersprockets/aggregate_in_aggregate_filter_issue. Authored-by: Bruce Robbins <bersprock...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit 49d2f3c2458863eefd63c8ce38064757874ab4ad) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 12 ++++++++++-- .../apache/spark/sql/errors/QueryCompilationErrors.scala | 15 +++++++++++++++ .../spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 16 ++++++++++++++++ 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1bc8814b334..44cb7ac0932 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2218,8 +2218,16 @@ class Analyzer(override val catalogManager: CatalogManager) } // We get an aggregate function, we need to wrap it in an AggregateExpression. case agg: AggregateFunction => - if (u.filter.isDefined && !u.filter.get.deterministic) { - throw QueryCompilationErrors.nonDeterministicFilterInAggregateError + u.filter match { + case Some(filter) if !filter.deterministic => + throw QueryCompilationErrors.nonDeterministicFilterInAggregateError + case Some(filter) if filter.dataType != BooleanType => + throw QueryCompilationErrors.nonBooleanFilterInAggregateError + case Some(filter) if filter.exists(_.isInstanceOf[AggregateExpression]) => + throw QueryCompilationErrors.aggregateInAggregateFilterError + case Some(filter) if filter.exists(_.isInstanceOf[WindowExpression]) => + throw QueryCompilationErrors.windowFunctionInAggregateFilterError + case _ => } if (u.ignoreNulls) { val aggFunc = agg match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 07d07fce9ed..502ba4e1909 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -334,6 +334,21 @@ object QueryCompilationErrors extends QueryErrorsBase { "it cannot be used in aggregate functions") } + def nonBooleanFilterInAggregateError(): Throwable = { + new AnalysisException("FILTER expression is not of type boolean. " + + "It cannot be used in an aggregate function") + } + + def aggregateInAggregateFilterError(): Throwable = { + new AnalysisException("FILTER expression contains aggregate. " + + "It cannot be used in an aggregate function") + } + + def windowFunctionInAggregateFilterError(): Throwable = { + new AnalysisException("FILTER expression contains window function. " + + "It cannot be used in an aggregate function") + } + def aliasNumberNotMatchColumnNumberError( columnSize: Int, outputSize: Int, t: TreeNode[_]): Throwable = { new AnalysisException("Number of column aliases does not match number of columns. " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index c69d51938ae..a5b8663f5e6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -545,6 +545,22 @@ class AnalysisErrorSuite extends AnalysisTest { "explode(array(min(a))), explode(array(max(a)))" :: Nil ) + errorTest( + "SPARK-38666: non-boolean aggregate filter", + CatalystSqlParser.parsePlan("SELECT sum(c) filter (where e) FROM TaBlE2"), + "FILTER expression is not of type boolean" :: Nil) + + errorTest( + "SPARK-38666: aggregate in aggregate filter", + CatalystSqlParser.parsePlan("SELECT sum(c) filter (where max(e) > 1) FROM TaBlE2"), + "FILTER expression contains aggregate" :: Nil) + + errorTest( + "SPARK-38666: window function in aggregate filter", + CatalystSqlParser.parsePlan("SELECT sum(c) " + + "filter (where nth_value(e, 2) over(order by b) > 1) FROM TaBlE2"), + "FILTER expression contains window function" :: Nil) + test("SPARK-6452 regression test") { // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) // Since we manually construct the logical plan at here and Sum only accept --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org