From 0cb474123cb206fac53004b38b83fa6b05c18426 Mon Sep 17 00:00:00 2001
From: David Christensen <david.christensen@crunchydata.com>
Date: Mon, 11 Mar 2024 10:37:05 -0400
Subject: [PATCH v2] Add GROUP BY ALL

GROUP BY ALL is a form of GROUP BY which adds any TargetExpr that does not
contain an Aggref into the groupClause of the query, effectively making it
exactly equivalent to specifying those same expressions in an explicit GROUP BY list.

Since this is exclusive with any other GROUP BY form, this is fairly simple to
add into the grammar and handle without needing to get into grouping sets or
other more complicated forms.

This greatly improves data exploration in particular, as well as making it so
you don't need to trivially wrap more complicated queries in a subquery or
reproduce long, complicated expressions in the literal GROUP BY.
---
 src/backend/parser/analyze.c             | 40 ++++++++++--
 src/backend/parser/gram.y                | 16 ++++-
 src/backend/parser/parse_clause.c        |  7 +-
 src/backend/utils/adt/ruleutils.c        |  3 +
 src/include/nodes/parsenodes.h           |  2 +
 src/include/parser/parse_clause.h        |  2 +
 src/test/regress/expected/aggregates.out | 82 ++++++++++++++++++++++++
 src/test/regress/sql/aggregates.sql      | 47 ++++++++++++++
 8 files changed, 188 insertions(+), 11 deletions(-)

diff --git a/src/backend/parser/analyze.c b/src/backend/parser/analyze.c
index b9763ea1714..4a633bb9337 100644
--- a/src/backend/parser/analyze.c
+++ b/src/backend/parser/analyze.c
@@ -32,6 +32,7 @@
 #include "nodes/makefuncs.h"
 #include "nodes/nodeFuncs.h"
 #include "nodes/queryjumble.h"
+#include "optimizer/clauses.h"
 #include "optimizer/optimizer.h"
 #include "parser/analyze.h"
 #include "parser/parse_agg.h"
@@ -1415,6 +1416,29 @@ transformSelectStmt(ParseState *pstate, SelectStmt *stmt)
 	qry->targetList = transformTargetList(pstate, stmt->targetList,
 										  EXPR_KIND_SELECT_TARGET);
 
+	/*
+	 * If groupByAll, expand targetList into groupClause. In this case, we
+	 * cannot have any other group clauses, so this is safe.
+	 */
+
+	if (stmt->groupAll)
+	{
+		ListCell *l1;
+		/*
+		 * Iterate over targets, any non-aggregate gets added as a Target.
+		 * Note that it's not enough to check for a top-level Aggref; we need
+		 * to ensure that any sub-expression here does not include an Aggref
+		 * (for instance an expression such as `sum(col) + 4` should not be
+		 * added as a grouping target.
+		 */
+		foreach (l1,qry->targetList)
+		{
+			TargetEntry *n = (TargetEntry*)lfirst(l1);
+			if (!contain_agg_clause((Node *)n->expr))
+				qry->groupClause = addTargetToGroupList(pstate, n, qry->groupClause, qry->targetList, 0);
+		}
+	}
+
 	/* mark column origins */
 	markTargetListOrigins(pstate, qry->targetList);
 
@@ -1438,14 +1462,16 @@ transformSelectStmt(ParseState *pstate, SelectStmt *stmt)
 										  EXPR_KIND_ORDER_BY,
 										  false /* allow SQL92 rules */ );
 
-	qry->groupClause = transformGroupClause(pstate,
-											stmt->groupClause,
-											&qry->groupingSets,
-											&qry->targetList,
-											qry->sortClause,
-											EXPR_KIND_GROUP_BY,
-											false /* allow SQL92 rules */ );
+	if (!stmt->groupAll)
+		qry->groupClause = transformGroupClause(pstate,
+												stmt->groupClause,
+												&qry->groupingSets,
+												&qry->targetList,
+												qry->sortClause,
+												EXPR_KIND_GROUP_BY,
+												false /* allow SQL92 rules */ );
 	qry->groupDistinct = stmt->groupDistinct;
+	qry->groupAll = stmt->groupAll;
 
 	if (stmt->distinctClause == NIL)
 	{
diff --git a/src/backend/parser/gram.y b/src/backend/parser/gram.y
index 9fd48acb1f8..3d97ceddc1e 100644
--- a/src/backend/parser/gram.y
+++ b/src/backend/parser/gram.y
@@ -120,6 +120,7 @@ typedef struct SelectLimit
 typedef struct GroupClause
 {
 	bool		distinct;
+	bool		all;
 	List	   *list;
 } GroupClause;
 
@@ -12993,6 +12994,7 @@ simple_select:
 					n->whereClause = $6;
 					n->groupClause = ($7)->list;
 					n->groupDistinct = ($7)->distinct;
+					n->groupAll = ($7)->all;
 					n->havingClause = $8;
 					n->windowClause = $9;
 					$$ = (Node *) n;
@@ -13010,6 +13012,7 @@ simple_select:
 					n->whereClause = $6;
 					n->groupClause = ($7)->list;
 					n->groupDistinct = ($7)->distinct;
+					n->groupAll = ($7)->all;
 					n->havingClause = $8;
 					n->windowClause = $9;
 					$$ = (Node *) n;
@@ -13502,12 +13505,21 @@ first_or_next: FIRST_P								{ $$ = 0; }
  * GroupingSet node of some type.
  */
 group_clause:
-			GROUP_P BY set_quantifier group_by_list
+			GROUP_P BY ALL
+				{
+					GroupClause *n = (GroupClause *) palloc(sizeof(GroupClause));
+					n->distinct = false;
+					n->list = NIL;
+					n->all = true;
+					$$ = n;
+				}
+			| GROUP_P BY set_quantifier group_by_list
 				{
 					GroupClause *n = (GroupClause *) palloc(sizeof(GroupClause));
 
 					n->distinct = $3 == SET_QUANTIFIER_DISTINCT;
 					n->list = $4;
+					n->all = false;
 					$$ = n;
 				}
 			| /*EMPTY*/
@@ -13516,6 +13528,7 @@ group_clause:
 
 					n->distinct = false;
 					n->list = NIL;
+					n->all = false;
 					$$ = n;
 				}
 		;
@@ -17618,6 +17631,7 @@ PLpgSQL_Expr: opt_distinct_clause opt_target_list
 					n->whereClause = $4;
 					n->groupClause = ($5)->list;
 					n->groupDistinct = ($5)->distinct;
+					n->groupAll = ($5)->all;
 					n->havingClause = $6;
 					n->windowClause = $7;
 					n->sortClause = $8;
diff --git a/src/backend/parser/parse_clause.c b/src/backend/parser/parse_clause.c
index 9f20a70ce13..9c26afd1f3c 100644
--- a/src/backend/parser/parse_clause.c
+++ b/src/backend/parser/parse_clause.c
@@ -90,8 +90,6 @@ static int	get_matching_location(int sortgroupref,
 								  List *sortgrouprefs, List *exprs);
 static List *resolve_unique_index_expr(ParseState *pstate, InferClause *infer,
 									   Relation heapRel);
-static List *addTargetToGroupList(ParseState *pstate, TargetEntry *tle,
-								  List *grouplist, List *targetlist, int location);
 static WindowClause *findWindowClause(List *wclist, const char *name);
 static Node *transformFrameOffset(ParseState *pstate, int frameOptions,
 								  Oid rangeopfamily, Oid rangeopcintype, Oid *inRangeFunc,
@@ -2598,6 +2596,9 @@ transformGroupingSet(List **flatresult,
  * GROUP BY items will be added to the targetlist (as resjunk columns)
  * if not already present, so the targetlist must be passed by reference.
  *
+ * If GROUP BY ALL is specified, the groupClause will be inferred to be all
+ * non-aggregate expressions in the targetList.
+ *
  * This is also used for window PARTITION BY clauses (which act almost the
  * same, but are always interpreted per SQL99 rules).
  *
@@ -3533,7 +3534,7 @@ addTargetToSortList(ParseState *pstate, TargetEntry *tle,
  *
  * Returns the updated SortGroupClause list.
  */
-static List *
+List *
 addTargetToGroupList(ParseState *pstate, TargetEntry *tle,
 					 List *grouplist, List *targetlist, int location)
 {
diff --git a/src/backend/utils/adt/ruleutils.c b/src/backend/utils/adt/ruleutils.c
index defcdaa8b34..0827b1e2625 100644
--- a/src/backend/utils/adt/ruleutils.c
+++ b/src/backend/utils/adt/ruleutils.c
@@ -6199,6 +6199,9 @@ get_basic_select_query(Query *query, deparse_context *context)
 				sep = ", ";
 			}
 		}
+		if (query->groupAll)
+			appendContextKeyword(context, " ALL ",
+								 -PRETTYINDENT_STD, PRETTYINDENT_STD, 1);
 		else
 		{
 			sep = "";
diff --git a/src/include/nodes/parsenodes.h b/src/include/nodes/parsenodes.h
index 4ed14fc5b78..ae1a954afeb 100644
--- a/src/include/nodes/parsenodes.h
+++ b/src/include/nodes/parsenodes.h
@@ -215,6 +215,7 @@ typedef struct Query
 
 	List	   *groupClause;	/* a list of SortGroupClause's */
 	bool		groupDistinct;	/* is the group by clause distinct? */
+	bool		groupAll;	/* is the group by clause distinct? */
 
 	List	   *groupingSets;	/* a list of GroupingSet's if present */
 
@@ -2192,6 +2193,7 @@ typedef struct SelectStmt
 	Node	   *whereClause;	/* WHERE qualification */
 	List	   *groupClause;	/* GROUP BY clauses */
 	bool		groupDistinct;	/* Is this GROUP BY DISTINCT? */
+	bool		groupAll;		/* Is this GROUP BY ALL? */
 	Node	   *havingClause;	/* HAVING conditional-expression */
 	List	   *windowClause;	/* WINDOW window_name AS (...), ... */
 
diff --git a/src/include/parser/parse_clause.h b/src/include/parser/parse_clause.h
index 3e9894926de..a8a8f752287 100644
--- a/src/include/parser/parse_clause.h
+++ b/src/include/parser/parse_clause.h
@@ -50,6 +50,8 @@ extern List *addTargetToSortList(ParseState *pstate, TargetEntry *tle,
 								 List *sortlist, List *targetlist, SortBy *sortby);
 extern Index assignSortGroupRef(TargetEntry *tle, List *tlist);
 extern bool targetIsInSortList(TargetEntry *tle, Oid sortop, List *sortList);
+extern List *addTargetToGroupList(ParseState *pstate, TargetEntry *tle,
+								  List *grouplist, List *targetlist, int location);
 
 /* functions in parse_jsontable.c */
 extern ParseNamespaceItem *transformJsonTable(ParseState *pstate, JsonTable *jt);
diff --git a/src/test/regress/expected/aggregates.out b/src/test/regress/expected/aggregates.out
index 1f24f6ffd1f..0f04e899365 100644
--- a/src/test/regress/expected/aggregates.out
+++ b/src/test/regress/expected/aggregates.out
@@ -1557,6 +1557,88 @@ drop table t2;
 drop table t3;
 drop table p_t1;
 --
+-- Test GROUP BY ALL
+--
+CREATE TEMP TABLE t1 (
+  a int,
+  b int
+);
+COPY t1 FROM STDIN;
+-- basic field check
+SELECT b, COUNT(*) FROM t1 GROUP BY ALL;
+ b | count 
+---+-------
+ 3 |     2
+ 2 |     2
+ 1 |     1
+(3 rows)
+
+-- throw a null in the values too
+SELECT a, COUNT(a) FROM t1 GROUP BY ALL;
+ a | count 
+---+-------
+   |     0
+ 1 |     4
+(2 rows)
+
+-- multiple columns, non-consecutive order
+SELECT a, SUM(b), b FROM t1 GROUP BY ALL;
+ a | sum | b 
+---+-----+---
+ 1 |   1 | 1
+ 1 |   6 | 3
+ 1 |   2 | 2
+   |   2 | 2
+(4 rows)
+
+-- multi columns, no aggregate
+SELECT a + b FROM t1 GROUP BY ALL;
+ ?column? 
+----------
+         
+        3
+        4
+        2
+(4 rows)
+
+-- non-top-level expression
+SELECT a, SUM(b) + 4 FROM t1 GROUP BY ALL;
+ a | ?column? 
+---+----------
+   |        6
+ 1 |       13
+(2 rows)
+
+-- including grouped column
+SELECT a, SUM(b) + a FROM t1 GROUP BY ALL;
+ a | ?column? 
+---+----------
+   |         
+ 1 |       10
+(2 rows)
+
+-- oops all aggregates
+SELECT COUNT(a), SUM(b) FROM t1 GROUP BY ALL;
+ count | sum 
+-------+-----
+     4 |  11
+(1 row)
+
+-- empty column list
+SELECT FROM t1 GROUP BY ALL;
+--
+(5 rows)
+
+-- filter
+SELECT a, COUNT(a) FILTER(WHERE b = 2) FROM t1 GROUP BY ALL;
+ a | count 
+---+-------
+   |     0
+ 1 |     1
+(2 rows)
+
+DROP TABLE t1;
+--
 -- Test GROUP BY matching of join columns that are type-coerced due to USING
 --
 create temp table t1(f1 int, f2 int);
diff --git a/src/test/regress/sql/aggregates.sql b/src/test/regress/sql/aggregates.sql
index 62540b1ffa4..1f4c282dc62 100644
--- a/src/test/regress/sql/aggregates.sql
+++ b/src/test/regress/sql/aggregates.sql
@@ -549,6 +549,53 @@ drop table t2;
 drop table t3;
 drop table p_t1;
 
+
+--
+-- Test GROUP BY ALL
+--
+
+CREATE TEMP TABLE t1 (
+  a int,
+  b int
+);
+
+COPY t1 FROM STDIN;
+1	1
+1	2
+1	3
+\N	2
+1	3
+\.
+
+-- basic field check
+SELECT b, COUNT(*) FROM t1 GROUP BY ALL;
+
+-- throw a null in the values too
+SELECT a, COUNT(a) FROM t1 GROUP BY ALL;
+
+-- multiple columns, non-consecutive order
+SELECT a, SUM(b), b FROM t1 GROUP BY ALL;
+
+-- multi columns, no aggregate
+SELECT a + b FROM t1 GROUP BY ALL;
+
+-- non-top-level expression
+SELECT a, SUM(b) + 4 FROM t1 GROUP BY ALL;
+
+-- including grouped column
+SELECT a, SUM(b) + a FROM t1 GROUP BY ALL;
+
+-- oops all aggregates
+SELECT COUNT(a), SUM(b) FROM t1 GROUP BY ALL;
+
+-- empty column list
+SELECT FROM t1 GROUP BY ALL;
+
+-- filter
+SELECT a, COUNT(a) FILTER(WHERE b = 2) FROM t1 GROUP BY ALL;
+
+DROP TABLE t1;
+
 --
 -- Test GROUP BY matching of join columns that are type-coerced due to USING
 --
-- 
2.49.0

