| 
									
										
										
										
											2016-04-15 01:51:31 +10:00
										 |  |  | """Test the secrets module.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | As most of the functions in secrets are thin wrappers around functions | 
					
						
							|  |  |  | defined elsewhere, we don't need to test them exhaustively. | 
					
						
							|  |  |  | """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import secrets | 
					
						
							|  |  |  | import unittest | 
					
						
							|  |  |  | import string | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # === Unit tests === | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Compare_Digest_Tests(unittest.TestCase): | 
					
						
							|  |  |  |     """Test secrets.compare_digest function.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_equal(self): | 
					
						
							|  |  |  |         # Test compare_digest functionality with equal (byte/text) strings. | 
					
						
							|  |  |  |         for s in ("a", "bcd", "xyz123"): | 
					
						
							|  |  |  |             a = s*100 | 
					
						
							|  |  |  |             b = s*100 | 
					
						
							|  |  |  |             self.assertTrue(secrets.compare_digest(a, b)) | 
					
						
							|  |  |  |             self.assertTrue(secrets.compare_digest(a.encode('utf-8'), b.encode('utf-8'))) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_unequal(self): | 
					
						
							|  |  |  |         # Test compare_digest functionality with unequal (byte/text) strings. | 
					
						
							|  |  |  |         self.assertFalse(secrets.compare_digest("abc", "abcd")) | 
					
						
							|  |  |  |         self.assertFalse(secrets.compare_digest(b"abc", b"abcd")) | 
					
						
							|  |  |  |         for s in ("x", "mn", "a1b2c3"): | 
					
						
							|  |  |  |             a = s*100 + "q" | 
					
						
							|  |  |  |             b = s*100 + "k" | 
					
						
							|  |  |  |             self.assertFalse(secrets.compare_digest(a, b)) | 
					
						
							|  |  |  |             self.assertFalse(secrets.compare_digest(a.encode('utf-8'), b.encode('utf-8'))) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_bad_types(self): | 
					
						
							|  |  |  |         # Test that compare_digest raises with mixed types. | 
					
						
							|  |  |  |         a = 'abcde' | 
					
						
							|  |  |  |         b = a.encode('utf-8') | 
					
						
							|  |  |  |         assert isinstance(a, str) | 
					
						
							|  |  |  |         assert isinstance(b, bytes) | 
					
						
							|  |  |  |         self.assertRaises(TypeError, secrets.compare_digest, a, b) | 
					
						
							|  |  |  |         self.assertRaises(TypeError, secrets.compare_digest, b, a) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_bool(self): | 
					
						
							|  |  |  |         # Test that compare_digest returns a bool. | 
					
						
							| 
									
										
										
										
											2016-04-15 10:04:24 +10:00
										 |  |  |         self.assertIsInstance(secrets.compare_digest("abc", "abc"), bool) | 
					
						
							|  |  |  |         self.assertIsInstance(secrets.compare_digest("abc", "xyz"), bool) | 
					
						
							| 
									
										
										
										
											2016-04-15 01:51:31 +10:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Random_Tests(unittest.TestCase): | 
					
						
							|  |  |  |     """Test wrappers around SystemRandom methods.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_randbits(self): | 
					
						
							|  |  |  |         # Test randbits. | 
					
						
							|  |  |  |         errmsg = "randbits(%d) returned %d" | 
					
						
							|  |  |  |         for numbits in (3, 12, 30): | 
					
						
							|  |  |  |             for i in range(6): | 
					
						
							|  |  |  |                 n = secrets.randbits(numbits) | 
					
						
							|  |  |  |                 self.assertTrue(0 <= n < 2**numbits, errmsg % (numbits, n)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_choice(self): | 
					
						
							|  |  |  |         # Test choice. | 
					
						
							|  |  |  |         items = [1, 2, 4, 8, 16, 32, 64] | 
					
						
							|  |  |  |         for i in range(10): | 
					
						
							|  |  |  |             self.assertTrue(secrets.choice(items) in items) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_randbelow(self): | 
					
						
							|  |  |  |         # Test randbelow. | 
					
						
							|  |  |  |         for i in range(2, 10): | 
					
						
							| 
									
										
										
										
											2016-04-15 10:04:24 +10:00
										 |  |  |             self.assertIn(secrets.randbelow(i), range(i)) | 
					
						
							| 
									
										
										
										
											2016-04-15 01:51:31 +10:00
										 |  |  |         self.assertRaises(ValueError, secrets.randbelow, 0) | 
					
						
							| 
									
										
										
										
											2016-12-29 22:54:25 -07:00
										 |  |  |         self.assertRaises(ValueError, secrets.randbelow, -1) | 
					
						
							| 
									
										
										
										
											2016-04-15 01:51:31 +10:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Token_Tests(unittest.TestCase): | 
					
						
							|  |  |  |     """Test token functions.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_token_defaults(self): | 
					
						
							|  |  |  |         # Test that token_* functions handle default size correctly. | 
					
						
							|  |  |  |         for func in (secrets.token_bytes, secrets.token_hex, | 
					
						
							|  |  |  |                      secrets.token_urlsafe): | 
					
						
							| 
									
										
										
										
											2016-04-15 10:04:24 +10:00
										 |  |  |             with self.subTest(func=func): | 
					
						
							|  |  |  |                 name = func.__name__ | 
					
						
							|  |  |  |                 try: | 
					
						
							|  |  |  |                     func() | 
					
						
							|  |  |  |                 except TypeError: | 
					
						
							|  |  |  |                     self.fail("%s cannot be called with no argument" % name) | 
					
						
							|  |  |  |                 try: | 
					
						
							|  |  |  |                     func(None) | 
					
						
							|  |  |  |                 except TypeError: | 
					
						
							|  |  |  |                     self.fail("%s cannot be called with None" % name) | 
					
						
							| 
									
										
										
										
											2016-04-15 01:51:31 +10:00
										 |  |  |         size = secrets.DEFAULT_ENTROPY | 
					
						
							|  |  |  |         self.assertEqual(len(secrets.token_bytes(None)), size) | 
					
						
							|  |  |  |         self.assertEqual(len(secrets.token_hex(None)), 2*size) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def test_token_bytes(self): | 
					
						
							|  |  |  |         # Test token_bytes. | 
					
						
							|  |  |  |         for n in (1, 8, 17, 100): | 
					
						
							| 
									
										
										
										
											2016-04-15 10:04:24 +10:00
										 |  |  |             with self.subTest(n=n): | 
					
						
							|  |  |  |                 self.assertIsInstance(secrets.token_bytes(n), bytes) | 
					
						
							|  |  |  |                 self.assertEqual(len(secrets.token_bytes(n)), n) | 
					
						
							| 
									
										
										
										
											2016-04-15 01:51:31 +10:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def test_token_hex(self): | 
					
						
							|  |  |  |         # Test token_hex. | 
					
						
							|  |  |  |         for n in (1, 12, 25, 90): | 
					
						
							| 
									
										
										
										
											2016-04-15 10:04:24 +10:00
										 |  |  |             with self.subTest(n=n): | 
					
						
							|  |  |  |                 s = secrets.token_hex(n) | 
					
						
							| 
									
										
										
										
											2016-04-15 10:06:18 +10:00
										 |  |  |                 self.assertIsInstance(s, str) | 
					
						
							| 
									
										
										
										
											2016-04-15 10:04:24 +10:00
										 |  |  |                 self.assertEqual(len(s), 2*n) | 
					
						
							|  |  |  |                 self.assertTrue(all(c in string.hexdigits for c in s)) | 
					
						
							| 
									
										
										
										
											2016-04-15 01:51:31 +10:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def test_token_urlsafe(self): | 
					
						
							|  |  |  |         # Test token_urlsafe. | 
					
						
							|  |  |  |         legal = string.ascii_letters + string.digits + '-_' | 
					
						
							|  |  |  |         for n in (1, 11, 28, 76): | 
					
						
							| 
									
										
										
										
											2016-04-15 10:04:24 +10:00
										 |  |  |             with self.subTest(n=n): | 
					
						
							|  |  |  |                 s = secrets.token_urlsafe(n) | 
					
						
							| 
									
										
										
										
											2016-04-15 10:06:18 +10:00
										 |  |  |                 self.assertIsInstance(s, str) | 
					
						
							| 
									
										
										
										
											2016-04-15 10:04:24 +10:00
										 |  |  |                 self.assertTrue(all(c in legal for c in s)) | 
					
						
							| 
									
										
										
										
											2016-04-15 01:51:31 +10:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == '__main__': | 
					
						
							|  |  |  |     unittest.main() |