This adds a check when incrementing the shared count and weak count and
will trap if it overflows. This also double the effective range of the
counts for most 64-bit targets.

The counter type, _Atomic_word, is usually a signed 32-bit int (except
on Solaris v9 where it is a signed 64-bit long). The return type of
std::shared_ptr::use_count() is long. For targets where long is wider
than _Atomic_word (most 64-bit targets) we can treat the _Atomic_word
reference counts as unsigned and allow them to wrap around from their
most positive value to their most negative value without any problems.
The logic that operates on the counts only cares if they are zero or
non-zero, and never performs relational comparisons. The atomic
fetch_add operations on integers are required by the standard to behave
like unsigned types, so that overflow is well-defined:

  "the result is as if the object value and parameters were converted to
  their corresponding unsigned types, the computation performed on those
  types, and the result converted back to the signed type."

So if we allow the counts to wrap around to negative values, all we need
to do is cast the value to make_unsigned_t<_Atomic_word> before
returning it as long from the use_count() function.

In practice even exceeding INT_MAX is extremely unlikely, as it would
require billions of shared_ptr or weak_ptr objects to have been
constructed and never destroyed. However, if that happens we now have
twice the range before the count returns to zero and causes problems.

This change also adds checks to the increment operations to detect when
incrementing the maximum value, and will now trap if the counter would
be incremented past its maximum. The maximum value is the value at which
incrementing it produces an invalid use_count(). So that is either the
maximum positive value of _Atomic_word, or for targets where we now
allow the counters to wrap around to negative values, the "maximum"
value is -1, because that is the value at which one more increment
overflows the counter to zero.

libstdc++-v3/ChangeLog:

        PR libstdc++/71945
        * include/bits/shared_ptr_base.h (_Sp_counted_base::_S_chk):
        Trap if a reference count cannot be incremented any higher.
        (_Sp_counted_base::_M_add_ref_copy): Use _S_chk.
        (_Sp_counted_base::_M_add_weak_ref): Likewise.
        (_Sp_counted_base<...>::_M_add_ref_lock_nothrow): Likewise.
        (_Sp_counted_base::_M_use_count): Cast _M_use_count to unsigned
        before returning as long.
---

This improves safety for std::shared_ptr and std::weak_ptr, removing a
possible source of undefined behaviour (although only in extreme cases
where the reference count somehow exceeds INT_MAX).

However, in a very simple benchmark I see a significant impact in
performance, which is not entirely surprising given that this patch adds
a test to every increment. Maybe we could change [[__unlikely__]] on
those branches to use __builtin_expect_with_probability instead, with a
very very very small probability?

I like the idea of trapping on these overflows, but we need to benchmark
carefully and decide if it's worth the overhead, or if the overhead can
be reduced. Maybe this patch should be split into two, one which adds
the unsigned cast in use_count(), which doubles the usable range of
reference counts without affecting codegen, and a second patch which
adds the overflow detection and the trap.

Either way (as the complete patch, or just casting to unsigned so that
negative values can be used), it depends on the PR 121148 fix to avoid
UB from signed overflow in the atomics.

Tested x86_64-linux.

 libstdc++-v3/include/bits/shared_ptr_base.h | 62 +++++++++++++++++++--
 1 file changed, 58 insertions(+), 4 deletions(-)

diff --git a/libstdc++-v3/include/bits/shared_ptr_base.h 
b/libstdc++-v3/include/bits/shared_ptr_base.h
index fb868e7afc36..ccc7cef3cb5d 100644
--- a/libstdc++-v3/include/bits/shared_ptr_base.h
+++ b/libstdc++-v3/include/bits/shared_ptr_base.h
@@ -148,7 +148,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
       // Increment the use count (used when the count is greater than zero).
       void
       _M_add_ref_copy()
-      { __gnu_cxx::__atomic_add_dispatch(&_M_use_count, 1); }
+      { _S_chk(__gnu_cxx::__exchange_and_add_dispatch(&_M_use_count, 1)); }
 
       // Increment the use count if it is non-zero, throw otherwise.
       void
@@ -200,7 +200,15 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
       // Increment the weak count.
       void
       _M_weak_add_ref() noexcept
-      { __gnu_cxx::__atomic_add_dispatch(&_M_weak_count, 1); }
+      {
+       // _M_weak_count can always use negative values (except for _S_single)
+       // because only _M_use_count can be observed. See _M_chk for details.
+       constexpr _Atomic_word __max = _Lp != _S_single
+             ? -1 : __gnu_cxx::__int_traits<_Atomic_word>::__max;
+
+       if (__gnu_cxx::__exchange_and_add_dispatch(&_M_weak_count, 1) == __max)
+         [[__unlikely__]] __builtin_trap();
+      }
 
       // Decrement the weak count.
       void
@@ -224,15 +232,52 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
       long
       _M_get_use_count() const noexcept
       {
+       // If long is wider than _Atomic_word then we can treat _Atomic_word
+       // as unsigned, and so double its usable range. If the widths are the
+       // same then casting to unsigned and then to long is a no-op.
+       using _Up = typename make_unsigned<_Atomic_word>::type;
+
         // No memory barrier is used here so there is no synchronization
         // with other threads.
-        return __atomic_load_n(&_M_use_count, __ATOMIC_RELAXED);
+       return (_Up) __atomic_load_n(&_M_use_count, __ATOMIC_RELAXED);
       }
 
     private:
       _Sp_counted_base(_Sp_counted_base const&) = delete;
       _Sp_counted_base& operator=(_Sp_counted_base const&) = delete;
 
+      // Called when incrementing _M_use_count to cause a trap on overflow.
+      // This should be passed the value of the counter before the increment.
+      static void
+      _S_chk(_Atomic_word __count)
+      {
+       // __max is the maximum allowed value for the shared reference count.
+       // All valid reference count values need to fit into [0,LONG_MAX)
+       // because users can observe the count via shared_ptr::use_count().
+       //
+       // When long is wider than _Atomic_word, _M_use_count can go negative
+       // and the cast in _Sp_counted_base::use_count() will turn it into a
+       // positive value suitable for returning to users. The implementation
+       // only cares whether _M_use_count reaches zero after a decrement,
+       // so negative values are not a problem internally.
+       // So when possible, use -1 for __max (incrementing past that would
+       // overflow _M_use_count to 0, which means an empty shared_ptr).
+       //
+       // When long is not wider than _Atomic_word, negative counts would
+       // not fit in [0,LONG_MAX) after casting to unsigned, so use_count()
+       // would return invalid negative values, which is not allowed.
+       // So __max is just the type's maximum positive value.
+       //
+       // The _S_single policy cannot use negative counts, because it uses
+       // non-atomic increments with undefined behaviour on signed overflow.
+       constexpr _Atomic_word __max
+         = sizeof(long) > sizeof(_Atomic_word) && _Lp != _S_single
+             ? -1 : __gnu_cxx::__int_traits<_Atomic_word>::__max;
+
+       if (__count == __max) [[__unlikely__]]
+         __builtin_trap();
+      }
+
       _Atomic_word  _M_use_count;     // #shared
       _Atomic_word  _M_weak_count;    // #weak + (#shared != 0)
     };
@@ -244,6 +289,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
     {
       if (_M_use_count == 0)
        return false;
+      _S_chk(_M_use_count);
       ++_M_use_count;
       return true;
     }
@@ -254,8 +300,15 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
     _M_add_ref_lock_nothrow() noexcept
     {
       __gnu_cxx::__scoped_lock sentry(*this);
-      if (__gnu_cxx::__exchange_and_add_dispatch(&_M_use_count, 1) == 0)
+      if (auto __c = __gnu_cxx::__exchange_and_add_dispatch(&_M_use_count, 1))
+       _S_chk(__c);
+      else
        {
+         // Count was zero, so we cannot lock it to get a shared_ptr.
+         // Reset to zero. This isn't racy, because there are no shared_ptr
+         // objects using this count and any other weak_ptr objects using it
+         // must call this function to modify _M_use_count, so would be
+         // synchronized by the mutex.
          _M_use_count = 0;
          return false;
        }
@@ -279,6 +332,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
       while (!__atomic_compare_exchange_n(&_M_use_count, &__count, __count + 1,
                                          true, __ATOMIC_ACQ_REL,
                                          __ATOMIC_RELAXED));
+      _S_chk(__count);
       return true;
     }
 
-- 
2.50.1

Reply via email to