From 92f8fb427fe8b9dc6c5fcfffe51421e6814b9750 Mon Sep 17 00:00:00 2001
From: Jelte Fennema <jelte.fennema@microsoft.com>
Date: Thu, 30 Sep 2021 16:29:53 +0200
Subject: [PATCH] Use tcp_user_timeout and keepalives in PQcancel

PGcancel would not adhere to the timeout specified in the
tcp_user_timeout connection option. It would also not enable keepalives
on the socket. So no keepalives would be used by the on these
connections at all, not even the system configured ones.

This means a call to PGcancel could take much longer than expected. This
can especially be an issue because there's no non blocking version of
PGcancel.
---
 src/interfaces/libpq/fe-connect.c | 177 +++++++++++++++++++++++++++---
 src/interfaces/libpq/libpq-int.h  |   8 ++
 2 files changed, 171 insertions(+), 14 deletions(-)

diff --git a/src/interfaces/libpq/fe-connect.c b/src/interfaces/libpq/fe-connect.c
index 56755f0796..5f0f0a5945 100644
--- a/src/interfaces/libpq/fe-connect.c
+++ b/src/interfaces/libpq/fe-connect.c
@@ -4363,13 +4363,57 @@ PQgetCancel(PGconn *conn)
 	if (conn->sock == PGINVALID_SOCKET)
 		return NULL;
 
-	cancel = malloc(sizeof(PGcancel));
+	cancel = calloc(1, sizeof(PGcancel));
 	if (cancel == NULL)
 		return NULL;
 
 	memcpy(&cancel->raddr, &conn->raddr, sizeof(SockAddr));
 	cancel->be_pid = conn->be_pid;
 	cancel->be_key = conn->be_key;
+	// Set all the socket options to -1, so we can use that to see if the
+	// socket options are set or not.
+	cancel->connect_timeout = -1;
+	cancel->pgtcp_user_timeout = -1;
+	cancel->keepalives = -1;
+	cancel->keepalives_idle = -1;
+	cancel->keepalives_interval = -1;
+	cancel->keepalives_count = -1;
+	if (conn->connect_timeout != NULL) {
+		if (!parse_int_param(conn->connect_timeout, &cancel->connect_timeout, conn,
+							 "connect_timeout")) {
+			return NULL;
+		}
+	}
+	if (conn->pgtcp_user_timeout != NULL) {
+		if (!parse_int_param(conn->pgtcp_user_timeout, &cancel->pgtcp_user_timeout, conn,
+							 "tcp_user_timeout")) {
+			return NULL;
+		}
+	}
+	if (conn->keepalives != NULL) {
+		if (!parse_int_param(conn->keepalives, &cancel->keepalives, conn,
+							 "keepalives")) {
+			return NULL;
+		}
+	}
+	if (conn->keepalives_idle != NULL) {
+		if (!parse_int_param(conn->keepalives_idle, &cancel->keepalives_idle, conn,
+							 "keepalives_idle")) {
+			return NULL;
+		}
+	}
+	if (conn->keepalives_interval != NULL) {
+		if (!parse_int_param(conn->keepalives_interval, &cancel->keepalives_interval, conn,
+							 "keepalives_interval")) {
+			return NULL;
+		}
+	}
+	if (conn->keepalives_count != NULL) {
+		if (!parse_int_param(conn->keepalives_count, &cancel->keepalives_count, conn,
+							 "keepalives_count")) {
+			return NULL;
+		}
+	}
 
 	return cancel;
 }
@@ -4404,8 +4448,7 @@ PQfreeCancel(PGcancel *cancel)
  * between the two versions of the cancel function possible.
  */
 static int
-internal_cancel(SockAddr *raddr, int be_pid, int be_key,
-				char *errbuf, int errbufsize)
+internal_cancel(PGcancel *cancel, char *errbuf, int errbufsize)
 {
 	int			save_errno = SOCK_ERRNO;
 	pgsocket	tmpsock = PGINVALID_SOCKET;
@@ -4421,14 +4464,110 @@ internal_cancel(SockAddr *raddr, int be_pid, int be_key,
 	 * We need to open a temporary connection to the postmaster. Do this with
 	 * only kernel calls.
 	 */
-	if ((tmpsock = socket(raddr->addr.ss_family, SOCK_STREAM, 0)) == PGINVALID_SOCKET)
+	if ((tmpsock = socket(cancel->raddr.addr.ss_family, SOCK_STREAM, 0)) == PGINVALID_SOCKET)
 	{
 		strlcpy(errbuf, "PQcancel() -- socket() failed: ", errbufsize);
 		goto cancel_errReturn;
 	}
+
+	if (!IS_AF_UNIX(cancel->raddr.addr.ss_family)) {
+#ifndef WIN32
+#ifdef TCP_USER_TIMEOUT
+		if (cancel->pgtcp_user_timeout >= 0) {
+			if (setsockopt(tmpsock, IPPROTO_TCP, TCP_USER_TIMEOUT,
+						   (char *) &cancel->pgtcp_user_timeout,
+						   sizeof(cancel->pgtcp_user_timeout)) < 0) {
+				goto cancel_errReturn;
+			}
+		}
+#endif
+
+		if (cancel->keepalives != 0) {
+			int on = 1;
+			if (setsockopt(tmpsock,
+							SOL_SOCKET, SO_KEEPALIVE,
+							(char *) &on, sizeof(on)) < 0)
+			{
+				goto cancel_errReturn;
+			}
+		}
+
+#ifdef PG_TCP_KEEPALIVE_IDLE
+		if (cancel->keepalives_idle >= 0) {
+			if (setsockopt(tmpsock, IPPROTO_TCP, PG_TCP_KEEPALIVE_IDLE,
+						   (char *) &cancel->keepalives_idle,
+						   sizeof(cancel->keepalives_idle)) < 0)
+			{
+				goto cancel_errReturn;
+			}
+		}
+#endif
+
+#ifdef TCP_KEEPINTVL
+		if (cancel->keepalives_interval >= 0) {
+			if (setsockopt(tmpsock, IPPROTO_TCP, TCP_KEEPINTVL,
+						   (char *) &cancel->keepalives_interval,
+						   sizeof(cancel->keepalives_interval)) < 0)
+			{
+				goto cancel_errReturn;
+			}
+		}
+#endif
+
+#ifdef TCP_KEEPCNT
+		if (cancel->keepalives_count >= 0) {
+			if (setsockopt(tmpsock, IPPROTO_TCP, TCP_KEEPCNT,
+						   (char *) &cancel->keepalives_count,
+						   sizeof(cancel->keepalives_count)) < 0)
+			{
+				goto cancel_errReturn;
+			}
+		}
+#endif
+#else
+	if (cancel->keepalives != 0) {
+		int idle = cancel->keepalives_idle;
+		int interval = cancel->keepalives_interval;
+		if (idle <= 0)
+			idle = 2 * 60 * 60;		/* 2 hours = default */
+
+		if (conn->keepalives_interval &&
+			!parse_int_param(conn->keepalives_interval, &interval, conn,
+							 "keepalives_interval"))
+			return 0;
+		if (interval <= 0)
+			interval = 1;			/* 1 second = default */
+
+		ka.onoff = 1;
+		ka.keepalivetime = idle * 1000;
+		ka.keepaliveinterval = interval * 1000;
+
+		if (WSAIoctl(conn->sock,
+					 SIO_KEEPALIVE_VALS,
+					 (LPVOID) &ka,
+					 sizeof(ka),
+					 NULL,
+					 0,
+					 &retsize,
+					 NULL,
+					 NULL)
+			!= 0)
+		{
+			appendPQExpBuffer(&conn->errorMessage,
+							  libpq_gettext("%s(%s) failed: error code %d\n"),
+							  "WSAIoctl", "SIO_KEEPALIVE_VALS",
+							  WSAGetLastError());
+			return 0;
+		}
+	}
+
+#endif
+	}
+
+
 retry3:
-	if (connect(tmpsock, (struct sockaddr *) &raddr->addr,
-				raddr->salen) < 0)
+	if (connect(tmpsock, (struct sockaddr *) &cancel->raddr.addr,
+				cancel->raddr.salen) < 0)
 	{
 		if (SOCK_ERRNO == EINTR)
 			/* Interrupted system call - we'll just try again */
@@ -4445,8 +4584,8 @@ retry3:
 
 	crp.packetlen = pg_hton32((uint32) sizeof(crp));
 	crp.cp.cancelRequestCode = (MsgType) pg_hton32(CANCEL_REQUEST_CODE);
-	crp.cp.backendPID = pg_hton32(be_pid);
-	crp.cp.cancelAuthCode = pg_hton32(be_key);
+	crp.cp.backendPID = pg_hton32(cancel->be_pid);
+	crp.cp.cancelAuthCode = pg_hton32(cancel->be_key);
 
 retry4:
 	if (send(tmpsock, (char *) &crp, sizeof(crp), 0) != (int) sizeof(crp))
@@ -4516,8 +4655,7 @@ PQcancel(PGcancel *cancel, char *errbuf, int errbufsize)
 		return false;
 	}
 
-	return internal_cancel(&cancel->raddr, cancel->be_pid, cancel->be_key,
-						   errbuf, errbufsize);
+	return internal_cancel(cancel, errbuf, errbufsize);
 }
 
 /*
@@ -4550,10 +4688,21 @@ PQrequestCancel(PGconn *conn)
 
 		return false;
 	}
-
-	r = internal_cancel(&conn->raddr, conn->be_pid, conn->be_key,
-						conn->errorMessage.data, conn->errorMessage.maxlen);
-
+	PGcancel cancel = {0};
+
+	memcpy(&cancel.raddr, &conn->raddr, sizeof(SockAddr));
+	cancel.be_pid = conn->be_pid;
+	cancel.be_key = conn->be_key;
+	// Set all the socket options to -1, so we can use that to see if the
+	// socket options are set or not.
+	cancel.connect_timeout = -1;
+	cancel.pgtcp_user_timeout = -1;
+	cancel.keepalives = -1;
+	cancel.keepalives_idle = -1;
+	cancel.keepalives_interval = -1;
+	cancel.keepalives_count = -1;
+
+	r = internal_cancel(&cancel, conn->errorMessage.data, conn->errorMessage.maxlen);
 	if (!r)
 		conn->errorMessage.len = strlen(conn->errorMessage.data);
 
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index 334aea4b6e..661c175f14 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -581,6 +581,14 @@ struct pg_cancel
 	SockAddr	raddr;			/* Remote address */
 	int			be_pid;			/* PID of backend --- needed for cancels */
 	int			be_key;			/* key of backend --- needed for cancels */
+	int			connect_timeout;	/* connection timeout */
+	int			pgtcp_user_timeout; /* tcp user timeout */
+	int			keepalives;		/* use TCP keepalives? */
+	int			keepalives_idle;	/* time between TCP keepalives */
+	int			keepalives_interval;	/* time between TCP keepalive
+										 * retransmits */
+	int			keepalives_count;	/* maximum number of TCP keepalive
+									 * retransmits */
 };
 
 
-- 
2.17.1

