From 3f3597952fac250eeb55aa1b09a6409c4417beab Mon Sep 17 00:00:00 2001
From: Rahila Syed <rahilasyed.90@gmail.com>
Date: Wed, 26 Mar 2025 14:38:39 +0530
Subject: [PATCH v21 1/2] Preparatory changes for reporting memory context
 statistics

Ensure that MemoryContextStatsInternal can return number of
contexts. Also, provide an option in MemoryContextStatsInternal
to return without printing stats to either stderr or logs.
---
 src/backend/utils/mmgr/mcxt.c | 77 ++++++++++++++++++++++++++++++-----
 1 file changed, 66 insertions(+), 11 deletions(-)

diff --git a/src/backend/utils/mmgr/mcxt.c b/src/backend/utils/mmgr/mcxt.c
index d98ae9db6be..2cbde8f39c3 100644
--- a/src/backend/utils/mmgr/mcxt.c
+++ b/src/backend/utils/mmgr/mcxt.c
@@ -135,6 +135,17 @@ static const MemoryContextMethods mcxt_methods[] = {
 };
 
 #undef BOGUS_MCTX
+/*
+ * This is passed to MemoryContextStatsInternal to determine whether
+ * to print context statistics or not and where to print them logs or
+ * stderr.
+ */
+typedef enum PrintDestination
+{
+	PRINT_STATS_TO_STDERR = 0,
+	PRINT_STATS_TO_LOGS,
+	PRINT_STATS_NONE
+}			PrintDestination;
 
 /*
  * CurrentMemoryContext
@@ -162,7 +173,7 @@ static void MemoryContextCallResetCallbacks(MemoryContext context);
 static void MemoryContextStatsInternal(MemoryContext context, int level,
 									   int max_level, int max_children,
 									   MemoryContextCounters *totals,
-									   bool print_to_stderr);
+									   PrintDestination print_location, int *num_contexts);
 static void MemoryContextStatsPrint(MemoryContext context, void *passthru,
 									const char *stats_string,
 									bool print_to_stderr);
@@ -831,11 +842,19 @@ MemoryContextStatsDetail(MemoryContext context,
 						 bool print_to_stderr)
 {
 	MemoryContextCounters grand_totals;
+	int			num_contexts;
+	PrintDestination print_location;
 
 	memset(&grand_totals, 0, sizeof(grand_totals));
 
+	if (print_to_stderr)
+		print_location = PRINT_STATS_TO_STDERR;
+	else
+		print_location = PRINT_STATS_TO_LOGS;
+
+	/* num_contexts report number of contexts aggregated in the output */
 	MemoryContextStatsInternal(context, 0, max_level, max_children,
-							   &grand_totals, print_to_stderr);
+							   &grand_totals, print_location, &num_contexts);
 
 	if (print_to_stderr)
 		fprintf(stderr,
@@ -870,13 +889,14 @@ MemoryContextStatsDetail(MemoryContext context,
  *		One recursion level for MemoryContextStats
  *
  * Print stats for this context if possible, but in any case accumulate counts
- * into *totals (if not NULL).
+ * into *totals (if not NULL). The callers should make sure that print_location
+ * is set to PRINT_STATS_STDERR or PRINT_STATS_TO_LOGS or PRINT_STATS_NONE.
  */
 static void
 MemoryContextStatsInternal(MemoryContext context, int level,
 						   int max_level, int max_children,
 						   MemoryContextCounters *totals,
-						   bool print_to_stderr)
+						   PrintDestination print_location, int *num_contexts)
 {
 	MemoryContext child;
 	int			ichild;
@@ -884,10 +904,39 @@ MemoryContextStatsInternal(MemoryContext context, int level,
 	Assert(MemoryContextIsValid(context));
 
 	/* Examine the context itself */
-	context->methods->stats(context,
-							MemoryContextStatsPrint,
-							&level,
-							totals, print_to_stderr);
+	switch (print_location)
+	{
+		case PRINT_STATS_TO_STDERR:
+			context->methods->stats(context,
+									MemoryContextStatsPrint,
+									&level,
+									totals, true);
+			break;
+
+		case PRINT_STATS_TO_LOGS:
+			context->methods->stats(context,
+									MemoryContextStatsPrint,
+									&level,
+									totals, false);
+			break;
+
+		case PRINT_STATS_NONE:
+
+			/*
+			 * Do not print the statistics if print_location is
+			 * PRINT_STATS_NONE, only compute totals. This is used in
+			 * reporting of memory context statistics via a sql function. Last
+			 * parameter is not relevant.
++			 */
+			context->methods->stats(context,
+									NULL,
+									NULL,
+									totals, false);
+			break;
+	}
+
+	/* Increment the context count for each of the recursive call */
+	*num_contexts = *num_contexts + 1;
 
 	/*
 	 * Examine children.
@@ -907,7 +956,7 @@ MemoryContextStatsInternal(MemoryContext context, int level,
 			MemoryContextStatsInternal(child, level + 1,
 									   max_level, max_children,
 									   totals,
-									   print_to_stderr);
+									   print_location, num_contexts);
 		}
 	}
 
@@ -926,7 +975,13 @@ MemoryContextStatsInternal(MemoryContext context, int level,
 			child = MemoryContextTraverseNext(child, context);
 		}
 
-		if (print_to_stderr)
+		/*
+		 * Add the count of children contexts which are traversed in the
+		 * non-recursive manner.
+		 */
+		*num_contexts = *num_contexts + ichild;
+
+		if (print_location == PRINT_STATS_TO_STDERR)
 		{
 			for (int i = 0; i <= level; i++)
 				fprintf(stderr, "  ");
@@ -939,7 +994,7 @@ MemoryContextStatsInternal(MemoryContext context, int level,
 					local_totals.freechunks,
 					local_totals.totalspace - local_totals.freespace);
 		}
-		else
+		else if (print_location == PRINT_STATS_TO_LOGS)
 			ereport(LOG_SERVER_ONLY,
 					(errhidestmt(true),
 					 errhidecontext(true),
-- 
2.39.3 (Apple Git-146)

