check_shl_overflow() currently evaluates (@a << @s) in an unsigned long long accumulator. When callers pass __int128/u128 values, the intermediate is truncated to 64 bits before the comparison, so the helper always reports overflow and returns a zeroed result even when *@d is wide enough to hold the full shift.
Introduce __shl_eval_type() to derive the internal evaluation type from @a and *@d. On architectures with CONFIG_ARCH_SUPPORTS_INT128 and compiler support for __int128, it promotes the accumulator to u128 when the promoted sum of @a and *@d is wider than 64 bits; otherwise it stays in an unsigned 64-bit type. This keeps the accumulator unsigned (avoiding UB when left-shifting negative signed values), preserves existing code generation for all current 32/64-bit users, and fixes the spurious overflow reporting for 128-bit shift users. Signed-off-by: Rafael V. Volkmer <[email protected]> --- include/linux/overflow.h | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/include/linux/overflow.h b/include/linux/overflow.h index 725f95f7e416..ca8252e625d5 100644 --- a/include/linux/overflow.h +++ b/include/linux/overflow.h @@ -175,6 +175,26 @@ static inline bool __must_check __must_check_overflow(bool overflow) __val; \ }) +/** + * __shl_eval_type() - Choose evaluation type for shift checks + * @a: value to be shifted + * @d: destination pointer + * + * Returns the internal type used by check_shl_overflow() to evaluate + * (@a << @s), widening to unsigned __int128 when available and either + * @a or *@d promote wider than 64 bits, otherwise using unsigned long long. + */ +#if defined(__SIZEOF_INT128__) +#define __shl_eval_type(a, d) \ + typeof(__builtin_choose_expr( \ + sizeof((a) + (typeof(*(d)))0) > sizeof(unsigned long long), \ + (unsigned __int128)0, \ + 0ULL)) +#else +#define __shl_eval_type(a, d) \ + typeof(0ULL + (a) + (typeof(*(d)))0) +#endif + /** * check_shl_overflow() - Calculate a left-shifted value and check overflow * @a: Value to be shifted @@ -199,7 +219,7 @@ static inline bool __must_check __must_check_overflow(bool overflow) typeof(a) _a = a; \ typeof(s) _s = s; \ typeof(d) _d = d; \ - unsigned long long _a_full = _a; \ + __shl_eval_type(_a, _d) _a_full = (__shl_eval_type(_a, _d))_a; \ unsigned int _to_shift = \ is_non_negative(_s) && _s < 8 * sizeof(*d) ? _s : 0; \ *_d = (_a_full << _to_shift); \ -- 2.43.0
