From 2dca72bb66c086ef77a27f7d7ff0bb524b4b9108 Mon Sep 17 00:00:00 2001
From: Anthonin Bonnefoy <anthonin.bonnefoy@datadoghq.com>
Date: Thu, 23 May 2024 11:24:44 +0200
Subject: Fix row estimation in gather paths

In parallel plans, the row count of a partial plan is estimated to
(rows/parallel_divisor). The parallel_divisor is the number of
parallel_workers plus a possible leader contribution.

When creating a gather path, we currently estimate the sum of gathered
rows to worker_rows*parallel_workers which leads to a lower estimated
row count.

This patch changes the gather path row estimation to
worker_rows*parallel_divisor to get a more accurate estimation.
---
 src/backend/optimizer/path/allpaths.c         |  7 ++--
 src/backend/optimizer/path/costsize.c         | 19 +++++++++
 src/backend/optimizer/plan/planner.c          |  6 +--
 src/include/optimizer/cost.h                  |  1 +
 src/test/regress/expected/join_hash.out       | 19 +++++----
 src/test/regress/expected/select_parallel.out | 39 +++++++++++++++++++
 src/test/regress/expected/test_setup.out      | 20 ++++++++++
 src/test/regress/sql/select_parallel.sql      | 11 ++++++
 src/test/regress/sql/test_setup.sql           | 21 ++++++++++
 9 files changed, 126 insertions(+), 17 deletions(-)

diff --git a/src/backend/optimizer/path/allpaths.c b/src/backend/optimizer/path/allpaths.c
index 4895cee994..fc72dfdeab 100644
--- a/src/backend/optimizer/path/allpaths.c
+++ b/src/backend/optimizer/path/allpaths.c
@@ -3071,8 +3071,7 @@ generate_gather_paths(PlannerInfo *root, RelOptInfo *rel, bool override_rows)
 	 * of partial_pathlist because of the way add_partial_path works.
 	 */
 	cheapest_partial_path = linitial(rel->partial_pathlist);
-	rows =
-		cheapest_partial_path->rows * cheapest_partial_path->parallel_workers;
+	rows = compute_gather_rows(cheapest_partial_path);
 	simple_gather_path = (Path *)
 		create_gather_path(root, rel, cheapest_partial_path, rel->reltarget,
 						   NULL, rowsp);
@@ -3090,7 +3089,7 @@ generate_gather_paths(PlannerInfo *root, RelOptInfo *rel, bool override_rows)
 		if (subpath->pathkeys == NIL)
 			continue;
 
-		rows = subpath->rows * subpath->parallel_workers;
+		rows = compute_gather_rows(subpath);
 		path = create_gather_merge_path(root, rel, subpath, rel->reltarget,
 										subpath->pathkeys, NULL, rowsp);
 		add_path(rel, &path->path);
@@ -3274,7 +3273,7 @@ generate_useful_gather_paths(PlannerInfo *root, RelOptInfo *rel, bool override_r
 													subpath,
 													useful_pathkeys,
 													-1.0);
-				rows = subpath->rows * subpath->parallel_workers;
+				rows = compute_gather_rows(subpath);
 			}
 			else
 				subpath = (Path *) create_incremental_sort_path(root,
diff --git a/src/backend/optimizer/path/costsize.c b/src/backend/optimizer/path/costsize.c
index ee23ed7835..c197d3f9e4 100644
--- a/src/backend/optimizer/path/costsize.c
+++ b/src/backend/optimizer/path/costsize.c
@@ -217,6 +217,25 @@ clamp_row_est(double nrows)
 	return nrows;
 }
 
+/*
+ * compute_gather_rows
+ *		Compute the number of rows for gather nodes.
+ *
+ * When creating a gather (merge) path, we need to estimate the sum of rows
+ * distributed to all workers. A worker will have an estimated row set to
+ * (rows / parallel_divisor). Since parallel_divisor may include the leader
+ * contribution, we can't simply multiply workers' rows by the number of
+ * parallel_workers and instead need to reuse the parallel_divisor to get a
+ * more accurate estimation.
+ */
+double
+compute_gather_rows(Path *partial_path)
+{
+	double		parallel_divisor = get_parallel_divisor(partial_path);
+
+	return clamp_row_est(partial_path->rows * parallel_divisor);
+}
+
 /*
  * clamp_width_est
  *		Force a tuple-width estimate to a sane value.
diff --git a/src/backend/optimizer/plan/planner.c b/src/backend/optimizer/plan/planner.c
index 4711f91239..c7aea3db9f 100644
--- a/src/backend/optimizer/plan/planner.c
+++ b/src/backend/optimizer/plan/planner.c
@@ -5370,8 +5370,8 @@ create_ordered_paths(PlannerInfo *root,
 																	root->sort_pathkeys,
 																	presorted_keys,
 																	limit_tuples);
-			total_groups = input_path->rows *
-				input_path->parallel_workers;
+			total_groups = compute_gather_rows(input_path);
+
 			sorted_path = (Path *)
 				create_gather_merge_path(root, ordered_rel,
 										 sorted_path,
@@ -7543,7 +7543,7 @@ gather_grouping_paths(PlannerInfo *root, RelOptInfo *rel)
 			(presorted_keys == 0 || !enable_incremental_sort))
 			continue;
 
-		total_groups = path->rows * path->parallel_workers;
+		total_groups = compute_gather_rows(path);
 
 		/*
 		 * We've no need to consider both a sort and incremental sort. We'll
diff --git a/src/include/optimizer/cost.h b/src/include/optimizer/cost.h
index b1c51a4e70..393fc8a9e5 100644
--- a/src/include/optimizer/cost.h
+++ b/src/include/optimizer/cost.h
@@ -212,5 +212,6 @@ extern PathTarget *set_pathtarget_cost_width(PlannerInfo *root, PathTarget *targ
 extern double compute_bitmap_pages(PlannerInfo *root, RelOptInfo *baserel,
 								   Path *bitmapqual, double loop_count,
 								   Cost *cost_p, double *tuples_p);
+extern double compute_gather_rows(Path *partial_path);
 
 #endif							/* COST_H */
diff --git a/src/test/regress/expected/join_hash.out b/src/test/regress/expected/join_hash.out
index 262fa71ed8..4fc34a0e72 100644
--- a/src/test/regress/expected/join_hash.out
+++ b/src/test/regress/expected/join_hash.out
@@ -508,18 +508,17 @@ set local hash_mem_multiplier = 1.0;
 set local enable_parallel_hash = on;
 explain (costs off)
   select count(*) from simple r join extremely_skewed s using (id);
-                              QUERY PLAN                               
------------------------------------------------------------------------
- Finalize Aggregate
+                           QUERY PLAN                            
+-----------------------------------------------------------------
+ Aggregate
    ->  Gather
          Workers Planned: 1
-         ->  Partial Aggregate
-               ->  Parallel Hash Join
-                     Hash Cond: (r.id = s.id)
-                     ->  Parallel Seq Scan on simple r
-                     ->  Parallel Hash
-                           ->  Parallel Seq Scan on extremely_skewed s
-(9 rows)
+         ->  Parallel Hash Join
+               Hash Cond: (r.id = s.id)
+               ->  Parallel Seq Scan on simple r
+               ->  Parallel Hash
+                     ->  Parallel Seq Scan on extremely_skewed s
+(8 rows)
 
 select count(*) from simple r join extremely_skewed s using (id);
  count 
diff --git a/src/test/regress/expected/select_parallel.out b/src/test/regress/expected/select_parallel.out
index 5a603f86b7..f95f882704 100644
--- a/src/test/regress/expected/select_parallel.out
+++ b/src/test/regress/expected/select_parallel.out
@@ -1328,4 +1328,43 @@ SELECT 1 FROM tenk1_vw_sec
                  Filter: (f1 < tenk1_vw_sec.unique1)
 (9 rows)
 
+-- test estimated rows in gather nodes with different numbers of workers
+EXPLAIN (COSTS OFF)
+SELECT * FROM tenk1 ORDER BY twenty;
+               QUERY PLAN               
+----------------------------------------
+ Gather Merge
+   Workers Planned: 4
+   ->  Sort
+         Sort Key: twenty
+         ->  Parallel Seq Scan on tenk1
+(5 rows)
+
+SELECT * FROM get_estimated_rows('SELECT * FROM tenk1 ORDER BY twenty');
+ estimated 
+-----------
+     10000
+(1 row)
+
+set max_parallel_workers_per_gather=3;
+SELECT * FROM get_estimated_rows('SELECT * FROM tenk1 ORDER BY twenty');
+ estimated 
+-----------
+     10000
+(1 row)
+
+set max_parallel_workers_per_gather=2;
+SELECT * FROM get_estimated_rows('SELECT * FROM tenk1 ORDER BY twenty');
+ estimated 
+-----------
+     10000
+(1 row)
+
+set max_parallel_workers_per_gather=1;
+SELECT * FROM get_estimated_rows('SELECT * FROM tenk1 ORDER BY twenty');
+ estimated 
+-----------
+      9999
+(1 row)
+
 rollback;
diff --git a/src/test/regress/expected/test_setup.out b/src/test/regress/expected/test_setup.out
index 3d0eeec996..8f2d863b9c 100644
--- a/src/test/regress/expected/test_setup.out
+++ b/src/test/regress/expected/test_setup.out
@@ -239,3 +239,23 @@ create function fipshash(text)
     returns text
     strict immutable parallel safe leakproof
     return substr(encode(sha256($1::bytea), 'hex'), 1, 32);
+-- get the number of estimated rows in the top node
+create function get_estimated_rows(text) returns table (estimated int)
+language plpgsql as
+$$
+declare
+    ln text;
+    tmp text[];
+    first_row bool := true;
+begin
+    for ln in
+        execute format('explain %s', $1)
+    loop
+        if first_row then
+            first_row := false;
+            tmp := regexp_match(ln, 'rows=(\d*)');
+            return query select tmp[1]::int;
+        end if;
+    end loop;
+end;
+$$;
diff --git a/src/test/regress/sql/select_parallel.sql b/src/test/regress/sql/select_parallel.sql
index c7df8f775c..b162cab7e9 100644
--- a/src/test/regress/sql/select_parallel.sql
+++ b/src/test/regress/sql/select_parallel.sql
@@ -510,4 +510,15 @@ EXPLAIN (COSTS OFF)
 SELECT 1 FROM tenk1_vw_sec
   WHERE (SELECT sum(f1) FROM int4_tbl WHERE f1 < unique1) < 100;
 
+-- test estimated rows in gather nodes with different numbers of workers
+EXPLAIN (COSTS OFF)
+SELECT * FROM tenk1 ORDER BY twenty;
+SELECT * FROM get_estimated_rows('SELECT * FROM tenk1 ORDER BY twenty');
+set max_parallel_workers_per_gather=3;
+SELECT * FROM get_estimated_rows('SELECT * FROM tenk1 ORDER BY twenty');
+set max_parallel_workers_per_gather=2;
+SELECT * FROM get_estimated_rows('SELECT * FROM tenk1 ORDER BY twenty');
+set max_parallel_workers_per_gather=1;
+SELECT * FROM get_estimated_rows('SELECT * FROM tenk1 ORDER BY twenty');
+
 rollback;
diff --git a/src/test/regress/sql/test_setup.sql b/src/test/regress/sql/test_setup.sql
index 06b0e2121f..937d1619c8 100644
--- a/src/test/regress/sql/test_setup.sql
+++ b/src/test/regress/sql/test_setup.sql
@@ -294,3 +294,24 @@ create function fipshash(text)
     returns text
     strict immutable parallel safe leakproof
     return substr(encode(sha256($1::bytea), 'hex'), 1, 32);
+
+-- get the number of estimated rows in the top node
+create function get_estimated_rows(text) returns table (estimated int)
+language plpgsql as
+$$
+declare
+    ln text;
+    tmp text[];
+    first_row bool := true;
+begin
+    for ln in
+        execute format('explain %s', $1)
+    loop
+        if first_row then
+            first_row := false;
+            tmp := regexp_match(ln, 'rows=(\d*)');
+            return query select tmp[1]::int;
+        end if;
+    end loop;
+end;
+$$;
-- 
2.39.3 (Apple Git-146)

