From 110ecf4f77324d82a66d0120c886c75209c6688c Mon Sep 17 00:00:00 2001
From: Rahila Syed <rahilasyed.90@gmail.com>
Date: Mon, 3 Feb 2025 15:33:19 +0530
Subject: [PATCH 1/2] Preparatory changes for reporting memory context
 statistics

Ensure that MemoryContextStatsInternal can return number of
contexts. Also, provide an option in MemoryContextStatsInternal
to return without printing stats to either stderr or logs.
---
 src/backend/utils/mmgr/mcxt.c | 65 +++++++++++++++++++++++++++++------
 1 file changed, 55 insertions(+), 10 deletions(-)

diff --git a/src/backend/utils/mmgr/mcxt.c b/src/backend/utils/mmgr/mcxt.c
index aa6da0d035..946a3731fd 100644
--- a/src/backend/utils/mmgr/mcxt.c
+++ b/src/backend/utils/mmgr/mcxt.c
@@ -135,6 +135,17 @@ static const MemoryContextMethods mcxt_methods[] = {
 };
 
 #undef BOGUS_MCTX
+/*
+ * This is passed to MemoryContextStatsInternal to determine whether
+ * to print context statistics or not and where to print them logs or
+ * stderr.
+ */
+typedef enum PrintDestination
+{
+	PRINT_STATS_TO_STDERR = 0,
+	PRINT_STATS_TO_LOGS,
+	PRINT_STATS_NONE
+}			PrintDestination;
 
 /*
  * CurrentMemoryContext
@@ -162,7 +173,7 @@ static void MemoryContextCallResetCallbacks(MemoryContext context);
 static void MemoryContextStatsInternal(MemoryContext context, int level,
 									   int max_level, int max_children,
 									   MemoryContextCounters *totals,
-									   bool print_to_stderr);
+									   PrintDestination print_location, int *num_contexts);
 static void MemoryContextStatsPrint(MemoryContext context, void *passthru,
 									const char *stats_string,
 									bool print_to_stderr);
@@ -831,11 +842,19 @@ MemoryContextStatsDetail(MemoryContext context,
 						 bool print_to_stderr)
 {
 	MemoryContextCounters grand_totals;
+	int			num_contexts;
+	PrintDestination print_location;
 
 	memset(&grand_totals, 0, sizeof(grand_totals));
 
+	if (print_to_stderr)
+		print_location = PRINT_STATS_TO_STDERR;
+	else
+		print_location = PRINT_STATS_TO_LOGS;
+
+	/* num_contexts report number of contexts aggregated in the output */
 	MemoryContextStatsInternal(context, 0, max_level, max_children,
-							   &grand_totals, print_to_stderr);
+							   &grand_totals, print_location, &num_contexts);
 
 	if (print_to_stderr)
 		fprintf(stderr,
@@ -876,18 +895,43 @@ static void
 MemoryContextStatsInternal(MemoryContext context, int level,
 						   int max_level, int max_children,
 						   MemoryContextCounters *totals,
-						   bool print_to_stderr)
+						   PrintDestination print_location, int *num_contexts)
 {
 	MemoryContext child;
 	int			ichild;
+	bool		print_to_stderr = true;
 
+	check_stack_depth();
 	Assert(MemoryContextIsValid(context));
 
-	/* Examine the context itself */
-	context->methods->stats(context,
-							MemoryContextStatsPrint,
-							&level,
-							totals, print_to_stderr);
+	if (print_location == PRINT_STATS_TO_STDERR)
+		print_to_stderr = true;
+	else if (print_location == PRINT_STATS_TO_LOGS)
+		print_to_stderr = false;
+
+	if (print_location != PRINT_STATS_NONE)
+	{
+		/* Examine the context itself */
+		context->methods->stats(context,
+								MemoryContextStatsPrint,
+								&level,
+								totals, print_to_stderr);
+	}
+
+	/*
+	 * Do not print the statistics if print_to_stderr is PRINT_STATS_NONE,
+	 * only compute totals.
+	 */
+	else
+	{
+		/* Examine the context itself */
+		context->methods->stats(context,
+								NULL,
+								NULL,
+								totals, print_to_stderr);
+	}
+	/* Increment the context count */
+	*num_contexts = *num_contexts + 1;
 
 	/*
 	 * Examine children.
@@ -907,7 +951,7 @@ MemoryContextStatsInternal(MemoryContext context, int level,
 			MemoryContextStatsInternal(child, level + 1,
 									   max_level, max_children,
 									   totals,
-									   print_to_stderr);
+									   print_location, num_contexts);
 		}
 	}
 
@@ -925,6 +969,7 @@ MemoryContextStatsInternal(MemoryContext context, int level,
 			ichild++;
 			child = MemoryContextTraverseNext(child, context);
 		}
+		*num_contexts = *num_contexts + ichild;
 
 		if (print_to_stderr)
 		{
@@ -939,7 +984,7 @@ MemoryContextStatsInternal(MemoryContext context, int level,
 					local_totals.freechunks,
 					local_totals.totalspace - local_totals.freespace);
 		}
-		else
+		else if (print_location != PRINT_STATS_NONE)
 			ereport(LOG_SERVER_ONLY,
 					(errhidestmt(true),
 					 errhidecontext(true),
-- 
2.34.1

