From c24d5e870993bebc1d6fcb35e7d8367b9fca6ddf Mon Sep 17 00:00:00 2001
From: David Christensen <david.christensen@crunchydata.com>
Date: Mon, 11 Mar 2024 10:37:05 -0400
Subject: [PATCH v3] Add GROUP BY ALL

GROUP BY ALL is a form of GROUP BY which adds any TargetExpr that does not
contain an Aggref into the groupClause of the query, effectively making it
exactly equivalent to specifying those same expressions in an explicit GROUP BY list.

Since this is exclusive with any other GROUP BY form, this is fairly simple to
add into the grammar and handle without needing to get into grouping sets or
other more complicated forms.

This greatly improves data exploration in particular, as well as making it so
you don't need to trivially wrap more complicated queries in a subquery or
reproduce long, complicated expressions in the literal GROUP BY.
---
 src/backend/parser/analyze.c             |  40 +++++--
 src/backend/parser/gram.y                |  16 ++-
 src/backend/parser/parse_clause.c        |   7 +-
 src/backend/utils/adt/ruleutils.c        |   5 +-
 src/include/nodes/parsenodes.h           |   2 +
 src/include/parser/parse_clause.h        |   2 +
 src/test/regress/expected/aggregates.out | 133 +++++++++++++++++++++++
 src/test/regress/sql/aggregates.sql      |  66 +++++++++++
 8 files changed, 259 insertions(+), 12 deletions(-)

diff --git a/src/backend/parser/analyze.c b/src/backend/parser/analyze.c
index b9763ea1714..2262c0cad08 100644
--- a/src/backend/parser/analyze.c
+++ b/src/backend/parser/analyze.c
@@ -48,6 +48,7 @@
 #include "parser/parse_target.h"
 #include "parser/parse_type.h"
 #include "parser/parsetree.h"
+#include "rewrite/rewriteManip.h"
 #include "utils/backend_status.h"
 #include "utils/builtins.h"
 #include "utils/guc.h"
@@ -1415,6 +1416,29 @@ transformSelectStmt(ParseState *pstate, SelectStmt *stmt)
 	qry->targetList = transformTargetList(pstate, stmt->targetList,
 										  EXPR_KIND_SELECT_TARGET);
 
+	/*
+	 * If groupByAll, expand targetList into groupClause. In this case, we
+	 * cannot have any other group clauses, so this is safe.
+	 */
+
+	if (stmt->groupAll)
+	{
+		ListCell *l1;
+		/*
+		 * Iterate over targets, any non-aggregate gets added as a Target.
+		 * Note that it's not enough to check for a top-level Aggref; we need
+		 * to ensure that any sub-expression here does not include an Aggref
+		 * (for instance an expression such as `sum(col) + 4` should not be
+		 * added as a grouping target.
+		 */
+		foreach (l1,qry->targetList)
+		{
+			TargetEntry *n = (TargetEntry*)lfirst(l1);
+			if (!contain_aggs_of_level((Node *)n->expr, 0))
+				qry->groupClause = addTargetToGroupList(pstate, n, qry->groupClause, qry->targetList, 0);
+		}
+	}
+
 	/* mark column origins */
 	markTargetListOrigins(pstate, qry->targetList);
 
@@ -1438,14 +1462,16 @@ transformSelectStmt(ParseState *pstate, SelectStmt *stmt)
 										  EXPR_KIND_ORDER_BY,
 										  false /* allow SQL92 rules */ );
 
-	qry->groupClause = transformGroupClause(pstate,
-											stmt->groupClause,
-											&qry->groupingSets,
-											&qry->targetList,
-											qry->sortClause,
-											EXPR_KIND_GROUP_BY,
-											false /* allow SQL92 rules */ );
+	if (!stmt->groupAll)
+		qry->groupClause = transformGroupClause(pstate,
+												stmt->groupClause,
+												&qry->groupingSets,
+												&qry->targetList,
+												qry->sortClause,
+												EXPR_KIND_GROUP_BY,
+												false /* allow SQL92 rules */ );
 	qry->groupDistinct = stmt->groupDistinct;
+	qry->groupAll = stmt->groupAll;
 
 	if (stmt->distinctClause == NIL)
 	{
diff --git a/src/backend/parser/gram.y b/src/backend/parser/gram.y
index 9fd48acb1f8..3d97ceddc1e 100644
--- a/src/backend/parser/gram.y
+++ b/src/backend/parser/gram.y
@@ -120,6 +120,7 @@ typedef struct SelectLimit
 typedef struct GroupClause
 {
 	bool		distinct;
+	bool		all;
 	List	   *list;
 } GroupClause;
 
@@ -12993,6 +12994,7 @@ simple_select:
 					n->whereClause = $6;
 					n->groupClause = ($7)->list;
 					n->groupDistinct = ($7)->distinct;
+					n->groupAll = ($7)->all;
 					n->havingClause = $8;
 					n->windowClause = $9;
 					$$ = (Node *) n;
@@ -13010,6 +13012,7 @@ simple_select:
 					n->whereClause = $6;
 					n->groupClause = ($7)->list;
 					n->groupDistinct = ($7)->distinct;
+					n->groupAll = ($7)->all;
 					n->havingClause = $8;
 					n->windowClause = $9;
 					$$ = (Node *) n;
@@ -13502,12 +13505,21 @@ first_or_next: FIRST_P								{ $$ = 0; }
  * GroupingSet node of some type.
  */
 group_clause:
-			GROUP_P BY set_quantifier group_by_list
+			GROUP_P BY ALL
+				{
+					GroupClause *n = (GroupClause *) palloc(sizeof(GroupClause));
+					n->distinct = false;
+					n->list = NIL;
+					n->all = true;
+					$$ = n;
+				}
+			| GROUP_P BY set_quantifier group_by_list
 				{
 					GroupClause *n = (GroupClause *) palloc(sizeof(GroupClause));
 
 					n->distinct = $3 == SET_QUANTIFIER_DISTINCT;
 					n->list = $4;
+					n->all = false;
 					$$ = n;
 				}
 			| /*EMPTY*/
@@ -13516,6 +13528,7 @@ group_clause:
 
 					n->distinct = false;
 					n->list = NIL;
+					n->all = false;
 					$$ = n;
 				}
 		;
@@ -17618,6 +17631,7 @@ PLpgSQL_Expr: opt_distinct_clause opt_target_list
 					n->whereClause = $4;
 					n->groupClause = ($5)->list;
 					n->groupDistinct = ($5)->distinct;
+					n->groupAll = ($5)->all;
 					n->havingClause = $6;
 					n->windowClause = $7;
 					n->sortClause = $8;
diff --git a/src/backend/parser/parse_clause.c b/src/backend/parser/parse_clause.c
index 9f20a70ce13..9c26afd1f3c 100644
--- a/src/backend/parser/parse_clause.c
+++ b/src/backend/parser/parse_clause.c
@@ -90,8 +90,6 @@ static int	get_matching_location(int sortgroupref,
 								  List *sortgrouprefs, List *exprs);
 static List *resolve_unique_index_expr(ParseState *pstate, InferClause *infer,
 									   Relation heapRel);
-static List *addTargetToGroupList(ParseState *pstate, TargetEntry *tle,
-								  List *grouplist, List *targetlist, int location);
 static WindowClause *findWindowClause(List *wclist, const char *name);
 static Node *transformFrameOffset(ParseState *pstate, int frameOptions,
 								  Oid rangeopfamily, Oid rangeopcintype, Oid *inRangeFunc,
@@ -2598,6 +2596,9 @@ transformGroupingSet(List **flatresult,
  * GROUP BY items will be added to the targetlist (as resjunk columns)
  * if not already present, so the targetlist must be passed by reference.
  *
+ * If GROUP BY ALL is specified, the groupClause will be inferred to be all
+ * non-aggregate expressions in the targetList.
+ *
  * This is also used for window PARTITION BY clauses (which act almost the
  * same, but are always interpreted per SQL99 rules).
  *
@@ -3533,7 +3534,7 @@ addTargetToSortList(ParseState *pstate, TargetEntry *tle,
  *
  * Returns the updated SortGroupClause list.
  */
-static List *
+List *
 addTargetToGroupList(ParseState *pstate, TargetEntry *tle,
 					 List *grouplist, List *targetlist, int location)
 {
diff --git a/src/backend/utils/adt/ruleutils.c b/src/backend/utils/adt/ruleutils.c
index defcdaa8b34..aa97cc11c78 100644
--- a/src/backend/utils/adt/ruleutils.c
+++ b/src/backend/utils/adt/ruleutils.c
@@ -6186,7 +6186,10 @@ get_basic_select_query(Query *query, deparse_context *context)
 		save_ingroupby = context->inGroupBy;
 		context->inGroupBy = true;
 
-		if (query->groupingSets == NIL)
+		if (query->groupAll)
+			appendContextKeyword(context, " ALL ",
+								 -PRETTYINDENT_STD, PRETTYINDENT_STD, 1);
+		else if (query->groupingSets == NIL)
 		{
 			sep = "";
 			foreach(l, query->groupClause)
diff --git a/src/include/nodes/parsenodes.h b/src/include/nodes/parsenodes.h
index f1706df58fd..b0c1292d16a 100644
--- a/src/include/nodes/parsenodes.h
+++ b/src/include/nodes/parsenodes.h
@@ -215,6 +215,7 @@ typedef struct Query
 
 	List	   *groupClause;	/* a list of SortGroupClause's */
 	bool		groupDistinct;	/* is the group by clause distinct? */
+	bool		groupAll;	/* is the group by clause distinct? */
 
 	List	   *groupingSets;	/* a list of GroupingSet's if present */
 
@@ -2192,6 +2193,7 @@ typedef struct SelectStmt
 	Node	   *whereClause;	/* WHERE qualification */
 	List	   *groupClause;	/* GROUP BY clauses */
 	bool		groupDistinct;	/* Is this GROUP BY DISTINCT? */
+	bool		groupAll;		/* Is this GROUP BY ALL? */
 	Node	   *havingClause;	/* HAVING conditional-expression */
 	List	   *windowClause;	/* WINDOW window_name AS (...), ... */
 
diff --git a/src/include/parser/parse_clause.h b/src/include/parser/parse_clause.h
index 3e9894926de..a8a8f752287 100644
--- a/src/include/parser/parse_clause.h
+++ b/src/include/parser/parse_clause.h
@@ -50,6 +50,8 @@ extern List *addTargetToSortList(ParseState *pstate, TargetEntry *tle,
 								 List *sortlist, List *targetlist, SortBy *sortby);
 extern Index assignSortGroupRef(TargetEntry *tle, List *tlist);
 extern bool targetIsInSortList(TargetEntry *tle, Oid sortop, List *sortList);
+extern List *addTargetToGroupList(ParseState *pstate, TargetEntry *tle,
+								  List *grouplist, List *targetlist, int location);
 
 /* functions in parse_jsontable.c */
 extern ParseNamespaceItem *transformJsonTable(ParseState *pstate, JsonTable *jt);
diff --git a/src/test/regress/expected/aggregates.out b/src/test/regress/expected/aggregates.out
index 1f24f6ffd1f..74bdfc24bdc 100644
--- a/src/test/regress/expected/aggregates.out
+++ b/src/test/regress/expected/aggregates.out
@@ -1557,6 +1557,139 @@ drop table t2;
 drop table t3;
 drop table p_t1;
 --
+-- Test GROUP BY ALL
+--
+CREATE TEMP TABLE t1 (
+  a int,
+  b int,
+  c int
+);
+COPY t1 FROM STDIN;
+-- basic field check
+SELECT b, COUNT(*) FROM t1 GROUP BY ALL;
+ b | count 
+---+-------
+ 3 |     2
+ 2 |     2
+ 1 |     1
+(3 rows)
+
+-- throw a null in the values too
+SELECT a, COUNT(a) FROM t1 GROUP BY ALL;
+ a | count 
+---+-------
+   |     0
+ 1 |     4
+(2 rows)
+
+-- multiple columns, non-consecutive order
+SELECT a, SUM(b), b FROM t1 GROUP BY ALL;
+ a | sum | b 
+---+-----+---
+ 1 |   1 | 1
+ 1 |   6 | 3
+ 1 |   2 | 2
+   |   2 | 2
+(4 rows)
+
+-- multi columns, no aggregate
+SELECT a + b FROM t1 GROUP BY ALL;
+ ?column? 
+----------
+         
+        3
+        4
+        2
+(4 rows)
+
+-- non-top-level expression
+SELECT a, SUM(b) + 4 FROM t1 GROUP BY ALL;
+ a | ?column? 
+---+----------
+   |        6
+ 1 |       13
+(2 rows)
+
+-- including grouped column
+SELECT a, SUM(b) + a FROM t1 GROUP BY ALL;
+ a | ?column? 
+---+----------
+   |         
+ 1 |       10
+(2 rows)
+
+-- oops all aggregates
+SELECT COUNT(a), SUM(b) FROM t1 GROUP BY ALL;
+ count | sum 
+-------+-----
+     4 |  11
+(1 row)
+
+-- empty column list
+SELECT FROM t1 GROUP BY ALL;
+--
+(5 rows)
+
+-- filter
+SELECT a, COUNT(a) FILTER(WHERE b = 2) FROM t1 GROUP BY ALL;
+ a | count 
+---+-------
+   |     0
+ 1 |     1
+(2 rows)
+
+-- all cols
+SELECT * FROM t1 GROUP BY ALL;
+ a | b | c 
+---+---+---
+   | 2 | 2
+ 1 | 3 | 1
+ 1 | 2 | 2
+ 1 | 1 | 1
+(4 rows)
+
+SELECT *, count(*) FROM t1 GROUP BY ALL;
+ a | b | c | count 
+---+---+---+-------
+   | 2 | 2 |     1
+ 1 | 3 | 1 |     2
+ 1 | 2 | 2 |     1
+ 1 | 1 | 1 |     1
+(4 rows)
+
+-- expression without including aggregate columns
+SELECT a, SUM(b) + c FROM t1 GROUP BY ALL;
+ERROR:  column "t1.c" must appear in the GROUP BY clause or be used in an aggregate function
+LINE 1: SELECT a, SUM(b) + c FROM t1 GROUP BY ALL;
+                           ^
+-- subquery
+SELECT (SELECT a FROM t1 LIMIT 1), SUM(b) FROM t1 GROUP BY ALL;
+ a | sum 
+---+-----
+ 1 |  11
+(1 row)
+
+-- cte
+WITH a_query AS (SELECT a, SUM(b) AS sum_b FROM t1 GROUP BY ALL)
+SELECT AVG(a), sum_b  FROM a_query GROUP BY ALL;
+          avg           | sum_b 
+------------------------+-------
+ 1.00000000000000000000 |     9
+                        |     2
+(2 rows)
+
+CREATE VIEW v1 AS SELECT b, COUNT(*) FROM t1 GROUP BY ALL;
+NOTICE:  view "v1" will be a temporary view
+\sv v1
+CREATE OR REPLACE VIEW pg_temp_93.v1 AS
+ SELECT b,
+    count(*) AS count
+   FROM t1
+  GROUP BY
+  ALL 
+DROP VIEW v1;
+DROP TABLE t1;
+--
 -- Test GROUP BY matching of join columns that are type-coerced due to USING
 --
 create temp table t1(f1 int, f2 int);
diff --git a/src/test/regress/sql/aggregates.sql b/src/test/regress/sql/aggregates.sql
index 62540b1ffa4..11d6161cdf4 100644
--- a/src/test/regress/sql/aggregates.sql
+++ b/src/test/regress/sql/aggregates.sql
@@ -549,6 +549,72 @@ drop table t2;
 drop table t3;
 drop table p_t1;
 
+
+--
+-- Test GROUP BY ALL
+--
+
+CREATE TEMP TABLE t1 (
+  a int,
+  b int,
+  c int
+);
+
+COPY t1 FROM STDIN;
+1	1	1
+1	2	2
+1	3	1
+\N	2	2
+1	3	1
+\.
+
+-- basic field check
+SELECT b, COUNT(*) FROM t1 GROUP BY ALL;
+
+-- throw a null in the values too
+SELECT a, COUNT(a) FROM t1 GROUP BY ALL;
+
+-- multiple columns, non-consecutive order
+SELECT a, SUM(b), b FROM t1 GROUP BY ALL;
+
+-- multi columns, no aggregate
+SELECT a + b FROM t1 GROUP BY ALL;
+
+-- non-top-level expression
+SELECT a, SUM(b) + 4 FROM t1 GROUP BY ALL;
+
+-- including grouped column
+SELECT a, SUM(b) + a FROM t1 GROUP BY ALL;
+
+-- oops all aggregates
+SELECT COUNT(a), SUM(b) FROM t1 GROUP BY ALL;
+
+-- empty column list
+SELECT FROM t1 GROUP BY ALL;
+
+-- filter
+SELECT a, COUNT(a) FILTER(WHERE b = 2) FROM t1 GROUP BY ALL;
+
+-- all cols
+SELECT * FROM t1 GROUP BY ALL;
+SELECT *, count(*) FROM t1 GROUP BY ALL;
+
+-- expression without including aggregate columns
+SELECT a, SUM(b) + c FROM t1 GROUP BY ALL;
+
+-- subquery
+SELECT (SELECT a FROM t1 LIMIT 1), SUM(b) FROM t1 GROUP BY ALL;
+
+-- cte
+WITH a_query AS (SELECT a, SUM(b) AS sum_b FROM t1 GROUP BY ALL)
+SELECT AVG(a), sum_b  FROM a_query GROUP BY ALL;
+
+CREATE VIEW v1 AS SELECT b, COUNT(*) FROM t1 GROUP BY ALL;
+\sv v1
+
+DROP VIEW v1;
+DROP TABLE t1;
+
 --
 -- Test GROUP BY matching of join columns that are type-coerced due to USING
 --
-- 
2.49.0

