maropu commented on a change in pull request #30781:
URL: https://github.com/apache/spark/pull/30781#discussion_r545834754



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
##########
@@ -2097,9 +2097,14 @@ class Analyzer(override val catalogManager: 
CatalogManager)
         ResolvedFunc(Identifier.of(funcIdent.database.toArray, 
funcIdent.funcName))
 
       case q: LogicalPlan =>
+
+        val isGroupingIdAllowed = q.isInstanceOf[Aggregate] || 
q.isInstanceOf[GroupingSets] ||
+          q.isInstanceOf[Filter] || q.isInstanceOf[Sort]

Review comment:
       nit:
   ```
           val isGroupingIdAllowed = q match {
             case _: Aggregate | _: GroupingSets | _: Filter | _: Sort => true
             case _ => false
           }
   ```

##########
File path: 
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
##########
@@ -1006,4 +1006,58 @@ class AnalysisSuite extends AnalysisTest with Matchers {
       checkAnalysis(plan, expect)
     }
   }
+
+  test("SPARK-22748: grouping_id() can only be used with 
GroupingSets/Cube/Rollup") {

Review comment:
       Could you make the test title clearer, too?

##########
File path: 
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
##########
@@ -1006,4 +1006,58 @@ class AnalysisSuite extends AnalysisTest with Matchers {
       checkAnalysis(plan, expect)
     }
   }
+
+  test("SPARK-22748: grouping_id() can only be used with 
GroupingSets/Cube/Rollup") {
+
+    val rollUpPlan = parsePlan("""select grouping__id from (
+                                 |select grouping__id from (
+                                 |select a, b, count(1), grouping__id from 
TaBlE2
+                                 | group by a, b with rollup))
+                                 |""".stripMargin)

Review comment:
       nit: could you re-format this test like this?
   ```
       assertAnalysisSuccess(parsePlan(
         """
           |SELECT grouping__id FROM (
           |  SELECT grouping__id FROM (
           |    SELECT a, b, count(1), grouping__id FROM TaBlE2
           |      GROUP BY a, b WITH ROLLUP
           |  )
           |)
         """.stripMargin))
   ```
    - I think we don't need intermediate values, e.g., `rollUpPlan`
    - Please use uppercases for SQL keywords

##########
File path: 
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
##########
@@ -1006,4 +1006,58 @@ class AnalysisSuite extends AnalysisTest with Matchers {
       checkAnalysis(plan, expect)
     }
   }
+
+  test("SPARK-22748: grouping_id() can only be used with 
GroupingSets/Cube/Rollup") {
+
+    val rollUpPlan = parsePlan("""select grouping__id from (
+                                 |select grouping__id from (
+                                 |select a, b, count(1), grouping__id from 
TaBlE2
+                                 | group by a, b with rollup))
+                                 |""".stripMargin)
+
+    val cubePlan = parsePlan("""select grouping__id from (
+                               |select a, b, count(1), grouping__id from TaBlE2
+                               | group by a, b with cube)
+                               |""".stripMargin)
+
+    val groupingSetsPlan = parsePlan("""select grouping__id from (
+                                       |select a, b, count(1), grouping__id 
from TaBlE2
+                                       | group by a, b grouping sets ((a, b), 
()))
+                                       |""".stripMargin)
+
+    val wrongPlan = parsePlan("""select grouping__id from (
+                                |select a, b, count(1), grouping__id from 
TaBlE2
+                                | group by a, b)
+                                |""".stripMargin)
+
+    val plan1 = parsePlan("""select a, b, count(1) from TaBlE2
+                         | group by cube(a, b) having grouping__id > 0
+                            |""".stripMargin)
+
+    val plan2 = parsePlan("""select * from (select a, b, count(1) from TaBlE2
+                         | group by a, b grouping sets ((a, b), ())) where 
grouping__id > 0
+                            |""".stripMargin)
+
+    val plan3 = parsePlan("""select * from (select a, b, count(1) from TaBlE2
+                         | group by a, b grouping sets ((a, b), ())) order by 
grouping__id > 0
+                            |""".stripMargin)
+
+    val plan4 = parsePlan("""select a, b, count(1) from TaBlE2
+                         | group by a, b grouping sets ((a, b), ()) order by 
grouping__id > 0
+                            |""".stripMargin)
+
+    assertAnalysisSuccess(rollUpPlan, false)
+    assertAnalysisSuccess(cubePlan, false)
+    assertAnalysisSuccess(groupingSetsPlan, false)
+
+    assertAnalysisSuccess(plan1, false)
+    assertAnalysisSuccess(plan2, false)
+    assertAnalysisSuccess(plan3, false)
+    assertAnalysisSuccess(plan4, false)
+
+    assertAnalysisError(wrongPlan,
+      Seq("grouping_id() can only be used with GroupingSets/Cube/Rollup"),
+      false)
+  }
+

Review comment:
       nit: Remove this unnecessary blank line.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to