check_shl_overflow() currently evaluates (@a << @s) in an
unsigned long long accumulator. When callers pass __int128/u128
values, the intermediate is truncated to 64 bits before the
comparison, so the helper always reports overflow and returns a
zeroed result even when *@d is wide enough to hold the full shift.

Introduce __shl_eval_type() to derive the internal evaluation type
from @a and *@d. On architectures with CONFIG_ARCH_SUPPORTS_INT128
and compiler support for __int128, it promotes the accumulator to
u128 when the promoted sum of @a and *@d is wider than 64 bits;
otherwise it stays in an unsigned 64-bit type.

This keeps the accumulator unsigned (avoiding UB when left-shifting
negative signed values), preserves existing code generation for all
current 32/64-bit users, and fixes the spurious overflow reporting
for 128-bit shift users.

Signed-off-by: Rafael V. Volkmer <[email protected]>
---
 include/linux/overflow.h | 22 +++++++++++++++++++++-
 1 file changed, 21 insertions(+), 1 deletion(-)

diff --git a/include/linux/overflow.h b/include/linux/overflow.h
index 725f95f7e416..ca8252e625d5 100644
--- a/include/linux/overflow.h
+++ b/include/linux/overflow.h
@@ -175,6 +175,26 @@ static inline bool __must_check __must_check_overflow(bool 
overflow)
                __val;                                          \
        })
 
+/**
+ * __shl_eval_type() - Choose evaluation type for shift checks
+ * @a: value to be shifted
+ * @d: destination pointer
+ *
+ * Returns the internal type used by check_shl_overflow() to evaluate
+ * (@a << @s), widening to unsigned __int128 when available and either
+ * @a or *@d promote wider than 64 bits, otherwise using unsigned long long.
+ */
+#if defined(__SIZEOF_INT128__)
+#define __shl_eval_type(a, d)                                                  
\
+       typeof(__builtin_choose_expr(                                           
\
+               sizeof((a) + (typeof(*(d)))0) > sizeof(unsigned long long),     
\
+               (unsigned __int128)0,                                           
\
+               0ULL))
+#else
+#define __shl_eval_type(a, d)                                                  
\
+       typeof(0ULL + (a) + (typeof(*(d)))0)
+#endif
+
 /**
  * check_shl_overflow() - Calculate a left-shifted value and check overflow
  * @a: Value to be shifted
@@ -199,7 +219,7 @@ static inline bool __must_check __must_check_overflow(bool 
overflow)
        typeof(a) _a = a;                                               \
        typeof(s) _s = s;                                               \
        typeof(d) _d = d;                                               \
-       unsigned long long _a_full = _a;                                \
+       __shl_eval_type(_a, _d) _a_full = (__shl_eval_type(_a, _d))_a;  \
        unsigned int _to_shift =                                        \
                is_non_negative(_s) && _s < 8 * sizeof(*d) ? _s : 0;    \
        *_d = (_a_full << _to_shift);                                   \
-- 
2.43.0


Reply via email to