bpo-45138: Expand traced SQL statements in sqlite3 trace callback (GH-28240)

This commit is contained in:
Erlend Egeberg Aasland 2022-03-09 03:46:40 +01:00 committed by GitHub
parent b33a1ae703
commit d1777515f9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 109 additions and 16 deletions

View file

@ -560,6 +560,9 @@ Connection Objects
Passing :const:`None` as *trace_callback* will disable the trace callback. Passing :const:`None` as *trace_callback* will disable the trace callback.
For SQLite 3.14.0 and newer, bound parameters are expanded in the passed
statement string.
.. note:: .. note::
Exceptions raised in the trace callback are not propagated. As a Exceptions raised in the trace callback are not propagated. As a
development and debugging aid, use development and debugging aid, use
@ -568,6 +571,9 @@ Connection Objects
.. versionadded:: 3.3 .. versionadded:: 3.3
.. versionchanged:: 3.11
Added support for expanded SQL statements.
.. method:: enable_load_extension(enabled) .. method:: enable_load_extension(enabled)

View file

@ -322,6 +322,10 @@ sqlite3
Instead we leave it to the SQLite library to handle these cases. Instead we leave it to the SQLite library to handle these cases.
(Contributed by Erlend E. Aasland in :issue:`44092`.) (Contributed by Erlend E. Aasland in :issue:`44092`.)
* For SQLite 3.14.0 and newer, bound parameters are expanded in the statement
string passed to the trace callback. See :meth:`~sqlite3.Connection.set_trace_callback`.
(Contributed by Erlend E. Aasland in :issue:`45138`.)
sys sys
--- ---

View file

@ -20,12 +20,16 @@
# misrepresented as being the original software. # misrepresented as being the original software.
# 3. This notice may not be removed or altered from any source distribution. # 3. This notice may not be removed or altered from any source distribution.
import unittest import contextlib
import sqlite3 as sqlite import sqlite3 as sqlite
import unittest
from test.support.os_helper import TESTFN, unlink from test.support.os_helper import TESTFN, unlink
from test.test_sqlite3.test_dbapi import memory_database, cx_limit
from test.test_sqlite3.test_userfunctions import with_tracebacks from test.test_sqlite3.test_userfunctions import with_tracebacks
class CollationTests(unittest.TestCase): class CollationTests(unittest.TestCase):
def test_create_collation_not_string(self): def test_create_collation_not_string(self):
con = sqlite.connect(":memory:") con = sqlite.connect(":memory:")
@ -224,6 +228,16 @@ def bad_progress():
class TraceCallbackTests(unittest.TestCase): class TraceCallbackTests(unittest.TestCase):
@contextlib.contextmanager
def check_stmt_trace(self, cx, expected):
try:
traced = []
cx.set_trace_callback(lambda stmt: traced.append(stmt))
yield
finally:
self.assertEqual(traced, expected)
cx.set_trace_callback(None)
def test_trace_callback_used(self): def test_trace_callback_used(self):
""" """
Test that the trace callback is invoked once it is set. Test that the trace callback is invoked once it is set.
@ -289,6 +303,51 @@ def trace(statement):
con2.close() con2.close()
self.assertEqual(traced_statements, queries) self.assertEqual(traced_statements, queries)
@unittest.skipIf(sqlite.sqlite_version_info < (3, 14, 0),
"Requires SQLite 3.14.0 or newer")
def test_trace_expanded_sql(self):
expected = [
"create table t(t)",
"BEGIN ",
"insert into t values(0)",
"insert into t values(1)",
"insert into t values(2)",
"COMMIT",
]
with memory_database() as cx, self.check_stmt_trace(cx, expected):
with cx:
cx.execute("create table t(t)")
cx.executemany("insert into t values(?)", ((v,) for v in range(3)))
@with_tracebacks(
sqlite.DataError,
regex="Expanded SQL string exceeds the maximum string length"
)
def test_trace_too_much_expanded_sql(self):
# If the expanded string is too large, we'll fall back to the
# unexpanded SQL statement. The resulting string length is limited by
# SQLITE_LIMIT_LENGTH.
template = "select 'b' as \"a\" from sqlite_master where \"a\"="
category = sqlite.SQLITE_LIMIT_LENGTH
with memory_database() as cx, cx_limit(cx, category=category) as lim:
nextra = lim - (len(template) + 2) - 1
ok_param = "a" * nextra
bad_param = "a" * (nextra + 1)
unexpanded_query = template + "?"
with self.check_stmt_trace(cx, [unexpanded_query]):
cx.execute(unexpanded_query, (bad_param,))
expanded_query = f"{template}'{ok_param}'"
with self.check_stmt_trace(cx, [expanded_query]):
cx.execute(unexpanded_query, (ok_param,))
@with_tracebacks(ZeroDivisionError, regex="division by zero")
def test_trace_bad_handler(self):
with memory_database() as cx:
cx.set_trace_callback(lambda stmt: 5/0)
cx.execute("select 1")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View file

@ -0,0 +1,3 @@
For SQLite 3.14.0 and newer, bound parameters are expanded in the statement
string passed to the :mod:`sqlite3` trace callback. Patch by Erlend E.
Aasland.

View file

@ -1079,11 +1079,10 @@ progress_callback(void *ctx)
* to ensure future compatibility. * to ensure future compatibility.
*/ */
static int static int
trace_callback(unsigned int type, void *ctx, void *prepared_statement, trace_callback(unsigned int type, void *ctx, void *stmt, void *sql)
void *statement_string)
#else #else
static void static void
trace_callback(void *ctx, const char *statement_string) trace_callback(void *ctx, const char *sql)
#endif #endif
{ {
#ifdef HAVE_TRACE_V2 #ifdef HAVE_TRACE_V2
@ -1094,24 +1093,46 @@ trace_callback(void *ctx, const char *statement_string)
PyGILState_STATE gilstate = PyGILState_Ensure(); PyGILState_STATE gilstate = PyGILState_Ensure();
PyObject *py_statement = NULL;
PyObject *ret = NULL;
py_statement = PyUnicode_DecodeUTF8(statement_string,
strlen(statement_string), "replace");
assert(ctx != NULL); assert(ctx != NULL);
if (py_statement) { PyObject *py_statement = NULL;
PyObject *callable = ((callback_context *)ctx)->callable; #ifdef HAVE_TRACE_V2
ret = PyObject_CallOneArg(callable, py_statement); assert(stmt != NULL);
Py_DECREF(py_statement); const char *expanded_sql = sqlite3_expanded_sql((sqlite3_stmt *)stmt);
if (expanded_sql == NULL) {
sqlite3 *db = sqlite3_db_handle((sqlite3_stmt *)stmt);
if (sqlite3_errcode(db) == SQLITE_NOMEM) {
(void)PyErr_NoMemory();
goto exit;
} }
if (ret) { pysqlite_state *state = ((callback_context *)ctx)->state;
Py_DECREF(ret); assert(state != NULL);
PyErr_SetString(state->DataError,
"Expanded SQL string exceeds the maximum string "
"length");
print_or_clear_traceback((callback_context *)ctx);
// Fall back to unexpanded sql
py_statement = PyUnicode_FromString((const char *)sql);
} }
else { else {
print_or_clear_traceback(ctx); py_statement = PyUnicode_FromString(expanded_sql);
sqlite3_free((void *)expanded_sql);
}
#else
py_statement = PyUnicode_FromString(sql);
#endif
if (py_statement) {
PyObject *callable = ((callback_context *)ctx)->callable;
PyObject *ret = PyObject_CallOneArg(callable, py_statement);
Py_DECREF(py_statement);
Py_XDECREF(ret);
} }
exit:
if (PyErr_Occurred()) {
print_or_clear_traceback((callback_context *)ctx);
}
PyGILState_Release(gilstate); PyGILState_Release(gilstate);
#ifdef HAVE_TRACE_V2 #ifdef HAVE_TRACE_V2
return 0; return 0;