From cb55070e7e524fa93b906a395908514f82a147a5 Mon Sep 17 00:00:00 2001
From: "shiy.fnst" <shiy.fnst@fujitsu.com>
Date: Tue, 20 Jul 2021 20:36:58 +0800
Subject: [PATCH] test-use-quick-select-to-get-median

---
 src/backend/access/spgist/spgquadtreeproc.c | 11 ++--
 src/backend/utils/adt/geo_spgist.c          | 10 ++--
 src/backend/utils/adt/rangetypes_spgist.c   | 14 +++---
 src/include/lib/sort_template.h             | 56 +++++++++++++++++----
 src/include/port.h                          |  5 ++
 src/port/Makefile                           |  1 +
 src/port/qselect.c                          | 24 +++++++++
 src/tools/msvc/Mkvcbuild.pm                 |  2 +-
 8 files changed, 97 insertions(+), 26 deletions(-)
 create mode 100644 src/port/qselect.c

diff --git a/src/backend/access/spgist/spgquadtreeproc.c b/src/backend/access/spgist/spgquadtreeproc.c
index a52d924..f9df8fc 100644
--- a/src/backend/access/spgist/spgquadtreeproc.c
+++ b/src/backend/access/spgist/spgquadtreeproc.c
@@ -176,17 +176,18 @@ spg_quad_picksplit(PG_FUNCTION_ARGS)
 #ifdef USE_MEDIAN
 	/* Use the median values of x and y as the centroid point */
 	Point	  **sorted;
+	int			median;
 
 	sorted = palloc(sizeof(*sorted) * in->nTuples);
 	for (i = 0; i < in->nTuples; i++)
 		sorted[i] = DatumGetPointP(in->datums[i]);
 
 	centroid = palloc(sizeof(*centroid));
-
-	qsort(sorted, in->nTuples, sizeof(*sorted), x_cmp);
-	centroid->x = sorted[in->nTuples >> 1]->x;
-	qsort(sorted, in->nTuples, sizeof(*sorted), y_cmp);
-	centroid->y = sorted[in->nTuples >> 1]->y;
+	median = in->nTuples >> 1;
+	qselect(sorted, in->nTuples, median, sizeof(*sorted), x_cmp);
+	centroid->x = sorted[median]->x;
+	qselect(sorted, in->nTuples, median, sizeof(*sorted), y_cmp);
+	centroid->y = sorted[median]->y;
 #else
 	/* Use the average values of x and y as the centroid point */
 	centroid = palloc0(sizeof(*centroid));
diff --git a/src/backend/utils/adt/geo_spgist.c b/src/backend/utils/adt/geo_spgist.c
index 6ee75d0..54bfa6c 100644
--- a/src/backend/utils/adt/geo_spgist.c
+++ b/src/backend/utils/adt/geo_spgist.c
@@ -461,13 +461,13 @@ spg_box_quad_picksplit(PG_FUNCTION_ARGS)
 		highYs[i] = box->high.y;
 	}
 
-	qsort(lowXs, in->nTuples, sizeof(float8), compareDoubles);
-	qsort(highXs, in->nTuples, sizeof(float8), compareDoubles);
-	qsort(lowYs, in->nTuples, sizeof(float8), compareDoubles);
-	qsort(highYs, in->nTuples, sizeof(float8), compareDoubles);
-
 	median = in->nTuples / 2;
 
+	qselect(lowXs, in->nTuples, median, sizeof(float8), compareDoubles);
+	qselect(highXs, in->nTuples, median, sizeof(float8), compareDoubles);
+	qselect(lowYs, in->nTuples, median, sizeof(float8), compareDoubles);
+	qselect(highYs, in->nTuples, median, sizeof(float8), compareDoubles);
+
 	centroid = palloc(sizeof(BOX));
 
 	centroid->low.x = lowXs[median];
diff --git a/src/backend/utils/adt/rangetypes_spgist.c b/src/backend/utils/adt/rangetypes_spgist.c
index f29de6a..3128a94 100644
--- a/src/backend/utils/adt/rangetypes_spgist.c
+++ b/src/backend/utils/adt/rangetypes_spgist.c
@@ -204,6 +204,7 @@ spg_range_quad_picksplit(PG_FUNCTION_ARGS)
 	int			i;
 	int			j;
 	int			nonEmptyCount;
+	int			median;
 	RangeType  *centroid;
 	bool		empty;
 	TypeCacheEntry *typcache;
@@ -257,15 +258,16 @@ spg_range_quad_picksplit(PG_FUNCTION_ARGS)
 		PG_RETURN_VOID();
 	}
 
+	median = nonEmptyCount / 2;
 	/* Sort range bounds in order to find medians */
-	qsort_arg(lowerBounds, nonEmptyCount, sizeof(RangeBound),
-			  bound_cmp, typcache);
-	qsort_arg(upperBounds, nonEmptyCount, sizeof(RangeBound),
-			  bound_cmp, typcache);
+	qselect_arg(lowerBounds, nonEmptyCount, median, sizeof(RangeBound),
+			bound_cmp, typcache);
+	qselect_arg(upperBounds, nonEmptyCount, median, sizeof(RangeBound),
+			bound_cmp, typcache);
 
 	/* Construct "centroid" range from medians of lower and upper bounds */
-	centroid = range_serialize(typcache, &lowerBounds[nonEmptyCount / 2],
-							   &upperBounds[nonEmptyCount / 2], false);
+	centroid = range_serialize(typcache, &lowerBounds[median],
+							   &upperBounds[median], false);
 	out->hasPrefix = true;
 	out->prefixDatum = RangeTypePGetDatum(centroid);
 
diff --git a/src/include/lib/sort_template.h b/src/include/lib/sort_template.h
index f52627d..39bc9e0 100644
--- a/src/include/lib/sort_template.h
+++ b/src/include/lib/sort_template.h
@@ -15,6 +15,7 @@
  *
  *	  - ST_SORT - the name of a sort function to be generated
  *	  - ST_ELEMENT_TYPE - type of the referenced elements
+ *	  - ST_TOP_K - if defined the sort will stop when find the top nth element
  *	  - ST_DECLARE - if defined the functions and types are declared
  *	  - ST_DEFINE - if defined the functions and types are defined
  *	  - ST_SCOPE - scope (e.g. extern, static inline) for functions
@@ -172,6 +173,24 @@
 #define ST_SORT_INVOKE_ARG
 #endif
 
+/*
+ * If the user only want to get the top k element, we can stop the sort once we
+ * found the target.
+ */
+#ifdef ST_TOP_K
+#define ST_TOP_K_IN_RIGHT(a_) a_ > target
+#define ST_TOP_K_IN_LEFT(a_) a_ <= target
+#define ST_RESET_TOP_K(a_) target = target - (a_)
+#define ST_SORT_PROTO_NTH , size_t target
+#define ST_SORT_INVOKE_NTH , target
+#else
+#define ST_TOP_K_IN_RIGHT(a_) true
+#define ST_TOP_K_IN_LEFT(a_) true
+#define ST_RESET_TOP_K(a_)
+#define ST_SORT_PROTO_NTH
+#define ST_SORT_INVOKE_NTH
+#endif
+
 #ifdef ST_DECLARE
 
 #ifdef ST_COMPARE_RUNTIME_POINTER
@@ -181,6 +200,7 @@ typedef int (*ST_COMPARATOR_TYPE_NAME) (const ST_ELEMENT_TYPE *,
 
 /* Declare the sort function.  Note optional arguments at end. */
 ST_SCOPE void ST_SORT(ST_ELEMENT_TYPE * first, size_t n
+					  ST_SORT_PROTO_NTH
 					  ST_SORT_PROTO_ELEMENT_SIZE
 					  ST_SORT_PROTO_COMPARE
 					  ST_SORT_PROTO_ARG);
@@ -218,6 +238,7 @@ ST_SCOPE void ST_SORT(ST_ELEMENT_TYPE * first, size_t n
 			ST_SORT_INVOKE_ARG)
 #define DO_SORT(a_, n_)													\
 	ST_SORT((a_), (n_)													\
+			ST_SORT_INVOKE_NTH											\
 			ST_SORT_INVOKE_ELEMENT_SIZE									\
 			ST_SORT_INVOKE_COMPARE										\
 			ST_SORT_INVOKE_ARG)
@@ -277,6 +298,7 @@ ST_SWAPN(ST_POINTER_TYPE * a, ST_POINTER_TYPE * b, size_t n)
  */
 ST_SCOPE void
 ST_SORT(ST_ELEMENT_TYPE * data, size_t n
+		ST_SORT_PROTO_NTH
 		ST_SORT_PROTO_ELEMENT_SIZE
 		ST_SORT_PROTO_COMPARE
 		ST_SORT_PROTO_ARG)
@@ -290,7 +312,9 @@ ST_SORT(ST_ELEMENT_TYPE * data, size_t n
 			   *pm,
 			   *pn;
 	size_t		d1,
-				d2;
+				d2,
+				nd1,
+				nd2;
 	int			r,
 				presorted;
 
@@ -371,30 +395,38 @@ loop:
 	DO_SWAPN(pb, pn - d1, d1);
 	d1 = pb - pa;
 	d2 = pd - pc;
+
+	nd1 = d1 / ST_POINTER_STEP;
+	nd2 = d2 / ST_POINTER_STEP;
+
 	if (d1 <= d2)
 	{
 		/* Recurse on left partition, then iterate on right partition */
-		if (d1 > ST_POINTER_STEP)
-			DO_SORT(a, d1 / ST_POINTER_STEP);
-		if (d2 > ST_POINTER_STEP)
+		if (d1 > ST_POINTER_STEP && ST_TOP_K_IN_RIGHT(nd1))
+			DO_SORT(a, nd1);
+		if (d2 > ST_POINTER_STEP && ST_TOP_K_IN_LEFT(n - nd2))
 		{
 			/* Iterate rather than recurse to save stack space */
 			/* DO_SORT(pn - d2, d2 / ST_POINTER_STEP) */
+			ST_RESET_TOP_K(n - nd2);
 			a = pn - d2;
-			n = d2 / ST_POINTER_STEP;
+			n = nd2;
 			goto loop;
 		}
 	}
 	else
 	{
 		/* Recurse on right partition, then iterate on left partition */
-		if (d2 > ST_POINTER_STEP)
-			DO_SORT(pn - d2, d2 / ST_POINTER_STEP);
-		if (d1 > ST_POINTER_STEP)
+		if (d2 > ST_POINTER_STEP && ST_TOP_K_IN_LEFT(n - nd2))
+		{
+			ST_RESET_TOP_K(n - nd2);
+			DO_SORT(pn - d2, nd2);
+		}
+		if (d1 > ST_POINTER_STEP && ST_TOP_K_IN_RIGHT(nd1))
 		{
 			/* Iterate rather than recurse to save stack space */
 			/* DO_SORT(a, d1 / ST_POINTER_STEP) */
-			n = d1 / ST_POINTER_STEP;
+			n = nd1;
 			goto loop;
 		}
 	}
@@ -429,3 +461,9 @@ loop:
 #undef ST_SORT_PROTO_ELEMENT_SIZE
 #undef ST_SWAP
 #undef ST_SWAPN
+#undef ST_TOP_K
+#undef ST_TOP_K_IN_RIGHT
+#undef ST_TOP_K_IN_LEFT
+#undef ST_RESET_TOP_K
+#undef ST_SORT_PROTO_NTH
+#undef ST_SORT_INVOKE_NTH
diff --git a/src/include/port.h b/src/include/port.h
index 82f63de..14899be 100644
--- a/src/include/port.h
+++ b/src/include/port.h
@@ -508,6 +508,11 @@ typedef int (*qsort_arg_comparator) (const void *a, const void *b, void *arg);
 extern void qsort_arg(void *base, size_t nel, size_t elsize,
 					  qsort_arg_comparator cmp, void *arg);
 
+extern void qselect_arg(void *base, size_t nel, size_t target, size_t elsize,
+					  qsort_arg_comparator cmp, void *arg);
+extern void qselect(void *base, size_t nel, size_t target, size_t elsize,
+					int (*cmp) (const void *, const void *));
+
 extern void *bsearch_arg(const void *key, const void *base,
 						 size_t nmemb, size_t size,
 						 int (*compar) (const void *, const void *, void *),
diff --git a/src/port/Makefile b/src/port/Makefile
index 52dbf57..c82cd65 100644
--- a/src/port/Makefile
+++ b/src/port/Makefile
@@ -54,6 +54,7 @@ OBJS = \
 	pgstrcasecmp.o \
 	pgstrsignal.o \
 	pqsignal.o \
+	qselect.o \
 	qsort.o \
 	qsort_arg.o \
 	quotes.o \
diff --git a/src/port/qselect.c b/src/port/qselect.c
new file mode 100644
index 0000000..c8689d8
--- /dev/null
+++ b/src/port/qselect.c
@@ -0,0 +1,24 @@
+/*
+ *	qselect.c: standard quickselect algorithm
+ */
+
+#include "c.h"
+
+#define ST_SORT qselect
+#define ST_TOP_K
+#define ST_ELEMENT_TYPE_VOID
+#define ST_COMPARE_RUNTIME_POINTER
+#define ST_SCOPE
+#define ST_DECLARE
+#define ST_DEFINE
+#include "lib/sort_template.h"
+
+#define ST_SORT qselect_arg
+#define ST_TOP_K
+#define ST_ELEMENT_TYPE_VOID
+#define ST_COMPARATOR_TYPE_NAME qsort_arg_comparator
+#define ST_COMPARE_RUNTIME_POINTER
+#define ST_COMPARE_ARG_TYPE void
+#define ST_SCOPE
+#define ST_DEFINE
+#include "lib/sort_template.h"
diff --git a/src/tools/msvc/Mkvcbuild.pm b/src/tools/msvc/Mkvcbuild.pm
index 233ddbf..6f64a72 100644
--- a/src/tools/msvc/Mkvcbuild.pm
+++ b/src/tools/msvc/Mkvcbuild.pm
@@ -109,7 +109,7 @@ sub mkvcbuild
 	  dirent.c dlopen.c getopt.c getopt_long.c link.c
 	  pread.c preadv.c pwrite.c pwritev.c pg_bitutils.c
 	  pg_strong_random.c pgcheckdir.c pgmkdirp.c pgsleep.c pgstrcasecmp.c
-	  pqsignal.c mkdtemp.c qsort.c qsort_arg.c bsearch_arg.c quotes.c system.c
+	  pqsignal.c mkdtemp.c qselect.c qsort.c qsort_arg.c bsearch_arg.c quotes.c system.c
 	  strerror.c tar.c thread.c
 	  win32env.c win32error.c win32security.c win32setlocale.c win32stat.c);
 
-- 
2.27.0

