| 
									
										
										
										
											2006-04-21 10:40:58 +00:00
										 |  |  | # pysqlite2/test/hooks.py: tests for various SQLite-specific hooks | 
					
						
							|  |  |  | # | 
					
						
							| 
									
										
										
										
											2021-01-07 01:36:35 +01:00
										 |  |  | # Copyright (C) 2006-2007 Gerhard Häring <gh@ghaering.de> | 
					
						
							| 
									
										
										
										
											2006-04-21 10:40:58 +00:00
										 |  |  | # | 
					
						
							|  |  |  | # This file is part of pysqlite. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # This software is provided 'as-is', without any express or implied | 
					
						
							|  |  |  | # warranty.  In no event will the authors be held liable for any damages | 
					
						
							|  |  |  | # arising from the use of this software. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Permission is granted to anyone to use this software for any purpose, | 
					
						
							|  |  |  | # including commercial applications, and to alter it and redistribute it | 
					
						
							|  |  |  | # freely, subject to the following restrictions: | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # 1. The origin of this software must not be misrepresented; you must not | 
					
						
							|  |  |  | #    claim that you wrote the original software. If you use this software | 
					
						
							|  |  |  | #    in a product, an acknowledgment in the product documentation would be | 
					
						
							|  |  |  | #    appreciated but is not required. | 
					
						
							|  |  |  | # 2. Altered source versions must be plainly marked as such, and must not be | 
					
						
							|  |  |  | #    misrepresented as being the original software. | 
					
						
							|  |  |  | # 3. This notice may not be removed or altered from any source distribution. | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-05-02 08:14:35 -06:00
										 |  |  | import contextlib | 
					
						
							| 
									
										
										
										
											2022-03-09 18:39:49 +01:00
										 |  |  | import sqlite3 as sqlite | 
					
						
							| 
									
										
										
										
											2022-05-02 08:14:35 -06:00
										 |  |  | import unittest | 
					
						
							| 
									
										
										
										
											2006-04-21 10:40:58 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-07 23:18:38 +08:00
										 |  |  | from test.support.os_helper import TESTFN, unlink | 
					
						
							| 
									
										
										
										
											2022-05-02 08:14:35 -06:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  | from .util import memory_database, cx_limit, with_tracebacks | 
					
						
							|  |  |  | from .util import MemoryDatabaseMixin | 
					
						
							| 
									
										
										
										
											2017-04-09 12:11:59 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-05-02 08:14:35 -06:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  | class CollationTests(MemoryDatabaseMixin, unittest.TestCase): | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-07 01:05:07 +01:00
										 |  |  |     def test_create_collation_not_string(self): | 
					
						
							| 
									
										
										
										
											2016-09-27 00:10:03 +03:00
										 |  |  |         with self.assertRaises(TypeError): | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  |             self.con.create_collation(None, lambda x, y: (x > y) - (x < y)) | 
					
						
							| 
									
										
										
										
											2016-09-27 00:10:03 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-07 01:05:07 +01:00
										 |  |  |     def test_create_collation_not_callable(self): | 
					
						
							| 
									
										
										
										
											2016-06-12 22:34:49 +03:00
										 |  |  |         with self.assertRaises(TypeError) as cm: | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  |             self.con.create_collation("X", 42) | 
					
						
							| 
									
										
										
										
											2016-06-12 22:34:49 +03:00
										 |  |  |         self.assertEqual(str(cm.exception), 'parameter must be callable') | 
					
						
							| 
									
										
										
										
											2006-04-21 10:40:58 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-07 01:05:07 +01:00
										 |  |  |     def test_create_collation_not_ascii(self): | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  |         self.con.create_collation("collä", lambda x, y: (x > y) - (x < y)) | 
					
						
							| 
									
										
										
										
											2006-04-21 10:40:58 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-07 01:05:07 +01:00
										 |  |  |     def test_create_collation_bad_upper(self): | 
					
						
							| 
									
										
										
										
											2016-09-27 00:10:03 +03:00
										 |  |  |         class BadUpperStr(str): | 
					
						
							|  |  |  |             def upper(self): | 
					
						
							|  |  |  |                 return None | 
					
						
							|  |  |  |         mycoll = lambda x, y: -((x > y) - (x < y)) | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  |         self.con.create_collation(BadUpperStr("mycoll"), mycoll) | 
					
						
							|  |  |  |         result = self.con.execute("""
 | 
					
						
							| 
									
										
										
										
											2016-09-27 00:10:03 +03:00
										 |  |  |             select x from ( | 
					
						
							|  |  |  |             select 'a' as x | 
					
						
							|  |  |  |             union | 
					
						
							|  |  |  |             select 'b' as x | 
					
						
							|  |  |  |             ) order by x collate mycoll | 
					
						
							|  |  |  |             """).fetchall()
 | 
					
						
							|  |  |  |         self.assertEqual(result[0][0], 'b') | 
					
						
							|  |  |  |         self.assertEqual(result[1][0], 'a') | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-07 01:05:07 +01:00
										 |  |  |     def test_collation_is_used(self): | 
					
						
							| 
									
										
										
										
											2006-04-21 10:40:58 +00:00
										 |  |  |         def mycoll(x, y): | 
					
						
							|  |  |  |             # reverse order | 
					
						
							| 
									
										
										
										
											2009-01-27 18:17:45 +00:00
										 |  |  |             return -((x > y) - (x < y)) | 
					
						
							| 
									
										
										
										
											2006-04-21 10:40:58 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  |         self.con.create_collation("mycoll", mycoll) | 
					
						
							| 
									
										
										
										
											2006-04-21 10:40:58 +00:00
										 |  |  |         sql = """
 | 
					
						
							|  |  |  |             select x from ( | 
					
						
							|  |  |  |             select 'a' as x | 
					
						
							|  |  |  |             union | 
					
						
							|  |  |  |             select 'b' as x | 
					
						
							|  |  |  |             union | 
					
						
							|  |  |  |             select 'c' as x | 
					
						
							|  |  |  |             ) order by x collate mycoll | 
					
						
							|  |  |  |             """
 | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  |         result = self.con.execute(sql).fetchall() | 
					
						
							| 
									
										
										
										
											2016-06-14 00:42:50 +03:00
										 |  |  |         self.assertEqual(result, [('c',), ('b',), ('a',)], | 
					
						
							|  |  |  |                          msg='the expected order was not returned') | 
					
						
							| 
									
										
										
										
											2006-04-21 10:40:58 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  |         self.con.create_collation("mycoll", None) | 
					
						
							| 
									
										
										
										
											2016-06-12 22:34:49 +03:00
										 |  |  |         with self.assertRaises(sqlite.OperationalError) as cm: | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  |             result = self.con.execute(sql).fetchall() | 
					
						
							| 
									
										
										
										
											2016-06-12 22:34:49 +03:00
										 |  |  |         self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') | 
					
						
							| 
									
										
										
										
											2006-04-21 10:40:58 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-07 01:05:07 +01:00
										 |  |  |     def test_collation_returns_large_integer(self): | 
					
						
							| 
									
										
										
										
											2013-02-07 17:01:47 +02:00
										 |  |  |         def mycoll(x, y): | 
					
						
							|  |  |  |             # reverse order | 
					
						
							|  |  |  |             return -((x > y) - (x < y)) * 2**32 | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  |         self.con.create_collation("mycoll", mycoll) | 
					
						
							| 
									
										
										
										
											2013-02-07 17:01:47 +02:00
										 |  |  |         sql = """
 | 
					
						
							|  |  |  |             select x from ( | 
					
						
							|  |  |  |             select 'a' as x | 
					
						
							|  |  |  |             union | 
					
						
							|  |  |  |             select 'b' as x | 
					
						
							|  |  |  |             union | 
					
						
							|  |  |  |             select 'c' as x | 
					
						
							|  |  |  |             ) order by x collate mycoll | 
					
						
							|  |  |  |             """
 | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  |         result = self.con.execute(sql).fetchall() | 
					
						
							| 
									
										
										
										
											2013-02-07 17:01:47 +02:00
										 |  |  |         self.assertEqual(result, [('c',), ('b',), ('a',)], | 
					
						
							|  |  |  |                          msg="the expected order was not returned") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-07 01:05:07 +01:00
										 |  |  |     def test_collation_register_twice(self): | 
					
						
							| 
									
										
										
										
											2006-04-21 10:40:58 +00:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Register two different collation functions under the same name. | 
					
						
							|  |  |  |         Verify that the last one is actually used. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  |         con = self.con | 
					
						
							| 
									
										
										
										
											2009-01-27 18:17:45 +00:00
										 |  |  |         con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) | 
					
						
							|  |  |  |         con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y))) | 
					
						
							| 
									
										
										
										
											2006-04-21 10:40:58 +00:00
										 |  |  |         result = con.execute("""
 | 
					
						
							|  |  |  |             select x from (select 'a' as x union select 'b' as x) order by x collate mycoll | 
					
						
							|  |  |  |             """).fetchall()
 | 
					
						
							| 
									
										
										
										
											2016-06-12 22:34:49 +03:00
										 |  |  |         self.assertEqual(result[0][0], 'b') | 
					
						
							|  |  |  |         self.assertEqual(result[1][0], 'a') | 
					
						
							| 
									
										
										
										
											2006-04-21 10:40:58 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-07 01:05:07 +01:00
										 |  |  |     def test_deregister_collation(self): | 
					
						
							| 
									
										
										
										
											2006-04-21 10:40:58 +00:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Register a collation, then deregister it. Make sure an error is raised if we try | 
					
						
							|  |  |  |         to use it. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  |         con = self.con | 
					
						
							| 
									
										
										
										
											2009-01-27 18:17:45 +00:00
										 |  |  |         con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) | 
					
						
							| 
									
										
										
										
											2006-04-21 10:40:58 +00:00
										 |  |  |         con.create_collation("mycoll", None) | 
					
						
							| 
									
										
										
										
											2016-06-12 22:34:49 +03:00
										 |  |  |         with self.assertRaises(sqlite.OperationalError) as cm: | 
					
						
							| 
									
										
										
										
											2006-04-21 10:40:58 +00:00
										 |  |  |             con.execute("select 'a' as x union select 'b' as x order by x collate mycoll") | 
					
						
							| 
									
										
										
										
											2016-06-12 22:34:49 +03:00
										 |  |  |         self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') | 
					
						
							| 
									
										
										
										
											2006-04-21 10:40:58 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | class ProgressTests(MemoryDatabaseMixin, unittest.TestCase): | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-07 01:05:07 +01:00
										 |  |  |     def test_progress_handler_used(self): | 
					
						
							| 
									
										
										
										
											2008-03-29 00:45:29 +00:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Test that the progress handler is invoked once it is set. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         progress_calls = [] | 
					
						
							|  |  |  |         def progress(): | 
					
						
							|  |  |  |             progress_calls.append(None) | 
					
						
							|  |  |  |             return 0 | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  |         self.con.set_progress_handler(progress, 1) | 
					
						
							|  |  |  |         self.con.execute("""
 | 
					
						
							| 
									
										
										
										
											2008-03-29 00:45:29 +00:00
										 |  |  |             create table foo(a, b) | 
					
						
							|  |  |  |             """)
 | 
					
						
							| 
									
										
										
										
											2009-07-04 08:32:15 +00:00
										 |  |  |         self.assertTrue(progress_calls) | 
					
						
							| 
									
										
										
										
											2008-03-29 00:45:29 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-07 01:05:07 +01:00
										 |  |  |     def test_opcode_count(self): | 
					
						
							| 
									
										
										
										
											2008-03-29 00:45:29 +00:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Test that the opcode argument is respected. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  |         con = self.con | 
					
						
							| 
									
										
										
										
											2008-03-29 00:45:29 +00:00
										 |  |  |         progress_calls = [] | 
					
						
							|  |  |  |         def progress(): | 
					
						
							|  |  |  |             progress_calls.append(None) | 
					
						
							|  |  |  |             return 0 | 
					
						
							|  |  |  |         con.set_progress_handler(progress, 1) | 
					
						
							|  |  |  |         curs = con.cursor() | 
					
						
							|  |  |  |         curs.execute("""
 | 
					
						
							|  |  |  |             create table foo (a, b) | 
					
						
							|  |  |  |             """)
 | 
					
						
							|  |  |  |         first_count = len(progress_calls) | 
					
						
							|  |  |  |         progress_calls = [] | 
					
						
							|  |  |  |         con.set_progress_handler(progress, 2) | 
					
						
							|  |  |  |         curs.execute("""
 | 
					
						
							|  |  |  |             create table bar (a, b) | 
					
						
							|  |  |  |             """)
 | 
					
						
							|  |  |  |         second_count = len(progress_calls) | 
					
						
							| 
									
										
										
										
											2014-03-12 21:51:52 -05:00
										 |  |  |         self.assertGreaterEqual(first_count, second_count) | 
					
						
							| 
									
										
										
										
											2008-03-29 00:45:29 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-07 01:05:07 +01:00
										 |  |  |     def test_cancel_operation(self): | 
					
						
							| 
									
										
										
										
											2008-03-29 00:45:29 +00:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Test that returning a non-zero value stops the operation in progress. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         def progress(): | 
					
						
							|  |  |  |             return 1 | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  |         self.con.set_progress_handler(progress, 1) | 
					
						
							|  |  |  |         curs = self.con.cursor() | 
					
						
							| 
									
										
										
										
											2008-03-29 00:45:29 +00:00
										 |  |  |         self.assertRaises( | 
					
						
							|  |  |  |             sqlite.OperationalError, | 
					
						
							|  |  |  |             curs.execute, | 
					
						
							|  |  |  |             "create table bar (a, b)") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-07 01:05:07 +01:00
										 |  |  |     def test_clear_handler(self): | 
					
						
							| 
									
										
										
										
											2008-03-29 00:45:29 +00:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Test that setting the progress handler to None clears the previously set handler. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  |         con = self.con | 
					
						
							| 
									
										
										
										
											2008-03-29 00:45:29 +00:00
										 |  |  |         action = 0 | 
					
						
							|  |  |  |         def progress(): | 
					
						
							| 
									
										
										
										
											2012-02-17 21:30:55 +02:00
										 |  |  |             nonlocal action | 
					
						
							| 
									
										
										
										
											2008-03-29 00:45:29 +00:00
										 |  |  |             action = 1 | 
					
						
							|  |  |  |             return 0 | 
					
						
							|  |  |  |         con.set_progress_handler(progress, 1) | 
					
						
							|  |  |  |         con.set_progress_handler(None, 1) | 
					
						
							|  |  |  |         con.execute("select 1 union select 2 union select 3").fetchall() | 
					
						
							| 
									
										
										
										
											2009-07-04 08:32:15 +00:00
										 |  |  |         self.assertEqual(action, 0, "progress handler was not cleared") | 
					
						
							| 
									
										
										
										
											2008-03-29 00:45:29 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-03 13:10:18 +01:00
										 |  |  |     @with_tracebacks(ZeroDivisionError, msg_regex="bad_progress") | 
					
						
							| 
									
										
										
										
											2021-08-08 08:49:44 +03:00
										 |  |  |     def test_error_in_progress_handler(self): | 
					
						
							|  |  |  |         def bad_progress(): | 
					
						
							|  |  |  |             1 / 0 | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  |         self.con.set_progress_handler(bad_progress, 1) | 
					
						
							| 
									
										
										
										
											2021-08-08 08:49:44 +03:00
										 |  |  |         with self.assertRaises(sqlite.OperationalError): | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  |             self.con.execute("""
 | 
					
						
							| 
									
										
										
										
											2021-08-08 08:49:44 +03:00
										 |  |  |                 create table foo(a, b) | 
					
						
							|  |  |  |                 """)
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-02-03 13:10:18 +01:00
										 |  |  |     @with_tracebacks(ZeroDivisionError, msg_regex="bad_progress") | 
					
						
							| 
									
										
										
										
											2021-08-08 08:49:44 +03:00
										 |  |  |     def test_error_in_progress_handler_result(self): | 
					
						
							|  |  |  |         class BadBool: | 
					
						
							|  |  |  |             def __bool__(self): | 
					
						
							|  |  |  |                 1 / 0 | 
					
						
							|  |  |  |         def bad_progress(): | 
					
						
							|  |  |  |             return BadBool() | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  |         self.con.set_progress_handler(bad_progress, 1) | 
					
						
							| 
									
										
										
										
											2021-08-08 08:49:44 +03:00
										 |  |  |         with self.assertRaises(sqlite.OperationalError): | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  |             self.con.execute("""
 | 
					
						
							| 
									
										
										
										
											2021-08-08 08:49:44 +03:00
										 |  |  |                 create table foo(a, b) | 
					
						
							|  |  |  |                 """)
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-29 22:02:12 +02:00
										 |  |  |     def test_progress_handler_keyword_args(self): | 
					
						
							| 
									
										
										
										
											2025-05-08 15:42:00 +03:00
										 |  |  |         with self.assertRaisesRegex(TypeError, | 
					
						
							|  |  |  |                 'takes at least 1 positional argument'): | 
					
						
							| 
									
										
										
										
											2023-08-29 22:02:12 +02:00
										 |  |  |             self.con.set_progress_handler(progress_handler=lambda: None, n=1) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-08 08:49:44 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  | class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase): | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-05-02 08:14:35 -06:00
										 |  |  |     @contextlib.contextmanager | 
					
						
							|  |  |  |     def check_stmt_trace(self, cx, expected): | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             traced = [] | 
					
						
							|  |  |  |             cx.set_trace_callback(lambda stmt: traced.append(stmt)) | 
					
						
							|  |  |  |             yield | 
					
						
							|  |  |  |         finally: | 
					
						
							|  |  |  |             self.assertEqual(traced, expected) | 
					
						
							|  |  |  |             cx.set_trace_callback(None) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-07 01:05:07 +01:00
										 |  |  |     def test_trace_callback_used(self): | 
					
						
							| 
									
										
										
										
											2011-04-04 00:12:04 +02:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Test that the trace callback is invoked once it is set. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         traced_statements = [] | 
					
						
							|  |  |  |         def trace(statement): | 
					
						
							|  |  |  |             traced_statements.append(statement) | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  |         self.con.set_trace_callback(trace) | 
					
						
							|  |  |  |         self.con.execute("create table foo(a, b)") | 
					
						
							| 
									
										
										
										
											2011-04-04 00:12:04 +02:00
										 |  |  |         self.assertTrue(traced_statements) | 
					
						
							|  |  |  |         self.assertTrue(any("create table foo" in stmt for stmt in traced_statements)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-07 01:05:07 +01:00
										 |  |  |     def test_clear_trace_callback(self): | 
					
						
							| 
									
										
										
										
											2011-04-04 00:12:04 +02:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Test that setting the trace callback to None clears the previously set callback. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  |         con = self.con | 
					
						
							| 
									
										
										
										
											2011-04-04 00:12:04 +02:00
										 |  |  |         traced_statements = [] | 
					
						
							|  |  |  |         def trace(statement): | 
					
						
							|  |  |  |             traced_statements.append(statement) | 
					
						
							|  |  |  |         con.set_trace_callback(trace) | 
					
						
							|  |  |  |         con.set_trace_callback(None) | 
					
						
							|  |  |  |         con.execute("create table foo(a, b)") | 
					
						
							|  |  |  |         self.assertFalse(traced_statements, "trace callback was not cleared") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-07 01:05:07 +01:00
										 |  |  |     def test_unicode_content(self): | 
					
						
							| 
									
										
										
										
											2011-04-04 00:12:04 +02:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Test that the statement can contain unicode literals. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac' | 
					
						
							| 
									
										
										
										
											2023-08-17 08:45:48 +02:00
										 |  |  |         con = self.con | 
					
						
							| 
									
										
										
										
											2011-04-04 00:12:04 +02:00
										 |  |  |         traced_statements = [] | 
					
						
							|  |  |  |         def trace(statement): | 
					
						
							|  |  |  |             traced_statements.append(statement) | 
					
						
							|  |  |  |         con.set_trace_callback(trace) | 
					
						
							|  |  |  |         con.execute("create table foo(x)") | 
					
						
							| 
									
										
										
										
											2021-05-14 12:27:21 +02:00
										 |  |  |         con.execute("insert into foo(x) values ('%s')" % unicode_value) | 
					
						
							| 
									
										
										
										
											2011-04-04 00:12:04 +02:00
										 |  |  |         con.commit() | 
					
						
							|  |  |  |         self.assertTrue(any(unicode_value in stmt for stmt in traced_statements), | 
					
						
							| 
									
										
										
										
											2011-04-04 00:50:01 +02:00
										 |  |  |                         "Unicode data %s garbled in trace callback: %s" | 
					
						
							|  |  |  |                         % (ascii(unicode_value), ', '.join(map(ascii, traced_statements)))) | 
					
						
							| 
									
										
										
										
											2011-04-04 00:12:04 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-07 01:05:07 +01:00
										 |  |  |     def test_trace_callback_content(self): | 
					
						
							| 
									
										
										
										
											2017-04-09 12:11:59 +03:00
										 |  |  |         # set_trace_callback() shouldn't produce duplicate content (bpo-26187) | 
					
						
							|  |  |  |         traced_statements = [] | 
					
						
							|  |  |  |         def trace(statement): | 
					
						
							|  |  |  |             traced_statements.append(statement) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         queries = ["create table foo(x)", | 
					
						
							|  |  |  |                    "insert into foo(x) values(1)"] | 
					
						
							|  |  |  |         self.addCleanup(unlink, TESTFN) | 
					
						
							|  |  |  |         con1 = sqlite.connect(TESTFN, isolation_level=None) | 
					
						
							|  |  |  |         con2 = sqlite.connect(TESTFN) | 
					
						
							| 
									
										
										
										
											2021-06-03 18:38:19 +02:00
										 |  |  |         try: | 
					
						
							|  |  |  |             con1.set_trace_callback(trace) | 
					
						
							|  |  |  |             cur = con1.cursor() | 
					
						
							|  |  |  |             cur.execute(queries[0]) | 
					
						
							|  |  |  |             con2.execute("create table bar(x)") | 
					
						
							|  |  |  |             cur.execute(queries[1]) | 
					
						
							|  |  |  |         finally: | 
					
						
							|  |  |  |             con1.close() | 
					
						
							|  |  |  |             con2.close() | 
					
						
							| 
									
										
										
										
											2017-04-09 12:11:59 +03:00
										 |  |  |         self.assertEqual(traced_statements, queries) | 
					
						
							| 
									
										
										
										
											2011-04-04 00:12:04 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-05-02 08:14:35 -06:00
										 |  |  |     def test_trace_expanded_sql(self): | 
					
						
							|  |  |  |         expected = [ | 
					
						
							|  |  |  |             "create table t(t)", | 
					
						
							|  |  |  |             "BEGIN ", | 
					
						
							|  |  |  |             "insert into t values(0)", | 
					
						
							|  |  |  |             "insert into t values(1)", | 
					
						
							|  |  |  |             "insert into t values(2)", | 
					
						
							|  |  |  |             "COMMIT", | 
					
						
							|  |  |  |         ] | 
					
						
							|  |  |  |         with memory_database() as cx, self.check_stmt_trace(cx, expected): | 
					
						
							|  |  |  |             with cx: | 
					
						
							|  |  |  |                 cx.execute("create table t(t)") | 
					
						
							|  |  |  |                 cx.executemany("insert into t values(?)", ((v,) for v in range(3))) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @with_tracebacks( | 
					
						
							|  |  |  |         sqlite.DataError, | 
					
						
							|  |  |  |         regex="Expanded SQL string exceeds the maximum string length" | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     def test_trace_too_much_expanded_sql(self): | 
					
						
							|  |  |  |         # If the expanded string is too large, we'll fall back to the | 
					
						
							| 
									
										
										
										
											2023-06-19 00:29:08 +02:00
										 |  |  |         # unexpanded SQL statement. | 
					
						
							| 
									
										
										
										
											2022-05-02 08:14:35 -06:00
										 |  |  |         # The resulting string length is limited by the runtime limit | 
					
						
							|  |  |  |         # SQLITE_LIMIT_LENGTH. | 
					
						
							|  |  |  |         template = "select 1 as a where a=" | 
					
						
							|  |  |  |         category = sqlite.SQLITE_LIMIT_LENGTH | 
					
						
							|  |  |  |         with memory_database() as cx, cx_limit(cx, category=category) as lim: | 
					
						
							|  |  |  |             ok_param = "a" | 
					
						
							|  |  |  |             bad_param = "a" * lim | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             unexpanded_query = template + "?" | 
					
						
							|  |  |  |             expected = [unexpanded_query] | 
					
						
							|  |  |  |             with self.check_stmt_trace(cx, expected): | 
					
						
							|  |  |  |                 cx.execute(unexpanded_query, (bad_param,)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             expanded_query = f"{template}'{ok_param}'" | 
					
						
							|  |  |  |             with self.check_stmt_trace(cx, [expanded_query]): | 
					
						
							|  |  |  |                 cx.execute(unexpanded_query, (ok_param,)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @with_tracebacks(ZeroDivisionError, regex="division by zero") | 
					
						
							|  |  |  |     def test_trace_bad_handler(self): | 
					
						
							|  |  |  |         with memory_database() as cx: | 
					
						
							|  |  |  |             cx.set_trace_callback(lambda stmt: 5/0) | 
					
						
							|  |  |  |             cx.execute("select 1") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-29 22:02:12 +02:00
										 |  |  |     def test_trace_keyword_args(self): | 
					
						
							| 
									
										
										
										
											2025-05-08 15:42:00 +03:00
										 |  |  |         with self.assertRaisesRegex(TypeError, | 
					
						
							|  |  |  |                 'takes exactly 1 positional argument'): | 
					
						
							| 
									
										
										
										
											2023-08-29 22:02:12 +02:00
										 |  |  |             self.con.set_trace_callback(trace_callback=lambda: None) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2011-04-04 00:12:04 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2006-04-21 10:40:58 +00:00
										 |  |  | if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2021-09-13 14:16:26 +03:00
										 |  |  |     unittest.main() |