https://github.com/python/cpython/commit/5c6e3b715082bfccd0b4cf2bb1c18e8b1afcad3e
commit: 5c6e3b715082bfccd0b4cf2bb1c18e8b1afcad3e
branch: main
author: Eric V. Smith <[email protected]>
committer: ericvsmith <[email protected]>
date: 2024-09-24T21:26:26-04:00
summary:

gh-90562: Support zero argument super with dataclasses when slots=True 
(gh-124455)

Co-authored-by: @wookie184
Co-authored-by: Carl Meyer <[email protected]>

files:
A Misc/NEWS.d/next/Library/2024-09-23-18-26-17.gh-issue-90562.Yj566G.rst
M Doc/library/dataclasses.rst
M Lib/dataclasses.py
M Lib/test/test_dataclasses/__init__.py

diff --git a/Doc/library/dataclasses.rst b/Doc/library/dataclasses.rst
index cfca11afbd2e41..1457392ce6e86c 100644
--- a/Doc/library/dataclasses.rst
+++ b/Doc/library/dataclasses.rst
@@ -187,13 +187,6 @@ Module contents
      If :attr:`!__slots__` is already defined in the class, then 
:exc:`TypeError`
      is raised.
 
-    .. warning::
-        Calling no-arg :func:`super` in dataclasses using ``slots=True``
-        will result in the following exception being raised:
-        ``TypeError: super(type, obj): obj must be an instance or subtype of 
type``.
-        The two-arg :func:`super` is a valid workaround.
-        See :gh:`90562` for full details.
-
     .. warning::
        Passing parameters to a base class :meth:`~object.__init_subclass__`
        when using ``slots=True`` will result in a :exc:`TypeError`.
diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py
index 6255d8980974e0..f5cb97edaf72cd 100644
--- a/Lib/dataclasses.py
+++ b/Lib/dataclasses.py
@@ -1218,9 +1218,31 @@ def _get_slots(cls):
             raise TypeError(f"Slots of '{cls.__name__}' cannot be determined")
 
 
+def _update_func_cell_for__class__(f, oldcls, newcls):
+    # Returns True if we update a cell, else False.
+    if f is None:
+        # f will be None in the case of a property where not all of
+        # fget, fset, and fdel are used.  Nothing to do in that case.
+        return False
+    try:
+        idx = f.__code__.co_freevars.index("__class__")
+    except ValueError:
+        # This function doesn't reference __class__, so nothing to do.
+        return False
+    # Fix the cell to point to the new class, if it's already pointing
+    # at the old class.  I'm not convinced that the "is oldcls" test
+    # is needed, but other than performance can't hurt.
+    closure = f.__closure__[idx]
+    if closure.cell_contents is oldcls:
+        closure.cell_contents = newcls
+        return True
+    return False
+
+
 def _add_slots(cls, is_frozen, weakref_slot):
-    # Need to create a new class, since we can't set __slots__
-    #  after a class has been created.
+    # Need to create a new class, since we can't set __slots__ after a
+    # class has been created, and the @dataclass decorator is called
+    # after the class is created.
 
     # Make sure __slots__ isn't already set.
     if '__slots__' in cls.__dict__:
@@ -1259,18 +1281,37 @@ def _add_slots(cls, is_frozen, weakref_slot):
 
     # And finally create the class.
     qualname = getattr(cls, '__qualname__', None)
-    cls = type(cls)(cls.__name__, cls.__bases__, cls_dict)
+    newcls = type(cls)(cls.__name__, cls.__bases__, cls_dict)
     if qualname is not None:
-        cls.__qualname__ = qualname
+        newcls.__qualname__ = qualname
 
     if is_frozen:
         # Need this for pickling frozen classes with slots.
         if '__getstate__' not in cls_dict:
-            cls.__getstate__ = _dataclass_getstate
+            newcls.__getstate__ = _dataclass_getstate
         if '__setstate__' not in cls_dict:
-            cls.__setstate__ = _dataclass_setstate
-
-    return cls
+            newcls.__setstate__ = _dataclass_setstate
+
+    # Fix up any closures which reference __class__.  This is used to
+    # fix zero argument super so that it points to the correct class
+    # (the newly created one, which we're returning) and not the
+    # original class.  We can break out of this loop as soon as we
+    # make an update, since all closures for a class will share a
+    # given cell.
+    for member in newcls.__dict__.values():
+        # If this is a wrapped function, unwrap it.
+        member = inspect.unwrap(member)
+
+        if isinstance(member, types.FunctionType):
+            if _update_func_cell_for__class__(member, cls, newcls):
+                break
+        elif isinstance(member, property):
+            if (_update_func_cell_for__class__(member.fget, cls, newcls)
+                or _update_func_cell_for__class__(member.fset, cls, newcls)
+                or _update_func_cell_for__class__(member.fdel, cls, newcls)):
+                break
+
+    return newcls
 
 
 def dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False,
diff --git a/Lib/test/test_dataclasses/__init__.py 
b/Lib/test/test_dataclasses/__init__.py
index 6934e88d9d338c..69e86162e0c11a 100644
--- a/Lib/test/test_dataclasses/__init__.py
+++ b/Lib/test/test_dataclasses/__init__.py
@@ -17,7 +17,7 @@
 from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, 
Optional, Protocol, DefaultDict
 from typing import get_type_hints
 from collections import deque, OrderedDict, namedtuple, defaultdict
-from functools import total_ordering
+from functools import total_ordering, wraps
 
 import typing       # Needed for the string "typing.ClassVar[int]" to work as 
an annotation.
 import dataclasses  # Needed for the string "dataclasses.InitVar[int]" to work 
as an annotation.
@@ -4869,5 +4869,129 @@ class A:
         self.assertEqual(fs[0].name, 'x')
 
 
+class TestZeroArgumentSuperWithSlots(unittest.TestCase):
+    def test_zero_argument_super(self):
+        @dataclass(slots=True)
+        class A:
+            def foo(self):
+                super()
+
+        A().foo()
+
+    def test_dunder_class_with_old_property(self):
+        @dataclass(slots=True)
+        class A:
+            def _get_foo(slf):
+                self.assertIs(__class__, type(slf))
+                self.assertIs(__class__, slf.__class__)
+                return __class__
+
+            def _set_foo(slf, value):
+                self.assertIs(__class__, type(slf))
+                self.assertIs(__class__, slf.__class__)
+
+            def _del_foo(slf):
+                self.assertIs(__class__, type(slf))
+                self.assertIs(__class__, slf.__class__)
+
+            foo = property(_get_foo, _set_foo, _del_foo)
+
+        a = A()
+        self.assertIs(a.foo, A)
+        a.foo = 4
+        del a.foo
+
+    def test_dunder_class_with_new_property(self):
+        @dataclass(slots=True)
+        class A:
+            @property
+            def foo(slf):
+                return slf.__class__
+
+            @foo.setter
+            def foo(slf, value):
+                self.assertIs(__class__, type(slf))
+
+            @foo.deleter
+            def foo(slf):
+                self.assertIs(__class__, type(slf))
+
+        a = A()
+        self.assertIs(a.foo, A)
+        a.foo = 4
+        del a.foo
+
+    # Test the parts of a property individually.
+    def test_slots_dunder_class_property_getter(self):
+        @dataclass(slots=True)
+        class A:
+            @property
+            def foo(slf):
+                return __class__
+
+        a = A()
+        self.assertIs(a.foo, A)
+
+    def test_slots_dunder_class_property_setter(self):
+        @dataclass(slots=True)
+        class A:
+            foo = property()
+            @foo.setter
+            def foo(slf, val):
+                self.assertIs(__class__, type(slf))
+
+        a = A()
+        a.foo = 4
+
+    def test_slots_dunder_class_property_deleter(self):
+        @dataclass(slots=True)
+        class A:
+            foo = property()
+            @foo.deleter
+            def foo(slf):
+                self.assertIs(__class__, type(slf))
+
+        a = A()
+        del a.foo
+
+    def test_wrapped(self):
+        def mydecorator(f):
+            @wraps(f)
+            def wrapper(*args, **kwargs):
+                return f(*args, **kwargs)
+            return wrapper
+
+        @dataclass(slots=True)
+        class A:
+            @mydecorator
+            def foo(self):
+                super()
+
+        A().foo()
+
+    def test_remembered_class(self):
+        # Apply the dataclass decorator manually (not when the class
+        # is created), so that we can keep a reference to the
+        # undecorated class.
+        class A:
+            def cls(self):
+                return __class__
+
+        self.assertIs(A().cls(), A)
+
+        B = dataclass(slots=True)(A)
+        self.assertIs(B().cls(), B)
+
+        # This is undesirable behavior, but is a function of how
+        # modifying __class__ in the closure works.  I'm not sure this
+        # should be tested or not: I don't really want to guarantee
+        # this behavior, but I don't want to lose the point that this
+        # is how it works.
+
+        # The underlying class is "broken" by changing its __class__
+        # in A.foo() to B.  This normally isn't a problem, because no
+        # one will be keeping a reference to the underlying class A.
+        self.assertIs(A().cls(), B)
+
 if __name__ == '__main__':
     unittest.main()
diff --git 
a/Misc/NEWS.d/next/Library/2024-09-23-18-26-17.gh-issue-90562.Yj566G.rst 
b/Misc/NEWS.d/next/Library/2024-09-23-18-26-17.gh-issue-90562.Yj566G.rst
new file mode 100644
index 00000000000000..7a389fefc6c54b
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2024-09-23-18-26-17.gh-issue-90562.Yj566G.rst
@@ -0,0 +1,3 @@
+Modify dataclasses to support zero-argument super() when ``slots=True`` is
+specified.  This works by modifying all references to ``__class__`` to point
+to the newly created class.

_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: [email protected]

Reply via email to