Skip to content

Commit ab35f4b

Browse files
Erlend Egeberg Aaslanderlend-aasland
Erlend Egeberg Aasland
authored andcommitted
gh-89301: Fix regression with bound values in traced statements
1 parent 117836f commit ab35f4b

File tree

3 files changed

+106
-15
lines changed

3 files changed

+106
-15
lines changed

Lib/test/test_sqlite3/test_hooks.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,16 @@
2020
# misrepresented as being the original software.
2121
# 3. This notice may not be removed or altered from any source distribution.
2222

23-
import unittest
23+
import contextlib
2424
import sqlite3 as sqlite
25+
import unittest
2526

2627
from test.support.os_helper import TESTFN, unlink
28+
29+
from test.test_sqlite3.test_dbapi import memory_database, cx_limit
2730
from test.test_sqlite3.test_userfunctions import with_tracebacks
2831

32+
2933
class CollationTests(unittest.TestCase):
3034
def test_create_collation_not_string(self):
3135
con = sqlite.connect(":memory:")
@@ -224,6 +228,16 @@ def bad_progress():
224228

225229

226230
class TraceCallbackTests(unittest.TestCase):
231+
@contextlib.contextmanager
232+
def check_stmt_trace(self, cx, expected):
233+
try:
234+
traced = []
235+
cx.set_trace_callback(lambda stmt: traced.append(stmt))
236+
yield
237+
finally:
238+
self.assertEqual(traced, expected)
239+
cx.set_trace_callback(None)
240+
227241
def test_trace_callback_used(self):
228242
"""
229243
Test that the trace callback is invoked once it is set.
@@ -289,6 +303,53 @@ def trace(statement):
289303
con2.close()
290304
self.assertEqual(traced_statements, queries)
291305

306+
def test_trace_expanded_sql(self):
307+
expected = [
308+
"create table t(t)",
309+
"BEGIN ",
310+
"insert into t values(0)",
311+
"insert into t values(1)",
312+
"insert into t values(2)",
313+
"COMMIT",
314+
]
315+
with memory_database() as cx, self.check_stmt_trace(cx, expected):
316+
with cx:
317+
cx.execute("create table t(t)")
318+
cx.executemany("insert into t values(?)", ((v,) for v in range(3)))
319+
320+
@with_tracebacks(
321+
sqlite.DataError,
322+
regex="Expanded SQL string exceeds the maximum string length"
323+
)
324+
def test_trace_too_much_expanded_sql(self):
325+
# If the expanded string is too large, we'll fall back to the
326+
# unexpanded SQL statement (for SQLite 3.14.0 and newer).
327+
# The resulting string length is limited by the runtime limit
328+
# SQLITE_LIMIT_LENGTH.
329+
template = "select 'b' as \"a\" from sqlite_master where \"a\"="
330+
category = sqlite.SQLITE_LIMIT_LENGTH
331+
with memory_database() as cx, cx_limit(cx, category=category) as lim:
332+
nextra = lim - (len(template) + 2) - 1
333+
ok_param = "a" * nextra
334+
bad_param = "a" * (nextra + 1)
335+
336+
unexpanded_query = template + "?"
337+
expected = [unexpanded_query]
338+
if sqlite.sqlite_version_info < (3, 14, 0):
339+
expected = []
340+
with self.check_stmt_trace(cx, expected):
341+
cx.execute(unexpanded_query, (bad_param,))
342+
343+
expanded_query = f"{template}'{ok_param}'"
344+
with self.check_stmt_trace(cx, [expanded_query]):
345+
cx.execute(unexpanded_query, (ok_param,))
346+
347+
@with_tracebacks(ZeroDivisionError, regex="division by zero")
348+
def test_trace_bad_handler(self):
349+
with memory_database() as cx:
350+
cx.set_trace_callback(lambda stmt: 5/0)
351+
cx.execute("select 1")
352+
292353

293354
if __name__ == "__main__":
294355
unittest.main()
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Fix a regression in the :mod:`sqlite3` trace callback where bound parameters
2+
were not expanded in the passed statement string. The regression was introduced
3+
in Python 3.10 by :issue:`40318`. Patch by Erlend E. Aasland.

Modules/_sqlite/connection.c

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,11 +1332,10 @@ progress_callback(void *ctx)
13321332
* to ensure future compatibility.
13331333
*/
13341334
static int
1335-
trace_callback(unsigned int type, void *ctx, void *prepared_statement,
1336-
void *statement_string)
1335+
trace_callback(unsigned int type, void *ctx, void *stmt, void *sql)
13371336
#else
13381337
static void
1339-
trace_callback(void *ctx, const char *statement_string)
1338+
trace_callback(void *ctx, const char *sql)
13401339
#endif
13411340
{
13421341
#ifdef HAVE_TRACE_V2
@@ -1347,24 +1346,52 @@ trace_callback(void *ctx, const char *statement_string)
13471346

13481347
PyGILState_STATE gilstate = PyGILState_Ensure();
13491348

1350-
PyObject *py_statement = NULL;
1351-
PyObject *ret = NULL;
1352-
py_statement = PyUnicode_DecodeUTF8(statement_string,
1353-
strlen(statement_string), "replace");
13541349
assert(ctx != NULL);
1350+
pysqlite_state *state = ((callback_context *)ctx)->state;
1351+
assert(state != NULL);
1352+
1353+
PyObject *py_statement = NULL;
1354+
#ifdef HAVE_TRACE_V2
1355+
const char *expanded_sql = sqlite3_expanded_sql((sqlite3_stmt *)stmt);
1356+
if (expanded_sql == NULL) {
1357+
sqlite3 *db = sqlite3_db_handle((sqlite3_stmt *)stmt);
1358+
if (sqlite3_errcode(db) == SQLITE_NOMEM) {
1359+
(void)PyErr_NoMemory();
1360+
goto exit;
1361+
}
1362+
1363+
PyErr_SetString(state->DataError,
1364+
"Expanded SQL string exceeds the maximum string "
1365+
"length");
1366+
print_or_clear_traceback((callback_context *)ctx);
1367+
1368+
// Fall back to unexpanded sql
1369+
py_statement = PyUnicode_FromString((const char *)sql);
1370+
}
1371+
else {
1372+
py_statement = PyUnicode_FromString(expanded_sql);
1373+
sqlite3_free((void *)expanded_sql);
1374+
}
1375+
#else
1376+
if (sql == NULL) {
1377+
PyErr_SetString(state->DataError,
1378+
"Expanded SQL string exceeds the maximum string length");
1379+
print_or_clear_traceback((callback_context *)ctx);
1380+
goto exit;
1381+
}
1382+
py_statement = PyUnicode_FromString(sql);
1383+
#endif
13551384
if (py_statement) {
13561385
PyObject *callable = ((callback_context *)ctx)->callable;
1357-
ret = PyObject_CallOneArg(callable, py_statement);
1386+
PyObject *ret = PyObject_CallOneArg(callable, py_statement);
13581387
Py_DECREF(py_statement);
1388+
Py_XDECREF(ret);
13591389
}
1360-
1361-
if (ret) {
1362-
Py_DECREF(ret);
1363-
}
1364-
else {
1365-
print_or_clear_traceback(ctx);
1390+
if (PyErr_Occurred()) {
1391+
print_or_clear_traceback((callback_context *)ctx);
13661392
}
13671393

1394+
exit:
13681395
PyGILState_Release(gilstate);
13691396
#ifdef HAVE_TRACE_V2
13701397
return 0;

0 commit comments

Comments
 (0)