From d7c94eb6eb008ee962dc23ebe24f8ae7aa84895b Mon Sep 17 00:00:00 2001
From: TsinghuaLucky912 <2903807914@qq.com>
Date: Thu, 28 Nov 2024 08:51:39 -0800
Subject: [PATCH] Added prosupport function for estimating numeric
 generate_series rows

---
 src/backend/utils/adt/numeric.c              | 118 +++++++++++++++++++
 src/include/catalog/pg_proc.dat              |   9 +-
 src/test/regress/expected/misc_functions.out |  60 ++++++++++
 src/test/regress/sql/misc_functions.sql      |  37 ++++++
 4 files changed, 222 insertions(+), 2 deletions(-)

diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c
index 344d7137f9..5cec3d6bd1 100644
--- a/src/backend/utils/adt/numeric.c
+++ b/src/backend/utils/adt/numeric.c
@@ -32,6 +32,7 @@
 #include "lib/hyperloglog.h"
 #include "libpq/pqformat.h"
 #include "miscadmin.h"
+#include "optimizer/optimizer.h"
 #include "nodes/nodeFuncs.h"
 #include "nodes/supportnodes.h"
 #include "utils/array.h"
@@ -1828,6 +1829,123 @@ generate_series_step_numeric(PG_FUNCTION_ARGS)
 }
 
 
+/*
+ * Planner support function for generate_series(numeric, numeric [, numeric])
+ */
+Datum
+generate_series_numeric_support(PG_FUNCTION_ARGS)
+{
+	Node	   *rawreq = (Node *) PG_GETARG_POINTER(0);
+	Node	   *ret = NULL;
+
+	if (IsA(rawreq, SupportRequestRows))
+	{
+		/* Try to estimate the number of rows returned */
+		SupportRequestRows *req = (SupportRequestRows *) rawreq;
+
+		if (is_funcclause(req->node))	/* be paranoid */
+		{
+			List	   *args = ((FuncExpr *) req->node)->args;
+			Node	   *arg1,
+					   *arg2,
+					   *arg3;
+
+			/* We can use estimated argument values here */
+			arg1 = estimate_expression_value(req->root, linitial(args));
+			arg2 = estimate_expression_value(req->root, lsecond(args));
+			if (list_length(args) >= 3)
+				arg3 = estimate_expression_value(req->root, lthird(args));
+			else
+				arg3 = NULL;
+
+			/*
+			 * If any argument is constant NULL, we can safely assume that
+			 * zero rows are returned.  Otherwise, if they're all non-NULL
+			 * constants, we can calculate the number of rows that will be
+			 * returned.  Use double arithmetic to avoid overflow hazards.
+			 */
+			if ((IsA(arg1, Const) &&
+				 ((Const *) arg1)->constisnull) ||
+				(IsA(arg2, Const) &&
+				 ((Const *) arg2)->constisnull) ||
+				(arg3 != NULL && IsA(arg3, Const) &&
+				 ((Const *) arg3)->constisnull))
+			{
+				req->rows = 0;
+				ret = (Node *) req;
+			}
+			else if (IsA(arg1, Const) &&
+					 IsA(arg2, Const) &&
+					 (arg3 == NULL || IsA(arg3, Const)))
+			{
+				Numeric		start,
+							finish,
+							step;
+				NumericVar	var_start,
+							var_finish,
+							var_diff;
+				NumericVar	nstep = const_one;
+
+				init_var(&var_start);
+				init_var(&var_finish);
+				init_var(&var_diff);
+
+				start = DatumGetNumeric(((Const *) arg1)->constvalue);
+				finish = DatumGetNumeric(((Const *) arg2)->constvalue);
+
+				if (NUMERIC_IS_SPECIAL(start) || NUMERIC_IS_SPECIAL(finish))
+				{
+					goto cleanup;
+				}
+
+				init_var_from_num(start, &var_start);
+				init_var_from_num(finish, &var_finish);
+
+				if (arg3)
+				{
+					step = DatumGetNumeric(((Const *) arg3)->constvalue);
+					if (NUMERIC_IS_SPECIAL(step))
+						goto cleanup;
+
+					init_var_from_num(step, &nstep);
+				}
+
+				sub_var(&var_finish, &var_start, &var_diff);
+
+				/* This equation works for either sign of step */
+				if (cmp_var(&nstep, &const_zero) != 0)
+				{
+					/* When the sign of the step size and the series range don't match, there are no rows in the series. */
+					if (nstep.sign != var_diff.sign)
+					{
+						req->rows = 0;
+						ret = (Node *) req;
+					}
+					else
+					{
+						NumericVar	q;
+
+						init_var(&q);
+						div_var(&var_diff, &nstep, &q, 0, false, false);
+
+						req->rows = numericvar_to_double_no_overflow(&q) + 1;
+						ret = (Node *) req;
+
+						free_var(&q);
+					}
+				}
+
+		cleanup:
+				free_var(&var_start);
+				free_var(&var_finish);
+				free_var(&var_diff);
+			}
+		}
+	}
+
+	PG_RETURN_POINTER(ret);
+}
+
 /*
  * Implements the numeric version of the width_bucket() function
  * defined by SQL2003. See also width_bucket_float8().
diff --git a/src/include/catalog/pg_proc.dat b/src/include/catalog/pg_proc.dat
index cbbe8acd38..9575524007 100644
--- a/src/include/catalog/pg_proc.dat
+++ b/src/include/catalog/pg_proc.dat
@@ -8464,13 +8464,18 @@
   proname => 'generate_series_int8_support', prorettype => 'internal',
   proargtypes => 'internal', prosrc => 'generate_series_int8_support' },
 { oid => '3259', descr => 'non-persistent series generator',
-  proname => 'generate_series', prorows => '1000', proretset => 't',
+  proname => 'generate_series', prorows => '1000',
+  prosupport => 'generate_series_numeric_support', proretset => 't',
   prorettype => 'numeric', proargtypes => 'numeric numeric numeric',
   prosrc => 'generate_series_step_numeric' },
 { oid => '3260', descr => 'non-persistent series generator',
-  proname => 'generate_series', prorows => '1000', proretset => 't',
+  proname => 'generate_series', prorows => '1000',
+  prosupport => 'generate_series_numeric_support', proretset => 't',
   prorettype => 'numeric', proargtypes => 'numeric numeric',
   prosrc => 'generate_series_numeric' },
+{ oid => '8405', descr => 'planner support for generate_series',
+  proname => 'generate_series_numeric_support', prorettype => 'internal',
+  proargtypes => 'internal', prosrc => 'generate_series_numeric_support' },
 { oid => '938', descr => 'non-persistent series generator',
   proname => 'generate_series', prorows => '1000',
   prosupport => 'generate_series_timestamp_support', proretset => 't',
diff --git a/src/test/regress/expected/misc_functions.out b/src/test/regress/expected/misc_functions.out
index 36b1201f9f..2fb8c771db 100644
--- a/src/test/regress/expected/misc_functions.out
+++ b/src/test/regress/expected/misc_functions.out
@@ -712,6 +712,66 @@ false, true, false, true);
 -- the support function.
 SELECT * FROM generate_series(TIMESTAMPTZ '2024-02-01', TIMESTAMPTZ '2024-03-01', INTERVAL '0 day') g(s);
 ERROR:  step size cannot equal zero
+--
+-- Test the SupportRequestRows support function for generate_series_numeric()
+--
+-- Ensure the row estimate matches the actual rows
+SELECT explain_mask_costs($$
+SELECT * FROM generate_series(1.0, 25.0, 2.0) g(s);$$,
+true, true, false, true);
+                                    explain_mask_costs                                    
+------------------------------------------------------------------------------------------
+ Function Scan on generate_series g  (cost=N..N rows=13 width=N) (actual rows=13 loops=1)
+(1 row)
+
+SELECT explain_mask_costs($$
+SELECT * FROM generate_series(1.0, 25.0) g(s);$$,
+true, true, false, true);
+                                    explain_mask_costs                                    
+------------------------------------------------------------------------------------------
+ Function Scan on generate_series g  (cost=N..N rows=25 width=N) (actual rows=25 loops=1)
+(1 row)
+
+-- Ensure the estimates match when step is decreasing
+SELECT explain_mask_costs($$
+SELECT * FROM generate_series(25.0, 1.0, -1.0) g(s);$$,
+true, true, false, true);
+                                    explain_mask_costs                                    
+------------------------------------------------------------------------------------------
+ Function Scan on generate_series g  (cost=N..N rows=25 width=N) (actual rows=25 loops=1)
+(1 row)
+
+-- Ensure an empty range estimates 1 row
+SELECT explain_mask_costs($$
+SELECT * FROM generate_series(25.0, 1.0, 1.0) g(s);$$,
+true, true, false, true);
+                                   explain_mask_costs                                   
+----------------------------------------------------------------------------------------
+ Function Scan on generate_series g  (cost=N..N rows=1 width=N) (actual rows=0 loops=1)
+(1 row)
+
+-- Ensure we get the default row estimate for infinity values
+SELECT explain_mask_costs($$
+SELECT * FROM generate_series('-infinity'::NUMERIC, 'infinity'::NUMERIC, 1.0) g(s);$$,
+false, true, false, true);
+                        explain_mask_costs                         
+-------------------------------------------------------------------
+ Function Scan on generate_series g  (cost=N..N rows=1000 width=N)
+(1 row)
+
+SELECT explain_mask_costs($$
+SELECT * FROM generate_series(1.0, 25.0, '-infinity'::NUMERIC) g(s);$$,
+false, true, false, true);
+                        explain_mask_costs                         
+-------------------------------------------------------------------
+ Function Scan on generate_series g  (cost=N..N rows=1000 width=N)
+(1 row)
+
+-- Ensure the row estimate behaves correctly when step size is zero.
+-- We expect generate_series_numeric() to throw the error rather than in
+-- the support function.
+SELECT * FROM generate_series(25.0, 2.0, 0.0) g(s);
+ERROR:  step size cannot equal zero
 -- Test functions for control data
 SELECT count(*) > 0 AS ok FROM pg_control_checkpoint();
  ok 
diff --git a/src/test/regress/sql/misc_functions.sql b/src/test/regress/sql/misc_functions.sql
index b7495d70eb..3e8171cc09 100644
--- a/src/test/regress/sql/misc_functions.sql
+++ b/src/test/regress/sql/misc_functions.sql
@@ -311,6 +311,43 @@ false, true, false, true);
 -- the support function.
 SELECT * FROM generate_series(TIMESTAMPTZ '2024-02-01', TIMESTAMPTZ '2024-03-01', INTERVAL '0 day') g(s);
 
+--
+-- Test the SupportRequestRows support function for generate_series_numeric()
+--
+
+-- Ensure the row estimate matches the actual rows
+SELECT explain_mask_costs($$
+SELECT * FROM generate_series(1.0, 25.0, 2.0) g(s);$$,
+true, true, false, true);
+
+SELECT explain_mask_costs($$
+SELECT * FROM generate_series(1.0, 25.0) g(s);$$,
+true, true, false, true);
+
+-- Ensure the estimates match when step is decreasing
+SELECT explain_mask_costs($$
+SELECT * FROM generate_series(25.0, 1.0, -1.0) g(s);$$,
+true, true, false, true);
+
+-- Ensure an empty range estimates 1 row
+SELECT explain_mask_costs($$
+SELECT * FROM generate_series(25.0, 1.0, 1.0) g(s);$$,
+true, true, false, true);
+
+-- Ensure we get the default row estimate for infinity values
+SELECT explain_mask_costs($$
+SELECT * FROM generate_series('-infinity'::NUMERIC, 'infinity'::NUMERIC, 1.0) g(s);$$,
+false, true, false, true);
+
+SELECT explain_mask_costs($$
+SELECT * FROM generate_series(1.0, 25.0, '-infinity'::NUMERIC) g(s);$$,
+false, true, false, true);
+
+-- Ensure the row estimate behaves correctly when step size is zero.
+-- We expect generate_series_numeric() to throw the error rather than in
+-- the support function.
+SELECT * FROM generate_series(25.0, 2.0, 0.0) g(s);
+
 -- Test functions for control data
 SELECT count(*) > 0 AS ok FROM pg_control_checkpoint();
 SELECT count(*) > 0 AS ok FROM pg_control_init();
-- 
2.31.1

