gh-105539: Explict resource management for connection objects in sqlite3 tests (#108017)

- Use memory_database() helper
- Move test utility functions to util.py
- Add convenience memory database mixin
- Add check() helper for closed connection tests
This commit is contained in:
Erlend E. Aasland 2023-08-17 08:45:48 +02:00 committed by GitHub
parent c9d83f93d8
commit 1344cfac43
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 371 additions and 385 deletions

View file

@ -26,34 +26,31 @@
from test.support.os_helper import TESTFN, unlink
from test.test_sqlite3.test_dbapi import memory_database, cx_limit
from test.test_sqlite3.test_userfunctions import with_tracebacks
from .util import memory_database, cx_limit, with_tracebacks
from .util import MemoryDatabaseMixin
class CollationTests(unittest.TestCase):
class CollationTests(MemoryDatabaseMixin, unittest.TestCase):
def test_create_collation_not_string(self):
con = sqlite.connect(":memory:")
with self.assertRaises(TypeError):
con.create_collation(None, lambda x, y: (x > y) - (x < y))
self.con.create_collation(None, lambda x, y: (x > y) - (x < y))
def test_create_collation_not_callable(self):
con = sqlite.connect(":memory:")
with self.assertRaises(TypeError) as cm:
con.create_collation("X", 42)
self.con.create_collation("X", 42)
self.assertEqual(str(cm.exception), 'parameter must be callable')
def test_create_collation_not_ascii(self):
con = sqlite.connect(":memory:")
con.create_collation("collä", lambda x, y: (x > y) - (x < y))
self.con.create_collation("collä", lambda x, y: (x > y) - (x < y))
def test_create_collation_bad_upper(self):
class BadUpperStr(str):
def upper(self):
return None
con = sqlite.connect(":memory:")
mycoll = lambda x, y: -((x > y) - (x < y))
con.create_collation(BadUpperStr("mycoll"), mycoll)
result = con.execute("""
self.con.create_collation(BadUpperStr("mycoll"), mycoll)
result = self.con.execute("""
select x from (
select 'a' as x
union
@ -68,8 +65,7 @@ def mycoll(x, y):
# reverse order
return -((x > y) - (x < y))
con = sqlite.connect(":memory:")
con.create_collation("mycoll", mycoll)
self.con.create_collation("mycoll", mycoll)
sql = """
select x from (
select 'a' as x
@ -79,21 +75,20 @@ def mycoll(x, y):
select 'c' as x
) order by x collate mycoll
"""
result = con.execute(sql).fetchall()
result = self.con.execute(sql).fetchall()
self.assertEqual(result, [('c',), ('b',), ('a',)],
msg='the expected order was not returned')
con.create_collation("mycoll", None)
self.con.create_collation("mycoll", None)
with self.assertRaises(sqlite.OperationalError) as cm:
result = con.execute(sql).fetchall()
result = self.con.execute(sql).fetchall()
self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
def test_collation_returns_large_integer(self):
def mycoll(x, y):
# reverse order
return -((x > y) - (x < y)) * 2**32
con = sqlite.connect(":memory:")
con.create_collation("mycoll", mycoll)
self.con.create_collation("mycoll", mycoll)
sql = """
select x from (
select 'a' as x
@ -103,7 +98,7 @@ def mycoll(x, y):
select 'c' as x
) order by x collate mycoll
"""
result = con.execute(sql).fetchall()
result = self.con.execute(sql).fetchall()
self.assertEqual(result, [('c',), ('b',), ('a',)],
msg="the expected order was not returned")
@ -112,7 +107,7 @@ def test_collation_register_twice(self):
Register two different collation functions under the same name.
Verify that the last one is actually used.
"""
con = sqlite.connect(":memory:")
con = self.con
con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y)))
result = con.execute("""
@ -126,25 +121,26 @@ def test_deregister_collation(self):
Register a collation, then deregister it. Make sure an error is raised if we try
to use it.
"""
con = sqlite.connect(":memory:")
con = self.con
con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
con.create_collation("mycoll", None)
with self.assertRaises(sqlite.OperationalError) as cm:
con.execute("select 'a' as x union select 'b' as x order by x collate mycoll")
self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
class ProgressTests(unittest.TestCase):
class ProgressTests(MemoryDatabaseMixin, unittest.TestCase):
def test_progress_handler_used(self):
"""
Test that the progress handler is invoked once it is set.
"""
con = sqlite.connect(":memory:")
progress_calls = []
def progress():
progress_calls.append(None)
return 0
con.set_progress_handler(progress, 1)
con.execute("""
self.con.set_progress_handler(progress, 1)
self.con.execute("""
create table foo(a, b)
""")
self.assertTrue(progress_calls)
@ -153,7 +149,7 @@ def test_opcode_count(self):
"""
Test that the opcode argument is respected.
"""
con = sqlite.connect(":memory:")
con = self.con
progress_calls = []
def progress():
progress_calls.append(None)
@ -176,11 +172,10 @@ def test_cancel_operation(self):
"""
Test that returning a non-zero value stops the operation in progress.
"""
con = sqlite.connect(":memory:")
def progress():
return 1
con.set_progress_handler(progress, 1)
curs = con.cursor()
self.con.set_progress_handler(progress, 1)
curs = self.con.cursor()
self.assertRaises(
sqlite.OperationalError,
curs.execute,
@ -190,7 +185,7 @@ def test_clear_handler(self):
"""
Test that setting the progress handler to None clears the previously set handler.
"""
con = sqlite.connect(":memory:")
con = self.con
action = 0
def progress():
nonlocal action
@ -203,31 +198,30 @@ def progress():
@with_tracebacks(ZeroDivisionError, name="bad_progress")
def test_error_in_progress_handler(self):
con = sqlite.connect(":memory:")
def bad_progress():
1 / 0
con.set_progress_handler(bad_progress, 1)
self.con.set_progress_handler(bad_progress, 1)
with self.assertRaises(sqlite.OperationalError):
con.execute("""
self.con.execute("""
create table foo(a, b)
""")
@with_tracebacks(ZeroDivisionError, name="bad_progress")
def test_error_in_progress_handler_result(self):
con = sqlite.connect(":memory:")
class BadBool:
def __bool__(self):
1 / 0
def bad_progress():
return BadBool()
con.set_progress_handler(bad_progress, 1)
self.con.set_progress_handler(bad_progress, 1)
with self.assertRaises(sqlite.OperationalError):
con.execute("""
self.con.execute("""
create table foo(a, b)
""")
class TraceCallbackTests(unittest.TestCase):
class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase):
@contextlib.contextmanager
def check_stmt_trace(self, cx, expected):
try:
@ -242,12 +236,11 @@ def test_trace_callback_used(self):
"""
Test that the trace callback is invoked once it is set.
"""
con = sqlite.connect(":memory:")
traced_statements = []
def trace(statement):
traced_statements.append(statement)
con.set_trace_callback(trace)
con.execute("create table foo(a, b)")
self.con.set_trace_callback(trace)
self.con.execute("create table foo(a, b)")
self.assertTrue(traced_statements)
self.assertTrue(any("create table foo" in stmt for stmt in traced_statements))
@ -255,7 +248,7 @@ def test_clear_trace_callback(self):
"""
Test that setting the trace callback to None clears the previously set callback.
"""
con = sqlite.connect(":memory:")
con = self.con
traced_statements = []
def trace(statement):
traced_statements.append(statement)
@ -269,7 +262,7 @@ def test_unicode_content(self):
Test that the statement can contain unicode literals.
"""
unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac'
con = sqlite.connect(":memory:")
con = self.con
traced_statements = []
def trace(statement):
traced_statements.append(statement)