From 344cb7758d52594392e955efeb33c32cedfcb47a Mon Sep 17 00:00:00 2001
From: Thomas Munro <thomas.munro@gmail.com>
Date: Tue, 11 Jun 2024 14:32:47 +1200
Subject: [PATCH] XXX toy use of read stream API

---
 src/hnsw.h      |  1 +
 src/hnswutils.c | 93 +++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 94 insertions(+)

diff --git a/src/hnsw.h b/src/hnsw.h
index 480ad9f..e738241 100644
--- a/src/hnsw.h
+++ b/src/hnsw.h
@@ -136,6 +136,7 @@ struct HnswElementData
 	uint8		deleted;
 	uint32		hash;
 	HnswNeighborsPtr neighbors;
+	Buffer		buffer;
 	BlockNumber blkno;
 	OffsetNumber offno;
 	OffsetNumber neighborOffno;
diff --git a/src/hnswutils.c b/src/hnswutils.c
index d3ba911..b9addff 100644
--- a/src/hnswutils.c
+++ b/src/hnswutils.c
@@ -14,6 +14,10 @@
 #include "utils/memdebug.h"
 #include "utils/rel.h"
 
+#if PG_VERSION_NUM >= 170000
+#include "storage/read_stream.h"
+#endif
+
 #if PG_VERSION_NUM >= 130000
 #include "common/hashfn.h"
 #else
@@ -278,6 +282,9 @@ HnswInitElementFromBlock(BlockNumber blkno, OffsetNumber offno)
 	HnswElement element = palloc(sizeof(HnswElementData));
 	char	   *base = NULL;
 
+#if PG_VERSION_NUM >= 170000
+	element->buffer = InvalidBuffer;
+#endif
 	element->blkno = blkno;
 	element->offno = offno;
 	HnswPtrStore(base, element->neighbors, (HnswNeighborArrayPtr *) NULL);
@@ -552,7 +559,20 @@ HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index,
 	HnswElementTuple etup;
 
 	/* Read vector */
+#if PG_VERSION_NUM >= 170000
+	if (element->buffer != InvalidBuffer)
+	{
+		/* Buffer pinned already. */
+		buf = element->buffer;
+		Assert(BufferGetBlockNumber(buf) == element->blkno);
+	}
+	else
+	{
+		buf = ReadBuffer(index, element->blkno);
+	}
+#else
 	buf = ReadBuffer(index, element->blkno);
+#endif
 	LockBuffer(buf, BUFFER_LOCK_SHARE);
 	page = BufferGetPage(buf);
 
@@ -714,6 +734,28 @@ CountElement(char *base, HnswElement skipElement, HnswCandidate * hc)
 	return e->heaptidsLength != 0;
 }
 
+#if PG_VERSION_NUM >= 170000
+typedef struct HnswSearchLayerNextBlockData {
+	char	   *base;
+	HnswCandidate **items;
+	int			nitems;
+	int			i;
+} HnswSearchLayerNextBlockData;
+
+static BlockNumber
+HnswSearchLayerNextBlock(ReadStream *stream, void *callback_data, void *per_buffer_data)
+{
+	HnswSearchLayerNextBlockData *data = callback_data;
+	HnswElement hce;
+
+	if (data->i == data->nitems)
+		return InvalidBlockNumber;
+
+	hce = HnswPtrAccess(data->base, data->items[data->i++]->element);
+	return hce->blkno;
+}
+#endif
+
 /*
  * Algorithm 2 from paper
  */
@@ -729,6 +771,11 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
 	HnswNeighborArray *neighborhoodData = NULL;
 	Size		neighborhoodSize;
 
+#if PG_VERSION_NUM >= 170000
+	ReadStream *stream;
+	HnswSearchLayerNextBlockData stream_callback_data;
+#endif
+
 	InitVisited(base, &v, index, ef, m);
 
 	/* Create local memory for neighborhood if needed */
@@ -764,6 +811,8 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
 		HnswCandidate *c = ((HnswPairingHeapNode *) pairingheap_remove_first(C))->inner;
 		HnswCandidate *f = ((HnswPairingHeapNode *) pairingheap_first(W))->inner;
 		HnswElement cElement;
+		HnswCandidate *items[HNSW_MAX_SIZE];
+		int nitems;
 
 		if (c->distance > f->distance)
 			break;
@@ -785,6 +834,8 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
 			neighborhood = neighborhoodData;
 		}
 
+		/* Build a list of indexes of neighbors to visit. */
+		nitems = 0;
 		for (int i = 0; i < neighborhood->length; i++)
 		{
 			HnswCandidate *e = &neighborhood->items[i];
@@ -793,6 +844,38 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
 			AddToVisited(base, &v, e, index, &visited);
 
 			if (!visited)
+				items[nitems++] = e;
+		}
+
+#if PG_VERSION_NUM >= 170000
+		stream_callback_data.base = base;
+		stream_callback_data.items = items;
+		stream_callback_data.nitems = nitems;
+		stream_callback_data.i = 0;
+
+		/*
+		 * XXX For this quick-and-dirty hack, we'll use a temporary stream for
+		 * each set of neighbors we visit...  Should really re-use a stream,
+		 * and reset it after we hit stall points that need more data to look
+		 * further ahead.
+		 */
+		if (index)
+			stream = read_stream_begin_relation(READ_STREAM_FULL,
+												NULL,
+												index,
+												MAIN_FORKNUM,
+												HnswSearchLayerNextBlock,
+												&stream_callback_data,
+												0);
+		else
+			stream = NULL;
+#endif
+
+		/* Visit them. */
+		for (int i = 0; i < nitems; i++)
+		{
+			HnswCandidate *e = items[i];
+
 			{
 				float		eDistance;
 				HnswElement eElement = HnswPtrAccess(base, e->element);
@@ -802,7 +885,13 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
 				if (index == NULL)
 					eDistance = GetCandidateDistance(base, e, q, procinfo, collation);
 				else
+				{
+#if PG_VERSION_NUM >= 170000
+					if (stream)
+						eElement->buffer = read_stream_next_buffer(stream, NULL);
+#endif
 					HnswLoadElement(eElement, &eDistance, &q, index, procinfo, collation, inserting, wlen >= ef ? &f->distance : NULL);
+				}
 
 				if (eDistance < f->distance || wlen < ef)
 				{
@@ -838,6 +927,10 @@ HnswSearchLayer(char *base, Datum q, List *ep, int ef, int lc, Relation index, F
 				}
 			}
 		}
+#if PG_VERSION_NUM >= 170000
+		if (stream)
+			read_stream_end(stream);
+#endif
 	}
 
 	/* Add each element of W to w */
-- 
2.45.1

