On Wed, Jan 10, 2024 at 09:06:08AM +0700, John Naylor wrote:
> If we have say 25 elements, I mean (for SSE2) check the first 16, then
> the last 16. Some will be checked twice, but that's okay.

I finally got around to trying this.  0001 adds this overlapping logic.
0002 is a rebased version of the AVX2 patch (it needed some updates after
commit 9f225e9).  And 0003 is a benchmark for test_lfind32().  It runs
pg_lfind32() on an array of the given size 100M times.

I've also attached the results of running this benchmark on my machine at
HEAD, after applying 0001, and after applying both 0001 and 0002.  0001
appears to work pretty well.  When there is a small "tail," it regresses a
small amount, but overall, it seems to improve more cases than it harms.
0002 does regress searches on smaller arrays quite a bit, since it
postpones the SIMD optimizations until the arrays are longer.  It might be
possible to mitigate by using 2 registers when the "tail" is long enough,
but I have yet to try that.

-- 
Nathan Bossart
Amazon Web Services: https://aws.amazon.com
>From 3a4d74eeab18d9e8f510e11185109ed910e40268 Mon Sep 17 00:00:00 2001
From: Nathan Bossart <nat...@postgresql.org>
Date: Fri, 15 Mar 2024 12:26:26 -0500
Subject: [PATCH v2 1/3] pg_lfind32: process "tail" with SIMD intructions

---
 src/include/port/pg_lfind.h | 16 ++++++++++++++--
 1 file changed, 14 insertions(+), 2 deletions(-)

diff --git a/src/include/port/pg_lfind.h b/src/include/port/pg_lfind.h
index b8dfa66eef..9d21284724 100644
--- a/src/include/port/pg_lfind.h
+++ b/src/include/port/pg_lfind.h
@@ -103,7 +103,7 @@ pg_lfind32(uint32 key, uint32 *base, uint32 nelem)
 	const uint32 nelem_per_iteration = 4 * nelem_per_vector;
 
 	/* round down to multiple of elements per iteration */
-	const uint32 tail_idx = nelem & ~(nelem_per_iteration - 1);
+	uint32 tail_idx = nelem & ~(nelem_per_iteration - 1);
 
 #if defined(USE_ASSERT_CHECKING)
 	bool		assert_result = false;
@@ -117,9 +117,11 @@ pg_lfind32(uint32 key, uint32 *base, uint32 nelem)
 			break;
 		}
 	}
+	i = 0;
 #endif
 
-	for (i = 0; i < tail_idx; i += nelem_per_iteration)
+retry:
+	for (; i < tail_idx; i += nelem_per_iteration)
 	{
 		Vector32	vals1,
 					vals2,
@@ -157,6 +159,16 @@ pg_lfind32(uint32 key, uint32 *base, uint32 nelem)
 			return true;
 		}
 	}
+
+	if (i == nelem)
+		return false;
+	else if (tail_idx > 0)
+	{
+		tail_idx = nelem;
+		i = nelem - nelem_per_iteration;
+		goto retry;
+	}
+
 #endif							/* ! USE_NO_SIMD */
 
 	/* Process the remaining elements one at a time. */
-- 
2.25.1

>From 0ac61e17b6ed07116086ded2a6a5142da9afa28f Mon Sep 17 00:00:00 2001
From: Nathan Bossart <nat...@postgresql.org>
Date: Fri, 15 Mar 2024 12:26:52 -0500
Subject: [PATCH v2 2/3] add avx2 support in simd.h

---
 src/include/port/simd.h | 58 ++++++++++++++++++++++++++++++++---------
 1 file changed, 45 insertions(+), 13 deletions(-)

diff --git a/src/include/port/simd.h b/src/include/port/simd.h
index 597496f2fb..767127b85c 100644
--- a/src/include/port/simd.h
+++ b/src/include/port/simd.h
@@ -18,7 +18,15 @@
 #ifndef SIMD_H
 #define SIMD_H
 
-#if (defined(__x86_64__) || defined(_M_AMD64))
+#if defined(__AVX2__)
+
+#include <immintrin.h>
+#define USE_AVX2
+typedef __m256i Vector8;
+typedef __m256i Vector32;
+
+#elif (defined(__x86_64__) || defined(_M_AMD64))
+
 /*
  * SSE2 instructions are part of the spec for the 64-bit x86 ISA. We assume
  * that compilers targeting this architecture understand SSE2 intrinsics.
@@ -107,7 +115,9 @@ static inline Vector32 vector32_eq(const Vector32 v1, const Vector32 v2);
 static inline void
 vector8_load(Vector8 *v, const uint8 *s)
 {
-#if defined(USE_SSE2)
+#if defined(USE_AVX2)
+	*v = _mm256_loadu_si256((const __m256i *) s);
+#elif defined(USE_SSE2)
 	*v = _mm_loadu_si128((const __m128i *) s);
 #elif defined(USE_NEON)
 	*v = vld1q_u8(s);
@@ -120,7 +130,9 @@ vector8_load(Vector8 *v, const uint8 *s)
 static inline void
 vector32_load(Vector32 *v, const uint32 *s)
 {
-#ifdef USE_SSE2
+#if defined(USE_AVX2)
+	*v = _mm256_loadu_si256((const __m256i *) s);
+#elif defined(USE_SSE2)
 	*v = _mm_loadu_si128((const __m128i *) s);
 #elif defined(USE_NEON)
 	*v = vld1q_u32(s);
@@ -134,7 +146,9 @@ vector32_load(Vector32 *v, const uint32 *s)
 static inline Vector8
 vector8_broadcast(const uint8 c)
 {
-#if defined(USE_SSE2)
+#if defined(USE_AVX2)
+	return _mm256_set1_epi8(c);
+#elif defined(USE_SSE2)
 	return _mm_set1_epi8(c);
 #elif defined(USE_NEON)
 	return vdupq_n_u8(c);
@@ -147,7 +161,9 @@ vector8_broadcast(const uint8 c)
 static inline Vector32
 vector32_broadcast(const uint32 c)
 {
-#ifdef USE_SSE2
+#if defined(USE_AVX2)
+	return _mm256_set1_epi32(c);
+#elif defined(USE_SSE2)
 	return _mm_set1_epi32(c);
 #elif defined(USE_NEON)
 	return vdupq_n_u32(c);
@@ -270,7 +286,9 @@ vector8_has_le(const Vector8 v, const uint8 c)
 static inline bool
 vector8_is_highbit_set(const Vector8 v)
 {
-#ifdef USE_SSE2
+#if defined(USE_AVX2)
+	return _mm256_movemask_epi8(v) != 0;
+#elif defined(USE_SSE2)
 	return _mm_movemask_epi8(v) != 0;
 #elif defined(USE_NEON)
 	return vmaxvq_u8(v) > 0x7F;
@@ -308,7 +326,9 @@ vector32_is_highbit_set(const Vector32 v)
 static inline uint32
 vector8_highbit_mask(const Vector8 v)
 {
-#ifdef USE_SSE2
+#if defined(USE_AVX2)
+	return (uint32) _mm256_movemask_epi8(v);
+#elif defined(USE_SSE2)
 	return (uint32) _mm_movemask_epi8(v);
 #elif defined(USE_NEON)
 	/*
@@ -337,7 +357,9 @@ vector8_highbit_mask(const Vector8 v)
 static inline Vector8
 vector8_or(const Vector8 v1, const Vector8 v2)
 {
-#ifdef USE_SSE2
+#if defined(USE_AVX2)
+	return _mm256_or_si256(v1, v2);
+#elif defined(USE_SSE2)
 	return _mm_or_si128(v1, v2);
 #elif defined(USE_NEON)
 	return vorrq_u8(v1, v2);
@@ -350,7 +372,9 @@ vector8_or(const Vector8 v1, const Vector8 v2)
 static inline Vector32
 vector32_or(const Vector32 v1, const Vector32 v2)
 {
-#ifdef USE_SSE2
+#if defined(USE_AVX2)
+	return _mm256_or_si256(v1, v2);
+#elif defined(USE_SSE2)
 	return _mm_or_si128(v1, v2);
 #elif defined(USE_NEON)
 	return vorrq_u32(v1, v2);
@@ -368,7 +392,9 @@ vector32_or(const Vector32 v1, const Vector32 v2)
 static inline Vector8
 vector8_ssub(const Vector8 v1, const Vector8 v2)
 {
-#ifdef USE_SSE2
+#if defined(USE_AVX2)
+	return _mm256_subs_epu8(v1, v2);
+#elif defined(USE_SSE2)
 	return _mm_subs_epu8(v1, v2);
 #elif defined(USE_NEON)
 	return vqsubq_u8(v1, v2);
@@ -384,7 +410,9 @@ vector8_ssub(const Vector8 v1, const Vector8 v2)
 static inline Vector8
 vector8_eq(const Vector8 v1, const Vector8 v2)
 {
-#ifdef USE_SSE2
+#if defined(USE_AVX2)
+	return _mm256_cmpeq_epi8(v1, v2);
+#elif defined(USE_SSE2)
 	return _mm_cmpeq_epi8(v1, v2);
 #elif defined(USE_NEON)
 	return vceqq_u8(v1, v2);
@@ -396,7 +424,9 @@ vector8_eq(const Vector8 v1, const Vector8 v2)
 static inline Vector32
 vector32_eq(const Vector32 v1, const Vector32 v2)
 {
-#ifdef USE_SSE2
+#if defined(USE_AVX2)
+	return _mm256_cmpeq_epi32(v1, v2);
+#elif defined(USE_SSE2)
 	return _mm_cmpeq_epi32(v1, v2);
 #elif defined(USE_NEON)
 	return vceqq_u32(v1, v2);
@@ -411,7 +441,9 @@ vector32_eq(const Vector32 v1, const Vector32 v2)
 static inline Vector8
 vector8_min(const Vector8 v1, const Vector8 v2)
 {
-#ifdef USE_SSE2
+#if defined(USE_AVX2)
+	return _mm256_min_epu8(v1, v2);
+#elif defined(USE_SSE2)
 	return _mm_min_epu8(v1, v2);
 #elif defined(USE_NEON)
 	return vminq_u8(v1, v2);
-- 
2.25.1

>From 9b2b61927a8b52637f70659d513ddfeba7c03024 Mon Sep 17 00:00:00 2001
From: Nathan Bossart <nat...@postgresql.org>
Date: Fri, 15 Mar 2024 12:28:00 -0500
Subject: [PATCH v2 3/3] test_lfind32() benchmark

---
 .../modules/test_lfind/sql/test_lfind.sql     | 67 +++++++++++++++++++
 .../modules/test_lfind/test_lfind--1.0.sql    |  4 ++
 src/test/modules/test_lfind/test_lfind.c      | 16 +++++
 3 files changed, 87 insertions(+)

diff --git a/src/test/modules/test_lfind/sql/test_lfind.sql b/src/test/modules/test_lfind/sql/test_lfind.sql
index 766c640831..d8fa461bfa 100644
--- a/src/test/modules/test_lfind/sql/test_lfind.sql
+++ b/src/test/modules/test_lfind/sql/test_lfind.sql
@@ -8,3 +8,70 @@ CREATE EXTENSION test_lfind;
 SELECT test_lfind8();
 SELECT test_lfind8_le();
 SELECT test_lfind32();
+
+\timing on
+SELECT drive_lfind32(0);
+SELECT drive_lfind32(1);
+SELECT drive_lfind32(2);
+SELECT drive_lfind32(3);
+SELECT drive_lfind32(4);
+SELECT drive_lfind32(5);
+SELECT drive_lfind32(6);
+SELECT drive_lfind32(7);
+SELECT drive_lfind32(8);
+SELECT drive_lfind32(9);
+SELECT drive_lfind32(10);
+SELECT drive_lfind32(11);
+SELECT drive_lfind32(12);
+SELECT drive_lfind32(13);
+SELECT drive_lfind32(14);
+SELECT drive_lfind32(15);
+SELECT drive_lfind32(16);
+SELECT drive_lfind32(17);
+SELECT drive_lfind32(18);
+SELECT drive_lfind32(19);
+SELECT drive_lfind32(20);
+SELECT drive_lfind32(21);
+SELECT drive_lfind32(22);
+SELECT drive_lfind32(23);
+SELECT drive_lfind32(24);
+SELECT drive_lfind32(25);
+SELECT drive_lfind32(26);
+SELECT drive_lfind32(27);
+SELECT drive_lfind32(28);
+SELECT drive_lfind32(29);
+SELECT drive_lfind32(30);
+SELECT drive_lfind32(31);
+SELECT drive_lfind32(32);
+SELECT drive_lfind32(33);
+SELECT drive_lfind32(34);
+SELECT drive_lfind32(35);
+SELECT drive_lfind32(36);
+SELECT drive_lfind32(37);
+SELECT drive_lfind32(38);
+SELECT drive_lfind32(39);
+SELECT drive_lfind32(40);
+SELECT drive_lfind32(41);
+SELECT drive_lfind32(42);
+SELECT drive_lfind32(43);
+SELECT drive_lfind32(44);
+SELECT drive_lfind32(45);
+SELECT drive_lfind32(46);
+SELECT drive_lfind32(47);
+SELECT drive_lfind32(48);
+SELECT drive_lfind32(49);
+SELECT drive_lfind32(50);
+SELECT drive_lfind32(51);
+SELECT drive_lfind32(52);
+SELECT drive_lfind32(53);
+SELECT drive_lfind32(54);
+SELECT drive_lfind32(55);
+SELECT drive_lfind32(56);
+SELECT drive_lfind32(57);
+SELECT drive_lfind32(58);
+SELECT drive_lfind32(59);
+SELECT drive_lfind32(60);
+SELECT drive_lfind32(61);
+SELECT drive_lfind32(62);
+SELECT drive_lfind32(63);
+SELECT drive_lfind32(64);
diff --git a/src/test/modules/test_lfind/test_lfind--1.0.sql b/src/test/modules/test_lfind/test_lfind--1.0.sql
index 81801926ae..6b396dbd58 100644
--- a/src/test/modules/test_lfind/test_lfind--1.0.sql
+++ b/src/test/modules/test_lfind/test_lfind--1.0.sql
@@ -14,3 +14,7 @@ CREATE FUNCTION test_lfind8()
 CREATE FUNCTION test_lfind8_le()
 	RETURNS pg_catalog.void
 	AS 'MODULE_PATHNAME' LANGUAGE C;
+
+CREATE FUNCTION drive_lfind32(n int)
+	RETURNS pg_catalog.void
+	AS 'MODULE_PATHNAME' LANGUAGE C;
diff --git a/src/test/modules/test_lfind/test_lfind.c b/src/test/modules/test_lfind/test_lfind.c
index c04bc2f6b4..2234f148b6 100644
--- a/src/test/modules/test_lfind/test_lfind.c
+++ b/src/test/modules/test_lfind/test_lfind.c
@@ -146,3 +146,19 @@ test_lfind32(PG_FUNCTION_ARGS)
 
 	PG_RETURN_VOID();
 }
+
+PG_FUNCTION_INFO_V1(drive_lfind32);
+Datum
+drive_lfind32(PG_FUNCTION_ARGS)
+{
+	int			array_size = PG_GETARG_INT32(0);
+	uint32	   *test_array = palloc0(array_size * sizeof(uint32));
+
+	for (int i = 0; i < 100000000; i++)
+	{
+		if (pg_lfind32(1, test_array, array_size))
+			elog(ERROR, "pg_lfind32() found nonexistent element");
+	}
+
+	PG_RETURN_VOID();
+}
-- 
2.25.1

Reply via email to