From bda601d01c2cba6a6544e36d51fc62b79d8a3f53 Mon Sep 17 00:00:00 2001
From: Shlok Kyal <shlok.kyal.oss@gmail.com>
Date: Fri, 27 Sep 2024 16:04:54 +0530
Subject: [PATCH v11 2/2] Selective Invalidation of Cache

When we alter a publication, add/drop namespace to/from publication
all the cache for all the tables are invalidated.
With this patch for the above operationns we will invalidate the
cache of only the desired tables.

Added a new callback function 'rel_sync_cache_publicationrel_cb' which
is called when there is any change in pg_publication catalog and it
invalidates the tables present in the publication modified.
---
 src/backend/replication/pgoutput/pgoutput.c |  58 ++++++---
 src/backend/utils/cache/inval.c             | 130 +++++++++++++++++++-
 src/include/storage/sinval.h                |   9 ++
 src/include/utils/inval.h                   |   4 +
 4 files changed, 179 insertions(+), 22 deletions(-)

diff --git a/src/backend/replication/pgoutput/pgoutput.c b/src/backend/replication/pgoutput/pgoutput.c
index 00e7024563..1d80d27d0f 100644
--- a/src/backend/replication/pgoutput/pgoutput.c
+++ b/src/backend/replication/pgoutput/pgoutput.c
@@ -132,6 +132,8 @@ typedef struct RelationSyncEntry
 	List	   *streamed_txns;	/* streamed toplevel transactions with this
 								 * schema */
 
+	List	   *pub_ids;
+
 	/* are we publishing this rel? */
 	PublicationActions pubactions;
 
@@ -216,6 +218,7 @@ static RelationSyncEntry *get_rel_sync_entry(PGOutputData *data,
 static void rel_sync_cache_relation_cb(Datum arg, Oid relid);
 static void rel_sync_cache_publication_cb(Datum arg, int cacheid,
 										  uint32 hashvalue);
+static void rel_sync_cache_publicationrel_cb(Datum arg, Oid pubid);
 static void set_schema_sent_in_streamed_txn(RelationSyncEntry *entry,
 											TransactionId xid);
 static bool get_schema_sent_in_streamed_txn(RelationSyncEntry *entry,
@@ -1134,7 +1137,7 @@ init_tuple_slot(PGOutputData *data, Relation relation,
 	TupleDesc	oldtupdesc;
 	TupleDesc	newtupdesc;
 
-	oldctx = MemoryContextSwitchTo(data->cachectx);
+	oldctx = MemoryContextSwitchTo(CacheMemoryContext);
 
 	/*
 	 * Create tuple table slots. Create a copy of the TupleDesc as it needs to
@@ -1739,12 +1742,6 @@ static void
 publication_invalidation_cb(Datum arg, int cacheid, uint32 hashvalue)
 {
 	publications_valid = false;
-
-	/*
-	 * Also invalidate per-relation cache so that next time the filtering info
-	 * is checked it will be updated with the new publication settings.
-	 */
-	rel_sync_cache_publication_cb(arg, cacheid, hashvalue);
 }
 
 /*
@@ -1920,17 +1917,7 @@ init_rel_sync_cache(MemoryContext cachectx)
 								  rel_sync_cache_publication_cb,
 								  (Datum) 0);
 
-	/*
-	 * Flush all cache entries after any publication changes.  (We need no
-	 * callback entry for pg_publication, because publication_invalidation_cb
-	 * will take care of it.)
-	 */
-	CacheRegisterSyscacheCallback(PUBLICATIONRELMAP,
-								  rel_sync_cache_publication_cb,
-								  (Datum) 0);
-	CacheRegisterSyscacheCallback(PUBLICATIONNAMESPACEMAP,
-								  rel_sync_cache_publication_cb,
-								  (Datum) 0);
+	CacheRegisterPubcacheCallback(rel_sync_cache_publicationrel_cb, (Datum) 0);
 
 	relation_callbacks_registered = true;
 }
@@ -2000,6 +1987,7 @@ get_rel_sync_entry(PGOutputData *data, Relation relation)
 		entry->publish_as_relid = InvalidOid;
 		entry->columns = NULL;
 		entry->attrmap = NULL;
+		entry->pub_ids = NIL;
 	}
 
 	/* Validate the entry */
@@ -2044,6 +2032,8 @@ get_rel_sync_entry(PGOutputData *data, Relation relation)
 		entry->schema_sent = false;
 		list_free(entry->streamed_txns);
 		entry->streamed_txns = NIL;
+		list_free(entry->pub_ids);
+		entry->pub_ids = NIL;
 		bms_free(entry->columns);
 		entry->columns = NULL;
 		entry->pubactions.pubinsert = false;
@@ -2108,6 +2098,10 @@ get_rel_sync_entry(PGOutputData *data, Relation relation)
 
 					pub_relid = llast_oid(ancestors);
 					ancestor_level = list_length(ancestors);
+
+					oldctx = MemoryContextSwitchTo(CacheMemoryContext);
+					entry->pub_ids = lappend_oid(entry->pub_ids, pub->oid);
+					MemoryContextSwitchTo(oldctx);
 				}
 			}
 
@@ -2145,7 +2139,12 @@ get_rel_sync_entry(PGOutputData *data, Relation relation)
 				if (list_member_oid(pubids, pub->oid) ||
 					list_member_oid(schemaPubids, pub->oid) ||
 					ancestor_published)
+				{
 					publish = true;
+					oldctx = MemoryContextSwitchTo(CacheMemoryContext);
+					entry->pub_ids = lappend_oid(entry->pub_ids, pub->oid);
+					MemoryContextSwitchTo(oldctx);
+				}
 			}
 
 			/*
@@ -2318,6 +2317,29 @@ rel_sync_cache_relation_cb(Datum arg, Oid relid)
 	}
 }
 
+/*
+ * Publication invalidation callback
+ */
+static void
+rel_sync_cache_publicationrel_cb(Datum arg, Oid pubid)
+{
+	HASH_SEQ_STATUS status;
+	RelationSyncEntry *entry;
+
+	if (RelationSyncCache == NULL)
+		return;
+
+	hash_seq_init(&status, RelationSyncCache);
+	while ((entry = (RelationSyncEntry *) hash_seq_search(&status)) != NULL)
+	{
+		if (entry->replicate_valid && list_member_oid(entry->pub_ids, pubid))
+		{
+			entry->replicate_valid = false;
+			entry->pub_ids = NIL;
+		}
+	}
+}
+
 /*
  * Publication relation/schema map syscache invalidation callback
  *
diff --git a/src/backend/utils/cache/inval.c b/src/backend/utils/cache/inval.c
index 603aa4157b..a34be79ee6 100644
--- a/src/backend/utils/cache/inval.c
+++ b/src/backend/utils/cache/inval.c
@@ -160,6 +160,9 @@
  */
 #define CatCacheMsgs 0
 #define RelCacheMsgs 1
+#define PubCacheMsgs 2
+
+#define NumberofCache 3
 
 /* Pointers to main arrays in TopTransactionContext */
 typedef struct InvalMessageArray
@@ -168,13 +171,13 @@ typedef struct InvalMessageArray
 	int			maxmsgs;		/* current allocated size of array */
 } InvalMessageArray;
 
-static InvalMessageArray InvalMessageArrays[2];
+static InvalMessageArray InvalMessageArrays[NumberofCache];
 
 /* Control information for one logical group of messages */
 typedef struct InvalidationMsgsGroup
 {
-	int			firstmsg[2];	/* first index in relevant array */
-	int			nextmsg[2];		/* last+1 index */
+	int			firstmsg[NumberofCache];	/* first index in relevant array */
+	int			nextmsg[NumberofCache];		/* last+1 index */
 } InvalidationMsgsGroup;
 
 /* Macros to help preserve InvalidationMsgsGroup abstraction */
@@ -189,6 +192,7 @@ typedef struct InvalidationMsgsGroup
 	do { \
 		SetSubGroupToFollow(targetgroup, priorgroup, CatCacheMsgs); \
 		SetSubGroupToFollow(targetgroup, priorgroup, RelCacheMsgs); \
+		SetSubGroupToFollow(targetgroup, priorgroup, PubCacheMsgs); \
 	} while (0)
 
 #define NumMessagesInSubGroup(group, subgroup) \
@@ -196,7 +200,8 @@ typedef struct InvalidationMsgsGroup
 
 #define NumMessagesInGroup(group) \
 	(NumMessagesInSubGroup(group, CatCacheMsgs) + \
-	 NumMessagesInSubGroup(group, RelCacheMsgs))
+	 NumMessagesInSubGroup(group, RelCacheMsgs) + \
+	 NumMessagesInSubGroup(group, PubCacheMsgs))
 
 
 /*----------------
@@ -251,6 +256,7 @@ int			debug_discard_caches = 0;
 
 #define MAX_SYSCACHE_CALLBACKS 64
 #define MAX_RELCACHE_CALLBACKS 10
+#define MAX_PUBCACHE_CALLBACKS 10
 
 static struct SYSCACHECALLBACK
 {
@@ -272,6 +278,14 @@ static struct RELCACHECALLBACK
 
 static int	relcache_callback_count = 0;
 
+static struct PUBCACHECALLBACK
+{
+	PubcacheCallbackFunction function;
+	Datum		arg;
+}			pubcache_callback_list[MAX_PUBCACHE_CALLBACKS];
+
+static int	pubcache_callback_count = 0;
+
 /* ----------------------------------------------------------------
  *				Invalidation subgroup support functions
  * ----------------------------------------------------------------
@@ -464,6 +478,38 @@ AddRelcacheInvalidationMessage(InvalidationMsgsGroup *group,
 	AddInvalidationMessage(group, RelCacheMsgs, &msg);
 }
 
+/*
+ * Add a publication inval entry
+ */
+static void
+AddPubcacheInvalidationMessage(InvalidationMsgsGroup *group,
+							   Oid dbId, Oid pubId)
+{
+	SharedInvalidationMessage msg;
+
+	/*
+	 * Don't add a duplicate item. We assume dbId need not be checked because
+	 * it will never change. InvalidOid for relId means all relations so we
+	 * don't need to add individual ones when it is present.
+	 */
+
+	ProcessMessageSubGroup(group, PubCacheMsgs,
+						   if (msg->pc.id == SHAREDINVALPUBCACHE_ID &&
+							   (msg->pc.pubId == pubId ||
+								msg->pc.pubId == InvalidOid))
+						   return);
+
+
+	/* OK, add the item */
+	msg.pc.id = SHAREDINVALPUBCACHE_ID;
+	msg.pc.dbId = dbId;
+	msg.pc.pubId = pubId;
+	/* check AddCatcacheInvalidationMessage() for an explanation */
+	VALGRIND_MAKE_MEM_DEFINED(&msg, sizeof(msg));
+
+	AddInvalidationMessage(group, PubCacheMsgs, &msg);
+}
+
 /*
  * Add a snapshot inval entry
  *
@@ -502,6 +548,7 @@ AppendInvalidationMessages(InvalidationMsgsGroup *dest,
 {
 	AppendInvalidationMessageSubGroup(dest, src, CatCacheMsgs);
 	AppendInvalidationMessageSubGroup(dest, src, RelCacheMsgs);
+	AppendInvalidationMessageSubGroup(dest, src, PubCacheMsgs);
 }
 
 /*
@@ -516,6 +563,7 @@ ProcessInvalidationMessages(InvalidationMsgsGroup *group,
 {
 	ProcessMessageSubGroup(group, CatCacheMsgs, func(msg));
 	ProcessMessageSubGroup(group, RelCacheMsgs, func(msg));
+	ProcessMessageSubGroup(group, PubCacheMsgs, func(msg));
 }
 
 /*
@@ -528,6 +576,7 @@ ProcessInvalidationMessagesMulti(InvalidationMsgsGroup *group,
 {
 	ProcessMessageSubGroupMulti(group, CatCacheMsgs, func(msgs, n));
 	ProcessMessageSubGroupMulti(group, RelCacheMsgs, func(msgs, n));
+	ProcessMessageSubGroupMulti(group, PubCacheMsgs, func(msgs, n));
 }
 
 /* ----------------------------------------------------------------
@@ -590,6 +639,18 @@ RegisterRelcacheInvalidation(Oid dbId, Oid relId)
 		transInvalInfo->RelcacheInitFileInval = true;
 }
 
+/*
+ * RegisterPubcacheInvalidation
+ *
+ * As above, but register a publication invalidation event.
+ */
+static void
+RegisterPubcacheInvalidation(Oid dbId, Oid pubId)
+{
+	AddPubcacheInvalidationMessage(&transInvalInfo->CurrentCmdInvalidMsgs,
+								   dbId, pubId);
+}
+
 /*
  * RegisterSnapshotInvalidation
  *
@@ -660,6 +721,8 @@ PrepareInvalidationState(void)
 		InvalMessageArrays[CatCacheMsgs].maxmsgs = 0;
 		InvalMessageArrays[RelCacheMsgs].msgs = NULL;
 		InvalMessageArrays[RelCacheMsgs].maxmsgs = 0;
+		InvalMessageArrays[PubCacheMsgs].msgs = NULL;
+		InvalMessageArrays[PubCacheMsgs].maxmsgs = 0;
 	}
 
 	transInvalInfo = myInfo;
@@ -773,6 +836,20 @@ LocalExecuteInvalidationMessage(SharedInvalidationMessage *msg)
 		else if (msg->sn.dbId == MyDatabaseId)
 			InvalidateCatalogSnapshot();
 	}
+	else if (msg->id == SHAREDINVALPUBCACHE_ID)
+	{
+		if (msg->pc.dbId == MyDatabaseId || msg->pc.dbId == InvalidOid)
+		{
+			int			i;
+
+			for (i = 0; i < pubcache_callback_count; i++)
+			{
+				struct PUBCACHECALLBACK *pcitem = pubcache_callback_list + i;
+
+				pcitem->function(pcitem->arg, msg->pc.pubId);
+			}
+		}
+	}
 	else
 		elog(FATAL, "unrecognized SI message ID: %d", msg->id);
 }
@@ -944,6 +1021,18 @@ xactGetCommittedInvalidationMessages(SharedInvalidationMessage **msgs,
 										msgs,
 										n * sizeof(SharedInvalidationMessage)),
 								 nmsgs += n));
+	ProcessMessageSubGroupMulti(&transInvalInfo->PriorCmdInvalidMsgs,
+								PubCacheMsgs,
+								(memcpy(msgarray + nmsgs,
+										msgs,
+										n * sizeof(SharedInvalidationMessage)),
+								 nmsgs += n));
+	ProcessMessageSubGroupMulti(&transInvalInfo->CurrentCmdInvalidMsgs,
+								PubCacheMsgs,
+								(memcpy(msgarray + nmsgs,
+										msgs,
+										n * sizeof(SharedInvalidationMessage)),
+								 nmsgs += n));
 	Assert(nmsgs == nummsgs);
 
 	return nmsgs;
@@ -1312,6 +1401,17 @@ CacheInvalidateHeapTuple(Relation relation,
 		else
 			return;
 	}
+	else if (tupleRelId == PublicationRelationId)
+	{
+		Form_pg_publication pubtup = (Form_pg_publication) GETSTRUCT(tuple);
+
+		/* get publication id */
+		relationId = pubtup->oid;
+		databaseId = MyDatabaseId;
+
+		RegisterPubcacheInvalidation(databaseId, relationId);
+		return;
+	}
 	else
 		return;
 
@@ -1567,6 +1667,25 @@ CacheRegisterRelcacheCallback(RelcacheCallbackFunction func,
 	++relcache_callback_count;
 }
 
+/*
+ * CacheRegisterPubcacheCallback
+ *		Register the specified function to be called for all future
+ *		publication invalidation events.  The OID of the publication being
+ *		invalidated will be passed to the function.
+ */
+void
+CacheRegisterPubcacheCallback(PubcacheCallbackFunction func,
+							  Datum arg)
+{
+	if (pubcache_callback_count >= MAX_PUBCACHE_CALLBACKS)
+		elog(FATAL, "out of pubcache_callback_list slots");
+
+	pubcache_callback_list[pubcache_callback_count].function = func;
+	pubcache_callback_list[pubcache_callback_count].arg = arg;
+
+	++pubcache_callback_count;
+}
+
 /*
  * CallSyscacheCallbacks
  *
@@ -1629,6 +1748,9 @@ LogLogicalInvalidations(void)
 		ProcessMessageSubGroupMulti(group, RelCacheMsgs,
 									XLogRegisterData((char *) msgs,
 													 n * sizeof(SharedInvalidationMessage)));
+		ProcessMessageSubGroupMulti(group, PubCacheMsgs,
+									XLogRegisterData((char *) msgs,
+													 n * sizeof(SharedInvalidationMessage)));
 		XLogInsert(RM_XACT_ID, XLOG_XACT_INVALIDATIONS);
 	}
 }
diff --git a/src/include/storage/sinval.h b/src/include/storage/sinval.h
index 8f5744b21b..9a97268b0a 100644
--- a/src/include/storage/sinval.h
+++ b/src/include/storage/sinval.h
@@ -110,6 +110,14 @@ typedef struct
 	Oid			relId;			/* relation ID */
 } SharedInvalSnapshotMsg;
 
+#define SHAREDINVALPUBCACHE_ID	(-6)
+typedef struct
+{
+	int8		id;				/* type field --- must be first */
+	Oid			dbId;			/* database ID, or 0 if a shared relation */
+	Oid			pubId;			/* publication ID */
+} SharedInvalPubcacheMsg;
+
 typedef union
 {
 	int8		id;				/* type field --- must be first */
@@ -119,6 +127,7 @@ typedef union
 	SharedInvalSmgrMsg sm;
 	SharedInvalRelmapMsg rm;
 	SharedInvalSnapshotMsg sn;
+	SharedInvalPubcacheMsg pc;
 } SharedInvalidationMessage;
 
 
diff --git a/src/include/utils/inval.h b/src/include/utils/inval.h
index 24695facf2..66d27b8bee 100644
--- a/src/include/utils/inval.h
+++ b/src/include/utils/inval.h
@@ -22,6 +22,7 @@ extern PGDLLIMPORT int debug_discard_caches;
 
 typedef void (*SyscacheCallbackFunction) (Datum arg, int cacheid, uint32 hashvalue);
 typedef void (*RelcacheCallbackFunction) (Datum arg, Oid relid);
+typedef void (*PubcacheCallbackFunction) (Datum arg, Oid pubid);
 
 
 extern void AcceptInvalidationMessages(void);
@@ -59,6 +60,9 @@ extern void CacheRegisterSyscacheCallback(int cacheid,
 extern void CacheRegisterRelcacheCallback(RelcacheCallbackFunction func,
 										  Datum arg);
 
+extern void CacheRegisterPubcacheCallback(PubcacheCallbackFunction func,
+										  Datum arg);
+
 extern void CallSyscacheCallbacks(int cacheid, uint32 hashvalue);
 
 extern void InvalidateSystemCaches(void);
-- 
2.34.1

