commit 4e606150f4e5cb105b83c2b6b2e77d8e55a35a2c
Author: Joel Jakobsson <joel@compiler.org>
Date:   Tue Jun 13 22:13:06 2023 +0200

    Refactor hashset comparison functions and improve ordering logic
    
    This commit introduces a significant refactor of the comparison functions
    for the hashset type, specifically hashset_cmp(), hashset_lt(), hashset_le(),
    hashset_gt(), and hashset_ge().
    
    We addressed an issue with the previous implementation where the comparison
    functions did not correctly handle hashsets with empty positions. This
    resulted in incorrect ordering of hashsets when using these comparison functions.
    
    Now, the comparison functions correctly iterate over the elements in the
    hashsets, advancing the iterator for either hashset only when a valid element
    is found. This effectively skips over any empty positions, resulting in the
    comparison and ordering of elements that are actually present in the hashset.
    
    Additionally, the four functions hashset_lt(), hashset_le(), hashset_gt(),
    and hashset_ge() have been simplified by using the hashset_cmp() function,
    which reduces redundancy in the codebase.
    
    Tests have been updated to reflect these changes and ensure the correct
    functionality of the revised comparison functions.

diff --git a/hashset.c b/hashset.c
index ec3ed44..2e6a51f 100644
--- a/hashset.c
+++ b/hashset.c
@@ -863,53 +863,15 @@ Datum hashset_hash(PG_FUNCTION_ARGS)
 Datum
 hashset_lt(PG_FUNCTION_ARGS)
 {
-	hashset_t *a = PG_GETARG_HASHSET(0);
-	hashset_t *b = PG_GETARG_HASHSET(1);
+    hashset_t *a = PG_GETARG_HASHSET(0);
+    hashset_t *b = PG_GETARG_HASHSET(1);
+    int32 cmp;
 
-	char *bitmap_a, *bitmap_b;
-	int32 *values_a, *values_b;
-	int i;
+    cmp = DatumGetInt32(DirectFunctionCall2(hashset_cmp,
+                                            PointerGetDatum(a),
+                                            PointerGetDatum(b)));
 
-	bitmap_a = a->data;
-	values_a = (int32 *)(a->data + CEIL_DIV(a->maxelements, 8));
-
-	bitmap_b = b->data;
-	values_b = (int32 *)(b->data + CEIL_DIV(b->maxelements, 8));
-
-	/* Compare elements in a lexicographic manner */
-	for (i = 0; i < Min(a->maxelements, b->maxelements); i++)
-	{
-		int byte = (i / 8);
-		int bit = (i % 8);
-
-		bool has_elem_a = bitmap_a[byte] & (0x01 << bit);
-		bool has_elem_b = bitmap_b[byte] & (0x01 << bit);
-
-		if (has_elem_a && has_elem_b)
-		{
-			int32 value_a = values_a[i];
-			int32 value_b = values_b[i];
-
-			if (value_a < value_b)
-				PG_RETURN_BOOL(true);
-			else if (value_a > value_b)
-				PG_RETURN_BOOL(false);
-
-		}
-		else if (has_elem_a)
-			PG_RETURN_BOOL(false);
-		else if (has_elem_b)
-			PG_RETURN_BOOL(true);
-	}
-
-	/*
-	 * If all elements are equal up to the shorter hashset length,
-	 * then the hashset with fewer elements is considered "less than"
-	 */
-	if (a->maxelements < b->maxelements)
-		PG_RETURN_BOOL(true);
-	else
-		PG_RETURN_BOOL(false);
+    PG_RETURN_BOOL(cmp < 0);
 }
 
 
@@ -918,13 +880,13 @@ hashset_le(PG_FUNCTION_ARGS)
 {
 	hashset_t *a = PG_GETARG_HASHSET(0);
 	hashset_t *b = PG_GETARG_HASHSET(1);
+	int32 cmp;
 
-	/* If a equals b, or a is less than b, then a is less than or equal to b */
-	if (DatumGetBool(DirectFunctionCall2(hashset_equals, PointerGetDatum(a), PointerGetDatum(b))) ||
-		DatumGetBool(DirectFunctionCall2(hashset_lt, PointerGetDatum(a), PointerGetDatum(b))))
-		PG_RETURN_BOOL(true);
+	cmp = DatumGetInt32(DirectFunctionCall2(hashset_cmp,
+											PointerGetDatum(a),
+											PointerGetDatum(b)));
 
-	PG_RETURN_BOOL(false);
+	PG_RETURN_BOOL(cmp <= 0);
 }
 
 
@@ -933,12 +895,13 @@ hashset_gt(PG_FUNCTION_ARGS)
 {
 	hashset_t *a = PG_GETARG_HASHSET(0);
 	hashset_t *b = PG_GETARG_HASHSET(1);
+	int32 cmp;
 
-	/* If a is not less than or equal to b, then a is greater than b */
-	if (!DatumGetBool(DirectFunctionCall2(hashset_le, PointerGetDatum(a), PointerGetDatum(b))))
-		PG_RETURN_BOOL(true);
+	cmp = DatumGetInt32(DirectFunctionCall2(hashset_cmp,
+											PointerGetDatum(a),
+											PointerGetDatum(b)));
 
-	PG_RETURN_BOOL(false);
+	PG_RETURN_BOOL(cmp > 0);
 }
 
 
@@ -947,13 +910,13 @@ hashset_ge(PG_FUNCTION_ARGS)
 {
 	hashset_t *a = PG_GETARG_HASHSET(0);
 	hashset_t *b = PG_GETARG_HASHSET(1);
+	int32 cmp;
 
-	/* If a equals b, or a is not less than b, then a is greater than or equal to b */
-	if (DatumGetBool(DirectFunctionCall2(hashset_equals, PointerGetDatum(a), PointerGetDatum(b))) ||
-		!DatumGetBool(DirectFunctionCall2(hashset_lt, PointerGetDatum(a), PointerGetDatum(b))))
-		PG_RETURN_BOOL(true);
+	cmp = DatumGetInt32(DirectFunctionCall2(hashset_cmp,
+											PointerGetDatum(a),
+											PointerGetDatum(b)));
 
-	PG_RETURN_BOOL(false);
+	PG_RETURN_BOOL(cmp >= 0);
 }
 
 
@@ -965,7 +928,7 @@ hashset_cmp(PG_FUNCTION_ARGS)
 
 	char *bitmap_a, *bitmap_b;
 	int32 *values_a, *values_b;
-	int i;
+	int i = 0, j = 0;
 
 	bitmap_a = a->data;
 	values_a = (int32 *)(a->data + CEIL_DIV(a->maxelements, 8));
@@ -973,45 +936,52 @@ hashset_cmp(PG_FUNCTION_ARGS)
 	bitmap_b = b->data;
 	values_b = (int32 *)(b->data + CEIL_DIV(b->maxelements, 8));
 
-	/*
-	 * Iterate through the elements
-	 */
-	for (i = 0; i < Min(a->maxelements, b->maxelements); i++)
+	/* Iterate over the elements in each hashset independently */
+	while(i < a->maxelements && j < b->maxelements)
 	{
-		int byte = (i / 8);
-		int bit = (i % 8);
+		int byte_a = (i / 8);
+		int bit_a = (i % 8);
 
-		bool a_contains = bitmap_a[byte] & (0x01 << bit);
-		bool b_contains = bitmap_b[byte] & (0x01 << bit);
+		int byte_b = (j / 8);
+		int bit_b = (j % 8);
 
-		if (a_contains && b_contains)
-		{
-			int32 value_a = values_a[i];
-			int32 value_b = values_b[i];
+		bool has_elem_a = bitmap_a[byte_a] & (0x01 << bit_a);
+		bool has_elem_b = bitmap_b[byte_b] & (0x01 << bit_b);
 
-			if (value_a < value_b)
-				PG_RETURN_INT32(-1);
-			else if (value_a > value_b)
-				PG_RETURN_INT32(1);
-		}
-		else if (a_contains)
+		int32 value_a;
+		int32 value_b;
+
+		/* Skip if position is empty in either bitmap */
+		if (!has_elem_a)
 		{
-			PG_RETURN_INT32(1);
+			i++;
+			continue;
 		}
-		else if (b_contains)
+
+		if (!has_elem_b)
 		{
+			j++;
+			continue;
+		}
+
+		/* Both hashsets have an element at the current position */
+		value_a = values_a[i++];
+		value_b = values_b[j++];
+
+		if (value_a < value_b)
 			PG_RETURN_INT32(-1);
-		}
+		else if (value_a > value_b)
+			PG_RETURN_INT32(1);
 	}
 
 	/*
-	 * If we got here, the elements in the overlap are equal.
-	 * We need to check the number of elements to determine the order.
+	 * If all compared elements are equal,
+	 * then compare the remaining elements in the larger hashset
 	 */
-	if (a->nelements < b->nelements)
-		PG_RETURN_INT32(-1);
-	else if (a->nelements > b->nelements)
+	if (i < a->maxelements)
 		PG_RETURN_INT32(1);
+	else if (j < b->maxelements)
+		PG_RETURN_INT32(-1);
 	else
 		PG_RETURN_INT32(0);
 }
diff --git a/test/expected/order.out b/test/expected/order.out
index 8d3ff61..f8f8d5b 100644
--- a/test/expected/order.out
+++ b/test/expected/order.out
@@ -116,3 +116,59 @@ SELECT '{2}'::hashset <> '{3}'::hashset; -- true
  t
 (1 row)
 
+CREATE OR REPLACE FUNCTION generate_random_hashset(num_elements INT)
+RETURNS hashset AS $$
+DECLARE
+  element INT;
+  random_set hashset;
+BEGIN
+  random_set := hashset_init(num_elements);
+
+  FOR i IN 1..num_elements LOOP
+    element := floor(random() * 1000)::INT;
+    random_set := hashset_add(random_set, element);
+  END LOOP;
+
+  RETURN random_set;
+END;
+$$ LANGUAGE plpgsql;
+SELECT setseed(0.123465);
+ setseed 
+---------
+ 
+(1 row)
+
+CREATE TABLE hashset_order_test AS
+SELECT generate_random_hashset(3) AS hashset_col
+FROM generate_series(1,1000)
+UNION
+SELECT generate_random_hashset(2)
+FROM generate_series(1,1000);
+SELECT hashset_col
+FROM hashset_order_test
+ORDER BY hashset_col
+LIMIT 20;
+ hashset_col 
+-------------
+ {2,857}
+ {3,85,507}
+ {3,569,891}
+ {3,867,610}
+ {5,207,283}
+ {5,283,972}
+ {5,550,991}
+ {5,606,148}
+ {5,734}
+ {5,862}
+ {5,872}
+ {6,431}
+ {6,444,929}
+ {6,521}
+ {6,592}
+ {7,878,229}
+ {8,14,859}
+ {8,605}
+ {8,654}
+ {8,698}
+(20 rows)
+
diff --git a/test/sql/order.sql b/test/sql/order.sql
index e6c9323..1780c0b 100644
--- a/test/sql/order.sql
+++ b/test/sql/order.sql
@@ -27,3 +27,34 @@ SELECT '{2}'::hashset = '{3}'::hashset; -- false
 SELECT '{2}'::hashset <> '{1}'::hashset; -- true
 SELECT '{2}'::hashset <> '{2}'::hashset; -- false
 SELECT '{2}'::hashset <> '{3}'::hashset; -- true
+
+CREATE OR REPLACE FUNCTION generate_random_hashset(num_elements INT)
+RETURNS hashset AS $$
+DECLARE
+  element INT;
+  random_set hashset;
+BEGIN
+  random_set := hashset_init(num_elements);
+
+  FOR i IN 1..num_elements LOOP
+    element := floor(random() * 1000)::INT;
+    random_set := hashset_add(random_set, element);
+  END LOOP;
+
+  RETURN random_set;
+END;
+$$ LANGUAGE plpgsql;
+
+SELECT setseed(0.123465);
+
+CREATE TABLE hashset_order_test AS
+SELECT generate_random_hashset(3) AS hashset_col
+FROM generate_series(1,1000)
+UNION
+SELECT generate_random_hashset(2)
+FROM generate_series(1,1000);
+
+SELECT hashset_col
+FROM hashset_order_test
+ORDER BY hashset_col
+LIMIT 20;
