https://github.com/python/cpython/commit/775912a51d6847b0e4fe415fa91f2e0b06a3c43c
commit: 775912a51d6847b0e4fe415fa91f2e0b06a3c43c
branch: main
author: Bruce Merry <[email protected]>
committer: gvanrossum <[email protected]>
date: 2024-04-08T09:58:02-07:00
summary:

gh-81322: support multiple separators in StreamReader.readuntil (#16429)

files:
A Misc/NEWS.d/next/Library/2019-09-26-17-52-52.bpo-37141.onYY2-.rst
M Doc/library/asyncio-stream.rst
M Lib/asyncio/streams.py
M Lib/test/test_asyncio/test_streams.py

diff --git a/Doc/library/asyncio-stream.rst b/Doc/library/asyncio-stream.rst
index 68b1dff20213e1..6231b49b1e2431 100644
--- a/Doc/library/asyncio-stream.rst
+++ b/Doc/library/asyncio-stream.rst
@@ -260,8 +260,19 @@ StreamReader
       buffer is reset.  The :attr:`IncompleteReadError.partial` attribute
       may contain a portion of the separator.
 
+      The *separator* may also be an :term:`iterable` of separators. In this
+      case the return value will be the shortest possible that has any
+      separator as the suffix. For the purposes of :exc:`LimitOverrunError`,
+      the shortest possible separator is considered to be the one that
+      matched.
+
       .. versionadded:: 3.5.2
 
+      .. versionchanged:: 3.13
+
+         The *separator* parameter may now be an :term:`iterable` of
+         separators.
+
    .. method:: at_eof()
 
       Return ``True`` if the buffer is empty and :meth:`feed_eof`
diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py
index 3fe52dbac25c91..4517ca22d74637 100644
--- a/Lib/asyncio/streams.py
+++ b/Lib/asyncio/streams.py
@@ -590,20 +590,34 @@ async def readuntil(self, separator=b'\n'):
         If the data cannot be read because of over limit, a
         LimitOverrunError exception  will be raised, and the data
         will be left in the internal buffer, so it can be read again.
+
+        The ``separator`` may also be an iterable of separators. In this
+        case the return value will be the shortest possible that has any
+        separator as the suffix. For the purposes of LimitOverrunError,
+        the shortest possible separator is considered to be the one that
+        matched.
         """
-        seplen = len(separator)
-        if seplen == 0:
+        if isinstance(separator, bytes):
+            separator = [separator]
+        else:
+            # Makes sure shortest matches wins, and supports arbitrary 
iterables
+            separator = sorted(separator, key=len)
+        if not separator:
+            raise ValueError('Separator should contain at least one element')
+        min_seplen = len(separator[0])
+        max_seplen = len(separator[-1])
+        if min_seplen == 0:
             raise ValueError('Separator should be at least one-byte string')
 
         if self._exception is not None:
             raise self._exception
 
         # Consume whole buffer except last bytes, which length is
-        # one less than seplen. Let's check corner cases with
-        # separator='SEPARATOR':
+        # one less than max_seplen. Let's check corner cases with
+        # separator[-1]='SEPARATOR':
         # * we have received almost complete separator (without last
         #   byte). i.e buffer='some textSEPARATO'. In this case we
-        #   can safely consume len(separator) - 1 bytes.
+        #   can safely consume max_seplen - 1 bytes.
         # * last byte of buffer is first byte of separator, i.e.
         #   buffer='abcdefghijklmnopqrS'. We may safely consume
         #   everything except that last byte, but this require to
@@ -616,26 +630,35 @@ async def readuntil(self, separator=b'\n'):
         #   messages :)
 
         # `offset` is the number of bytes from the beginning of the buffer
-        # where there is no occurrence of `separator`.
+        # where there is no occurrence of any `separator`.
         offset = 0
 
-        # Loop until we find `separator` in the buffer, exceed the buffer size,
+        # Loop until we find a `separator` in the buffer, exceed the buffer 
size,
         # or an EOF has happened.
         while True:
             buflen = len(self._buffer)
 
-            # Check if we now have enough data in the buffer for `separator` to
-            # fit.
-            if buflen - offset >= seplen:
-                isep = self._buffer.find(separator, offset)
-
-                if isep != -1:
-                    # `separator` is in the buffer. `isep` will be used later
-                    # to retrieve the data.
+            # Check if we now have enough data in the buffer for shortest
+            # separator to fit.
+            if buflen - offset >= min_seplen:
+                match_start = None
+                match_end = None
+                for sep in separator:
+                    isep = self._buffer.find(sep, offset)
+
+                    if isep != -1:
+                        # `separator` is in the buffer. `match_start` and
+                        # `match_end` will be used later to retrieve the
+                        # data.
+                        end = isep + len(sep)
+                        if match_end is None or end < match_end:
+                            match_end = end
+                            match_start = isep
+                if match_end is not None:
                     break
 
                 # see upper comment for explanation.
-                offset = buflen + 1 - seplen
+                offset = max(0, buflen + 1 - max_seplen)
                 if offset > self._limit:
                     raise exceptions.LimitOverrunError(
                         'Separator is not found, and chunk exceed the limit',
@@ -644,7 +667,7 @@ async def readuntil(self, separator=b'\n'):
             # Complete message (with full separator) may be present in buffer
             # even when EOF flag is set. This may happen when the last chunk
             # adds data which makes separator be found. That's why we check for
-            # EOF *ater* inspecting the buffer.
+            # EOF *after* inspecting the buffer.
             if self._eof:
                 chunk = bytes(self._buffer)
                 self._buffer.clear()
@@ -653,12 +676,12 @@ async def readuntil(self, separator=b'\n'):
             # _wait_for_data() will resume reading if stream was paused.
             await self._wait_for_data('readuntil')
 
-        if isep > self._limit:
+        if match_start > self._limit:
             raise exceptions.LimitOverrunError(
-                'Separator is found, but chunk is longer than limit', isep)
+                'Separator is found, but chunk is longer than limit', 
match_start)
 
-        chunk = self._buffer[:isep + seplen]
-        del self._buffer[:isep + seplen]
+        chunk = self._buffer[:match_end]
+        del self._buffer[:match_end]
         self._maybe_resume_transport()
         return bytes(chunk)
 
diff --git a/Lib/test/test_asyncio/test_streams.py 
b/Lib/test/test_asyncio/test_streams.py
index 2cf48538d5d30d..792e88761acdc2 100644
--- a/Lib/test/test_asyncio/test_streams.py
+++ b/Lib/test/test_asyncio/test_streams.py
@@ -383,6 +383,10 @@ def test_readuntil_separator(self):
         stream = asyncio.StreamReader(loop=self.loop)
         with self.assertRaisesRegex(ValueError, 'Separator should be'):
             self.loop.run_until_complete(stream.readuntil(separator=b''))
+        with self.assertRaisesRegex(ValueError, 'Separator should be'):
+            self.loop.run_until_complete(stream.readuntil(separator=[b'']))
+        with self.assertRaisesRegex(ValueError, 'Separator should contain'):
+            self.loop.run_until_complete(stream.readuntil(separator=[]))
 
     def test_readuntil_multi_chunks(self):
         stream = asyncio.StreamReader(loop=self.loop)
@@ -466,6 +470,48 @@ def test_readuntil_limit_found_sep(self):
 
         self.assertEqual(b'some dataAAA', stream._buffer)
 
+    def test_readuntil_multi_separator(self):
+        stream = asyncio.StreamReader(loop=self.loop)
+
+        # Simple case
+        stream.feed_data(b'line 1\nline 2\r')
+        data = self.loop.run_until_complete(stream.readuntil([b'\r', b'\n']))
+        self.assertEqual(b'line 1\n', data)
+        data = self.loop.run_until_complete(stream.readuntil([b'\r', b'\n']))
+        self.assertEqual(b'line 2\r', data)
+        self.assertEqual(b'', stream._buffer)
+
+        # First end position matches, even if that's a longer match
+        stream.feed_data(b'ABCDEFG')
+        data = self.loop.run_until_complete(stream.readuntil([b'DEF', 
b'BCDE']))
+        self.assertEqual(b'ABCDE', data)
+        self.assertEqual(b'FG', stream._buffer)
+
+    def test_readuntil_multi_separator_limit(self):
+        stream = asyncio.StreamReader(loop=self.loop, limit=3)
+        stream.feed_data(b'some dataA')
+
+        with self.assertRaisesRegex(asyncio.LimitOverrunError,
+                                    'is found') as cm:
+            self.loop.run_until_complete(stream.readuntil([b'A', b'ome 
dataA']))
+
+        self.assertEqual(b'some dataA', stream._buffer)
+
+    def test_readuntil_multi_separator_negative_offset(self):
+        # If the buffer is big enough for the smallest separator (but does
+        # not contain it) but too small for the largest, `offset` must not
+        # become negative.
+        stream = asyncio.StreamReader(loop=self.loop)
+        stream.feed_data(b'data')
+
+        readuntil_task = self.loop.create_task(stream.readuntil([b'A', b'long 
sep']))
+        self.loop.call_soon(stream.feed_data, b'Z')
+        self.loop.call_soon(stream.feed_data, b'Aaaa')
+
+        data = self.loop.run_until_complete(readuntil_task)
+        self.assertEqual(b'dataZA', data)
+        self.assertEqual(b'aaa', stream._buffer)
+
     def test_readexactly_zero_or_less(self):
         # Read exact number of bytes (zero or less).
         stream = asyncio.StreamReader(loop=self.loop)
diff --git a/Misc/NEWS.d/next/Library/2019-09-26-17-52-52.bpo-37141.onYY2-.rst 
b/Misc/NEWS.d/next/Library/2019-09-26-17-52-52.bpo-37141.onYY2-.rst
new file mode 100644
index 00000000000000..d916f319947dc4
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2019-09-26-17-52-52.bpo-37141.onYY2-.rst
@@ -0,0 +1,2 @@
+Accept an iterable of separators in :meth:`asyncio.StreamReader.readuntil`, 
stopping
+when one of them is encountered.

_______________________________________________
Python-checkins mailing list -- [email protected]
To unsubscribe send an email to [email protected]
https://mail.python.org/mailman3/lists/python-checkins.python.org/
Member address: [email protected]

Reply via email to