From 9ec5b893114f63332be30f88f9a81402e044501e Mon Sep 17 00:00:00 2001
From: Anthonin Bonnefoy <anthonin.bonnefoy@datadoghq.com>
Date: Wed, 18 Sep 2024 08:17:31 +0200
Subject: Clean psql extended state on extended command

When handling an extended command, bind_params and stmtName would be set
to NULL, leaking possibly allocated memory. On top of that, the
send_mode was not reset, leading to a possible state where we could
try to process an bind_named query with a NULL stmtName, leading to an
assert failure.

This patch reset the extended state to a correct state by freeing
allocated parameters and setting the send_mode back to default whenever
an extended extended backslash query is processed.
---
 src/bin/psql/command.c             | 12 +++----
 src/bin/psql/common.c              | 58 +++++++++++++++++++-----------
 src/bin/psql/common.h              |  1 +
 src/test/regress/expected/psql.out |  8 +++++
 src/test/regress/sql/psql.sql      |  5 +++
 5 files changed, 55 insertions(+), 29 deletions(-)

diff --git a/src/bin/psql/command.c b/src/bin/psql/command.c
index 4dfc7b2d857..16fb7973d23 100644
--- a/src/bin/psql/command.c
+++ b/src/bin/psql/command.c
@@ -483,9 +483,7 @@ exec_command_bind(PsqlScanState scan_state, bool active_branch)
 		int			nparams = 0;
 		int			nalloc = 0;
 
-		pset.bind_params = NULL;
-		pset.stmtName = NULL;
-
+		clean_extended_state();
 		while ((opt = psql_scan_slash_option(scan_state, OT_NORMAL, NULL, false)))
 		{
 			nparams++;
@@ -521,9 +519,7 @@ exec_command_bind_named(PsqlScanState scan_state, bool active_branch,
 		int			nparams = 0;
 		int			nalloc = 0;
 
-		pset.bind_params = NULL;
-		pset.stmtName = NULL;
-
+		clean_extended_state();
 		/* get the mandatory prepared statement name */
 		opt = psql_scan_slash_option(scan_state, OT_NORMAL, NULL, false);
 		if (!opt)
@@ -719,7 +715,7 @@ exec_command_close(PsqlScanState scan_state, bool active_branch, const char *cmd
 		char	   *opt = psql_scan_slash_option(scan_state,
 												 OT_NORMAL, NULL, false);
 
-		pset.stmtName = NULL;
+		clean_extended_state();
 		if (!opt)
 		{
 			pg_log_error("\\%s: missing required argument", cmd);
@@ -2205,7 +2201,7 @@ exec_command_parse(PsqlScanState scan_state, bool active_branch,
 		char	   *opt = psql_scan_slash_option(scan_state,
 												 OT_NORMAL, NULL, false);
 
-		pset.stmtName = NULL;
+		clean_extended_state();
 		if (!opt)
 		{
 			pg_log_error("\\%s: missing required argument", cmd);
diff --git a/src/bin/psql/common.c b/src/bin/psql/common.c
index 066dccbd841..d66477873ee 100644
--- a/src/bin/psql/common.c
+++ b/src/bin/psql/common.c
@@ -1275,27 +1275,7 @@ sendquery_cleanup:
 	}
 
 	/* clean up after extended protocol queries */
-	switch (pset.send_mode)
-	{
-		case PSQL_SEND_EXTENDED_CLOSE:	/* \close */
-			free(pset.stmtName);
-			break;
-		case PSQL_SEND_EXTENDED_PARSE:	/* \parse */
-			free(pset.stmtName);
-			break;
-		case PSQL_SEND_EXTENDED_QUERY_PARAMS:	/* \bind */
-		case PSQL_SEND_EXTENDED_QUERY_PREPARED: /* \bind_named */
-			for (i = 0; i < pset.bind_nparams; i++)
-				free(pset.bind_params[i]);
-			free(pset.bind_params);
-			free(pset.stmtName);
-			pset.bind_params = NULL;
-			break;
-		case PSQL_SEND_QUERY:
-			break;
-	}
-	pset.stmtName = NULL;
-	pset.send_mode = PSQL_SEND_QUERY;
+	clean_extended_state();
 
 	/* reset \gset trigger */
 	if (pset.gset_prefix)
@@ -2287,6 +2267,42 @@ uri_prefix_length(const char *connstr)
 	return 0;
 }
 
+/*
+ * Reset psql extended state
+ *
+ * Handling backslash command for extended protocol will change the
+ * send mode and allocate stmtName and bind params. This state needs
+ * to be cleaned when the query is processed or when a new extended
+ * command is processed, erasing the previous state.
+ */
+void
+clean_extended_state(void)
+{
+	int			i;
+
+	switch (pset.send_mode)
+	{
+		case PSQL_SEND_EXTENDED_CLOSE:	/* \close */
+			free(pset.stmtName);
+			break;
+		case PSQL_SEND_EXTENDED_PARSE:	/* \parse */
+			free(pset.stmtName);
+			break;
+		case PSQL_SEND_EXTENDED_QUERY_PARAMS:	/* \bind */
+		case PSQL_SEND_EXTENDED_QUERY_PREPARED: /* \bind_named */
+			for (i = 0; i < pset.bind_nparams; i++)
+				free(pset.bind_params[i]);
+			free(pset.bind_params);
+			free(pset.stmtName);
+			pset.bind_params = NULL;
+			break;
+		case PSQL_SEND_QUERY:
+			break;
+	}
+	pset.stmtName = NULL;
+	pset.send_mode = PSQL_SEND_QUERY;
+}
+
 /*
  * Recognized connection string either starts with a valid URI prefix or
  * contains a "=" in it.
diff --git a/src/bin/psql/common.h b/src/bin/psql/common.h
index 6efe12274fe..e3762a2c6c7 100644
--- a/src/bin/psql/common.h
+++ b/src/bin/psql/common.h
@@ -41,6 +41,7 @@ extern bool standard_strings(void);
 extern const char *session_username(void);
 
 extern void expand_tilde(char **filename);
+extern void clean_extended_state(void);
 
 extern bool recognized_connection_string(const char *connstr);
 
diff --git a/src/test/regress/expected/psql.out b/src/test/regress/expected/psql.out
index 6aeb7cb9636..6f585fd6a13 100644
--- a/src/test/regress/expected/psql.out
+++ b/src/test/regress/expected/psql.out
@@ -132,6 +132,14 @@ SELECT $1, $2 \parse stmt3
  foo      | bar
 (1 row)
 
+-- Check multiple calls to bind_named
+\bind_named test
+\bind_named
+\bind_named: missing required argument
+\g
+ERROR:  there is no parameter $1
+LINE 1: SELECT $1, $2 
+               ^
 -- \close (extended query protocol)
 \close
 \close: missing required argument
diff --git a/src/test/regress/sql/psql.sql b/src/test/regress/sql/psql.sql
index 0a2f8b46922..71f63b7f124 100644
--- a/src/test/regress/sql/psql.sql
+++ b/src/test/regress/sql/psql.sql
@@ -59,6 +59,11 @@ SELECT $1, $2 \parse stmt3
 \bind_named stmt2 'foo' \g
 \bind_named stmt3 'foo' 'bar' \g
 
+-- Check multiple calls to bind_named
+\bind_named test
+\bind_named
+\g
+
 -- \close (extended query protocol)
 \close
 \close ''
-- 
2.39.3 (Apple Git-146)

