Replaced heap sort algorithm with faster introspective sort algorithm.

Signed-off-by: Andrey Abramov <st5...@yandex.ru>
---
v1: The introspective sort algorithm is faster the heap sort (for example on my 
machine on a 100MB of random data it was consistently almost twice faster) and 
it doesn't have the worst case, unlike qsort.

 lib/sort.c | 127 ++++++++++++++++++++++++++++++++++++++++++++++-------
 1 file changed, 112 insertions(+), 15 deletions(-)

diff --git a/lib/sort.c b/lib/sort.c
index d6b7a202b0b6..edc09287e572 100644
--- a/lib/sort.c
+++ b/lib/sort.c
@@ -1,8 +1,9 @@
 // SPDX-License-Identifier: GPL-2.0
 /*
- * A fast, small, non-recursive O(nlog n) sort for the Linux kernel
+ * A fast, recursive O(nlog n) sort for the Linux kernel
  *
  * Jan 23 2005  Matt Mackall <m...@selenic.com>
+ * Feb 5  2019  Andrey Abramov <st5...@yandex.ru> (introspective sort)
  */
 
 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
@@ -48,33 +49,22 @@ static void generic_swap(void *a, void *b, int size)
  * @num: number of elements
  * @size: size of each element
  * @cmp_func: pointer to comparison function
- * @swap_func: pointer to swap function or NULL
+ * @swap_func: pointer to swap function
  *
- * This function does a heapsort on the given array. You may provide a
- * swap_func function optimized to your element type.
+ * This function does a heapsort on the given array.
  *
  * Sorting time is O(n log n) both on average and worst-case. While
  * qsort is about 20% faster on average, it suffers from exploitable
  * O(n*n) worst-case behavior and extra memory requirements that make
  * it less suitable for kernel use.
  */
-
-void sort(void *base, size_t num, size_t size,
+void heapsort(void *base, size_t num, size_t size,
          int (*cmp_func)(const void *, const void *),
          void (*swap_func)(void *, void *, int size))
 {
        /* pre-scale counters for performance */
        int i = (num/2 - 1) * size, n = num * size, c, r;
 
-       if (!swap_func) {
-               if (size == 4 && alignment_ok(base, 4))
-                       swap_func = u32_swap;
-               else if (size == 8 && alignment_ok(base, 8))
-                       swap_func = u64_swap;
-               else
-                       swap_func = generic_swap;
-       }
-
        /* heapify */
        for ( ; i >= 0; i -= size) {
                for (r = i; r * 2 + size < n; r  = c) {
@@ -103,4 +93,111 @@ void sort(void *base, size_t num, size_t size,
        }
 }
 
+void introsort(void *base, size_t num, size_t size,
+         int (*cmp_func)(const void *, const void *),
+         void (*swap_func)(void *, void *, int size),
+         unsigned int max_depth, unsigned int depth)
+{
+
+       void *last = base + (num - 1) * size;
+       void *pivot = base + ((num - 1) / 2) * size;
+       void *i = base;
+       void *j = last;
+
+       if (num <= 1)
+               return;
+
+       /* switching to heapsort */
+       if (depth >= max_depth) {
+               heapsort(base, num, size, cmp_func, swap_func);
+               return;
+       }
+
+       /* making pivot be the median of middle, first and last elements */
+       if ((cmp_func(pivot, base) >= 0 && cmp_func(base, last) >= 0)
+               || (cmp_func(last, base) >= 0 && cmp_func(base, pivot) >= 0)) {
+               pivot = base;
+       } else if ((cmp_func(pivot, last) >= 0 && cmp_func(last, base) >= 0)
+               || (cmp_func(base, last) >= 0 && cmp_func(last, pivot) >= 0)) {
+               pivot = last;
+       }
+
+       /* split array */
+       while (true) {
+               while (cmp_func(i, pivot) < 0 && i < last)
+                       i += size;
+               while (cmp_func(j, pivot) > 0 && j > base)
+                       j -= size;
+
+               if (i >= j)
+                       break;
+
+               swap_func(i, j, size);
+
+               if (i == pivot)
+                       pivot = j;
+               else if (j == pivot)
+                       pivot = i;
+
+               j -= size;
+               i += size;
+       }
+
+       /* continue for smaller parts */
+       if (i < last)
+               introsort(i, ((size_t)last - (size_t)i) / size + 1,
+                       size, cmp_func, swap_func, max_depth, depth + 1);
+       if (base < j)
+               introsort(base, ((size_t)j - (size_t)base) / size + 1,
+                       size, cmp_func, swap_func, max_depth, depth + 1);
+}
+
+unsigned int log2_up(size_t val)
+{
+       unsigned int log = 0;
+       size_t current = 1;
+
+       unsigned int max_reachable_log = sizeof(val) * 8 - 1;
+
+       while (current < val) {
+               current <<= 1;
+               log++;
+               if (log == max_reachable_log && current < val)
+                       return max_reachable_log + 1;
+       }
+
+       return log;
+}
+
+
+/**
+ * sort - sort an array of elements
+ * @base: pointer to data to sort
+ * @num: number of elements
+ * @size: size of each element
+ * @cmp_func: pointer to comparison function
+ * @swap_func: pointer to swap function or NULL
+ *
+ * This function does a introspective sort on the given array. You may provide 
a
+ * swap_func function optimized to your element type.
+ *
+ * The introspective sort use both qsort and heapsort,
+ * so it is faster than heapsortĀ on average,
+ * but it doesn't have the worst case unlike qsort
+ */
+void sort(void *base, size_t num, size_t size,
+         int (*cmp_func)(const void *, const void *),
+         void (*swap_func)(void *, void *, int size))
+{
+       if (!swap_func) {
+               if (size == 4 && alignment_ok(base, 4))
+                       swap_func = u32_swap;
+               else if (size == 8 && alignment_ok(base, 8))
+                       swap_func = u64_swap;
+               else
+                       swap_func = generic_swap;
+       }
+
+       introsort(base, num, size, cmp_func, swap_func, log2_up(num) * 2, 1);
+}
 EXPORT_SYMBOL(sort);
-- 
2.20.1

Reply via email to