[3.14] gh-142830: prevent some crashes when mutating sqlite3 callbacks (GH-143245) (#143322)

(cherry picked from commit 7f6c16a956)
This commit is contained in:
Bénédikt Tran 2026-01-01 12:24:21 +01:00 committed by GitHub
parent 8680b18f26
commit 048edac8be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 184 additions and 28 deletions

View file

@ -24,11 +24,15 @@
import sqlite3 as sqlite
import unittest
from test.support import import_helper
from test.support.os_helper import TESTFN, unlink
from .util import memory_database, cx_limit, with_tracebacks
from .util import MemoryDatabaseMixin
# TODO(picnixz): increase test coverage for other callbacks
# such as 'func', 'step', 'finalize', and 'collation'.
class CollationTests(MemoryDatabaseMixin, unittest.TestCase):
@ -129,8 +133,55 @@ def test_deregister_collation(self):
self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
class AuthorizerTests(MemoryDatabaseMixin, unittest.TestCase):
def assert_not_authorized(self, func, /, *args, **kwargs):
with self.assertRaisesRegex(sqlite.DatabaseError, "not authorized"):
func(*args, **kwargs)
# When a handler has an invalid signature, the exception raised is
# the same that would be raised if the handler "negatively" replied.
def test_authorizer_invalid_signature(self):
self.cx.execute("create table if not exists test(a number)")
self.cx.set_authorizer(lambda: None)
self.assert_not_authorized(self.cx.execute, "select * from test")
# Tests for checking that callback context mutations do not crash.
# Regression tests for https://github.com/python/cpython/issues/142830.
@with_tracebacks(ZeroDivisionError, regex="hello world")
def test_authorizer_concurrent_mutation_in_call(self):
self.cx.execute("create table if not exists test(a number)")
def handler(*a, **kw):
self.cx.set_authorizer(None)
raise ZeroDivisionError("hello world")
self.cx.set_authorizer(handler)
self.assert_not_authorized(self.cx.execute, "select * from test")
@with_tracebacks(OverflowError)
def test_authorizer_concurrent_mutation_with_overflown_value(self):
_testcapi = import_helper.import_module("_testcapi")
self.cx.execute("create table if not exists test(a number)")
def handler(*a, **kw):
self.cx.set_authorizer(None)
# We expect 'int' at the C level, so this one will raise
# when converting via PyLong_Int().
return _testcapi.INT_MAX + 1
self.cx.set_authorizer(handler)
self.assert_not_authorized(self.cx.execute, "select * from test")
class ProgressTests(MemoryDatabaseMixin, unittest.TestCase):
def assert_interrupted(self, func, /, *args, **kwargs):
with self.assertRaisesRegex(sqlite.OperationalError, "interrupted"):
func(*args, **kwargs)
def test_progress_handler_used(self):
"""
Test that the progress handler is invoked once it is set.
@ -219,7 +270,7 @@ def bad_progress():
create table foo(a, b)
""")
def test_progress_handler_keyword_args(self):
def test_set_progress_handler_keyword_args(self):
regex = (
r"Passing keyword argument 'progress_handler' to "
r"_sqlite3.Connection.set_progress_handler\(\) is deprecated. "
@ -231,6 +282,43 @@ def test_progress_handler_keyword_args(self):
self.con.set_progress_handler(progress_handler=lambda: None, n=1)
self.assertEqual(cm.filename, __file__)
# When a handler has an invalid signature, the exception raised is
# the same that would be raised if the handler "negatively" replied.
def test_progress_handler_invalid_signature(self):
self.cx.execute("create table if not exists test(a number)")
self.cx.set_progress_handler(lambda x: None, 1)
self.assert_interrupted(self.cx.execute, "select * from test")
# Tests for checking that callback context mutations do not crash.
# Regression tests for https://github.com/python/cpython/issues/142830.
@with_tracebacks(ZeroDivisionError, regex="hello world")
def test_progress_handler_concurrent_mutation_in_call(self):
self.cx.execute("create table if not exists test(a number)")
def handler(*a, **kw):
self.cx.set_progress_handler(None, 1)
raise ZeroDivisionError("hello world")
self.cx.set_progress_handler(handler, 1)
self.assert_interrupted(self.cx.execute, "select * from test")
def test_progress_handler_concurrent_mutation_in_conversion(self):
self.cx.execute("create table if not exists test(a number)")
class Handler:
def __bool__(_):
# clear the progress handler
self.cx.set_progress_handler(None, 1)
raise ValueError # force PyObject_True() to fail
self.cx.set_progress_handler(Handler.__init__, 1)
self.assert_interrupted(self.cx.execute, "select * from test")
# Running with tracebacks makes the second execution of this
# function raise another exception because of a database change.
class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase):
@ -352,7 +440,7 @@ def test_trace_bad_handler(self):
cx.set_trace_callback(lambda stmt: 5/0)
cx.execute("select 1")
def test_trace_keyword_args(self):
def test_set_trace_callback_keyword_args(self):
regex = (
r"Passing keyword argument 'trace_callback' to "
r"_sqlite3.Connection.set_trace_callback\(\) is deprecated. "
@ -364,6 +452,35 @@ def test_trace_keyword_args(self):
self.con.set_trace_callback(trace_callback=lambda: None)
self.assertEqual(cm.filename, __file__)
# When a handler has an invalid signature, the exception raised is
# the same that would be raised if the handler "negatively" replied,
# but for the trace handler, exceptions are never re-raised (only
# printed when needed).
@with_tracebacks(
TypeError,
regex=r".*<lambda>\(\) missing 6 required positional arguments",
)
def test_trace_handler_invalid_signature(self):
self.cx.execute("create table if not exists test(a number)")
self.cx.set_trace_callback(lambda x, y, z, t, a, b, c: None)
self.cx.execute("select * from test")
# Tests for checking that callback context mutations do not crash.
# Regression tests for https://github.com/python/cpython/issues/142830.
@with_tracebacks(ZeroDivisionError, regex="hello world")
def test_trace_callback_concurrent_mutation_in_call(self):
self.cx.execute("create table if not exists test(a number)")
def handler(statement):
# clear the progress handler
self.cx.set_trace_callback(None)
raise ZeroDivisionError("hello world")
self.cx.set_trace_callback(handler)
self.cx.execute("select * from test")
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,2 @@
:mod:`sqlite3`: fix use-after-free crashes when the connection's callbacks
are mutated during a callback execution. Patch by Bénédikt Tran.

View file

@ -145,7 +145,8 @@ class _sqlite3.Connection "pysqlite_Connection *" "clinic_state()->ConnectionTyp
/*[clinic end generated code: output=da39a3ee5e6b4b0d input=67369db2faf80891]*/
static int _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self);
static void free_callback_context(callback_context *ctx);
static void incref_callback_context(callback_context *ctx);
static void decref_callback_context(callback_context *ctx);
static void set_callback_context(callback_context **ctx_pp,
callback_context *ctx);
static int connection_close(pysqlite_Connection *self);
@ -937,8 +938,9 @@ func_callback(sqlite3_context *context, int argc, sqlite3_value **argv)
args = _pysqlite_build_py_params(context, argc, argv);
if (args) {
callback_context *ctx = (callback_context *)sqlite3_user_data(context);
assert(ctx != NULL);
incref_callback_context(ctx);
py_retval = PyObject_CallObject(ctx->callable, args);
decref_callback_context(ctx);
Py_DECREF(args);
}
@ -965,7 +967,7 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params)
PyObject* stepmethod = NULL;
callback_context *ctx = (callback_context *)sqlite3_user_data(context);
assert(ctx != NULL);
incref_callback_context(ctx);
aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*));
if (aggregate_instance == NULL) {
@ -1003,6 +1005,7 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params)
}
error:
decref_callback_context(ctx);
Py_XDECREF(stepmethod);
Py_XDECREF(function_result);
@ -1034,9 +1037,10 @@ final_callback(sqlite3_context *context)
PyObject *exc = PyErr_GetRaisedException();
callback_context *ctx = (callback_context *)sqlite3_user_data(context);
assert(ctx != NULL);
incref_callback_context(ctx);
function_result = PyObject_CallMethodNoArgs(*aggregate_instance,
ctx->state->str_finalize);
decref_callback_context(ctx);
Py_DECREF(*aggregate_instance);
ok = 0;
@ -1108,6 +1112,7 @@ create_callback_context(PyTypeObject *cls, PyObject *callable)
callback_context *ctx = PyMem_Malloc(sizeof(callback_context));
if (ctx != NULL) {
PyObject *module = PyType_GetModule(cls);
ctx->refcount = 1;
ctx->callable = Py_NewRef(callable);
ctx->module = Py_NewRef(module);
ctx->state = pysqlite_get_state(module);
@ -1119,11 +1124,33 @@ static void
free_callback_context(callback_context *ctx)
{
assert(ctx != NULL);
assert(ctx->refcount == 0);
Py_XDECREF(ctx->callable);
Py_XDECREF(ctx->module);
PyMem_Free(ctx);
}
static inline void
incref_callback_context(callback_context *ctx)
{
assert(PyGILState_Check());
assert(ctx != NULL);
assert(ctx->refcount > 0);
ctx->refcount++;
}
static inline void
decref_callback_context(callback_context *ctx)
{
assert(PyGILState_Check());
assert(ctx != NULL);
assert(ctx->refcount > 0);
ctx->refcount--;
if (ctx->refcount == 0) {
free_callback_context(ctx);
}
}
static void
set_callback_context(callback_context **ctx_pp, callback_context *ctx)
{
@ -1131,7 +1158,7 @@ set_callback_context(callback_context **ctx_pp, callback_context *ctx)
callback_context *tmp = *ctx_pp;
*ctx_pp = ctx;
if (tmp != NULL) {
free_callback_context(tmp);
decref_callback_context(tmp);
}
}
@ -1142,7 +1169,7 @@ destructor_callback(void *ctx)
// This function may be called without the GIL held, so we need to
// ensure that we destroy 'ctx' with the GIL held.
PyGILState_STATE gstate = PyGILState_Ensure();
free_callback_context((callback_context *)ctx);
decref_callback_context((callback_context *)ctx);
PyGILState_Release(gstate);
}
}
@ -1204,7 +1231,7 @@ pysqlite_connection_create_function_impl(pysqlite_Connection *self,
func_callback,
NULL,
NULL,
&destructor_callback); // will decref func
&destructor_callback); // will free 'ctx'
if (rc != SQLITE_OK) {
/* Workaround for SQLite bug: no error code or string is available here */
@ -1228,7 +1255,7 @@ inverse_callback(sqlite3_context *context, int argc, sqlite3_value **params)
PyGILState_STATE gilstate = PyGILState_Ensure();
callback_context *ctx = (callback_context *)sqlite3_user_data(context);
assert(ctx != NULL);
incref_callback_context(ctx);
int size = sizeof(PyObject *);
PyObject **cls = (PyObject **)sqlite3_aggregate_context(context, size);
@ -1260,6 +1287,7 @@ inverse_callback(sqlite3_context *context, int argc, sqlite3_value **params)
Py_DECREF(res);
exit:
decref_callback_context(ctx);
Py_XDECREF(method);
PyGILState_Release(gilstate);
}
@ -1276,7 +1304,7 @@ value_callback(sqlite3_context *context)
PyGILState_STATE gilstate = PyGILState_Ensure();
callback_context *ctx = (callback_context *)sqlite3_user_data(context);
assert(ctx != NULL);
incref_callback_context(ctx);
int size = sizeof(PyObject *);
PyObject **cls = (PyObject **)sqlite3_aggregate_context(context, size);
@ -1284,6 +1312,8 @@ value_callback(sqlite3_context *context)
assert(*cls != NULL);
PyObject *res = PyObject_CallMethodNoArgs(*cls, ctx->state->str_value);
decref_callback_context(ctx);
if (res == NULL) {
int attr_err = PyErr_ExceptionMatches(PyExc_AttributeError);
set_sqlite_error(context, attr_err
@ -1406,7 +1436,7 @@ pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self,
0,
&step_callback,
&final_callback,
&destructor_callback); // will decref func
&destructor_callback); // will free 'ctx'
if (rc != SQLITE_OK) {
/* Workaround for SQLite bug: no error code or string is available here */
PyErr_SetString(self->OperationalError, "Error creating aggregate");
@ -1416,7 +1446,7 @@ pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self,
}
static int
authorizer_callback(void *ctx, int action, const char *arg1,
authorizer_callback(void *ctx_vp, int action, const char *arg1,
const char *arg2 , const char *dbname,
const char *access_attempt_source)
{
@ -1425,8 +1455,9 @@ authorizer_callback(void *ctx, int action, const char *arg1,
PyObject *ret;
int rc = SQLITE_DENY;
assert(ctx != NULL);
PyObject *callable = ((callback_context *)ctx)->callable;
callback_context *ctx = (callback_context *)ctx_vp;
incref_callback_context(ctx);
PyObject *callable = ctx->callable;
ret = PyObject_CallFunction(callable, "issss", action, arg1, arg2, dbname,
access_attempt_source);
@ -1448,21 +1479,23 @@ authorizer_callback(void *ctx, int action, const char *arg1,
Py_DECREF(ret);
}
decref_callback_context(ctx);
PyGILState_Release(gilstate);
return rc;
}
static int
progress_callback(void *ctx)
progress_callback(void *ctx_vp)
{
PyGILState_STATE gilstate = PyGILState_Ensure();
int rc;
PyObject *ret;
assert(ctx != NULL);
PyObject *callable = ((callback_context *)ctx)->callable;
ret = PyObject_CallNoArgs(callable);
callback_context *ctx = (callback_context *)ctx_vp;
incref_callback_context(ctx);
ret = PyObject_CallNoArgs(ctx->callable);
if (!ret) {
/* abort query if error occurred */
rc = -1;
@ -1475,6 +1508,7 @@ progress_callback(void *ctx)
print_or_clear_traceback(ctx);
}
decref_callback_context(ctx);
PyGILState_Release(gilstate);
return rc;
}
@ -1486,7 +1520,7 @@ progress_callback(void *ctx)
* to ensure future compatibility.
*/
static int
trace_callback(unsigned int type, void *ctx, void *stmt, void *sql)
trace_callback(unsigned int type, void *ctx_vp, void *stmt, void *sql)
{
if (type != SQLITE_TRACE_STMT) {
return 0;
@ -1494,8 +1528,9 @@ trace_callback(unsigned int type, void *ctx, void *stmt, void *sql)
PyGILState_STATE gilstate = PyGILState_Ensure();
assert(ctx != NULL);
pysqlite_state *state = ((callback_context *)ctx)->state;
callback_context *ctx = (callback_context *)ctx_vp;
incref_callback_context(ctx);
pysqlite_state *state = ctx->state;
assert(state != NULL);
PyObject *py_statement = NULL;
@ -1509,7 +1544,7 @@ trace_callback(unsigned int type, void *ctx, void *stmt, void *sql)
PyErr_SetString(state->DataError,
"Expanded SQL string exceeds the maximum string length");
print_or_clear_traceback((callback_context *)ctx);
print_or_clear_traceback(ctx);
// Fall back to unexpanded sql
py_statement = PyUnicode_FromString((const char *)sql);
@ -1519,16 +1554,16 @@ trace_callback(unsigned int type, void *ctx, void *stmt, void *sql)
sqlite3_free((void *)expanded_sql);
}
if (py_statement) {
PyObject *callable = ((callback_context *)ctx)->callable;
PyObject *ret = PyObject_CallOneArg(callable, py_statement);
PyObject *ret = PyObject_CallOneArg(ctx->callable, py_statement);
Py_DECREF(py_statement);
Py_XDECREF(ret);
}
if (PyErr_Occurred()) {
print_or_clear_traceback((callback_context *)ctx);
print_or_clear_traceback(ctx);
}
exit:
decref_callback_context(ctx);
PyGILState_Release(gilstate);
return 0;
}
@ -1952,6 +1987,8 @@ collation_callback(void *context, int text1_length, const void *text1_data,
PyObject* retval = NULL;
long longval;
int result = 0;
callback_context *ctx = (callback_context *)context;
incref_callback_context(ctx);
/* This callback may be executed multiple times per sqlite3_step(). Bail if
* the previous call failed */
@ -1968,8 +2005,6 @@ collation_callback(void *context, int text1_length, const void *text1_data,
goto finally;
}
callback_context *ctx = (callback_context *)context;
assert(ctx != NULL);
PyObject *args[] = { NULL, string1, string2 }; // Borrowed refs.
size_t nargsf = 2 | PY_VECTORCALL_ARGUMENTS_OFFSET;
retval = PyObject_Vectorcall(ctx->callable, args + 1, nargsf, NULL);
@ -1991,6 +2026,7 @@ collation_callback(void *context, int text1_length, const void *text1_data,
}
finally:
decref_callback_context(ctx);
Py_XDECREF(string1);
Py_XDECREF(string2);
Py_XDECREF(retval);

View file

@ -36,6 +36,7 @@ typedef struct _callback_context
PyObject *callable;
PyObject *module;
pysqlite_state *state;
Py_ssize_t refcount;
} callback_context;
enum autocommit_mode {