https://github.com/python/cpython/commit/016f4b5975a251d353c6e9d399b2b6a234735f25
commit: 016f4b5975a251d353c6e9d399b2b6a234735f25
branch: main
author: Jelle Zijlstra <[email protected]>
committer: JelleZijlstra <[email protected]>
date: 2024-08-11T23:42:57Z
summary:

gh-119180: Improvements to ForwardRef.evaluate (#122210)

Noticed some issues while writing documentation for this method.

files:
M Lib/annotationlib.py
M Lib/test/test_annotationlib.py
M Lib/typing.py

diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py
index 141e31bbf910e3..8f2a93be915832 100644
--- a/Lib/annotationlib.py
+++ b/Lib/annotationlib.py
@@ -74,7 +74,7 @@ def __init_subclass__(cls, /, *args, **kwds):
     def evaluate(self, *, globals=None, locals=None, type_params=None, 
owner=None):
         """Evaluate the forward reference and return the value.
 
-        If the forward reference is not evaluatable, raise an exception.
+        If the forward reference cannot be evaluated, raise an exception.
         """
         if self.__forward_evaluated__:
             return self.__forward_value__
@@ -89,12 +89,10 @@ def evaluate(self, *, globals=None, locals=None, 
type_params=None, owner=None):
                 return value
         if owner is None:
             owner = self.__owner__
-        if type_params is None and owner is None:
-            raise TypeError("Either 'type_params' or 'owner' must be provided")
 
-        if self.__forward_module__ is not None:
+        if globals is None and self.__forward_module__ is not None:
             globals = getattr(
-                sys.modules.get(self.__forward_module__, None), "__dict__", 
globals
+                sys.modules.get(self.__forward_module__, None), "__dict__", 
None
             )
         if globals is None:
             globals = self.__globals__
@@ -112,14 +110,14 @@ def evaluate(self, *, globals=None, locals=None, 
type_params=None, owner=None):
 
         if locals is None:
             locals = {}
-            if isinstance(self.__owner__, type):
-                locals.update(vars(self.__owner__))
+            if isinstance(owner, type):
+                locals.update(vars(owner))
 
-        if type_params is None and self.__owner__ is not None:
+        if type_params is None and owner is not None:
             # "Inject" type parameters into the local namespace
             # (unless they are shadowed by assignments *in* the local 
namespace),
             # as a way of emulating annotation scopes when calling `eval()`
-            type_params = getattr(self.__owner__, "__type_params__", None)
+            type_params = getattr(owner, "__type_params__", None)
 
         # type parameters require some special handling,
         # as they exist in their own scope
@@ -129,7 +127,14 @@ def evaluate(self, *, globals=None, locals=None, 
type_params=None, owner=None):
         # but should in turn be overridden by names in the class scope
         # (which here are called `globalns`!)
         if type_params is not None:
-            globals, locals = dict(globals), dict(locals)
+            if globals is None:
+                globals = {}
+            else:
+                globals = dict(globals)
+            if locals is None:
+                locals = {}
+            else:
+                locals = dict(locals)
             for param in type_params:
                 param_name = param.__name__
                 if not self.__forward_is_class__ or param_name not in globals:
diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py
index e4dcdb6b58d009..db8350c2746983 100644
--- a/Lib/test/test_annotationlib.py
+++ b/Lib/test/test_annotationlib.py
@@ -5,7 +5,7 @@
 import itertools
 import pickle
 import unittest
-from annotationlib import Format, get_annotations, get_annotate_function
+from annotationlib import Format, ForwardRef, get_annotations, 
get_annotate_function
 from typing import Unpack
 
 from test.test_inspect import inspect_stock_annotations
@@ -250,6 +250,46 @@ def test_special_attrs(self):
             with self.assertRaises(TypeError):
                 pickle.dumps(fr, proto)
 
+    def test_evaluate_with_type_params(self):
+        class Gen[T]:
+            alias = int
+
+        with self.assertRaises(NameError):
+            ForwardRef("T").evaluate()
+        with self.assertRaises(NameError):
+            ForwardRef("T").evaluate(type_params=())
+        with self.assertRaises(NameError):
+            ForwardRef("T").evaluate(owner=int)
+
+        T, = Gen.__type_params__
+        
self.assertIs(ForwardRef("T").evaluate(type_params=Gen.__type_params__), T)
+        self.assertIs(ForwardRef("T").evaluate(owner=Gen), T)
+
+        with self.assertRaises(NameError):
+            ForwardRef("alias").evaluate(type_params=Gen.__type_params__)
+        self.assertIs(ForwardRef("alias").evaluate(owner=Gen), int)
+        # If you pass custom locals, we don't look at the owner's locals
+        with self.assertRaises(NameError):
+            ForwardRef("alias").evaluate(owner=Gen, locals={})
+        # But if the name exists in the locals, it works
+        self.assertIs(
+            ForwardRef("alias").evaluate(owner=Gen, locals={"alias": str}), str
+        )
+
+    def test_fwdref_with_module(self):
+        self.assertIs(ForwardRef("Format", module=annotationlib).evaluate(), 
Format)
+
+        with self.assertRaises(NameError):
+            # If globals are passed explicitly, we don't look at the module 
dict
+            ForwardRef("Format", module=annotationlib).evaluate(globals={})
+
+    def test_fwdref_value_is_cached(self):
+        fr = ForwardRef("hello")
+        with self.assertRaises(NameError):
+            fr.evaluate()
+        self.assertIs(fr.evaluate(globals={"hello": str}), str)
+        self.assertIs(fr.evaluate(), str)
+
 
 class TestGetAnnotations(unittest.TestCase):
     def test_builtin_type(self):
diff --git a/Lib/typing.py b/Lib/typing.py
index 39a14ae6f83c28..bcb7bec23a9aa1 100644
--- a/Lib/typing.py
+++ b/Lib/typing.py
@@ -474,6 +474,10 @@ def _eval_type(t, globalns, localns, 
type_params=_sentinel, *, recursive_guard=f
         _deprecation_warning_for_no_type_params_passed("typing._eval_type")
         type_params = ()
     if isinstance(t, ForwardRef):
+        # If the forward_ref has __forward_module__ set, evaluate() infers the 
globals
+        # from the module, and it will probably pick better than the globals 
we have here.
+        if t.__forward_module__ is not None:
+            globalns = None
         return evaluate_forward_ref(t, globals=globalns, locals=localns,
                                     type_params=type_params, owner=owner,
                                     _recursive_guard=recursive_guard, 
format=format)

_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: [email protected]

Reply via email to