https://github.com/python/cpython/commit/df4d84c3cdca572f1be8f5dc5ef8ead5351b51fb
commit: df4d84c3cdca572f1be8f5dc5ef8ead5351b51fb
branch: main
author: Laurie O <[email protected]>
committer: gvanrossum <[email protected]>
date: 2024-04-06T07:27:13-07:00
summary:

gh-96471: Add asyncio queue shutdown (#104228)

Co-authored-by: Duprat <[email protected]>

files:
A Misc/NEWS.d/next/Library/2023-05-06-05-00-42.gh-issue-96471.S3X5I-.rst
M Doc/library/asyncio-queue.rst
M Doc/whatsnew/3.13.rst
M Lib/asyncio/queues.py
M Lib/test/test_asyncio/test_queues.py

diff --git a/Doc/library/asyncio-queue.rst b/Doc/library/asyncio-queue.rst
index d86fbc21351e2d..030d4310942d7a 100644
--- a/Doc/library/asyncio-queue.rst
+++ b/Doc/library/asyncio-queue.rst
@@ -62,6 +62,9 @@ Queue
       Remove and return an item from the queue. If queue is empty,
       wait until an item is available.
 
+      Raises :exc:`QueueShutDown` if the queue has been shut down and
+      is empty, or if the queue has been shut down immediately.
+
    .. method:: get_nowait()
 
       Return an item if one is immediately available, else raise
@@ -82,6 +85,8 @@ Queue
       Put an item into the queue. If the queue is full, wait until a
       free slot is available before adding the item.
 
+      Raises :exc:`QueueShutDown` if the queue has been shut down.
+
    .. method:: put_nowait(item)
 
       Put an item into the queue without blocking.
@@ -92,6 +97,21 @@ Queue
 
       Return the number of items in the queue.
 
+   .. method:: shutdown(immediate=False)
+
+      Shut down the queue, making :meth:`~Queue.get` and :meth:`~Queue.put`
+      raise :exc:`QueueShutDown`.
+
+      By default, :meth:`~Queue.get` on a shut down queue will only
+      raise once the queue is empty. Set *immediate* to true to make
+      :meth:`~Queue.get` raise immediately instead.
+
+      All blocked callers of :meth:`~Queue.put` will be unblocked. If
+      *immediate* is true, also unblock callers of :meth:`~Queue.get`
+      and :meth:`~Queue.join`.
+
+      .. versionadded:: 3.13
+
    .. method:: task_done()
 
       Indicate that a formerly enqueued task is complete.
@@ -105,6 +125,9 @@ Queue
       call was received for every item that had been :meth:`~Queue.put`
       into the queue).
 
+      ``shutdown(immediate=True)`` calls :meth:`task_done` for each
+      remaining item in the queue.
+
       Raises :exc:`ValueError` if called more times than there were
       items placed in the queue.
 
@@ -145,6 +168,14 @@ Exceptions
    on a queue that has reached its *maxsize*.
 
 
+.. exception:: QueueShutDown
+
+   Exception raised when :meth:`~Queue.put` or :meth:`~Queue.get` is
+   called on a queue which has been shut down.
+
+   .. versionadded:: 3.13
+
+
 Examples
 ========
 
diff --git a/Doc/whatsnew/3.13.rst b/Doc/whatsnew/3.13.rst
index e31f0c52d4c5f5..c785d4cfa8fdc3 100644
--- a/Doc/whatsnew/3.13.rst
+++ b/Doc/whatsnew/3.13.rst
@@ -296,6 +296,10 @@ asyncio
   with the tasks being completed.
   (Contributed by Justin Arthur in :gh:`77714`.)
 
+* Add :meth:`asyncio.Queue.shutdown` (along with
+  :exc:`asyncio.QueueShutDown`) for queue termination.
+  (Contributed by Laurie Opperman in :gh:`104228`.)
+
 base64
 ------
 
diff --git a/Lib/asyncio/queues.py b/Lib/asyncio/queues.py
index a9656a6df561ba..b8156704b8fc23 100644
--- a/Lib/asyncio/queues.py
+++ b/Lib/asyncio/queues.py
@@ -1,4 +1,11 @@
-__all__ = ('Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty')
+__all__ = (
+    'Queue',
+    'PriorityQueue',
+    'LifoQueue',
+    'QueueFull',
+    'QueueEmpty',
+    'QueueShutDown',
+)
 
 import collections
 import heapq
@@ -18,6 +25,11 @@ class QueueFull(Exception):
     pass
 
 
+class QueueShutDown(Exception):
+    """Raised when putting on to or getting from a shut-down Queue."""
+    pass
+
+
 class Queue(mixins._LoopBoundMixin):
     """A queue, useful for coordinating producer and consumer coroutines.
 
@@ -41,6 +53,7 @@ def __init__(self, maxsize=0):
         self._finished = locks.Event()
         self._finished.set()
         self._init(maxsize)
+        self._is_shutdown = False
 
     # These three are overridable in subclasses.
 
@@ -81,6 +94,8 @@ def _format(self):
             result += f' _putters[{len(self._putters)}]'
         if self._unfinished_tasks:
             result += f' tasks={self._unfinished_tasks}'
+        if self._is_shutdown:
+            result += ' shutdown'
         return result
 
     def qsize(self):
@@ -112,8 +127,12 @@ async def put(self, item):
 
         Put an item into the queue. If the queue is full, wait until a free
         slot is available before adding item.
+
+        Raises QueueShutDown if the queue has been shut down.
         """
         while self.full():
+            if self._is_shutdown:
+                raise QueueShutDown
             putter = self._get_loop().create_future()
             self._putters.append(putter)
             try:
@@ -125,7 +144,7 @@ async def put(self, item):
                     self._putters.remove(putter)
                 except ValueError:
                     # The putter could be removed from self._putters by a
-                    # previous get_nowait call.
+                    # previous get_nowait call or a shutdown call.
                     pass
                 if not self.full() and not putter.cancelled():
                     # We were woken up by get_nowait(), but can't take
@@ -138,7 +157,11 @@ def put_nowait(self, item):
         """Put an item into the queue without blocking.
 
         If no free slot is immediately available, raise QueueFull.
+
+        Raises QueueShutDown if the queue has been shut down.
         """
+        if self._is_shutdown:
+            raise QueueShutDown
         if self.full():
             raise QueueFull
         self._put(item)
@@ -150,8 +173,13 @@ async def get(self):
         """Remove and return an item from the queue.
 
         If queue is empty, wait until an item is available.
+
+        Raises QueueShutDown if the queue has been shut down and is empty, or
+        if the queue has been shut down immediately.
         """
         while self.empty():
+            if self._is_shutdown and self.empty():
+                raise QueueShutDown
             getter = self._get_loop().create_future()
             self._getters.append(getter)
             try:
@@ -163,7 +191,7 @@ async def get(self):
                     self._getters.remove(getter)
                 except ValueError:
                     # The getter could be removed from self._getters by a
-                    # previous put_nowait call.
+                    # previous put_nowait call, or a shutdown call.
                     pass
                 if not self.empty() and not getter.cancelled():
                     # We were woken up by put_nowait(), but can't take
@@ -176,8 +204,13 @@ def get_nowait(self):
         """Remove and return an item from the queue.
 
         Return an item if one is immediately available, else raise QueueEmpty.
+
+        Raises QueueShutDown if the queue has been shut down and is empty, or
+        if the queue has been shut down immediately.
         """
         if self.empty():
+            if self._is_shutdown:
+                raise QueueShutDown
             raise QueueEmpty
         item = self._get()
         self._wakeup_next(self._putters)
@@ -194,6 +227,9 @@ def task_done(self):
         been processed (meaning that a task_done() call was received for every
         item that had been put() into the queue).
 
+        shutdown(immediate=True) calls task_done() for each remaining item in
+        the queue.
+
         Raises ValueError if called more times than there were items placed in
         the queue.
         """
@@ -214,6 +250,32 @@ async def join(self):
         if self._unfinished_tasks > 0:
             await self._finished.wait()
 
+    def shutdown(self, immediate=False):
+        """Shut-down the queue, making queue gets and puts raise QueueShutDown.
+
+        By default, gets will only raise once the queue is empty. Set
+        'immediate' to True to make gets raise immediately instead.
+
+        All blocked callers of put() will be unblocked, and also get()
+        and join() if 'immediate'.
+        """
+        self._is_shutdown = True
+        if immediate:
+            while not self.empty():
+                self._get()
+                if self._unfinished_tasks > 0:
+                    self._unfinished_tasks -= 1
+            if self._unfinished_tasks == 0:
+                self._finished.set()
+        while self._getters:
+            getter = self._getters.popleft()
+            if not getter.done():
+                getter.set_result(None)
+        while self._putters:
+            putter = self._putters.popleft()
+            if not putter.done():
+                putter.set_result(None)
+
 
 class PriorityQueue(Queue):
     """A subclass of Queue; retrieves entries in priority order (lowest first).
diff --git a/Lib/test/test_asyncio/test_queues.py 
b/Lib/test/test_asyncio/test_queues.py
index 2d058ccf6a8c72..5019e9a293525d 100644
--- a/Lib/test/test_asyncio/test_queues.py
+++ b/Lib/test/test_asyncio/test_queues.py
@@ -522,5 +522,204 @@ class PriorityQueueJoinTests(_QueueJoinTestMixin, 
unittest.IsolatedAsyncioTestCa
     q_class = asyncio.PriorityQueue
 
 
+class _QueueShutdownTestMixin:
+    q_class = None
+
+    def assertRaisesShutdown(self, msg="Didn't appear to shut-down queue"):
+        return self.assertRaises(asyncio.QueueShutDown, msg=msg)
+
+    async def test_format(self):
+        q = self.q_class()
+        q.shutdown()
+        self.assertEqual(q._format(), 'maxsize=0 shutdown')
+
+    async def test_shutdown_empty(self):
+        # Test shutting down an empty queue
+
+        # Setup empty queue, and join() and get() tasks
+        q = self.q_class()
+        loop = asyncio.get_running_loop()
+        get_task = loop.create_task(q.get())
+        await asyncio.sleep(0)  # want get task pending before shutdown
+
+        # Perform shut-down
+        q.shutdown(immediate=False)  # unfinished tasks: 0 -> 0
+
+        self.assertEqual(q.qsize(), 0)
+
+        # Ensure join() task successfully finishes
+        await q.join()
+
+        # Ensure get() task is finished, and raised ShutDown
+        await asyncio.sleep(0)
+        self.assertTrue(get_task.done())
+        with self.assertRaisesShutdown():
+            await get_task
+
+        # Ensure put() and get() raise ShutDown
+        with self.assertRaisesShutdown():
+            await q.put("data")
+        with self.assertRaisesShutdown():
+            q.put_nowait("data")
+
+        with self.assertRaisesShutdown():
+            await q.get()
+        with self.assertRaisesShutdown():
+            q.get_nowait()
+
+    async def test_shutdown_nonempty(self):
+        # Test shutting down a non-empty queue
+
+        # Setup full queue with 1 item, and join() and put() tasks
+        q = self.q_class(maxsize=1)
+        loop = asyncio.get_running_loop()
+
+        q.put_nowait("data")
+        join_task = loop.create_task(q.join())
+        put_task = loop.create_task(q.put("data2"))
+
+        # Ensure put() task is not finished
+        await asyncio.sleep(0)
+        self.assertFalse(put_task.done())
+
+        # Perform shut-down
+        q.shutdown(immediate=False)  # unfinished tasks: 1 -> 1
+
+        self.assertEqual(q.qsize(), 1)
+
+        # Ensure put() task is finished, and raised ShutDown
+        await asyncio.sleep(0)
+        self.assertTrue(put_task.done())
+        with self.assertRaisesShutdown():
+            await put_task
+
+        # Ensure get() succeeds on enqueued item
+        self.assertEqual(await q.get(), "data")
+
+        # Ensure join() task is not finished
+        await asyncio.sleep(0)
+        self.assertFalse(join_task.done())
+
+        # Ensure put() and get() raise ShutDown
+        with self.assertRaisesShutdown():
+            await q.put("data")
+        with self.assertRaisesShutdown():
+            q.put_nowait("data")
+
+        with self.assertRaisesShutdown():
+            await q.get()
+        with self.assertRaisesShutdown():
+            q.get_nowait()
+
+        # Ensure there is 1 unfinished task, and join() task succeeds
+        q.task_done()
+
+        await asyncio.sleep(0)
+        self.assertTrue(join_task.done())
+        await join_task
+
+        with self.assertRaises(
+            ValueError, msg="Didn't appear to mark all tasks done"
+        ):
+            q.task_done()
+
+    async def test_shutdown_immediate(self):
+        # Test immediately shutting down a queue
+
+        # Setup queue with 1 item, and a join() task
+        q = self.q_class()
+        loop = asyncio.get_running_loop()
+        q.put_nowait("data")
+        join_task = loop.create_task(q.join())
+
+        # Perform shut-down
+        q.shutdown(immediate=True)  # unfinished tasks: 1 -> 0
+
+        self.assertEqual(q.qsize(), 0)
+
+        # Ensure join() task has successfully finished
+        await asyncio.sleep(0)
+        self.assertTrue(join_task.done())
+        await join_task
+
+        # Ensure put() and get() raise ShutDown
+        with self.assertRaisesShutdown():
+            await q.put("data")
+        with self.assertRaisesShutdown():
+            q.put_nowait("data")
+
+        with self.assertRaisesShutdown():
+            await q.get()
+        with self.assertRaisesShutdown():
+            q.get_nowait()
+
+        # Ensure there are no unfinished tasks
+        with self.assertRaises(
+            ValueError, msg="Didn't appear to mark all tasks done"
+        ):
+            q.task_done()
+
+    async def test_shutdown_immediate_with_unfinished(self):
+        # Test immediately shutting down a queue with unfinished tasks
+
+        # Setup queue with 2 items (1 retrieved), and a join() task
+        q = self.q_class()
+        loop = asyncio.get_running_loop()
+        q.put_nowait("data")
+        q.put_nowait("data")
+        join_task = loop.create_task(q.join())
+        self.assertEqual(await q.get(), "data")
+
+        # Perform shut-down
+        q.shutdown(immediate=True)  # unfinished tasks: 2 -> 1
+
+        self.assertEqual(q.qsize(), 0)
+
+        # Ensure join() task is not finished
+        await asyncio.sleep(0)
+        self.assertFalse(join_task.done())
+
+        # Ensure put() and get() raise ShutDown
+        with self.assertRaisesShutdown():
+            await q.put("data")
+        with self.assertRaisesShutdown():
+            q.put_nowait("data")
+
+        with self.assertRaisesShutdown():
+            await q.get()
+        with self.assertRaisesShutdown():
+            q.get_nowait()
+
+        # Ensure there is 1 unfinished task
+        q.task_done()
+        with self.assertRaises(
+            ValueError, msg="Didn't appear to mark all tasks done"
+        ):
+            q.task_done()
+
+        # Ensure join() task has successfully finished
+        await asyncio.sleep(0)
+        self.assertTrue(join_task.done())
+        await join_task
+
+
+class QueueShutdownTests(
+    _QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase
+):
+    q_class = asyncio.Queue
+
+
+class LifoQueueShutdownTests(
+    _QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase
+):
+    q_class = asyncio.LifoQueue
+
+
+class PriorityQueueShutdownTests(
+    _QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase
+):
+    q_class = asyncio.PriorityQueue
+
+
 if __name__ == '__main__':
     unittest.main()
diff --git 
a/Misc/NEWS.d/next/Library/2023-05-06-05-00-42.gh-issue-96471.S3X5I-.rst 
b/Misc/NEWS.d/next/Library/2023-05-06-05-00-42.gh-issue-96471.S3X5I-.rst
new file mode 100644
index 00000000000000..128a85d3d73ddf
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2023-05-06-05-00-42.gh-issue-96471.S3X5I-.rst
@@ -0,0 +1,2 @@
+Add :py:class:`asyncio.Queue` termination with
+:py:meth:`~asyncio.Queue.shutdown` method.

_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: [email protected]

Reply via email to