https://github.com/python/cpython/commit/38a99568763604ccec5d5027f0658100ad76876f
commit: 38a99568763604ccec5d5027f0658100ad76876f
branch: main
author: Thomas Grainger <[email protected]>
committer: kumaraditya303 <[email protected]>
date: 2025-01-20T22:23:55+05:30
summary:

gh-128308: pass `**kwargs` to asyncio task_factory (#128768)

Co-authored-by: Kumar Aditya <[email protected]>

files:
A Misc/NEWS.d/next/Library/2025-01-13-07-54-32.gh-issue-128308.kYSDRF.rst
M Doc/library/asyncio-eventloop.rst
M Lib/asyncio/base_events.py
M Lib/asyncio/events.py
M Lib/test/test_asyncio/test_base_events.py
M Lib/test/test_asyncio/test_eager_task_factory.py
M Lib/test/test_asyncio/test_free_threading.py
M Lib/test/test_asyncio/test_taskgroups.py

diff --git a/Doc/library/asyncio-eventloop.rst 
b/Doc/library/asyncio-eventloop.rst
index 3bf38a2212c0e0..15ef33e195904d 100644
--- a/Doc/library/asyncio-eventloop.rst
+++ b/Doc/library/asyncio-eventloop.rst
@@ -392,9 +392,9 @@ Creating Futures and Tasks
 
    If *factory* is ``None`` the default task factory will be set.
    Otherwise, *factory* must be a *callable* with the signature matching
-   ``(loop, coro, context=None)``, where *loop* is a reference to the active
+   ``(loop, coro, **kwargs)``, where *loop* is a reference to the active
    event loop, and *coro* is a coroutine object.  The callable
-   must return a :class:`asyncio.Future`-compatible object.
+   must pass on all *kwargs*, and return a :class:`asyncio.Task`-compatible 
object.
 
 .. method:: loop.get_task_factory()
 
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
index 85018797db33bb..ed852421e44212 100644
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -458,25 +458,18 @@ def create_future(self):
         """Create a Future object attached to the loop."""
         return futures.Future(loop=self)
 
-    def create_task(self, coro, *, name=None, context=None):
+    def create_task(self, coro, **kwargs):
         """Schedule a coroutine object.
 
         Return a task object.
         """
         self._check_closed()
-        if self._task_factory is None:
-            task = tasks.Task(coro, loop=self, name=name, context=context)
-            if task._source_traceback:
-                del task._source_traceback[-1]
-        else:
-            if context is None:
-                # Use legacy API if context is not needed
-                task = self._task_factory(self, coro)
-            else:
-                task = self._task_factory(self, coro, context=context)
-
-            task.set_name(name)
+        if self._task_factory is not None:
+            return self._task_factory(self, coro, **kwargs)
 
+        task = tasks.Task(coro, loop=self, **kwargs)
+        if task._source_traceback:
+            del task._source_traceback[-1]
         try:
             return task
         finally:
@@ -490,9 +483,10 @@ def set_task_factory(self, factory):
         If factory is None the default task factory will be set.
 
         If factory is a callable, it should have a signature matching
-        '(loop, coro)', where 'loop' will be a reference to the active
-        event loop, 'coro' will be a coroutine object.  The callable
-        must return a Future.
+        '(loop, coro, **kwargs)', where 'loop' will be a reference to the 
active
+        event loop, 'coro' will be a coroutine object, and **kwargs will be
+        arbitrary keyword arguments that should be passed on to Task.
+        The callable must return a Task.
         """
         if factory is not None and not callable(factory):
             raise TypeError('task factory must be a callable or None')
diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py
index 2ee9870e80f20b..2e45b4fe6fa2dd 100644
--- a/Lib/asyncio/events.py
+++ b/Lib/asyncio/events.py
@@ -329,7 +329,7 @@ def create_future(self):
 
     # Method scheduling a coroutine object: create a task.
 
-    def create_task(self, coro, *, name=None, context=None):
+    def create_task(self, coro, **kwargs):
         raise NotImplementedError
 
     # Methods for interacting with threads.
diff --git a/Lib/test/test_asyncio/test_base_events.py 
b/Lib/test/test_asyncio/test_base_events.py
index 102c9be0ecf031..8cf1f6891faf97 100644
--- a/Lib/test/test_asyncio/test_base_events.py
+++ b/Lib/test/test_asyncio/test_base_events.py
@@ -833,8 +833,8 @@ async def test():
             loop.close()
 
     def test_create_named_task_with_custom_factory(self):
-        def task_factory(loop, coro):
-            return asyncio.Task(coro, loop=loop)
+        def task_factory(loop, coro, **kwargs):
+            return asyncio.Task(coro, loop=loop, **kwargs)
 
         async def test():
             pass
diff --git a/Lib/test/test_asyncio/test_eager_task_factory.py 
b/Lib/test/test_asyncio/test_eager_task_factory.py
index dcf9ff716ad399..10450c11b68279 100644
--- a/Lib/test/test_asyncio/test_eager_task_factory.py
+++ b/Lib/test/test_asyncio/test_eager_task_factory.py
@@ -302,6 +302,18 @@ async def run():
 
        self.run_coro(run())
 
+    def test_name(self):
+        name = None
+        async def coro():
+            nonlocal name
+            name = asyncio.current_task().get_name()
+
+        async def main():
+            task = self.loop.create_task(coro(), name="test name")
+            self.assertEqual(name, "test name")
+            await task
+
+        self.run_coro(coro())
 
 class AsyncTaskCounter:
     def __init__(self, loop, *, task_class, eager):
diff --git a/Lib/test/test_asyncio/test_free_threading.py 
b/Lib/test/test_asyncio/test_free_threading.py
index 8f4bba5f3b97d9..05106a2c2fe3f6 100644
--- a/Lib/test/test_asyncio/test_free_threading.py
+++ b/Lib/test/test_asyncio/test_free_threading.py
@@ -112,8 +112,8 @@ class TestPyFreeThreading(TestFreeThreading, TestCase):
     all_tasks = staticmethod(asyncio.tasks._py_all_tasks)
     current_task = staticmethod(asyncio.tasks._py_current_task)
 
-    def factory(self, loop, coro, context=None):
-        return asyncio.tasks._PyTask(coro, loop=loop, context=context)
+    def factory(self, loop, coro, **kwargs):
+        return asyncio.tasks._PyTask(coro, loop=loop, **kwargs)
 
 
 @unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires 
_asyncio")
@@ -121,16 +121,16 @@ class TestCFreeThreading(TestFreeThreading, TestCase):
     all_tasks = staticmethod(getattr(asyncio.tasks, "_c_all_tasks", None))
     current_task = staticmethod(getattr(asyncio.tasks, "_c_current_task", 
None))
 
-    def factory(self, loop, coro, context=None):
-        return asyncio.tasks._CTask(coro, loop=loop, context=context)
+    def factory(self, loop, coro, **kwargs):
+        return asyncio.tasks._CTask(coro, loop=loop, **kwargs)
 
 
 class TestEagerPyFreeThreading(TestPyFreeThreading):
-    def factory(self, loop, coro, context=None):
-        return asyncio.tasks._PyTask(coro, loop=loop, context=context, 
eager_start=True)
+    def factory(self, loop, coro, eager_start=True, **kwargs):
+        return asyncio.tasks._PyTask(coro, loop=loop, **kwargs, 
eager_start=eager_start)
 
 
 @unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires 
_asyncio")
 class TestEagerCFreeThreading(TestCFreeThreading, TestCase):
-    def factory(self, loop, coro, context=None):
-        return asyncio.tasks._CTask(coro, loop=loop, context=context, 
eager_start=True)
+    def factory(self, loop, coro, eager_start=True, **kwargs):
+        return asyncio.tasks._CTask(coro, loop=loop, **kwargs, 
eager_start=eager_start)
diff --git a/Lib/test/test_asyncio/test_taskgroups.py 
b/Lib/test/test_asyncio/test_taskgroups.py
index 870fa8dbbf2714..7859b33532fa27 100644
--- a/Lib/test/test_asyncio/test_taskgroups.py
+++ b/Lib/test/test_asyncio/test_taskgroups.py
@@ -1040,6 +1040,18 @@ class MyKeyboardInterrupt(KeyboardInterrupt):
         self.assertIsNotNone(exc)
         self.assertListEqual(gc.get_referrers(exc), no_other_refs())
 
+    async def test_name(self):
+        name = None
+
+        async def asyncfn():
+            nonlocal name
+            name = asyncio.current_task().get_name()
+
+        async with asyncio.TaskGroup() as tg:
+            tg.create_task(asyncfn(), name="example name")
+
+        self.assertEqual(name, "example name")
+
 
 class TestTaskGroup(BaseTestTaskGroup, unittest.IsolatedAsyncioTestCase):
     loop_factory = asyncio.EventLoop
diff --git 
a/Misc/NEWS.d/next/Library/2025-01-13-07-54-32.gh-issue-128308.kYSDRF.rst 
b/Misc/NEWS.d/next/Library/2025-01-13-07-54-32.gh-issue-128308.kYSDRF.rst
new file mode 100644
index 00000000000000..efa613876a35fd
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2025-01-13-07-54-32.gh-issue-128308.kYSDRF.rst
@@ -0,0 +1 @@
+Support the *name* keyword argument for eager tasks in 
:func:`asyncio.loop.create_task`,  :func:`asyncio.create_task` and  
:func:`asyncio.TaskGroup.create_task`, by passing on all *kwargs* to the task 
factory set by :func:`asyncio.loop.set_task_factory`.

_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: [email protected]

Reply via email to