mirror of
https://github.com/python/cpython.git
synced 2026-01-03 22:12:27 +00:00
[3.14] gh-142830: prevent some crashes when mutating sqlite3 callbacks (GH-143245) (#143322)
(cherry picked from commit 7f6c16a956)
This commit is contained in:
parent
8680b18f26
commit
048edac8be
4 changed files with 184 additions and 28 deletions
|
|
@ -24,11 +24,15 @@
|
|||
import sqlite3 as sqlite
|
||||
import unittest
|
||||
|
||||
from test.support import import_helper
|
||||
from test.support.os_helper import TESTFN, unlink
|
||||
|
||||
from .util import memory_database, cx_limit, with_tracebacks
|
||||
from .util import MemoryDatabaseMixin
|
||||
|
||||
# TODO(picnixz): increase test coverage for other callbacks
|
||||
# such as 'func', 'step', 'finalize', and 'collation'.
|
||||
|
||||
|
||||
class CollationTests(MemoryDatabaseMixin, unittest.TestCase):
|
||||
|
||||
|
|
@ -129,8 +133,55 @@ def test_deregister_collation(self):
|
|||
self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
|
||||
|
||||
|
||||
class AuthorizerTests(MemoryDatabaseMixin, unittest.TestCase):
|
||||
|
||||
def assert_not_authorized(self, func, /, *args, **kwargs):
|
||||
with self.assertRaisesRegex(sqlite.DatabaseError, "not authorized"):
|
||||
func(*args, **kwargs)
|
||||
|
||||
# When a handler has an invalid signature, the exception raised is
|
||||
# the same that would be raised if the handler "negatively" replied.
|
||||
|
||||
def test_authorizer_invalid_signature(self):
|
||||
self.cx.execute("create table if not exists test(a number)")
|
||||
self.cx.set_authorizer(lambda: None)
|
||||
self.assert_not_authorized(self.cx.execute, "select * from test")
|
||||
|
||||
# Tests for checking that callback context mutations do not crash.
|
||||
# Regression tests for https://github.com/python/cpython/issues/142830.
|
||||
|
||||
@with_tracebacks(ZeroDivisionError, regex="hello world")
|
||||
def test_authorizer_concurrent_mutation_in_call(self):
|
||||
self.cx.execute("create table if not exists test(a number)")
|
||||
|
||||
def handler(*a, **kw):
|
||||
self.cx.set_authorizer(None)
|
||||
raise ZeroDivisionError("hello world")
|
||||
|
||||
self.cx.set_authorizer(handler)
|
||||
self.assert_not_authorized(self.cx.execute, "select * from test")
|
||||
|
||||
@with_tracebacks(OverflowError)
|
||||
def test_authorizer_concurrent_mutation_with_overflown_value(self):
|
||||
_testcapi = import_helper.import_module("_testcapi")
|
||||
self.cx.execute("create table if not exists test(a number)")
|
||||
|
||||
def handler(*a, **kw):
|
||||
self.cx.set_authorizer(None)
|
||||
# We expect 'int' at the C level, so this one will raise
|
||||
# when converting via PyLong_Int().
|
||||
return _testcapi.INT_MAX + 1
|
||||
|
||||
self.cx.set_authorizer(handler)
|
||||
self.assert_not_authorized(self.cx.execute, "select * from test")
|
||||
|
||||
|
||||
class ProgressTests(MemoryDatabaseMixin, unittest.TestCase):
|
||||
|
||||
def assert_interrupted(self, func, /, *args, **kwargs):
|
||||
with self.assertRaisesRegex(sqlite.OperationalError, "interrupted"):
|
||||
func(*args, **kwargs)
|
||||
|
||||
def test_progress_handler_used(self):
|
||||
"""
|
||||
Test that the progress handler is invoked once it is set.
|
||||
|
|
@ -219,7 +270,7 @@ def bad_progress():
|
|||
create table foo(a, b)
|
||||
""")
|
||||
|
||||
def test_progress_handler_keyword_args(self):
|
||||
def test_set_progress_handler_keyword_args(self):
|
||||
regex = (
|
||||
r"Passing keyword argument 'progress_handler' to "
|
||||
r"_sqlite3.Connection.set_progress_handler\(\) is deprecated. "
|
||||
|
|
@ -231,6 +282,43 @@ def test_progress_handler_keyword_args(self):
|
|||
self.con.set_progress_handler(progress_handler=lambda: None, n=1)
|
||||
self.assertEqual(cm.filename, __file__)
|
||||
|
||||
# When a handler has an invalid signature, the exception raised is
|
||||
# the same that would be raised if the handler "negatively" replied.
|
||||
|
||||
def test_progress_handler_invalid_signature(self):
|
||||
self.cx.execute("create table if not exists test(a number)")
|
||||
self.cx.set_progress_handler(lambda x: None, 1)
|
||||
self.assert_interrupted(self.cx.execute, "select * from test")
|
||||
|
||||
# Tests for checking that callback context mutations do not crash.
|
||||
# Regression tests for https://github.com/python/cpython/issues/142830.
|
||||
|
||||
@with_tracebacks(ZeroDivisionError, regex="hello world")
|
||||
def test_progress_handler_concurrent_mutation_in_call(self):
|
||||
self.cx.execute("create table if not exists test(a number)")
|
||||
|
||||
def handler(*a, **kw):
|
||||
self.cx.set_progress_handler(None, 1)
|
||||
raise ZeroDivisionError("hello world")
|
||||
|
||||
self.cx.set_progress_handler(handler, 1)
|
||||
self.assert_interrupted(self.cx.execute, "select * from test")
|
||||
|
||||
def test_progress_handler_concurrent_mutation_in_conversion(self):
|
||||
self.cx.execute("create table if not exists test(a number)")
|
||||
|
||||
class Handler:
|
||||
def __bool__(_):
|
||||
# clear the progress handler
|
||||
self.cx.set_progress_handler(None, 1)
|
||||
raise ValueError # force PyObject_True() to fail
|
||||
|
||||
self.cx.set_progress_handler(Handler.__init__, 1)
|
||||
self.assert_interrupted(self.cx.execute, "select * from test")
|
||||
|
||||
# Running with tracebacks makes the second execution of this
|
||||
# function raise another exception because of a database change.
|
||||
|
||||
|
||||
class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase):
|
||||
|
||||
|
|
@ -352,7 +440,7 @@ def test_trace_bad_handler(self):
|
|||
cx.set_trace_callback(lambda stmt: 5/0)
|
||||
cx.execute("select 1")
|
||||
|
||||
def test_trace_keyword_args(self):
|
||||
def test_set_trace_callback_keyword_args(self):
|
||||
regex = (
|
||||
r"Passing keyword argument 'trace_callback' to "
|
||||
r"_sqlite3.Connection.set_trace_callback\(\) is deprecated. "
|
||||
|
|
@ -364,6 +452,35 @@ def test_trace_keyword_args(self):
|
|||
self.con.set_trace_callback(trace_callback=lambda: None)
|
||||
self.assertEqual(cm.filename, __file__)
|
||||
|
||||
# When a handler has an invalid signature, the exception raised is
|
||||
# the same that would be raised if the handler "negatively" replied,
|
||||
# but for the trace handler, exceptions are never re-raised (only
|
||||
# printed when needed).
|
||||
|
||||
@with_tracebacks(
|
||||
TypeError,
|
||||
regex=r".*<lambda>\(\) missing 6 required positional arguments",
|
||||
)
|
||||
def test_trace_handler_invalid_signature(self):
|
||||
self.cx.execute("create table if not exists test(a number)")
|
||||
self.cx.set_trace_callback(lambda x, y, z, t, a, b, c: None)
|
||||
self.cx.execute("select * from test")
|
||||
|
||||
# Tests for checking that callback context mutations do not crash.
|
||||
# Regression tests for https://github.com/python/cpython/issues/142830.
|
||||
|
||||
@with_tracebacks(ZeroDivisionError, regex="hello world")
|
||||
def test_trace_callback_concurrent_mutation_in_call(self):
|
||||
self.cx.execute("create table if not exists test(a number)")
|
||||
|
||||
def handler(statement):
|
||||
# clear the progress handler
|
||||
self.cx.set_trace_callback(None)
|
||||
raise ZeroDivisionError("hello world")
|
||||
|
||||
self.cx.set_trace_callback(handler)
|
||||
self.cx.execute("select * from test")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
:mod:`sqlite3`: fix use-after-free crashes when the connection's callbacks
|
||||
are mutated during a callback execution. Patch by Bénédikt Tran.
|
||||
|
|
@ -145,7 +145,8 @@ class _sqlite3.Connection "pysqlite_Connection *" "clinic_state()->ConnectionTyp
|
|||
/*[clinic end generated code: output=da39a3ee5e6b4b0d input=67369db2faf80891]*/
|
||||
|
||||
static int _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self);
|
||||
static void free_callback_context(callback_context *ctx);
|
||||
static void incref_callback_context(callback_context *ctx);
|
||||
static void decref_callback_context(callback_context *ctx);
|
||||
static void set_callback_context(callback_context **ctx_pp,
|
||||
callback_context *ctx);
|
||||
static int connection_close(pysqlite_Connection *self);
|
||||
|
|
@ -937,8 +938,9 @@ func_callback(sqlite3_context *context, int argc, sqlite3_value **argv)
|
|||
args = _pysqlite_build_py_params(context, argc, argv);
|
||||
if (args) {
|
||||
callback_context *ctx = (callback_context *)sqlite3_user_data(context);
|
||||
assert(ctx != NULL);
|
||||
incref_callback_context(ctx);
|
||||
py_retval = PyObject_CallObject(ctx->callable, args);
|
||||
decref_callback_context(ctx);
|
||||
Py_DECREF(args);
|
||||
}
|
||||
|
||||
|
|
@ -965,7 +967,7 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params)
|
|||
PyObject* stepmethod = NULL;
|
||||
|
||||
callback_context *ctx = (callback_context *)sqlite3_user_data(context);
|
||||
assert(ctx != NULL);
|
||||
incref_callback_context(ctx);
|
||||
|
||||
aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*));
|
||||
if (aggregate_instance == NULL) {
|
||||
|
|
@ -1003,6 +1005,7 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params)
|
|||
}
|
||||
|
||||
error:
|
||||
decref_callback_context(ctx);
|
||||
Py_XDECREF(stepmethod);
|
||||
Py_XDECREF(function_result);
|
||||
|
||||
|
|
@ -1034,9 +1037,10 @@ final_callback(sqlite3_context *context)
|
|||
PyObject *exc = PyErr_GetRaisedException();
|
||||
|
||||
callback_context *ctx = (callback_context *)sqlite3_user_data(context);
|
||||
assert(ctx != NULL);
|
||||
incref_callback_context(ctx);
|
||||
function_result = PyObject_CallMethodNoArgs(*aggregate_instance,
|
||||
ctx->state->str_finalize);
|
||||
decref_callback_context(ctx);
|
||||
Py_DECREF(*aggregate_instance);
|
||||
|
||||
ok = 0;
|
||||
|
|
@ -1108,6 +1112,7 @@ create_callback_context(PyTypeObject *cls, PyObject *callable)
|
|||
callback_context *ctx = PyMem_Malloc(sizeof(callback_context));
|
||||
if (ctx != NULL) {
|
||||
PyObject *module = PyType_GetModule(cls);
|
||||
ctx->refcount = 1;
|
||||
ctx->callable = Py_NewRef(callable);
|
||||
ctx->module = Py_NewRef(module);
|
||||
ctx->state = pysqlite_get_state(module);
|
||||
|
|
@ -1119,11 +1124,33 @@ static void
|
|||
free_callback_context(callback_context *ctx)
|
||||
{
|
||||
assert(ctx != NULL);
|
||||
assert(ctx->refcount == 0);
|
||||
Py_XDECREF(ctx->callable);
|
||||
Py_XDECREF(ctx->module);
|
||||
PyMem_Free(ctx);
|
||||
}
|
||||
|
||||
static inline void
|
||||
incref_callback_context(callback_context *ctx)
|
||||
{
|
||||
assert(PyGILState_Check());
|
||||
assert(ctx != NULL);
|
||||
assert(ctx->refcount > 0);
|
||||
ctx->refcount++;
|
||||
}
|
||||
|
||||
static inline void
|
||||
decref_callback_context(callback_context *ctx)
|
||||
{
|
||||
assert(PyGILState_Check());
|
||||
assert(ctx != NULL);
|
||||
assert(ctx->refcount > 0);
|
||||
ctx->refcount--;
|
||||
if (ctx->refcount == 0) {
|
||||
free_callback_context(ctx);
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
set_callback_context(callback_context **ctx_pp, callback_context *ctx)
|
||||
{
|
||||
|
|
@ -1131,7 +1158,7 @@ set_callback_context(callback_context **ctx_pp, callback_context *ctx)
|
|||
callback_context *tmp = *ctx_pp;
|
||||
*ctx_pp = ctx;
|
||||
if (tmp != NULL) {
|
||||
free_callback_context(tmp);
|
||||
decref_callback_context(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1142,7 +1169,7 @@ destructor_callback(void *ctx)
|
|||
// This function may be called without the GIL held, so we need to
|
||||
// ensure that we destroy 'ctx' with the GIL held.
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
free_callback_context((callback_context *)ctx);
|
||||
decref_callback_context((callback_context *)ctx);
|
||||
PyGILState_Release(gstate);
|
||||
}
|
||||
}
|
||||
|
|
@ -1204,7 +1231,7 @@ pysqlite_connection_create_function_impl(pysqlite_Connection *self,
|
|||
func_callback,
|
||||
NULL,
|
||||
NULL,
|
||||
&destructor_callback); // will decref func
|
||||
&destructor_callback); // will free 'ctx'
|
||||
|
||||
if (rc != SQLITE_OK) {
|
||||
/* Workaround for SQLite bug: no error code or string is available here */
|
||||
|
|
@ -1228,7 +1255,7 @@ inverse_callback(sqlite3_context *context, int argc, sqlite3_value **params)
|
|||
PyGILState_STATE gilstate = PyGILState_Ensure();
|
||||
|
||||
callback_context *ctx = (callback_context *)sqlite3_user_data(context);
|
||||
assert(ctx != NULL);
|
||||
incref_callback_context(ctx);
|
||||
|
||||
int size = sizeof(PyObject *);
|
||||
PyObject **cls = (PyObject **)sqlite3_aggregate_context(context, size);
|
||||
|
|
@ -1260,6 +1287,7 @@ inverse_callback(sqlite3_context *context, int argc, sqlite3_value **params)
|
|||
Py_DECREF(res);
|
||||
|
||||
exit:
|
||||
decref_callback_context(ctx);
|
||||
Py_XDECREF(method);
|
||||
PyGILState_Release(gilstate);
|
||||
}
|
||||
|
|
@ -1276,7 +1304,7 @@ value_callback(sqlite3_context *context)
|
|||
PyGILState_STATE gilstate = PyGILState_Ensure();
|
||||
|
||||
callback_context *ctx = (callback_context *)sqlite3_user_data(context);
|
||||
assert(ctx != NULL);
|
||||
incref_callback_context(ctx);
|
||||
|
||||
int size = sizeof(PyObject *);
|
||||
PyObject **cls = (PyObject **)sqlite3_aggregate_context(context, size);
|
||||
|
|
@ -1284,6 +1312,8 @@ value_callback(sqlite3_context *context)
|
|||
assert(*cls != NULL);
|
||||
|
||||
PyObject *res = PyObject_CallMethodNoArgs(*cls, ctx->state->str_value);
|
||||
decref_callback_context(ctx);
|
||||
|
||||
if (res == NULL) {
|
||||
int attr_err = PyErr_ExceptionMatches(PyExc_AttributeError);
|
||||
set_sqlite_error(context, attr_err
|
||||
|
|
@ -1406,7 +1436,7 @@ pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self,
|
|||
0,
|
||||
&step_callback,
|
||||
&final_callback,
|
||||
&destructor_callback); // will decref func
|
||||
&destructor_callback); // will free 'ctx'
|
||||
if (rc != SQLITE_OK) {
|
||||
/* Workaround for SQLite bug: no error code or string is available here */
|
||||
PyErr_SetString(self->OperationalError, "Error creating aggregate");
|
||||
|
|
@ -1416,7 +1446,7 @@ pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self,
|
|||
}
|
||||
|
||||
static int
|
||||
authorizer_callback(void *ctx, int action, const char *arg1,
|
||||
authorizer_callback(void *ctx_vp, int action, const char *arg1,
|
||||
const char *arg2 , const char *dbname,
|
||||
const char *access_attempt_source)
|
||||
{
|
||||
|
|
@ -1425,8 +1455,9 @@ authorizer_callback(void *ctx, int action, const char *arg1,
|
|||
PyObject *ret;
|
||||
int rc = SQLITE_DENY;
|
||||
|
||||
assert(ctx != NULL);
|
||||
PyObject *callable = ((callback_context *)ctx)->callable;
|
||||
callback_context *ctx = (callback_context *)ctx_vp;
|
||||
incref_callback_context(ctx);
|
||||
PyObject *callable = ctx->callable;
|
||||
ret = PyObject_CallFunction(callable, "issss", action, arg1, arg2, dbname,
|
||||
access_attempt_source);
|
||||
|
||||
|
|
@ -1448,21 +1479,23 @@ authorizer_callback(void *ctx, int action, const char *arg1,
|
|||
Py_DECREF(ret);
|
||||
}
|
||||
|
||||
decref_callback_context(ctx);
|
||||
PyGILState_Release(gilstate);
|
||||
return rc;
|
||||
}
|
||||
|
||||
static int
|
||||
progress_callback(void *ctx)
|
||||
progress_callback(void *ctx_vp)
|
||||
{
|
||||
PyGILState_STATE gilstate = PyGILState_Ensure();
|
||||
|
||||
int rc;
|
||||
PyObject *ret;
|
||||
|
||||
assert(ctx != NULL);
|
||||
PyObject *callable = ((callback_context *)ctx)->callable;
|
||||
ret = PyObject_CallNoArgs(callable);
|
||||
callback_context *ctx = (callback_context *)ctx_vp;
|
||||
incref_callback_context(ctx);
|
||||
|
||||
ret = PyObject_CallNoArgs(ctx->callable);
|
||||
if (!ret) {
|
||||
/* abort query if error occurred */
|
||||
rc = -1;
|
||||
|
|
@ -1475,6 +1508,7 @@ progress_callback(void *ctx)
|
|||
print_or_clear_traceback(ctx);
|
||||
}
|
||||
|
||||
decref_callback_context(ctx);
|
||||
PyGILState_Release(gilstate);
|
||||
return rc;
|
||||
}
|
||||
|
|
@ -1486,7 +1520,7 @@ progress_callback(void *ctx)
|
|||
* to ensure future compatibility.
|
||||
*/
|
||||
static int
|
||||
trace_callback(unsigned int type, void *ctx, void *stmt, void *sql)
|
||||
trace_callback(unsigned int type, void *ctx_vp, void *stmt, void *sql)
|
||||
{
|
||||
if (type != SQLITE_TRACE_STMT) {
|
||||
return 0;
|
||||
|
|
@ -1494,8 +1528,9 @@ trace_callback(unsigned int type, void *ctx, void *stmt, void *sql)
|
|||
|
||||
PyGILState_STATE gilstate = PyGILState_Ensure();
|
||||
|
||||
assert(ctx != NULL);
|
||||
pysqlite_state *state = ((callback_context *)ctx)->state;
|
||||
callback_context *ctx = (callback_context *)ctx_vp;
|
||||
incref_callback_context(ctx);
|
||||
pysqlite_state *state = ctx->state;
|
||||
assert(state != NULL);
|
||||
|
||||
PyObject *py_statement = NULL;
|
||||
|
|
@ -1509,7 +1544,7 @@ trace_callback(unsigned int type, void *ctx, void *stmt, void *sql)
|
|||
|
||||
PyErr_SetString(state->DataError,
|
||||
"Expanded SQL string exceeds the maximum string length");
|
||||
print_or_clear_traceback((callback_context *)ctx);
|
||||
print_or_clear_traceback(ctx);
|
||||
|
||||
// Fall back to unexpanded sql
|
||||
py_statement = PyUnicode_FromString((const char *)sql);
|
||||
|
|
@ -1519,16 +1554,16 @@ trace_callback(unsigned int type, void *ctx, void *stmt, void *sql)
|
|||
sqlite3_free((void *)expanded_sql);
|
||||
}
|
||||
if (py_statement) {
|
||||
PyObject *callable = ((callback_context *)ctx)->callable;
|
||||
PyObject *ret = PyObject_CallOneArg(callable, py_statement);
|
||||
PyObject *ret = PyObject_CallOneArg(ctx->callable, py_statement);
|
||||
Py_DECREF(py_statement);
|
||||
Py_XDECREF(ret);
|
||||
}
|
||||
if (PyErr_Occurred()) {
|
||||
print_or_clear_traceback((callback_context *)ctx);
|
||||
print_or_clear_traceback(ctx);
|
||||
}
|
||||
|
||||
exit:
|
||||
decref_callback_context(ctx);
|
||||
PyGILState_Release(gilstate);
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -1952,6 +1987,8 @@ collation_callback(void *context, int text1_length, const void *text1_data,
|
|||
PyObject* retval = NULL;
|
||||
long longval;
|
||||
int result = 0;
|
||||
callback_context *ctx = (callback_context *)context;
|
||||
incref_callback_context(ctx);
|
||||
|
||||
/* This callback may be executed multiple times per sqlite3_step(). Bail if
|
||||
* the previous call failed */
|
||||
|
|
@ -1968,8 +2005,6 @@ collation_callback(void *context, int text1_length, const void *text1_data,
|
|||
goto finally;
|
||||
}
|
||||
|
||||
callback_context *ctx = (callback_context *)context;
|
||||
assert(ctx != NULL);
|
||||
PyObject *args[] = { NULL, string1, string2 }; // Borrowed refs.
|
||||
size_t nargsf = 2 | PY_VECTORCALL_ARGUMENTS_OFFSET;
|
||||
retval = PyObject_Vectorcall(ctx->callable, args + 1, nargsf, NULL);
|
||||
|
|
@ -1991,6 +2026,7 @@ collation_callback(void *context, int text1_length, const void *text1_data,
|
|||
}
|
||||
|
||||
finally:
|
||||
decref_callback_context(ctx);
|
||||
Py_XDECREF(string1);
|
||||
Py_XDECREF(string2);
|
||||
Py_XDECREF(retval);
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ typedef struct _callback_context
|
|||
PyObject *callable;
|
||||
PyObject *module;
|
||||
pysqlite_state *state;
|
||||
Py_ssize_t refcount;
|
||||
} callback_context;
|
||||
|
||||
enum autocommit_mode {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue