diff --git a/Lib/sqlite3/test/userfunctions.py b/Lib/sqlite3/test/userfunctions.py index b3da3c425b8035..9681dbdde2b092 100644 --- a/Lib/sqlite3/test/userfunctions.py +++ b/Lib/sqlite3/test/userfunctions.py @@ -23,12 +23,16 @@ import contextlib import functools +import gc import io +import sys import unittest import unittest.mock -import gc import sqlite3 as sqlite +from test.support import bigmemtest + + def with_tracebacks(strings): """Convenience decorator for testing callback tracebacks.""" strings.append('Traceback') @@ -69,6 +73,10 @@ def func_returnlonglong(): return 1<<31 def func_raiseexception(): 5/0 +def func_memoryerror(): + raise MemoryError +def func_overflowerror(): + raise OverflowError def func_isstring(v): return type(v) is str @@ -187,6 +195,8 @@ def setUp(self): self.con.create_function("returnblob", 0, func_returnblob) self.con.create_function("returnlonglong", 0, func_returnlonglong) self.con.create_function("raiseexception", 0, func_raiseexception) + self.con.create_function("memoryerror", 0, func_memoryerror) + self.con.create_function("overflowerror", 0, func_overflowerror) self.con.create_function("isstring", 1, func_isstring) self.con.create_function("isint", 1, func_isint) @@ -279,6 +289,20 @@ def test_func_exception(self): cur.fetchone() self.assertEqual(str(cm.exception), 'user-defined function raised exception') + @with_tracebacks(['func_memoryerror', 'MemoryError']) + def test_func_memory_error(self): + cur = self.con.cursor() + with self.assertRaises(MemoryError): + cur.execute("select memoryerror()") + cur.fetchone() + + @with_tracebacks(['func_overflowerror', 'OverflowError']) + def test_func_overflow_error(self): + cur = self.con.cursor() + with self.assertRaises(sqlite.DataError): + cur.execute("select overflowerror()") + cur.fetchone() + def test_param_string(self): cur = self.con.cursor() for text in ["foo", str()]: @@ -384,6 +408,25 @@ def md5sum(t): del x,y gc.collect() + @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') + @bigmemtest(size=2**31, memuse=3, dry_run=False) + def test_large_text(self, size): + cur = self.con.cursor() + for size in 2**31-1, 2**31: + self.con.create_function("largetext", 0, lambda size=size: "b" * size) + with self.assertRaises(sqlite.DataError): + cur.execute("select largetext()") + + @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') + @bigmemtest(size=2**31, memuse=2, dry_run=False) + def test_large_blob(self, size): + cur = self.con.cursor() + for size in 2**31-1, 2**31: + self.con.create_function("largeblob", 0, lambda size=size: b"b" * size) + with self.assertRaises(sqlite.DataError): + cur.execute("select largeblob()") + + class AggregateTests(unittest.TestCase): def setUp(self): self.con = sqlite.connect(":memory:") diff --git a/Misc/NEWS.d/next/Library/2021-08-05-14-59-39.bpo-44839.MURNL9.rst b/Misc/NEWS.d/next/Library/2021-08-05-14-59-39.bpo-44839.MURNL9.rst new file mode 100644 index 00000000000000..62ad62c5d48d54 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2021-08-05-14-59-39.bpo-44839.MURNL9.rst @@ -0,0 +1,4 @@ +:class:`MemoryError` raised in user-defined functions will now produce a +``MemoryError`` in :mod:`sqlite3`. :class:`OverflowError` will now be converted +to :class:`~sqlite3.DataError`. Previously +:class:`~sqlite3.OperationalError` was produced in these cases. diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index aae6c66d63faba..0dab3e85160e82 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -619,6 +619,29 @@ _pysqlite_build_py_params(sqlite3_context *context, int argc, return NULL; } +// Checks the Python exception and sets the appropriate SQLite error code. +static void +set_sqlite_error(sqlite3_context *context, const char *msg) +{ + assert(PyErr_Occurred()); + if (PyErr_ExceptionMatches(PyExc_MemoryError)) { + sqlite3_result_error_nomem(context); + } + else if (PyErr_ExceptionMatches(PyExc_OverflowError)) { + sqlite3_result_error_toobig(context); + } + else { + sqlite3_result_error(context, msg, -1); + } + pysqlite_state *state = pysqlite_get_state(NULL); + if (state->enable_callback_tracebacks) { + PyErr_Print(); + } + else { + PyErr_Clear(); + } +} + static void _pysqlite_func_callback(sqlite3_context *context, int argc, sqlite3_value **argv) { @@ -645,14 +668,7 @@ _pysqlite_func_callback(sqlite3_context *context, int argc, sqlite3_value **argv Py_DECREF(py_retval); } if (!ok) { - pysqlite_state *state = pysqlite_get_state(NULL); - if (state->enable_callback_tracebacks) { - PyErr_Print(); - } - else { - PyErr_Clear(); - } - sqlite3_result_error(context, "user-defined function raised exception", -1); + set_sqlite_error(context, "user-defined function raised exception"); } PyGILState_Release(threadstate); @@ -676,18 +692,9 @@ static void _pysqlite_step_callback(sqlite3_context *context, int argc, sqlite3_ if (*aggregate_instance == NULL) { *aggregate_instance = _PyObject_CallNoArg(aggregate_class); - - if (PyErr_Occurred()) { - *aggregate_instance = 0; - - pysqlite_state *state = pysqlite_get_state(NULL); - if (state->enable_callback_tracebacks) { - PyErr_Print(); - } - else { - PyErr_Clear(); - } - sqlite3_result_error(context, "user-defined aggregate's '__init__' method raised error", -1); + if (!*aggregate_instance) { + set_sqlite_error(context, + "user-defined aggregate's '__init__' method raised error"); goto error; } } @@ -706,14 +713,8 @@ static void _pysqlite_step_callback(sqlite3_context *context, int argc, sqlite3_ Py_DECREF(args); if (!function_result) { - pysqlite_state *state = pysqlite_get_state(NULL); - if (state->enable_callback_tracebacks) { - PyErr_Print(); - } - else { - PyErr_Clear(); - } - sqlite3_result_error(context, "user-defined aggregate's 'step' method raised error", -1); + set_sqlite_error(context, + "user-defined aggregate's 'step' method raised error"); } error: @@ -761,14 +762,8 @@ _pysqlite_final_callback(sqlite3_context *context) Py_DECREF(function_result); } if (!ok) { - pysqlite_state *state = pysqlite_get_state(NULL); - if (state->enable_callback_tracebacks) { - PyErr_Print(); - } - else { - PyErr_Clear(); - } - sqlite3_result_error(context, "user-defined aggregate's 'finalize' method raised error", -1); + set_sqlite_error(context, + "user-defined aggregate's 'finalize' method raised error"); } /* Restore the exception (if any) of the last call to step(),