[3.10] gh-79579: Improve DML query detection in sqlite3 (GH-93623) (#93801)

The fix involves using pysqlite_check_remaining_sql(), not only to check
for multiple statements, but now also to strip leading comments and
whitespace from SQL statements, so we can improve DML query detection.

pysqlite_check_remaining_sql() is renamed lstrip_sql(), to more
accurately reflect its function, and hardened to handle more SQL comment
corner cases.

(cherry picked from commit 46740073ef)
This commit is contained in:
Erlend Egeberg Aasland 2022-06-14 15:05:36 +02:00 committed by GitHub
parent f9585e2adc
commit 2229d34a6e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 1939 additions and 75 deletions

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,2 @@
:mod:`sqlite3` now correctly detects DML queries with leading comments.
Patch by Erlend E. Aasland.

View file

@ -29,16 +29,7 @@
#include "util.h" #include "util.h"
/* prototypes */ /* prototypes */
static int pysqlite_check_remaining_sql(const char* tail); static const char *lstrip_sql(const char *sql);
typedef enum {
LINECOMMENT_1,
IN_LINECOMMENT,
COMMENTSTART_1,
IN_COMMENT,
COMMENTEND_1,
NORMAL
} parse_remaining_sql_state;
typedef enum { typedef enum {
TYPE_LONG, TYPE_LONG,
@ -55,7 +46,6 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
int rc; int rc;
const char* sql_cstr; const char* sql_cstr;
Py_ssize_t sql_cstr_len; Py_ssize_t sql_cstr_len;
const char* p;
assert(PyUnicode_Check(sql)); assert(PyUnicode_Check(sql));
@ -87,20 +77,12 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
/* Determine if the statement is a DML statement. /* Determine if the statement is a DML statement.
SELECT is the only exception. See #9924. */ SELECT is the only exception. See #9924. */
for (p = sql_cstr; *p != 0; p++) { const char *p = lstrip_sql(sql_cstr);
switch (*p) { if (p != NULL) {
case ' ':
case '\r':
case '\n':
case '\t':
continue;
}
self->is_dml = (PyOS_strnicmp(p, "insert", 6) == 0) self->is_dml = (PyOS_strnicmp(p, "insert", 6) == 0)
|| (PyOS_strnicmp(p, "update", 6) == 0) || (PyOS_strnicmp(p, "update", 6) == 0)
|| (PyOS_strnicmp(p, "delete", 6) == 0) || (PyOS_strnicmp(p, "delete", 6) == 0)
|| (PyOS_strnicmp(p, "replace", 7) == 0); || (PyOS_strnicmp(p, "replace", 7) == 0);
break;
} }
Py_BEGIN_ALLOW_THREADS Py_BEGIN_ALLOW_THREADS
@ -118,7 +100,7 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
goto error; goto error;
} }
if (rc == SQLITE_OK && pysqlite_check_remaining_sql(tail)) { if (rc == SQLITE_OK && lstrip_sql(tail)) {
(void)sqlite3_finalize(self->st); (void)sqlite3_finalize(self->st);
self->st = NULL; self->st = NULL;
PyErr_SetString(pysqlite_Warning, PyErr_SetString(pysqlite_Warning,
@ -431,73 +413,61 @@ stmt_traverse(pysqlite_Statement *self, visitproc visit, void *arg)
} }
/* /*
* Checks if there is anything left in an SQL string after SQLite compiled it. * Strip leading whitespace and comments from incoming SQL (null terminated C
* This is used to check if somebody tried to execute more than one SQL command * string) and return a pointer to the first non-whitespace, non-comment
* with one execute()/executemany() command, which the DB-API and we don't * character.
* allow.
* *
* Returns 1 if there is more left than should be. 0 if ok. * This is used to check if somebody tries to execute more than one SQL query
* with one execute()/executemany() command, which the DB-API don't allow.
*
* It is also used to harden DML query detection.
*/ */
static int pysqlite_check_remaining_sql(const char* tail) static inline const char *
lstrip_sql(const char *sql)
{ {
const char* pos = tail; // This loop is borrowed from the SQLite source code.
for (const char *pos = sql; *pos; pos++) {
parse_remaining_sql_state state = NORMAL;
for (;;) {
switch (*pos) { switch (*pos) {
case 0:
return 0;
case '-':
if (state == NORMAL) {
state = LINECOMMENT_1;
} else if (state == LINECOMMENT_1) {
state = IN_LINECOMMENT;
}
break;
case ' ': case ' ':
case '\t': case '\t':
break; case '\f':
case '\n': case '\n':
case 13: case '\r':
if (state == IN_LINECOMMENT) { // Skip whitespace.
state = NORMAL;
}
break; break;
case '/': case '-':
if (state == NORMAL) { // Skip line comments.
state = COMMENTSTART_1; if (pos[1] == '-') {
} else if (state == COMMENTEND_1) { pos += 2;
state = NORMAL; while (pos[0] && pos[0] != '\n') {
} else if (state == COMMENTSTART_1) {
return 1;
}
break;
case '*':
if (state == NORMAL) {
return 1;
} else if (state == LINECOMMENT_1) {
return 1;
} else if (state == COMMENTSTART_1) {
state = IN_COMMENT;
} else if (state == IN_COMMENT) {
state = COMMENTEND_1;
}
break;
default:
if (state == COMMENTEND_1) {
state = IN_COMMENT;
} else if (state == IN_LINECOMMENT) {
} else if (state == IN_COMMENT) {
} else {
return 1;
}
}
pos++; pos++;
} }
if (pos[0] == '\0') {
return NULL;
}
continue;
}
return pos;
case '/':
// Skip C style comments.
if (pos[1] == '*') {
pos += 2;
while (pos[0] && (pos[0] != '*' || pos[1] != '/')) {
pos++;
}
if (pos[0] == '\0') {
return NULL;
}
pos++;
continue;
}
return pos;
default:
return pos;
}
}
return 0; return NULL;
} }
static PyMemberDef stmt_members[] = { static PyMemberDef stmt_members[] = {