mirror of
				https://github.com/python/cpython.git
				synced 2025-10-31 13:41:24 +00:00 
			
		
		
		
	gh-81322: support multiple separators in StreamReader.readuntil (#16429)
This commit is contained in:
		
							parent
							
								
									24a2bd0481
								
							
						
					
					
						commit
						775912a51d
					
				
					 4 changed files with 102 additions and 20 deletions
				
			
		|  | @ -260,8 +260,19 @@ StreamReader | |||
|       buffer is reset.  The :attr:`IncompleteReadError.partial` attribute | ||||
|       may contain a portion of the separator. | ||||
| 
 | ||||
|       The *separator* may also be an :term:`iterable` of separators. In this | ||||
|       case the return value will be the shortest possible that has any | ||||
|       separator as the suffix. For the purposes of :exc:`LimitOverrunError`, | ||||
|       the shortest possible separator is considered to be the one that | ||||
|       matched. | ||||
| 
 | ||||
|       .. versionadded:: 3.5.2 | ||||
| 
 | ||||
|       .. versionchanged:: 3.13 | ||||
| 
 | ||||
|          The *separator* parameter may now be an :term:`iterable` of | ||||
|          separators. | ||||
| 
 | ||||
|    .. method:: at_eof() | ||||
| 
 | ||||
|       Return ``True`` if the buffer is empty and :meth:`feed_eof` | ||||
|  |  | |||
|  | @ -590,20 +590,34 @@ async def readuntil(self, separator=b'\n'): | |||
|         If the data cannot be read because of over limit, a | ||||
|         LimitOverrunError exception  will be raised, and the data | ||||
|         will be left in the internal buffer, so it can be read again. | ||||
| 
 | ||||
|         The ``separator`` may also be an iterable of separators. In this | ||||
|         case the return value will be the shortest possible that has any | ||||
|         separator as the suffix. For the purposes of LimitOverrunError, | ||||
|         the shortest possible separator is considered to be the one that | ||||
|         matched. | ||||
|         """ | ||||
|         seplen = len(separator) | ||||
|         if seplen == 0: | ||||
|         if isinstance(separator, bytes): | ||||
|             separator = [separator] | ||||
|         else: | ||||
|             # Makes sure shortest matches wins, and supports arbitrary iterables | ||||
|             separator = sorted(separator, key=len) | ||||
|         if not separator: | ||||
|             raise ValueError('Separator should contain at least one element') | ||||
|         min_seplen = len(separator[0]) | ||||
|         max_seplen = len(separator[-1]) | ||||
|         if min_seplen == 0: | ||||
|             raise ValueError('Separator should be at least one-byte string') | ||||
| 
 | ||||
|         if self._exception is not None: | ||||
|             raise self._exception | ||||
| 
 | ||||
|         # Consume whole buffer except last bytes, which length is | ||||
|         # one less than seplen. Let's check corner cases with | ||||
|         # separator='SEPARATOR': | ||||
|         # one less than max_seplen. Let's check corner cases with | ||||
|         # separator[-1]='SEPARATOR': | ||||
|         # * we have received almost complete separator (without last | ||||
|         #   byte). i.e buffer='some textSEPARATO'. In this case we | ||||
|         #   can safely consume len(separator) - 1 bytes. | ||||
|         #   can safely consume max_seplen - 1 bytes. | ||||
|         # * last byte of buffer is first byte of separator, i.e. | ||||
|         #   buffer='abcdefghijklmnopqrS'. We may safely consume | ||||
|         #   everything except that last byte, but this require to | ||||
|  | @ -616,26 +630,35 @@ async def readuntil(self, separator=b'\n'): | |||
|         #   messages :) | ||||
| 
 | ||||
|         # `offset` is the number of bytes from the beginning of the buffer | ||||
|         # where there is no occurrence of `separator`. | ||||
|         # where there is no occurrence of any `separator`. | ||||
|         offset = 0 | ||||
| 
 | ||||
|         # Loop until we find `separator` in the buffer, exceed the buffer size, | ||||
|         # Loop until we find a `separator` in the buffer, exceed the buffer size, | ||||
|         # or an EOF has happened. | ||||
|         while True: | ||||
|             buflen = len(self._buffer) | ||||
| 
 | ||||
|             # Check if we now have enough data in the buffer for `separator` to | ||||
|             # fit. | ||||
|             if buflen - offset >= seplen: | ||||
|                 isep = self._buffer.find(separator, offset) | ||||
|             # Check if we now have enough data in the buffer for shortest | ||||
|             # separator to fit. | ||||
|             if buflen - offset >= min_seplen: | ||||
|                 match_start = None | ||||
|                 match_end = None | ||||
|                 for sep in separator: | ||||
|                     isep = self._buffer.find(sep, offset) | ||||
| 
 | ||||
|                 if isep != -1: | ||||
|                     # `separator` is in the buffer. `isep` will be used later | ||||
|                     # to retrieve the data. | ||||
|                     if isep != -1: | ||||
|                         # `separator` is in the buffer. `match_start` and | ||||
|                         # `match_end` will be used later to retrieve the | ||||
|                         # data. | ||||
|                         end = isep + len(sep) | ||||
|                         if match_end is None or end < match_end: | ||||
|                             match_end = end | ||||
|                             match_start = isep | ||||
|                 if match_end is not None: | ||||
|                     break | ||||
| 
 | ||||
|                 # see upper comment for explanation. | ||||
|                 offset = buflen + 1 - seplen | ||||
|                 offset = max(0, buflen + 1 - max_seplen) | ||||
|                 if offset > self._limit: | ||||
|                     raise exceptions.LimitOverrunError( | ||||
|                         'Separator is not found, and chunk exceed the limit', | ||||
|  | @ -644,7 +667,7 @@ async def readuntil(self, separator=b'\n'): | |||
|             # Complete message (with full separator) may be present in buffer | ||||
|             # even when EOF flag is set. This may happen when the last chunk | ||||
|             # adds data which makes separator be found. That's why we check for | ||||
|             # EOF *ater* inspecting the buffer. | ||||
|             # EOF *after* inspecting the buffer. | ||||
|             if self._eof: | ||||
|                 chunk = bytes(self._buffer) | ||||
|                 self._buffer.clear() | ||||
|  | @ -653,12 +676,12 @@ async def readuntil(self, separator=b'\n'): | |||
|             # _wait_for_data() will resume reading if stream was paused. | ||||
|             await self._wait_for_data('readuntil') | ||||
| 
 | ||||
|         if isep > self._limit: | ||||
|         if match_start > self._limit: | ||||
|             raise exceptions.LimitOverrunError( | ||||
|                 'Separator is found, but chunk is longer than limit', isep) | ||||
|                 'Separator is found, but chunk is longer than limit', match_start) | ||||
| 
 | ||||
|         chunk = self._buffer[:isep + seplen] | ||||
|         del self._buffer[:isep + seplen] | ||||
|         chunk = self._buffer[:match_end] | ||||
|         del self._buffer[:match_end] | ||||
|         self._maybe_resume_transport() | ||||
|         return bytes(chunk) | ||||
| 
 | ||||
|  |  | |||
|  | @ -383,6 +383,10 @@ def test_readuntil_separator(self): | |||
|         stream = asyncio.StreamReader(loop=self.loop) | ||||
|         with self.assertRaisesRegex(ValueError, 'Separator should be'): | ||||
|             self.loop.run_until_complete(stream.readuntil(separator=b'')) | ||||
|         with self.assertRaisesRegex(ValueError, 'Separator should be'): | ||||
|             self.loop.run_until_complete(stream.readuntil(separator=[b''])) | ||||
|         with self.assertRaisesRegex(ValueError, 'Separator should contain'): | ||||
|             self.loop.run_until_complete(stream.readuntil(separator=[])) | ||||
| 
 | ||||
|     def test_readuntil_multi_chunks(self): | ||||
|         stream = asyncio.StreamReader(loop=self.loop) | ||||
|  | @ -466,6 +470,48 @@ def test_readuntil_limit_found_sep(self): | |||
| 
 | ||||
|         self.assertEqual(b'some dataAAA', stream._buffer) | ||||
| 
 | ||||
|     def test_readuntil_multi_separator(self): | ||||
|         stream = asyncio.StreamReader(loop=self.loop) | ||||
| 
 | ||||
|         # Simple case | ||||
|         stream.feed_data(b'line 1\nline 2\r') | ||||
|         data = self.loop.run_until_complete(stream.readuntil([b'\r', b'\n'])) | ||||
|         self.assertEqual(b'line 1\n', data) | ||||
|         data = self.loop.run_until_complete(stream.readuntil([b'\r', b'\n'])) | ||||
|         self.assertEqual(b'line 2\r', data) | ||||
|         self.assertEqual(b'', stream._buffer) | ||||
| 
 | ||||
|         # First end position matches, even if that's a longer match | ||||
|         stream.feed_data(b'ABCDEFG') | ||||
|         data = self.loop.run_until_complete(stream.readuntil([b'DEF', b'BCDE'])) | ||||
|         self.assertEqual(b'ABCDE', data) | ||||
|         self.assertEqual(b'FG', stream._buffer) | ||||
| 
 | ||||
|     def test_readuntil_multi_separator_limit(self): | ||||
|         stream = asyncio.StreamReader(loop=self.loop, limit=3) | ||||
|         stream.feed_data(b'some dataA') | ||||
| 
 | ||||
|         with self.assertRaisesRegex(asyncio.LimitOverrunError, | ||||
|                                     'is found') as cm: | ||||
|             self.loop.run_until_complete(stream.readuntil([b'A', b'ome dataA'])) | ||||
| 
 | ||||
|         self.assertEqual(b'some dataA', stream._buffer) | ||||
| 
 | ||||
|     def test_readuntil_multi_separator_negative_offset(self): | ||||
|         # If the buffer is big enough for the smallest separator (but does | ||||
|         # not contain it) but too small for the largest, `offset` must not | ||||
|         # become negative. | ||||
|         stream = asyncio.StreamReader(loop=self.loop) | ||||
|         stream.feed_data(b'data') | ||||
| 
 | ||||
|         readuntil_task = self.loop.create_task(stream.readuntil([b'A', b'long sep'])) | ||||
|         self.loop.call_soon(stream.feed_data, b'Z') | ||||
|         self.loop.call_soon(stream.feed_data, b'Aaaa') | ||||
| 
 | ||||
|         data = self.loop.run_until_complete(readuntil_task) | ||||
|         self.assertEqual(b'dataZA', data) | ||||
|         self.assertEqual(b'aaa', stream._buffer) | ||||
| 
 | ||||
|     def test_readexactly_zero_or_less(self): | ||||
|         # Read exact number of bytes (zero or less). | ||||
|         stream = asyncio.StreamReader(loop=self.loop) | ||||
|  |  | |||
|  | @ -0,0 +1,2 @@ | |||
| Accept an iterable of separators in :meth:`asyncio.StreamReader.readuntil`, stopping | ||||
| when one of them is encountered. | ||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Bruce Merry
						Bruce Merry