This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new 4cb2ae220cf [SPARK-38666][SQL] Add missing aggregate filter checks
4cb2ae220cf is described below

commit 4cb2ae220cf6c04221eaf70fa2af1526507a38de
Author: Bruce Robbins <bersprock...@gmail.com>
AuthorDate: Fri Apr 22 11:12:32 2022 +0800

    [SPARK-38666][SQL] Add missing aggregate filter checks
    
    ### What changes were proposed in this pull request?
    
    Add checks in `ResolveFunctions#validateFunction` to ensure the following 
about each aggregate filter:
    
    - has a datatype of boolean
    - doesn't contain an aggregate expression
    - doesn't contain a window expression
    
    `ExtractGenerator` already handles the case of a generator in an aggregate 
filter.
    
    ### Why are the changes needed?
    
    There are three cases where a query with an aggregate filter produces 
non-helpful error messages.
    
    1) Window expression in aggregate filter
    
    ```
    select sum(a) filter (where nth_value(a, 2) over (order by b) > 1)
    from (select 1 a, '2' b);
    ```
    The above query should produce an analysis error, but instead produces a 
stack overflow:
    ```
    java.lang.StackOverflowError: null
            at 
scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62) 
~[scala-library.jar:?]
            at 
scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53) 
~[scala-library.jar:?]
            at 
scala.collection.immutable.VectorBuilder.$plus$plus$eq(Vector.scala:668) 
~[scala-library.jar:?]
            at 
scala.collection.immutable.VectorBuilder.$plus$plus$eq(Vector.scala:645) 
~[scala-library.jar:?]
            at 
scala.collection.generic.GenericCompanion.apply(GenericCompanion.scala:56) 
~[scala-library.jar:?]
            at 
org.apache.spark.sql.catalyst.trees.UnaryLike.children(TreeNode.scala:1172) 
~[spark-catalyst_2.12-3.4.0-SNAPSHOT.jar:3.4.0-SNAPSHOT]
            at 
org.apache.spark.sql.catalyst.trees.UnaryLike.children$(TreeNode.scala:1172) 
~[spark-catalyst_2.12-3.4.0-SNAPSHOT.jar:3.4.0-SNAPSHOT]
            at 
org.apache.spark.sql.catalyst.expressions.UnaryExpression.children$lzycompute(Expression.scala:494)
 ~[spark-catalyst_2.12-3.4.0-SNAPSHOT.jar:3.4.0-SNAPSHOT]
            at 
org.apache.spark.sql.catalyst.expressions.UnaryExpression.children(Expression.scala:494)
 ~[spark-catalyst_2.12-3.4.0-SNAPSHOT.jar:3.4.0-SNAPSHOT]
            at 
org.apache.spark.sql.catalyst.expressions.Expression.childrenResolved(Expression.scala:223)
 ~[spark-catalyst_2.12-3.4.0-SNAPSHOT.jar:3.4.0-SNAPSHOT]
            at 
org.apache.spark.sql.catalyst.expressions.Alias.resolved$lzycompute(namedExpressions.scala:155)
 ~[spark-catalyst_2.12-3.4.0-SNAPSHOT.jar:3.4.0-SNAPSHOT]
            at 
org.apache.spark.sql.catalyst.expressions.Alias.resolved(namedExpressions.scala:155)
 ~[spark-catalyst_2.12-3.4.0-SNAPSHOT.jar:3.4.0-SNAPSHOT]
    ```
    With this PR, the query will instead produce
    ```
    org.apache.spark.sql.AnalysisException: FILTER expression contains window 
function. It cannot be used in an aggregate function; line 1 pos 7
    ```
    
    2) Non-boolean filter expression
    
    ```
    select sum(a) filter (where a) from (select 1 a, '2' b);
    ```
    This query should produce an analysis error, but instead causes a 
projection compilation error or whole-stage codegen error (depending on the 
datatype of the expression):
    ````
    org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 
50, Column 6: Not a boolean expression
            at 
org.codehaus.janino.UnitCompiler.compileError(UnitCompiler.java:12021) 
~[janino-3.0.16.jar:?]
            at 
org.codehaus.janino.UnitCompiler.compileBoolean2(UnitCompiler.java:4049) 
~[janino-3.0.16.jar:?]
            at 
org.codehaus.janino.UnitCompiler.access$6300(UnitCompiler.java:226) 
~[janino-3.0.16.jar:?]
            at 
org.codehaus.janino.UnitCompiler$14.visitIntegerLiteral(UnitCompiler.java:4016) 
~[janino-3.0.16.jar:?]
            at 
org.codehaus.janino.UnitCompiler$14.visitIntegerLiteral(UnitCompiler.java:3986) 
~[janino-3.0.16.jar:?]
    ...
            at 
com.google.common.cache.LocalCache$LoadingValueReference.loadFuture(LocalCache.java:3599)
 ~[guava-14.0.1.jar:?]
            at 
com.google.common.cache.LocalCache$Segment.loadSync(LocalCache.java:2379) 
~[guava-14.0.1.jar:?]
            ... 37 more
    NULL
    Time taken: 6.132 seconds, Fetched 1 row(s)
    ````
    After the compilation error, _the query returns a result as if `a` was a 
boolean `false`_.
    
    With this PR, the query will instead produce
    ```
    org.apache.spark.sql.AnalysisException: FILTER expression is not of type 
boolean. It cannot be used in an aggregate function; line 1 pos 7
    ```
    
    3) Aggregate expression in filter expression
    
    ```
    select max(b) filter (where max(a) > 1) from (select 1 a, '2' b);
    ```
    The above query should produce an analysis error, but instead causes a 
projection compilation error or whole-stage codegen error (depending on the 
datatype of the expression being aggregated):
    ```
    org.apache.spark.SparkUnsupportedOperationException: Cannot generate code 
for expression: max(1)
            at 
org.apache.spark.sql.errors.QueryExecutionErrors$.cannotGenerateCodeForExpressionError(QueryExecutionErrors.scala:84)
 ~[spark-catalyst_2.12-3.4.0-SNAPSHOT.jar:3.4.0-SNAPSHOT]
            at 
org.apache.spark.sql.catalyst.expressions.Unevaluable.doGenCode(Expression.scala:347)
 ~[spark-catalyst_2.12-3.4.0-SNAPSHOT.jar:3.4.0-SNAPSHOT]
            at 
org.apache.spark.sql.catalyst.expressions.Unevaluable.doGenCode$(Expression.scala:346)
 ~[spark-catalyst_2.12-3.4.0-SNAPSHOT.jar:3.4.0-SNAPSHOT]
            at 
org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression.doGenCode(interfaces.scala:99)
 ~[spark-catalyst_2.12-3.4.0-SNAPSHOT.jar:3.4.0-SNAPSHOT]
    ```
    With this PR, the query will instead produce
    ```
    org.apache.spark.sql.AnalysisException: FILTER expression contains 
aggregate. It cannot be used in an aggregate function; line 1 pos 7
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, except in error conditions.
    
    ### How was this patch tested?
    
    New unit tests.
    
    Closes #36072 from bersprockets/aggregate_in_aggregate_filter_issue.
    
    Authored-by: Bruce Robbins <bersprock...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 49d2f3c2458863eefd63c8ce38064757874ab4ad)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../apache/spark/sql/catalyst/analysis/Analyzer.scala    | 12 ++++++++++--
 .../apache/spark/sql/errors/QueryCompilationErrors.scala | 15 +++++++++++++++
 .../spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 16 ++++++++++++++++
 3 files changed, 41 insertions(+), 2 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 1bc8814b334..44cb7ac0932 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2218,8 +2218,16 @@ class Analyzer(override val catalogManager: 
CatalogManager)
           }
         // We get an aggregate function, we need to wrap it in an 
AggregateExpression.
         case agg: AggregateFunction =>
-          if (u.filter.isDefined && !u.filter.get.deterministic) {
-            throw QueryCompilationErrors.nonDeterministicFilterInAggregateError
+          u.filter match {
+            case Some(filter) if !filter.deterministic =>
+              throw 
QueryCompilationErrors.nonDeterministicFilterInAggregateError
+            case Some(filter) if filter.dataType != BooleanType =>
+              throw QueryCompilationErrors.nonBooleanFilterInAggregateError
+            case Some(filter) if 
filter.exists(_.isInstanceOf[AggregateExpression]) =>
+              throw QueryCompilationErrors.aggregateInAggregateFilterError
+            case Some(filter) if 
filter.exists(_.isInstanceOf[WindowExpression]) =>
+              throw QueryCompilationErrors.windowFunctionInAggregateFilterError
+            case _ =>
           }
           if (u.ignoreNulls) {
             val aggFunc = agg match {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 07d07fce9ed..502ba4e1909 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -334,6 +334,21 @@ object QueryCompilationErrors extends QueryErrorsBase {
       "it cannot be used in aggregate functions")
   }
 
+  def nonBooleanFilterInAggregateError(): Throwable = {
+    new AnalysisException("FILTER expression is not of type boolean. " +
+      "It cannot be used in an aggregate function")
+  }
+
+  def aggregateInAggregateFilterError(): Throwable = {
+    new AnalysisException("FILTER expression contains aggregate. " +
+      "It cannot be used in an aggregate function")
+  }
+
+  def windowFunctionInAggregateFilterError(): Throwable = {
+    new AnalysisException("FILTER expression contains window function. " +
+      "It cannot be used in an aggregate function")
+  }
+
   def aliasNumberNotMatchColumnNumberError(
       columnSize: Int, outputSize: Int, t: TreeNode[_]): Throwable = {
     new AnalysisException("Number of column aliases does not match number of 
columns. " +
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index c69d51938ae..a5b8663f5e6 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -545,6 +545,22 @@ class AnalysisErrorSuite extends AnalysisTest {
       "explode(array(min(a))), explode(array(max(a)))" :: Nil
   )
 
+  errorTest(
+    "SPARK-38666: non-boolean aggregate filter",
+    CatalystSqlParser.parsePlan("SELECT sum(c) filter (where e) FROM TaBlE2"),
+    "FILTER expression is not of type boolean" :: Nil)
+
+  errorTest(
+    "SPARK-38666: aggregate in aggregate filter",
+    CatalystSqlParser.parsePlan("SELECT sum(c) filter (where max(e) > 1) FROM 
TaBlE2"),
+    "FILTER expression contains aggregate" :: Nil)
+
+  errorTest(
+    "SPARK-38666: window function in aggregate filter",
+    CatalystSqlParser.parsePlan("SELECT sum(c) " +
+       "filter (where nth_value(e, 2) over(order by b) > 1) FROM TaBlE2"),
+    "FILTER expression contains window function" :: Nil)
+
   test("SPARK-6452 regression test") {
     // CheckAnalysis should throw AnalysisException when Aggregate contains 
missing attribute(s)
     // Since we manually construct the logical plan at here and Sum only accept


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to