Following a discussion on changes in RDMA subsystem [1], there is a use
for helpers that allow overflow safe comparison between the result of a
sum or a diff between two variables and some third operand. The classic
use case is checking that end address is in range, given start address
and length, but there are probably others. Add helpers that perform
mathematically correct comparison regardles of the types being used.

[1] https://lore.kernel.org/all/[email protected]/

Signed-off-by: Michael Margolin <[email protected]>
---
 include/linux/overflow.h   |  34 ++++++++
 lib/tests/overflow_kunit.c | 171 +++++++++++++++++++++++++++++++++++++
 2 files changed, 205 insertions(+)

diff --git a/include/linux/overflow.h b/include/linux/overflow.h
index 154ed0dbb43f..c79bde6cee3d 100644
--- a/include/linux/overflow.h
+++ b/include/linux/overflow.h
@@ -99,6 +99,23 @@ static inline bool __must_check __must_check_overflow(bool 
overflow)
                *__ptr = wrapping_add(typeof(var), *__ptr, offset);     \
        })
 
+/**
+ * overflow_safe_add_cmp() - Calculate addition of first two parameters and
+ * compare to the third parameter. The returned result is correct even when the
+ * sum wraps-around.
+ * @a: first addend
+ * @b: second addend
+ * @c: comparison operand
+ *
+ * Returns -1 if @a + @b < @c, 0 if @a + @b == @c and 1 if @a + @b > @c.
+ */
+#define overflow_safe_add_cmp(a, b, c)                                        \
+       ({                                                                    \
+               typeof(c) __val, __c = c;                                     \
+               __builtin_add_overflow(a, b, &__val) || __val > __c ? 1 :     \
+               __val == __c ? 0 : -1;                                        \
+       })
+
 /**
  * check_sub_overflow() - Calculate subtraction with overflow checking
  * @a: minuend; value to subtract from
@@ -145,6 +162,23 @@ static inline bool __must_check __must_check_overflow(bool 
overflow)
                *__ptr = wrapping_sub(typeof(var), *__ptr, offset);     \
        })
 
+/**
+ * overflow_safe_sub_cmp() - Calculate subtraction of second parameter from
+ * the first and compare to the third parameter. The returned result is correct
+ * even when the subtraction wraps-around.
+ * @a: minuend; value to subtract from
+ * @b: subtrahend; value to subtract from @a
+ * @c: comparison operand
+ *
+ * Returns -1 if @a - @b < @c, 0 if @a - @b == @c and 1 if @a - @b > @c.
+ */
+#define overflow_safe_sub_cmp(a, b, c)                                        \
+       ({                                                                    \
+               typeof(c) __val, __c = c;                                     \
+               __builtin_sub_overflow(a, b, &__val) || __val < __c ? -1 :    \
+                       __val == __c ? 0 : 1;                                 \
+       })
+
 /**
  * check_mul_overflow() - Calculate multiplication with overflow checking
  * @a: first factor
diff --git a/lib/tests/overflow_kunit.c b/lib/tests/overflow_kunit.c
index 19cb03b25dc5..f6d94e55a593 100644
--- a/lib/tests/overflow_kunit.c
+++ b/lib/tests/overflow_kunit.c
@@ -1225,6 +1225,174 @@ static void DEFINE_FLEX_test(struct kunit *test)
        KUNIT_EXPECT_EQ(test, __member_size(two_but_zero->array), 
array_size_override);
 }
 
+/* Common test macro for all overflow_safe_*_cmp functions */
+#define TEST_SAFE_CMP(op, a, b, c, expected) do {                              
  \
+       int result = overflow_safe_ ## op ## _cmp(a, b, c);                     
  \
+       KUNIT_EXPECT_EQ_MSG(test, result, expected,                             
  \
+               "expected overflow_safe_" #op "_cmp(%s, %s, %s) == %d, got 
%d\n", \
+               #a, #b, #c, expected, result);                                  
  \
+       count++;                                                                
  \
+} while (0)
+
+static void overflow_safe_add_cmp_test(struct kunit *test)
+{
+       int count = 0;
+
+       /* Basic addition comparisons without overflow */
+       TEST_SAFE_CMP(add, 5, 3, 8, 0);
+       TEST_SAFE_CMP(add, 5, 3, 7, 1);
+       TEST_SAFE_CMP(add, 5, 3, 9, -1);
+       TEST_SAFE_CMP(add, 0, 0, 0, 0);
+       TEST_SAFE_CMP(add, 1, 0, 1, 0);
+       TEST_SAFE_CMP(add, 0, 1, 1, 0);
+
+       /* Test with unsigned 8-bit values */
+       TEST_SAFE_CMP(add, (u8)100, (u8)50, (u8)150, 0);
+       TEST_SAFE_CMP(add, (u8)100, (u8)50, (u8)149, 1);
+       TEST_SAFE_CMP(add, (u8)100, (u8)50, (u8)151, -1);
+
+       /* Test overflow cases with u8 - should behave as if on infinite types 
*/
+       TEST_SAFE_CMP(add, (u8)200, (u8)100, (u8)44, 1);   /* 300 > 44, not 44 
== 44 */
+       TEST_SAFE_CMP(add, (u8)255, (u8)1, (u8)0, 1);      /* 256 > 0, not 0 == 
0 */
+       TEST_SAFE_CMP(add, (u8)255, (u8)2, (u8)1, 1);      /* 257 > 1, not 1 == 
1 */
+       TEST_SAFE_CMP(add, (u8)200, (u8)200, (u8)145, 1);  /* 400 > 145, not 
145 == 145 */
+
+       /* Test with signed values */
+       TEST_SAFE_CMP(add, -5, 3, -2, 0);
+       TEST_SAFE_CMP(add, -5, 3, -3, 1);
+       TEST_SAFE_CMP(add, -5, 3, -1, -1);
+
+       /* Test signed overflow cases - should behave as if on infinite types */
+       TEST_SAFE_CMP(add, (s8)100, (s8)50, (s8)-106, 1);  /* 150 > -106, not 
-106 == -106 */
+       TEST_SAFE_CMP(add, (s8)127, (s8)1, (s8)-128, 1);   /* 128 > -128, not 
-128 == -128 */
+       TEST_SAFE_CMP(add, (s8)127, (s8)127, (s8)-2, 1);   /* 254 > -2, not -2 
== -2 */
+
+       /* Test with larger types */
+       TEST_SAFE_CMP(add, (u32)0xFFFFFFFF, (u32)1, (u32)0, 1);
+       TEST_SAFE_CMP(add, (u32)0xFFFFFFFF, (u32)2, (u32)1, 1);
+       TEST_SAFE_CMP(add, (u32)0xFFFFFFFF, (u32)0xFFFFFFFF, (u32)0, 1);
+
+       /* Test real-world use case: checking if an address range is within 
bounds */
+       TEST_SAFE_CMP(add, 0xFFFFFFFFFFFF0000ULL, 0x1000, 0x1000000ULL, 1);
+       TEST_SAFE_CMP(add, 0xFFFFFFFFFFFF0000ULL, 0x10000, 0x1000000ULL, 1);
+
+       /* Test with mixed types */
+       TEST_SAFE_CMP(add, (u8)200, (u16)300, (u32)500, 0);
+       TEST_SAFE_CMP(add, (u8)200, (u16)300, (u32)499, 1);
+       TEST_SAFE_CMP(add, (u8)200, (u16)300, (u32)501, -1);
+
+       /* Test with mixed types that would overflow in the smaller type */
+       TEST_SAFE_CMP(add, (u8)200, (u8)100, (u16)300, 0);
+       TEST_SAFE_CMP(add, (u8)255, (u16)1, (u32)256, 0);
+       TEST_SAFE_CMP(add, (u8)255, (u16)256, (u32)511, 0);
+
+       /* Test with mixed signed/unsigned types */
+       TEST_SAFE_CMP(add, (s8)-10, (u8)20, (s16)10, 0);
+       TEST_SAFE_CMP(add, (s8)-10, (u8)5, (s16)-5, 0);
+       TEST_SAFE_CMP(add, (s8)-128, (u16)128, (s32)0, 0);
+       TEST_SAFE_CMP(add, (s8)-128, (u16)127, (s32)-1, 0);
+
+       /* Test with mixed types where the result would overflow in c's type */
+       TEST_SAFE_CMP(add, (u32)40000, (u32)30000, (u16)4464, 1);  /* 70000 - 
65536 */
+       TEST_SAFE_CMP(add, (s32)30000, (s32)10000, (s16)-25536, 1); /* 40000 - 
65536 */
+
+       kunit_info(test, "%d overflow_safe_add_cmp tests finished\n", count);
+}
+
+static void overflow_safe_sub_cmp_test(struct kunit *test)
+{
+       int count = 0;
+
+       /* Basic subtraction comparisons without overflow */
+       TEST_SAFE_CMP(sub, 8, 3, 5, 0);
+       TEST_SAFE_CMP(sub, 8, 3, 4, 1);
+       TEST_SAFE_CMP(sub, 8, 3, 6, -1);
+       TEST_SAFE_CMP(sub, 5, 5, 0, 0);
+       TEST_SAFE_CMP(sub, 10, 0, 10, 0);
+
+       /* Test with unsigned 8-bit values */
+       TEST_SAFE_CMP(sub, (u8)150, (u8)50, (u8)100, 0);
+       TEST_SAFE_CMP(sub, (u8)150, (u8)50, (u8)99, 1);
+       TEST_SAFE_CMP(sub, (u8)150, (u8)50, (u8)101, -1);
+
+       /* Test underflow cases with u8 - should behave as if on infinite types 
*/
+       TEST_SAFE_CMP(sub, (u8)50, (u8)100, (u8)200, -1);
+       TEST_SAFE_CMP(sub, (u8)0, (u8)1, (u8)255, -1);     /* -1 < 255, not 255 
== 255 */
+       TEST_SAFE_CMP(sub, (u8)0, (u8)2, (u8)254, -1);     /* -2 < 254, not 254 
== 254 */
+       TEST_SAFE_CMP(sub, (u8)10, (u8)20, (u8)246, -1);   /* -10 < 246, not 
246 == 246 */
+
+       /* Test with signed values */
+       TEST_SAFE_CMP(sub, 5, -3, 8, 0);
+       TEST_SAFE_CMP(sub, -5, -3, -2, 0);
+       TEST_SAFE_CMP(sub, -5, 3, -8, 0);
+
+       /* Test signed underflow cases - should behave as if on infinite types 
*/
+       TEST_SAFE_CMP(sub, (s8)-100, (s8)50, (s8)106, -1);  /* -150 < 106, not 
106 == 106 */
+       TEST_SAFE_CMP(sub, (s8)-128, (s8)1, (s8)127, -1);   /* -129 < 127, not 
127 == 127 */
+       TEST_SAFE_CMP(sub, (s8)-128, (s8)127, (s8)1, -1);   /* -255 < 1, not 1 
== 1 */
+
+       /* Test with larger types */
+       TEST_SAFE_CMP(sub, (u32)0, (u32)1, (u32)0xFFFFFFFF, -1);
+       TEST_SAFE_CMP(sub, (u32)1, (u32)2, (u32)0xFFFFFFFF, -1);
+       TEST_SAFE_CMP(sub, (u32)0, (u32)0xFFFFFFFF, (u32)1, -1);
+
+       /* Test with mixed types */
+       TEST_SAFE_CMP(sub, (u16)500, (u8)200, (u32)300, 0);
+       TEST_SAFE_CMP(sub, (u16)500, (u8)200, (u32)299, 1);
+       TEST_SAFE_CMP(sub, (u16)500, (u8)200, (u32)301, -1);
+
+       /* Test with mixed types that would underflow in the smaller type */
+       TEST_SAFE_CMP(sub, (u8)50, (u16)100, (s32)-50, 0);
+       TEST_SAFE_CMP(sub, (u8)0, (u16)1, (s32)-1, 0);
+       TEST_SAFE_CMP(sub, (u8)10, (u16)20, (s32)-10, 0);
+
+       /* Test with mixed signed/unsigned types */
+       TEST_SAFE_CMP(sub, (s8)10, (u8)20, (s16)-10, 0);
+       TEST_SAFE_CMP(sub, (u8)20, (s8)-10, (s16)30, 0);
+       TEST_SAFE_CMP(sub, (s16)-1000, (u16)1000, (s32)-2000, 0);
+
+       /* Test with mixed types where the result would underflow in c's type */
+       TEST_SAFE_CMP(sub, (u16)1000, (u32)40000, (u16)26536, -1);  /* -39000 + 
65536 */
+       TEST_SAFE_CMP(sub, (s16)-30000, (s16)10000, (s16)25536, -1); /* -40000 
+ 65536 */
+
+       kunit_info(test, "%d overflow_safe_sub_cmp tests finished\n", count);
+}
+
+#undef TEST_SAFE_CMP
+
+static void overflow_safe_cmp_side_effects_test(struct kunit *test)
+{
+       int a_orig = 5, a_test = 5;
+       int b_orig = 3, b_test = 3;
+       int c_orig = 8, c_test = 8;
+       int count = 0;
+
+       /* Test that the add macro doesn't have side effects */
+       overflow_safe_add_cmp(a_test++, b_test++, c_test++);
+       KUNIT_EXPECT_EQ_MSG(test, a_test, a_orig + 1,
+               "overflow_safe_add_cmp had unexpected side effect on first 
argument");
+       KUNIT_EXPECT_EQ_MSG(test, b_test, b_orig + 1,
+               "overflow_safe_add_cmp had unexpected side effect on second 
argument");
+       KUNIT_EXPECT_EQ_MSG(test, c_test, c_orig + 1,
+               "overflow_safe_add_cmp had unexpected side effect on third 
argument");
+       count += 3;
+
+       /* Test that the sub macro doesn't have side effects */
+       a_test = 5;
+       b_test = 3;
+       c_test = 8;
+       overflow_safe_sub_cmp(a_test++, b_test++, c_test++);
+       KUNIT_EXPECT_EQ_MSG(test, a_test, a_orig + 1,
+               "overflow_safe_sub_cmp had unexpected side effect on first 
argument");
+       KUNIT_EXPECT_EQ_MSG(test, b_test, b_orig + 1,
+               "overflow_safe_sub_cmp had unexpected side effect on second 
argument");
+       KUNIT_EXPECT_EQ_MSG(test, c_test, c_orig + 1,
+               "overflow_safe_sub_cmp had unexpected side effect on third 
argument");
+       count += 3;
+
+       kunit_info(test, "%d overflow safe comparison side effects tests 
finished\n", count);
+}
+
 static struct kunit_case overflow_test_cases[] = {
        KUNIT_CASE(u8_u8__u8_overflow_test),
        KUNIT_CASE(s8_s8__s8_overflow_test),
@@ -1248,6 +1416,9 @@ static struct kunit_case overflow_test_cases[] = {
        KUNIT_CASE(same_type_test),
        KUNIT_CASE(castable_to_type_test),
        KUNIT_CASE(DEFINE_FLEX_test),
+       KUNIT_CASE(overflow_safe_add_cmp_test),
+       KUNIT_CASE(overflow_safe_sub_cmp_test),
+       KUNIT_CASE(overflow_safe_cmp_side_effects_test),
        {}
 };
 
-- 
2.47.1


Reply via email to