mirror of
				https://github.com/python/cpython.git
				synced 2025-10-31 05:31:20 +00:00 
			
		
		
		
	gh-81322: support multiple separators in StreamReader.readuntil (#16429)
This commit is contained in:
		
							parent
							
								
									24a2bd0481
								
							
						
					
					
						commit
						775912a51d
					
				
					 4 changed files with 102 additions and 20 deletions
				
			
		|  | @ -260,8 +260,19 @@ StreamReader | ||||||
|       buffer is reset.  The :attr:`IncompleteReadError.partial` attribute |       buffer is reset.  The :attr:`IncompleteReadError.partial` attribute | ||||||
|       may contain a portion of the separator. |       may contain a portion of the separator. | ||||||
| 
 | 
 | ||||||
|  |       The *separator* may also be an :term:`iterable` of separators. In this | ||||||
|  |       case the return value will be the shortest possible that has any | ||||||
|  |       separator as the suffix. For the purposes of :exc:`LimitOverrunError`, | ||||||
|  |       the shortest possible separator is considered to be the one that | ||||||
|  |       matched. | ||||||
|  | 
 | ||||||
|       .. versionadded:: 3.5.2 |       .. versionadded:: 3.5.2 | ||||||
| 
 | 
 | ||||||
|  |       .. versionchanged:: 3.13 | ||||||
|  | 
 | ||||||
|  |          The *separator* parameter may now be an :term:`iterable` of | ||||||
|  |          separators. | ||||||
|  | 
 | ||||||
|    .. method:: at_eof() |    .. method:: at_eof() | ||||||
| 
 | 
 | ||||||
|       Return ``True`` if the buffer is empty and :meth:`feed_eof` |       Return ``True`` if the buffer is empty and :meth:`feed_eof` | ||||||
|  |  | ||||||
|  | @ -590,20 +590,34 @@ async def readuntil(self, separator=b'\n'): | ||||||
|         If the data cannot be read because of over limit, a |         If the data cannot be read because of over limit, a | ||||||
|         LimitOverrunError exception  will be raised, and the data |         LimitOverrunError exception  will be raised, and the data | ||||||
|         will be left in the internal buffer, so it can be read again. |         will be left in the internal buffer, so it can be read again. | ||||||
|  | 
 | ||||||
|  |         The ``separator`` may also be an iterable of separators. In this | ||||||
|  |         case the return value will be the shortest possible that has any | ||||||
|  |         separator as the suffix. For the purposes of LimitOverrunError, | ||||||
|  |         the shortest possible separator is considered to be the one that | ||||||
|  |         matched. | ||||||
|         """ |         """ | ||||||
|         seplen = len(separator) |         if isinstance(separator, bytes): | ||||||
|         if seplen == 0: |             separator = [separator] | ||||||
|  |         else: | ||||||
|  |             # Makes sure shortest matches wins, and supports arbitrary iterables | ||||||
|  |             separator = sorted(separator, key=len) | ||||||
|  |         if not separator: | ||||||
|  |             raise ValueError('Separator should contain at least one element') | ||||||
|  |         min_seplen = len(separator[0]) | ||||||
|  |         max_seplen = len(separator[-1]) | ||||||
|  |         if min_seplen == 0: | ||||||
|             raise ValueError('Separator should be at least one-byte string') |             raise ValueError('Separator should be at least one-byte string') | ||||||
| 
 | 
 | ||||||
|         if self._exception is not None: |         if self._exception is not None: | ||||||
|             raise self._exception |             raise self._exception | ||||||
| 
 | 
 | ||||||
|         # Consume whole buffer except last bytes, which length is |         # Consume whole buffer except last bytes, which length is | ||||||
|         # one less than seplen. Let's check corner cases with |         # one less than max_seplen. Let's check corner cases with | ||||||
|         # separator='SEPARATOR': |         # separator[-1]='SEPARATOR': | ||||||
|         # * we have received almost complete separator (without last |         # * we have received almost complete separator (without last | ||||||
|         #   byte). i.e buffer='some textSEPARATO'. In this case we |         #   byte). i.e buffer='some textSEPARATO'. In this case we | ||||||
|         #   can safely consume len(separator) - 1 bytes. |         #   can safely consume max_seplen - 1 bytes. | ||||||
|         # * last byte of buffer is first byte of separator, i.e. |         # * last byte of buffer is first byte of separator, i.e. | ||||||
|         #   buffer='abcdefghijklmnopqrS'. We may safely consume |         #   buffer='abcdefghijklmnopqrS'. We may safely consume | ||||||
|         #   everything except that last byte, but this require to |         #   everything except that last byte, but this require to | ||||||
|  | @ -616,26 +630,35 @@ async def readuntil(self, separator=b'\n'): | ||||||
|         #   messages :) |         #   messages :) | ||||||
| 
 | 
 | ||||||
|         # `offset` is the number of bytes from the beginning of the buffer |         # `offset` is the number of bytes from the beginning of the buffer | ||||||
|         # where there is no occurrence of `separator`. |         # where there is no occurrence of any `separator`. | ||||||
|         offset = 0 |         offset = 0 | ||||||
| 
 | 
 | ||||||
|         # Loop until we find `separator` in the buffer, exceed the buffer size, |         # Loop until we find a `separator` in the buffer, exceed the buffer size, | ||||||
|         # or an EOF has happened. |         # or an EOF has happened. | ||||||
|         while True: |         while True: | ||||||
|             buflen = len(self._buffer) |             buflen = len(self._buffer) | ||||||
| 
 | 
 | ||||||
|             # Check if we now have enough data in the buffer for `separator` to |             # Check if we now have enough data in the buffer for shortest | ||||||
|             # fit. |             # separator to fit. | ||||||
|             if buflen - offset >= seplen: |             if buflen - offset >= min_seplen: | ||||||
|                 isep = self._buffer.find(separator, offset) |                 match_start = None | ||||||
|  |                 match_end = None | ||||||
|  |                 for sep in separator: | ||||||
|  |                     isep = self._buffer.find(sep, offset) | ||||||
| 
 | 
 | ||||||
|                     if isep != -1: |                     if isep != -1: | ||||||
|                     # `separator` is in the buffer. `isep` will be used later |                         # `separator` is in the buffer. `match_start` and | ||||||
|                     # to retrieve the data. |                         # `match_end` will be used later to retrieve the | ||||||
|  |                         # data. | ||||||
|  |                         end = isep + len(sep) | ||||||
|  |                         if match_end is None or end < match_end: | ||||||
|  |                             match_end = end | ||||||
|  |                             match_start = isep | ||||||
|  |                 if match_end is not None: | ||||||
|                     break |                     break | ||||||
| 
 | 
 | ||||||
|                 # see upper comment for explanation. |                 # see upper comment for explanation. | ||||||
|                 offset = buflen + 1 - seplen |                 offset = max(0, buflen + 1 - max_seplen) | ||||||
|                 if offset > self._limit: |                 if offset > self._limit: | ||||||
|                     raise exceptions.LimitOverrunError( |                     raise exceptions.LimitOverrunError( | ||||||
|                         'Separator is not found, and chunk exceed the limit', |                         'Separator is not found, and chunk exceed the limit', | ||||||
|  | @ -644,7 +667,7 @@ async def readuntil(self, separator=b'\n'): | ||||||
|             # Complete message (with full separator) may be present in buffer |             # Complete message (with full separator) may be present in buffer | ||||||
|             # even when EOF flag is set. This may happen when the last chunk |             # even when EOF flag is set. This may happen when the last chunk | ||||||
|             # adds data which makes separator be found. That's why we check for |             # adds data which makes separator be found. That's why we check for | ||||||
|             # EOF *ater* inspecting the buffer. |             # EOF *after* inspecting the buffer. | ||||||
|             if self._eof: |             if self._eof: | ||||||
|                 chunk = bytes(self._buffer) |                 chunk = bytes(self._buffer) | ||||||
|                 self._buffer.clear() |                 self._buffer.clear() | ||||||
|  | @ -653,12 +676,12 @@ async def readuntil(self, separator=b'\n'): | ||||||
|             # _wait_for_data() will resume reading if stream was paused. |             # _wait_for_data() will resume reading if stream was paused. | ||||||
|             await self._wait_for_data('readuntil') |             await self._wait_for_data('readuntil') | ||||||
| 
 | 
 | ||||||
|         if isep > self._limit: |         if match_start > self._limit: | ||||||
|             raise exceptions.LimitOverrunError( |             raise exceptions.LimitOverrunError( | ||||||
|                 'Separator is found, but chunk is longer than limit', isep) |                 'Separator is found, but chunk is longer than limit', match_start) | ||||||
| 
 | 
 | ||||||
|         chunk = self._buffer[:isep + seplen] |         chunk = self._buffer[:match_end] | ||||||
|         del self._buffer[:isep + seplen] |         del self._buffer[:match_end] | ||||||
|         self._maybe_resume_transport() |         self._maybe_resume_transport() | ||||||
|         return bytes(chunk) |         return bytes(chunk) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -383,6 +383,10 @@ def test_readuntil_separator(self): | ||||||
|         stream = asyncio.StreamReader(loop=self.loop) |         stream = asyncio.StreamReader(loop=self.loop) | ||||||
|         with self.assertRaisesRegex(ValueError, 'Separator should be'): |         with self.assertRaisesRegex(ValueError, 'Separator should be'): | ||||||
|             self.loop.run_until_complete(stream.readuntil(separator=b'')) |             self.loop.run_until_complete(stream.readuntil(separator=b'')) | ||||||
|  |         with self.assertRaisesRegex(ValueError, 'Separator should be'): | ||||||
|  |             self.loop.run_until_complete(stream.readuntil(separator=[b''])) | ||||||
|  |         with self.assertRaisesRegex(ValueError, 'Separator should contain'): | ||||||
|  |             self.loop.run_until_complete(stream.readuntil(separator=[])) | ||||||
| 
 | 
 | ||||||
|     def test_readuntil_multi_chunks(self): |     def test_readuntil_multi_chunks(self): | ||||||
|         stream = asyncio.StreamReader(loop=self.loop) |         stream = asyncio.StreamReader(loop=self.loop) | ||||||
|  | @ -466,6 +470,48 @@ def test_readuntil_limit_found_sep(self): | ||||||
| 
 | 
 | ||||||
|         self.assertEqual(b'some dataAAA', stream._buffer) |         self.assertEqual(b'some dataAAA', stream._buffer) | ||||||
| 
 | 
 | ||||||
|  |     def test_readuntil_multi_separator(self): | ||||||
|  |         stream = asyncio.StreamReader(loop=self.loop) | ||||||
|  | 
 | ||||||
|  |         # Simple case | ||||||
|  |         stream.feed_data(b'line 1\nline 2\r') | ||||||
|  |         data = self.loop.run_until_complete(stream.readuntil([b'\r', b'\n'])) | ||||||
|  |         self.assertEqual(b'line 1\n', data) | ||||||
|  |         data = self.loop.run_until_complete(stream.readuntil([b'\r', b'\n'])) | ||||||
|  |         self.assertEqual(b'line 2\r', data) | ||||||
|  |         self.assertEqual(b'', stream._buffer) | ||||||
|  | 
 | ||||||
|  |         # First end position matches, even if that's a longer match | ||||||
|  |         stream.feed_data(b'ABCDEFG') | ||||||
|  |         data = self.loop.run_until_complete(stream.readuntil([b'DEF', b'BCDE'])) | ||||||
|  |         self.assertEqual(b'ABCDE', data) | ||||||
|  |         self.assertEqual(b'FG', stream._buffer) | ||||||
|  | 
 | ||||||
|  |     def test_readuntil_multi_separator_limit(self): | ||||||
|  |         stream = asyncio.StreamReader(loop=self.loop, limit=3) | ||||||
|  |         stream.feed_data(b'some dataA') | ||||||
|  | 
 | ||||||
|  |         with self.assertRaisesRegex(asyncio.LimitOverrunError, | ||||||
|  |                                     'is found') as cm: | ||||||
|  |             self.loop.run_until_complete(stream.readuntil([b'A', b'ome dataA'])) | ||||||
|  | 
 | ||||||
|  |         self.assertEqual(b'some dataA', stream._buffer) | ||||||
|  | 
 | ||||||
|  |     def test_readuntil_multi_separator_negative_offset(self): | ||||||
|  |         # If the buffer is big enough for the smallest separator (but does | ||||||
|  |         # not contain it) but too small for the largest, `offset` must not | ||||||
|  |         # become negative. | ||||||
|  |         stream = asyncio.StreamReader(loop=self.loop) | ||||||
|  |         stream.feed_data(b'data') | ||||||
|  | 
 | ||||||
|  |         readuntil_task = self.loop.create_task(stream.readuntil([b'A', b'long sep'])) | ||||||
|  |         self.loop.call_soon(stream.feed_data, b'Z') | ||||||
|  |         self.loop.call_soon(stream.feed_data, b'Aaaa') | ||||||
|  | 
 | ||||||
|  |         data = self.loop.run_until_complete(readuntil_task) | ||||||
|  |         self.assertEqual(b'dataZA', data) | ||||||
|  |         self.assertEqual(b'aaa', stream._buffer) | ||||||
|  | 
 | ||||||
|     def test_readexactly_zero_or_less(self): |     def test_readexactly_zero_or_less(self): | ||||||
|         # Read exact number of bytes (zero or less). |         # Read exact number of bytes (zero or less). | ||||||
|         stream = asyncio.StreamReader(loop=self.loop) |         stream = asyncio.StreamReader(loop=self.loop) | ||||||
|  |  | ||||||
|  | @ -0,0 +1,2 @@ | ||||||
|  | Accept an iterable of separators in :meth:`asyncio.StreamReader.readuntil`, stopping | ||||||
|  | when one of them is encountered. | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Bruce Merry
						Bruce Merry