mirror of
				https://github.com/python/cpython.git
				synced 2025-10-31 13:41:24 +00:00 
			
		
		
		
	bpo-40282: Allow random.getrandbits(0) (GH-19539)
This commit is contained in:
		
							parent
							
								
									d7c657d4b1
								
							
						
					
					
						commit
						75a3378810
					
				
					 5 changed files with 42 additions and 44 deletions
				
			
		|  | @ -111,6 +111,9 @@ Bookkeeping functions | |||
|    as an optional part of the API. When available, :meth:`getrandbits` enables | ||||
|    :meth:`randrange` to handle arbitrarily large ranges. | ||||
| 
 | ||||
|    .. versionchanged:: 3.9 | ||||
|       This method now accepts zero for *k*. | ||||
| 
 | ||||
| 
 | ||||
| .. function:: randbytes(n) | ||||
| 
 | ||||
|  |  | |||
|  | @ -261,6 +261,8 @@ def randint(self, a, b): | |||
|     def _randbelow_with_getrandbits(self, n): | ||||
|         "Return a random int in the range [0,n).  Raises ValueError if n==0." | ||||
| 
 | ||||
|         if not n: | ||||
|             raise ValueError("Boundary cannot be zero") | ||||
|         getrandbits = self.getrandbits | ||||
|         k = n.bit_length()  # don't use (n-1) here because n can be 1 | ||||
|         r = getrandbits(k)          # 0 <= r < 2**k | ||||
|  | @ -733,8 +735,8 @@ def random(self): | |||
| 
 | ||||
|     def getrandbits(self, k): | ||||
|         """getrandbits(k) -> x.  Generates an int with k random bits.""" | ||||
|         if k <= 0: | ||||
|             raise ValueError('number of bits must be greater than zero') | ||||
|         if k < 0: | ||||
|             raise ValueError('number of bits must be non-negative') | ||||
|         numbytes = (k + 7) // 8                       # bits / 8 and rounded up | ||||
|         x = int.from_bytes(_urandom(numbytes), 'big') | ||||
|         return x >> (numbytes * 8 - k)                # trim excess bits | ||||
|  |  | |||
|  | @ -263,6 +263,31 @@ def test_gauss(self): | |||
|             self.assertEqual(x1, x2) | ||||
|             self.assertEqual(y1, y2) | ||||
| 
 | ||||
|     def test_getrandbits(self): | ||||
|         # Verify ranges | ||||
|         for k in range(1, 1000): | ||||
|             self.assertTrue(0 <= self.gen.getrandbits(k) < 2**k) | ||||
|         self.assertEqual(self.gen.getrandbits(0), 0) | ||||
| 
 | ||||
|         # Verify all bits active | ||||
|         getbits = self.gen.getrandbits | ||||
|         for span in [1, 2, 3, 4, 31, 32, 32, 52, 53, 54, 119, 127, 128, 129]: | ||||
|             all_bits = 2**span-1 | ||||
|             cum = 0 | ||||
|             cpl_cum = 0 | ||||
|             for i in range(100): | ||||
|                 v = getbits(span) | ||||
|                 cum |= v | ||||
|                 cpl_cum |= all_bits ^ v | ||||
|             self.assertEqual(cum, all_bits) | ||||
|             self.assertEqual(cpl_cum, all_bits) | ||||
| 
 | ||||
|         # Verify argument checking | ||||
|         self.assertRaises(TypeError, self.gen.getrandbits) | ||||
|         self.assertRaises(TypeError, self.gen.getrandbits, 1, 2) | ||||
|         self.assertRaises(ValueError, self.gen.getrandbits, -1) | ||||
|         self.assertRaises(TypeError, self.gen.getrandbits, 10.1) | ||||
| 
 | ||||
|     def test_pickling(self): | ||||
|         for proto in range(pickle.HIGHEST_PROTOCOL + 1): | ||||
|             state = pickle.dumps(self.gen, proto) | ||||
|  | @ -390,26 +415,6 @@ def test_randrange_errors(self): | |||
|         raises(0, 42, 0) | ||||
|         raises(0, 42, 3.14159) | ||||
| 
 | ||||
|     def test_genrandbits(self): | ||||
|         # Verify ranges | ||||
|         for k in range(1, 1000): | ||||
|             self.assertTrue(0 <= self.gen.getrandbits(k) < 2**k) | ||||
| 
 | ||||
|         # Verify all bits active | ||||
|         getbits = self.gen.getrandbits | ||||
|         for span in [1, 2, 3, 4, 31, 32, 32, 52, 53, 54, 119, 127, 128, 129]: | ||||
|             cum = 0 | ||||
|             for i in range(100): | ||||
|                 cum |= getbits(span) | ||||
|             self.assertEqual(cum, 2**span-1) | ||||
| 
 | ||||
|         # Verify argument checking | ||||
|         self.assertRaises(TypeError, self.gen.getrandbits) | ||||
|         self.assertRaises(TypeError, self.gen.getrandbits, 1, 2) | ||||
|         self.assertRaises(ValueError, self.gen.getrandbits, 0) | ||||
|         self.assertRaises(ValueError, self.gen.getrandbits, -1) | ||||
|         self.assertRaises(TypeError, self.gen.getrandbits, 10.1) | ||||
| 
 | ||||
|     def test_randbelow_logic(self, _log=log, int=int): | ||||
|         # check bitcount transition points:  2**i and 2**(i+1)-1 | ||||
|         # show that: k = int(1.001 + _log(n, 2)) | ||||
|  | @ -629,34 +634,18 @@ def test_rangelimits(self): | |||
|             self.assertEqual(set(range(start,stop)), | ||||
|                 set([self.gen.randrange(start,stop) for i in range(100)])) | ||||
| 
 | ||||
|     def test_genrandbits(self): | ||||
|     def test_getrandbits(self): | ||||
|         super().test_getrandbits() | ||||
| 
 | ||||
|         # Verify cross-platform repeatability | ||||
|         self.gen.seed(1234567) | ||||
|         self.assertEqual(self.gen.getrandbits(100), | ||||
|                          97904845777343510404718956115) | ||||
|         # Verify ranges | ||||
|         for k in range(1, 1000): | ||||
|             self.assertTrue(0 <= self.gen.getrandbits(k) < 2**k) | ||||
| 
 | ||||
|         # Verify all bits active | ||||
|         getbits = self.gen.getrandbits | ||||
|         for span in [1, 2, 3, 4, 31, 32, 32, 52, 53, 54, 119, 127, 128, 129]: | ||||
|             cum = 0 | ||||
|             for i in range(100): | ||||
|                 cum |= getbits(span) | ||||
|             self.assertEqual(cum, 2**span-1) | ||||
| 
 | ||||
|         # Verify argument checking | ||||
|         self.assertRaises(TypeError, self.gen.getrandbits) | ||||
|         self.assertRaises(TypeError, self.gen.getrandbits, 'a') | ||||
|         self.assertRaises(TypeError, self.gen.getrandbits, 1, 2) | ||||
|         self.assertRaises(ValueError, self.gen.getrandbits, 0) | ||||
|         self.assertRaises(ValueError, self.gen.getrandbits, -1) | ||||
| 
 | ||||
|     def test_randrange_uses_getrandbits(self): | ||||
|         # Verify use of getrandbits by randrange | ||||
|         # Use same seed as in the cross-platform repeatability test | ||||
|         # in test_genrandbits above. | ||||
|         # in test_getrandbits above. | ||||
|         self.gen.seed(1234567) | ||||
|         # If randrange uses getrandbits, it should pick getrandbits(100) | ||||
|         # when called with a 100-bits stop argument. | ||||
|  |  | |||
|  | @ -0,0 +1 @@ | |||
| Allow ``random.getrandbits(0)`` to succeed and to return 0. | ||||
|  | @ -474,12 +474,15 @@ _random_Random_getrandbits_impl(RandomObject *self, int k) | |||
|     uint32_t *wordarray; | ||||
|     PyObject *result; | ||||
| 
 | ||||
|     if (k <= 0) { | ||||
|     if (k < 0) { | ||||
|         PyErr_SetString(PyExc_ValueError, | ||||
|                         "number of bits must be greater than zero"); | ||||
|                         "number of bits must be non-negative"); | ||||
|         return NULL; | ||||
|     } | ||||
| 
 | ||||
|     if (k == 0) | ||||
|         return PyLong_FromLong(0); | ||||
| 
 | ||||
|     if (k <= 32)  /* Fast path */ | ||||
|         return PyLong_FromUnsignedLong(genrand_uint32(self) >> (32 - k)); | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Antoine Pitrou
						Antoine Pitrou