From 8a3845e6754e0d323d498e596bc0d82e40de0cdb Mon Sep 17 00:00:00 2001
From: Jelte Fennema-Nio <jelte.fennema@microsoft.com>
Date: Thu, 7 Mar 2024 10:18:17 +0100
Subject: [PATCH v33 2/5] Add tests for libpq query cancellation APIs

This is in preparation of making changes and additions to these APIs.
---
 .../modules/libpq_pipeline/libpq_pipeline.c   | 138 +++++++++++++++++-
 1 file changed, 137 insertions(+), 1 deletion(-)

diff --git a/src/test/modules/libpq_pipeline/libpq_pipeline.c b/src/test/modules/libpq_pipeline/libpq_pipeline.c
index 5f43aa40de4..3517a852736 100644
--- a/src/test/modules/libpq_pipeline/libpq_pipeline.c
+++ b/src/test/modules/libpq_pipeline/libpq_pipeline.c
@@ -86,6 +86,139 @@ pg_fatal_impl(int line, const char *fmt,...)
 	exit(1);
 }
 
+/*
+ * Check that the query on the given connection got canceled.
+ *
+ * This is a function wrapped in a macro to make the reported line number
+ * in an error match the line number of the invocation.
+ */
+#define confirm_query_canceled(conn) confirm_query_canceled_impl(__LINE__, conn)
+static void
+confirm_query_canceled_impl(int line, PGconn *conn)
+{
+	PGresult   *res = NULL;
+
+	res = PQgetResult(conn);
+	if (res == NULL)
+		pg_fatal_impl(line, "PQgetResult returned null: %s",
+					  PQerrorMessage(conn));
+	if (PQresultStatus(res) != PGRES_FATAL_ERROR)
+		pg_fatal_impl(line, "query did not fail when it was expected");
+	if (strcmp(PQresultErrorField(res, PG_DIAG_SQLSTATE), "57014") != 0)
+		pg_fatal_impl(line, "query failed with a different error than cancellation: %s",
+					  PQerrorMessage(conn));
+	PQclear(res);
+	while (PQisBusy(conn))
+	{
+		PQconsumeInput(conn);
+	}
+}
+
+#define send_cancellable_query(conn, monitorConn) send_cancellable_query_impl(__LINE__, conn, monitorConn)
+static void
+send_cancellable_query_impl(int line, PGconn *conn, PGconn *monitorConn)
+{
+	const char *env_wait;
+	const Oid	paramTypes[1] = {INT4OID};
+
+	env_wait = getenv("PG_TEST_TIMEOUT_DEFAULT");
+	if (env_wait == NULL)
+		env_wait = "180";
+
+	if (PQsendQueryParams(conn, "SELECT pg_sleep($1)", 1, paramTypes, &env_wait, NULL, NULL, 0) != 1)
+		pg_fatal_impl(line, "failed to send query: %s", PQerrorMessage(conn));
+
+	/*
+	 * Wait until the query is actually running. Otherwise sending a
+	 * cancellation request might not cancel the query due to race conditions.
+	 */
+	while (true)
+	{
+		char	   *value = NULL;
+		PGresult   *res = PQexec(
+								 monitorConn,
+								 "SELECT count(*) FROM pg_stat_activity WHERE "
+								 "query = 'SELECT pg_sleep($1)' "
+								 "AND state = 'active'");
+
+		if (PQresultStatus(res) != PGRES_TUPLES_OK)
+		{
+			pg_fatal("Connection to database failed: %s", PQerrorMessage(monitorConn));
+		}
+		if (PQntuples(res) != 1)
+		{
+			pg_fatal("unexpected number of rows received: %d", PQntuples(res));
+		}
+		if (PQnfields(res) != 1)
+		{
+			pg_fatal("unexpected number of columns received: %d", PQnfields(res));
+		}
+		value = PQgetvalue(res, 0, 0);
+		if (*value != '0')
+		{
+			PQclear(res);
+			break;
+		}
+		PQclear(res);
+
+		/*
+		 * wait 10ms before polling again
+		 */
+		pg_usleep(10000);
+	}
+}
+
+static void
+test_cancel(PGconn *conn, const char *conninfo)
+{
+	PGcancel   *cancel = NULL;
+	PGconn	   *monitorConn = NULL;
+	char		errorbuf[256];
+
+	fprintf(stderr, "test cancellations... ");
+
+	if (PQsetnonblocking(conn, 1) != 0)
+		pg_fatal("failed to set nonblocking mode: %s", PQerrorMessage(conn));
+
+	/*
+	 * Make a connection to the database to monitor the query on the main
+	 * connection.
+	 */
+	monitorConn = PQconnectdb(conninfo);
+	if (PQstatus(conn) != CONNECTION_OK)
+	{
+		pg_fatal("Connection to database failed: %s",
+				 PQerrorMessage(conn));
+	}
+
+	/* test PQcancel */
+	send_cancellable_query(conn, monitorConn);
+	cancel = PQgetCancel(conn);
+	if (!PQcancel(cancel, errorbuf, sizeof(errorbuf)))
+	{
+		pg_fatal("failed to run PQcancel: %s", errorbuf);
+	};
+	confirm_query_canceled(conn);
+
+	/* PGcancel object can be reused for the next query */
+	send_cancellable_query(conn, monitorConn);
+	if (!PQcancel(cancel, errorbuf, sizeof(errorbuf)))
+	{
+		pg_fatal("failed to run PQcancel: %s", errorbuf);
+	};
+	confirm_query_canceled(conn);
+
+	PQfreeCancel(cancel);
+
+	/* test PQrequestCancel */
+	send_cancellable_query(conn, monitorConn);
+	if (!PQrequestCancel(conn))
+		pg_fatal("failed to run PQrequestCancel: %s", PQerrorMessage(conn));
+	confirm_query_canceled(conn);
+
+	fprintf(stderr, "ok\n");
+}
+
 static void
 test_disallowed_in_pipeline(PGconn *conn)
 {
@@ -1789,6 +1922,7 @@ usage(const char *progname)
 static void
 print_test_list(void)
 {
+	printf("cancel\n");
 	printf("disallowed_in_pipeline\n");
 	printf("multi_pipelines\n");
 	printf("nosync\n");
@@ -1890,7 +2024,9 @@ main(int argc, char **argv)
 						PQTRACE_SUPPRESS_TIMESTAMPS | PQTRACE_REGRESS_MODE);
 	}
 
-	if (strcmp(testname, "disallowed_in_pipeline") == 0)
+	if (strcmp(testname, "cancel") == 0)
+		test_cancel(conn, conninfo);
+	else if (strcmp(testname, "disallowed_in_pipeline") == 0)
 		test_disallowed_in_pipeline(conn);
 	else if (strcmp(testname, "multi_pipelines") == 0)
 		test_multi_pipelines(conn);
-- 
2.34.1

